1use std::collections::HashMap;
20
21use serde::Deserialize;
22use tracing::debug;
23
24use crate::provider::CostRates;
25
26const DEFAULTS_TOML: &str = include_str!("defaults/models.toml");
28
29#[derive(Debug, Deserialize)]
32struct CatalogFile {
33 #[serde(flatten)]
34 providers: HashMap<String, ProviderEntry>,
35}
36
37#[derive(Debug, Deserialize)]
38struct ProviderEntry {
39 #[serde(default)]
40 default_model: Option<String>,
41 #[serde(default)]
42 api_key_env: Option<String>,
43 #[serde(default)]
44 cache_read_multiplier: Option<f64>,
45 #[serde(default)]
46 cache_creation_multiplier: Option<f64>,
47 #[serde(default)]
48 models: HashMap<String, ModelEntry>,
49}
50
51#[derive(Debug, Deserialize)]
52struct ModelEntry {
53 input: f64,
54 output: f64,
55 #[serde(default)]
56 context_window: Option<u64>,
57 #[serde(default = "default_true")]
58 supports_tool_use: bool,
59 #[serde(default)]
60 supports_vision: bool,
61 #[serde(default)]
62 cache_read_multiplier: Option<f64>,
63 #[serde(default)]
64 cache_creation_multiplier: Option<f64>,
65}
66
67fn default_true() -> bool {
68 true
69}
70
71#[derive(Debug, Clone)]
75pub struct ModelInfo {
76 pub id: String,
78 pub provider: String,
80 pub pricing: CostRates,
82 pub context_window: Option<u64>,
84 pub supports_tool_use: bool,
86 pub supports_vision: bool,
88}
89
90#[derive(Debug, Clone)]
92pub struct ProviderInfo {
93 pub name: String,
95 pub default_model: Option<String>,
97 pub api_key_env: Option<String>,
99 pub cache_read_multiplier: Option<f64>,
101 pub cache_creation_multiplier: Option<f64>,
103}
104
105type ModelKey = String;
109
110fn make_key(provider: &str, model: &str) -> ModelKey {
111 format!("{provider}::{model}")
112}
113
114#[derive(Debug, Clone)]
122pub struct ModelRegistry {
123 models: HashMap<ModelKey, ModelInfo>,
124 providers: HashMap<String, ProviderInfo>,
125}
126
127impl ModelRegistry {
128 pub fn new() -> Self {
130 Self {
131 models: HashMap::new(),
132 providers: HashMap::new(),
133 }
134 }
135
136 pub fn with_defaults() -> Self {
138 Self::from_toml(DEFAULTS_TOML).expect("embedded models.toml must be valid")
139 }
140
141 pub fn from_toml(toml_str: &str) -> Result<Self, String> {
143 let file: CatalogFile =
144 toml::from_str(toml_str).map_err(|e| format!("models TOML parse error: {e}"))?;
145
146 let mut models = HashMap::new();
147 let mut providers = HashMap::new();
148
149 for (prov_name, pe) in &file.providers {
150 providers.insert(
151 prov_name.clone(),
152 ProviderInfo {
153 name: prov_name.clone(),
154 default_model: pe.default_model.clone(),
155 api_key_env: pe.api_key_env.clone(),
156 cache_read_multiplier: pe.cache_read_multiplier,
157 cache_creation_multiplier: pe.cache_creation_multiplier,
158 },
159 );
160
161 for (model_id, me) in &pe.models {
162 let info = ModelInfo {
163 id: model_id.clone(),
164 provider: prov_name.clone(),
165 pricing: CostRates {
166 input_per_million: me.input,
167 output_per_million: me.output,
168 cache_read_multiplier: me
169 .cache_read_multiplier
170 .or(pe.cache_read_multiplier),
171 cache_creation_multiplier: me
172 .cache_creation_multiplier
173 .or(pe.cache_creation_multiplier),
174 },
175 context_window: me.context_window,
176 supports_tool_use: me.supports_tool_use,
177 supports_vision: me.supports_vision,
178 };
179 models.insert(make_key(prov_name, model_id), info);
180 }
181 }
182
183 Ok(Self { models, providers })
184 }
185
186 pub fn merge(&mut self, other: Self) {
188 for (key, info) in other.models {
189 self.models.insert(key, info);
190 }
191 for (key, info) in other.providers {
192 if let Some(existing) = self.providers.get_mut(&key) {
193 if info.default_model.is_some() {
194 existing.default_model = info.default_model;
195 }
196 if info.api_key_env.is_some() {
197 existing.api_key_env = info.api_key_env;
198 }
199 if info.cache_read_multiplier.is_some() {
200 existing.cache_read_multiplier = info.cache_read_multiplier;
201 }
202 if info.cache_creation_multiplier.is_some() {
203 existing.cache_creation_multiplier = info.cache_creation_multiplier;
204 }
205 } else {
206 self.providers.insert(key, info);
207 }
208 }
209 }
210
211 pub fn register(&mut self, provider: &str, model_id: &str, info: ModelInfo) {
215 self.models.insert(make_key(provider, model_id), info);
216 }
217
218 pub fn get(&self, provider: &str, model: &str) -> Option<&ModelInfo> {
222 self.models.get(&make_key(provider, model))
223 }
224
225 pub fn get_fuzzy(&self, provider: &str, model: &str) -> Option<&ModelInfo> {
228 if let Some(info) = self.get(provider, model) {
229 return Some(info);
230 }
231
232 let prefix = format!("{provider}::");
233
234 let mut best: Option<(&str, &ModelInfo)> = None;
235 for (key, info) in &self.models {
236 if let Some(registered) = key.strip_prefix(&prefix) {
237 if model.contains(registered) || registered.contains(model) {
238 let dominated = best
239 .map(|(prev, _)| registered.len() > prev.len())
240 .unwrap_or(true);
241 if dominated {
242 best = Some((registered, info));
243 }
244 }
245 }
246 }
247 if let Some((matched, info)) = best {
248 debug!(provider, model, matched, "fuzzy model match");
249 return Some(info);
250 }
251
252 None
253 }
254
255 pub fn get_pricing(&self, provider: &str, model: &str) -> Option<CostRates> {
258 if let Some(info) = self.get_fuzzy(provider, model) {
259 return Some(info.pricing.clone());
260 }
261
262 self.providers.get(provider).and_then(|p| {
263 if p.cache_read_multiplier.is_some() || p.cache_creation_multiplier.is_some() {
264 Some(CostRates {
265 input_per_million: 0.0,
266 output_per_million: 0.0,
267 cache_read_multiplier: p.cache_read_multiplier,
268 cache_creation_multiplier: p.cache_creation_multiplier,
269 })
270 } else {
271 None
272 }
273 })
274 }
275
276 pub fn provider(&self, name: &str) -> Option<&ProviderInfo> {
280 self.providers.get(name)
281 }
282
283 pub fn provider_names(&self) -> Vec<&str> {
285 let mut names: Vec<&str> = self.providers.keys().map(|s| s.as_str()).collect();
286 names.sort();
287 names
288 }
289
290 pub fn default_model(&self, provider: &str) -> Option<&str> {
292 self.providers
293 .get(provider)
294 .and_then(|p| p.default_model.as_deref())
295 }
296
297 pub fn api_key_env(&self, provider: &str) -> Option<&str> {
299 self.providers
300 .get(provider)
301 .and_then(|p| p.api_key_env.as_deref())
302 }
303
304 pub fn models_for_provider(&self, provider: &str) -> Vec<&str> {
306 let prefix = format!("{provider}::");
307 let mut out: Vec<&str> = self
308 .models
309 .iter()
310 .filter_map(|(key, info)| {
311 if key.starts_with(&prefix) {
312 Some(info.id.as_str())
313 } else {
314 None
315 }
316 })
317 .collect();
318 out.sort();
319 out
320 }
321
322 pub fn models_by_provider(&self) -> HashMap<String, Vec<String>> {
324 let mut result: HashMap<String, Vec<String>> = HashMap::new();
325 for prov in self.providers.keys() {
326 result.insert(
327 prov.clone(),
328 self.models_for_provider(prov)
329 .into_iter()
330 .map(String::from)
331 .collect(),
332 );
333 }
334 result
335 }
336
337 pub fn len(&self) -> usize {
339 self.models.len()
340 }
341
342 pub fn is_empty(&self) -> bool {
344 self.models.is_empty()
345 }
346}
347
348impl Default for ModelRegistry {
349 fn default() -> Self {
350 Self::with_defaults()
351 }
352}
353
354pub type PricingRegistry = ModelRegistry;
356
357#[cfg(test)]
358mod tests {
359 use super::*;
360
361 #[test]
362 fn defaults_load_successfully() {
363 let reg = ModelRegistry::with_defaults();
364 assert!(!reg.is_empty());
365 }
366
367 #[test]
368 fn exact_match() {
369 let reg = ModelRegistry::with_defaults();
370 let info = reg.get("anthropic", "claude-sonnet-4-5").unwrap();
371 assert!((info.pricing.input_per_million - 3.0).abs() < 1e-9);
372 assert!((info.pricing.output_per_million - 15.0).abs() < 1e-9);
373 assert!((info.pricing.cache_read_multiplier.unwrap() - 0.1).abs() < 1e-9);
374 assert!((info.pricing.cache_creation_multiplier.unwrap() - 1.25).abs() < 1e-9);
375 assert_eq!(info.context_window, Some(200_000));
376 assert!(info.supports_tool_use);
377 assert!(info.supports_vision);
378 }
379
380 #[test]
381 fn fuzzy_match_longer_model_id() {
382 let reg = ModelRegistry::with_defaults();
383 let info = reg
384 .get_fuzzy("anthropic", "claude-sonnet-4-5-20250514")
385 .unwrap();
386 assert!((info.pricing.input_per_million - 3.0).abs() < 1e-9);
387 }
388
389 #[test]
390 fn fuzzy_match_picks_most_specific() {
391 let mut reg = ModelRegistry::new();
392 let short_key = make_key("test", "claude-sonnet");
393 reg.models.insert(
394 short_key,
395 ModelInfo {
396 id: "claude-sonnet".into(),
397 provider: "test".into(),
398 pricing: CostRates {
399 input_per_million: 1.0,
400 output_per_million: 5.0,
401 cache_read_multiplier: None,
402 cache_creation_multiplier: None,
403 },
404 context_window: None,
405 supports_tool_use: true,
406 supports_vision: false,
407 },
408 );
409 let long_key = make_key("test", "claude-sonnet-4-5");
410 reg.models.insert(
411 long_key,
412 ModelInfo {
413 id: "claude-sonnet-4-5".into(),
414 provider: "test".into(),
415 pricing: CostRates {
416 input_per_million: 3.0,
417 output_per_million: 15.0,
418 cache_read_multiplier: None,
419 cache_creation_multiplier: None,
420 },
421 context_window: None,
422 supports_tool_use: true,
423 supports_vision: false,
424 },
425 );
426 let info = reg.get_fuzzy("test", "claude-sonnet-4-5-20250514").unwrap();
427 assert!((info.pricing.input_per_million - 3.0).abs() < 1e-9);
428 }
429
430 #[test]
431 fn provider_default_cache_multipliers() {
432 let reg = ModelRegistry::with_defaults();
433 let pricing = reg.get_pricing("anthropic", "claude-unknown-99").unwrap();
434 assert!((pricing.cache_read_multiplier.unwrap() - 0.1).abs() < 1e-9);
435 }
436
437 #[test]
438 fn merge_overrides() {
439 let mut base = ModelRegistry::with_defaults();
440 let overrides = ModelRegistry::from_toml(
441 r#"
442[anthropic.models.claude-sonnet-4-5]
443input = 99.0
444output = 99.0
445"#,
446 )
447 .unwrap();
448 base.merge(overrides);
449 let info = base.get("anthropic", "claude-sonnet-4-5").unwrap();
450 assert!((info.pricing.input_per_million - 99.0).abs() < 1e-9);
451 }
452
453 #[test]
454 fn openai_cache_rates() {
455 let reg = ModelRegistry::with_defaults();
456 let info = reg.get("openai", "gpt-4o").unwrap();
457 assert!((info.pricing.cache_read_multiplier.unwrap() - 0.1).abs() < 1e-9);
458 assert!((info.pricing.cache_creation_multiplier.unwrap() - 1.0).abs() < 1e-9);
459 }
460
461 #[test]
462 fn gemini_cache_rates() {
463 let reg = ModelRegistry::with_defaults();
464 let info = reg.get_fuzzy("gemini", "gemini-2-5-flash").unwrap();
465 assert!((info.pricing.cache_read_multiplier.unwrap() - 0.1).abs() < 1e-9);
466 }
467
468 #[test]
469 fn from_toml_custom() {
470 let toml = r#"
471[custom]
472cache_read_multiplier = 0.3
473
474[custom.models.my-model]
475input = 5.0
476output = 20.0
477"#;
478 let reg = ModelRegistry::from_toml(toml).unwrap();
479 let info = reg.get("custom", "my-model").unwrap();
480 assert!((info.pricing.input_per_million - 5.0).abs() < 1e-9);
481 assert!((info.pricing.cache_read_multiplier.unwrap() - 0.3).abs() < 1e-9);
482 assert!(info.pricing.cache_creation_multiplier.is_none());
483 }
484
485 #[test]
486 fn per_model_cache_override() {
487 let toml = r#"
488[prov]
489cache_read_multiplier = 0.1
490cache_creation_multiplier = 1.25
491
492[prov.models.special]
493input = 10.0
494output = 50.0
495cache_read_multiplier = 0.05
496"#;
497 let reg = ModelRegistry::from_toml(toml).unwrap();
498 let info = reg.get("prov", "special").unwrap();
499 assert!((info.pricing.cache_read_multiplier.unwrap() - 0.05).abs() < 1e-9);
500 assert!((info.pricing.cache_creation_multiplier.unwrap() - 1.25).abs() < 1e-9);
501 }
502
503 #[test]
504 fn empty_provider_no_panic() {
505 let toml = r#"
506[empty]
507"#;
508 let reg = ModelRegistry::from_toml(toml).unwrap();
509 assert!(reg.get("empty", "anything").is_none());
510 assert!(reg.get_fuzzy("empty", "anything").is_none());
511 }
512
513 #[test]
516 fn default_model_per_provider() {
517 let reg = ModelRegistry::with_defaults();
518 assert_eq!(reg.default_model("anthropic"), Some("claude-haiku-4-5"));
519 assert_eq!(reg.default_model("openai"), Some("gpt-4o"));
520 assert_eq!(reg.default_model("gemini"), Some("gemini-2.5-pro"));
521 assert_eq!(reg.default_model("groq"), Some("llama-3.3-70b-versatile"));
522 assert_eq!(reg.default_model("deepseek"), Some("deepseek-chat"));
523 assert_eq!(reg.default_model("ollama"), Some("qwen3.5:9b"));
524 }
525
526 #[test]
527 fn api_key_env_per_provider() {
528 let reg = ModelRegistry::with_defaults();
529 assert_eq!(reg.api_key_env("anthropic"), Some("ANTHROPIC_API_KEY"));
530 assert_eq!(reg.api_key_env("openai"), Some("OPENAI_API_KEY"));
531 assert_eq!(reg.api_key_env("ollama"), None);
532 }
533
534 #[test]
535 fn models_for_provider_lists_all() {
536 let reg = ModelRegistry::with_defaults();
537 let anthropic = reg.models_for_provider("anthropic");
538 assert!(anthropic.contains(&"claude-haiku-4-5"));
539 assert!(anthropic.contains(&"claude-sonnet-4-6"));
540 assert!(anthropic.contains(&"claude-opus-4-6"));
541 assert!(anthropic.len() >= 4);
542 }
543
544 #[test]
545 fn models_by_provider_for_settings_api() {
546 let reg = ModelRegistry::with_defaults();
547 let map = reg.models_by_provider();
548 assert!(map.contains_key("anthropic"));
549 assert!(map.contains_key("openai"));
550 assert!(map.contains_key("ollama"));
551 assert!(map["ollama"].is_empty());
552 }
553
554 #[test]
555 fn provider_names_returns_all() {
556 let reg = ModelRegistry::with_defaults();
557 let names = reg.provider_names();
558 assert!(names.contains(&"anthropic"));
559 assert!(names.contains(&"openai"));
560 assert!(names.contains(&"gemini"));
561 assert!(names.contains(&"groq"));
562 assert!(names.contains(&"deepseek"));
563 assert!(names.contains(&"openrouter"));
564 assert!(names.contains(&"ollama"));
565 }
566
567 #[test]
568 fn model_capabilities() {
569 let reg = ModelRegistry::with_defaults();
570 let haiku = reg.get("anthropic", "claude-haiku-4-5").unwrap();
571 assert!(haiku.supports_tool_use);
572 assert!(haiku.supports_vision);
573
574 let gpt41 = reg.get("openai", "gpt-4.1").unwrap();
575 assert!(gpt41.supports_tool_use);
576 assert!(!gpt41.supports_vision);
577 }
578
579 #[test]
582 fn register_makes_model_visible_via_get() {
583 let mut reg = ModelRegistry::new();
584 reg.register(
585 "ollama",
586 "qwen3.5:9b",
587 ModelInfo {
588 id: "qwen3.5:9b".into(),
589 provider: "ollama".into(),
590 pricing: CostRates {
591 input_per_million: 0.0,
592 output_per_million: 0.0,
593 cache_read_multiplier: None,
594 cache_creation_multiplier: None,
595 },
596 context_window: Some(262_144),
597 supports_tool_use: true,
598 supports_vision: true,
599 },
600 );
601 let info = reg.get("ollama", "qwen3.5:9b").unwrap();
602 assert_eq!(info.context_window, Some(262_144));
603 assert!(info.supports_vision);
604 }
605
606 #[test]
607 fn register_appears_in_models_for_provider() {
608 let mut reg = ModelRegistry::with_defaults();
609 assert!(reg.models_for_provider("ollama").is_empty());
610
611 reg.register(
612 "ollama",
613 "llama3:8b",
614 ModelInfo {
615 id: "llama3:8b".into(),
616 provider: "ollama".into(),
617 pricing: CostRates {
618 input_per_million: 0.0,
619 output_per_million: 0.0,
620 cache_read_multiplier: None,
621 cache_creation_multiplier: None,
622 },
623 context_window: Some(131_072),
624 supports_tool_use: true,
625 supports_vision: false,
626 },
627 );
628 let models = reg.models_for_provider("ollama");
629 assert_eq!(models, vec!["llama3:8b"]);
630 }
631
632 #[test]
633 fn register_overrides_existing() {
634 let mut reg = ModelRegistry::with_defaults();
635 let original = reg.get("anthropic", "claude-haiku-4-5").unwrap();
636 assert!(original.pricing.input_per_million > 0.0);
637
638 reg.register(
639 "anthropic",
640 "claude-haiku-4-5",
641 ModelInfo {
642 id: "claude-haiku-4-5".into(),
643 provider: "anthropic".into(),
644 pricing: CostRates {
645 input_per_million: 99.0,
646 output_per_million: 99.0,
647 cache_read_multiplier: None,
648 cache_creation_multiplier: None,
649 },
650 context_window: Some(200_000),
651 supports_tool_use: true,
652 supports_vision: true,
653 },
654 );
655 let updated = reg.get("anthropic", "claude-haiku-4-5").unwrap();
656 assert!((updated.pricing.input_per_million - 99.0).abs() < 1e-9);
657 }
658}