Skip to main content

llm_codegen/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use serde::Deserialize;
4use std::collections::{BTreeMap, HashMap};
5use std::fmt::Write;
6use std::path::Path;
7
8type ModelsDevData = HashMap<String, ProviderData>;
9type ContextWindowOverride = fn(&str, u32) -> u32;
10
11#[derive(Debug, Deserialize)]
12struct ProviderData {
13    #[allow(dead_code)]
14    id: String,
15    #[allow(dead_code)]
16    name: String,
17    #[serde(default)]
18    #[allow(dead_code)]
19    env: Vec<String>,
20    #[serde(default)]
21    models: HashMap<String, ModelData>,
22}
23
24#[derive(Debug, Deserialize)]
25struct ModelData {
26    id: String,
27    name: String,
28    #[serde(default)]
29    tool_call: Option<bool>,
30    #[serde(default)]
31    reasoning: Option<bool>,
32    #[serde(default)]
33    #[allow(dead_code)]
34    cost: Option<CostData>,
35    #[serde(default)]
36    limit: Option<LimitData>,
37    #[serde(default)]
38    modalities: Option<ModalitiesData>,
39}
40
41#[derive(Debug, Deserialize, Default)]
42struct ModalitiesData {
43    #[serde(default)]
44    input: Vec<String>,
45}
46
47#[derive(Debug, Deserialize)]
48#[allow(dead_code)]
49struct CostData {
50    #[serde(default)]
51    input: f64,
52    #[serde(default)]
53    output: f64,
54}
55
56#[derive(Debug, Deserialize)]
57struct LimitData {
58    #[serde(default)]
59    context: u32,
60    #[serde(default)]
61    #[allow(dead_code)]
62    output: u32,
63}
64
65/// Provider configuration for codegen (catalog providers with known model lists)
66struct ProviderConfig {
67    /// Unique provider key used in `provider_models` map (e.g. "codex")
68    dev_id: &'static str,
69    /// models.dev provider ID to read models from (defaults to `dev_id` when `None`)
70    source_dev_id: Option<&'static str>,
71    /// Additional models.dev keys whose models are merged into this provider
72    extra_source_ids: &'static [&'static str],
73    /// Only include models whose ID passes this filter (None = include all)
74    model_filter: Option<fn(&str) -> bool>,
75    /// Provider-specific generated context window override
76    context_window_override: Option<ContextWindowOverride>,
77    /// Our Rust enum name (e.g. "Gemini")
78    enum_name: &'static str,
79    /// Our internal provider name used for parsing (e.g. "gemini")
80    parser_name: &'static str,
81    /// Human-readable provider name (e.g. "AWS Bedrock")
82    display_name: &'static str,
83    /// Env var our code actually checks (None for providers with complex credential chains)
84    env_var: Option<&'static str>,
85    /// OAuth provider ID for providers that require OAuth login (e.g. "codex")
86    oauth_provider_id: Option<&'static str>,
87    /// Default reasoning levels for models that support reasoning (empty = use standard 3)
88    default_reasoning_levels: &'static [&'static str],
89}
90
91impl ProviderConfig {
92    /// Shorthand for providers with default `source_dev_id`, `model_filter`, and `oauth_provider_id`.
93    const fn standard(
94        dev_id: &'static str,
95        enum_name: &'static str,
96        parser_name: &'static str,
97        display_name: &'static str,
98        env_var: Option<&'static str>,
99    ) -> Self {
100        Self {
101            dev_id,
102            source_dev_id: None,
103            extra_source_ids: &[],
104            model_filter: None,
105            context_window_override: None,
106            enum_name,
107            parser_name,
108            display_name,
109            env_var,
110            oauth_provider_id: None,
111            default_reasoning_levels: &["low", "medium", "high"],
112        }
113    }
114
115    /// The models.dev key to look up in the JSON data.
116    fn json_key(&self) -> &'static str {
117        self.source_dev_id.unwrap_or(self.dev_id)
118    }
119}
120
121/// Dynamic provider — model name is user-supplied at runtime, no fixed enum
122#[allow(clippy::struct_field_names)]
123struct DynamicProviderConfig {
124    /// Rust variant name in `LlmModel` (e.g. "Ollama")
125    enum_name: &'static str,
126    /// Parser name used in "provider:model" strings (e.g. "ollama")
127    parser_name: &'static str,
128    /// Human-readable provider name (e.g. "Ollama")
129    display_name: &'static str,
130}
131
132const PROVIDERS: &[ProviderConfig] = &[
133    ProviderConfig::standard("anthropic", "Anthropic", "anthropic", "Anthropic", Some("ANTHROPIC_API_KEY")),
134    ProviderConfig {
135        dev_id: "codex",
136        source_dev_id: Some("openai"),
137        extra_source_ids: &[],
138        model_filter: Some(|id| id.contains("codex") || id.starts_with("gpt-5.") || id == "gpt-5"),
139        context_window_override: Some(codex_subscription_context_window),
140        enum_name: "Codex",
141        parser_name: "codex",
142        display_name: "Codex",
143        env_var: None,
144        oauth_provider_id: Some("codex"),
145        default_reasoning_levels: &["low", "medium", "high", "xhigh"],
146    },
147    ProviderConfig::standard("deepseek", "DeepSeek", "deepseek", "DeepSeek", Some("DEEPSEEK_API_KEY")),
148    ProviderConfig::standard("google", "Gemini", "gemini", "Gemini", Some("GEMINI_API_KEY")),
149    ProviderConfig::standard("moonshotai", "Moonshot", "moonshot", "Moonshot", Some("MOONSHOT_API_KEY")),
150    ProviderConfig::standard("openai", "Openai", "openai", "OpenAI", Some("OPENAI_API_KEY")),
151    ProviderConfig::standard("openrouter", "OpenRouter", "openrouter", "OpenRouter", Some("OPENROUTER_API_KEY")),
152    ProviderConfig {
153        extra_source_ids: &["zai-coding-plan"],
154        ..ProviderConfig::standard("zai", "ZAi", "zai", "ZAI", Some("ZAI_API_KEY"))
155    },
156    ProviderConfig::standard("amazon-bedrock", "Bedrock", "bedrock", "AWS Bedrock", None),
157];
158
159const DYNAMIC_PROVIDERS: &[DynamicProviderConfig] = &[
160    DynamicProviderConfig { enum_name: "Ollama", parser_name: "ollama", display_name: "Ollama" },
161    DynamicProviderConfig { enum_name: "LlamaCpp", parser_name: "llamacpp", display_name: "LlamaCpp" },
162];
163
164const CODEX_SUBSCRIPTION_CONTEXT_WINDOW: u32 = 272_000;
165
166fn codex_subscription_context_window(model_id: &str, default_context_window: u32) -> u32 {
167    match model_id {
168        "gpt-5.5" | "gpt-5.4" | "gpt-5.4-mini" | "gpt-5.3-codex" | "gpt-5.2" | "codex-auto-review" => {
169            CODEX_SUBSCRIPTION_CONTEXT_WINDOW
170        }
171        _ => default_context_window,
172    }
173}
174
175#[derive(Debug, Clone)]
176struct ModelInfo {
177    variant_name: String,
178    model_id: String,
179    display_name: String,
180    context_window: u32,
181    reasoning_levels: Vec<String>,
182    input_modalities: Vec<String>,
183}
184
185type ProviderModels = BTreeMap<&'static str, Vec<ModelInfo>>;
186
187struct CodegenCtx {
188    provider_models: ProviderModels,
189}
190
191/// Output of the code generator.
192pub struct GeneratedOutput {
193    /// The generated Rust source (for `generated.rs`).
194    pub rust_source: String,
195    /// Per-provider markdown documentation keyed by provider identifier.
196    ///
197    /// Keys are provider `dev_ids` (e.g. `"anthropic"`, `"ollama"`) and values
198    /// are markdown strings suitable for `#![doc = include_str!(...)]`.
199    pub provider_docs: HashMap<String, String>,
200}
201
202/// Run the codegen, returning the generated Rust source and per-provider docs.
203pub fn generate(models_json_path: &Path) -> Result<GeneratedOutput, String> {
204    let json_bytes = std::fs::read_to_string(models_json_path).map_err(|e| format!("read: {e}"))?;
205    let data: ModelsDevData = serde_json::from_str(&json_bytes).map_err(|e| format!("parse: {e}"))?;
206
207    let provider_models = build_provider_models(&data)?;
208    let ctx = CodegenCtx { provider_models };
209    Ok(GeneratedOutput { rust_source: emit_generated_source(&ctx), provider_docs: emit_provider_docs(&ctx) })
210}
211
212fn build_provider_models(data: &ModelsDevData) -> Result<ProviderModels, String> {
213    let mut provider_models = ProviderModels::new();
214
215    for cfg in PROVIDERS {
216        let json_key = cfg.json_key();
217        let provider_data =
218            data.get(json_key).ok_or_else(|| format!("Provider '{json_key}' not found in models.dev data"))?;
219
220        let mut models: Vec<ModelInfo> = collect_models_from(cfg, &provider_data.models);
221
222        for &extra_key in cfg.extra_source_ids {
223            if let Some(extra_data) = data.get(extra_key) {
224                let extra = collect_models_from(cfg, &extra_data.models);
225                let existing_ids: std::collections::HashSet<String> =
226                    models.iter().map(|m| m.model_id.clone()).collect();
227                models.extend(extra.into_iter().filter(|m| !existing_ids.contains(&m.model_id)));
228            }
229        }
230
231        models.sort_by(|a, b| a.model_id.cmp(&b.model_id));
232        provider_models.insert(cfg.dev_id, models);
233    }
234
235    Ok(provider_models)
236}
237
238fn collect_models_from(cfg: &ProviderConfig, models: &HashMap<String, ModelData>) -> Vec<ModelInfo> {
239    models
240        .values()
241        .filter(|m| m.tool_call == Some(true))
242        .filter(|m| !is_alias(&m.id))
243        .filter(|m| cfg.model_filter.is_none_or(|f| f(&m.id)))
244        .map(|m| {
245            let reasoning_levels = if m.reasoning.unwrap_or(false) {
246                cfg.default_reasoning_levels.iter().map(|s| (*s).to_string()).collect()
247            } else {
248                Vec::new()
249            };
250            let input_modalities =
251                m.modalities.as_ref().map_or_else(|| vec!["text".to_string()], |md| md.input.clone());
252            let source_context_window = m.limit.as_ref().map_or(0, |l| l.context);
253            let context_window = cfg.context_window_override.map_or(source_context_window, |override_context_window| {
254                override_context_window(&m.id, source_context_window)
255            });
256            ModelInfo {
257                variant_name: model_id_to_variant(&m.id),
258                model_id: m.id.clone(),
259                display_name: m.name.clone(),
260                context_window,
261                reasoning_levels,
262                input_modalities,
263            }
264        })
265        .collect()
266}
267
268/// Returns true for "latest" alias IDs that just point to another model
269fn is_alias(id: &str) -> bool {
270    id.ends_with("-latest")
271}
272
273/// Convert a model ID like "claude-sonnet-4-5-20250929" into a `PascalCase` variant name.
274/// Treats `-`, `.`, `/`, and `:` as word separators.
275fn model_id_to_variant(id: &str) -> String {
276    let mut result = String::new();
277    let mut capitalize_next = true;
278
279    for ch in id.chars() {
280        if ch == '-' || ch == '.' || ch == '/' || ch == ':' {
281            capitalize_next = true;
282        } else if capitalize_next {
283            result.push(ch.to_ascii_uppercase());
284            capitalize_next = false;
285        } else {
286            result.push(ch);
287        }
288    }
289
290    // If the variant starts with a digit, prefix with underscore
291    if result.starts_with(|c: char| c.is_ascii_digit()) {
292        result.insert(0, '_');
293    }
294
295    result
296}
297
298fn emit_generated_source(ctx: &CodegenCtx) -> String {
299    let mut out = String::with_capacity(64_000);
300    emit_header(&mut out);
301    emit_provider_enums(&mut out, &ctx.provider_models);
302    emit_provider_impls(&mut out, &ctx.provider_models);
303    emit_llm_model_enum(&mut out);
304    emit_from_impls(&mut out);
305    emit_llm_model_impl(&mut out);
306    emit_display_impl(&mut out);
307    emit_fromstr_impl(&mut out);
308    out
309}
310
311fn emit_header(out: &mut String) {
312    pushln(out, "// Auto-generated from models.dev — do not edit manually");
313    pushln(out, "// Regenerated automatically by build.rs");
314    blank(out);
315    pushln(out, "use std::borrow::Cow;");
316    pushln(out, "use std::sync::LazyLock;");
317    pushln(out, "use crate::ReasoningEffort;");
318    blank(out);
319}
320
321fn emit_provider_enums(out: &mut String, provider_models: &ProviderModels) {
322    for cfg in PROVIDERS {
323        pushln(out, "#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]");
324        pushln(out, format!("pub enum {}Model {{", cfg.enum_name));
325        for model in &provider_models[cfg.dev_id] {
326            pushln(out, format!("    {},", model.variant_name));
327        }
328        pushln(out, "}");
329        blank(out);
330    }
331}
332
333fn emit_provider_impls(out: &mut String, provider_models: &ProviderModels) {
334    for cfg in PROVIDERS {
335        let models = &provider_models[cfg.dev_id];
336        let enum_name = format!("{}Model", cfg.enum_name);
337
338        pushln(out, format!("impl {enum_name} {{"));
339
340        // model_id — each model has a unique ID, no grouping needed
341        pushln(out, "    #[allow(clippy::too_many_lines)]");
342        pushln(out, "    fn model_id(self) -> &'static str {");
343        pushln(out, "        match self {");
344        for model in models {
345            pushln(out, format!("            Self::{} => \"{}\",", model.variant_name, model.model_id));
346        }
347        pushln(out, "        }");
348        pushln(out, "    }");
349        blank(out);
350
351        // display_name — group variants that share the same name
352        pushln(out, "    #[allow(clippy::too_many_lines)]");
353        pushln(out, "    fn display_name(self) -> &'static str {");
354        pushln(out, "        match self {");
355        emit_grouped_arms(out, models, |m| escape_rust_string(&m.display_name), |name| format!("\"{name}\""));
356        pushln(out, "        }");
357        pushln(out, "    }");
358        blank(out);
359
360        // context_window — group variants that share the same value
361        pushln(out, "    fn context_window(self) -> u32 {");
362        pushln(out, "        match self {");
363        emit_grouped_arms(
364            out,
365            models,
366            |m| m.context_window.to_string(),
367            |val| format_number(val.parse::<u32>().unwrap()),
368        );
369        pushln(out, "        }");
370        pushln(out, "    }");
371        blank(out);
372
373        // reasoning_levels — per-model reasoning level list
374        pushln(out, "    pub fn reasoning_levels(self) -> &'static [ReasoningEffort] {");
375        pushln(out, "        match self {");
376        emit_grouped_arms(out, models, |m| m.reasoning_levels.join(","), format_reasoning_levels_rhs);
377        pushln(out, "        }");
378        pushln(out, "    }");
379        blank(out);
380
381        // supports_reasoning — derived from reasoning_levels
382        pushln(out, "    pub fn supports_reasoning(self) -> bool {");
383        pushln(out, "        !self.reasoning_levels().is_empty()");
384        pushln(out, "    }");
385        blank(out);
386
387        for modality in ["image", "audio"] {
388            pushln(out, format!("    pub fn supports_{modality}(self) -> bool {{"));
389            pushln(out, "        match self {");
390            let mod_owned = modality.to_string();
391            emit_grouped_arms(
392                out,
393                models,
394                |m| m.input_modalities.contains(&mod_owned).to_string(),
395                std::string::ToString::to_string,
396            );
397            pushln(out, "        }");
398            pushln(out, "    }");
399            blank(out);
400        }
401
402        // ALL constant
403        pushln(out, format!("    const ALL: &[{enum_name}] = &["));
404        for model in models {
405            pushln(out, format!("        Self::{},", model.variant_name));
406        }
407        pushln(out, "    ];");
408
409        pushln(out, "}");
410        blank(out);
411
412        emit_from_str_impl(out, &enum_name, cfg.parser_name, models);
413    }
414}
415
416fn emit_from_str_impl(out: &mut String, enum_name: &str, parser_name: &str, models: &[ModelInfo]) {
417    pushln(out, format!("impl std::str::FromStr for {enum_name} {{"));
418    pushln(out, "    type Err = String;");
419    blank(out);
420    pushln(out, "    #[allow(clippy::too_many_lines)]");
421    pushln(out, "    fn from_str(s: &str) -> Result<Self, Self::Err> {");
422    pushln(out, "        match s {");
423    for model in models {
424        pushln(out, format!("            \"{}\" => Ok(Self::{}),", model.model_id, model.variant_name));
425    }
426    pushln(out, format!("            _ => Err(format!(\"Unknown {parser_name} model: '{{s}}'\")),"));
427    pushln(out, "        }");
428    pushln(out, "    }");
429    pushln(out, "}");
430    blank(out);
431}
432
433/// Emit match arms grouped by value to avoid clippy `match_same_arms`.
434///
435/// `key_fn` extracts a grouping key from each model (e.g. `context_window` as string).
436/// `fmt_val` formats the key into the match arm's RHS.
437fn emit_grouped_arms(
438    out: &mut String,
439    models: &[ModelInfo],
440    key_fn: impl Fn(&ModelInfo) -> String,
441    fmt_val: impl Fn(&str) -> String,
442) {
443    // Group variants by value, preserving insertion order via BTreeMap
444    let mut groups: BTreeMap<String, Vec<&str>> = BTreeMap::new();
445    for model in models {
446        groups.entry(key_fn(model)).or_default().push(&model.variant_name);
447    }
448
449    for (key, variants) in &groups {
450        let rhs = fmt_val(key);
451        if variants.len() == 1 {
452            pushln(out, format!("            Self::{} => {rhs},", variants[0]));
453        } else {
454            let patterns: Vec<String> = variants.iter().map(|v| format!("Self::{v}")).collect();
455            pushln(out, format!("            {} => {rhs},", patterns.join(" | ")));
456        }
457    }
458}
459
460/// Format the RHS of a `reasoning_levels()` match arm from a comma-joined key.
461fn format_reasoning_levels_rhs(key: &str) -> String {
462    if key.is_empty() {
463        return "&[]".to_string();
464    }
465    let items: Vec<String> = key
466        .split(',')
467        .map(|l| {
468            let variant = level_str_to_variant(l);
469            format!("ReasoningEffort::{variant}")
470        })
471        .collect();
472    format!("&[{}]", items.join(", "))
473}
474
475/// Map a reasoning level string to its `ReasoningEffort` variant name.
476/// Panics at build time if an unknown level is encountered.
477fn level_str_to_variant(level: &str) -> &'static str {
478    match level {
479        "low" => "Low",
480        "medium" => "Medium",
481        "high" => "High",
482        "xhigh" => "Xhigh",
483        other => panic!("Unknown reasoning level: {other}"),
484    }
485}
486
487fn emit_llm_model_enum(out: &mut String) {
488    pushln(out, "/// A model from a specific provider");
489    pushln(out, "#[derive(Debug, Clone, PartialEq, Eq, Hash)]");
490    pushln(out, "pub enum LlmModel {");
491    for cfg in PROVIDERS {
492        pushln(out, format!("    {provider}({provider}Model),", provider = cfg.enum_name));
493    }
494    for dyn_cfg in DYNAMIC_PROVIDERS {
495        pushln(out, format!("    {}(String),", dyn_cfg.enum_name));
496    }
497    pushln(out, "}");
498    blank(out);
499}
500
501fn emit_from_impls(out: &mut String) {
502    for cfg in PROVIDERS {
503        pushln(out, format!("impl From<{}Model> for LlmModel {{", cfg.enum_name));
504        pushln(out, format!("    fn from(m: {}Model) -> Self {{ LlmModel::{}(m) }}", cfg.enum_name, cfg.enum_name));
505        pushln(out, "}");
506        blank(out);
507    }
508}
509
510fn emit_llm_model_impl(out: &mut String) {
511    pushln(out, "impl LlmModel {");
512    emit_llm_model_id(out);
513    emit_llm_display_name(out);
514    emit_llm_provider(out);
515    emit_llm_provider_display_name(out);
516    emit_llm_context_window(out);
517    emit_llm_required_env_var(out);
518    emit_llm_all_required_env_vars(out);
519    emit_llm_oauth_provider_id(out);
520    emit_llm_reasoning_levels(out);
521    emit_llm_supports_reasoning(out);
522    for modality in ["image", "audio"] {
523        emit_llm_supports_modality(out, modality);
524    }
525    emit_llm_all(out);
526    pushln(out, "}");
527    blank(out);
528}
529
530fn emit_llm_model_id(out: &mut String) {
531    pushln(out, "    /// Raw model ID (e.g. `claude-opus-4-6`, `llama3.2`)");
532    pushln(out, "    pub fn model_id(&self) -> Cow<'static, str> {");
533    pushln(out, "        match self {");
534    for cfg in PROVIDERS {
535        pushln(out, format!("            Self::{}(m) => Cow::Borrowed(m.model_id()),", cfg.enum_name));
536    }
537    pushln(out, format!("            {} => Cow::Owned(s.clone()),", dynamic_pattern_with_binding("s")));
538    pushln(out, "        }");
539    pushln(out, "    }");
540    blank(out);
541}
542
543fn emit_llm_display_name(out: &mut String) {
544    pushln(out, "    /// Human-readable display name (e.g. `Claude Opus 4.6`)");
545    pushln(out, "    pub fn display_name(&self) -> Cow<'static, str> {");
546    pushln(out, "        match self {");
547    for cfg in PROVIDERS {
548        pushln(out, format!("            Self::{}(m) => Cow::Borrowed(m.display_name()),", cfg.enum_name));
549    }
550    for dyn_cfg in DYNAMIC_PROVIDERS {
551        pushln(
552            out,
553            format!(
554                "            Self::{}(s) => Cow::Owned(format!(\"{} {{s}}\")),",
555                dyn_cfg.enum_name, dyn_cfg.enum_name
556            ),
557        );
558    }
559    pushln(out, "        }");
560    pushln(out, "    }");
561    blank(out);
562}
563
564fn emit_llm_provider(out: &mut String) {
565    pushln(out, "    /// Provider identifier (e.g. `anthropic`)");
566    pushln(out, "    pub fn provider(&self) -> &'static str {");
567    pushln(out, "        match self {");
568    for cfg in PROVIDERS {
569        pushln(out, format!("            Self::{}(_) => \"{}\",", cfg.enum_name, cfg.parser_name));
570    }
571    for dyn_cfg in DYNAMIC_PROVIDERS {
572        pushln(out, format!("            Self::{}(_) => \"{}\",", dyn_cfg.enum_name, dyn_cfg.parser_name));
573    }
574    pushln(out, "        }");
575    pushln(out, "    }");
576    blank(out);
577}
578
579fn emit_llm_provider_display_name(out: &mut String) {
580    pushln(out, "    /// Human-readable provider name (e.g. `AWS Bedrock`)");
581    pushln(out, "    pub fn provider_display_name(&self) -> &'static str {");
582    pushln(out, "        match self {");
583    for cfg in PROVIDERS {
584        pushln(out, format!("            Self::{}(_) => \"{}\",", cfg.enum_name, cfg.display_name));
585    }
586    for dyn_cfg in DYNAMIC_PROVIDERS {
587        pushln(out, format!("            Self::{}(_) => \"{}\",", dyn_cfg.enum_name, dyn_cfg.display_name));
588    }
589    pushln(out, "        }");
590    pushln(out, "    }");
591    blank(out);
592}
593
594fn emit_llm_context_window(out: &mut String) {
595    pushln(out, "    /// Context window size in tokens (None for dynamic providers)");
596    pushln(out, "    pub fn context_window(&self) -> Option<u32> {");
597    pushln(out, "        match self {");
598    for cfg in PROVIDERS {
599        pushln(out, format!("            Self::{}(m) => Some(m.context_window()),", cfg.enum_name));
600    }
601    pushln(out, format!("            {} => None,", dynamic_pattern_with_binding("_")));
602    pushln(out, "        }");
603    pushln(out, "    }");
604    blank(out);
605}
606
607fn emit_llm_required_env_var(out: &mut String) {
608    pushln(out, "    /// Required env var for this model's provider (None for local providers)");
609    pushln(out, "    pub fn required_env_var(&self) -> Option<&'static str> {");
610    pushln(out, "        match self {");
611    let mut none_arms: Vec<String> = Vec::new();
612    for cfg in PROVIDERS {
613        match cfg.env_var {
614            Some(var) => pushln(out, format!("            Self::{}(_) => Some(\"{}\"),", cfg.enum_name, var)),
615            None => none_arms.push(format!("Self::{}(_)", cfg.enum_name)),
616        }
617    }
618    for dyn_cfg in DYNAMIC_PROVIDERS {
619        none_arms.push(format!("Self::{}(_)", dyn_cfg.enum_name));
620    }
621    pushln(out, format!("            {} => None,", none_arms.join(" | ")));
622    pushln(out, "        }");
623    pushln(out, "    }");
624    blank(out);
625}
626
627fn emit_llm_all_required_env_vars(out: &mut String) {
628    let vars: Vec<&str> = PROVIDERS.iter().filter_map(|cfg| cfg.env_var).collect();
629    pushln(out, "    /// All provider API key env var names (deduplicated, static)");
630    pushln(
631        out,
632        format!(
633            "    pub const ALL_REQUIRED_ENV_VARS: &[&str] = &[{}];",
634            vars.iter().map(|v| format!("\"{v}\"")).collect::<Vec<_>>().join(", ")
635        ),
636    );
637    blank(out);
638}
639
640fn emit_llm_oauth_provider_id(out: &mut String) {
641    pushln(out, "    /// OAuth provider ID if this model requires OAuth login (e.g. `\"codex\"`)");
642    pushln(out, "    pub fn oauth_provider_id(&self) -> Option<&'static str> {");
643    pushln(out, "        match self {");
644    let mut none_arms: Vec<String> = Vec::new();
645    for cfg in PROVIDERS {
646        match cfg.oauth_provider_id {
647            Some(id) => pushln(out, format!("            Self::{}(_) => Some(\"{}\"),", cfg.enum_name, id)),
648            None => none_arms.push(format!("Self::{}(_)", cfg.enum_name)),
649        }
650    }
651    for dyn_cfg in DYNAMIC_PROVIDERS {
652        none_arms.push(format!("Self::{}(_)", dyn_cfg.enum_name));
653    }
654    pushln(out, format!("            {} => None,", none_arms.join(" | ")));
655    pushln(out, "        }");
656    pushln(out, "    }");
657    blank(out);
658}
659
660fn emit_llm_reasoning_levels(out: &mut String) {
661    pushln(out, "    /// Reasoning levels supported by this model (empty if not a reasoning model)");
662    pushln(out, "    pub fn reasoning_levels(&self) -> &'static [ReasoningEffort] {");
663    pushln(out, "        match self {");
664    for cfg in PROVIDERS {
665        pushln(out, format!("            Self::{}(m) => m.reasoning_levels(),", cfg.enum_name));
666    }
667    pushln(out, format!("            {} => &[],", dynamic_pattern_with_binding("_")));
668    pushln(out, "        }");
669    pushln(out, "    }");
670    blank(out);
671}
672
673fn emit_llm_supports_reasoning(out: &mut String) {
674    pushln(out, "    /// Whether this model supports reasoning/extended thinking");
675    pushln(out, "    pub fn supports_reasoning(&self) -> bool {");
676    pushln(out, "        !self.reasoning_levels().is_empty()");
677    pushln(out, "    }");
678    blank(out);
679}
680
681fn emit_llm_supports_modality(out: &mut String, modality: &str) {
682    pushln(out, format!("    /// Whether this model supports {modality} input"));
683    pushln(out, format!("    pub fn supports_{modality}(&self) -> bool {{"));
684    pushln(out, "        match self {");
685    for cfg in PROVIDERS {
686        pushln(out, format!("            Self::{}(m) => m.supports_{modality}(),", cfg.enum_name));
687    }
688    pushln(out, format!("            {} => false,", dynamic_pattern_with_binding("_")));
689    pushln(out, "        }");
690    pushln(out, "    }");
691    blank(out);
692}
693
694fn emit_llm_all(out: &mut String) {
695    pushln(out, "    /// All catalog models (excludes dynamic providers)");
696    pushln(out, "    pub fn all() -> &'static [LlmModel] {");
697    pushln(out, "        static ALL: LazyLock<Vec<LlmModel>> = LazyLock::new(|| {");
698    pushln(out, "            let mut v = Vec::new();");
699    for cfg in PROVIDERS {
700        pushln(
701            out,
702            format!(
703                "            v.extend({}Model::ALL.iter().copied().map(LlmModel::{}));",
704                cfg.enum_name, cfg.enum_name
705            ),
706        );
707    }
708    pushln(out, "            v");
709    pushln(out, "        });");
710    pushln(out, "        &ALL");
711    pushln(out, "    }");
712}
713
714fn emit_display_impl(out: &mut String) {
715    pushln(out, "impl std::fmt::Display for LlmModel {");
716    pushln(out, "    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {");
717    pushln(out, "        write!(f, \"{}:{}\", self.provider(), self.model_id())");
718    pushln(out, "    }");
719    pushln(out, "}");
720    blank(out);
721}
722
723fn emit_fromstr_impl(out: &mut String) {
724    pushln(out, "impl std::str::FromStr for LlmModel {");
725    pushln(out, "    type Err = String;");
726    blank(out);
727    pushln(out, "    /// Parse a `provider:model` string into an `LlmModel`");
728    pushln(out, "    fn from_str(s: &str) -> Result<Self, Self::Err> {");
729    pushln(out, "        let (provider_str, model_str) = s.split_once(':').unwrap_or((s, \"\"));");
730    pushln(out, "        match provider_str {");
731    for cfg in PROVIDERS {
732        pushln(
733            out,
734            format!(
735                "            \"{}\" => model_str.parse::<{}Model>().map(Self::{}),",
736                cfg.parser_name, cfg.enum_name, cfg.enum_name
737            ),
738        );
739    }
740    for dyn_cfg in DYNAMIC_PROVIDERS {
741        pushln(
742            out,
743            format!(
744                "            \"{}\" => Ok(Self::{}(model_str.to_string())),",
745                dyn_cfg.parser_name, dyn_cfg.enum_name
746            ),
747        );
748    }
749    pushln(out, "            _ => Err(format!(\"Unknown provider: '{provider_str}'\")),");
750    pushln(out, "        }");
751    pushln(out, "    }");
752    pushln(out, "}");
753}
754
755/// Build a combined `|` pattern for all dynamic providers with a binding variable.
756/// e.g. `Self::Ollama(s) | Self::LlamaCpp(s)` or `Self::Ollama(_) | Self::LlamaCpp(_)`
757fn dynamic_pattern_with_binding(binding: &str) -> String {
758    DYNAMIC_PROVIDERS.iter().map(|d| format!("Self::{}({binding})", d.enum_name)).collect::<Vec<_>>().join(" | ")
759}
760
761/// Format a number with underscore separators (e.g. `200000` → `200_000`).
762fn format_number(n: u32) -> String {
763    let s = n.to_string();
764    if s.len() <= 4 {
765        return s;
766    }
767    let mut result = String::with_capacity(s.len() + s.len() / 3);
768    for (i, ch) in s.chars().enumerate() {
769        if i > 0 && (s.len() - i).is_multiple_of(3) {
770            result.push('_');
771        }
772        result.push(ch);
773    }
774    result
775}
776
777fn escape_rust_string(raw: &str) -> String {
778    raw.replace('\\', "\\\\").replace('"', "\\\"")
779}
780
781fn emit_provider_docs(ctx: &CodegenCtx) -> HashMap<String, String> {
782    let mut docs = HashMap::new();
783
784    for cfg in PROVIDERS {
785        let models = &ctx.provider_models[cfg.dev_id];
786        let mut doc = String::new();
787
788        pushln(&mut doc, format!("`{}` LLM provider.", cfg.display_name));
789        blank(&mut doc);
790
791        // Authentication
792        pushln(&mut doc, "# Authentication");
793        blank(&mut doc);
794        match cfg.env_var {
795            Some(var) => pushln(&mut doc, format!("Set the `{var}` environment variable.")),
796            None if cfg.oauth_provider_id.is_some() => {
797                pushln(&mut doc, "This provider uses OAuth authentication.");
798            }
799            None => {
800                pushln(
801                    &mut doc,
802                    "Uses the default AWS credential chain (environment variables, config files, IAM roles).",
803                );
804            }
805        }
806        blank(&mut doc);
807
808        // Supported models table
809        pushln(&mut doc, "# Supported models");
810        blank(&mut doc);
811        pushln(&mut doc, "| Model ID | Name | Context | Reasoning | Image | Audio |");
812        pushln(&mut doc, "|----------|------|---------|-----------|-------|-------|");
813        for model in models {
814            let ctx_str = format_context_window(model.context_window);
815            let reasoning = if model.reasoning_levels.is_empty() { "" } else { "yes" };
816            let image = if model.input_modalities.contains(&"image".to_string()) { "yes" } else { "" };
817            let audio = if model.input_modalities.contains(&"audio".to_string()) { "yes" } else { "" };
818            pushln(
819                &mut doc,
820                format!(
821                    "| `{}` | `{}` | `{}` | {} | {} | {} |",
822                    model.model_id, model.display_name, ctx_str, reasoning, image, audio
823                ),
824            );
825        }
826
827        docs.insert(cfg.dev_id.to_string(), doc);
828    }
829
830    // Dynamic providers
831    for dyn_cfg in DYNAMIC_PROVIDERS {
832        let mut doc = String::new();
833        pushln(&mut doc, format!("`{}` LLM provider.", dyn_cfg.display_name));
834        blank(&mut doc);
835        pushln(
836            &mut doc,
837            format!("This provider accepts any model name at runtime (e.g. `{}:my-model`).", dyn_cfg.parser_name),
838        );
839        pushln(&mut doc, "No API key is required.");
840        docs.insert(dyn_cfg.parser_name.to_string(), doc);
841    }
842
843    docs
844}
845
846/// Format a token count as human-readable (e.g. `1_000_000` → `1M`, `200_000` → `200k`).
847fn format_context_window(tokens: u32) -> String {
848    if tokens == 0 {
849        return "unknown".to_string();
850    }
851    if tokens >= 1_000_000 && tokens.is_multiple_of(1_000_000) {
852        format!("{}M", tokens / 1_000_000)
853    } else if tokens >= 1_000 && tokens.is_multiple_of(1_000) {
854        format!("{}k", tokens / 1_000)
855    } else {
856        format_number(tokens)
857    }
858}
859
860fn pushln(out: &mut String, line: impl AsRef<str>) {
861    writeln!(out, "{}", line.as_ref()).expect("writing to String should not fail");
862}
863
864fn blank(out: &mut String) {
865    pushln(out, "");
866}
867
868// ── Tests ────────────────────────────────────────────────────────────────
869
870#[cfg(test)]
871mod tests {
872    use super::*;
873    use serde_json::Value;
874    use serde_json::json;
875    use tempfile::NamedTempFile;
876
877    #[test]
878    fn generate_sorts_and_filters_models() {
879        let mut data = minimal_models_dev_json();
880        let root = data.as_object_mut().expect("root object");
881        let anthropic = root.get_mut("anthropic").and_then(Value::as_object_mut).expect("anthropic provider");
882
883        anthropic.insert(
884            "models".to_string(),
885            json!({
886                "b-model": {
887                    "id": "b-model",
888                    "name": "B Model",
889                    "tool_call": true,
890                    "limit": {"context": 2000, "output": 0}
891                },
892                "a-model": {
893                    "id": "a-model",
894                    "name": "A Model",
895                    "tool_call": true,
896                    "limit": {"context": 1000, "output": 0}
897                },
898                "alpha-latest": {
899                    "id": "alpha-latest",
900                    "name": "Alias",
901                    "tool_call": true,
902                    "limit": {"context": 500, "output": 0}
903                },
904                "no-tools": {
905                    "id": "no-tools",
906                    "name": "No Tools",
907                    "tool_call": false,
908                    "limit": {"context": 500, "output": 0}
909                }
910            }),
911        );
912
913        let source = generate_from_value(&data);
914        // Provider-level FromStr: sorted model IDs
915        let a_model = "\"a-model\" => Ok(Self::AModel),";
916        let b_model = "\"b-model\" => Ok(Self::BModel),";
917        let a_pos = source.find(a_model).expect("a-model parse arm");
918        let b_pos = source.find(b_model).expect("b-model parse arm");
919        assert!(a_pos < b_pos);
920        // Aliases and non-tool-call models are excluded
921        assert!(!source.contains("AlphaLatest"));
922        assert!(!source.contains("NoTools"));
923    }
924
925    #[test]
926    fn generate_contains_core_sections() {
927        let source = generate_from_value(&minimal_models_dev_json());
928        assert!(source.contains("pub enum LlmModel {"));
929        assert!(source.contains("impl std::str::FromStr for LlmModel {"));
930        assert!(source.contains("impl std::fmt::Display for LlmModel {"));
931        assert!(source.contains("pub fn required_env_var(&self) -> Option<&'static str> {"));
932    }
933
934    #[test]
935    fn generate_contains_dynamic_provider_arms() {
936        let source = generate_from_value(&minimal_models_dev_json());
937        assert!(source.contains("\"ollama\" => Ok(Self::Ollama(model_str.to_string())),"));
938        assert!(source.contains("\"llamacpp\" => Ok(Self::LlamaCpp(model_str.to_string())),"));
939        // Dynamic providers are combined with | for None-returning arms
940        assert!(source.contains("Self::Ollama(_) | Self::LlamaCpp(_) => None,"));
941    }
942
943    #[test]
944    fn generate_codex_is_catalog_provider() {
945        let source = generate_from_value(&minimal_models_dev_json());
946        // Codex is a catalog provider, not dynamic
947        assert!(source.contains("pub enum CodexModel {"));
948        assert!(source.contains("\"codex\" => model_str.parse::<CodexModel>().map(Self::Codex),"));
949        assert!(source.contains("Self::Codex(m) => Some(m.context_window()),"));
950    }
951
952    #[test]
953    fn generate_oauth_provider_id_for_codex() {
954        let source = generate_from_value(&minimal_models_dev_json());
955        // Codex models return Some("codex") for oauth_provider_id
956        assert!(source.contains("Self::Codex(_) => Some(\"codex\"),"));
957        // Non-OAuth providers return None
958        assert!(source.contains("pub fn oauth_provider_id(&self) -> Option<&'static str>"));
959    }
960
961    #[test]
962    fn generate_delegates_to_provider_impls() {
963        let source = generate_from_value(&minimal_models_dev_json());
964        // LlmModel delegates to per-provider methods
965        assert!(source.contains("Self::Anthropic(m) => Cow::Borrowed(m.model_id()),"));
966        assert!(source.contains("Self::Anthropic(m) => Some(m.context_window()),"));
967        // Provider-level FromStr is used by LlmModel::FromStr
968        assert!(source.contains("\"anthropic\" => model_str.parse::<AnthropicModel>().map(Self::Anthropic),"));
969    }
970
971    #[test]
972    fn generate_formats_large_numbers_with_separators() {
973        let mut data = minimal_models_dev_json();
974        let root = data.as_object_mut().expect("root object");
975        let anthropic = root.get_mut("anthropic").and_then(Value::as_object_mut).expect("anthropic provider");
976
977        anthropic.insert(
978            "models".to_string(),
979            json!({
980                "big-model": {
981                    "id": "big-model",
982                    "name": "Big Model",
983                    "tool_call": true,
984                    "limit": {"context": 200_000, "output": 0}
985                }
986            }),
987        );
988
989        let source = generate_from_value(&data);
990        assert!(source.contains("200_000"));
991        assert!(!source.contains("200000"));
992    }
993
994    #[test]
995    fn generate_groups_identical_match_arms() {
996        let mut data = minimal_models_dev_json();
997        let root = data.as_object_mut().expect("root object");
998        let anthropic = root.get_mut("anthropic").and_then(Value::as_object_mut).expect("anthropic provider");
999
1000        anthropic.insert(
1001            "models".to_string(),
1002            json!({
1003                "model-a": {
1004                    "id": "model-a",
1005                    "name": "Same Name",
1006                    "tool_call": true,
1007                    "limit": {"context": 100_000, "output": 0}
1008                },
1009                "model-b": {
1010                    "id": "model-b",
1011                    "name": "Same Name",
1012                    "tool_call": true,
1013                    "limit": {"context": 100_000, "output": 0}
1014                }
1015            }),
1016        );
1017
1018        let source = generate_from_value(&data);
1019        // Both context_window and display_name should combine arms
1020        assert!(source.contains("Self::ModelA | Self::ModelB => 100_000,"));
1021        assert!(source.contains("Self::ModelA | Self::ModelB => \"Same Name\","));
1022    }
1023
1024    #[test]
1025    fn test_model_id_to_variant() {
1026        assert_eq!(model_id_to_variant("claude-sonnet-4-5-20250929"), "ClaudeSonnet4520250929");
1027        assert_eq!(model_id_to_variant("gemini-2.5-flash"), "Gemini25Flash");
1028        assert_eq!(model_id_to_variant("deepseek-chat"), "DeepseekChat");
1029        assert_eq!(model_id_to_variant("glm-4.5"), "Glm45");
1030    }
1031
1032    #[test]
1033    fn test_model_id_to_variant_with_slash_and_colon() {
1034        assert_eq!(model_id_to_variant("anthropic/claude-opus-4.6"), "AnthropicClaudeOpus46");
1035        assert_eq!(model_id_to_variant("openai/gpt-5.1-codex-max"), "OpenaiGpt51CodexMax");
1036        assert_eq!(model_id_to_variant("deepseek/deepseek-r1:free"), "DeepseekDeepseekR1Free");
1037    }
1038
1039    #[test]
1040    fn test_is_alias() {
1041        assert!(is_alias("claude-sonnet-4-5-latest"));
1042        assert!(is_alias("claude-3-7-sonnet-latest"));
1043        assert!(!is_alias("claude-sonnet-4-5-20250929"));
1044    }
1045
1046    #[test]
1047    fn generate_contains_reasoning_levels_and_supports_reasoning() {
1048        let mut data = minimal_models_dev_json();
1049        let root = data.as_object_mut().expect("root object");
1050        let anthropic = root.get_mut("anthropic").and_then(Value::as_object_mut).expect("anthropic provider");
1051
1052        anthropic.insert(
1053            "models".to_string(),
1054            json!({
1055                "thinker": {
1056                    "id": "thinker",
1057                    "name": "Thinker",
1058                    "tool_call": true,
1059                    "reasoning": true,
1060                    "limit": {"context": 1000, "output": 0}
1061                },
1062                "fast": {
1063                    "id": "fast",
1064                    "name": "Fast",
1065                    "tool_call": true,
1066                    "reasoning": false,
1067                    "limit": {"context": 1000, "output": 0}
1068                }
1069            }),
1070        );
1071
1072        let source = generate_from_value(&data);
1073        // Provider enum should have reasoning_levels method
1074        assert!(source.contains("pub fn reasoning_levels(self) -> &'static [ReasoningEffort] {"));
1075        // Thinker (anthropic) gets standard 3 levels
1076        assert!(
1077            source
1078                .contains("Self::Thinker => &[ReasoningEffort::Low, ReasoningEffort::Medium, ReasoningEffort::High],")
1079        );
1080        // Fast gets empty
1081        assert!(source.contains("Self::Fast => &[],"));
1082        // supports_reasoning delegates to reasoning_levels
1083        assert!(source.contains("pub fn supports_reasoning(self) -> bool {"));
1084        assert!(source.contains("!self.reasoning_levels().is_empty()"));
1085        // LlmModel level
1086        assert!(source.contains("pub fn reasoning_levels(&self) -> &'static [ReasoningEffort] {"));
1087        // Dynamic providers return empty
1088        assert!(source.contains("Self::Ollama(_) | Self::LlamaCpp(_) => &[],"));
1089    }
1090
1091    #[test]
1092    fn generate_codex_gets_four_reasoning_levels() {
1093        let mut data = minimal_models_dev_json();
1094        let root = data.as_object_mut().expect("root object");
1095        let openai = root.get_mut("openai").and_then(Value::as_object_mut).expect("openai provider");
1096
1097        openai.insert(
1098            "models".to_string(),
1099            json!({
1100                "gpt-5.4-codex": {
1101                    "id": "gpt-5.4-codex",
1102                    "name": "GPT-5.4 Codex",
1103                    "tool_call": true,
1104                    "reasoning": true,
1105                    "limit": {"context": 200_000, "output": 0}
1106                }
1107            }),
1108        );
1109
1110        let source = generate_from_value(&data);
1111        assert!(
1112            source.contains(
1113                "ReasoningEffort::Low, ReasoningEffort::Medium, ReasoningEffort::High, ReasoningEffort::Xhigh"
1114            ),
1115            "Codex reasoning model should have 4 levels including Xhigh"
1116        );
1117    }
1118
1119    #[test]
1120    fn generate_codex_overrides_gpt55_subscription_context_window() {
1121        let mut data = minimal_models_dev_json();
1122        let root = data.as_object_mut().expect("root object");
1123        let openai = root.get_mut("openai").and_then(Value::as_object_mut).expect("openai provider");
1124
1125        openai.insert(
1126            "models".to_string(),
1127            json!({
1128                "gpt-5.5": {
1129                    "id": "gpt-5.5",
1130                    "name": "GPT-5.5",
1131                    "tool_call": true,
1132                    "reasoning": true,
1133                    "limit": {"context": 1_050_000, "output": 128_000}
1134                }
1135            }),
1136        );
1137
1138        let source = generate_from_value(&data);
1139        let codex_impl = generated_impl_block(&source, "CodexModel");
1140        let openai_impl = generated_impl_block(&source, "OpenaiModel");
1141
1142        assert!(codex_impl.contains("Self::Gpt55 => 272_000,"));
1143        assert!(openai_impl.contains("Self::Gpt55 => 1_050_000,"));
1144    }
1145
1146    #[test]
1147    fn codex_subscription_context_window_overrides_known_codex_models() {
1148        for model_id in ["gpt-5.5", "gpt-5.4", "gpt-5.4-mini", "gpt-5.3-codex", "gpt-5.2", "codex-auto-review"] {
1149            assert_eq!(codex_subscription_context_window(model_id, 1_050_000), 272_000);
1150        }
1151    }
1152
1153    #[test]
1154    fn codex_subscription_context_window_leaves_unknown_models_unchanged() {
1155        assert_eq!(codex_subscription_context_window("gpt-5.3-codex-spark", 128_000), 128_000);
1156        assert_eq!(codex_subscription_context_window("some-future-model", 400_000), 400_000);
1157    }
1158
1159    fn generate_from_value(data: &Value) -> String {
1160        let tmp = NamedTempFile::new().expect("temp file");
1161        let json = serde_json::to_string(data).expect("serialize fixture");
1162        std::fs::write(tmp.path(), json).expect("write fixture");
1163        generate(tmp.path()).expect("codegen succeeds").rust_source
1164    }
1165
1166    fn generated_impl_block<'a>(source: &'a str, enum_name: &str) -> &'a str {
1167        let start_marker = format!("impl {enum_name} {{");
1168        let start = source.find(&start_marker).expect("provider impl block start");
1169        let rest = &source[start..];
1170        let end_marker = format!("impl std::str::FromStr for {enum_name}");
1171        let end = rest.find(&end_marker).expect("provider impl block end");
1172        &rest[..end]
1173    }
1174    fn minimal_models_dev_json() -> Value {
1175        let mut root = serde_json::Map::new();
1176        for cfg in PROVIDERS {
1177            let json_key = cfg.json_key();
1178            root.entry(json_key.to_string()).or_insert_with(|| {
1179                json!({
1180                    "id": json_key,
1181                    "name": json_key,
1182                    "env": [],
1183                    "models": {}
1184                })
1185            });
1186            for &extra in cfg.extra_source_ids {
1187                root.entry(extra.to_string()).or_insert_with(|| {
1188                    json!({
1189                        "id": extra,
1190                        "name": extra,
1191                        "env": [],
1192                        "models": {}
1193                    })
1194                });
1195            }
1196        }
1197        Value::Object(root)
1198    }
1199
1200    #[test]
1201    fn extra_source_ids_merges_models_into_provider() {
1202        let mut data = minimal_models_dev_json();
1203        let root = data.as_object_mut().unwrap();
1204
1205        // Add a model to the zai-coding-plan extra source that doesn't exist in zai
1206        let extra = root.get_mut("zai-coding-plan").unwrap().as_object_mut().unwrap();
1207        extra.insert(
1208            "models".to_string(),
1209            json!({
1210                "extra-model": {
1211                    "id": "extra-model",
1212                    "name": "Extra Model",
1213                    "tool_call": true,
1214                    "limit": {"context": 4000, "output": 0}
1215                }
1216            }),
1217        );
1218
1219        let source = generate_from_value(&data);
1220        // The extra model should appear under ZAi
1221        assert!(source.contains("\"extra-model\" => Ok(Self::ExtraModel),"));
1222    }
1223
1224    #[test]
1225    fn extra_source_ids_does_not_duplicate_existing_models() {
1226        let mut data = minimal_models_dev_json();
1227        let root = data.as_object_mut().unwrap();
1228
1229        // Add same model ID to both zai and zai-coding-plan
1230        let zai = root.get_mut("zai").unwrap().as_object_mut().unwrap();
1231        zai.insert(
1232            "models".to_string(),
1233            json!({
1234                "shared-model": {
1235                    "id": "shared-model",
1236                    "name": "Shared Model",
1237                    "tool_call": true,
1238                    "limit": {"context": 1000, "output": 0}
1239                }
1240            }),
1241        );
1242        let extra = root.get_mut("zai-coding-plan").unwrap().as_object_mut().unwrap();
1243        extra.insert(
1244            "models".to_string(),
1245            json!({
1246                "shared-model": {
1247                    "id": "shared-model",
1248                    "name": "Shared Model Duplicate",
1249                    "tool_call": true,
1250                    "limit": {"context": 2000, "output": 0}
1251                }
1252            }),
1253        );
1254
1255        let source = generate_from_value(&data);
1256        let from_str_matches = source.matches("\"shared-model\" => Ok(Self::SharedModel),").count();
1257        assert_eq!(from_str_matches, 1);
1258    }
1259
1260    #[test]
1261    fn generate_emits_provider_docs() {
1262        let mut data = minimal_models_dev_json();
1263        let root = data.as_object_mut().unwrap();
1264        let anthropic = root.get_mut("anthropic").and_then(Value::as_object_mut).unwrap();
1265
1266        anthropic.insert(
1267            "models".to_string(),
1268            json!({
1269                "claude-test": {
1270                    "id": "claude-test",
1271                    "name": "Claude Test",
1272                    "tool_call": true,
1273                    "reasoning": true,
1274                    "limit": {"context": 200_000, "output": 0},
1275                    "modalities": {"input": ["text", "image"]}
1276                }
1277            }),
1278        );
1279
1280        let tmp = NamedTempFile::new().unwrap();
1281        std::fs::write(tmp.path(), serde_json::to_string(&data).unwrap()).unwrap();
1282        let output = generate(tmp.path()).unwrap();
1283
1284        let anthropic_doc = &output.provider_docs["anthropic"];
1285        assert!(anthropic_doc.contains("`Anthropic` LLM provider."));
1286        assert!(anthropic_doc.contains("`ANTHROPIC_API_KEY`"));
1287        assert!(anthropic_doc.contains("| `claude-test` | `Claude Test` | `200k` | yes | yes |  |"));
1288
1289        // Dynamic providers get a short doc
1290        let ollama_doc = &output.provider_docs["ollama"];
1291        assert!(ollama_doc.contains("`Ollama` LLM provider."));
1292        assert!(ollama_doc.contains("any model name at runtime"));
1293    }
1294
1295    #[test]
1296    fn format_context_window_formats_correctly() {
1297        assert_eq!(format_context_window(1_000_000), "1M");
1298        assert_eq!(format_context_window(200_000), "200k");
1299        assert_eq!(format_context_window(8_000), "8k");
1300        assert_eq!(format_context_window(0), "unknown");
1301    }
1302
1303    #[test]
1304    fn level_str_to_variant_covers_all_reasoning_efforts() {
1305        for effort in utils::ReasoningEffort::all() {
1306            // Should not panic for any known variant
1307            let _ = level_str_to_variant(effort.as_str());
1308        }
1309    }
1310}