import re
import subprocess
import requests
from pathlib import Path
from typing import Any, Optional, Literal, Annotated
from dataclasses import dataclass
import tomlkit
import typer
def log(message: str):
print(f"[INFO] {message}")
def fetch_models_dev_json() -> dict[str, Any]:
log("Fetching models.dev API data...")
response = requests.get("https://models.dev/api.json")
if response.status_code != 200:
raise Exception(f"Failed to fetch models.dev JSON: {response.status_code}")
log("Successfully fetched models.dev data")
return response.json()
def to_pascal_case(identifier: str) -> str:
match = re.match(r"^(\d+)(.*)$", identifier)
if match:
digits, rest = match.groups()
identifier = rest + digits
cleaned = re.sub(r"[^a-zA-Z0-9_]", "_", identifier)
return "".join(word.capitalize() for word in cleaned.split("_") if word)
def provider_id_to_snake_case(provider_id: str) -> str:
match = re.match(r"^(\d+)(.*)$", provider_id)
if match:
digits, rest = match.groups()
provider_id = rest + "_" + digits
result = provider_id.replace("-", "_")
result = re.sub(r"_+", "_", result).strip("_")
return result
def to_constructor_name(identifier: str) -> str:
match = re.match(r"^(\d+)(.*)$", identifier)
if match:
digits, rest = match.groups()
identifier = rest + digits
cleaned = re.sub(r"[^a-zA-Z0-9_]", "_", identifier)
cleaned = re.sub(r"_+", "_", cleaned).strip("_")
return cleaned.lower()
def get_project_root() -> Path:
root = Path(__file__).resolve().parent.parent
if not (root / "Cargo.toml").exists():
raise RuntimeError("Could not find project root (missing Cargo.toml)")
return root
def filter_openai_compatible_providers(
all_providers: dict[str, Any],
) -> dict[str, Any]:
return {
provider_id: provider_data
for provider_id, provider_data in all_providers.items()
if provider_data.get("npm") == "@ai-sdk/openai-compatible"
and provider_id != "privatemode-ai" }
def get_model_capabilities(model_data: dict[str, Any]) -> list[str]:
capabilities = []
if model_data.get("tool_call", False):
capabilities.append("ToolCallSupport")
if model_data.get("reasoning", False):
capabilities.append("ReasoningSupport")
if model_data.get("structured_output", False):
capabilities.append("StructuredOutputSupport")
modalities = model_data.get("modalities", {})
input_modalities = modalities.get("input", [])
output_modalities = modalities.get("output", [])
if "text" in input_modalities:
capabilities.append("TextInputSupport")
if "audio" in input_modalities:
capabilities.append("AudioInputSupport")
if "image" in input_modalities or model_data.get("attachment", False):
capabilities.append("ImageInputSupport")
if "video" in input_modalities:
capabilities.append("VideoInputSupport")
if "text" in output_modalities:
capabilities.append("TextOutputSupport")
if "audio" in output_modalities:
capabilities.append("AudioOutputSupport")
if "image" in output_modalities:
capabilities.append("ImageOutputSupport")
if "video" in output_modalities:
capabilities.append("VideoOutputSupport")
return sorted(list(set(capabilities)))
@dataclass
class PendingWrite:
path: Path
content: str
file: Literal["mod.rs", "capabilities.rs"]
def create_pending_write(
provider_id: str,
content: str,
file_type: Literal["mod", "capabilities"],
root: Path,
) -> PendingWrite:
providers_dir = root / "src" / "providers"
provider_dir = providers_dir / provider_id
if file_type == "mod":
file_path = provider_dir / "mod.rs"
file_type_full = "mod.rs"
elif file_type == "capabilities":
file_path = provider_dir / "capabilities.rs"
file_type_full = "capabilities.rs"
else:
raise ValueError(f"Unknown file_type: {file_type}")
return PendingWrite(path=file_path, content=content, file=file_type_full)
def batch_write_files(
pending_writes: list[PendingWrite], dry_run: bool = False
) -> list[Path]:
if not pending_writes:
log("No files to write")
return []
if dry_run:
log(f"[DRY RUN] Would write {len(pending_writes)} files:")
for pw in pending_writes:
log(f" - {pw.path} ({pw.file})")
return []
written_files: list[Path] = []
try:
for pw in pending_writes:
if pw.path.exists():
log(f"Will overwrite: {pw.path}")
for pw in pending_writes:
pw.path.parent.mkdir(parents=True, exist_ok=True)
for pw in pending_writes:
with open(pw.path, "w", encoding="utf-8") as f:
f.write(pw.content)
written_files.append(pw.path)
log(f"Wrote: {pw.path}")
return written_files
except Exception as e:
log(f"Error during file writing: {e}")
log(f"Rolling back {len(written_files)} written files...")
for path in written_files:
try:
path.unlink()
log(f"Rolled back: {path}")
except Exception as rollback_error:
log(f"Failed to rollback {path}: {rollback_error}")
raise
MANUALLY_MANAGED_MODULES = {
"openai",
"openai_compatible",
"openai_chat_completions",
"anthropic",
"groq",
"google",
"vercel",
"openrouter",
"mistral",
"amazon_bedrock",
"togetherai",
"xai",
}
def update_cargo_toml(provider_ids: list[str], root: Path | None = None):
if root is None:
root = get_project_root()
cargo_path = root / "Cargo.toml"
with open(cargo_path, "r", encoding="utf-8") as f:
doc = tomlkit.load(f)
features = doc["features"]
full_feature = features["full"]
modified = False
for provider_id in sorted(provider_ids):
if provider_id not in features:
features[provider_id] = ["openaichatcompletions"]
log(f"Added feature '{provider_id}' to Cargo.toml")
modified = True
if provider_id not in full_feature:
full_feature.append(provider_id)
log(f"Added '{provider_id}' to 'full' feature list")
modified = True
if modified:
with open(cargo_path, "w", encoding="utf-8") as f:
tomlkit.dump(doc, f)
log("Updated Cargo.toml")
else:
log("Cargo.toml already up to date")
def get_codegen_provider_dirs(root: Path) -> list[str]:
providers_dir = root / "src" / "providers"
codegen_dirs = []
for entry in sorted(providers_dir.iterdir()):
if not entry.is_dir():
continue
dir_name = entry.name
normalized = dir_name.replace("-", "_")
if normalized in MANUALLY_MANAGED_MODULES:
continue
codegen_dirs.append(dir_name)
return codegen_dirs
def update_providers_mod_rs(root: Path | None = None):
if root is None:
root = get_project_root()
mod_rs_path = root / "src" / "providers" / "mod.rs"
content = mod_rs_path.read_text(encoding="utf-8")
start_marker = "// [codegen]"
end_marker = "// [end-codegen]"
start_idx = content.find(start_marker)
end_idx = content.find(end_marker)
if start_idx == -1 or end_idx == -1:
raise RuntimeError(
f"Could not find codegen markers in {mod_rs_path}. "
f"Expected '{start_marker}' and '{end_marker}'"
)
provider_dirs = get_codegen_provider_dirs(root)
lines = []
for dir_name in provider_dirs:
module_name = provider_id_to_snake_case(dir_name)
struct_name = to_pascal_case(dir_name)
feature_name = dir_name
needs_path_attr = dir_name != module_name
lines.append(f'#[cfg(feature = "{feature_name}")]')
if needs_path_attr:
lines.append(f'#[path = "{dir_name}/mod.rs"]')
lines.append(f"pub mod {module_name};")
lines.append(f'#[cfg(feature = "{feature_name}")]')
lines.append(f"pub use {module_name}::{struct_name};")
lines.append("")
codegen_content = "\n".join(lines)
new_block = f"{start_marker}\n{codegen_content}{end_marker}"
new_content = content[:start_idx] + new_block + content[end_idx + len(end_marker) :]
mod_rs_path.write_text(new_content, encoding="utf-8")
log(f"Updated {mod_rs_path} with {len(provider_dirs)} codegen providers")
def get_model_display_name(model_id: str, model_data: dict[str, Any]) -> str:
return model_data.get("name", model_id)
def get_model_constructor_name(model_id: str, folder_prefix: str | None = None) -> str:
if folder_prefix:
prefix_parts = folder_prefix.split("/")
prefix = "_".join(to_constructor_name(p) for p in prefix_parts)
return prefix + "_" + to_constructor_name(model_id)
return to_constructor_name(model_id)
def get_model_type_name(model_id: str, folder_prefix: str | None = None) -> str:
if folder_prefix:
prefix_parts = folder_prefix.split("/")
prefix = "".join(to_pascal_case(p) for p in prefix_parts)
return prefix + to_pascal_case(model_id)
return to_pascal_case(model_id)
def parse_model_id(model_id: str) -> tuple[str, str | None]:
if "/" in model_id:
parts = model_id.split("/")
if len(parts) == 2:
return parts[1], parts[0]
elif len(parts) > 2:
return parts[-1], "/".join(parts[:-1])
return model_id, None
def generate_capabilities_rs(provider_id: str, models: dict[str, Any]) -> str:
provider_struct_name = to_pascal_case(provider_id)
provider_module = provider_id_to_snake_case(provider_id)
lines = [
f"//! Capabilities for {provider_module} models.",
"//!",
f"//! This module defines model types and their capabilities for {
provider_module
} providers.",
"//! Users can implement additional traits on custom models.",
"",
"use crate::core::capabilities::*;",
"use crate::model_capabilities;",
f"use crate::providers::{provider_module}::{provider_struct_name};",
"",
"model_capabilities! {",
f" provider: {provider_struct_name},",
" models: {",
]
active_models = {
model_id: model_data
for model_id, model_data in models.items()
if model_data.get("status") != "deprecated"
}
seen_type_names: set[str] = set()
for model_id, model_data in sorted(active_models.items()):
base_name, folder_prefix = parse_model_id(model_id)
model_type_name = get_model_type_name(base_name, folder_prefix)
if model_type_name in seen_type_names:
log(
f"Skipping duplicate type name '{model_type_name}' for model '{model_id}'"
)
continue
seen_type_names.add(model_type_name)
model_name = model_id constructor_name = get_model_constructor_name(base_name, folder_prefix)
display_name = get_model_display_name(model_id, model_data)
capabilities = get_model_capabilities(model_data)
lines.extend(
[
f" {model_type_name} {{",
f' model_name: "{model_name}",',
f" constructor_name: {constructor_name},",
f' display_name: "{display_name}",',
f" capabilities: [{', '.join(capabilities)}]",
" },",
]
)
lines.extend(
[
" }",
"}",
]
)
return "\n".join(lines) + "\n"
def generate_provider_capabilities_content(
provider_id: str, provider_data: dict[str, Any]
) -> str | None:
models = provider_data.get("models", {})
if not models:
log(f"Warning: No models found for provider '{provider_id}'")
return None
log(f"Preparing capabilities for '{provider_id}' with {len(models)} models")
return generate_capabilities_rs(provider_id, models)
def prepare_provider_capabilities(
provider_id: str, provider_data: dict[str, Any], root: Path | None = None
) -> PendingWrite | None:
if root is None:
root = get_project_root()
content = generate_provider_capabilities_content(provider_id, provider_data)
if content is None:
return None
return create_pending_write(provider_id, content, "capabilities", root)
def prepare_all_capabilities(
all_providers: dict[str, Any],
) -> list[PendingWrite]:
root = get_project_root()
pending_writes = []
for provider_id, provider_data in all_providers.items():
try:
pending_write = prepare_provider_capabilities(
provider_id, provider_data, root
)
if pending_write:
pending_writes.append(pending_write)
except Exception as e:
log(f"Error preparing capabilities for '{provider_id}': {e}")
return pending_writes
def generate_openai_compatible_content(
provider_id: str, provider_data: dict[str, Any]
) -> str:
provider_struct_name = to_pascal_case(provider_id)
api_endpoint = provider_data.get("api", "")
env_var = provider_data.get("env", [""])[0]
default_model = provider_data.get("model", provider_id)
lines = [
f"//! This module provides the {
provider_struct_name
} provider, wrapping OpenAI Chat Completions for {
provider_struct_name
} requests.",
"",
"pub mod capabilities;",
"",
"// Generate the settings module",
"crate::openai_compatible_settings!(",
f" {provider_struct_name}ProviderSettings,",
f" {provider_struct_name}ProviderSettingsBuilder,",
f' "{provider_struct_name}",',
f' "{api_endpoint}",',
f' "{env_var}"',
");",
"",
"// Generate the provider struct and builder",
"crate::openai_compatible_provider!(",
f" {provider_struct_name},",
f" {provider_struct_name}Builder,",
f" {provider_struct_name}ProviderSettings,",
f' "{default_model}"',
");",
"",
"// Generate the language model implementation",
f"crate::openai_compatible_language_model!({provider_struct_name});",
"",
]
return "\n".join(lines)
def prepare_openai_compatible_provider(
provider_id: str,
provider_data: dict[str, Any],
with_capabilities: bool = False,
root: Path | None = None,
) -> list[PendingWrite]:
if root is None:
root = get_project_root()
log(f"Preparing OpenAI-compatible provider '{provider_id}'")
pending = []
mod_content = generate_openai_compatible_content(provider_id, provider_data)
pending.append(create_pending_write(provider_id, mod_content, "mod", root))
if with_capabilities:
cap_write = prepare_provider_capabilities(provider_id, provider_data, root)
if cap_write:
pending.append(cap_write)
return pending
def prepare_all_openai_compatible_providers(
all_providers: dict[str, Any], with_capabilities: bool = False
) -> list[PendingWrite]:
root = get_project_root()
compatible_providers = filter_openai_compatible_providers(all_providers)
log(f"Found {len(compatible_providers)} OpenAI-compatible providers")
pending_writes = []
for provider_id, provider_data in compatible_providers.items():
try:
provider_writes = prepare_openai_compatible_provider(
provider_id, provider_data, with_capabilities, root
)
pending_writes.extend(provider_writes)
except Exception as e:
log(f"Error preparing OpenAI-compatible provider '{provider_id}': {e}")
return pending_writes
app = typer.Typer(
help="AI SDK Provider Code Generator - Generate Rust code for aisdk crate from models.dev",
no_args_is_help=True,
)
def run_cargo_fmt():
try:
root = get_project_root()
log("Running cargo fmt...")
result = subprocess.run(
["cargo", "fmt", "--all"], cwd=root, capture_output=True, text=True
)
if result.returncode != 0:
log(f"Warning: cargo fmt failed: {result.stderr}")
else:
log("Code formatted successfully")
except Exception as e:
log(f"Warning: Could not run cargo fmt: {e}")
@app.command()
def openai_compatible(
provider_id: Annotated[
Optional[str],
typer.Argument(
help="Specific models.dev provider ID to generate (e.g., 'deepseek', 'openrouter')"
),
] = None,
with_capabilities: Annotated[
bool,
typer.Option("--with-capabilities", "-c", help="Also generate capabilities.rs"),
] = False,
):
try:
all_providers = fetch_models_dev_json()
if provider_id:
if provider_id not in all_providers:
log(f"Error: Provider '{provider_id}' not found in models.dev")
raise typer.Exit(code=1)
provider_data = all_providers[provider_id]
if provider_data.get("npm") != "@ai-sdk/openai-compatible":
log(f"Error: Provider '{provider_id}' is not OpenAI-compatible")
log(
f"Use 'capabilities {
provider_id
}' instead for non-OpenAI-compatible provider capabilities"
)
raise typer.Exit(code=1)
pending_writes = prepare_openai_compatible_provider(
provider_id, provider_data, with_capabilities
)
else:
pending_writes = prepare_all_openai_compatible_providers(
all_providers, with_capabilities
)
written_files = batch_write_files(pending_writes)
generated_provider_ids = []
if provider_id:
generated_provider_ids = [provider_id]
else:
generated_provider_ids = list(
filter_openai_compatible_providers(all_providers).keys()
)
update_cargo_toml(generated_provider_ids)
update_providers_mod_rs()
run_cargo_fmt()
log(f"✓ Successfully wrote {len(written_files)} files")
except Exception as e:
log(f"Error: {e}")
raise typer.Exit(code=1)
@app.command()
def capabilities(
provider_id: Annotated[
Optional[str],
typer.Argument(
help="Specific provider ID to generate (e.g., 'openai', 'anthropic')"
),
] = None,
):
try:
all_providers = fetch_models_dev_json()
if provider_id:
if provider_id not in all_providers:
log(f"Error: Provider '{provider_id}' not found in models.dev")
raise typer.Exit(code=1)
provider_data = all_providers[provider_id]
cap_write = prepare_provider_capabilities(provider_id, provider_data)
pending_writes = [cap_write] if cap_write else []
if not pending_writes:
log(f"Warning: No capabilities to generate for '{provider_id}'")
else:
pending_writes = prepare_all_capabilities(all_providers)
written_files = batch_write_files(pending_writes)
update_providers_mod_rs()
run_cargo_fmt()
log(f"✓ Successfully wrote {len(written_files)} files")
except Exception as e:
log(f"Error: {e}")
raise typer.Exit(code=1)
def main():
app()
if __name__ == "__main__":
main()