Skip to main content

llm_codegen/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use serde::Deserialize;
4use std::collections::{BTreeMap, HashMap};
5use std::fmt::Write;
6use std::path::Path;
7
8type ModelsDevData = HashMap<String, ProviderData>;
9
10#[derive(Debug, Deserialize)]
11struct ProviderData {
12    #[allow(dead_code)]
13    id: String,
14    #[allow(dead_code)]
15    name: String,
16    #[serde(default)]
17    #[allow(dead_code)]
18    env: Vec<String>,
19    #[serde(default)]
20    models: HashMap<String, ModelData>,
21}
22
23#[derive(Debug, Deserialize)]
24struct ModelData {
25    id: String,
26    name: String,
27    #[serde(default)]
28    tool_call: Option<bool>,
29    #[serde(default)]
30    reasoning: Option<bool>,
31    #[serde(default)]
32    #[allow(dead_code)]
33    cost: Option<CostData>,
34    #[serde(default)]
35    limit: Option<LimitData>,
36    #[serde(default)]
37    modalities: Option<ModalitiesData>,
38}
39
40#[derive(Debug, Deserialize, Default)]
41struct ModalitiesData {
42    #[serde(default)]
43    input: Vec<String>,
44}
45
46#[derive(Debug, Deserialize)]
47#[allow(dead_code)]
48struct CostData {
49    #[serde(default)]
50    input: f64,
51    #[serde(default)]
52    output: f64,
53}
54
55#[derive(Debug, Deserialize)]
56struct LimitData {
57    #[serde(default)]
58    context: u32,
59    #[serde(default)]
60    #[allow(dead_code)]
61    output: u32,
62}
63
64/// Provider configuration for codegen (catalog providers with known model lists)
65struct ProviderConfig {
66    /// Unique provider key used in `provider_models` map (e.g. "codex")
67    dev_id: &'static str,
68    /// models.dev provider ID to read models from (defaults to `dev_id` when `None`)
69    source_dev_id: Option<&'static str>,
70    /// Additional models.dev keys whose models are merged into this provider
71    extra_source_ids: &'static [&'static str],
72    /// Only include models whose ID passes this filter (None = include all)
73    model_filter: Option<fn(&str) -> bool>,
74    /// Our Rust enum name (e.g. "Gemini")
75    enum_name: &'static str,
76    /// Our internal provider name used for parsing (e.g. "gemini")
77    parser_name: &'static str,
78    /// Human-readable provider name (e.g. "AWS Bedrock")
79    display_name: &'static str,
80    /// Env var our code actually checks (None for providers with complex credential chains)
81    env_var: Option<&'static str>,
82    /// OAuth provider ID for providers that require OAuth login (e.g. "codex")
83    oauth_provider_id: Option<&'static str>,
84    /// Default reasoning levels for models that support reasoning (empty = use standard 3)
85    default_reasoning_levels: &'static [&'static str],
86}
87
88impl ProviderConfig {
89    /// Shorthand for providers with default `source_dev_id`, `model_filter`, and `oauth_provider_id`.
90    const fn standard(
91        dev_id: &'static str,
92        enum_name: &'static str,
93        parser_name: &'static str,
94        display_name: &'static str,
95        env_var: Option<&'static str>,
96    ) -> Self {
97        Self {
98            dev_id,
99            source_dev_id: None,
100            extra_source_ids: &[],
101            model_filter: None,
102            enum_name,
103            parser_name,
104            display_name,
105            env_var,
106            oauth_provider_id: None,
107            default_reasoning_levels: &["low", "medium", "high"],
108        }
109    }
110
111    /// The models.dev key to look up in the JSON data.
112    fn json_key(&self) -> &'static str {
113        self.source_dev_id.unwrap_or(self.dev_id)
114    }
115}
116
117/// Dynamic provider — model name is user-supplied at runtime, no fixed enum
118#[allow(clippy::struct_field_names)]
119struct DynamicProviderConfig {
120    /// Rust variant name in `LlmModel` (e.g. "Ollama")
121    enum_name: &'static str,
122    /// Parser name used in "provider:model" strings (e.g. "ollama")
123    parser_name: &'static str,
124    /// Human-readable provider name (e.g. "Ollama")
125    display_name: &'static str,
126}
127
128const PROVIDERS: &[ProviderConfig] = &[
129    ProviderConfig::standard("anthropic", "Anthropic", "anthropic", "Anthropic", Some("ANTHROPIC_API_KEY")),
130    ProviderConfig {
131        dev_id: "codex",
132        source_dev_id: Some("openai"),
133        extra_source_ids: &[],
134        model_filter: Some(|id| id.contains("codex") || id.starts_with("gpt-5.4")),
135        enum_name: "Codex",
136        parser_name: "codex",
137        display_name: "Codex",
138        env_var: None,
139        oauth_provider_id: Some("codex"),
140        default_reasoning_levels: &["low", "medium", "high", "xhigh"],
141    },
142    ProviderConfig::standard("deepseek", "DeepSeek", "deepseek", "DeepSeek", Some("DEEPSEEK_API_KEY")),
143    ProviderConfig::standard("google", "Gemini", "gemini", "Gemini", Some("GEMINI_API_KEY")),
144    ProviderConfig::standard("moonshotai", "Moonshot", "moonshot", "Moonshot", Some("MOONSHOT_API_KEY")),
145    ProviderConfig::standard("openai", "Openai", "openai", "OpenAI", Some("OPENAI_API_KEY")),
146    ProviderConfig::standard("openrouter", "OpenRouter", "openrouter", "OpenRouter", Some("OPENROUTER_API_KEY")),
147    ProviderConfig {
148        extra_source_ids: &["zai-coding-plan"],
149        ..ProviderConfig::standard("zai", "ZAi", "zai", "ZAI", Some("ZAI_API_KEY"))
150    },
151    ProviderConfig::standard("amazon-bedrock", "Bedrock", "bedrock", "AWS Bedrock", None),
152];
153
154const DYNAMIC_PROVIDERS: &[DynamicProviderConfig] = &[
155    DynamicProviderConfig { enum_name: "Ollama", parser_name: "ollama", display_name: "Ollama" },
156    DynamicProviderConfig { enum_name: "LlamaCpp", parser_name: "llamacpp", display_name: "LlamaCpp" },
157];
158
159#[derive(Debug, Clone)]
160struct ModelInfo {
161    variant_name: String,
162    model_id: String,
163    display_name: String,
164    context_window: u32,
165    reasoning_levels: Vec<String>,
166    input_modalities: Vec<String>,
167}
168
169type ProviderModels = BTreeMap<&'static str, Vec<ModelInfo>>;
170
171struct CodegenCtx {
172    provider_models: ProviderModels,
173}
174
175/// Output of the code generator.
176pub struct GeneratedOutput {
177    /// The generated Rust source (for `generated.rs`).
178    pub rust_source: String,
179    /// Per-provider markdown documentation keyed by provider identifier.
180    ///
181    /// Keys are provider `dev_ids` (e.g. `"anthropic"`, `"ollama"`) and values
182    /// are markdown strings suitable for `#![doc = include_str!(...)]`.
183    pub provider_docs: HashMap<String, String>,
184}
185
186/// Run the codegen, returning the generated Rust source and per-provider docs.
187pub fn generate(models_json_path: &Path) -> Result<GeneratedOutput, String> {
188    let json_bytes = std::fs::read_to_string(models_json_path).map_err(|e| format!("read: {e}"))?;
189    let data: ModelsDevData = serde_json::from_str(&json_bytes).map_err(|e| format!("parse: {e}"))?;
190
191    let provider_models = build_provider_models(&data)?;
192    let ctx = CodegenCtx { provider_models };
193    Ok(GeneratedOutput { rust_source: emit_generated_source(&ctx), provider_docs: emit_provider_docs(&ctx) })
194}
195
196fn build_provider_models(data: &ModelsDevData) -> Result<ProviderModels, String> {
197    let mut provider_models = ProviderModels::new();
198
199    for cfg in PROVIDERS {
200        let json_key = cfg.json_key();
201        let provider_data =
202            data.get(json_key).ok_or_else(|| format!("Provider '{json_key}' not found in models.dev data"))?;
203
204        let mut models: Vec<ModelInfo> = collect_models_from(cfg, &provider_data.models);
205
206        for &extra_key in cfg.extra_source_ids {
207            if let Some(extra_data) = data.get(extra_key) {
208                let extra = collect_models_from(cfg, &extra_data.models);
209                let existing_ids: std::collections::HashSet<String> =
210                    models.iter().map(|m| m.model_id.clone()).collect();
211                models.extend(extra.into_iter().filter(|m| !existing_ids.contains(&m.model_id)));
212            }
213        }
214
215        models.sort_by(|a, b| a.model_id.cmp(&b.model_id));
216        provider_models.insert(cfg.dev_id, models);
217    }
218
219    Ok(provider_models)
220}
221
222fn collect_models_from(cfg: &ProviderConfig, models: &HashMap<String, ModelData>) -> Vec<ModelInfo> {
223    models
224        .values()
225        .filter(|m| m.tool_call == Some(true))
226        .filter(|m| !is_alias(&m.id))
227        .filter(|m| cfg.model_filter.is_none_or(|f| f(&m.id)))
228        .map(|m| {
229            let reasoning_levels = if m.reasoning.unwrap_or(false) {
230                cfg.default_reasoning_levels.iter().map(|s| (*s).to_string()).collect()
231            } else {
232                Vec::new()
233            };
234            let input_modalities =
235                m.modalities.as_ref().map_or_else(|| vec!["text".to_string()], |md| md.input.clone());
236            ModelInfo {
237                variant_name: model_id_to_variant(&m.id),
238                model_id: m.id.clone(),
239                display_name: m.name.clone(),
240                context_window: m.limit.as_ref().map_or(0, |l| l.context),
241                reasoning_levels,
242                input_modalities,
243            }
244        })
245        .collect()
246}
247
248/// Returns true for "latest" alias IDs that just point to another model
249fn is_alias(id: &str) -> bool {
250    id.ends_with("-latest")
251}
252
253/// Convert a model ID like "claude-sonnet-4-5-20250929" into a `PascalCase` variant name.
254/// Treats `-`, `.`, `/`, and `:` as word separators.
255fn model_id_to_variant(id: &str) -> String {
256    let mut result = String::new();
257    let mut capitalize_next = true;
258
259    for ch in id.chars() {
260        if ch == '-' || ch == '.' || ch == '/' || ch == ':' {
261            capitalize_next = true;
262        } else if capitalize_next {
263            result.push(ch.to_ascii_uppercase());
264            capitalize_next = false;
265        } else {
266            result.push(ch);
267        }
268    }
269
270    // If the variant starts with a digit, prefix with underscore
271    if result.starts_with(|c: char| c.is_ascii_digit()) {
272        result.insert(0, '_');
273    }
274
275    result
276}
277
278fn emit_generated_source(ctx: &CodegenCtx) -> String {
279    let mut out = String::with_capacity(64_000);
280    emit_header(&mut out);
281    emit_provider_enums(&mut out, &ctx.provider_models);
282    emit_provider_impls(&mut out, &ctx.provider_models);
283    emit_llm_model_enum(&mut out);
284    emit_from_impls(&mut out);
285    emit_llm_model_impl(&mut out);
286    emit_display_impl(&mut out);
287    emit_fromstr_impl(&mut out);
288    out
289}
290
291fn emit_header(out: &mut String) {
292    pushln(out, "// Auto-generated from models.dev — do not edit manually");
293    pushln(out, "// Regenerated automatically by build.rs");
294    blank(out);
295    pushln(out, "use std::borrow::Cow;");
296    pushln(out, "use std::sync::LazyLock;");
297    pushln(out, "use crate::ReasoningEffort;");
298    blank(out);
299}
300
301fn emit_provider_enums(out: &mut String, provider_models: &ProviderModels) {
302    for cfg in PROVIDERS {
303        pushln(out, "#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]");
304        pushln(out, format!("pub enum {}Model {{", cfg.enum_name));
305        for model in &provider_models[cfg.dev_id] {
306            pushln(out, format!("    {},", model.variant_name));
307        }
308        pushln(out, "}");
309        blank(out);
310    }
311}
312
313fn emit_provider_impls(out: &mut String, provider_models: &ProviderModels) {
314    for cfg in PROVIDERS {
315        let models = &provider_models[cfg.dev_id];
316        let enum_name = format!("{}Model", cfg.enum_name);
317
318        pushln(out, format!("impl {enum_name} {{"));
319
320        // model_id — each model has a unique ID, no grouping needed
321        pushln(out, "    #[allow(clippy::too_many_lines)]");
322        pushln(out, "    fn model_id(self) -> &'static str {");
323        pushln(out, "        match self {");
324        for model in models {
325            pushln(out, format!("            Self::{} => \"{}\",", model.variant_name, model.model_id));
326        }
327        pushln(out, "        }");
328        pushln(out, "    }");
329        blank(out);
330
331        // display_name — group variants that share the same name
332        pushln(out, "    #[allow(clippy::too_many_lines)]");
333        pushln(out, "    fn display_name(self) -> &'static str {");
334        pushln(out, "        match self {");
335        emit_grouped_arms(out, models, |m| escape_rust_string(&m.display_name), |name| format!("\"{name}\""));
336        pushln(out, "        }");
337        pushln(out, "    }");
338        blank(out);
339
340        // context_window — group variants that share the same value
341        pushln(out, "    fn context_window(self) -> u32 {");
342        pushln(out, "        match self {");
343        emit_grouped_arms(
344            out,
345            models,
346            |m| m.context_window.to_string(),
347            |val| format_number(val.parse::<u32>().unwrap()),
348        );
349        pushln(out, "        }");
350        pushln(out, "    }");
351        blank(out);
352
353        // reasoning_levels — per-model reasoning level list
354        pushln(out, "    pub fn reasoning_levels(self) -> &'static [ReasoningEffort] {");
355        pushln(out, "        match self {");
356        emit_grouped_arms(out, models, |m| m.reasoning_levels.join(","), format_reasoning_levels_rhs);
357        pushln(out, "        }");
358        pushln(out, "    }");
359        blank(out);
360
361        // supports_reasoning — derived from reasoning_levels
362        pushln(out, "    pub fn supports_reasoning(self) -> bool {");
363        pushln(out, "        !self.reasoning_levels().is_empty()");
364        pushln(out, "    }");
365        blank(out);
366
367        for modality in ["image", "audio"] {
368            pushln(out, format!("    pub fn supports_{modality}(self) -> bool {{"));
369            pushln(out, "        match self {");
370            let mod_owned = modality.to_string();
371            emit_grouped_arms(
372                out,
373                models,
374                |m| m.input_modalities.contains(&mod_owned).to_string(),
375                std::string::ToString::to_string,
376            );
377            pushln(out, "        }");
378            pushln(out, "    }");
379            blank(out);
380        }
381
382        // ALL constant
383        pushln(out, format!("    const ALL: &[{enum_name}] = &["));
384        for model in models {
385            pushln(out, format!("        Self::{},", model.variant_name));
386        }
387        pushln(out, "    ];");
388
389        pushln(out, "}");
390        blank(out);
391
392        emit_from_str_impl(out, &enum_name, cfg.parser_name, models);
393    }
394}
395
396fn emit_from_str_impl(out: &mut String, enum_name: &str, parser_name: &str, models: &[ModelInfo]) {
397    pushln(out, format!("impl std::str::FromStr for {enum_name} {{"));
398    pushln(out, "    type Err = String;");
399    blank(out);
400    pushln(out, "    #[allow(clippy::too_many_lines)]");
401    pushln(out, "    fn from_str(s: &str) -> Result<Self, Self::Err> {");
402    pushln(out, "        match s {");
403    for model in models {
404        pushln(out, format!("            \"{}\" => Ok(Self::{}),", model.model_id, model.variant_name));
405    }
406    pushln(out, format!("            _ => Err(format!(\"Unknown {parser_name} model: '{{s}}'\")),"));
407    pushln(out, "        }");
408    pushln(out, "    }");
409    pushln(out, "}");
410    blank(out);
411}
412
413/// Emit match arms grouped by value to avoid clippy `match_same_arms`.
414///
415/// `key_fn` extracts a grouping key from each model (e.g. `context_window` as string).
416/// `fmt_val` formats the key into the match arm's RHS.
417fn emit_grouped_arms(
418    out: &mut String,
419    models: &[ModelInfo],
420    key_fn: impl Fn(&ModelInfo) -> String,
421    fmt_val: impl Fn(&str) -> String,
422) {
423    // Group variants by value, preserving insertion order via BTreeMap
424    let mut groups: BTreeMap<String, Vec<&str>> = BTreeMap::new();
425    for model in models {
426        groups.entry(key_fn(model)).or_default().push(&model.variant_name);
427    }
428
429    for (key, variants) in &groups {
430        let rhs = fmt_val(key);
431        if variants.len() == 1 {
432            pushln(out, format!("            Self::{} => {rhs},", variants[0]));
433        } else {
434            let patterns: Vec<String> = variants.iter().map(|v| format!("Self::{v}")).collect();
435            pushln(out, format!("            {} => {rhs},", patterns.join(" | ")));
436        }
437    }
438}
439
440/// Format the RHS of a `reasoning_levels()` match arm from a comma-joined key.
441fn format_reasoning_levels_rhs(key: &str) -> String {
442    if key.is_empty() {
443        return "&[]".to_string();
444    }
445    let items: Vec<String> = key
446        .split(',')
447        .map(|l| {
448            let variant = level_str_to_variant(l);
449            format!("ReasoningEffort::{variant}")
450        })
451        .collect();
452    format!("&[{}]", items.join(", "))
453}
454
455/// Map a reasoning level string to its `ReasoningEffort` variant name.
456/// Panics at build time if an unknown level is encountered.
457fn level_str_to_variant(level: &str) -> &'static str {
458    match level {
459        "low" => "Low",
460        "medium" => "Medium",
461        "high" => "High",
462        "xhigh" => "Xhigh",
463        other => panic!("Unknown reasoning level: {other}"),
464    }
465}
466
467fn emit_llm_model_enum(out: &mut String) {
468    pushln(out, "/// A model from a specific provider");
469    pushln(out, "#[derive(Debug, Clone, PartialEq, Eq, Hash)]");
470    pushln(out, "pub enum LlmModel {");
471    for cfg in PROVIDERS {
472        pushln(out, format!("    {provider}({provider}Model),", provider = cfg.enum_name));
473    }
474    for dyn_cfg in DYNAMIC_PROVIDERS {
475        pushln(out, format!("    {}(String),", dyn_cfg.enum_name));
476    }
477    pushln(out, "}");
478    blank(out);
479}
480
481fn emit_from_impls(out: &mut String) {
482    for cfg in PROVIDERS {
483        pushln(out, format!("impl From<{}Model> for LlmModel {{", cfg.enum_name));
484        pushln(out, format!("    fn from(m: {}Model) -> Self {{ LlmModel::{}(m) }}", cfg.enum_name, cfg.enum_name));
485        pushln(out, "}");
486        blank(out);
487    }
488}
489
490fn emit_llm_model_impl(out: &mut String) {
491    pushln(out, "impl LlmModel {");
492    emit_llm_model_id(out);
493    emit_llm_display_name(out);
494    emit_llm_provider(out);
495    emit_llm_provider_display_name(out);
496    emit_llm_context_window(out);
497    emit_llm_required_env_var(out);
498    emit_llm_all_required_env_vars(out);
499    emit_llm_oauth_provider_id(out);
500    emit_llm_reasoning_levels(out);
501    emit_llm_supports_reasoning(out);
502    for modality in ["image", "audio"] {
503        emit_llm_supports_modality(out, modality);
504    }
505    emit_llm_all(out);
506    pushln(out, "}");
507    blank(out);
508}
509
510fn emit_llm_model_id(out: &mut String) {
511    pushln(out, "    /// Raw model ID (e.g. `claude-opus-4-6`, `llama3.2`)");
512    pushln(out, "    pub fn model_id(&self) -> Cow<'static, str> {");
513    pushln(out, "        match self {");
514    for cfg in PROVIDERS {
515        pushln(out, format!("            Self::{}(m) => Cow::Borrowed(m.model_id()),", cfg.enum_name));
516    }
517    pushln(out, format!("            {} => Cow::Owned(s.clone()),", dynamic_pattern_with_binding("s")));
518    pushln(out, "        }");
519    pushln(out, "    }");
520    blank(out);
521}
522
523fn emit_llm_display_name(out: &mut String) {
524    pushln(out, "    /// Human-readable display name (e.g. `Claude Opus 4.6`)");
525    pushln(out, "    pub fn display_name(&self) -> Cow<'static, str> {");
526    pushln(out, "        match self {");
527    for cfg in PROVIDERS {
528        pushln(out, format!("            Self::{}(m) => Cow::Borrowed(m.display_name()),", cfg.enum_name));
529    }
530    for dyn_cfg in DYNAMIC_PROVIDERS {
531        pushln(
532            out,
533            format!(
534                "            Self::{}(s) => Cow::Owned(format!(\"{} {{s}}\")),",
535                dyn_cfg.enum_name, dyn_cfg.enum_name
536            ),
537        );
538    }
539    pushln(out, "        }");
540    pushln(out, "    }");
541    blank(out);
542}
543
544fn emit_llm_provider(out: &mut String) {
545    pushln(out, "    /// Provider identifier (e.g. `anthropic`)");
546    pushln(out, "    pub fn provider(&self) -> &'static str {");
547    pushln(out, "        match self {");
548    for cfg in PROVIDERS {
549        pushln(out, format!("            Self::{}(_) => \"{}\",", cfg.enum_name, cfg.parser_name));
550    }
551    for dyn_cfg in DYNAMIC_PROVIDERS {
552        pushln(out, format!("            Self::{}(_) => \"{}\",", dyn_cfg.enum_name, dyn_cfg.parser_name));
553    }
554    pushln(out, "        }");
555    pushln(out, "    }");
556    blank(out);
557}
558
559fn emit_llm_provider_display_name(out: &mut String) {
560    pushln(out, "    /// Human-readable provider name (e.g. `AWS Bedrock`)");
561    pushln(out, "    pub fn provider_display_name(&self) -> &'static str {");
562    pushln(out, "        match self {");
563    for cfg in PROVIDERS {
564        pushln(out, format!("            Self::{}(_) => \"{}\",", cfg.enum_name, cfg.display_name));
565    }
566    for dyn_cfg in DYNAMIC_PROVIDERS {
567        pushln(out, format!("            Self::{}(_) => \"{}\",", dyn_cfg.enum_name, dyn_cfg.display_name));
568    }
569    pushln(out, "        }");
570    pushln(out, "    }");
571    blank(out);
572}
573
574fn emit_llm_context_window(out: &mut String) {
575    pushln(out, "    /// Context window size in tokens (None for dynamic providers)");
576    pushln(out, "    pub fn context_window(&self) -> Option<u32> {");
577    pushln(out, "        match self {");
578    for cfg in PROVIDERS {
579        pushln(out, format!("            Self::{}(m) => Some(m.context_window()),", cfg.enum_name));
580    }
581    pushln(out, format!("            {} => None,", dynamic_pattern_with_binding("_")));
582    pushln(out, "        }");
583    pushln(out, "    }");
584    blank(out);
585}
586
587fn emit_llm_required_env_var(out: &mut String) {
588    pushln(out, "    /// Required env var for this model's provider (None for local providers)");
589    pushln(out, "    pub fn required_env_var(&self) -> Option<&'static str> {");
590    pushln(out, "        match self {");
591    let mut none_arms: Vec<String> = Vec::new();
592    for cfg in PROVIDERS {
593        match cfg.env_var {
594            Some(var) => pushln(out, format!("            Self::{}(_) => Some(\"{}\"),", cfg.enum_name, var)),
595            None => none_arms.push(format!("Self::{}(_)", cfg.enum_name)),
596        }
597    }
598    for dyn_cfg in DYNAMIC_PROVIDERS {
599        none_arms.push(format!("Self::{}(_)", dyn_cfg.enum_name));
600    }
601    pushln(out, format!("            {} => None,", none_arms.join(" | ")));
602    pushln(out, "        }");
603    pushln(out, "    }");
604    blank(out);
605}
606
607fn emit_llm_all_required_env_vars(out: &mut String) {
608    let vars: Vec<&str> = PROVIDERS.iter().filter_map(|cfg| cfg.env_var).collect();
609    pushln(out, "    /// All provider API key env var names (deduplicated, static)");
610    pushln(
611        out,
612        format!(
613            "    pub const ALL_REQUIRED_ENV_VARS: &[&str] = &[{}];",
614            vars.iter().map(|v| format!("\"{v}\"")).collect::<Vec<_>>().join(", ")
615        ),
616    );
617    blank(out);
618}
619
620fn emit_llm_oauth_provider_id(out: &mut String) {
621    pushln(out, "    /// OAuth provider ID if this model requires OAuth login (e.g. `\"codex\"`)");
622    pushln(out, "    pub fn oauth_provider_id(&self) -> Option<&'static str> {");
623    pushln(out, "        match self {");
624    let mut none_arms: Vec<String> = Vec::new();
625    for cfg in PROVIDERS {
626        match cfg.oauth_provider_id {
627            Some(id) => pushln(out, format!("            Self::{}(_) => Some(\"{}\"),", cfg.enum_name, id)),
628            None => none_arms.push(format!("Self::{}(_)", cfg.enum_name)),
629        }
630    }
631    for dyn_cfg in DYNAMIC_PROVIDERS {
632        none_arms.push(format!("Self::{}(_)", dyn_cfg.enum_name));
633    }
634    pushln(out, format!("            {} => None,", none_arms.join(" | ")));
635    pushln(out, "        }");
636    pushln(out, "    }");
637    blank(out);
638}
639
640fn emit_llm_reasoning_levels(out: &mut String) {
641    pushln(out, "    /// Reasoning levels supported by this model (empty if not a reasoning model)");
642    pushln(out, "    pub fn reasoning_levels(&self) -> &'static [ReasoningEffort] {");
643    pushln(out, "        match self {");
644    for cfg in PROVIDERS {
645        pushln(out, format!("            Self::{}(m) => m.reasoning_levels(),", cfg.enum_name));
646    }
647    pushln(out, format!("            {} => &[],", dynamic_pattern_with_binding("_")));
648    pushln(out, "        }");
649    pushln(out, "    }");
650    blank(out);
651}
652
653fn emit_llm_supports_reasoning(out: &mut String) {
654    pushln(out, "    /// Whether this model supports reasoning/extended thinking");
655    pushln(out, "    pub fn supports_reasoning(&self) -> bool {");
656    pushln(out, "        !self.reasoning_levels().is_empty()");
657    pushln(out, "    }");
658    blank(out);
659}
660
661fn emit_llm_supports_modality(out: &mut String, modality: &str) {
662    pushln(out, format!("    /// Whether this model supports {modality} input"));
663    pushln(out, format!("    pub fn supports_{modality}(&self) -> bool {{"));
664    pushln(out, "        match self {");
665    for cfg in PROVIDERS {
666        pushln(out, format!("            Self::{}(m) => m.supports_{modality}(),", cfg.enum_name));
667    }
668    pushln(out, format!("            {} => false,", dynamic_pattern_with_binding("_")));
669    pushln(out, "        }");
670    pushln(out, "    }");
671    blank(out);
672}
673
674fn emit_llm_all(out: &mut String) {
675    pushln(out, "    /// All catalog models (excludes dynamic providers)");
676    pushln(out, "    pub fn all() -> &'static [LlmModel] {");
677    pushln(out, "        static ALL: LazyLock<Vec<LlmModel>> = LazyLock::new(|| {");
678    pushln(out, "            let mut v = Vec::new();");
679    for cfg in PROVIDERS {
680        pushln(
681            out,
682            format!(
683                "            v.extend({}Model::ALL.iter().copied().map(LlmModel::{}));",
684                cfg.enum_name, cfg.enum_name
685            ),
686        );
687    }
688    pushln(out, "            v");
689    pushln(out, "        });");
690    pushln(out, "        &ALL");
691    pushln(out, "    }");
692}
693
694fn emit_display_impl(out: &mut String) {
695    pushln(out, "impl std::fmt::Display for LlmModel {");
696    pushln(out, "    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {");
697    pushln(out, "        write!(f, \"{}:{}\", self.provider(), self.model_id())");
698    pushln(out, "    }");
699    pushln(out, "}");
700    blank(out);
701}
702
703fn emit_fromstr_impl(out: &mut String) {
704    pushln(out, "impl std::str::FromStr for LlmModel {");
705    pushln(out, "    type Err = String;");
706    blank(out);
707    pushln(out, "    /// Parse a `provider:model` string into an `LlmModel`");
708    pushln(out, "    fn from_str(s: &str) -> Result<Self, Self::Err> {");
709    pushln(out, "        let (provider_str, model_str) = s.split_once(':').unwrap_or((s, \"\"));");
710    pushln(out, "        match provider_str {");
711    for cfg in PROVIDERS {
712        pushln(
713            out,
714            format!(
715                "            \"{}\" => model_str.parse::<{}Model>().map(Self::{}),",
716                cfg.parser_name, cfg.enum_name, cfg.enum_name
717            ),
718        );
719    }
720    for dyn_cfg in DYNAMIC_PROVIDERS {
721        pushln(
722            out,
723            format!(
724                "            \"{}\" => Ok(Self::{}(model_str.to_string())),",
725                dyn_cfg.parser_name, dyn_cfg.enum_name
726            ),
727        );
728    }
729    pushln(out, "            _ => Err(format!(\"Unknown provider: '{provider_str}'\")),");
730    pushln(out, "        }");
731    pushln(out, "    }");
732    pushln(out, "}");
733}
734
735/// Build a combined `|` pattern for all dynamic providers with a binding variable.
736/// e.g. `Self::Ollama(s) | Self::LlamaCpp(s)` or `Self::Ollama(_) | Self::LlamaCpp(_)`
737fn dynamic_pattern_with_binding(binding: &str) -> String {
738    DYNAMIC_PROVIDERS.iter().map(|d| format!("Self::{}({binding})", d.enum_name)).collect::<Vec<_>>().join(" | ")
739}
740
741/// Format a number with underscore separators (e.g. `200000` → `200_000`).
742fn format_number(n: u32) -> String {
743    let s = n.to_string();
744    if s.len() <= 4 {
745        return s;
746    }
747    let mut result = String::with_capacity(s.len() + s.len() / 3);
748    for (i, ch) in s.chars().enumerate() {
749        if i > 0 && (s.len() - i).is_multiple_of(3) {
750            result.push('_');
751        }
752        result.push(ch);
753    }
754    result
755}
756
757fn escape_rust_string(raw: &str) -> String {
758    raw.replace('\\', "\\\\").replace('"', "\\\"")
759}
760
761fn emit_provider_docs(ctx: &CodegenCtx) -> HashMap<String, String> {
762    let mut docs = HashMap::new();
763
764    for cfg in PROVIDERS {
765        let models = &ctx.provider_models[cfg.dev_id];
766        let mut doc = String::new();
767
768        pushln(&mut doc, format!("`{}` LLM provider.", cfg.display_name));
769        blank(&mut doc);
770
771        // Authentication
772        pushln(&mut doc, "# Authentication");
773        blank(&mut doc);
774        match cfg.env_var {
775            Some(var) => pushln(&mut doc, format!("Set the `{var}` environment variable.")),
776            None if cfg.oauth_provider_id.is_some() => {
777                pushln(&mut doc, "This provider uses OAuth authentication.");
778            }
779            None => {
780                pushln(
781                    &mut doc,
782                    "Uses the default AWS credential chain (environment variables, config files, IAM roles).",
783                );
784            }
785        }
786        blank(&mut doc);
787
788        // Supported models table
789        pushln(&mut doc, "# Supported models");
790        blank(&mut doc);
791        pushln(&mut doc, "| Model ID | Name | Context | Reasoning | Image | Audio |");
792        pushln(&mut doc, "|----------|------|---------|-----------|-------|-------|");
793        for model in models {
794            let ctx_str = format_context_window(model.context_window);
795            let reasoning = if model.reasoning_levels.is_empty() { "" } else { "yes" };
796            let image = if model.input_modalities.contains(&"image".to_string()) { "yes" } else { "" };
797            let audio = if model.input_modalities.contains(&"audio".to_string()) { "yes" } else { "" };
798            pushln(
799                &mut doc,
800                format!(
801                    "| `{}` | `{}` | `{}` | {} | {} | {} |",
802                    model.model_id, model.display_name, ctx_str, reasoning, image, audio
803                ),
804            );
805        }
806
807        docs.insert(cfg.dev_id.to_string(), doc);
808    }
809
810    // Dynamic providers
811    for dyn_cfg in DYNAMIC_PROVIDERS {
812        let mut doc = String::new();
813        pushln(&mut doc, format!("`{}` LLM provider.", dyn_cfg.display_name));
814        blank(&mut doc);
815        pushln(
816            &mut doc,
817            format!("This provider accepts any model name at runtime (e.g. `{}:my-model`).", dyn_cfg.parser_name),
818        );
819        pushln(&mut doc, "No API key is required.");
820        docs.insert(dyn_cfg.parser_name.to_string(), doc);
821    }
822
823    docs
824}
825
826/// Format a token count as human-readable (e.g. `1_000_000` → `1M`, `200_000` → `200k`).
827fn format_context_window(tokens: u32) -> String {
828    if tokens == 0 {
829        return "unknown".to_string();
830    }
831    if tokens >= 1_000_000 && tokens.is_multiple_of(1_000_000) {
832        format!("{}M", tokens / 1_000_000)
833    } else if tokens >= 1_000 && tokens.is_multiple_of(1_000) {
834        format!("{}k", tokens / 1_000)
835    } else {
836        format_number(tokens)
837    }
838}
839
840fn pushln(out: &mut String, line: impl AsRef<str>) {
841    writeln!(out, "{}", line.as_ref()).expect("writing to String should not fail");
842}
843
844fn blank(out: &mut String) {
845    pushln(out, "");
846}
847
848// ── Tests ────────────────────────────────────────────────────────────────
849
850#[cfg(test)]
851mod tests {
852    use super::*;
853    use serde_json::Value;
854    use serde_json::json;
855    use tempfile::NamedTempFile;
856
857    #[test]
858    fn generate_sorts_and_filters_models() {
859        let mut data = minimal_models_dev_json();
860        let root = data.as_object_mut().expect("root object");
861        let anthropic = root.get_mut("anthropic").and_then(Value::as_object_mut).expect("anthropic provider");
862
863        anthropic.insert(
864            "models".to_string(),
865            json!({
866                "b-model": {
867                    "id": "b-model",
868                    "name": "B Model",
869                    "tool_call": true,
870                    "limit": {"context": 2000, "output": 0}
871                },
872                "a-model": {
873                    "id": "a-model",
874                    "name": "A Model",
875                    "tool_call": true,
876                    "limit": {"context": 1000, "output": 0}
877                },
878                "alpha-latest": {
879                    "id": "alpha-latest",
880                    "name": "Alias",
881                    "tool_call": true,
882                    "limit": {"context": 500, "output": 0}
883                },
884                "no-tools": {
885                    "id": "no-tools",
886                    "name": "No Tools",
887                    "tool_call": false,
888                    "limit": {"context": 500, "output": 0}
889                }
890            }),
891        );
892
893        let source = generate_from_value(&data);
894        // Provider-level FromStr: sorted model IDs
895        let a_model = "\"a-model\" => Ok(Self::AModel),";
896        let b_model = "\"b-model\" => Ok(Self::BModel),";
897        let a_pos = source.find(a_model).expect("a-model parse arm");
898        let b_pos = source.find(b_model).expect("b-model parse arm");
899        assert!(a_pos < b_pos);
900        // Aliases and non-tool-call models are excluded
901        assert!(!source.contains("AlphaLatest"));
902        assert!(!source.contains("NoTools"));
903    }
904
905    #[test]
906    fn generate_contains_core_sections() {
907        let source = generate_from_value(&minimal_models_dev_json());
908        assert!(source.contains("pub enum LlmModel {"));
909        assert!(source.contains("impl std::str::FromStr for LlmModel {"));
910        assert!(source.contains("impl std::fmt::Display for LlmModel {"));
911        assert!(source.contains("pub fn required_env_var(&self) -> Option<&'static str> {"));
912    }
913
914    #[test]
915    fn generate_contains_dynamic_provider_arms() {
916        let source = generate_from_value(&minimal_models_dev_json());
917        assert!(source.contains("\"ollama\" => Ok(Self::Ollama(model_str.to_string())),"));
918        assert!(source.contains("\"llamacpp\" => Ok(Self::LlamaCpp(model_str.to_string())),"));
919        // Dynamic providers are combined with | for None-returning arms
920        assert!(source.contains("Self::Ollama(_) | Self::LlamaCpp(_) => None,"));
921    }
922
923    #[test]
924    fn generate_codex_is_catalog_provider() {
925        let source = generate_from_value(&minimal_models_dev_json());
926        // Codex is a catalog provider, not dynamic
927        assert!(source.contains("pub enum CodexModel {"));
928        assert!(source.contains("\"codex\" => model_str.parse::<CodexModel>().map(Self::Codex),"));
929        assert!(source.contains("Self::Codex(m) => Some(m.context_window()),"));
930    }
931
932    #[test]
933    fn generate_oauth_provider_id_for_codex() {
934        let source = generate_from_value(&minimal_models_dev_json());
935        // Codex models return Some("codex") for oauth_provider_id
936        assert!(source.contains("Self::Codex(_) => Some(\"codex\"),"));
937        // Non-OAuth providers return None
938        assert!(source.contains("pub fn oauth_provider_id(&self) -> Option<&'static str>"));
939    }
940
941    #[test]
942    fn generate_delegates_to_provider_impls() {
943        let source = generate_from_value(&minimal_models_dev_json());
944        // LlmModel delegates to per-provider methods
945        assert!(source.contains("Self::Anthropic(m) => Cow::Borrowed(m.model_id()),"));
946        assert!(source.contains("Self::Anthropic(m) => Some(m.context_window()),"));
947        // Provider-level FromStr is used by LlmModel::FromStr
948        assert!(source.contains("\"anthropic\" => model_str.parse::<AnthropicModel>().map(Self::Anthropic),"));
949    }
950
951    #[test]
952    fn generate_formats_large_numbers_with_separators() {
953        let mut data = minimal_models_dev_json();
954        let root = data.as_object_mut().expect("root object");
955        let anthropic = root.get_mut("anthropic").and_then(Value::as_object_mut).expect("anthropic provider");
956
957        anthropic.insert(
958            "models".to_string(),
959            json!({
960                "big-model": {
961                    "id": "big-model",
962                    "name": "Big Model",
963                    "tool_call": true,
964                    "limit": {"context": 200_000, "output": 0}
965                }
966            }),
967        );
968
969        let source = generate_from_value(&data);
970        assert!(source.contains("200_000"));
971        assert!(!source.contains("200000"));
972    }
973
974    #[test]
975    fn generate_groups_identical_match_arms() {
976        let mut data = minimal_models_dev_json();
977        let root = data.as_object_mut().expect("root object");
978        let anthropic = root.get_mut("anthropic").and_then(Value::as_object_mut).expect("anthropic provider");
979
980        anthropic.insert(
981            "models".to_string(),
982            json!({
983                "model-a": {
984                    "id": "model-a",
985                    "name": "Same Name",
986                    "tool_call": true,
987                    "limit": {"context": 100_000, "output": 0}
988                },
989                "model-b": {
990                    "id": "model-b",
991                    "name": "Same Name",
992                    "tool_call": true,
993                    "limit": {"context": 100_000, "output": 0}
994                }
995            }),
996        );
997
998        let source = generate_from_value(&data);
999        // Both context_window and display_name should combine arms
1000        assert!(source.contains("Self::ModelA | Self::ModelB => 100_000,"));
1001        assert!(source.contains("Self::ModelA | Self::ModelB => \"Same Name\","));
1002    }
1003
1004    #[test]
1005    fn test_model_id_to_variant() {
1006        assert_eq!(model_id_to_variant("claude-sonnet-4-5-20250929"), "ClaudeSonnet4520250929");
1007        assert_eq!(model_id_to_variant("gemini-2.5-flash"), "Gemini25Flash");
1008        assert_eq!(model_id_to_variant("deepseek-chat"), "DeepseekChat");
1009        assert_eq!(model_id_to_variant("glm-4.5"), "Glm45");
1010    }
1011
1012    #[test]
1013    fn test_model_id_to_variant_with_slash_and_colon() {
1014        assert_eq!(model_id_to_variant("anthropic/claude-opus-4.6"), "AnthropicClaudeOpus46");
1015        assert_eq!(model_id_to_variant("openai/gpt-5.1-codex-max"), "OpenaiGpt51CodexMax");
1016        assert_eq!(model_id_to_variant("deepseek/deepseek-r1:free"), "DeepseekDeepseekR1Free");
1017    }
1018
1019    #[test]
1020    fn test_is_alias() {
1021        assert!(is_alias("claude-sonnet-4-5-latest"));
1022        assert!(is_alias("claude-3-7-sonnet-latest"));
1023        assert!(!is_alias("claude-sonnet-4-5-20250929"));
1024    }
1025
1026    #[test]
1027    fn generate_contains_reasoning_levels_and_supports_reasoning() {
1028        let mut data = minimal_models_dev_json();
1029        let root = data.as_object_mut().expect("root object");
1030        let anthropic = root.get_mut("anthropic").and_then(Value::as_object_mut).expect("anthropic provider");
1031
1032        anthropic.insert(
1033            "models".to_string(),
1034            json!({
1035                "thinker": {
1036                    "id": "thinker",
1037                    "name": "Thinker",
1038                    "tool_call": true,
1039                    "reasoning": true,
1040                    "limit": {"context": 1000, "output": 0}
1041                },
1042                "fast": {
1043                    "id": "fast",
1044                    "name": "Fast",
1045                    "tool_call": true,
1046                    "reasoning": false,
1047                    "limit": {"context": 1000, "output": 0}
1048                }
1049            }),
1050        );
1051
1052        let source = generate_from_value(&data);
1053        // Provider enum should have reasoning_levels method
1054        assert!(source.contains("pub fn reasoning_levels(self) -> &'static [ReasoningEffort] {"));
1055        // Thinker (anthropic) gets standard 3 levels
1056        assert!(
1057            source
1058                .contains("Self::Thinker => &[ReasoningEffort::Low, ReasoningEffort::Medium, ReasoningEffort::High],")
1059        );
1060        // Fast gets empty
1061        assert!(source.contains("Self::Fast => &[],"));
1062        // supports_reasoning delegates to reasoning_levels
1063        assert!(source.contains("pub fn supports_reasoning(self) -> bool {"));
1064        assert!(source.contains("!self.reasoning_levels().is_empty()"));
1065        // LlmModel level
1066        assert!(source.contains("pub fn reasoning_levels(&self) -> &'static [ReasoningEffort] {"));
1067        // Dynamic providers return empty
1068        assert!(source.contains("Self::Ollama(_) | Self::LlamaCpp(_) => &[],"));
1069    }
1070
1071    #[test]
1072    fn generate_codex_gets_four_reasoning_levels() {
1073        let mut data = minimal_models_dev_json();
1074        let root = data.as_object_mut().expect("root object");
1075        let openai = root.get_mut("openai").and_then(Value::as_object_mut).expect("openai provider");
1076
1077        openai.insert(
1078            "models".to_string(),
1079            json!({
1080                "gpt-5.4-codex": {
1081                    "id": "gpt-5.4-codex",
1082                    "name": "GPT-5.4 Codex",
1083                    "tool_call": true,
1084                    "reasoning": true,
1085                    "limit": {"context": 200_000, "output": 0}
1086                }
1087            }),
1088        );
1089
1090        let source = generate_from_value(&data);
1091        assert!(
1092            source.contains(
1093                "ReasoningEffort::Low, ReasoningEffort::Medium, ReasoningEffort::High, ReasoningEffort::Xhigh"
1094            ),
1095            "Codex reasoning model should have 4 levels including Xhigh"
1096        );
1097    }
1098
1099    fn generate_from_value(data: &Value) -> String {
1100        let tmp = NamedTempFile::new().expect("temp file");
1101        let json = serde_json::to_string(data).expect("serialize fixture");
1102        std::fs::write(tmp.path(), json).expect("write fixture");
1103        generate(tmp.path()).expect("codegen succeeds").rust_source
1104    }
1105
1106    fn minimal_models_dev_json() -> Value {
1107        let mut root = serde_json::Map::new();
1108        for cfg in PROVIDERS {
1109            let json_key = cfg.json_key();
1110            root.entry(json_key.to_string()).or_insert_with(|| {
1111                json!({
1112                    "id": json_key,
1113                    "name": json_key,
1114                    "env": [],
1115                    "models": {}
1116                })
1117            });
1118            for &extra in cfg.extra_source_ids {
1119                root.entry(extra.to_string()).or_insert_with(|| {
1120                    json!({
1121                        "id": extra,
1122                        "name": extra,
1123                        "env": [],
1124                        "models": {}
1125                    })
1126                });
1127            }
1128        }
1129        Value::Object(root)
1130    }
1131
1132    #[test]
1133    fn extra_source_ids_merges_models_into_provider() {
1134        let mut data = minimal_models_dev_json();
1135        let root = data.as_object_mut().unwrap();
1136
1137        // Add a model to the zai-coding-plan extra source that doesn't exist in zai
1138        let extra = root.get_mut("zai-coding-plan").unwrap().as_object_mut().unwrap();
1139        extra.insert(
1140            "models".to_string(),
1141            json!({
1142                "extra-model": {
1143                    "id": "extra-model",
1144                    "name": "Extra Model",
1145                    "tool_call": true,
1146                    "limit": {"context": 4000, "output": 0}
1147                }
1148            }),
1149        );
1150
1151        let source = generate_from_value(&data);
1152        // The extra model should appear under ZAi
1153        assert!(source.contains("\"extra-model\" => Ok(Self::ExtraModel),"));
1154    }
1155
1156    #[test]
1157    fn extra_source_ids_does_not_duplicate_existing_models() {
1158        let mut data = minimal_models_dev_json();
1159        let root = data.as_object_mut().unwrap();
1160
1161        // Add same model ID to both zai and zai-coding-plan
1162        let zai = root.get_mut("zai").unwrap().as_object_mut().unwrap();
1163        zai.insert(
1164            "models".to_string(),
1165            json!({
1166                "shared-model": {
1167                    "id": "shared-model",
1168                    "name": "Shared Model",
1169                    "tool_call": true,
1170                    "limit": {"context": 1000, "output": 0}
1171                }
1172            }),
1173        );
1174        let extra = root.get_mut("zai-coding-plan").unwrap().as_object_mut().unwrap();
1175        extra.insert(
1176            "models".to_string(),
1177            json!({
1178                "shared-model": {
1179                    "id": "shared-model",
1180                    "name": "Shared Model Duplicate",
1181                    "tool_call": true,
1182                    "limit": {"context": 2000, "output": 0}
1183                }
1184            }),
1185        );
1186
1187        let source = generate_from_value(&data);
1188        let from_str_matches = source.matches("\"shared-model\" => Ok(Self::SharedModel),").count();
1189        assert_eq!(from_str_matches, 1);
1190    }
1191
1192    #[test]
1193    fn generate_emits_provider_docs() {
1194        let mut data = minimal_models_dev_json();
1195        let root = data.as_object_mut().unwrap();
1196        let anthropic = root.get_mut("anthropic").and_then(Value::as_object_mut).unwrap();
1197
1198        anthropic.insert(
1199            "models".to_string(),
1200            json!({
1201                "claude-test": {
1202                    "id": "claude-test",
1203                    "name": "Claude Test",
1204                    "tool_call": true,
1205                    "reasoning": true,
1206                    "limit": {"context": 200_000, "output": 0},
1207                    "modalities": {"input": ["text", "image"]}
1208                }
1209            }),
1210        );
1211
1212        let tmp = NamedTempFile::new().unwrap();
1213        std::fs::write(tmp.path(), serde_json::to_string(&data).unwrap()).unwrap();
1214        let output = generate(tmp.path()).unwrap();
1215
1216        let anthropic_doc = &output.provider_docs["anthropic"];
1217        assert!(anthropic_doc.contains("`Anthropic` LLM provider."));
1218        assert!(anthropic_doc.contains("`ANTHROPIC_API_KEY`"));
1219        assert!(anthropic_doc.contains("| `claude-test` | `Claude Test` | `200k` | yes | yes |  |"));
1220
1221        // Dynamic providers get a short doc
1222        let ollama_doc = &output.provider_docs["ollama"];
1223        assert!(ollama_doc.contains("`Ollama` LLM provider."));
1224        assert!(ollama_doc.contains("any model name at runtime"));
1225    }
1226
1227    #[test]
1228    fn format_context_window_formats_correctly() {
1229        assert_eq!(format_context_window(1_000_000), "1M");
1230        assert_eq!(format_context_window(200_000), "200k");
1231        assert_eq!(format_context_window(8_000), "8k");
1232        assert_eq!(format_context_window(0), "unknown");
1233    }
1234
1235    #[test]
1236    fn level_str_to_variant_covers_all_reasoning_efforts() {
1237        for effort in utils::ReasoningEffort::all() {
1238            // Should not panic for any known variant
1239            let _ = level_str_to_variant(effort.as_str());
1240        }
1241    }
1242}