1#![doc = include_str!("../README.md")]
2
3use serde::Deserialize;
4use std::collections::{BTreeMap, HashMap};
5use std::fmt::Write;
6use std::path::Path;
7
8type ModelsDevData = HashMap<String, ProviderData>;
9
10#[derive(Debug, Deserialize)]
11struct ProviderData {
12 #[allow(dead_code)]
13 id: String,
14 #[allow(dead_code)]
15 name: String,
16 #[serde(default)]
17 #[allow(dead_code)]
18 env: Vec<String>,
19 #[serde(default)]
20 models: HashMap<String, ModelData>,
21}
22
23#[derive(Debug, Deserialize)]
24struct ModelData {
25 id: String,
26 name: String,
27 #[serde(default)]
28 tool_call: Option<bool>,
29 #[serde(default)]
30 reasoning: Option<bool>,
31 #[serde(default)]
32 #[allow(dead_code)]
33 cost: Option<CostData>,
34 #[serde(default)]
35 limit: Option<LimitData>,
36 #[serde(default)]
37 modalities: Option<ModalitiesData>,
38}
39
40#[derive(Debug, Deserialize, Default)]
41struct ModalitiesData {
42 #[serde(default)]
43 input: Vec<String>,
44}
45
46#[derive(Debug, Deserialize)]
47#[allow(dead_code)]
48struct CostData {
49 #[serde(default)]
50 input: f64,
51 #[serde(default)]
52 output: f64,
53}
54
55#[derive(Debug, Deserialize)]
56struct LimitData {
57 #[serde(default)]
58 context: u32,
59 #[serde(default)]
60 #[allow(dead_code)]
61 output: u32,
62}
63
64struct ProviderConfig {
66 dev_id: &'static str,
68 source_dev_id: Option<&'static str>,
70 extra_source_ids: &'static [&'static str],
72 model_filter: Option<fn(&str) -> bool>,
74 enum_name: &'static str,
76 parser_name: &'static str,
78 display_name: &'static str,
80 env_var: Option<&'static str>,
82 oauth_provider_id: Option<&'static str>,
84 default_reasoning_levels: &'static [&'static str],
86}
87
88impl ProviderConfig {
89 const fn standard(
91 dev_id: &'static str,
92 enum_name: &'static str,
93 parser_name: &'static str,
94 display_name: &'static str,
95 env_var: Option<&'static str>,
96 ) -> Self {
97 Self {
98 dev_id,
99 source_dev_id: None,
100 extra_source_ids: &[],
101 model_filter: None,
102 enum_name,
103 parser_name,
104 display_name,
105 env_var,
106 oauth_provider_id: None,
107 default_reasoning_levels: &["low", "medium", "high"],
108 }
109 }
110
111 fn json_key(&self) -> &'static str {
113 self.source_dev_id.unwrap_or(self.dev_id)
114 }
115}
116
117#[allow(clippy::struct_field_names)]
119struct DynamicProviderConfig {
120 enum_name: &'static str,
122 parser_name: &'static str,
124 display_name: &'static str,
126}
127
128const PROVIDERS: &[ProviderConfig] = &[
129 ProviderConfig::standard("anthropic", "Anthropic", "anthropic", "Anthropic", Some("ANTHROPIC_API_KEY")),
130 ProviderConfig {
131 dev_id: "codex",
132 source_dev_id: Some("openai"),
133 extra_source_ids: &[],
134 model_filter: Some(|id| id.contains("codex") || id.starts_with("gpt-5.4")),
135 enum_name: "Codex",
136 parser_name: "codex",
137 display_name: "Codex",
138 env_var: None,
139 oauth_provider_id: Some("codex"),
140 default_reasoning_levels: &["low", "medium", "high", "xhigh"],
141 },
142 ProviderConfig::standard("deepseek", "DeepSeek", "deepseek", "DeepSeek", Some("DEEPSEEK_API_KEY")),
143 ProviderConfig::standard("google", "Gemini", "gemini", "Gemini", Some("GEMINI_API_KEY")),
144 ProviderConfig::standard("moonshotai", "Moonshot", "moonshot", "Moonshot", Some("MOONSHOT_API_KEY")),
145 ProviderConfig::standard("openai", "Openai", "openai", "OpenAI", Some("OPENAI_API_KEY")),
146 ProviderConfig::standard("openrouter", "OpenRouter", "openrouter", "OpenRouter", Some("OPENROUTER_API_KEY")),
147 ProviderConfig {
148 extra_source_ids: &["zai-coding-plan"],
149 ..ProviderConfig::standard("zai", "ZAi", "zai", "ZAI", Some("ZAI_API_KEY"))
150 },
151 ProviderConfig::standard("amazon-bedrock", "Bedrock", "bedrock", "AWS Bedrock", None),
152];
153
154const DYNAMIC_PROVIDERS: &[DynamicProviderConfig] = &[
155 DynamicProviderConfig { enum_name: "Ollama", parser_name: "ollama", display_name: "Ollama" },
156 DynamicProviderConfig { enum_name: "LlamaCpp", parser_name: "llamacpp", display_name: "LlamaCpp" },
157];
158
159#[derive(Debug, Clone)]
160struct ModelInfo {
161 variant_name: String,
162 model_id: String,
163 display_name: String,
164 context_window: u32,
165 reasoning_levels: Vec<String>,
166 input_modalities: Vec<String>,
167}
168
169type ProviderModels = BTreeMap<&'static str, Vec<ModelInfo>>;
170
171struct CodegenCtx {
172 provider_models: ProviderModels,
173}
174
175pub struct GeneratedOutput {
177 pub rust_source: String,
179 pub provider_docs: HashMap<String, String>,
184}
185
186pub fn generate(models_json_path: &Path) -> Result<GeneratedOutput, String> {
188 let json_bytes = std::fs::read_to_string(models_json_path).map_err(|e| format!("read: {e}"))?;
189 let data: ModelsDevData = serde_json::from_str(&json_bytes).map_err(|e| format!("parse: {e}"))?;
190
191 let provider_models = build_provider_models(&data)?;
192 let ctx = CodegenCtx { provider_models };
193 Ok(GeneratedOutput { rust_source: emit_generated_source(&ctx), provider_docs: emit_provider_docs(&ctx) })
194}
195
196fn build_provider_models(data: &ModelsDevData) -> Result<ProviderModels, String> {
197 let mut provider_models = ProviderModels::new();
198
199 for cfg in PROVIDERS {
200 let json_key = cfg.json_key();
201 let provider_data =
202 data.get(json_key).ok_or_else(|| format!("Provider '{json_key}' not found in models.dev data"))?;
203
204 let mut models: Vec<ModelInfo> = collect_models_from(cfg, &provider_data.models);
205
206 for &extra_key in cfg.extra_source_ids {
207 if let Some(extra_data) = data.get(extra_key) {
208 let extra = collect_models_from(cfg, &extra_data.models);
209 let existing_ids: std::collections::HashSet<String> =
210 models.iter().map(|m| m.model_id.clone()).collect();
211 models.extend(extra.into_iter().filter(|m| !existing_ids.contains(&m.model_id)));
212 }
213 }
214
215 models.sort_by(|a, b| a.model_id.cmp(&b.model_id));
216 provider_models.insert(cfg.dev_id, models);
217 }
218
219 Ok(provider_models)
220}
221
222fn collect_models_from(cfg: &ProviderConfig, models: &HashMap<String, ModelData>) -> Vec<ModelInfo> {
223 models
224 .values()
225 .filter(|m| m.tool_call == Some(true))
226 .filter(|m| !is_alias(&m.id))
227 .filter(|m| cfg.model_filter.is_none_or(|f| f(&m.id)))
228 .map(|m| {
229 let reasoning_levels = if m.reasoning.unwrap_or(false) {
230 cfg.default_reasoning_levels.iter().map(|s| (*s).to_string()).collect()
231 } else {
232 Vec::new()
233 };
234 let input_modalities =
235 m.modalities.as_ref().map_or_else(|| vec!["text".to_string()], |md| md.input.clone());
236 ModelInfo {
237 variant_name: model_id_to_variant(&m.id),
238 model_id: m.id.clone(),
239 display_name: m.name.clone(),
240 context_window: m.limit.as_ref().map_or(0, |l| l.context),
241 reasoning_levels,
242 input_modalities,
243 }
244 })
245 .collect()
246}
247
248fn is_alias(id: &str) -> bool {
250 id.ends_with("-latest")
251}
252
253fn model_id_to_variant(id: &str) -> String {
256 let mut result = String::new();
257 let mut capitalize_next = true;
258
259 for ch in id.chars() {
260 if ch == '-' || ch == '.' || ch == '/' || ch == ':' {
261 capitalize_next = true;
262 } else if capitalize_next {
263 result.push(ch.to_ascii_uppercase());
264 capitalize_next = false;
265 } else {
266 result.push(ch);
267 }
268 }
269
270 if result.starts_with(|c: char| c.is_ascii_digit()) {
272 result.insert(0, '_');
273 }
274
275 result
276}
277
278fn emit_generated_source(ctx: &CodegenCtx) -> String {
279 let mut out = String::with_capacity(64_000);
280 emit_header(&mut out);
281 emit_provider_enums(&mut out, &ctx.provider_models);
282 emit_provider_impls(&mut out, &ctx.provider_models);
283 emit_llm_model_enum(&mut out);
284 emit_from_impls(&mut out);
285 emit_llm_model_impl(&mut out);
286 emit_display_impl(&mut out);
287 emit_fromstr_impl(&mut out);
288 out
289}
290
291fn emit_header(out: &mut String) {
292 pushln(out, "// Auto-generated from models.dev — do not edit manually");
293 pushln(out, "// Regenerated automatically by build.rs");
294 blank(out);
295 pushln(out, "use std::borrow::Cow;");
296 pushln(out, "use std::sync::LazyLock;");
297 pushln(out, "use crate::ReasoningEffort;");
298 blank(out);
299}
300
301fn emit_provider_enums(out: &mut String, provider_models: &ProviderModels) {
302 for cfg in PROVIDERS {
303 pushln(out, "#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]");
304 pushln(out, format!("pub enum {}Model {{", cfg.enum_name));
305 for model in &provider_models[cfg.dev_id] {
306 pushln(out, format!(" {},", model.variant_name));
307 }
308 pushln(out, "}");
309 blank(out);
310 }
311}
312
313fn emit_provider_impls(out: &mut String, provider_models: &ProviderModels) {
314 for cfg in PROVIDERS {
315 let models = &provider_models[cfg.dev_id];
316 let enum_name = format!("{}Model", cfg.enum_name);
317
318 pushln(out, format!("impl {enum_name} {{"));
319
320 pushln(out, " #[allow(clippy::too_many_lines)]");
322 pushln(out, " fn model_id(self) -> &'static str {");
323 pushln(out, " match self {");
324 for model in models {
325 pushln(out, format!(" Self::{} => \"{}\",", model.variant_name, model.model_id));
326 }
327 pushln(out, " }");
328 pushln(out, " }");
329 blank(out);
330
331 pushln(out, " #[allow(clippy::too_many_lines)]");
333 pushln(out, " fn display_name(self) -> &'static str {");
334 pushln(out, " match self {");
335 emit_grouped_arms(out, models, |m| escape_rust_string(&m.display_name), |name| format!("\"{name}\""));
336 pushln(out, " }");
337 pushln(out, " }");
338 blank(out);
339
340 pushln(out, " fn context_window(self) -> u32 {");
342 pushln(out, " match self {");
343 emit_grouped_arms(
344 out,
345 models,
346 |m| m.context_window.to_string(),
347 |val| format_number(val.parse::<u32>().unwrap()),
348 );
349 pushln(out, " }");
350 pushln(out, " }");
351 blank(out);
352
353 pushln(out, " pub fn reasoning_levels(self) -> &'static [ReasoningEffort] {");
355 pushln(out, " match self {");
356 emit_grouped_arms(out, models, |m| m.reasoning_levels.join(","), format_reasoning_levels_rhs);
357 pushln(out, " }");
358 pushln(out, " }");
359 blank(out);
360
361 pushln(out, " pub fn supports_reasoning(self) -> bool {");
363 pushln(out, " !self.reasoning_levels().is_empty()");
364 pushln(out, " }");
365 blank(out);
366
367 for modality in ["image", "audio"] {
368 pushln(out, format!(" pub fn supports_{modality}(self) -> bool {{"));
369 pushln(out, " match self {");
370 let mod_owned = modality.to_string();
371 emit_grouped_arms(
372 out,
373 models,
374 |m| m.input_modalities.contains(&mod_owned).to_string(),
375 std::string::ToString::to_string,
376 );
377 pushln(out, " }");
378 pushln(out, " }");
379 blank(out);
380 }
381
382 pushln(out, format!(" const ALL: &[{enum_name}] = &["));
384 for model in models {
385 pushln(out, format!(" Self::{},", model.variant_name));
386 }
387 pushln(out, " ];");
388
389 pushln(out, "}");
390 blank(out);
391
392 emit_from_str_impl(out, &enum_name, cfg.parser_name, models);
393 }
394}
395
396fn emit_from_str_impl(out: &mut String, enum_name: &str, parser_name: &str, models: &[ModelInfo]) {
397 pushln(out, format!("impl std::str::FromStr for {enum_name} {{"));
398 pushln(out, " type Err = String;");
399 blank(out);
400 pushln(out, " #[allow(clippy::too_many_lines)]");
401 pushln(out, " fn from_str(s: &str) -> Result<Self, Self::Err> {");
402 pushln(out, " match s {");
403 for model in models {
404 pushln(out, format!(" \"{}\" => Ok(Self::{}),", model.model_id, model.variant_name));
405 }
406 pushln(out, format!(" _ => Err(format!(\"Unknown {parser_name} model: '{{s}}'\")),"));
407 pushln(out, " }");
408 pushln(out, " }");
409 pushln(out, "}");
410 blank(out);
411}
412
413fn emit_grouped_arms(
418 out: &mut String,
419 models: &[ModelInfo],
420 key_fn: impl Fn(&ModelInfo) -> String,
421 fmt_val: impl Fn(&str) -> String,
422) {
423 let mut groups: BTreeMap<String, Vec<&str>> = BTreeMap::new();
425 for model in models {
426 groups.entry(key_fn(model)).or_default().push(&model.variant_name);
427 }
428
429 for (key, variants) in &groups {
430 let rhs = fmt_val(key);
431 if variants.len() == 1 {
432 pushln(out, format!(" Self::{} => {rhs},", variants[0]));
433 } else {
434 let patterns: Vec<String> = variants.iter().map(|v| format!("Self::{v}")).collect();
435 pushln(out, format!(" {} => {rhs},", patterns.join(" | ")));
436 }
437 }
438}
439
440fn format_reasoning_levels_rhs(key: &str) -> String {
442 if key.is_empty() {
443 return "&[]".to_string();
444 }
445 let items: Vec<String> = key
446 .split(',')
447 .map(|l| {
448 let variant = level_str_to_variant(l);
449 format!("ReasoningEffort::{variant}")
450 })
451 .collect();
452 format!("&[{}]", items.join(", "))
453}
454
455fn level_str_to_variant(level: &str) -> &'static str {
458 match level {
459 "low" => "Low",
460 "medium" => "Medium",
461 "high" => "High",
462 "xhigh" => "Xhigh",
463 other => panic!("Unknown reasoning level: {other}"),
464 }
465}
466
467fn emit_llm_model_enum(out: &mut String) {
468 pushln(out, "/// A model from a specific provider");
469 pushln(out, "#[derive(Debug, Clone, PartialEq, Eq, Hash)]");
470 pushln(out, "pub enum LlmModel {");
471 for cfg in PROVIDERS {
472 pushln(out, format!(" {provider}({provider}Model),", provider = cfg.enum_name));
473 }
474 for dyn_cfg in DYNAMIC_PROVIDERS {
475 pushln(out, format!(" {}(String),", dyn_cfg.enum_name));
476 }
477 pushln(out, "}");
478 blank(out);
479}
480
481fn emit_from_impls(out: &mut String) {
482 for cfg in PROVIDERS {
483 pushln(out, format!("impl From<{}Model> for LlmModel {{", cfg.enum_name));
484 pushln(out, format!(" fn from(m: {}Model) -> Self {{ LlmModel::{}(m) }}", cfg.enum_name, cfg.enum_name));
485 pushln(out, "}");
486 blank(out);
487 }
488}
489
490fn emit_llm_model_impl(out: &mut String) {
491 pushln(out, "impl LlmModel {");
492 emit_llm_model_id(out);
493 emit_llm_display_name(out);
494 emit_llm_provider(out);
495 emit_llm_provider_display_name(out);
496 emit_llm_context_window(out);
497 emit_llm_required_env_var(out);
498 emit_llm_all_required_env_vars(out);
499 emit_llm_oauth_provider_id(out);
500 emit_llm_reasoning_levels(out);
501 emit_llm_supports_reasoning(out);
502 for modality in ["image", "audio"] {
503 emit_llm_supports_modality(out, modality);
504 }
505 emit_llm_all(out);
506 pushln(out, "}");
507 blank(out);
508}
509
510fn emit_llm_model_id(out: &mut String) {
511 pushln(out, " /// Raw model ID (e.g. `claude-opus-4-6`, `llama3.2`)");
512 pushln(out, " pub fn model_id(&self) -> Cow<'static, str> {");
513 pushln(out, " match self {");
514 for cfg in PROVIDERS {
515 pushln(out, format!(" Self::{}(m) => Cow::Borrowed(m.model_id()),", cfg.enum_name));
516 }
517 pushln(out, format!(" {} => Cow::Owned(s.clone()),", dynamic_pattern_with_binding("s")));
518 pushln(out, " }");
519 pushln(out, " }");
520 blank(out);
521}
522
523fn emit_llm_display_name(out: &mut String) {
524 pushln(out, " /// Human-readable display name (e.g. `Claude Opus 4.6`)");
525 pushln(out, " pub fn display_name(&self) -> Cow<'static, str> {");
526 pushln(out, " match self {");
527 for cfg in PROVIDERS {
528 pushln(out, format!(" Self::{}(m) => Cow::Borrowed(m.display_name()),", cfg.enum_name));
529 }
530 for dyn_cfg in DYNAMIC_PROVIDERS {
531 pushln(
532 out,
533 format!(
534 " Self::{}(s) => Cow::Owned(format!(\"{} {{s}}\")),",
535 dyn_cfg.enum_name, dyn_cfg.enum_name
536 ),
537 );
538 }
539 pushln(out, " }");
540 pushln(out, " }");
541 blank(out);
542}
543
544fn emit_llm_provider(out: &mut String) {
545 pushln(out, " /// Provider identifier (e.g. `anthropic`)");
546 pushln(out, " pub fn provider(&self) -> &'static str {");
547 pushln(out, " match self {");
548 for cfg in PROVIDERS {
549 pushln(out, format!(" Self::{}(_) => \"{}\",", cfg.enum_name, cfg.parser_name));
550 }
551 for dyn_cfg in DYNAMIC_PROVIDERS {
552 pushln(out, format!(" Self::{}(_) => \"{}\",", dyn_cfg.enum_name, dyn_cfg.parser_name));
553 }
554 pushln(out, " }");
555 pushln(out, " }");
556 blank(out);
557}
558
559fn emit_llm_provider_display_name(out: &mut String) {
560 pushln(out, " /// Human-readable provider name (e.g. `AWS Bedrock`)");
561 pushln(out, " pub fn provider_display_name(&self) -> &'static str {");
562 pushln(out, " match self {");
563 for cfg in PROVIDERS {
564 pushln(out, format!(" Self::{}(_) => \"{}\",", cfg.enum_name, cfg.display_name));
565 }
566 for dyn_cfg in DYNAMIC_PROVIDERS {
567 pushln(out, format!(" Self::{}(_) => \"{}\",", dyn_cfg.enum_name, dyn_cfg.display_name));
568 }
569 pushln(out, " }");
570 pushln(out, " }");
571 blank(out);
572}
573
574fn emit_llm_context_window(out: &mut String) {
575 pushln(out, " /// Context window size in tokens (None for dynamic providers)");
576 pushln(out, " pub fn context_window(&self) -> Option<u32> {");
577 pushln(out, " match self {");
578 for cfg in PROVIDERS {
579 pushln(out, format!(" Self::{}(m) => Some(m.context_window()),", cfg.enum_name));
580 }
581 pushln(out, format!(" {} => None,", dynamic_pattern_with_binding("_")));
582 pushln(out, " }");
583 pushln(out, " }");
584 blank(out);
585}
586
587fn emit_llm_required_env_var(out: &mut String) {
588 pushln(out, " /// Required env var for this model's provider (None for local providers)");
589 pushln(out, " pub fn required_env_var(&self) -> Option<&'static str> {");
590 pushln(out, " match self {");
591 let mut none_arms: Vec<String> = Vec::new();
592 for cfg in PROVIDERS {
593 match cfg.env_var {
594 Some(var) => pushln(out, format!(" Self::{}(_) => Some(\"{}\"),", cfg.enum_name, var)),
595 None => none_arms.push(format!("Self::{}(_)", cfg.enum_name)),
596 }
597 }
598 for dyn_cfg in DYNAMIC_PROVIDERS {
599 none_arms.push(format!("Self::{}(_)", dyn_cfg.enum_name));
600 }
601 pushln(out, format!(" {} => None,", none_arms.join(" | ")));
602 pushln(out, " }");
603 pushln(out, " }");
604 blank(out);
605}
606
607fn emit_llm_all_required_env_vars(out: &mut String) {
608 let vars: Vec<&str> = PROVIDERS.iter().filter_map(|cfg| cfg.env_var).collect();
609 pushln(out, " /// All provider API key env var names (deduplicated, static)");
610 pushln(
611 out,
612 format!(
613 " pub const ALL_REQUIRED_ENV_VARS: &[&str] = &[{}];",
614 vars.iter().map(|v| format!("\"{v}\"")).collect::<Vec<_>>().join(", ")
615 ),
616 );
617 blank(out);
618}
619
620fn emit_llm_oauth_provider_id(out: &mut String) {
621 pushln(out, " /// OAuth provider ID if this model requires OAuth login (e.g. `\"codex\"`)");
622 pushln(out, " pub fn oauth_provider_id(&self) -> Option<&'static str> {");
623 pushln(out, " match self {");
624 let mut none_arms: Vec<String> = Vec::new();
625 for cfg in PROVIDERS {
626 match cfg.oauth_provider_id {
627 Some(id) => pushln(out, format!(" Self::{}(_) => Some(\"{}\"),", cfg.enum_name, id)),
628 None => none_arms.push(format!("Self::{}(_)", cfg.enum_name)),
629 }
630 }
631 for dyn_cfg in DYNAMIC_PROVIDERS {
632 none_arms.push(format!("Self::{}(_)", dyn_cfg.enum_name));
633 }
634 pushln(out, format!(" {} => None,", none_arms.join(" | ")));
635 pushln(out, " }");
636 pushln(out, " }");
637 blank(out);
638}
639
640fn emit_llm_reasoning_levels(out: &mut String) {
641 pushln(out, " /// Reasoning levels supported by this model (empty if not a reasoning model)");
642 pushln(out, " pub fn reasoning_levels(&self) -> &'static [ReasoningEffort] {");
643 pushln(out, " match self {");
644 for cfg in PROVIDERS {
645 pushln(out, format!(" Self::{}(m) => m.reasoning_levels(),", cfg.enum_name));
646 }
647 pushln(out, format!(" {} => &[],", dynamic_pattern_with_binding("_")));
648 pushln(out, " }");
649 pushln(out, " }");
650 blank(out);
651}
652
653fn emit_llm_supports_reasoning(out: &mut String) {
654 pushln(out, " /// Whether this model supports reasoning/extended thinking");
655 pushln(out, " pub fn supports_reasoning(&self) -> bool {");
656 pushln(out, " !self.reasoning_levels().is_empty()");
657 pushln(out, " }");
658 blank(out);
659}
660
661fn emit_llm_supports_modality(out: &mut String, modality: &str) {
662 pushln(out, format!(" /// Whether this model supports {modality} input"));
663 pushln(out, format!(" pub fn supports_{modality}(&self) -> bool {{"));
664 pushln(out, " match self {");
665 for cfg in PROVIDERS {
666 pushln(out, format!(" Self::{}(m) => m.supports_{modality}(),", cfg.enum_name));
667 }
668 pushln(out, format!(" {} => false,", dynamic_pattern_with_binding("_")));
669 pushln(out, " }");
670 pushln(out, " }");
671 blank(out);
672}
673
674fn emit_llm_all(out: &mut String) {
675 pushln(out, " /// All catalog models (excludes dynamic providers)");
676 pushln(out, " pub fn all() -> &'static [LlmModel] {");
677 pushln(out, " static ALL: LazyLock<Vec<LlmModel>> = LazyLock::new(|| {");
678 pushln(out, " let mut v = Vec::new();");
679 for cfg in PROVIDERS {
680 pushln(
681 out,
682 format!(
683 " v.extend({}Model::ALL.iter().copied().map(LlmModel::{}));",
684 cfg.enum_name, cfg.enum_name
685 ),
686 );
687 }
688 pushln(out, " v");
689 pushln(out, " });");
690 pushln(out, " &ALL");
691 pushln(out, " }");
692}
693
694fn emit_display_impl(out: &mut String) {
695 pushln(out, "impl std::fmt::Display for LlmModel {");
696 pushln(out, " fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {");
697 pushln(out, " write!(f, \"{}:{}\", self.provider(), self.model_id())");
698 pushln(out, " }");
699 pushln(out, "}");
700 blank(out);
701}
702
703fn emit_fromstr_impl(out: &mut String) {
704 pushln(out, "impl std::str::FromStr for LlmModel {");
705 pushln(out, " type Err = String;");
706 blank(out);
707 pushln(out, " /// Parse a `provider:model` string into an `LlmModel`");
708 pushln(out, " fn from_str(s: &str) -> Result<Self, Self::Err> {");
709 pushln(out, " let (provider_str, model_str) = s.split_once(':').unwrap_or((s, \"\"));");
710 pushln(out, " match provider_str {");
711 for cfg in PROVIDERS {
712 pushln(
713 out,
714 format!(
715 " \"{}\" => model_str.parse::<{}Model>().map(Self::{}),",
716 cfg.parser_name, cfg.enum_name, cfg.enum_name
717 ),
718 );
719 }
720 for dyn_cfg in DYNAMIC_PROVIDERS {
721 pushln(
722 out,
723 format!(
724 " \"{}\" => Ok(Self::{}(model_str.to_string())),",
725 dyn_cfg.parser_name, dyn_cfg.enum_name
726 ),
727 );
728 }
729 pushln(out, " _ => Err(format!(\"Unknown provider: '{provider_str}'\")),");
730 pushln(out, " }");
731 pushln(out, " }");
732 pushln(out, "}");
733}
734
735fn dynamic_pattern_with_binding(binding: &str) -> String {
738 DYNAMIC_PROVIDERS.iter().map(|d| format!("Self::{}({binding})", d.enum_name)).collect::<Vec<_>>().join(" | ")
739}
740
741fn format_number(n: u32) -> String {
743 let s = n.to_string();
744 if s.len() <= 4 {
745 return s;
746 }
747 let mut result = String::with_capacity(s.len() + s.len() / 3);
748 for (i, ch) in s.chars().enumerate() {
749 if i > 0 && (s.len() - i).is_multiple_of(3) {
750 result.push('_');
751 }
752 result.push(ch);
753 }
754 result
755}
756
757fn escape_rust_string(raw: &str) -> String {
758 raw.replace('\\', "\\\\").replace('"', "\\\"")
759}
760
761fn emit_provider_docs(ctx: &CodegenCtx) -> HashMap<String, String> {
762 let mut docs = HashMap::new();
763
764 for cfg in PROVIDERS {
765 let models = &ctx.provider_models[cfg.dev_id];
766 let mut doc = String::new();
767
768 pushln(&mut doc, format!("`{}` LLM provider.", cfg.display_name));
769 blank(&mut doc);
770
771 pushln(&mut doc, "# Authentication");
773 blank(&mut doc);
774 match cfg.env_var {
775 Some(var) => pushln(&mut doc, format!("Set the `{var}` environment variable.")),
776 None if cfg.oauth_provider_id.is_some() => {
777 pushln(&mut doc, "This provider uses OAuth authentication.");
778 }
779 None => {
780 pushln(
781 &mut doc,
782 "Uses the default AWS credential chain (environment variables, config files, IAM roles).",
783 );
784 }
785 }
786 blank(&mut doc);
787
788 pushln(&mut doc, "# Supported models");
790 blank(&mut doc);
791 pushln(&mut doc, "| Model ID | Name | Context | Reasoning | Image | Audio |");
792 pushln(&mut doc, "|----------|------|---------|-----------|-------|-------|");
793 for model in models {
794 let ctx_str = format_context_window(model.context_window);
795 let reasoning = if model.reasoning_levels.is_empty() { "" } else { "yes" };
796 let image = if model.input_modalities.contains(&"image".to_string()) { "yes" } else { "" };
797 let audio = if model.input_modalities.contains(&"audio".to_string()) { "yes" } else { "" };
798 pushln(
799 &mut doc,
800 format!(
801 "| `{}` | `{}` | `{}` | {} | {} | {} |",
802 model.model_id, model.display_name, ctx_str, reasoning, image, audio
803 ),
804 );
805 }
806
807 docs.insert(cfg.dev_id.to_string(), doc);
808 }
809
810 for dyn_cfg in DYNAMIC_PROVIDERS {
812 let mut doc = String::new();
813 pushln(&mut doc, format!("`{}` LLM provider.", dyn_cfg.display_name));
814 blank(&mut doc);
815 pushln(
816 &mut doc,
817 format!("This provider accepts any model name at runtime (e.g. `{}:my-model`).", dyn_cfg.parser_name),
818 );
819 pushln(&mut doc, "No API key is required.");
820 docs.insert(dyn_cfg.parser_name.to_string(), doc);
821 }
822
823 docs
824}
825
826fn format_context_window(tokens: u32) -> String {
828 if tokens == 0 {
829 return "unknown".to_string();
830 }
831 if tokens >= 1_000_000 && tokens.is_multiple_of(1_000_000) {
832 format!("{}M", tokens / 1_000_000)
833 } else if tokens >= 1_000 && tokens.is_multiple_of(1_000) {
834 format!("{}k", tokens / 1_000)
835 } else {
836 format_number(tokens)
837 }
838}
839
840fn pushln(out: &mut String, line: impl AsRef<str>) {
841 writeln!(out, "{}", line.as_ref()).expect("writing to String should not fail");
842}
843
844fn blank(out: &mut String) {
845 pushln(out, "");
846}
847
848#[cfg(test)]
851mod tests {
852 use super::*;
853 use serde_json::Value;
854 use serde_json::json;
855 use tempfile::NamedTempFile;
856
857 #[test]
858 fn generate_sorts_and_filters_models() {
859 let mut data = minimal_models_dev_json();
860 let root = data.as_object_mut().expect("root object");
861 let anthropic = root.get_mut("anthropic").and_then(Value::as_object_mut).expect("anthropic provider");
862
863 anthropic.insert(
864 "models".to_string(),
865 json!({
866 "b-model": {
867 "id": "b-model",
868 "name": "B Model",
869 "tool_call": true,
870 "limit": {"context": 2000, "output": 0}
871 },
872 "a-model": {
873 "id": "a-model",
874 "name": "A Model",
875 "tool_call": true,
876 "limit": {"context": 1000, "output": 0}
877 },
878 "alpha-latest": {
879 "id": "alpha-latest",
880 "name": "Alias",
881 "tool_call": true,
882 "limit": {"context": 500, "output": 0}
883 },
884 "no-tools": {
885 "id": "no-tools",
886 "name": "No Tools",
887 "tool_call": false,
888 "limit": {"context": 500, "output": 0}
889 }
890 }),
891 );
892
893 let source = generate_from_value(&data);
894 let a_model = "\"a-model\" => Ok(Self::AModel),";
896 let b_model = "\"b-model\" => Ok(Self::BModel),";
897 let a_pos = source.find(a_model).expect("a-model parse arm");
898 let b_pos = source.find(b_model).expect("b-model parse arm");
899 assert!(a_pos < b_pos);
900 assert!(!source.contains("AlphaLatest"));
902 assert!(!source.contains("NoTools"));
903 }
904
905 #[test]
906 fn generate_contains_core_sections() {
907 let source = generate_from_value(&minimal_models_dev_json());
908 assert!(source.contains("pub enum LlmModel {"));
909 assert!(source.contains("impl std::str::FromStr for LlmModel {"));
910 assert!(source.contains("impl std::fmt::Display for LlmModel {"));
911 assert!(source.contains("pub fn required_env_var(&self) -> Option<&'static str> {"));
912 }
913
914 #[test]
915 fn generate_contains_dynamic_provider_arms() {
916 let source = generate_from_value(&minimal_models_dev_json());
917 assert!(source.contains("\"ollama\" => Ok(Self::Ollama(model_str.to_string())),"));
918 assert!(source.contains("\"llamacpp\" => Ok(Self::LlamaCpp(model_str.to_string())),"));
919 assert!(source.contains("Self::Ollama(_) | Self::LlamaCpp(_) => None,"));
921 }
922
923 #[test]
924 fn generate_codex_is_catalog_provider() {
925 let source = generate_from_value(&minimal_models_dev_json());
926 assert!(source.contains("pub enum CodexModel {"));
928 assert!(source.contains("\"codex\" => model_str.parse::<CodexModel>().map(Self::Codex),"));
929 assert!(source.contains("Self::Codex(m) => Some(m.context_window()),"));
930 }
931
932 #[test]
933 fn generate_oauth_provider_id_for_codex() {
934 let source = generate_from_value(&minimal_models_dev_json());
935 assert!(source.contains("Self::Codex(_) => Some(\"codex\"),"));
937 assert!(source.contains("pub fn oauth_provider_id(&self) -> Option<&'static str>"));
939 }
940
941 #[test]
942 fn generate_delegates_to_provider_impls() {
943 let source = generate_from_value(&minimal_models_dev_json());
944 assert!(source.contains("Self::Anthropic(m) => Cow::Borrowed(m.model_id()),"));
946 assert!(source.contains("Self::Anthropic(m) => Some(m.context_window()),"));
947 assert!(source.contains("\"anthropic\" => model_str.parse::<AnthropicModel>().map(Self::Anthropic),"));
949 }
950
951 #[test]
952 fn generate_formats_large_numbers_with_separators() {
953 let mut data = minimal_models_dev_json();
954 let root = data.as_object_mut().expect("root object");
955 let anthropic = root.get_mut("anthropic").and_then(Value::as_object_mut).expect("anthropic provider");
956
957 anthropic.insert(
958 "models".to_string(),
959 json!({
960 "big-model": {
961 "id": "big-model",
962 "name": "Big Model",
963 "tool_call": true,
964 "limit": {"context": 200_000, "output": 0}
965 }
966 }),
967 );
968
969 let source = generate_from_value(&data);
970 assert!(source.contains("200_000"));
971 assert!(!source.contains("200000"));
972 }
973
974 #[test]
975 fn generate_groups_identical_match_arms() {
976 let mut data = minimal_models_dev_json();
977 let root = data.as_object_mut().expect("root object");
978 let anthropic = root.get_mut("anthropic").and_then(Value::as_object_mut).expect("anthropic provider");
979
980 anthropic.insert(
981 "models".to_string(),
982 json!({
983 "model-a": {
984 "id": "model-a",
985 "name": "Same Name",
986 "tool_call": true,
987 "limit": {"context": 100_000, "output": 0}
988 },
989 "model-b": {
990 "id": "model-b",
991 "name": "Same Name",
992 "tool_call": true,
993 "limit": {"context": 100_000, "output": 0}
994 }
995 }),
996 );
997
998 let source = generate_from_value(&data);
999 assert!(source.contains("Self::ModelA | Self::ModelB => 100_000,"));
1001 assert!(source.contains("Self::ModelA | Self::ModelB => \"Same Name\","));
1002 }
1003
1004 #[test]
1005 fn test_model_id_to_variant() {
1006 assert_eq!(model_id_to_variant("claude-sonnet-4-5-20250929"), "ClaudeSonnet4520250929");
1007 assert_eq!(model_id_to_variant("gemini-2.5-flash"), "Gemini25Flash");
1008 assert_eq!(model_id_to_variant("deepseek-chat"), "DeepseekChat");
1009 assert_eq!(model_id_to_variant("glm-4.5"), "Glm45");
1010 }
1011
1012 #[test]
1013 fn test_model_id_to_variant_with_slash_and_colon() {
1014 assert_eq!(model_id_to_variant("anthropic/claude-opus-4.6"), "AnthropicClaudeOpus46");
1015 assert_eq!(model_id_to_variant("openai/gpt-5.1-codex-max"), "OpenaiGpt51CodexMax");
1016 assert_eq!(model_id_to_variant("deepseek/deepseek-r1:free"), "DeepseekDeepseekR1Free");
1017 }
1018
1019 #[test]
1020 fn test_is_alias() {
1021 assert!(is_alias("claude-sonnet-4-5-latest"));
1022 assert!(is_alias("claude-3-7-sonnet-latest"));
1023 assert!(!is_alias("claude-sonnet-4-5-20250929"));
1024 }
1025
1026 #[test]
1027 fn generate_contains_reasoning_levels_and_supports_reasoning() {
1028 let mut data = minimal_models_dev_json();
1029 let root = data.as_object_mut().expect("root object");
1030 let anthropic = root.get_mut("anthropic").and_then(Value::as_object_mut).expect("anthropic provider");
1031
1032 anthropic.insert(
1033 "models".to_string(),
1034 json!({
1035 "thinker": {
1036 "id": "thinker",
1037 "name": "Thinker",
1038 "tool_call": true,
1039 "reasoning": true,
1040 "limit": {"context": 1000, "output": 0}
1041 },
1042 "fast": {
1043 "id": "fast",
1044 "name": "Fast",
1045 "tool_call": true,
1046 "reasoning": false,
1047 "limit": {"context": 1000, "output": 0}
1048 }
1049 }),
1050 );
1051
1052 let source = generate_from_value(&data);
1053 assert!(source.contains("pub fn reasoning_levels(self) -> &'static [ReasoningEffort] {"));
1055 assert!(
1057 source
1058 .contains("Self::Thinker => &[ReasoningEffort::Low, ReasoningEffort::Medium, ReasoningEffort::High],")
1059 );
1060 assert!(source.contains("Self::Fast => &[],"));
1062 assert!(source.contains("pub fn supports_reasoning(self) -> bool {"));
1064 assert!(source.contains("!self.reasoning_levels().is_empty()"));
1065 assert!(source.contains("pub fn reasoning_levels(&self) -> &'static [ReasoningEffort] {"));
1067 assert!(source.contains("Self::Ollama(_) | Self::LlamaCpp(_) => &[],"));
1069 }
1070
1071 #[test]
1072 fn generate_codex_gets_four_reasoning_levels() {
1073 let mut data = minimal_models_dev_json();
1074 let root = data.as_object_mut().expect("root object");
1075 let openai = root.get_mut("openai").and_then(Value::as_object_mut).expect("openai provider");
1076
1077 openai.insert(
1078 "models".to_string(),
1079 json!({
1080 "gpt-5.4-codex": {
1081 "id": "gpt-5.4-codex",
1082 "name": "GPT-5.4 Codex",
1083 "tool_call": true,
1084 "reasoning": true,
1085 "limit": {"context": 200_000, "output": 0}
1086 }
1087 }),
1088 );
1089
1090 let source = generate_from_value(&data);
1091 assert!(
1092 source.contains(
1093 "ReasoningEffort::Low, ReasoningEffort::Medium, ReasoningEffort::High, ReasoningEffort::Xhigh"
1094 ),
1095 "Codex reasoning model should have 4 levels including Xhigh"
1096 );
1097 }
1098
1099 fn generate_from_value(data: &Value) -> String {
1100 let tmp = NamedTempFile::new().expect("temp file");
1101 let json = serde_json::to_string(data).expect("serialize fixture");
1102 std::fs::write(tmp.path(), json).expect("write fixture");
1103 generate(tmp.path()).expect("codegen succeeds").rust_source
1104 }
1105
1106 fn minimal_models_dev_json() -> Value {
1107 let mut root = serde_json::Map::new();
1108 for cfg in PROVIDERS {
1109 let json_key = cfg.json_key();
1110 root.entry(json_key.to_string()).or_insert_with(|| {
1111 json!({
1112 "id": json_key,
1113 "name": json_key,
1114 "env": [],
1115 "models": {}
1116 })
1117 });
1118 for &extra in cfg.extra_source_ids {
1119 root.entry(extra.to_string()).or_insert_with(|| {
1120 json!({
1121 "id": extra,
1122 "name": extra,
1123 "env": [],
1124 "models": {}
1125 })
1126 });
1127 }
1128 }
1129 Value::Object(root)
1130 }
1131
1132 #[test]
1133 fn extra_source_ids_merges_models_into_provider() {
1134 let mut data = minimal_models_dev_json();
1135 let root = data.as_object_mut().unwrap();
1136
1137 let extra = root.get_mut("zai-coding-plan").unwrap().as_object_mut().unwrap();
1139 extra.insert(
1140 "models".to_string(),
1141 json!({
1142 "extra-model": {
1143 "id": "extra-model",
1144 "name": "Extra Model",
1145 "tool_call": true,
1146 "limit": {"context": 4000, "output": 0}
1147 }
1148 }),
1149 );
1150
1151 let source = generate_from_value(&data);
1152 assert!(source.contains("\"extra-model\" => Ok(Self::ExtraModel),"));
1154 }
1155
1156 #[test]
1157 fn extra_source_ids_does_not_duplicate_existing_models() {
1158 let mut data = minimal_models_dev_json();
1159 let root = data.as_object_mut().unwrap();
1160
1161 let zai = root.get_mut("zai").unwrap().as_object_mut().unwrap();
1163 zai.insert(
1164 "models".to_string(),
1165 json!({
1166 "shared-model": {
1167 "id": "shared-model",
1168 "name": "Shared Model",
1169 "tool_call": true,
1170 "limit": {"context": 1000, "output": 0}
1171 }
1172 }),
1173 );
1174 let extra = root.get_mut("zai-coding-plan").unwrap().as_object_mut().unwrap();
1175 extra.insert(
1176 "models".to_string(),
1177 json!({
1178 "shared-model": {
1179 "id": "shared-model",
1180 "name": "Shared Model Duplicate",
1181 "tool_call": true,
1182 "limit": {"context": 2000, "output": 0}
1183 }
1184 }),
1185 );
1186
1187 let source = generate_from_value(&data);
1188 let from_str_matches = source.matches("\"shared-model\" => Ok(Self::SharedModel),").count();
1189 assert_eq!(from_str_matches, 1);
1190 }
1191
1192 #[test]
1193 fn generate_emits_provider_docs() {
1194 let mut data = minimal_models_dev_json();
1195 let root = data.as_object_mut().unwrap();
1196 let anthropic = root.get_mut("anthropic").and_then(Value::as_object_mut).unwrap();
1197
1198 anthropic.insert(
1199 "models".to_string(),
1200 json!({
1201 "claude-test": {
1202 "id": "claude-test",
1203 "name": "Claude Test",
1204 "tool_call": true,
1205 "reasoning": true,
1206 "limit": {"context": 200_000, "output": 0},
1207 "modalities": {"input": ["text", "image"]}
1208 }
1209 }),
1210 );
1211
1212 let tmp = NamedTempFile::new().unwrap();
1213 std::fs::write(tmp.path(), serde_json::to_string(&data).unwrap()).unwrap();
1214 let output = generate(tmp.path()).unwrap();
1215
1216 let anthropic_doc = &output.provider_docs["anthropic"];
1217 assert!(anthropic_doc.contains("`Anthropic` LLM provider."));
1218 assert!(anthropic_doc.contains("`ANTHROPIC_API_KEY`"));
1219 assert!(anthropic_doc.contains("| `claude-test` | `Claude Test` | `200k` | yes | yes | |"));
1220
1221 let ollama_doc = &output.provider_docs["ollama"];
1223 assert!(ollama_doc.contains("`Ollama` LLM provider."));
1224 assert!(ollama_doc.contains("any model name at runtime"));
1225 }
1226
1227 #[test]
1228 fn format_context_window_formats_correctly() {
1229 assert_eq!(format_context_window(1_000_000), "1M");
1230 assert_eq!(format_context_window(200_000), "200k");
1231 assert_eq!(format_context_window(8_000), "8k");
1232 assert_eq!(format_context_window(0), "unknown");
1233 }
1234
1235 #[test]
1236 fn level_str_to_variant_covers_all_reasoning_efforts() {
1237 for effort in utils::ReasoningEffort::all() {
1238 let _ = level_str_to_variant(effort.as_str());
1240 }
1241 }
1242}