1use std::collections::HashMap;
20
21use serde::Deserialize;
22use tracing::debug;
23
24use crate::provider::CostRates;
25
26const DEFAULTS_TOML: &str = include_str!("defaults/models.toml");
28
29#[derive(Debug, Deserialize)]
32struct CatalogFile {
33 #[serde(flatten)]
34 providers: HashMap<String, ProviderEntry>,
35}
36
37#[derive(Debug, Deserialize)]
38struct ProviderEntry {
39 #[serde(default)]
40 default_model: Option<String>,
41 #[serde(default)]
42 api_key_env: Option<String>,
43 #[serde(default)]
44 cache_read_multiplier: Option<f64>,
45 #[serde(default)]
46 cache_creation_multiplier: Option<f64>,
47 #[serde(default)]
48 models: HashMap<String, ModelEntry>,
49}
50
51#[derive(Debug, Deserialize)]
52struct ModelEntry {
53 input: f64,
54 output: f64,
55 #[serde(default)]
56 context_window: Option<u64>,
57 #[serde(default = "default_true")]
58 supports_tool_use: bool,
59 #[serde(default)]
60 supports_vision: bool,
61 #[serde(default)]
62 cache_read_multiplier: Option<f64>,
63 #[serde(default)]
64 cache_creation_multiplier: Option<f64>,
65}
66
67fn default_true() -> bool {
68 true
69}
70
71#[derive(Debug, Clone)]
75pub struct ModelInfo {
76 pub id: String,
78 pub provider: String,
80 pub pricing: CostRates,
82 pub context_window: Option<u64>,
84 pub supports_tool_use: bool,
86 pub supports_vision: bool,
88}
89
90#[derive(Debug, Clone)]
92pub struct ProviderInfo {
93 pub name: String,
95 pub default_model: Option<String>,
97 pub api_key_env: Option<String>,
99 pub cache_read_multiplier: Option<f64>,
101 pub cache_creation_multiplier: Option<f64>,
103}
104
105type ModelKey = String;
109
110fn make_key(provider: &str, model: &str) -> ModelKey {
111 format!("{provider}::{model}")
112}
113
114#[derive(Debug, Clone)]
122pub struct ModelRegistry {
123 models: HashMap<ModelKey, ModelInfo>,
124 providers: HashMap<String, ProviderInfo>,
125}
126
127impl ModelRegistry {
128 pub fn new() -> Self {
130 Self {
131 models: HashMap::new(),
132 providers: HashMap::new(),
133 }
134 }
135
136 pub fn with_defaults() -> Self {
138 Self::from_toml(DEFAULTS_TOML).expect("embedded models.toml must be valid")
139 }
140
141 pub fn from_toml(toml_str: &str) -> Result<Self, String> {
143 let file: CatalogFile =
144 toml::from_str(toml_str).map_err(|e| format!("models TOML parse error: {e}"))?;
145
146 let mut models = HashMap::new();
147 let mut providers = HashMap::new();
148
149 for (prov_name, pe) in &file.providers {
150 providers.insert(
151 prov_name.clone(),
152 ProviderInfo {
153 name: prov_name.clone(),
154 default_model: pe.default_model.clone(),
155 api_key_env: pe.api_key_env.clone(),
156 cache_read_multiplier: pe.cache_read_multiplier,
157 cache_creation_multiplier: pe.cache_creation_multiplier,
158 },
159 );
160
161 for (model_id, me) in &pe.models {
162 let info = ModelInfo {
163 id: model_id.clone(),
164 provider: prov_name.clone(),
165 pricing: CostRates {
166 input_per_million: me.input,
167 output_per_million: me.output,
168 cache_read_multiplier: me.cache_read_multiplier.or(pe.cache_read_multiplier),
169 cache_creation_multiplier: me
170 .cache_creation_multiplier
171 .or(pe.cache_creation_multiplier),
172 },
173 context_window: me.context_window,
174 supports_tool_use: me.supports_tool_use,
175 supports_vision: me.supports_vision,
176 };
177 models.insert(make_key(prov_name, model_id), info);
178 }
179 }
180
181 Ok(Self { models, providers })
182 }
183
184 pub fn merge(&mut self, other: Self) {
186 for (key, info) in other.models {
187 self.models.insert(key, info);
188 }
189 for (key, info) in other.providers {
190 if let Some(existing) = self.providers.get_mut(&key) {
191 if info.default_model.is_some() {
192 existing.default_model = info.default_model;
193 }
194 if info.api_key_env.is_some() {
195 existing.api_key_env = info.api_key_env;
196 }
197 if info.cache_read_multiplier.is_some() {
198 existing.cache_read_multiplier = info.cache_read_multiplier;
199 }
200 if info.cache_creation_multiplier.is_some() {
201 existing.cache_creation_multiplier = info.cache_creation_multiplier;
202 }
203 } else {
204 self.providers.insert(key, info);
205 }
206 }
207 }
208
209 pub fn get(&self, provider: &str, model: &str) -> Option<&ModelInfo> {
213 self.models.get(&make_key(provider, model))
214 }
215
216 pub fn get_fuzzy(&self, provider: &str, model: &str) -> Option<&ModelInfo> {
219 if let Some(info) = self.get(provider, model) {
220 return Some(info);
221 }
222
223 let prefix = format!("{provider}::");
224
225 let mut best: Option<(&str, &ModelInfo)> = None;
226 for (key, info) in &self.models {
227 if let Some(registered) = key.strip_prefix(&prefix) {
228 if model.contains(registered) || registered.contains(model) {
229 let dominated = best
230 .map(|(prev, _)| registered.len() > prev.len())
231 .unwrap_or(true);
232 if dominated {
233 best = Some((registered, info));
234 }
235 }
236 }
237 }
238 if let Some((matched, info)) = best {
239 debug!(provider, model, matched, "fuzzy model match");
240 return Some(info);
241 }
242
243 None
244 }
245
246 pub fn get_pricing(&self, provider: &str, model: &str) -> Option<CostRates> {
249 if let Some(info) = self.get_fuzzy(provider, model) {
250 return Some(info.pricing.clone());
251 }
252
253 self.providers.get(provider).and_then(|p| {
254 if p.cache_read_multiplier.is_some() || p.cache_creation_multiplier.is_some() {
255 Some(CostRates {
256 input_per_million: 0.0,
257 output_per_million: 0.0,
258 cache_read_multiplier: p.cache_read_multiplier,
259 cache_creation_multiplier: p.cache_creation_multiplier,
260 })
261 } else {
262 None
263 }
264 })
265 }
266
267 pub fn provider(&self, name: &str) -> Option<&ProviderInfo> {
271 self.providers.get(name)
272 }
273
274 pub fn provider_names(&self) -> Vec<&str> {
276 let mut names: Vec<&str> = self.providers.keys().map(|s| s.as_str()).collect();
277 names.sort();
278 names
279 }
280
281 pub fn default_model(&self, provider: &str) -> Option<&str> {
283 self.providers
284 .get(provider)
285 .and_then(|p| p.default_model.as_deref())
286 }
287
288 pub fn api_key_env(&self, provider: &str) -> Option<&str> {
290 self.providers
291 .get(provider)
292 .and_then(|p| p.api_key_env.as_deref())
293 }
294
295 pub fn models_for_provider(&self, provider: &str) -> Vec<&str> {
297 let prefix = format!("{provider}::");
298 let mut out: Vec<&str> = self
299 .models
300 .iter()
301 .filter_map(|(key, info)| {
302 if key.starts_with(&prefix) {
303 Some(info.id.as_str())
304 } else {
305 None
306 }
307 })
308 .collect();
309 out.sort();
310 out
311 }
312
313 pub fn models_by_provider(&self) -> HashMap<String, Vec<String>> {
315 let mut result: HashMap<String, Vec<String>> = HashMap::new();
316 for prov in self.providers.keys() {
317 result.insert(
318 prov.clone(),
319 self.models_for_provider(prov)
320 .into_iter()
321 .map(String::from)
322 .collect(),
323 );
324 }
325 result
326 }
327
328 pub fn len(&self) -> usize {
330 self.models.len()
331 }
332
333 pub fn is_empty(&self) -> bool {
335 self.models.is_empty()
336 }
337}
338
339impl Default for ModelRegistry {
340 fn default() -> Self {
341 Self::with_defaults()
342 }
343}
344
345pub type PricingRegistry = ModelRegistry;
347
348#[cfg(test)]
349mod tests {
350 use super::*;
351
352 #[test]
353 fn defaults_load_successfully() {
354 let reg = ModelRegistry::with_defaults();
355 assert!(!reg.is_empty());
356 }
357
358 #[test]
359 fn exact_match() {
360 let reg = ModelRegistry::with_defaults();
361 let info = reg.get("anthropic", "claude-sonnet-4-5").unwrap();
362 assert!((info.pricing.input_per_million - 3.0).abs() < 1e-9);
363 assert!((info.pricing.output_per_million - 15.0).abs() < 1e-9);
364 assert!((info.pricing.cache_read_multiplier.unwrap() - 0.1).abs() < 1e-9);
365 assert!((info.pricing.cache_creation_multiplier.unwrap() - 1.25).abs() < 1e-9);
366 assert_eq!(info.context_window, Some(200_000));
367 assert!(info.supports_tool_use);
368 assert!(info.supports_vision);
369 }
370
371 #[test]
372 fn fuzzy_match_longer_model_id() {
373 let reg = ModelRegistry::with_defaults();
374 let info = reg.get_fuzzy("anthropic", "claude-sonnet-4-5-20250514").unwrap();
375 assert!((info.pricing.input_per_million - 3.0).abs() < 1e-9);
376 }
377
378 #[test]
379 fn fuzzy_match_picks_most_specific() {
380 let mut reg = ModelRegistry::new();
381 let short_key = make_key("test", "claude-sonnet");
382 reg.models.insert(short_key, ModelInfo {
383 id: "claude-sonnet".into(),
384 provider: "test".into(),
385 pricing: CostRates {
386 input_per_million: 1.0,
387 output_per_million: 5.0,
388 cache_read_multiplier: None,
389 cache_creation_multiplier: None,
390 },
391 context_window: None,
392 supports_tool_use: true,
393 supports_vision: false,
394 });
395 let long_key = make_key("test", "claude-sonnet-4-5");
396 reg.models.insert(long_key, ModelInfo {
397 id: "claude-sonnet-4-5".into(),
398 provider: "test".into(),
399 pricing: CostRates {
400 input_per_million: 3.0,
401 output_per_million: 15.0,
402 cache_read_multiplier: None,
403 cache_creation_multiplier: None,
404 },
405 context_window: None,
406 supports_tool_use: true,
407 supports_vision: false,
408 });
409 let info = reg.get_fuzzy("test", "claude-sonnet-4-5-20250514").unwrap();
410 assert!((info.pricing.input_per_million - 3.0).abs() < 1e-9);
411 }
412
413 #[test]
414 fn provider_default_cache_multipliers() {
415 let reg = ModelRegistry::with_defaults();
416 let pricing = reg.get_pricing("anthropic", "claude-unknown-99").unwrap();
417 assert!((pricing.cache_read_multiplier.unwrap() - 0.1).abs() < 1e-9);
418 }
419
420 #[test]
421 fn merge_overrides() {
422 let mut base = ModelRegistry::with_defaults();
423 let overrides = ModelRegistry::from_toml(r#"
424[anthropic.models.claude-sonnet-4-5]
425input = 99.0
426output = 99.0
427"#).unwrap();
428 base.merge(overrides);
429 let info = base.get("anthropic", "claude-sonnet-4-5").unwrap();
430 assert!((info.pricing.input_per_million - 99.0).abs() < 1e-9);
431 }
432
433 #[test]
434 fn openai_cache_rates() {
435 let reg = ModelRegistry::with_defaults();
436 let info = reg.get("openai", "gpt-4o").unwrap();
437 assert!((info.pricing.cache_read_multiplier.unwrap() - 0.1).abs() < 1e-9);
438 assert!((info.pricing.cache_creation_multiplier.unwrap() - 1.0).abs() < 1e-9);
439 }
440
441 #[test]
442 fn gemini_cache_rates() {
443 let reg = ModelRegistry::with_defaults();
444 let info = reg.get_fuzzy("gemini", "gemini-2-5-flash").unwrap();
445 assert!((info.pricing.cache_read_multiplier.unwrap() - 0.1).abs() < 1e-9);
446 }
447
448 #[test]
449 fn from_toml_custom() {
450 let toml = r#"
451[custom]
452cache_read_multiplier = 0.3
453
454[custom.models.my-model]
455input = 5.0
456output = 20.0
457"#;
458 let reg = ModelRegistry::from_toml(toml).unwrap();
459 let info = reg.get("custom", "my-model").unwrap();
460 assert!((info.pricing.input_per_million - 5.0).abs() < 1e-9);
461 assert!((info.pricing.cache_read_multiplier.unwrap() - 0.3).abs() < 1e-9);
462 assert!(info.pricing.cache_creation_multiplier.is_none());
463 }
464
465 #[test]
466 fn per_model_cache_override() {
467 let toml = r#"
468[prov]
469cache_read_multiplier = 0.1
470cache_creation_multiplier = 1.25
471
472[prov.models.special]
473input = 10.0
474output = 50.0
475cache_read_multiplier = 0.05
476"#;
477 let reg = ModelRegistry::from_toml(toml).unwrap();
478 let info = reg.get("prov", "special").unwrap();
479 assert!((info.pricing.cache_read_multiplier.unwrap() - 0.05).abs() < 1e-9);
480 assert!((info.pricing.cache_creation_multiplier.unwrap() - 1.25).abs() < 1e-9);
481 }
482
483 #[test]
484 fn empty_provider_no_panic() {
485 let toml = r#"
486[empty]
487"#;
488 let reg = ModelRegistry::from_toml(toml).unwrap();
489 assert!(reg.get("empty", "anything").is_none());
490 assert!(reg.get_fuzzy("empty", "anything").is_none());
491 }
492
493 #[test]
496 fn default_model_per_provider() {
497 let reg = ModelRegistry::with_defaults();
498 assert_eq!(reg.default_model("anthropic"), Some("claude-haiku-4-5"));
499 assert_eq!(reg.default_model("openai"), Some("gpt-4o"));
500 assert_eq!(reg.default_model("gemini"), Some("gemini-2.5-pro"));
501 assert_eq!(reg.default_model("groq"), Some("llama-3.3-70b-versatile"));
502 assert_eq!(reg.default_model("deepseek"), Some("deepseek-chat"));
503 assert_eq!(reg.default_model("ollama"), Some("llama3.3"));
504 }
505
506 #[test]
507 fn api_key_env_per_provider() {
508 let reg = ModelRegistry::with_defaults();
509 assert_eq!(reg.api_key_env("anthropic"), Some("ANTHROPIC_API_KEY"));
510 assert_eq!(reg.api_key_env("openai"), Some("OPENAI_API_KEY"));
511 assert_eq!(reg.api_key_env("ollama"), None);
512 }
513
514 #[test]
515 fn models_for_provider_lists_all() {
516 let reg = ModelRegistry::with_defaults();
517 let anthropic = reg.models_for_provider("anthropic");
518 assert!(anthropic.contains(&"claude-haiku-4-5"));
519 assert!(anthropic.contains(&"claude-sonnet-4-6"));
520 assert!(anthropic.contains(&"claude-opus-4-6"));
521 assert!(anthropic.len() >= 4);
522 }
523
524 #[test]
525 fn models_by_provider_for_settings_api() {
526 let reg = ModelRegistry::with_defaults();
527 let map = reg.models_by_provider();
528 assert!(map.contains_key("anthropic"));
529 assert!(map.contains_key("openai"));
530 assert!(map.contains_key("ollama"));
531 assert!(map["ollama"].is_empty());
532 }
533
534 #[test]
535 fn provider_names_returns_all() {
536 let reg = ModelRegistry::with_defaults();
537 let names = reg.provider_names();
538 assert!(names.contains(&"anthropic"));
539 assert!(names.contains(&"openai"));
540 assert!(names.contains(&"gemini"));
541 assert!(names.contains(&"groq"));
542 assert!(names.contains(&"deepseek"));
543 assert!(names.contains(&"openrouter"));
544 assert!(names.contains(&"ollama"));
545 }
546
547 #[test]
548 fn model_capabilities() {
549 let reg = ModelRegistry::with_defaults();
550 let haiku = reg.get("anthropic", "claude-haiku-4-5").unwrap();
551 assert!(haiku.supports_tool_use);
552 assert!(haiku.supports_vision);
553
554 let gpt41 = reg.get("openai", "gpt-4.1").unwrap();
555 assert!(gpt41.supports_tool_use);
556 assert!(!gpt41.supports_vision);
557 }
558}