import argparse
import json
import subprocess
import sys
import tempfile
from dataclasses import (
dataclass,
)
from difflib import unified_diff
from pathlib import Path
from shutil import copy2
from typing import Any, Literal
SCHEMA_VERSION = "2025-06-18"
JSONRPC_VERSION = "2.0"
STANDARD_DERIVE = "#[derive(Debug, Clone, PartialEq, Deserialize, Serialize, JsonSchema, TS)]\n"
STANDARD_HASHABLE_DERIVE = (
"#[derive(Debug, Clone, PartialEq, Deserialize, Serialize, Hash, Eq, JsonSchema, TS)]\n"
)
DEFINITIONS: dict[str, Any] = {}
CLIENT_REQUEST_TYPE_NAMES: list[str] = []
SERVER_NOTIFICATION_TYPE_NAMES: list[str] = []
LARGE_ENUMS = {"ServerResult"}
def main() -> int:
parser = argparse.ArgumentParser(
description="Embed, cluster and analyse text prompts via the OpenAI API.",
)
default_schema_file = (
Path(__file__).resolve().parent / "schema" / SCHEMA_VERSION / "schema.json"
)
default_lib_rs = Path(__file__).resolve().parent / "src/lib.rs"
parser.add_argument(
"schema_file",
nargs="?",
default=default_schema_file,
help="schema.json file to process",
)
parser.add_argument(
"--check",
action="store_true",
help="Regenerate lib.rs in a sandbox and ensure the checked-in file matches",
)
args = parser.parse_args()
schema_file = Path(args.schema_file)
crate_dir = Path(__file__).resolve().parent
if args.check:
return run_check(schema_file, crate_dir, default_lib_rs)
generate_lib_rs(schema_file, default_lib_rs, fmt=True)
return 0
def generate_lib_rs(schema_file: Path, lib_rs: Path, fmt: bool) -> None:
lib_rs.parent.mkdir(parents=True, exist_ok=True)
global DEFINITIONS
with schema_file.open(encoding="utf-8") as f:
schema_json = json.load(f)
DEFINITIONS = schema_json["definitions"]
out = [
f"""
// @generated
// DO NOT EDIT THIS FILE DIRECTLY.
// Run the following in the crate root to regenerate this file:
//
// ```shell
// ./generate_mcp_types.py
// ```
use serde::Deserialize;
use serde::Serialize;
use serde::de::DeserializeOwned;
use std::convert::TryFrom;
use schemars::JsonSchema;
use ts_rs::TS;
pub const MCP_SCHEMA_VERSION: &str = "{SCHEMA_VERSION}";
pub const JSONRPC_VERSION: &str = "{JSONRPC_VERSION}";
/// Paired request/response types for the Model Context Protocol (MCP).
pub trait ModelContextProtocolRequest {{
const METHOD: &'static str;
type Params: DeserializeOwned + Serialize + Send + Sync + 'static;
type Result: DeserializeOwned + Serialize + Send + Sync + 'static;
}}
/// One-way message in the Model Context Protocol (MCP).
pub trait ModelContextProtocolNotification {{
const METHOD: &'static str;
type Params: DeserializeOwned + Serialize + Send + Sync + 'static;
}}
fn default_jsonrpc() -> String {{ JSONRPC_VERSION.to_owned() }}
"""
]
definitions = schema_json["definitions"]
for name, definition in definitions.items():
add_definition(name, definition, out)
try_from_impl_lines: list[str] = []
try_from_impl_lines.append("impl TryFrom<JSONRPCRequest> for ClientRequest {\n")
try_from_impl_lines.append(" type Error = serde_json::Error;\n")
try_from_impl_lines.append(
" fn try_from(req: JSONRPCRequest) -> std::result::Result<Self, Self::Error> {\n"
)
try_from_impl_lines.append(" match req.method.as_str() {\n")
for req_name in CLIENT_REQUEST_TYPE_NAMES:
defn = definitions[req_name]
method_const = defn.get("properties", {}).get("method", {}).get("const", req_name)
payload_type = f"<{req_name} as ModelContextProtocolRequest>::Params"
try_from_impl_lines.append(f' "{method_const}" => {{\n')
try_from_impl_lines.append(
" let params_json = req.params.unwrap_or(serde_json::Value::Null);\n"
)
try_from_impl_lines.append(
f" let params: {payload_type} = serde_json::from_value(params_json)?;\n"
)
try_from_impl_lines.append(f" Ok(ClientRequest::{req_name}(params))\n")
try_from_impl_lines.append(" },\n")
try_from_impl_lines.append(
' _ => Err(serde_json::Error::io(std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Unknown method: {}", req.method)))),\n'
)
try_from_impl_lines.append(" }\n")
try_from_impl_lines.append(" }\n")
try_from_impl_lines.append("}\n\n")
out.extend(try_from_impl_lines)
notif_impl_lines: list[str] = []
notif_impl_lines.append("impl TryFrom<JSONRPCNotification> for ServerNotification {\n")
notif_impl_lines.append(" type Error = serde_json::Error;\n")
notif_impl_lines.append(
" fn try_from(n: JSONRPCNotification) -> std::result::Result<Self, Self::Error> {\n"
)
notif_impl_lines.append(" match n.method.as_str() {\n")
for notif_name in SERVER_NOTIFICATION_TYPE_NAMES:
n_def = definitions[notif_name]
method_const = n_def.get("properties", {}).get("method", {}).get("const", notif_name)
payload_type = f"<{notif_name} as ModelContextProtocolNotification>::Params"
notif_impl_lines.append(f' "{method_const}" => {{\n')
notif_impl_lines.append(
" let params_json = n.params.unwrap_or(serde_json::Value::Null);\n"
)
notif_impl_lines.append(
f" let params: {payload_type} = serde_json::from_value(params_json)?;\n"
)
notif_impl_lines.append(f" Ok(ServerNotification::{notif_name}(params))\n")
notif_impl_lines.append(" },\n")
notif_impl_lines.append(
' _ => Err(serde_json::Error::io(std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Unknown method: {}", n.method)))),\n'
)
notif_impl_lines.append(" }\n")
notif_impl_lines.append(" }\n")
notif_impl_lines.append("}\n")
out.extend(notif_impl_lines)
with open(lib_rs, "w", encoding="utf-8") as f:
for chunk in out:
f.write(chunk)
if fmt:
subprocess.check_call(
["cargo", "fmt", "--", "--config", "imports_granularity=Item"],
cwd=lib_rs.parent.parent,
stderr=subprocess.DEVNULL,
)
def run_check(schema_file: Path, crate_dir: Path, checked_in_lib: Path) -> int:
config_path = crate_dir.parent / "rustfmt.toml"
eprint(f"Running --check with schema {schema_file}")
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_path = Path(tmp_dir)
eprint(f"Created temporary workspace at {tmp_path}")
manifest_path = tmp_path / "Cargo.toml"
eprint(f"Copying Cargo.toml into {manifest_path}")
copy2(crate_dir / "Cargo.toml", manifest_path)
manifest_text = manifest_path.read_text(encoding="utf-8")
manifest_text = manifest_text.replace(
"version = { workspace = true }",
'version = "0.0.0"',
)
manifest_text = manifest_text.replace("\n[lints]\nworkspace = true\n", "\n")
manifest_path.write_text(manifest_text, encoding="utf-8")
src_dir = tmp_path / "src"
src_dir.mkdir(parents=True, exist_ok=True)
eprint(f"Generating lib.rs into {src_dir}")
generated_lib = src_dir / "lib.rs"
generate_lib_rs(schema_file, generated_lib, fmt=False)
eprint("Formatting generated lib.rs with rustfmt")
subprocess.check_call(
[
"rustfmt",
"--config-path",
str(config_path),
str(generated_lib),
],
cwd=tmp_path,
stderr=subprocess.DEVNULL,
)
eprint("Comparing generated lib.rs with checked-in version")
checked_in_contents = checked_in_lib.read_text(encoding="utf-8")
generated_contents = generated_lib.read_text(encoding="utf-8")
if checked_in_contents == generated_contents:
eprint("lib.rs matches checked-in version")
return 0
diff = unified_diff(
checked_in_contents.splitlines(keepends=True),
generated_contents.splitlines(keepends=True),
fromfile=str(checked_in_lib),
tofile=str(generated_lib),
)
diff_text = "".join(diff)
eprint("Generated lib.rs does not match the checked-in version. Diff:")
if diff_text:
eprint(diff_text, end="")
eprint("Re-run generate_mcp_types.py without --check to update src/lib.rs.")
return 1
def add_definition(name: str, definition: dict[str, Any], out: list[str]) -> None:
if name == "Result":
out.append("pub type Result = serde_json::Value;\n\n")
return
description = definition.get("description")
properties = definition.get("properties", {})
if properties:
required_props = set(definition.get("required", []))
out.extend(define_struct(name, properties, required_props, description))
if name.endswith("Result"):
out.extend(f"impl From<{name}> for serde_json::Value {{\n")
out.append(f" fn from(value: {name}) -> Self {{\n")
out.append(" // Leave this as it should never fail\n")
out.append(" #[expect(clippy::unwrap_used)]\n")
out.append(" serde_json::to_value(value).unwrap()\n")
out.append(" }\n")
out.append("}\n\n")
return
enum_values = definition.get("enum", [])
if enum_values:
assert definition.get("type") == "string"
define_string_enum(name, enum_values, out, description)
return
any_of = definition.get("anyOf", [])
if any_of:
assert isinstance(any_of, list)
out.extend(define_any_of(name, any_of, description))
return
type_prop = definition.get("type", None)
if type_prop:
if type_prop == "string":
out.append(STANDARD_DERIVE)
out.append(f"pub struct {name}(String);\n\n")
return
elif types := check_string_list(type_prop):
define_untagged_enum(name, types, out)
return
elif type_prop == "array":
item_name = name + "Item"
out.extend(define_any_of(item_name, definition["items"]["anyOf"]))
out.append(f"pub type {name} = Vec<{item_name}>;\n\n")
return
raise ValueError(f"Unknown type: {type_prop} in {name}")
ref_prop = definition.get("$ref", None)
if ref_prop:
ref = type_from_ref(ref_prop)
out.extend(f"pub type {name} = {ref};\n\n")
return
raise ValueError(f"Definition for {name} could not be processed.")
extra_defs = []
@dataclass
class StructField:
viz: Literal["pub"] | Literal["const"]
name: str
type_name: str
serde: str | None = None
ts: str | None = None
comment: str | None = None
def append(self, out: list[str], supports_const: bool) -> None:
if self.comment:
out.append(f" // {self.comment}\n")
if self.serde:
out.append(f" {self.serde}\n")
if self.ts:
out.append(f" {self.ts}\n")
if self.viz == "const":
if supports_const:
out.append(f" const {self.name}: {self.type_name};\n")
else:
out.append(f" pub {self.name}: String, // {self.type_name}\n")
else:
out.append(f" pub {self.name}: {self.type_name},\n")
def define_struct(
name: str,
properties: dict[str, Any],
required_props: set[str],
description: str | None,
) -> list[str]:
out: list[str] = []
fields: list[StructField] = []
for prop_name, prop in properties.items():
if prop_name == "_meta":
continue
elif prop_name == "jsonrpc":
fields.append(
StructField(
"pub",
"jsonrpc",
"String", '#[serde(rename = "jsonrpc", default = "default_jsonrpc")]',
)
)
continue
prop_type = map_type(prop, prop_name, name)
is_optional = prop_name not in required_props
if is_optional:
prop_type = f"Option<{prop_type}>"
rs_prop = rust_prop_name(prop_name, is_optional)
if prop_type.startswith("&'static str"):
fields.append(StructField("const", rs_prop.name, prop_type, rs_prop.serde, rs_prop.ts))
else:
fields.append(StructField("pub", rs_prop.name, prop_type, rs_prop.serde, rs_prop.ts))
if name == "Implementation":
fields.append(
StructField(
"pub",
"user_agent",
"Option<String>",
'#[serde(default, skip_serializing_if = "Option::is_none")]',
'#[ts(optional)]',
"This is an extra field that the Codex MCP server sends as part of InitializeResult.",
)
)
if implements_request_trait(name):
add_trait_impl(name, "ModelContextProtocolRequest", fields, out)
elif implements_notification_trait(name):
add_trait_impl(name, "ModelContextProtocolNotification", fields, out)
else:
emit_doc_comment(description, out)
out.append(STANDARD_DERIVE)
out.append(f"pub struct {name} {{\n")
for field in fields:
field.append(out, supports_const=False)
out.append("}\n\n")
if extra_defs:
out.extend(extra_defs)
extra_defs.clear()
return out
def infer_result_type(request_type_name: str) -> str:
if not request_type_name.endswith("Request"):
return "Result" candidate = request_type_name[:-7] + "Result"
if candidate in DEFINITIONS:
return candidate
return "Result"
def implements_request_trait(name: str) -> bool:
return name.endswith("Request") and name not in (
"Request",
"JSONRPCRequest",
"PaginatedRequest",
)
def implements_notification_trait(name: str) -> bool:
return name.endswith("Notification") and name not in (
"Notification",
"JSONRPCNotification",
)
def add_trait_impl(
type_name: str, trait_name: str, fields: list[StructField], out: list[str]
) -> None:
out.append(STANDARD_DERIVE)
out.append(f"pub enum {type_name} {{}}\n\n")
out.append(f"impl {trait_name} for {type_name} {{\n")
for field in fields:
if field.name == "method":
field.name = "METHOD"
field.append(out, supports_const=True)
elif field.name == "params":
out.append(f" type Params = {field.type_name};\n")
else:
print(f"Warning: {type_name} has unexpected field {field.name}.")
if trait_name == "ModelContextProtocolRequest":
result_type = infer_result_type(type_name)
out.append(f" type Result = {result_type};\n")
out.append("}\n\n")
def define_string_enum(
name: str, enum_values: Any, out: list[str], description: str | None
) -> None:
emit_doc_comment(description, out)
out.append(STANDARD_DERIVE)
out.append(f"pub enum {name} {{\n")
for value in enum_values:
assert isinstance(value, str)
out.append(f' #[serde(rename = "{value}")]\n')
out.append(f" {capitalize(value)},\n")
out.append("}\n\n")
def define_untagged_enum(name: str, type_list: list[str], out: list[str]) -> None:
out.append(STANDARD_HASHABLE_DERIVE)
out.append("#[serde(untagged)]\n")
out.append(f"pub enum {name} {{\n")
for simple_type in type_list:
match simple_type:
case "string":
out.append(" String(String),\n")
case "integer":
out.append(" Integer(i64),\n")
case _:
raise ValueError(f"Unknown type in untagged enum: {simple_type} in {name}")
out.append("}\n\n")
def define_any_of(name: str, list_of_refs: list[Any], description: str | None = None) -> list[str]:
refs = [item["$ref"] for item in list_of_refs if isinstance(item, dict)]
out: list[str] = []
if description:
emit_doc_comment(description, out)
out.append(STANDARD_DERIVE)
if serde := get_serde_annotation_for_anyof_type(name):
out.append(serde + "\n")
if name in LARGE_ENUMS:
out.append("#[allow(clippy::large_enum_variant)]\n")
out.append(f"pub enum {name} {{\n")
if name == "ClientRequest":
global CLIENT_REQUEST_TYPE_NAMES
CLIENT_REQUEST_TYPE_NAMES = [type_from_ref(r) for r in refs]
if name == "ServerNotification":
global SERVER_NOTIFICATION_TYPE_NAMES
SERVER_NOTIFICATION_TYPE_NAMES = [type_from_ref(r) for r in refs]
for ref in refs:
ref_name = type_from_ref(ref)
variant_name = (
ref_name[len("JSONRPC") :]
if name == "JSONRPCMessage" and ref_name.startswith("JSONRPC")
else ref_name
)
if name in ("ClientRequest", "ServerNotification"):
if name == "ClientRequest":
payload_type = f"<{ref_name} as ModelContextProtocolRequest>::Params"
else:
payload_type = f"<{ref_name} as ModelContextProtocolNotification>::Params"
request_def = DEFINITIONS.get(ref_name, {})
method_const = (
request_def.get("properties", {}).get("method", {}).get("const", ref_name)
)
out.append(f' #[serde(rename = "{method_const}")]\n')
out.append(f" {variant_name}({payload_type}),\n")
else:
out.append(f" {variant_name}({ref_name}),\n")
out.append("}\n\n")
return out
def get_serde_annotation_for_anyof_type(type_name: str) -> str | None:
match type_name:
case "ClientRequest":
return '#[serde(tag = "method", content = "params")]'
case "ServerNotification":
return '#[serde(tag = "method", content = "params")]'
case _:
return "#[serde(untagged)]"
def map_type(
typedef: dict[str, Any],
prop_name: str | None = None,
struct_name: str | None = None,
) -> str:
ref_prop = typedef.get("$ref", None)
if ref_prop:
return type_from_ref(ref_prop)
any_of = typedef.get("anyOf", None)
if any_of:
assert prop_name is not None
assert struct_name is not None
custom_type = struct_name + capitalize(prop_name)
extra_defs.extend(define_any_of(custom_type, any_of))
return custom_type
type_prop = typedef.get("type", None)
if type_prop is None:
return "serde_json::Value"
if type_prop == "string":
if const_prop := typedef.get("const", None):
assert isinstance(const_prop, str)
return f'&\'static str = "{const_prop}"'
else:
return "String"
elif type_prop == "integer":
return "i64"
elif type_prop == "number":
return "f64"
elif type_prop == "boolean":
return "bool"
elif type_prop == "array":
item_type = typedef.get("items", None)
if item_type:
item_type = map_type(item_type, prop_name, struct_name)
assert isinstance(item_type, str)
return f"Vec<{item_type}>"
else:
raise ValueError("Array type without items.")
elif type_prop == "object":
if typedef.get("additionalProperties") is not None:
return "serde_json::Value"
if not typedef.get("properties"):
return "serde_json::Value"
assert prop_name is not None
assert struct_name is not None
custom_type = struct_name + capitalize(prop_name)
extra_defs.extend(
define_struct(
custom_type,
typedef["properties"],
set(typedef.get("required", [])),
typedef.get("description"),
)
)
return custom_type
else:
raise ValueError(f"Unknown type: {type_prop} in {typedef}")
@dataclass
class RustProp:
name: str
serde: str | None = None
ts: str | None = None
def rust_prop_name(name: str, is_optional: bool) -> RustProp:
prop_name: str
is_rename = False
if name == "type":
prop_name = "r#type"
elif name == "ref":
prop_name = "r#ref"
elif name == "enum":
prop_name = "r#enum"
elif snake_case := to_snake_case(name):
prop_name = snake_case
is_rename = True
else:
prop_name = name
serde_annotations = []
ts_str = None
if is_rename:
serde_annotations.append(f'rename = "{name}"')
if is_optional:
serde_annotations.append("default")
serde_annotations.append('skip_serializing_if = "Option::is_none"')
if serde_annotations:
serde_str = f"#[serde({', '.join(serde_annotations)})]"
else:
serde_str = None
if is_optional and serde_str:
ts_str = "#[ts(optional)]"
return RustProp(prop_name, serde_str, ts_str)
def to_snake_case(name: str) -> str | None:
snake_case = name[0].lower() + "".join("_" + c.lower() if c.isupper() else c for c in name[1:])
if snake_case != name:
return snake_case
else:
return None
def capitalize(name: str) -> str:
return name[0].upper() + name[1:]
def check_string_list(value: Any) -> list[str] | None:
if not isinstance(value, list):
return None
for item in value:
if not isinstance(item, str):
return None
return value
def type_from_ref(ref: str) -> str:
assert ref.startswith("#/definitions/")
return ref.split("/")[-1]
def emit_doc_comment(text: str | None, out: list[str]) -> None:
if not text:
return
for line in text.strip().split("\n"):
out.append(f"/// {line.rstrip()}\n")
def eprint(*args: Any, **kwargs: Any) -> None:
print(*args, file=sys.stderr, **kwargs)
if __name__ == "__main__":
sys.exit(main())