1#![doc = include_str!("../README.md")]
2
3use serde::Deserialize;
4use std::collections::{BTreeMap, HashMap};
5use std::fmt::Write;
6use std::path::Path;
7
8type ModelsDevData = HashMap<String, ProviderData>;
9type ContextWindowOverride = fn(&str, u32) -> u32;
10
11#[derive(Debug, Deserialize)]
12struct ProviderData {
13 #[allow(dead_code)]
14 id: String,
15 #[allow(dead_code)]
16 name: String,
17 #[serde(default)]
18 #[allow(dead_code)]
19 env: Vec<String>,
20 #[serde(default)]
21 models: HashMap<String, ModelData>,
22}
23
24#[derive(Debug, Deserialize)]
25struct ModelData {
26 id: String,
27 name: String,
28 #[serde(default)]
29 tool_call: Option<bool>,
30 #[serde(default)]
31 reasoning: Option<bool>,
32 #[serde(default)]
33 #[allow(dead_code)]
34 cost: Option<CostData>,
35 #[serde(default)]
36 limit: Option<LimitData>,
37 #[serde(default)]
38 modalities: Option<ModalitiesData>,
39}
40
41#[derive(Debug, Deserialize, Default)]
42struct ModalitiesData {
43 #[serde(default)]
44 input: Vec<String>,
45}
46
47#[derive(Debug, Deserialize)]
48#[allow(dead_code)]
49struct CostData {
50 #[serde(default)]
51 input: f64,
52 #[serde(default)]
53 output: f64,
54}
55
56#[derive(Debug, Deserialize)]
57struct LimitData {
58 #[serde(default)]
59 context: u32,
60 #[serde(default)]
61 #[allow(dead_code)]
62 output: u32,
63}
64
65struct ProviderConfig {
67 dev_id: &'static str,
69 source_dev_id: Option<&'static str>,
71 extra_source_ids: &'static [&'static str],
73 model_filter: Option<fn(&str) -> bool>,
75 context_window_override: Option<ContextWindowOverride>,
77 enum_name: &'static str,
79 parser_name: &'static str,
81 display_name: &'static str,
83 env_var: Option<&'static str>,
85 oauth_provider_id: Option<&'static str>,
87 default_reasoning_levels: &'static [&'static str],
89}
90
91impl ProviderConfig {
92 const fn standard(
94 dev_id: &'static str,
95 enum_name: &'static str,
96 parser_name: &'static str,
97 display_name: &'static str,
98 env_var: Option<&'static str>,
99 ) -> Self {
100 Self {
101 dev_id,
102 source_dev_id: None,
103 extra_source_ids: &[],
104 model_filter: None,
105 context_window_override: None,
106 enum_name,
107 parser_name,
108 display_name,
109 env_var,
110 oauth_provider_id: None,
111 default_reasoning_levels: &["low", "medium", "high"],
112 }
113 }
114
115 fn json_key(&self) -> &'static str {
117 self.source_dev_id.unwrap_or(self.dev_id)
118 }
119}
120
121#[allow(clippy::struct_field_names)]
123struct DynamicProviderConfig {
124 enum_name: &'static str,
126 parser_name: &'static str,
128 display_name: &'static str,
130}
131
132const PROVIDERS: &[ProviderConfig] = &[
133 ProviderConfig::standard("anthropic", "Anthropic", "anthropic", "Anthropic", Some("ANTHROPIC_API_KEY")),
134 ProviderConfig {
135 dev_id: "codex",
136 source_dev_id: Some("openai"),
137 extra_source_ids: &[],
138 model_filter: Some(|id| id.contains("codex") || id.starts_with("gpt-5.") || id == "gpt-5"),
139 context_window_override: Some(codex_subscription_context_window),
140 enum_name: "Codex",
141 parser_name: "codex",
142 display_name: "Codex",
143 env_var: None,
144 oauth_provider_id: Some("codex"),
145 default_reasoning_levels: &["low", "medium", "high", "xhigh"],
146 },
147 ProviderConfig::standard("deepseek", "DeepSeek", "deepseek", "DeepSeek", Some("DEEPSEEK_API_KEY")),
148 ProviderConfig::standard("google", "Gemini", "gemini", "Gemini", Some("GEMINI_API_KEY")),
149 ProviderConfig::standard("moonshotai", "Moonshot", "moonshot", "Moonshot", Some("MOONSHOT_API_KEY")),
150 ProviderConfig::standard("openai", "Openai", "openai", "OpenAI", Some("OPENAI_API_KEY")),
151 ProviderConfig::standard("openrouter", "OpenRouter", "openrouter", "OpenRouter", Some("OPENROUTER_API_KEY")),
152 ProviderConfig {
153 extra_source_ids: &["zai-coding-plan"],
154 ..ProviderConfig::standard("zai", "ZAi", "zai", "ZAI", Some("ZAI_API_KEY"))
155 },
156 ProviderConfig::standard("amazon-bedrock", "Bedrock", "bedrock", "AWS Bedrock", None),
157];
158
159const DYNAMIC_PROVIDERS: &[DynamicProviderConfig] = &[
160 DynamicProviderConfig { enum_name: "Ollama", parser_name: "ollama", display_name: "Ollama" },
161 DynamicProviderConfig { enum_name: "LlamaCpp", parser_name: "llamacpp", display_name: "LlamaCpp" },
162];
163
164const CODEX_SUBSCRIPTION_CONTEXT_WINDOW: u32 = 272_000;
165
166fn codex_subscription_context_window(model_id: &str, default_context_window: u32) -> u32 {
167 match model_id {
168 "gpt-5.5" | "gpt-5.4" | "gpt-5.4-mini" | "gpt-5.3-codex" | "gpt-5.2" | "codex-auto-review" => {
169 CODEX_SUBSCRIPTION_CONTEXT_WINDOW
170 }
171 _ => default_context_window,
172 }
173}
174
175#[derive(Debug, Clone)]
176struct ModelInfo {
177 variant_name: String,
178 model_id: String,
179 display_name: String,
180 context_window: u32,
181 reasoning_levels: Vec<String>,
182 input_modalities: Vec<String>,
183}
184
185type ProviderModels = BTreeMap<&'static str, Vec<ModelInfo>>;
186
187struct CodegenCtx {
188 provider_models: ProviderModels,
189}
190
191pub struct GeneratedOutput {
193 pub rust_source: String,
195 pub provider_docs: HashMap<String, String>,
200}
201
202pub fn generate(models_json_path: &Path) -> Result<GeneratedOutput, String> {
204 let json_bytes = std::fs::read_to_string(models_json_path).map_err(|e| format!("read: {e}"))?;
205 let data: ModelsDevData = serde_json::from_str(&json_bytes).map_err(|e| format!("parse: {e}"))?;
206
207 let provider_models = build_provider_models(&data)?;
208 let ctx = CodegenCtx { provider_models };
209 Ok(GeneratedOutput { rust_source: emit_generated_source(&ctx), provider_docs: emit_provider_docs(&ctx) })
210}
211
212fn build_provider_models(data: &ModelsDevData) -> Result<ProviderModels, String> {
213 let mut provider_models = ProviderModels::new();
214
215 for cfg in PROVIDERS {
216 let json_key = cfg.json_key();
217 let provider_data =
218 data.get(json_key).ok_or_else(|| format!("Provider '{json_key}' not found in models.dev data"))?;
219
220 let mut models: Vec<ModelInfo> = collect_models_from(cfg, &provider_data.models);
221
222 for &extra_key in cfg.extra_source_ids {
223 if let Some(extra_data) = data.get(extra_key) {
224 let extra = collect_models_from(cfg, &extra_data.models);
225 let existing_ids: std::collections::HashSet<String> =
226 models.iter().map(|m| m.model_id.clone()).collect();
227 models.extend(extra.into_iter().filter(|m| !existing_ids.contains(&m.model_id)));
228 }
229 }
230
231 models.sort_by(|a, b| a.model_id.cmp(&b.model_id));
232 provider_models.insert(cfg.dev_id, models);
233 }
234
235 Ok(provider_models)
236}
237
238fn collect_models_from(cfg: &ProviderConfig, models: &HashMap<String, ModelData>) -> Vec<ModelInfo> {
239 models
240 .values()
241 .filter(|m| m.tool_call == Some(true))
242 .filter(|m| !is_alias(&m.id))
243 .filter(|m| cfg.model_filter.is_none_or(|f| f(&m.id)))
244 .map(|m| {
245 let reasoning_levels = if m.reasoning.unwrap_or(false) {
246 cfg.default_reasoning_levels.iter().map(|s| (*s).to_string()).collect()
247 } else {
248 Vec::new()
249 };
250 let input_modalities =
251 m.modalities.as_ref().map_or_else(|| vec!["text".to_string()], |md| md.input.clone());
252 let source_context_window = m.limit.as_ref().map_or(0, |l| l.context);
253 let context_window = cfg.context_window_override.map_or(source_context_window, |override_context_window| {
254 override_context_window(&m.id, source_context_window)
255 });
256 ModelInfo {
257 variant_name: model_id_to_variant(&m.id),
258 model_id: m.id.clone(),
259 display_name: m.name.clone(),
260 context_window,
261 reasoning_levels,
262 input_modalities,
263 }
264 })
265 .collect()
266}
267
268fn is_alias(id: &str) -> bool {
270 id.ends_with("-latest")
271}
272
273fn model_id_to_variant(id: &str) -> String {
276 let mut result = String::new();
277 let mut capitalize_next = true;
278
279 for ch in id.chars() {
280 if ch == '-' || ch == '.' || ch == '/' || ch == ':' {
281 capitalize_next = true;
282 } else if capitalize_next {
283 result.push(ch.to_ascii_uppercase());
284 capitalize_next = false;
285 } else {
286 result.push(ch);
287 }
288 }
289
290 if result.starts_with(|c: char| c.is_ascii_digit()) {
292 result.insert(0, '_');
293 }
294
295 result
296}
297
298fn emit_generated_source(ctx: &CodegenCtx) -> String {
299 let mut out = String::with_capacity(64_000);
300 emit_header(&mut out);
301 emit_provider_enums(&mut out, &ctx.provider_models);
302 emit_provider_impls(&mut out, &ctx.provider_models);
303 emit_llm_model_enum(&mut out);
304 emit_from_impls(&mut out);
305 emit_llm_model_impl(&mut out);
306 emit_display_impl(&mut out);
307 emit_fromstr_impl(&mut out);
308 out
309}
310
311fn emit_header(out: &mut String) {
312 pushln(out, "// Auto-generated from models.dev — do not edit manually");
313 pushln(out, "// Regenerated automatically by build.rs");
314 blank(out);
315 pushln(out, "use std::borrow::Cow;");
316 pushln(out, "use std::sync::LazyLock;");
317 pushln(out, "use crate::ReasoningEffort;");
318 blank(out);
319}
320
321fn emit_provider_enums(out: &mut String, provider_models: &ProviderModels) {
322 for cfg in PROVIDERS {
323 pushln(out, "#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]");
324 pushln(out, format!("pub enum {}Model {{", cfg.enum_name));
325 for model in &provider_models[cfg.dev_id] {
326 pushln(out, format!(" {},", model.variant_name));
327 }
328 pushln(out, "}");
329 blank(out);
330 }
331}
332
333fn emit_provider_impls(out: &mut String, provider_models: &ProviderModels) {
334 for cfg in PROVIDERS {
335 let models = &provider_models[cfg.dev_id];
336 let enum_name = format!("{}Model", cfg.enum_name);
337
338 pushln(out, format!("impl {enum_name} {{"));
339
340 pushln(out, " #[allow(clippy::too_many_lines)]");
342 pushln(out, " fn model_id(self) -> &'static str {");
343 pushln(out, " match self {");
344 for model in models {
345 pushln(out, format!(" Self::{} => \"{}\",", model.variant_name, model.model_id));
346 }
347 pushln(out, " }");
348 pushln(out, " }");
349 blank(out);
350
351 pushln(out, " #[allow(clippy::too_many_lines)]");
353 pushln(out, " fn display_name(self) -> &'static str {");
354 pushln(out, " match self {");
355 emit_grouped_arms(out, models, |m| escape_rust_string(&m.display_name), |name| format!("\"{name}\""));
356 pushln(out, " }");
357 pushln(out, " }");
358 blank(out);
359
360 pushln(out, " fn context_window(self) -> u32 {");
362 pushln(out, " match self {");
363 emit_grouped_arms(
364 out,
365 models,
366 |m| m.context_window.to_string(),
367 |val| format_number(val.parse::<u32>().unwrap()),
368 );
369 pushln(out, " }");
370 pushln(out, " }");
371 blank(out);
372
373 pushln(out, " pub fn reasoning_levels(self) -> &'static [ReasoningEffort] {");
375 pushln(out, " match self {");
376 emit_grouped_arms(out, models, |m| m.reasoning_levels.join(","), format_reasoning_levels_rhs);
377 pushln(out, " }");
378 pushln(out, " }");
379 blank(out);
380
381 pushln(out, " pub fn supports_reasoning(self) -> bool {");
383 pushln(out, " !self.reasoning_levels().is_empty()");
384 pushln(out, " }");
385 blank(out);
386
387 for modality in ["image", "audio"] {
388 pushln(out, format!(" pub fn supports_{modality}(self) -> bool {{"));
389 pushln(out, " match self {");
390 let mod_owned = modality.to_string();
391 emit_grouped_arms(
392 out,
393 models,
394 |m| m.input_modalities.contains(&mod_owned).to_string(),
395 std::string::ToString::to_string,
396 );
397 pushln(out, " }");
398 pushln(out, " }");
399 blank(out);
400 }
401
402 pushln(out, format!(" const ALL: &[{enum_name}] = &["));
404 for model in models {
405 pushln(out, format!(" Self::{},", model.variant_name));
406 }
407 pushln(out, " ];");
408
409 pushln(out, "}");
410 blank(out);
411
412 emit_from_str_impl(out, &enum_name, cfg.parser_name, models);
413 }
414}
415
416fn emit_from_str_impl(out: &mut String, enum_name: &str, parser_name: &str, models: &[ModelInfo]) {
417 pushln(out, format!("impl std::str::FromStr for {enum_name} {{"));
418 pushln(out, " type Err = String;");
419 blank(out);
420 pushln(out, " #[allow(clippy::too_many_lines)]");
421 pushln(out, " fn from_str(s: &str) -> Result<Self, Self::Err> {");
422 pushln(out, " match s {");
423 for model in models {
424 pushln(out, format!(" \"{}\" => Ok(Self::{}),", model.model_id, model.variant_name));
425 }
426 pushln(out, format!(" _ => Err(format!(\"Unknown {parser_name} model: '{{s}}'\")),"));
427 pushln(out, " }");
428 pushln(out, " }");
429 pushln(out, "}");
430 blank(out);
431}
432
433fn emit_grouped_arms(
438 out: &mut String,
439 models: &[ModelInfo],
440 key_fn: impl Fn(&ModelInfo) -> String,
441 fmt_val: impl Fn(&str) -> String,
442) {
443 let mut groups: BTreeMap<String, Vec<&str>> = BTreeMap::new();
445 for model in models {
446 groups.entry(key_fn(model)).or_default().push(&model.variant_name);
447 }
448
449 for (key, variants) in &groups {
450 let rhs = fmt_val(key);
451 if variants.len() == 1 {
452 pushln(out, format!(" Self::{} => {rhs},", variants[0]));
453 } else {
454 let patterns: Vec<String> = variants.iter().map(|v| format!("Self::{v}")).collect();
455 pushln(out, format!(" {} => {rhs},", patterns.join(" | ")));
456 }
457 }
458}
459
460fn format_reasoning_levels_rhs(key: &str) -> String {
462 if key.is_empty() {
463 return "&[]".to_string();
464 }
465 let items: Vec<String> = key
466 .split(',')
467 .map(|l| {
468 let variant = level_str_to_variant(l);
469 format!("ReasoningEffort::{variant}")
470 })
471 .collect();
472 format!("&[{}]", items.join(", "))
473}
474
475fn level_str_to_variant(level: &str) -> &'static str {
478 match level {
479 "low" => "Low",
480 "medium" => "Medium",
481 "high" => "High",
482 "xhigh" => "Xhigh",
483 other => panic!("Unknown reasoning level: {other}"),
484 }
485}
486
487fn emit_llm_model_enum(out: &mut String) {
488 pushln(out, "/// A model from a specific provider");
489 pushln(out, "#[derive(Debug, Clone, PartialEq, Eq, Hash)]");
490 pushln(out, "pub enum LlmModel {");
491 for cfg in PROVIDERS {
492 pushln(out, format!(" {provider}({provider}Model),", provider = cfg.enum_name));
493 }
494 for dyn_cfg in DYNAMIC_PROVIDERS {
495 pushln(out, format!(" {}(String),", dyn_cfg.enum_name));
496 }
497 pushln(out, "}");
498 blank(out);
499}
500
501fn emit_from_impls(out: &mut String) {
502 for cfg in PROVIDERS {
503 pushln(out, format!("impl From<{}Model> for LlmModel {{", cfg.enum_name));
504 pushln(out, format!(" fn from(m: {}Model) -> Self {{ LlmModel::{}(m) }}", cfg.enum_name, cfg.enum_name));
505 pushln(out, "}");
506 blank(out);
507 }
508}
509
510fn emit_llm_model_impl(out: &mut String) {
511 pushln(out, "impl LlmModel {");
512 emit_llm_model_id(out);
513 emit_llm_display_name(out);
514 emit_llm_provider(out);
515 emit_llm_provider_display_name(out);
516 emit_llm_context_window(out);
517 emit_llm_required_env_var(out);
518 emit_llm_all_required_env_vars(out);
519 emit_llm_oauth_provider_id(out);
520 emit_llm_reasoning_levels(out);
521 emit_llm_supports_reasoning(out);
522 for modality in ["image", "audio"] {
523 emit_llm_supports_modality(out, modality);
524 }
525 emit_llm_all(out);
526 pushln(out, "}");
527 blank(out);
528}
529
530fn emit_llm_model_id(out: &mut String) {
531 pushln(out, " /// Raw model ID (e.g. `claude-opus-4-6`, `llama3.2`)");
532 pushln(out, " pub fn model_id(&self) -> Cow<'static, str> {");
533 pushln(out, " match self {");
534 for cfg in PROVIDERS {
535 pushln(out, format!(" Self::{}(m) => Cow::Borrowed(m.model_id()),", cfg.enum_name));
536 }
537 pushln(out, format!(" {} => Cow::Owned(s.clone()),", dynamic_pattern_with_binding("s")));
538 pushln(out, " }");
539 pushln(out, " }");
540 blank(out);
541}
542
543fn emit_llm_display_name(out: &mut String) {
544 pushln(out, " /// Human-readable display name (e.g. `Claude Opus 4.6`)");
545 pushln(out, " pub fn display_name(&self) -> Cow<'static, str> {");
546 pushln(out, " match self {");
547 for cfg in PROVIDERS {
548 pushln(out, format!(" Self::{}(m) => Cow::Borrowed(m.display_name()),", cfg.enum_name));
549 }
550 for dyn_cfg in DYNAMIC_PROVIDERS {
551 pushln(
552 out,
553 format!(
554 " Self::{}(s) => Cow::Owned(format!(\"{} {{s}}\")),",
555 dyn_cfg.enum_name, dyn_cfg.enum_name
556 ),
557 );
558 }
559 pushln(out, " }");
560 pushln(out, " }");
561 blank(out);
562}
563
564fn emit_llm_provider(out: &mut String) {
565 pushln(out, " /// Provider identifier (e.g. `anthropic`)");
566 pushln(out, " pub fn provider(&self) -> &'static str {");
567 pushln(out, " match self {");
568 for cfg in PROVIDERS {
569 pushln(out, format!(" Self::{}(_) => \"{}\",", cfg.enum_name, cfg.parser_name));
570 }
571 for dyn_cfg in DYNAMIC_PROVIDERS {
572 pushln(out, format!(" Self::{}(_) => \"{}\",", dyn_cfg.enum_name, dyn_cfg.parser_name));
573 }
574 pushln(out, " }");
575 pushln(out, " }");
576 blank(out);
577}
578
579fn emit_llm_provider_display_name(out: &mut String) {
580 pushln(out, " /// Human-readable provider name (e.g. `AWS Bedrock`)");
581 pushln(out, " pub fn provider_display_name(&self) -> &'static str {");
582 pushln(out, " match self {");
583 for cfg in PROVIDERS {
584 pushln(out, format!(" Self::{}(_) => \"{}\",", cfg.enum_name, cfg.display_name));
585 }
586 for dyn_cfg in DYNAMIC_PROVIDERS {
587 pushln(out, format!(" Self::{}(_) => \"{}\",", dyn_cfg.enum_name, dyn_cfg.display_name));
588 }
589 pushln(out, " }");
590 pushln(out, " }");
591 blank(out);
592}
593
594fn emit_llm_context_window(out: &mut String) {
595 pushln(out, " /// Context window size in tokens (None for dynamic providers)");
596 pushln(out, " pub fn context_window(&self) -> Option<u32> {");
597 pushln(out, " match self {");
598 for cfg in PROVIDERS {
599 pushln(out, format!(" Self::{}(m) => Some(m.context_window()),", cfg.enum_name));
600 }
601 pushln(out, format!(" {} => None,", dynamic_pattern_with_binding("_")));
602 pushln(out, " }");
603 pushln(out, " }");
604 blank(out);
605}
606
607fn emit_llm_required_env_var(out: &mut String) {
608 pushln(out, " /// Required env var for this model's provider (None for local providers)");
609 pushln(out, " pub fn required_env_var(&self) -> Option<&'static str> {");
610 pushln(out, " match self {");
611 let mut none_arms: Vec<String> = Vec::new();
612 for cfg in PROVIDERS {
613 match cfg.env_var {
614 Some(var) => pushln(out, format!(" Self::{}(_) => Some(\"{}\"),", cfg.enum_name, var)),
615 None => none_arms.push(format!("Self::{}(_)", cfg.enum_name)),
616 }
617 }
618 for dyn_cfg in DYNAMIC_PROVIDERS {
619 none_arms.push(format!("Self::{}(_)", dyn_cfg.enum_name));
620 }
621 pushln(out, format!(" {} => None,", none_arms.join(" | ")));
622 pushln(out, " }");
623 pushln(out, " }");
624 blank(out);
625}
626
627fn emit_llm_all_required_env_vars(out: &mut String) {
628 let vars: Vec<&str> = PROVIDERS.iter().filter_map(|cfg| cfg.env_var).collect();
629 pushln(out, " /// All provider API key env var names (deduplicated, static)");
630 pushln(
631 out,
632 format!(
633 " pub const ALL_REQUIRED_ENV_VARS: &[&str] = &[{}];",
634 vars.iter().map(|v| format!("\"{v}\"")).collect::<Vec<_>>().join(", ")
635 ),
636 );
637 blank(out);
638}
639
640fn emit_llm_oauth_provider_id(out: &mut String) {
641 pushln(out, " /// OAuth provider ID if this model requires OAuth login (e.g. `\"codex\"`)");
642 pushln(out, " pub fn oauth_provider_id(&self) -> Option<&'static str> {");
643 pushln(out, " match self {");
644 let mut none_arms: Vec<String> = Vec::new();
645 for cfg in PROVIDERS {
646 match cfg.oauth_provider_id {
647 Some(id) => pushln(out, format!(" Self::{}(_) => Some(\"{}\"),", cfg.enum_name, id)),
648 None => none_arms.push(format!("Self::{}(_)", cfg.enum_name)),
649 }
650 }
651 for dyn_cfg in DYNAMIC_PROVIDERS {
652 none_arms.push(format!("Self::{}(_)", dyn_cfg.enum_name));
653 }
654 pushln(out, format!(" {} => None,", none_arms.join(" | ")));
655 pushln(out, " }");
656 pushln(out, " }");
657 blank(out);
658}
659
660fn emit_llm_reasoning_levels(out: &mut String) {
661 pushln(out, " /// Reasoning levels supported by this model (empty if not a reasoning model)");
662 pushln(out, " pub fn reasoning_levels(&self) -> &'static [ReasoningEffort] {");
663 pushln(out, " match self {");
664 for cfg in PROVIDERS {
665 pushln(out, format!(" Self::{}(m) => m.reasoning_levels(),", cfg.enum_name));
666 }
667 pushln(out, format!(" {} => &[],", dynamic_pattern_with_binding("_")));
668 pushln(out, " }");
669 pushln(out, " }");
670 blank(out);
671}
672
673fn emit_llm_supports_reasoning(out: &mut String) {
674 pushln(out, " /// Whether this model supports reasoning/extended thinking");
675 pushln(out, " pub fn supports_reasoning(&self) -> bool {");
676 pushln(out, " !self.reasoning_levels().is_empty()");
677 pushln(out, " }");
678 blank(out);
679}
680
681fn emit_llm_supports_modality(out: &mut String, modality: &str) {
682 pushln(out, format!(" /// Whether this model supports {modality} input"));
683 pushln(out, format!(" pub fn supports_{modality}(&self) -> bool {{"));
684 pushln(out, " match self {");
685 for cfg in PROVIDERS {
686 pushln(out, format!(" Self::{}(m) => m.supports_{modality}(),", cfg.enum_name));
687 }
688 pushln(out, format!(" {} => false,", dynamic_pattern_with_binding("_")));
689 pushln(out, " }");
690 pushln(out, " }");
691 blank(out);
692}
693
694fn emit_llm_all(out: &mut String) {
695 pushln(out, " /// All catalog models (excludes dynamic providers)");
696 pushln(out, " pub fn all() -> &'static [LlmModel] {");
697 pushln(out, " static ALL: LazyLock<Vec<LlmModel>> = LazyLock::new(|| {");
698 pushln(out, " let mut v = Vec::new();");
699 for cfg in PROVIDERS {
700 pushln(
701 out,
702 format!(
703 " v.extend({}Model::ALL.iter().copied().map(LlmModel::{}));",
704 cfg.enum_name, cfg.enum_name
705 ),
706 );
707 }
708 pushln(out, " v");
709 pushln(out, " });");
710 pushln(out, " &ALL");
711 pushln(out, " }");
712}
713
714fn emit_display_impl(out: &mut String) {
715 pushln(out, "impl std::fmt::Display for LlmModel {");
716 pushln(out, " fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {");
717 pushln(out, " write!(f, \"{}:{}\", self.provider(), self.model_id())");
718 pushln(out, " }");
719 pushln(out, "}");
720 blank(out);
721}
722
723fn emit_fromstr_impl(out: &mut String) {
724 pushln(out, "impl std::str::FromStr for LlmModel {");
725 pushln(out, " type Err = String;");
726 blank(out);
727 pushln(out, " /// Parse a `provider:model` string into an `LlmModel`");
728 pushln(out, " fn from_str(s: &str) -> Result<Self, Self::Err> {");
729 pushln(out, " let (provider_str, model_str) = s.split_once(':').unwrap_or((s, \"\"));");
730 pushln(out, " match provider_str {");
731 for cfg in PROVIDERS {
732 pushln(
733 out,
734 format!(
735 " \"{}\" => model_str.parse::<{}Model>().map(Self::{}),",
736 cfg.parser_name, cfg.enum_name, cfg.enum_name
737 ),
738 );
739 }
740 for dyn_cfg in DYNAMIC_PROVIDERS {
741 pushln(
742 out,
743 format!(
744 " \"{}\" => Ok(Self::{}(model_str.to_string())),",
745 dyn_cfg.parser_name, dyn_cfg.enum_name
746 ),
747 );
748 }
749 pushln(out, " _ => Err(format!(\"Unknown provider: '{provider_str}'\")),");
750 pushln(out, " }");
751 pushln(out, " }");
752 pushln(out, "}");
753}
754
755fn dynamic_pattern_with_binding(binding: &str) -> String {
758 DYNAMIC_PROVIDERS.iter().map(|d| format!("Self::{}({binding})", d.enum_name)).collect::<Vec<_>>().join(" | ")
759}
760
761fn format_number(n: u32) -> String {
763 let s = n.to_string();
764 if s.len() <= 4 {
765 return s;
766 }
767 let mut result = String::with_capacity(s.len() + s.len() / 3);
768 for (i, ch) in s.chars().enumerate() {
769 if i > 0 && (s.len() - i).is_multiple_of(3) {
770 result.push('_');
771 }
772 result.push(ch);
773 }
774 result
775}
776
777fn escape_rust_string(raw: &str) -> String {
778 raw.replace('\\', "\\\\").replace('"', "\\\"")
779}
780
781fn emit_provider_docs(ctx: &CodegenCtx) -> HashMap<String, String> {
782 let mut docs = HashMap::new();
783
784 for cfg in PROVIDERS {
785 let models = &ctx.provider_models[cfg.dev_id];
786 let mut doc = String::new();
787
788 pushln(&mut doc, format!("`{}` LLM provider.", cfg.display_name));
789 blank(&mut doc);
790
791 pushln(&mut doc, "# Authentication");
793 blank(&mut doc);
794 match cfg.env_var {
795 Some(var) => pushln(&mut doc, format!("Set the `{var}` environment variable.")),
796 None if cfg.oauth_provider_id.is_some() => {
797 pushln(&mut doc, "This provider uses OAuth authentication.");
798 }
799 None => {
800 pushln(
801 &mut doc,
802 "Uses the default AWS credential chain (environment variables, config files, IAM roles).",
803 );
804 }
805 }
806 blank(&mut doc);
807
808 pushln(&mut doc, "# Supported models");
810 blank(&mut doc);
811 pushln(&mut doc, "| Model ID | Name | Context | Reasoning | Image | Audio |");
812 pushln(&mut doc, "|----------|------|---------|-----------|-------|-------|");
813 for model in models {
814 let ctx_str = format_context_window(model.context_window);
815 let reasoning = if model.reasoning_levels.is_empty() { "" } else { "yes" };
816 let image = if model.input_modalities.contains(&"image".to_string()) { "yes" } else { "" };
817 let audio = if model.input_modalities.contains(&"audio".to_string()) { "yes" } else { "" };
818 pushln(
819 &mut doc,
820 format!(
821 "| `{}` | `{}` | `{}` | {} | {} | {} |",
822 model.model_id, model.display_name, ctx_str, reasoning, image, audio
823 ),
824 );
825 }
826
827 docs.insert(cfg.dev_id.to_string(), doc);
828 }
829
830 for dyn_cfg in DYNAMIC_PROVIDERS {
832 let mut doc = String::new();
833 pushln(&mut doc, format!("`{}` LLM provider.", dyn_cfg.display_name));
834 blank(&mut doc);
835 pushln(
836 &mut doc,
837 format!("This provider accepts any model name at runtime (e.g. `{}:my-model`).", dyn_cfg.parser_name),
838 );
839 pushln(&mut doc, "No API key is required.");
840 docs.insert(dyn_cfg.parser_name.to_string(), doc);
841 }
842
843 docs
844}
845
846fn format_context_window(tokens: u32) -> String {
848 if tokens == 0 {
849 return "unknown".to_string();
850 }
851 if tokens >= 1_000_000 && tokens.is_multiple_of(1_000_000) {
852 format!("{}M", tokens / 1_000_000)
853 } else if tokens >= 1_000 && tokens.is_multiple_of(1_000) {
854 format!("{}k", tokens / 1_000)
855 } else {
856 format_number(tokens)
857 }
858}
859
860fn pushln(out: &mut String, line: impl AsRef<str>) {
861 writeln!(out, "{}", line.as_ref()).expect("writing to String should not fail");
862}
863
864fn blank(out: &mut String) {
865 pushln(out, "");
866}
867
868#[cfg(test)]
871mod tests {
872 use super::*;
873 use serde_json::Value;
874 use serde_json::json;
875 use tempfile::NamedTempFile;
876
877 #[test]
878 fn generate_sorts_and_filters_models() {
879 let mut data = minimal_models_dev_json();
880 let root = data.as_object_mut().expect("root object");
881 let anthropic = root.get_mut("anthropic").and_then(Value::as_object_mut).expect("anthropic provider");
882
883 anthropic.insert(
884 "models".to_string(),
885 json!({
886 "b-model": {
887 "id": "b-model",
888 "name": "B Model",
889 "tool_call": true,
890 "limit": {"context": 2000, "output": 0}
891 },
892 "a-model": {
893 "id": "a-model",
894 "name": "A Model",
895 "tool_call": true,
896 "limit": {"context": 1000, "output": 0}
897 },
898 "alpha-latest": {
899 "id": "alpha-latest",
900 "name": "Alias",
901 "tool_call": true,
902 "limit": {"context": 500, "output": 0}
903 },
904 "no-tools": {
905 "id": "no-tools",
906 "name": "No Tools",
907 "tool_call": false,
908 "limit": {"context": 500, "output": 0}
909 }
910 }),
911 );
912
913 let source = generate_from_value(&data);
914 let a_model = "\"a-model\" => Ok(Self::AModel),";
916 let b_model = "\"b-model\" => Ok(Self::BModel),";
917 let a_pos = source.find(a_model).expect("a-model parse arm");
918 let b_pos = source.find(b_model).expect("b-model parse arm");
919 assert!(a_pos < b_pos);
920 assert!(!source.contains("AlphaLatest"));
922 assert!(!source.contains("NoTools"));
923 }
924
925 #[test]
926 fn generate_contains_core_sections() {
927 let source = generate_from_value(&minimal_models_dev_json());
928 assert!(source.contains("pub enum LlmModel {"));
929 assert!(source.contains("impl std::str::FromStr for LlmModel {"));
930 assert!(source.contains("impl std::fmt::Display for LlmModel {"));
931 assert!(source.contains("pub fn required_env_var(&self) -> Option<&'static str> {"));
932 }
933
934 #[test]
935 fn generate_contains_dynamic_provider_arms() {
936 let source = generate_from_value(&minimal_models_dev_json());
937 assert!(source.contains("\"ollama\" => Ok(Self::Ollama(model_str.to_string())),"));
938 assert!(source.contains("\"llamacpp\" => Ok(Self::LlamaCpp(model_str.to_string())),"));
939 assert!(source.contains("Self::Ollama(_) | Self::LlamaCpp(_) => None,"));
941 }
942
943 #[test]
944 fn generate_codex_is_catalog_provider() {
945 let source = generate_from_value(&minimal_models_dev_json());
946 assert!(source.contains("pub enum CodexModel {"));
948 assert!(source.contains("\"codex\" => model_str.parse::<CodexModel>().map(Self::Codex),"));
949 assert!(source.contains("Self::Codex(m) => Some(m.context_window()),"));
950 }
951
952 #[test]
953 fn generate_oauth_provider_id_for_codex() {
954 let source = generate_from_value(&minimal_models_dev_json());
955 assert!(source.contains("Self::Codex(_) => Some(\"codex\"),"));
957 assert!(source.contains("pub fn oauth_provider_id(&self) -> Option<&'static str>"));
959 }
960
961 #[test]
962 fn generate_delegates_to_provider_impls() {
963 let source = generate_from_value(&minimal_models_dev_json());
964 assert!(source.contains("Self::Anthropic(m) => Cow::Borrowed(m.model_id()),"));
966 assert!(source.contains("Self::Anthropic(m) => Some(m.context_window()),"));
967 assert!(source.contains("\"anthropic\" => model_str.parse::<AnthropicModel>().map(Self::Anthropic),"));
969 }
970
971 #[test]
972 fn generate_formats_large_numbers_with_separators() {
973 let mut data = minimal_models_dev_json();
974 let root = data.as_object_mut().expect("root object");
975 let anthropic = root.get_mut("anthropic").and_then(Value::as_object_mut).expect("anthropic provider");
976
977 anthropic.insert(
978 "models".to_string(),
979 json!({
980 "big-model": {
981 "id": "big-model",
982 "name": "Big Model",
983 "tool_call": true,
984 "limit": {"context": 200_000, "output": 0}
985 }
986 }),
987 );
988
989 let source = generate_from_value(&data);
990 assert!(source.contains("200_000"));
991 assert!(!source.contains("200000"));
992 }
993
994 #[test]
995 fn generate_groups_identical_match_arms() {
996 let mut data = minimal_models_dev_json();
997 let root = data.as_object_mut().expect("root object");
998 let anthropic = root.get_mut("anthropic").and_then(Value::as_object_mut).expect("anthropic provider");
999
1000 anthropic.insert(
1001 "models".to_string(),
1002 json!({
1003 "model-a": {
1004 "id": "model-a",
1005 "name": "Same Name",
1006 "tool_call": true,
1007 "limit": {"context": 100_000, "output": 0}
1008 },
1009 "model-b": {
1010 "id": "model-b",
1011 "name": "Same Name",
1012 "tool_call": true,
1013 "limit": {"context": 100_000, "output": 0}
1014 }
1015 }),
1016 );
1017
1018 let source = generate_from_value(&data);
1019 assert!(source.contains("Self::ModelA | Self::ModelB => 100_000,"));
1021 assert!(source.contains("Self::ModelA | Self::ModelB => \"Same Name\","));
1022 }
1023
1024 #[test]
1025 fn test_model_id_to_variant() {
1026 assert_eq!(model_id_to_variant("claude-sonnet-4-5-20250929"), "ClaudeSonnet4520250929");
1027 assert_eq!(model_id_to_variant("gemini-2.5-flash"), "Gemini25Flash");
1028 assert_eq!(model_id_to_variant("deepseek-chat"), "DeepseekChat");
1029 assert_eq!(model_id_to_variant("glm-4.5"), "Glm45");
1030 }
1031
1032 #[test]
1033 fn test_model_id_to_variant_with_slash_and_colon() {
1034 assert_eq!(model_id_to_variant("anthropic/claude-opus-4.6"), "AnthropicClaudeOpus46");
1035 assert_eq!(model_id_to_variant("openai/gpt-5.1-codex-max"), "OpenaiGpt51CodexMax");
1036 assert_eq!(model_id_to_variant("deepseek/deepseek-r1:free"), "DeepseekDeepseekR1Free");
1037 }
1038
1039 #[test]
1040 fn test_is_alias() {
1041 assert!(is_alias("claude-sonnet-4-5-latest"));
1042 assert!(is_alias("claude-3-7-sonnet-latest"));
1043 assert!(!is_alias("claude-sonnet-4-5-20250929"));
1044 }
1045
1046 #[test]
1047 fn generate_contains_reasoning_levels_and_supports_reasoning() {
1048 let mut data = minimal_models_dev_json();
1049 let root = data.as_object_mut().expect("root object");
1050 let anthropic = root.get_mut("anthropic").and_then(Value::as_object_mut).expect("anthropic provider");
1051
1052 anthropic.insert(
1053 "models".to_string(),
1054 json!({
1055 "thinker": {
1056 "id": "thinker",
1057 "name": "Thinker",
1058 "tool_call": true,
1059 "reasoning": true,
1060 "limit": {"context": 1000, "output": 0}
1061 },
1062 "fast": {
1063 "id": "fast",
1064 "name": "Fast",
1065 "tool_call": true,
1066 "reasoning": false,
1067 "limit": {"context": 1000, "output": 0}
1068 }
1069 }),
1070 );
1071
1072 let source = generate_from_value(&data);
1073 assert!(source.contains("pub fn reasoning_levels(self) -> &'static [ReasoningEffort] {"));
1075 assert!(
1077 source
1078 .contains("Self::Thinker => &[ReasoningEffort::Low, ReasoningEffort::Medium, ReasoningEffort::High],")
1079 );
1080 assert!(source.contains("Self::Fast => &[],"));
1082 assert!(source.contains("pub fn supports_reasoning(self) -> bool {"));
1084 assert!(source.contains("!self.reasoning_levels().is_empty()"));
1085 assert!(source.contains("pub fn reasoning_levels(&self) -> &'static [ReasoningEffort] {"));
1087 assert!(source.contains("Self::Ollama(_) | Self::LlamaCpp(_) => &[],"));
1089 }
1090
1091 #[test]
1092 fn generate_codex_gets_four_reasoning_levels() {
1093 let mut data = minimal_models_dev_json();
1094 let root = data.as_object_mut().expect("root object");
1095 let openai = root.get_mut("openai").and_then(Value::as_object_mut).expect("openai provider");
1096
1097 openai.insert(
1098 "models".to_string(),
1099 json!({
1100 "gpt-5.4-codex": {
1101 "id": "gpt-5.4-codex",
1102 "name": "GPT-5.4 Codex",
1103 "tool_call": true,
1104 "reasoning": true,
1105 "limit": {"context": 200_000, "output": 0}
1106 }
1107 }),
1108 );
1109
1110 let source = generate_from_value(&data);
1111 assert!(
1112 source.contains(
1113 "ReasoningEffort::Low, ReasoningEffort::Medium, ReasoningEffort::High, ReasoningEffort::Xhigh"
1114 ),
1115 "Codex reasoning model should have 4 levels including Xhigh"
1116 );
1117 }
1118
1119 #[test]
1120 fn generate_codex_overrides_gpt55_subscription_context_window() {
1121 let mut data = minimal_models_dev_json();
1122 let root = data.as_object_mut().expect("root object");
1123 let openai = root.get_mut("openai").and_then(Value::as_object_mut).expect("openai provider");
1124
1125 openai.insert(
1126 "models".to_string(),
1127 json!({
1128 "gpt-5.5": {
1129 "id": "gpt-5.5",
1130 "name": "GPT-5.5",
1131 "tool_call": true,
1132 "reasoning": true,
1133 "limit": {"context": 1_050_000, "output": 128_000}
1134 }
1135 }),
1136 );
1137
1138 let source = generate_from_value(&data);
1139 let codex_impl = generated_impl_block(&source, "CodexModel");
1140 let openai_impl = generated_impl_block(&source, "OpenaiModel");
1141
1142 assert!(codex_impl.contains("Self::Gpt55 => 272_000,"));
1143 assert!(openai_impl.contains("Self::Gpt55 => 1_050_000,"));
1144 }
1145
1146 #[test]
1147 fn codex_subscription_context_window_overrides_known_codex_models() {
1148 for model_id in ["gpt-5.5", "gpt-5.4", "gpt-5.4-mini", "gpt-5.3-codex", "gpt-5.2", "codex-auto-review"] {
1149 assert_eq!(codex_subscription_context_window(model_id, 1_050_000), 272_000);
1150 }
1151 }
1152
1153 #[test]
1154 fn codex_subscription_context_window_leaves_unknown_models_unchanged() {
1155 assert_eq!(codex_subscription_context_window("gpt-5.3-codex-spark", 128_000), 128_000);
1156 assert_eq!(codex_subscription_context_window("some-future-model", 400_000), 400_000);
1157 }
1158
1159 fn generate_from_value(data: &Value) -> String {
1160 let tmp = NamedTempFile::new().expect("temp file");
1161 let json = serde_json::to_string(data).expect("serialize fixture");
1162 std::fs::write(tmp.path(), json).expect("write fixture");
1163 generate(tmp.path()).expect("codegen succeeds").rust_source
1164 }
1165
1166 fn generated_impl_block<'a>(source: &'a str, enum_name: &str) -> &'a str {
1167 let start_marker = format!("impl {enum_name} {{");
1168 let start = source.find(&start_marker).expect("provider impl block start");
1169 let rest = &source[start..];
1170 let end_marker = format!("impl std::str::FromStr for {enum_name}");
1171 let end = rest.find(&end_marker).expect("provider impl block end");
1172 &rest[..end]
1173 }
1174 fn minimal_models_dev_json() -> Value {
1175 let mut root = serde_json::Map::new();
1176 for cfg in PROVIDERS {
1177 let json_key = cfg.json_key();
1178 root.entry(json_key.to_string()).or_insert_with(|| {
1179 json!({
1180 "id": json_key,
1181 "name": json_key,
1182 "env": [],
1183 "models": {}
1184 })
1185 });
1186 for &extra in cfg.extra_source_ids {
1187 root.entry(extra.to_string()).or_insert_with(|| {
1188 json!({
1189 "id": extra,
1190 "name": extra,
1191 "env": [],
1192 "models": {}
1193 })
1194 });
1195 }
1196 }
1197 Value::Object(root)
1198 }
1199
1200 #[test]
1201 fn extra_source_ids_merges_models_into_provider() {
1202 let mut data = minimal_models_dev_json();
1203 let root = data.as_object_mut().unwrap();
1204
1205 let extra = root.get_mut("zai-coding-plan").unwrap().as_object_mut().unwrap();
1207 extra.insert(
1208 "models".to_string(),
1209 json!({
1210 "extra-model": {
1211 "id": "extra-model",
1212 "name": "Extra Model",
1213 "tool_call": true,
1214 "limit": {"context": 4000, "output": 0}
1215 }
1216 }),
1217 );
1218
1219 let source = generate_from_value(&data);
1220 assert!(source.contains("\"extra-model\" => Ok(Self::ExtraModel),"));
1222 }
1223
1224 #[test]
1225 fn extra_source_ids_does_not_duplicate_existing_models() {
1226 let mut data = minimal_models_dev_json();
1227 let root = data.as_object_mut().unwrap();
1228
1229 let zai = root.get_mut("zai").unwrap().as_object_mut().unwrap();
1231 zai.insert(
1232 "models".to_string(),
1233 json!({
1234 "shared-model": {
1235 "id": "shared-model",
1236 "name": "Shared Model",
1237 "tool_call": true,
1238 "limit": {"context": 1000, "output": 0}
1239 }
1240 }),
1241 );
1242 let extra = root.get_mut("zai-coding-plan").unwrap().as_object_mut().unwrap();
1243 extra.insert(
1244 "models".to_string(),
1245 json!({
1246 "shared-model": {
1247 "id": "shared-model",
1248 "name": "Shared Model Duplicate",
1249 "tool_call": true,
1250 "limit": {"context": 2000, "output": 0}
1251 }
1252 }),
1253 );
1254
1255 let source = generate_from_value(&data);
1256 let from_str_matches = source.matches("\"shared-model\" => Ok(Self::SharedModel),").count();
1257 assert_eq!(from_str_matches, 1);
1258 }
1259
1260 #[test]
1261 fn generate_emits_provider_docs() {
1262 let mut data = minimal_models_dev_json();
1263 let root = data.as_object_mut().unwrap();
1264 let anthropic = root.get_mut("anthropic").and_then(Value::as_object_mut).unwrap();
1265
1266 anthropic.insert(
1267 "models".to_string(),
1268 json!({
1269 "claude-test": {
1270 "id": "claude-test",
1271 "name": "Claude Test",
1272 "tool_call": true,
1273 "reasoning": true,
1274 "limit": {"context": 200_000, "output": 0},
1275 "modalities": {"input": ["text", "image"]}
1276 }
1277 }),
1278 );
1279
1280 let tmp = NamedTempFile::new().unwrap();
1281 std::fs::write(tmp.path(), serde_json::to_string(&data).unwrap()).unwrap();
1282 let output = generate(tmp.path()).unwrap();
1283
1284 let anthropic_doc = &output.provider_docs["anthropic"];
1285 assert!(anthropic_doc.contains("`Anthropic` LLM provider."));
1286 assert!(anthropic_doc.contains("`ANTHROPIC_API_KEY`"));
1287 assert!(anthropic_doc.contains("| `claude-test` | `Claude Test` | `200k` | yes | yes | |"));
1288
1289 let ollama_doc = &output.provider_docs["ollama"];
1291 assert!(ollama_doc.contains("`Ollama` LLM provider."));
1292 assert!(ollama_doc.contains("any model name at runtime"));
1293 }
1294
1295 #[test]
1296 fn format_context_window_formats_correctly() {
1297 assert_eq!(format_context_window(1_000_000), "1M");
1298 assert_eq!(format_context_window(200_000), "200k");
1299 assert_eq!(format_context_window(8_000), "8k");
1300 assert_eq!(format_context_window(0), "unknown");
1301 }
1302
1303 #[test]
1304 fn level_str_to_variant_covers_all_reasoning_efforts() {
1305 for effort in utils::ReasoningEffort::all() {
1306 let _ = level_str_to_variant(effort.as_str());
1308 }
1309 }
1310}