Skip to main content

agent_sdk/
models.rs

1//! Centralized model catalog with pricing, capabilities, and provider metadata.
2//!
3//! The [`ModelRegistry`] is the single source of truth for every model the
4//! system knows about. It ships with embedded defaults (from `defaults/models.toml`)
5//! and supports overlaying a user-provided TOML file at runtime — so updates
6//! don't require recompilation.
7//!
8//! # Example
9//!
10//! ```rust
11//! use agent_sdk::models::ModelRegistry;
12//!
13//! let registry = ModelRegistry::with_defaults();
14//! let info = registry.get("anthropic", "claude-sonnet-4-5").unwrap();
15//! assert!(info.pricing.input_per_million > 0.0);
16//! assert_eq!(info.context_window, Some(200_000));
17//! ```
18
19use std::collections::HashMap;
20
21use serde::Deserialize;
22use tracing::debug;
23
24use crate::provider::CostRates;
25
26/// Embedded default catalog (compiled into the binary).
27const DEFAULTS_TOML: &str = include_str!("defaults/models.toml");
28
29// ── TOML serde types ──────────────────────────────────────────────────
30
31#[derive(Debug, Deserialize)]
32struct CatalogFile {
33    #[serde(flatten)]
34    providers: HashMap<String, ProviderEntry>,
35}
36
37#[derive(Debug, Deserialize)]
38struct ProviderEntry {
39    #[serde(default)]
40    default_model: Option<String>,
41    #[serde(default)]
42    api_key_env: Option<String>,
43    #[serde(default)]
44    cache_read_multiplier: Option<f64>,
45    #[serde(default)]
46    cache_creation_multiplier: Option<f64>,
47    #[serde(default)]
48    models: HashMap<String, ModelEntry>,
49}
50
51#[derive(Debug, Deserialize)]
52struct ModelEntry {
53    input: f64,
54    output: f64,
55    #[serde(default)]
56    context_window: Option<u64>,
57    #[serde(default = "default_true")]
58    supports_tool_use: bool,
59    #[serde(default)]
60    supports_vision: bool,
61    #[serde(default)]
62    cache_read_multiplier: Option<f64>,
63    #[serde(default)]
64    cache_creation_multiplier: Option<f64>,
65}
66
67fn default_true() -> bool {
68    true
69}
70
71// ── Public types ──────────────────────────────────────────────────────
72
73/// Full information about a model: pricing + capabilities.
74#[derive(Debug, Clone)]
75pub struct ModelInfo {
76    /// Model ID as registered in the catalog.
77    pub id: String,
78    /// Provider name (e.g. "anthropic", "openai").
79    pub provider: String,
80    /// Cost rates for this model.
81    pub pricing: CostRates,
82    /// Maximum context window in tokens.
83    pub context_window: Option<u64>,
84    /// Whether this model supports tool use.
85    pub supports_tool_use: bool,
86    /// Whether this model supports vision/images.
87    pub supports_vision: bool,
88}
89
90/// Provider-level metadata.
91#[derive(Debug, Clone)]
92pub struct ProviderInfo {
93    /// Provider name (e.g. "anthropic").
94    pub name: String,
95    /// Default model for this provider.
96    pub default_model: Option<String>,
97    /// Environment variable for the API key.
98    pub api_key_env: Option<String>,
99    /// Provider-level cache read multiplier.
100    pub cache_read_multiplier: Option<f64>,
101    /// Provider-level cache creation multiplier.
102    pub cache_creation_multiplier: Option<f64>,
103}
104
105// ── Registry ──────────────────────────────────────────────────────────
106
107/// Composite key: `"provider::model"`.
108type ModelKey = String;
109
110fn make_key(provider: &str, model: &str) -> ModelKey {
111    format!("{provider}::{model}")
112}
113
114/// Centralized model catalog with pricing, capabilities, and provider metadata.
115///
116/// Lookup order for pricing/model queries:
117/// 1. Exact match on `"provider::model"`
118/// 2. Fuzzy match — any registered model whose name is a substring of the
119///    query (or vice-versa), scoped to the same provider
120/// 3. Provider-level default entry (cache multipliers only, via `get_pricing`)
121#[derive(Debug, Clone)]
122pub struct ModelRegistry {
123    models: HashMap<ModelKey, ModelInfo>,
124    providers: HashMap<String, ProviderInfo>,
125}
126
127impl ModelRegistry {
128    /// Create an empty registry.
129    pub fn new() -> Self {
130        Self {
131            models: HashMap::new(),
132            providers: HashMap::new(),
133        }
134    }
135
136    /// Create a registry pre-loaded with the embedded defaults.
137    pub fn with_defaults() -> Self {
138        Self::from_toml(DEFAULTS_TOML).expect("embedded models.toml must be valid")
139    }
140
141    /// Parse a TOML string into a registry.
142    pub fn from_toml(toml_str: &str) -> Result<Self, String> {
143        let file: CatalogFile =
144            toml::from_str(toml_str).map_err(|e| format!("models TOML parse error: {e}"))?;
145
146        let mut models = HashMap::new();
147        let mut providers = HashMap::new();
148
149        for (prov_name, pe) in &file.providers {
150            providers.insert(
151                prov_name.clone(),
152                ProviderInfo {
153                    name: prov_name.clone(),
154                    default_model: pe.default_model.clone(),
155                    api_key_env: pe.api_key_env.clone(),
156                    cache_read_multiplier: pe.cache_read_multiplier,
157                    cache_creation_multiplier: pe.cache_creation_multiplier,
158                },
159            );
160
161            for (model_id, me) in &pe.models {
162                let info = ModelInfo {
163                    id: model_id.clone(),
164                    provider: prov_name.clone(),
165                    pricing: CostRates {
166                        input_per_million: me.input,
167                        output_per_million: me.output,
168                        cache_read_multiplier: me.cache_read_multiplier.or(pe.cache_read_multiplier),
169                        cache_creation_multiplier: me
170                            .cache_creation_multiplier
171                            .or(pe.cache_creation_multiplier),
172                    },
173                    context_window: me.context_window,
174                    supports_tool_use: me.supports_tool_use,
175                    supports_vision: me.supports_vision,
176                };
177                models.insert(make_key(prov_name, model_id), info);
178            }
179        }
180
181        Ok(Self { models, providers })
182    }
183
184    /// Merge another registry on top (overrides win).
185    pub fn merge(&mut self, other: Self) {
186        for (key, info) in other.models {
187            self.models.insert(key, info);
188        }
189        for (key, info) in other.providers {
190            if let Some(existing) = self.providers.get_mut(&key) {
191                if info.default_model.is_some() {
192                    existing.default_model = info.default_model;
193                }
194                if info.api_key_env.is_some() {
195                    existing.api_key_env = info.api_key_env;
196                }
197                if info.cache_read_multiplier.is_some() {
198                    existing.cache_read_multiplier = info.cache_read_multiplier;
199                }
200                if info.cache_creation_multiplier.is_some() {
201                    existing.cache_creation_multiplier = info.cache_creation_multiplier;
202                }
203            } else {
204                self.providers.insert(key, info);
205            }
206        }
207    }
208
209    // ── Model lookups ─────────────────────────────────────────────────
210
211    /// Exact-match lookup.
212    pub fn get(&self, provider: &str, model: &str) -> Option<&ModelInfo> {
213        self.models.get(&make_key(provider, model))
214    }
215
216    /// Fuzzy lookup: tries exact match first, then substring matching
217    /// against all models for the given provider.
218    pub fn get_fuzzy(&self, provider: &str, model: &str) -> Option<&ModelInfo> {
219        if let Some(info) = self.get(provider, model) {
220            return Some(info);
221        }
222
223        let prefix = format!("{provider}::");
224
225        let mut best: Option<(&str, &ModelInfo)> = None;
226        for (key, info) in &self.models {
227            if let Some(registered) = key.strip_prefix(&prefix) {
228                if model.contains(registered) || registered.contains(model) {
229                    let dominated = best
230                        .map(|(prev, _)| registered.len() > prev.len())
231                        .unwrap_or(true);
232                    if dominated {
233                        best = Some((registered, info));
234                    }
235                }
236            }
237        }
238        if let Some((matched, info)) = best {
239            debug!(provider, model, matched, "fuzzy model match");
240            return Some(info);
241        }
242
243        None
244    }
245
246    /// Get pricing for a model (convenience wrapper returning just `CostRates`).
247    /// Falls back to provider-level cache multipliers for unknown models.
248    pub fn get_pricing(&self, provider: &str, model: &str) -> Option<CostRates> {
249        if let Some(info) = self.get_fuzzy(provider, model) {
250            return Some(info.pricing.clone());
251        }
252
253        self.providers.get(provider).and_then(|p| {
254            if p.cache_read_multiplier.is_some() || p.cache_creation_multiplier.is_some() {
255                Some(CostRates {
256                    input_per_million: 0.0,
257                    output_per_million: 0.0,
258                    cache_read_multiplier: p.cache_read_multiplier,
259                    cache_creation_multiplier: p.cache_creation_multiplier,
260                })
261            } else {
262                None
263            }
264        })
265    }
266
267    // ── Provider lookups ──────────────────────────────────────────────
268
269    /// Get provider metadata.
270    pub fn provider(&self, name: &str) -> Option<&ProviderInfo> {
271        self.providers.get(name)
272    }
273
274    /// List all known provider names, sorted alphabetically.
275    pub fn provider_names(&self) -> Vec<&str> {
276        let mut names: Vec<&str> = self.providers.keys().map(|s| s.as_str()).collect();
277        names.sort();
278        names
279    }
280
281    /// Get the default model for a provider.
282    pub fn default_model(&self, provider: &str) -> Option<&str> {
283        self.providers
284            .get(provider)
285            .and_then(|p| p.default_model.as_deref())
286    }
287
288    /// Get the API key env var for a provider.
289    pub fn api_key_env(&self, provider: &str) -> Option<&str> {
290        self.providers
291            .get(provider)
292            .and_then(|p| p.api_key_env.as_deref())
293    }
294
295    /// List all model IDs for a provider, sorted alphabetically.
296    pub fn models_for_provider(&self, provider: &str) -> Vec<&str> {
297        let prefix = format!("{provider}::");
298        let mut out: Vec<&str> = self
299            .models
300            .iter()
301            .filter_map(|(key, info)| {
302                if key.starts_with(&prefix) {
303                    Some(info.id.as_str())
304                } else {
305                    None
306                }
307            })
308            .collect();
309        out.sort();
310        out
311    }
312
313    /// Get a map of provider → model list, suitable for the settings API.
314    pub fn models_by_provider(&self) -> HashMap<String, Vec<String>> {
315        let mut result: HashMap<String, Vec<String>> = HashMap::new();
316        for prov in self.providers.keys() {
317            result.insert(
318                prov.clone(),
319                self.models_for_provider(prov)
320                    .into_iter()
321                    .map(String::from)
322                    .collect(),
323            );
324        }
325        result
326    }
327
328    /// Number of models in the registry.
329    pub fn len(&self) -> usize {
330        self.models.len()
331    }
332
333    /// Whether the registry is empty.
334    pub fn is_empty(&self) -> bool {
335        self.models.is_empty()
336    }
337}
338
339impl Default for ModelRegistry {
340    fn default() -> Self {
341        Self::with_defaults()
342    }
343}
344
345/// Backward-compatible alias.
346pub type PricingRegistry = ModelRegistry;
347
348#[cfg(test)]
349mod tests {
350    use super::*;
351
352    #[test]
353    fn defaults_load_successfully() {
354        let reg = ModelRegistry::with_defaults();
355        assert!(!reg.is_empty());
356    }
357
358    #[test]
359    fn exact_match() {
360        let reg = ModelRegistry::with_defaults();
361        let info = reg.get("anthropic", "claude-sonnet-4-5").unwrap();
362        assert!((info.pricing.input_per_million - 3.0).abs() < 1e-9);
363        assert!((info.pricing.output_per_million - 15.0).abs() < 1e-9);
364        assert!((info.pricing.cache_read_multiplier.unwrap() - 0.1).abs() < 1e-9);
365        assert!((info.pricing.cache_creation_multiplier.unwrap() - 1.25).abs() < 1e-9);
366        assert_eq!(info.context_window, Some(200_000));
367        assert!(info.supports_tool_use);
368        assert!(info.supports_vision);
369    }
370
371    #[test]
372    fn fuzzy_match_longer_model_id() {
373        let reg = ModelRegistry::with_defaults();
374        let info = reg.get_fuzzy("anthropic", "claude-sonnet-4-5-20250514").unwrap();
375        assert!((info.pricing.input_per_million - 3.0).abs() < 1e-9);
376    }
377
378    #[test]
379    fn fuzzy_match_picks_most_specific() {
380        let mut reg = ModelRegistry::new();
381        let short_key = make_key("test", "claude-sonnet");
382        reg.models.insert(short_key, ModelInfo {
383            id: "claude-sonnet".into(),
384            provider: "test".into(),
385            pricing: CostRates {
386                input_per_million: 1.0,
387                output_per_million: 5.0,
388                cache_read_multiplier: None,
389                cache_creation_multiplier: None,
390            },
391            context_window: None,
392            supports_tool_use: true,
393            supports_vision: false,
394        });
395        let long_key = make_key("test", "claude-sonnet-4-5");
396        reg.models.insert(long_key, ModelInfo {
397            id: "claude-sonnet-4-5".into(),
398            provider: "test".into(),
399            pricing: CostRates {
400                input_per_million: 3.0,
401                output_per_million: 15.0,
402                cache_read_multiplier: None,
403                cache_creation_multiplier: None,
404            },
405            context_window: None,
406            supports_tool_use: true,
407            supports_vision: false,
408        });
409        let info = reg.get_fuzzy("test", "claude-sonnet-4-5-20250514").unwrap();
410        assert!((info.pricing.input_per_million - 3.0).abs() < 1e-9);
411    }
412
413    #[test]
414    fn provider_default_cache_multipliers() {
415        let reg = ModelRegistry::with_defaults();
416        let pricing = reg.get_pricing("anthropic", "claude-unknown-99").unwrap();
417        assert!((pricing.cache_read_multiplier.unwrap() - 0.1).abs() < 1e-9);
418    }
419
420    #[test]
421    fn merge_overrides() {
422        let mut base = ModelRegistry::with_defaults();
423        let overrides = ModelRegistry::from_toml(r#"
424[anthropic.models.claude-sonnet-4-5]
425input = 99.0
426output = 99.0
427"#).unwrap();
428        base.merge(overrides);
429        let info = base.get("anthropic", "claude-sonnet-4-5").unwrap();
430        assert!((info.pricing.input_per_million - 99.0).abs() < 1e-9);
431    }
432
433    #[test]
434    fn openai_cache_rates() {
435        let reg = ModelRegistry::with_defaults();
436        let info = reg.get("openai", "gpt-4o").unwrap();
437        assert!((info.pricing.cache_read_multiplier.unwrap() - 0.1).abs() < 1e-9);
438        assert!((info.pricing.cache_creation_multiplier.unwrap() - 1.0).abs() < 1e-9);
439    }
440
441    #[test]
442    fn gemini_cache_rates() {
443        let reg = ModelRegistry::with_defaults();
444        let info = reg.get_fuzzy("gemini", "gemini-2-5-flash").unwrap();
445        assert!((info.pricing.cache_read_multiplier.unwrap() - 0.1).abs() < 1e-9);
446    }
447
448    #[test]
449    fn from_toml_custom() {
450        let toml = r#"
451[custom]
452cache_read_multiplier = 0.3
453
454[custom.models.my-model]
455input = 5.0
456output = 20.0
457"#;
458        let reg = ModelRegistry::from_toml(toml).unwrap();
459        let info = reg.get("custom", "my-model").unwrap();
460        assert!((info.pricing.input_per_million - 5.0).abs() < 1e-9);
461        assert!((info.pricing.cache_read_multiplier.unwrap() - 0.3).abs() < 1e-9);
462        assert!(info.pricing.cache_creation_multiplier.is_none());
463    }
464
465    #[test]
466    fn per_model_cache_override() {
467        let toml = r#"
468[prov]
469cache_read_multiplier = 0.1
470cache_creation_multiplier = 1.25
471
472[prov.models.special]
473input = 10.0
474output = 50.0
475cache_read_multiplier = 0.05
476"#;
477        let reg = ModelRegistry::from_toml(toml).unwrap();
478        let info = reg.get("prov", "special").unwrap();
479        assert!((info.pricing.cache_read_multiplier.unwrap() - 0.05).abs() < 1e-9);
480        assert!((info.pricing.cache_creation_multiplier.unwrap() - 1.25).abs() < 1e-9);
481    }
482
483    #[test]
484    fn empty_provider_no_panic() {
485        let toml = r#"
486[empty]
487"#;
488        let reg = ModelRegistry::from_toml(toml).unwrap();
489        assert!(reg.get("empty", "anything").is_none());
490        assert!(reg.get_fuzzy("empty", "anything").is_none());
491    }
492
493    // ── Provider metadata ─────────────────────────────────────────────
494
495    #[test]
496    fn default_model_per_provider() {
497        let reg = ModelRegistry::with_defaults();
498        assert_eq!(reg.default_model("anthropic"), Some("claude-haiku-4-5"));
499        assert_eq!(reg.default_model("openai"), Some("gpt-4o"));
500        assert_eq!(reg.default_model("gemini"), Some("gemini-2.5-pro"));
501        assert_eq!(reg.default_model("groq"), Some("llama-3.3-70b-versatile"));
502        assert_eq!(reg.default_model("deepseek"), Some("deepseek-chat"));
503        assert_eq!(reg.default_model("ollama"), Some("llama3.3"));
504    }
505
506    #[test]
507    fn api_key_env_per_provider() {
508        let reg = ModelRegistry::with_defaults();
509        assert_eq!(reg.api_key_env("anthropic"), Some("ANTHROPIC_API_KEY"));
510        assert_eq!(reg.api_key_env("openai"), Some("OPENAI_API_KEY"));
511        assert_eq!(reg.api_key_env("ollama"), None);
512    }
513
514    #[test]
515    fn models_for_provider_lists_all() {
516        let reg = ModelRegistry::with_defaults();
517        let anthropic = reg.models_for_provider("anthropic");
518        assert!(anthropic.contains(&"claude-haiku-4-5"));
519        assert!(anthropic.contains(&"claude-sonnet-4-6"));
520        assert!(anthropic.contains(&"claude-opus-4-6"));
521        assert!(anthropic.len() >= 4);
522    }
523
524    #[test]
525    fn models_by_provider_for_settings_api() {
526        let reg = ModelRegistry::with_defaults();
527        let map = reg.models_by_provider();
528        assert!(map.contains_key("anthropic"));
529        assert!(map.contains_key("openai"));
530        assert!(map.contains_key("ollama"));
531        assert!(map["ollama"].is_empty());
532    }
533
534    #[test]
535    fn provider_names_returns_all() {
536        let reg = ModelRegistry::with_defaults();
537        let names = reg.provider_names();
538        assert!(names.contains(&"anthropic"));
539        assert!(names.contains(&"openai"));
540        assert!(names.contains(&"gemini"));
541        assert!(names.contains(&"groq"));
542        assert!(names.contains(&"deepseek"));
543        assert!(names.contains(&"openrouter"));
544        assert!(names.contains(&"ollama"));
545    }
546
547    #[test]
548    fn model_capabilities() {
549        let reg = ModelRegistry::with_defaults();
550        let haiku = reg.get("anthropic", "claude-haiku-4-5").unwrap();
551        assert!(haiku.supports_tool_use);
552        assert!(haiku.supports_vision);
553
554        let gpt41 = reg.get("openai", "gpt-4.1").unwrap();
555        assert!(gpt41.supports_tool_use);
556        assert!(!gpt41.supports_vision);
557    }
558}