Skip to main content

llm_codegen/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use serde::Deserialize;
4use std::collections::{BTreeMap, HashMap};
5use std::fmt::Write;
6use std::path::Path;
7
8type ModelsDevData = HashMap<String, ProviderData>;
9type ContextWindowOverride = fn(&str, u32) -> u32;
10
11#[derive(Debug, Deserialize)]
12struct ProviderData {
13    #[allow(dead_code)]
14    id: String,
15    #[allow(dead_code)]
16    name: String,
17    #[serde(default)]
18    #[allow(dead_code)]
19    env: Vec<String>,
20    #[serde(default)]
21    models: HashMap<String, ModelData>,
22}
23
24#[derive(Debug, Deserialize)]
25struct ModelData {
26    id: String,
27    name: String,
28    #[serde(default)]
29    tool_call: Option<bool>,
30    #[serde(default)]
31    reasoning: Option<bool>,
32    #[serde(default)]
33    #[allow(dead_code)]
34    cost: Option<CostData>,
35    #[serde(default)]
36    limit: Option<LimitData>,
37    #[serde(default)]
38    modalities: Option<ModalitiesData>,
39}
40
41#[derive(Debug, Deserialize, Default)]
42struct ModalitiesData {
43    #[serde(default)]
44    input: Vec<String>,
45}
46
47#[derive(Debug, Deserialize)]
48#[allow(dead_code)]
49struct CostData {
50    #[serde(default)]
51    input: f64,
52    #[serde(default)]
53    output: f64,
54}
55
56#[derive(Debug, Deserialize)]
57struct LimitData {
58    #[serde(default)]
59    context: u32,
60    #[serde(default)]
61    #[allow(dead_code)]
62    output: u32,
63}
64
65/// Provider configuration for codegen (catalog providers with known model lists)
66struct ProviderConfig {
67    /// Unique provider key used in `provider_models` map (e.g. "codex")
68    dev_id: &'static str,
69    /// models.dev provider ID to read models from (defaults to `dev_id` when `None`)
70    source_dev_id: Option<&'static str>,
71    /// Additional models.dev keys whose models are merged into this provider
72    extra_source_ids: &'static [&'static str],
73    /// Only include models whose ID passes this filter (None = include all)
74    model_filter: Option<fn(&str) -> bool>,
75    /// Provider-specific generated context window override
76    context_window_override: Option<ContextWindowOverride>,
77    /// Our Rust enum name (e.g. "Gemini")
78    enum_name: &'static str,
79    /// Our internal provider name used for parsing (e.g. "gemini")
80    parser_name: &'static str,
81    /// Human-readable provider name (e.g. "AWS Bedrock")
82    display_name: &'static str,
83    /// Env var our code actually checks (None for providers with complex credential chains)
84    env_var: Option<&'static str>,
85    /// OAuth provider ID for providers that require OAuth login (e.g. "codex")
86    oauth_provider_id: Option<&'static str>,
87    /// Default reasoning levels for models that support reasoning (empty = use standard 3)
88    default_reasoning_levels: &'static [&'static str],
89    /// Emit a hybrid `{Enum}Model` that wraps `{Enum}FoundationModel` (catalog) plus
90    /// a `Profile(String)` variant. Used for Bedrock to accept arbitrary inference
91    /// profile IDs and ARNs at runtime.
92    is_hybrid_dynamic: bool,
93}
94
95impl ProviderConfig {
96    /// Shorthand for providers with default `source_dev_id`, `model_filter`, and `oauth_provider_id`.
97    const fn standard(
98        dev_id: &'static str,
99        enum_name: &'static str,
100        parser_name: &'static str,
101        display_name: &'static str,
102        env_var: Option<&'static str>,
103    ) -> Self {
104        Self {
105            dev_id,
106            source_dev_id: None,
107            extra_source_ids: &[],
108            model_filter: None,
109            context_window_override: None,
110            enum_name,
111            parser_name,
112            display_name,
113            env_var,
114            oauth_provider_id: None,
115            default_reasoning_levels: &["low", "medium", "high"],
116            is_hybrid_dynamic: false,
117        }
118    }
119
120    /// Inner catalog-enum name. For hybrid providers the outer `{enum_name}Model`
121    /// is a wrapper; the catalog enum is `{enum_name}FoundationModel`.
122    fn inner_enum_name(&self) -> String {
123        if self.is_hybrid_dynamic {
124            format!("{}FoundationModel", self.enum_name)
125        } else {
126            format!("{}Model", self.enum_name)
127        }
128    }
129
130    /// Outer enum name as referenced by `LlmModel::{enum_name}(...)`.
131    fn outer_enum_name(&self) -> String {
132        format!("{}Model", self.enum_name)
133    }
134
135    /// The models.dev key to look up in the JSON data.
136    fn json_key(&self) -> &'static str {
137        self.source_dev_id.unwrap_or(self.dev_id)
138    }
139}
140
141/// Dynamic provider — model name is user-supplied at runtime, no fixed enum
142#[allow(clippy::struct_field_names)]
143struct DynamicProviderConfig {
144    /// Rust variant name in `LlmModel` (e.g. "Ollama")
145    enum_name: &'static str,
146    /// Parser name used in "provider:model" strings (e.g. "ollama")
147    parser_name: &'static str,
148    /// Human-readable provider name (e.g. "Ollama")
149    display_name: &'static str,
150}
151
152const PROVIDERS: &[ProviderConfig] = &[
153    ProviderConfig::standard("anthropic", "Anthropic", "anthropic", "Anthropic", Some("ANTHROPIC_API_KEY")),
154    ProviderConfig {
155        dev_id: "codex",
156        source_dev_id: Some("openai"),
157        extra_source_ids: &[],
158        model_filter: Some(|id| id.contains("codex") || id.starts_with("gpt-5.") || id == "gpt-5"),
159        context_window_override: Some(codex_subscription_context_window),
160        enum_name: "Codex",
161        parser_name: "codex",
162        display_name: "Codex",
163        env_var: None,
164        oauth_provider_id: Some("codex"),
165        default_reasoning_levels: &["low", "medium", "high", "xhigh"],
166        is_hybrid_dynamic: false,
167    },
168    ProviderConfig::standard("deepseek", "DeepSeek", "deepseek", "DeepSeek", Some("DEEPSEEK_API_KEY")),
169    ProviderConfig::standard("google", "Gemini", "gemini", "Gemini", Some("GEMINI_API_KEY")),
170    ProviderConfig::standard("moonshotai", "Moonshot", "moonshot", "Moonshot", Some("MOONSHOT_API_KEY")),
171    ProviderConfig::standard("openai", "Openai", "openai", "OpenAI", Some("OPENAI_API_KEY")),
172    ProviderConfig::standard("openrouter", "OpenRouter", "openrouter", "OpenRouter", Some("OPENROUTER_API_KEY")),
173    ProviderConfig {
174        extra_source_ids: &["zai-coding-plan"],
175        ..ProviderConfig::standard("zai", "ZAi", "zai", "ZAI", Some("ZAI_API_KEY"))
176    },
177    ProviderConfig {
178        is_hybrid_dynamic: true,
179        ..ProviderConfig::standard("amazon-bedrock", "Bedrock", "bedrock", "AWS Bedrock", None)
180    },
181];
182
183const DYNAMIC_PROVIDERS: &[DynamicProviderConfig] = &[
184    DynamicProviderConfig { enum_name: "Ollama", parser_name: "ollama", display_name: "Ollama" },
185    DynamicProviderConfig { enum_name: "LlamaCpp", parser_name: "llamacpp", display_name: "LlamaCpp" },
186];
187
188const CODEX_SUBSCRIPTION_CONTEXT_WINDOW: u32 = 272_000;
189
190fn codex_subscription_context_window(model_id: &str, default_context_window: u32) -> u32 {
191    match model_id {
192        "gpt-5.5" | "gpt-5.4" | "gpt-5.4-mini" | "gpt-5.3-codex" | "gpt-5.2" | "codex-auto-review" => {
193            CODEX_SUBSCRIPTION_CONTEXT_WINDOW
194        }
195        _ => default_context_window,
196    }
197}
198
199#[derive(Debug, Clone)]
200struct ModelInfo {
201    variant_name: String,
202    model_id: String,
203    display_name: String,
204    context_window: u32,
205    reasoning_levels: Vec<String>,
206    input_modalities: Vec<String>,
207}
208
209type ProviderModels = BTreeMap<&'static str, Vec<ModelInfo>>;
210
211struct CodegenCtx {
212    provider_models: ProviderModels,
213}
214
215/// Output of the code generator.
216pub struct GeneratedOutput {
217    /// The generated Rust source (for `generated.rs`).
218    pub rust_source: String,
219    /// Per-provider markdown documentation keyed by provider identifier.
220    ///
221    /// Keys are provider `dev_ids` (e.g. `"anthropic"`, `"ollama"`) and values
222    /// are markdown strings suitable for `#![doc = include_str!(...)]`.
223    pub provider_docs: HashMap<String, String>,
224}
225
226/// Run the codegen, returning the generated Rust source and per-provider docs.
227pub fn generate(models_json_path: &Path) -> Result<GeneratedOutput, String> {
228    let json_bytes = std::fs::read_to_string(models_json_path).map_err(|e| format!("read: {e}"))?;
229    let data: ModelsDevData = serde_json::from_str(&json_bytes).map_err(|e| format!("parse: {e}"))?;
230
231    let provider_models = build_provider_models(&data)?;
232    let ctx = CodegenCtx { provider_models };
233    Ok(GeneratedOutput { rust_source: emit_generated_source(&ctx), provider_docs: emit_provider_docs(&ctx) })
234}
235
236fn build_provider_models(data: &ModelsDevData) -> Result<ProviderModels, String> {
237    let mut provider_models = ProviderModels::new();
238
239    for cfg in PROVIDERS {
240        let json_key = cfg.json_key();
241        let provider_data =
242            data.get(json_key).ok_or_else(|| format!("Provider '{json_key}' not found in models.dev data"))?;
243
244        let mut models: Vec<ModelInfo> = collect_models_from(cfg, &provider_data.models);
245
246        for &extra_key in cfg.extra_source_ids {
247            if let Some(extra_data) = data.get(extra_key) {
248                let extra = collect_models_from(cfg, &extra_data.models);
249                let existing_ids: std::collections::HashSet<String> =
250                    models.iter().map(|m| m.model_id.clone()).collect();
251                models.extend(extra.into_iter().filter(|m| !existing_ids.contains(&m.model_id)));
252            }
253        }
254
255        models.sort_by(|a, b| a.model_id.cmp(&b.model_id));
256        provider_models.insert(cfg.dev_id, models);
257    }
258
259    Ok(provider_models)
260}
261
262fn collect_models_from(cfg: &ProviderConfig, models: &HashMap<String, ModelData>) -> Vec<ModelInfo> {
263    models
264        .values()
265        .filter(|m| m.tool_call == Some(true))
266        .filter(|m| !is_alias(&m.id))
267        .filter(|m| cfg.model_filter.is_none_or(|f| f(&m.id)))
268        .map(|m| {
269            let reasoning_levels = if m.reasoning.unwrap_or(false) {
270                cfg.default_reasoning_levels.iter().map(|s| (*s).to_string()).collect()
271            } else {
272                Vec::new()
273            };
274            let input_modalities =
275                m.modalities.as_ref().map_or_else(|| vec!["text".to_string()], |md| md.input.clone());
276            let source_context_window = m.limit.as_ref().map_or(0, |l| l.context);
277            let context_window = cfg.context_window_override.map_or(source_context_window, |override_context_window| {
278                override_context_window(&m.id, source_context_window)
279            });
280            ModelInfo {
281                variant_name: model_id_to_variant(&m.id),
282                model_id: m.id.clone(),
283                display_name: m.name.clone(),
284                context_window,
285                reasoning_levels,
286                input_modalities,
287            }
288        })
289        .collect()
290}
291
292/// Returns true for "latest" alias IDs that just point to another model
293fn is_alias(id: &str) -> bool {
294    id.ends_with("-latest")
295}
296
297/// Convert a model ID like "claude-sonnet-4-5-20250929" into a `PascalCase` variant name.
298/// Treats `-`, `.`, `/`, and `:` as word separators.
299fn model_id_to_variant(id: &str) -> String {
300    let mut result = String::new();
301    let mut capitalize_next = true;
302
303    for ch in id.chars() {
304        if ch == '-' || ch == '.' || ch == '/' || ch == ':' {
305            capitalize_next = true;
306        } else if capitalize_next {
307            result.push(ch.to_ascii_uppercase());
308            capitalize_next = false;
309        } else {
310            result.push(ch);
311        }
312    }
313
314    // If the variant starts with a digit, prefix with underscore
315    if result.starts_with(|c: char| c.is_ascii_digit()) {
316        result.insert(0, '_');
317    }
318
319    result
320}
321
322fn emit_generated_source(ctx: &CodegenCtx) -> String {
323    let mut out = String::with_capacity(64_000);
324    emit_header(&mut out);
325    emit_provider_enums(&mut out, &ctx.provider_models);
326    emit_provider_impls(&mut out, &ctx.provider_models);
327    emit_llm_model_enum(&mut out);
328    emit_from_impls(&mut out);
329    emit_llm_model_impl(&mut out);
330    emit_display_impl(&mut out);
331    emit_fromstr_impl(&mut out);
332    out
333}
334
335fn emit_header(out: &mut String) {
336    pushln(out, "// Auto-generated from models.dev — do not edit manually");
337    pushln(out, "// Regenerated automatically by build.rs");
338    blank(out);
339    pushln(out, "use std::borrow::Cow;");
340    pushln(out, "use std::sync::LazyLock;");
341    pushln(out, "use crate::ReasoningEffort;");
342    blank(out);
343}
344
345fn emit_provider_enums(out: &mut String, provider_models: &ProviderModels) {
346    for cfg in PROVIDERS {
347        let inner = cfg.inner_enum_name();
348        pushln(out, "#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]");
349        pushln(out, format!("pub enum {inner} {{"));
350        for model in &provider_models[cfg.dev_id] {
351            pushln(out, format!("    {},", model.variant_name));
352        }
353        pushln(out, "}");
354        blank(out);
355
356        if cfg.is_hybrid_dynamic {
357            let outer = cfg.outer_enum_name();
358            pushln(out, "#[derive(Debug, Clone, PartialEq, Eq, Hash)]");
359            pushln(out, format!("pub enum {outer} {{"));
360            pushln(out, format!("    Foundation({inner}),"));
361            pushln(out, "    Profile(String),");
362            pushln(out, "}");
363            blank(out);
364        }
365    }
366}
367
368fn emit_provider_impls(out: &mut String, provider_models: &ProviderModels) {
369    for cfg in PROVIDERS {
370        let models = &provider_models[cfg.dev_id];
371        let enum_name = cfg.inner_enum_name();
372
373        pushln(out, format!("impl {enum_name} {{"));
374
375        // model_id — each model has a unique ID, no grouping needed
376        pushln(out, "    #[allow(clippy::too_many_lines)]");
377        pushln(out, "    fn model_id(self) -> &'static str {");
378        pushln(out, "        match self {");
379        for model in models {
380            pushln(out, format!("            Self::{} => \"{}\",", model.variant_name, model.model_id));
381        }
382        pushln(out, "        }");
383        pushln(out, "    }");
384        blank(out);
385
386        // display_name — group variants that share the same name
387        pushln(out, "    #[allow(clippy::too_many_lines)]");
388        pushln(out, "    fn display_name(self) -> &'static str {");
389        pushln(out, "        match self {");
390        emit_grouped_arms(out, models, |m| escape_rust_string(&m.display_name), |name| format!("\"{name}\""));
391        pushln(out, "        }");
392        pushln(out, "    }");
393        blank(out);
394
395        // context_window — group variants that share the same value
396        pushln(out, "    fn context_window(self) -> u32 {");
397        pushln(out, "        match self {");
398        emit_grouped_arms(
399            out,
400            models,
401            |m| m.context_window.to_string(),
402            |val| format_number(val.parse::<u32>().unwrap()),
403        );
404        pushln(out, "        }");
405        pushln(out, "    }");
406        blank(out);
407
408        // reasoning_levels — per-model reasoning level list
409        pushln(out, "    pub fn reasoning_levels(self) -> &'static [ReasoningEffort] {");
410        pushln(out, "        match self {");
411        emit_grouped_arms(out, models, |m| m.reasoning_levels.join(","), format_reasoning_levels_rhs);
412        pushln(out, "        }");
413        pushln(out, "    }");
414        blank(out);
415
416        // supports_reasoning — derived from reasoning_levels
417        pushln(out, "    pub fn supports_reasoning(self) -> bool {");
418        pushln(out, "        !self.reasoning_levels().is_empty()");
419        pushln(out, "    }");
420        blank(out);
421
422        for modality in ["image", "audio"] {
423            pushln(out, format!("    pub fn supports_{modality}(self) -> bool {{"));
424            pushln(out, "        match self {");
425            let mod_owned = modality.to_string();
426            emit_grouped_arms(
427                out,
428                models,
429                |m| m.input_modalities.contains(&mod_owned).to_string(),
430                std::string::ToString::to_string,
431            );
432            pushln(out, "        }");
433            pushln(out, "    }");
434            blank(out);
435        }
436
437        // ALL constant
438        pushln(out, format!("    const ALL: &[{enum_name}] = &["));
439        for model in models {
440            pushln(out, format!("        Self::{},", model.variant_name));
441        }
442        pushln(out, "    ];");
443
444        pushln(out, "}");
445        blank(out);
446
447        emit_from_str_impl(out, &enum_name, cfg.parser_name, models);
448
449        if cfg.is_hybrid_dynamic {
450            emit_hybrid_wrapper_impl(out, &cfg.outer_enum_name(), cfg.enum_name);
451            emit_hybrid_wrapper_from_str(out, &cfg.outer_enum_name(), &enum_name);
452        }
453    }
454}
455
456fn emit_hybrid_wrapper_impl(out: &mut String, outer_name: &str, display_prefix: &str) {
457    pushln(out, format!("impl {outer_name} {{"));
458
459    pushln(out, "    pub fn model_id(&self) -> Cow<'static, str> {");
460    pushln(out, "        match self {");
461    pushln(out, "            Self::Foundation(m) => Cow::Borrowed(m.model_id()),");
462    pushln(out, "            Self::Profile(s) => Cow::Owned(s.clone()),");
463    pushln(out, "        }");
464    pushln(out, "    }");
465    blank(out);
466
467    pushln(out, "    pub fn display_name(&self) -> Cow<'static, str> {");
468    pushln(out, "        match self {");
469    pushln(out, "            Self::Foundation(m) => Cow::Borrowed(m.display_name()),");
470    pushln(out, format!("            Self::Profile(s) => Cow::Owned(format!(\"{display_prefix} {{s}}\")),"));
471    pushln(out, "        }");
472    pushln(out, "    }");
473    blank(out);
474
475    pushln(out, "    pub fn context_window(&self) -> Option<u32> {");
476    pushln(out, "        match self {");
477    pushln(out, "            Self::Foundation(m) => Some(m.context_window()),");
478    pushln(out, "            Self::Profile(_) => None,");
479    pushln(out, "        }");
480    pushln(out, "    }");
481    blank(out);
482
483    pushln(out, "    pub fn reasoning_levels(&self) -> &'static [ReasoningEffort] {");
484    pushln(out, "        match self {");
485    pushln(out, "            Self::Foundation(m) => m.reasoning_levels(),");
486    pushln(out, "            Self::Profile(_) => &[],");
487    pushln(out, "        }");
488    pushln(out, "    }");
489    blank(out);
490
491    pushln(out, "    pub fn supports_reasoning(&self) -> bool {");
492    pushln(out, "        !self.reasoning_levels().is_empty()");
493    pushln(out, "    }");
494    blank(out);
495
496    for modality in ["image", "audio"] {
497        pushln(out, format!("    pub fn supports_{modality}(&self) -> bool {{"));
498        pushln(out, "        match self {");
499        pushln(out, format!("            Self::Foundation(m) => m.supports_{modality}(),"));
500        pushln(out, "            Self::Profile(_) => false,");
501        pushln(out, "        }");
502        pushln(out, "    }");
503        blank(out);
504    }
505
506    pushln(out, "}");
507    blank(out);
508}
509
510fn emit_hybrid_wrapper_from_str(out: &mut String, outer_name: &str, inner_name: &str) {
511    pushln(out, format!("impl std::str::FromStr for {outer_name} {{"));
512    pushln(out, "    type Err = String;");
513    blank(out);
514    pushln(out, "    fn from_str(s: &str) -> Result<Self, Self::Err> {");
515    pushln(out, format!("        match s.parse::<{inner_name}>() {{"));
516    pushln(out, "            Ok(m) => Ok(Self::Foundation(m)),");
517    pushln(out, "            Err(_) => Ok(Self::Profile(s.to_string())),");
518    pushln(out, "        }");
519    pushln(out, "    }");
520    pushln(out, "}");
521    blank(out);
522}
523
524fn emit_from_str_impl(out: &mut String, enum_name: &str, parser_name: &str, models: &[ModelInfo]) {
525    pushln(out, format!("impl std::str::FromStr for {enum_name} {{"));
526    pushln(out, "    type Err = String;");
527    blank(out);
528    pushln(out, "    #[allow(clippy::too_many_lines)]");
529    pushln(out, "    fn from_str(s: &str) -> Result<Self, Self::Err> {");
530    pushln(out, "        match s {");
531    for model in models {
532        pushln(out, format!("            \"{}\" => Ok(Self::{}),", model.model_id, model.variant_name));
533    }
534    pushln(out, format!("            _ => Err(format!(\"Unknown {parser_name} model: '{{s}}'\")),"));
535    pushln(out, "        }");
536    pushln(out, "    }");
537    pushln(out, "}");
538    blank(out);
539}
540
541/// Emit match arms grouped by value to avoid clippy `match_same_arms`.
542///
543/// `key_fn` extracts a grouping key from each model (e.g. `context_window` as string).
544/// `fmt_val` formats the key into the match arm's RHS.
545fn emit_grouped_arms(
546    out: &mut String,
547    models: &[ModelInfo],
548    key_fn: impl Fn(&ModelInfo) -> String,
549    fmt_val: impl Fn(&str) -> String,
550) {
551    // Group variants by value, preserving insertion order via BTreeMap
552    let mut groups: BTreeMap<String, Vec<&str>> = BTreeMap::new();
553    for model in models {
554        groups.entry(key_fn(model)).or_default().push(&model.variant_name);
555    }
556
557    for (key, variants) in &groups {
558        let rhs = fmt_val(key);
559        if variants.len() == 1 {
560            pushln(out, format!("            Self::{} => {rhs},", variants[0]));
561        } else {
562            let patterns: Vec<String> = variants.iter().map(|v| format!("Self::{v}")).collect();
563            pushln(out, format!("            {} => {rhs},", patterns.join(" | ")));
564        }
565    }
566}
567
568/// Format the RHS of a `reasoning_levels()` match arm from a comma-joined key.
569fn format_reasoning_levels_rhs(key: &str) -> String {
570    if key.is_empty() {
571        return "&[]".to_string();
572    }
573    let items: Vec<String> = key
574        .split(',')
575        .map(|l| {
576            let variant = level_str_to_variant(l);
577            format!("ReasoningEffort::{variant}")
578        })
579        .collect();
580    format!("&[{}]", items.join(", "))
581}
582
583/// Map a reasoning level string to its `ReasoningEffort` variant name.
584/// Panics at build time if an unknown level is encountered.
585fn level_str_to_variant(level: &str) -> &'static str {
586    match level {
587        "low" => "Low",
588        "medium" => "Medium",
589        "high" => "High",
590        "xhigh" => "Xhigh",
591        other => panic!("Unknown reasoning level: {other}"),
592    }
593}
594
595fn emit_llm_model_enum(out: &mut String) {
596    pushln(out, "/// A model from a specific provider");
597    pushln(out, "#[derive(Debug, Clone, PartialEq, Eq, Hash)]");
598    pushln(out, "pub enum LlmModel {");
599    for cfg in PROVIDERS {
600        pushln(out, format!("    {provider}({provider}Model),", provider = cfg.enum_name));
601    }
602    for dyn_cfg in DYNAMIC_PROVIDERS {
603        pushln(out, format!("    {}(String),", dyn_cfg.enum_name));
604    }
605    pushln(out, "}");
606    blank(out);
607}
608
609fn emit_from_impls(out: &mut String) {
610    for cfg in PROVIDERS {
611        pushln(out, format!("impl From<{}Model> for LlmModel {{", cfg.enum_name));
612        pushln(out, format!("    fn from(m: {}Model) -> Self {{ LlmModel::{}(m) }}", cfg.enum_name, cfg.enum_name));
613        pushln(out, "}");
614        blank(out);
615    }
616}
617
618fn emit_llm_model_impl(out: &mut String) {
619    pushln(out, "impl LlmModel {");
620    emit_llm_model_id(out);
621    emit_llm_display_name(out);
622    emit_llm_provider(out);
623    emit_llm_provider_display_name(out);
624    emit_llm_context_window(out);
625    emit_llm_required_env_var(out);
626    emit_llm_all_required_env_vars(out);
627    emit_llm_oauth_provider_id(out);
628    emit_llm_reasoning_levels(out);
629    emit_llm_supports_reasoning(out);
630    for modality in ["image", "audio"] {
631        emit_llm_supports_modality(out, modality);
632    }
633    emit_llm_all(out);
634    pushln(out, "}");
635    blank(out);
636}
637
638fn emit_llm_model_id(out: &mut String) {
639    pushln(out, "    /// Raw model ID (e.g. `claude-opus-4-6`, `llama3.2`)");
640    pushln(out, "    pub fn model_id(&self) -> Cow<'static, str> {");
641    pushln(out, "        match self {");
642    for cfg in PROVIDERS {
643        if cfg.is_hybrid_dynamic {
644            pushln(out, format!("            Self::{}(m) => m.model_id(),", cfg.enum_name));
645        } else {
646            pushln(out, format!("            Self::{}(m) => Cow::Borrowed(m.model_id()),", cfg.enum_name));
647        }
648    }
649    pushln(out, format!("            {} => Cow::Owned(s.clone()),", dynamic_pattern_with_binding("s")));
650    pushln(out, "        }");
651    pushln(out, "    }");
652    blank(out);
653}
654
655fn emit_llm_display_name(out: &mut String) {
656    pushln(out, "    /// Human-readable display name (e.g. `Claude Opus 4.6`)");
657    pushln(out, "    pub fn display_name(&self) -> Cow<'static, str> {");
658    pushln(out, "        match self {");
659    for cfg in PROVIDERS {
660        if cfg.is_hybrid_dynamic {
661            pushln(out, format!("            Self::{}(m) => m.display_name(),", cfg.enum_name));
662        } else {
663            pushln(out, format!("            Self::{}(m) => Cow::Borrowed(m.display_name()),", cfg.enum_name));
664        }
665    }
666    for dyn_cfg in DYNAMIC_PROVIDERS {
667        pushln(
668            out,
669            format!(
670                "            Self::{}(s) => Cow::Owned(format!(\"{} {{s}}\")),",
671                dyn_cfg.enum_name, dyn_cfg.enum_name
672            ),
673        );
674    }
675    pushln(out, "        }");
676    pushln(out, "    }");
677    blank(out);
678}
679
680fn emit_llm_provider(out: &mut String) {
681    pushln(out, "    /// Provider identifier (e.g. `anthropic`)");
682    pushln(out, "    pub fn provider(&self) -> &'static str {");
683    pushln(out, "        match self {");
684    for cfg in PROVIDERS {
685        pushln(out, format!("            Self::{}(_) => \"{}\",", cfg.enum_name, cfg.parser_name));
686    }
687    for dyn_cfg in DYNAMIC_PROVIDERS {
688        pushln(out, format!("            Self::{}(_) => \"{}\",", dyn_cfg.enum_name, dyn_cfg.parser_name));
689    }
690    pushln(out, "        }");
691    pushln(out, "    }");
692    blank(out);
693}
694
695fn emit_llm_provider_display_name(out: &mut String) {
696    pushln(out, "    /// Human-readable provider name (e.g. `AWS Bedrock`)");
697    pushln(out, "    pub fn provider_display_name(&self) -> &'static str {");
698    pushln(out, "        match self {");
699    for cfg in PROVIDERS {
700        pushln(out, format!("            Self::{}(_) => \"{}\",", cfg.enum_name, cfg.display_name));
701    }
702    for dyn_cfg in DYNAMIC_PROVIDERS {
703        pushln(out, format!("            Self::{}(_) => \"{}\",", dyn_cfg.enum_name, dyn_cfg.display_name));
704    }
705    pushln(out, "        }");
706    pushln(out, "    }");
707    blank(out);
708}
709
710fn emit_llm_context_window(out: &mut String) {
711    pushln(out, "    /// Context window size in tokens (None for dynamic providers)");
712    pushln(out, "    pub fn context_window(&self) -> Option<u32> {");
713    pushln(out, "        match self {");
714    for cfg in PROVIDERS {
715        if cfg.is_hybrid_dynamic {
716            pushln(out, format!("            Self::{}(m) => m.context_window(),", cfg.enum_name));
717        } else {
718            pushln(out, format!("            Self::{}(m) => Some(m.context_window()),", cfg.enum_name));
719        }
720    }
721    pushln(out, format!("            {} => None,", dynamic_pattern_with_binding("_")));
722    pushln(out, "        }");
723    pushln(out, "    }");
724    blank(out);
725}
726
727fn emit_llm_required_env_var(out: &mut String) {
728    pushln(out, "    /// Required env var for this model's provider (None for local providers)");
729    pushln(out, "    pub fn required_env_var(&self) -> Option<&'static str> {");
730    pushln(out, "        match self {");
731    let mut none_arms: Vec<String> = Vec::new();
732    for cfg in PROVIDERS {
733        match cfg.env_var {
734            Some(var) => pushln(out, format!("            Self::{}(_) => Some(\"{}\"),", cfg.enum_name, var)),
735            None => none_arms.push(format!("Self::{}(_)", cfg.enum_name)),
736        }
737    }
738    for dyn_cfg in DYNAMIC_PROVIDERS {
739        none_arms.push(format!("Self::{}(_)", dyn_cfg.enum_name));
740    }
741    pushln(out, format!("            {} => None,", none_arms.join(" | ")));
742    pushln(out, "        }");
743    pushln(out, "    }");
744    blank(out);
745}
746
747fn emit_llm_all_required_env_vars(out: &mut String) {
748    let vars: Vec<&str> = PROVIDERS.iter().filter_map(|cfg| cfg.env_var).collect();
749    pushln(out, "    /// All provider API key env var names (deduplicated, static)");
750    pushln(
751        out,
752        format!(
753            "    pub const ALL_REQUIRED_ENV_VARS: &[&str] = &[{}];",
754            vars.iter().map(|v| format!("\"{v}\"")).collect::<Vec<_>>().join(", ")
755        ),
756    );
757    blank(out);
758}
759
760fn emit_llm_oauth_provider_id(out: &mut String) {
761    pushln(out, "    /// OAuth provider ID if this model requires OAuth login (e.g. `\"codex\"`)");
762    pushln(out, "    pub fn oauth_provider_id(&self) -> Option<&'static str> {");
763    pushln(out, "        match self {");
764    let mut none_arms: Vec<String> = Vec::new();
765    for cfg in PROVIDERS {
766        match cfg.oauth_provider_id {
767            Some(id) => pushln(out, format!("            Self::{}(_) => Some(\"{}\"),", cfg.enum_name, id)),
768            None => none_arms.push(format!("Self::{}(_)", cfg.enum_name)),
769        }
770    }
771    for dyn_cfg in DYNAMIC_PROVIDERS {
772        none_arms.push(format!("Self::{}(_)", dyn_cfg.enum_name));
773    }
774    pushln(out, format!("            {} => None,", none_arms.join(" | ")));
775    pushln(out, "        }");
776    pushln(out, "    }");
777    blank(out);
778}
779
780fn emit_llm_reasoning_levels(out: &mut String) {
781    pushln(out, "    /// Reasoning levels supported by this model (empty if not a reasoning model)");
782    pushln(out, "    pub fn reasoning_levels(&self) -> &'static [ReasoningEffort] {");
783    pushln(out, "        match self {");
784    for cfg in PROVIDERS {
785        pushln(out, format!("            Self::{}(m) => m.reasoning_levels(),", cfg.enum_name));
786    }
787    pushln(out, format!("            {} => &[],", dynamic_pattern_with_binding("_")));
788    pushln(out, "        }");
789    pushln(out, "    }");
790    blank(out);
791}
792
793fn emit_llm_supports_reasoning(out: &mut String) {
794    pushln(out, "    /// Whether this model supports reasoning/extended thinking");
795    pushln(out, "    pub fn supports_reasoning(&self) -> bool {");
796    pushln(out, "        !self.reasoning_levels().is_empty()");
797    pushln(out, "    }");
798    blank(out);
799}
800
801fn emit_llm_supports_modality(out: &mut String, modality: &str) {
802    pushln(out, format!("    /// Whether this model supports {modality} input"));
803    pushln(out, format!("    pub fn supports_{modality}(&self) -> bool {{"));
804    pushln(out, "        match self {");
805    for cfg in PROVIDERS {
806        pushln(out, format!("            Self::{}(m) => m.supports_{modality}(),", cfg.enum_name));
807    }
808    pushln(out, format!("            {} => false,", dynamic_pattern_with_binding("_")));
809    pushln(out, "        }");
810    pushln(out, "    }");
811    blank(out);
812}
813
814fn emit_llm_all(out: &mut String) {
815    pushln(out, "    /// All catalog models (excludes dynamic providers)");
816    pushln(out, "    pub fn all() -> &'static [LlmModel] {");
817    pushln(out, "        static ALL: LazyLock<Vec<LlmModel>> = LazyLock::new(|| {");
818    pushln(out, "            let mut v = Vec::new();");
819    for cfg in PROVIDERS {
820        if cfg.is_hybrid_dynamic {
821            pushln(
822                out,
823                format!(
824                    "            v.extend({inner}::ALL.iter().copied().map({outer}::Foundation).map(LlmModel::{variant}));",
825                    inner = cfg.inner_enum_name(),
826                    outer = cfg.outer_enum_name(),
827                    variant = cfg.enum_name,
828                ),
829            );
830        } else {
831            pushln(
832                out,
833                format!(
834                    "            v.extend({}Model::ALL.iter().copied().map(LlmModel::{}));",
835                    cfg.enum_name, cfg.enum_name
836                ),
837            );
838        }
839    }
840    pushln(out, "            v");
841    pushln(out, "        });");
842    pushln(out, "        &ALL");
843    pushln(out, "    }");
844}
845
846fn emit_display_impl(out: &mut String) {
847    pushln(out, "impl std::fmt::Display for LlmModel {");
848    pushln(out, "    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {");
849    pushln(out, "        write!(f, \"{}:{}\", self.provider(), self.model_id())");
850    pushln(out, "    }");
851    pushln(out, "}");
852    blank(out);
853}
854
855fn emit_fromstr_impl(out: &mut String) {
856    pushln(out, "impl std::str::FromStr for LlmModel {");
857    pushln(out, "    type Err = String;");
858    blank(out);
859    pushln(out, "    /// Parse a `provider:model` string into an `LlmModel`");
860    pushln(out, "    fn from_str(s: &str) -> Result<Self, Self::Err> {");
861    pushln(out, "        let (provider_str, model_str) = s.split_once(':').unwrap_or((s, \"\"));");
862    pushln(out, "        match provider_str {");
863    for cfg in PROVIDERS {
864        pushln(
865            out,
866            format!(
867                "            \"{}\" => model_str.parse::<{}Model>().map(Self::{}),",
868                cfg.parser_name, cfg.enum_name, cfg.enum_name
869            ),
870        );
871    }
872    for dyn_cfg in DYNAMIC_PROVIDERS {
873        pushln(
874            out,
875            format!(
876                "            \"{}\" => Ok(Self::{}(model_str.to_string())),",
877                dyn_cfg.parser_name, dyn_cfg.enum_name
878            ),
879        );
880    }
881    pushln(out, "            _ => Err(format!(\"Unknown provider: '{provider_str}'\")),");
882    pushln(out, "        }");
883    pushln(out, "    }");
884    pushln(out, "}");
885}
886
887/// Build a combined `|` pattern for all dynamic providers with a binding variable.
888/// e.g. `Self::Ollama(s) | Self::LlamaCpp(s)` or `Self::Ollama(_) | Self::LlamaCpp(_)`
889fn dynamic_pattern_with_binding(binding: &str) -> String {
890    DYNAMIC_PROVIDERS.iter().map(|d| format!("Self::{}({binding})", d.enum_name)).collect::<Vec<_>>().join(" | ")
891}
892
893/// Format a number with underscore separators (e.g. `200000` → `200_000`).
894fn format_number(n: u32) -> String {
895    let s = n.to_string();
896    if s.len() <= 4 {
897        return s;
898    }
899    let mut result = String::with_capacity(s.len() + s.len() / 3);
900    for (i, ch) in s.chars().enumerate() {
901        if i > 0 && (s.len() - i).is_multiple_of(3) {
902            result.push('_');
903        }
904        result.push(ch);
905    }
906    result
907}
908
909fn escape_rust_string(raw: &str) -> String {
910    raw.replace('\\', "\\\\").replace('"', "\\\"")
911}
912
913fn emit_provider_docs(ctx: &CodegenCtx) -> HashMap<String, String> {
914    let mut docs = HashMap::new();
915
916    for cfg in PROVIDERS {
917        let models = &ctx.provider_models[cfg.dev_id];
918        let mut doc = String::new();
919
920        pushln(&mut doc, format!("`{}` LLM provider.", cfg.display_name));
921        blank(&mut doc);
922
923        // Authentication
924        pushln(&mut doc, "# Authentication");
925        blank(&mut doc);
926        match cfg.env_var {
927            Some(var) => pushln(&mut doc, format!("Set the `{var}` environment variable.")),
928            None if cfg.oauth_provider_id.is_some() => {
929                pushln(&mut doc, "This provider uses OAuth authentication.");
930            }
931            None => {
932                pushln(
933                    &mut doc,
934                    "Uses the default AWS credential chain (environment variables, config files, IAM roles).",
935                );
936            }
937        }
938        blank(&mut doc);
939
940        // Supported models table
941        pushln(&mut doc, "# Supported models");
942        blank(&mut doc);
943        pushln(&mut doc, "| Model ID | Name | Context | Reasoning | Image | Audio |");
944        pushln(&mut doc, "|----------|------|---------|-----------|-------|-------|");
945        for model in models {
946            let ctx_str = format_context_window(model.context_window);
947            let reasoning = if model.reasoning_levels.is_empty() { "" } else { "yes" };
948            let image = if model.input_modalities.contains(&"image".to_string()) { "yes" } else { "" };
949            let audio = if model.input_modalities.contains(&"audio".to_string()) { "yes" } else { "" };
950            pushln(
951                &mut doc,
952                format!(
953                    "| `{}` | `{}` | `{}` | {} | {} | {} |",
954                    model.model_id, model.display_name, ctx_str, reasoning, image, audio
955                ),
956            );
957        }
958
959        docs.insert(cfg.dev_id.to_string(), doc);
960    }
961
962    // Dynamic providers
963    for dyn_cfg in DYNAMIC_PROVIDERS {
964        let mut doc = String::new();
965        pushln(&mut doc, format!("`{}` LLM provider.", dyn_cfg.display_name));
966        blank(&mut doc);
967        pushln(
968            &mut doc,
969            format!("This provider accepts any model name at runtime (e.g. `{}:my-model`).", dyn_cfg.parser_name),
970        );
971        pushln(&mut doc, "No API key is required.");
972        docs.insert(dyn_cfg.parser_name.to_string(), doc);
973    }
974
975    docs
976}
977
978/// Format a token count as human-readable (e.g. `1_000_000` → `1M`, `200_000` → `200k`).
979fn format_context_window(tokens: u32) -> String {
980    if tokens == 0 {
981        return "unknown".to_string();
982    }
983    if tokens >= 1_000_000 && tokens.is_multiple_of(1_000_000) {
984        format!("{}M", tokens / 1_000_000)
985    } else if tokens >= 1_000 && tokens.is_multiple_of(1_000) {
986        format!("{}k", tokens / 1_000)
987    } else {
988        format_number(tokens)
989    }
990}
991
992fn pushln(out: &mut String, line: impl AsRef<str>) {
993    writeln!(out, "{}", line.as_ref()).expect("writing to String should not fail");
994}
995
996fn blank(out: &mut String) {
997    pushln(out, "");
998}
999
1000// ── Tests ────────────────────────────────────────────────────────────────
1001
1002#[cfg(test)]
1003mod tests {
1004    use super::*;
1005    use serde_json::Value;
1006    use serde_json::json;
1007    use tempfile::NamedTempFile;
1008
1009    #[test]
1010    fn generate_sorts_and_filters_models() {
1011        let mut data = minimal_models_dev_json();
1012        let root = data.as_object_mut().expect("root object");
1013        let anthropic = root.get_mut("anthropic").and_then(Value::as_object_mut).expect("anthropic provider");
1014
1015        anthropic.insert(
1016            "models".to_string(),
1017            json!({
1018                "b-model": {
1019                    "id": "b-model",
1020                    "name": "B Model",
1021                    "tool_call": true,
1022                    "limit": {"context": 2000, "output": 0}
1023                },
1024                "a-model": {
1025                    "id": "a-model",
1026                    "name": "A Model",
1027                    "tool_call": true,
1028                    "limit": {"context": 1000, "output": 0}
1029                },
1030                "alpha-latest": {
1031                    "id": "alpha-latest",
1032                    "name": "Alias",
1033                    "tool_call": true,
1034                    "limit": {"context": 500, "output": 0}
1035                },
1036                "no-tools": {
1037                    "id": "no-tools",
1038                    "name": "No Tools",
1039                    "tool_call": false,
1040                    "limit": {"context": 500, "output": 0}
1041                }
1042            }),
1043        );
1044
1045        let source = generate_from_value(&data);
1046        // Provider-level FromStr: sorted model IDs
1047        let a_model = "\"a-model\" => Ok(Self::AModel),";
1048        let b_model = "\"b-model\" => Ok(Self::BModel),";
1049        let a_pos = source.find(a_model).expect("a-model parse arm");
1050        let b_pos = source.find(b_model).expect("b-model parse arm");
1051        assert!(a_pos < b_pos);
1052        // Aliases and non-tool-call models are excluded
1053        assert!(!source.contains("AlphaLatest"));
1054        assert!(!source.contains("NoTools"));
1055    }
1056
1057    #[test]
1058    fn generate_contains_core_sections() {
1059        let source = generate_from_value(&minimal_models_dev_json());
1060        assert!(source.contains("pub enum LlmModel {"));
1061        assert!(source.contains("impl std::str::FromStr for LlmModel {"));
1062        assert!(source.contains("impl std::fmt::Display for LlmModel {"));
1063        assert!(source.contains("pub fn required_env_var(&self) -> Option<&'static str> {"));
1064    }
1065
1066    #[test]
1067    fn generate_contains_dynamic_provider_arms() {
1068        let source = generate_from_value(&minimal_models_dev_json());
1069        assert!(source.contains("\"ollama\" => Ok(Self::Ollama(model_str.to_string())),"));
1070        assert!(source.contains("\"llamacpp\" => Ok(Self::LlamaCpp(model_str.to_string())),"));
1071        // Dynamic providers are combined with | for None-returning arms
1072        assert!(source.contains("Self::Ollama(_) | Self::LlamaCpp(_) => None,"));
1073    }
1074
1075    #[test]
1076    fn generate_codex_is_catalog_provider() {
1077        let source = generate_from_value(&minimal_models_dev_json());
1078        // Codex is a catalog provider, not dynamic
1079        assert!(source.contains("pub enum CodexModel {"));
1080        assert!(source.contains("\"codex\" => model_str.parse::<CodexModel>().map(Self::Codex),"));
1081        assert!(source.contains("Self::Codex(m) => Some(m.context_window()),"));
1082    }
1083
1084    #[test]
1085    fn generate_bedrock_is_hybrid_provider() {
1086        let mut data = minimal_models_dev_json();
1087        let root = data.as_object_mut().unwrap();
1088        let bedrock = root.get_mut("amazon-bedrock").and_then(Value::as_object_mut).unwrap();
1089
1090        bedrock.insert(
1091            "models".to_string(),
1092            json!({
1093                "anthropic.foo-v1:0": {
1094                    "id": "anthropic.foo-v1:0",
1095                    "name": "Foo Model",
1096                    "tool_call": true,
1097                    "limit": {"context": 200_000, "output": 0}
1098                }
1099            }),
1100        );
1101
1102        let source = generate_from_value(&data);
1103
1104        assert!(source.contains("pub enum BedrockFoundationModel {"));
1105        assert!(source.contains("AnthropicFooV10"));
1106        assert!(source.contains("pub enum BedrockModel {"));
1107        assert!(source.contains("Foundation(BedrockFoundationModel)"));
1108        assert!(source.contains("Profile(String)"));
1109        assert!(source.contains("impl std::str::FromStr for BedrockFoundationModel"));
1110        assert!(source.contains("\"anthropic.foo-v1:0\" => Ok(Self::AnthropicFooV10),"));
1111        assert!(source.contains("Unknown bedrock model"));
1112
1113        assert!(source.contains("impl std::str::FromStr for BedrockModel"));
1114        assert!(source.contains("Err(_) => Ok(Self::Profile(s.to_string())),"));
1115        assert!(source.contains("Self::Profile(s) => Cow::Owned(s.clone()),"));
1116        assert!(source.contains("Self::Profile(_) => None,"));
1117        assert!(source.contains("Self::Profile(_) => &[],"));
1118        assert!(source.contains("Self::Profile(_) => false,"));
1119        assert!(source.contains("Self::Profile(s) => Cow::Owned(format!(\"Bedrock {s}\"))"));
1120
1121        assert!(source.contains(
1122            "BedrockFoundationModel::ALL.iter().copied().map(BedrockModel::Foundation).map(LlmModel::Bedrock)"
1123        ));
1124
1125        assert!(source.contains("Self::Bedrock(m) => m.model_id(),"));
1126        assert!(source.contains("Self::Bedrock(m) => m.display_name(),"));
1127        assert!(source.contains("Self::Bedrock(m) => m.context_window(),"));
1128    }
1129
1130    #[test]
1131    fn generate_non_hybrid_providers_keep_flat_enum() {
1132        let source = generate_from_value(&minimal_models_dev_json());
1133        assert!(source.contains("pub enum AnthropicModel {"));
1134        assert!(!source.contains("AnthropicFoundationModel"));
1135        assert!(!source.contains("Foundation(AnthropicModel)"));
1136        assert!(source.contains("Self::Anthropic(m) => Cow::Borrowed(m.model_id()),"));
1137        assert!(source.contains("Self::Anthropic(m) => Some(m.context_window()),"));
1138    }
1139
1140    #[test]
1141    fn generate_oauth_provider_id_for_codex() {
1142        let source = generate_from_value(&minimal_models_dev_json());
1143        // Codex models return Some("codex") for oauth_provider_id
1144        assert!(source.contains("Self::Codex(_) => Some(\"codex\"),"));
1145        // Non-OAuth providers return None
1146        assert!(source.contains("pub fn oauth_provider_id(&self) -> Option<&'static str>"));
1147    }
1148
1149    #[test]
1150    fn generate_delegates_to_provider_impls() {
1151        let source = generate_from_value(&minimal_models_dev_json());
1152        // LlmModel delegates to per-provider methods
1153        assert!(source.contains("Self::Anthropic(m) => Cow::Borrowed(m.model_id()),"));
1154        assert!(source.contains("Self::Anthropic(m) => Some(m.context_window()),"));
1155        // Provider-level FromStr is used by LlmModel::FromStr
1156        assert!(source.contains("\"anthropic\" => model_str.parse::<AnthropicModel>().map(Self::Anthropic),"));
1157    }
1158
1159    #[test]
1160    fn generate_formats_large_numbers_with_separators() {
1161        let mut data = minimal_models_dev_json();
1162        let root = data.as_object_mut().expect("root object");
1163        let anthropic = root.get_mut("anthropic").and_then(Value::as_object_mut).expect("anthropic provider");
1164
1165        anthropic.insert(
1166            "models".to_string(),
1167            json!({
1168                "big-model": {
1169                    "id": "big-model",
1170                    "name": "Big Model",
1171                    "tool_call": true,
1172                    "limit": {"context": 200_000, "output": 0}
1173                }
1174            }),
1175        );
1176
1177        let source = generate_from_value(&data);
1178        assert!(source.contains("200_000"));
1179        assert!(!source.contains("200000"));
1180    }
1181
1182    #[test]
1183    fn generate_groups_identical_match_arms() {
1184        let mut data = minimal_models_dev_json();
1185        let root = data.as_object_mut().expect("root object");
1186        let anthropic = root.get_mut("anthropic").and_then(Value::as_object_mut).expect("anthropic provider");
1187
1188        anthropic.insert(
1189            "models".to_string(),
1190            json!({
1191                "model-a": {
1192                    "id": "model-a",
1193                    "name": "Same Name",
1194                    "tool_call": true,
1195                    "limit": {"context": 100_000, "output": 0}
1196                },
1197                "model-b": {
1198                    "id": "model-b",
1199                    "name": "Same Name",
1200                    "tool_call": true,
1201                    "limit": {"context": 100_000, "output": 0}
1202                }
1203            }),
1204        );
1205
1206        let source = generate_from_value(&data);
1207        // Both context_window and display_name should combine arms
1208        assert!(source.contains("Self::ModelA | Self::ModelB => 100_000,"));
1209        assert!(source.contains("Self::ModelA | Self::ModelB => \"Same Name\","));
1210    }
1211
1212    #[test]
1213    fn test_model_id_to_variant() {
1214        assert_eq!(model_id_to_variant("claude-sonnet-4-5-20250929"), "ClaudeSonnet4520250929");
1215        assert_eq!(model_id_to_variant("gemini-2.5-flash"), "Gemini25Flash");
1216        assert_eq!(model_id_to_variant("deepseek-chat"), "DeepseekChat");
1217        assert_eq!(model_id_to_variant("glm-4.5"), "Glm45");
1218    }
1219
1220    #[test]
1221    fn test_model_id_to_variant_with_slash_and_colon() {
1222        assert_eq!(model_id_to_variant("anthropic/claude-opus-4.6"), "AnthropicClaudeOpus46");
1223        assert_eq!(model_id_to_variant("openai/gpt-5.1-codex-max"), "OpenaiGpt51CodexMax");
1224        assert_eq!(model_id_to_variant("deepseek/deepseek-r1:free"), "DeepseekDeepseekR1Free");
1225    }
1226
1227    #[test]
1228    fn test_is_alias() {
1229        assert!(is_alias("claude-sonnet-4-5-latest"));
1230        assert!(is_alias("claude-3-7-sonnet-latest"));
1231        assert!(!is_alias("claude-sonnet-4-5-20250929"));
1232    }
1233
1234    #[test]
1235    fn generate_contains_reasoning_levels_and_supports_reasoning() {
1236        let mut data = minimal_models_dev_json();
1237        let root = data.as_object_mut().expect("root object");
1238        let anthropic = root.get_mut("anthropic").and_then(Value::as_object_mut).expect("anthropic provider");
1239
1240        anthropic.insert(
1241            "models".to_string(),
1242            json!({
1243                "thinker": {
1244                    "id": "thinker",
1245                    "name": "Thinker",
1246                    "tool_call": true,
1247                    "reasoning": true,
1248                    "limit": {"context": 1000, "output": 0}
1249                },
1250                "fast": {
1251                    "id": "fast",
1252                    "name": "Fast",
1253                    "tool_call": true,
1254                    "reasoning": false,
1255                    "limit": {"context": 1000, "output": 0}
1256                }
1257            }),
1258        );
1259
1260        let source = generate_from_value(&data);
1261        // Provider enum should have reasoning_levels method
1262        assert!(source.contains("pub fn reasoning_levels(self) -> &'static [ReasoningEffort] {"));
1263        // Thinker (anthropic) gets standard 3 levels
1264        assert!(
1265            source
1266                .contains("Self::Thinker => &[ReasoningEffort::Low, ReasoningEffort::Medium, ReasoningEffort::High],")
1267        );
1268        // Fast gets empty
1269        assert!(source.contains("Self::Fast => &[],"));
1270        // supports_reasoning delegates to reasoning_levels
1271        assert!(source.contains("pub fn supports_reasoning(self) -> bool {"));
1272        assert!(source.contains("!self.reasoning_levels().is_empty()"));
1273        // LlmModel level
1274        assert!(source.contains("pub fn reasoning_levels(&self) -> &'static [ReasoningEffort] {"));
1275        // Dynamic providers return empty
1276        assert!(source.contains("Self::Ollama(_) | Self::LlamaCpp(_) => &[],"));
1277    }
1278
1279    #[test]
1280    fn generate_codex_gets_four_reasoning_levels() {
1281        let mut data = minimal_models_dev_json();
1282        let root = data.as_object_mut().expect("root object");
1283        let openai = root.get_mut("openai").and_then(Value::as_object_mut).expect("openai provider");
1284
1285        openai.insert(
1286            "models".to_string(),
1287            json!({
1288                "gpt-5.4-codex": {
1289                    "id": "gpt-5.4-codex",
1290                    "name": "GPT-5.4 Codex",
1291                    "tool_call": true,
1292                    "reasoning": true,
1293                    "limit": {"context": 200_000, "output": 0}
1294                }
1295            }),
1296        );
1297
1298        let source = generate_from_value(&data);
1299        assert!(
1300            source.contains(
1301                "ReasoningEffort::Low, ReasoningEffort::Medium, ReasoningEffort::High, ReasoningEffort::Xhigh"
1302            ),
1303            "Codex reasoning model should have 4 levels including Xhigh"
1304        );
1305    }
1306
1307    #[test]
1308    fn generate_codex_overrides_gpt55_subscription_context_window() {
1309        let mut data = minimal_models_dev_json();
1310        let root = data.as_object_mut().expect("root object");
1311        let openai = root.get_mut("openai").and_then(Value::as_object_mut).expect("openai provider");
1312
1313        openai.insert(
1314            "models".to_string(),
1315            json!({
1316                "gpt-5.5": {
1317                    "id": "gpt-5.5",
1318                    "name": "GPT-5.5",
1319                    "tool_call": true,
1320                    "reasoning": true,
1321                    "limit": {"context": 1_050_000, "output": 128_000}
1322                }
1323            }),
1324        );
1325
1326        let source = generate_from_value(&data);
1327        let codex_impl = generated_impl_block(&source, "CodexModel");
1328        let openai_impl = generated_impl_block(&source, "OpenaiModel");
1329
1330        assert!(codex_impl.contains("Self::Gpt55 => 272_000,"));
1331        assert!(openai_impl.contains("Self::Gpt55 => 1_050_000,"));
1332    }
1333
1334    #[test]
1335    fn codex_subscription_context_window_overrides_known_codex_models() {
1336        for model_id in ["gpt-5.5", "gpt-5.4", "gpt-5.4-mini", "gpt-5.3-codex", "gpt-5.2", "codex-auto-review"] {
1337            assert_eq!(codex_subscription_context_window(model_id, 1_050_000), 272_000);
1338        }
1339    }
1340
1341    #[test]
1342    fn codex_subscription_context_window_leaves_unknown_models_unchanged() {
1343        assert_eq!(codex_subscription_context_window("gpt-5.3-codex-spark", 128_000), 128_000);
1344        assert_eq!(codex_subscription_context_window("some-future-model", 400_000), 400_000);
1345    }
1346
1347    fn generate_from_value(data: &Value) -> String {
1348        let tmp = NamedTempFile::new().expect("temp file");
1349        let json = serde_json::to_string(data).expect("serialize fixture");
1350        std::fs::write(tmp.path(), json).expect("write fixture");
1351        generate(tmp.path()).expect("codegen succeeds").rust_source
1352    }
1353
1354    fn generated_impl_block<'a>(source: &'a str, enum_name: &str) -> &'a str {
1355        let start_marker = format!("impl {enum_name} {{");
1356        let start = source.find(&start_marker).expect("provider impl block start");
1357        let rest = &source[start..];
1358        let end_marker = format!("impl std::str::FromStr for {enum_name}");
1359        let end = rest.find(&end_marker).expect("provider impl block end");
1360        &rest[..end]
1361    }
1362    fn minimal_models_dev_json() -> Value {
1363        let mut root = serde_json::Map::new();
1364        for cfg in PROVIDERS {
1365            let json_key = cfg.json_key();
1366            root.entry(json_key.to_string()).or_insert_with(|| {
1367                json!({
1368                    "id": json_key,
1369                    "name": json_key,
1370                    "env": [],
1371                    "models": {}
1372                })
1373            });
1374            for &extra in cfg.extra_source_ids {
1375                root.entry(extra.to_string()).or_insert_with(|| {
1376                    json!({
1377                        "id": extra,
1378                        "name": extra,
1379                        "env": [],
1380                        "models": {}
1381                    })
1382                });
1383            }
1384        }
1385        Value::Object(root)
1386    }
1387
1388    #[test]
1389    fn extra_source_ids_merges_models_into_provider() {
1390        let mut data = minimal_models_dev_json();
1391        let root = data.as_object_mut().unwrap();
1392
1393        // Add a model to the zai-coding-plan extra source that doesn't exist in zai
1394        let extra = root.get_mut("zai-coding-plan").unwrap().as_object_mut().unwrap();
1395        extra.insert(
1396            "models".to_string(),
1397            json!({
1398                "extra-model": {
1399                    "id": "extra-model",
1400                    "name": "Extra Model",
1401                    "tool_call": true,
1402                    "limit": {"context": 4000, "output": 0}
1403                }
1404            }),
1405        );
1406
1407        let source = generate_from_value(&data);
1408        // The extra model should appear under ZAi
1409        assert!(source.contains("\"extra-model\" => Ok(Self::ExtraModel),"));
1410    }
1411
1412    #[test]
1413    fn extra_source_ids_does_not_duplicate_existing_models() {
1414        let mut data = minimal_models_dev_json();
1415        let root = data.as_object_mut().unwrap();
1416
1417        // Add same model ID to both zai and zai-coding-plan
1418        let zai = root.get_mut("zai").unwrap().as_object_mut().unwrap();
1419        zai.insert(
1420            "models".to_string(),
1421            json!({
1422                "shared-model": {
1423                    "id": "shared-model",
1424                    "name": "Shared Model",
1425                    "tool_call": true,
1426                    "limit": {"context": 1000, "output": 0}
1427                }
1428            }),
1429        );
1430        let extra = root.get_mut("zai-coding-plan").unwrap().as_object_mut().unwrap();
1431        extra.insert(
1432            "models".to_string(),
1433            json!({
1434                "shared-model": {
1435                    "id": "shared-model",
1436                    "name": "Shared Model Duplicate",
1437                    "tool_call": true,
1438                    "limit": {"context": 2000, "output": 0}
1439                }
1440            }),
1441        );
1442
1443        let source = generate_from_value(&data);
1444        let from_str_matches = source.matches("\"shared-model\" => Ok(Self::SharedModel),").count();
1445        assert_eq!(from_str_matches, 1);
1446    }
1447
1448    #[test]
1449    fn generate_emits_provider_docs() {
1450        let mut data = minimal_models_dev_json();
1451        let root = data.as_object_mut().unwrap();
1452        let anthropic = root.get_mut("anthropic").and_then(Value::as_object_mut).unwrap();
1453
1454        anthropic.insert(
1455            "models".to_string(),
1456            json!({
1457                "claude-test": {
1458                    "id": "claude-test",
1459                    "name": "Claude Test",
1460                    "tool_call": true,
1461                    "reasoning": true,
1462                    "limit": {"context": 200_000, "output": 0},
1463                    "modalities": {"input": ["text", "image"]}
1464                }
1465            }),
1466        );
1467
1468        let tmp = NamedTempFile::new().unwrap();
1469        std::fs::write(tmp.path(), serde_json::to_string(&data).unwrap()).unwrap();
1470        let output = generate(tmp.path()).unwrap();
1471
1472        let anthropic_doc = &output.provider_docs["anthropic"];
1473        assert!(anthropic_doc.contains("`Anthropic` LLM provider."));
1474        assert!(anthropic_doc.contains("`ANTHROPIC_API_KEY`"));
1475        assert!(anthropic_doc.contains("| `claude-test` | `Claude Test` | `200k` | yes | yes |  |"));
1476
1477        // Dynamic providers get a short doc
1478        let ollama_doc = &output.provider_docs["ollama"];
1479        assert!(ollama_doc.contains("`Ollama` LLM provider."));
1480        assert!(ollama_doc.contains("any model name at runtime"));
1481    }
1482
1483    #[test]
1484    fn format_context_window_formats_correctly() {
1485        assert_eq!(format_context_window(1_000_000), "1M");
1486        assert_eq!(format_context_window(200_000), "200k");
1487        assert_eq!(format_context_window(8_000), "8k");
1488        assert_eq!(format_context_window(0), "unknown");
1489    }
1490
1491    #[test]
1492    fn level_str_to_variant_covers_all_reasoning_efforts() {
1493        for effort in utils::ReasoningEffort::all() {
1494            // Should not panic for any known variant
1495            let _ = level_str_to_variant(effort.as_str());
1496        }
1497    }
1498}