Skip to main content

claude_agent/models/
registry.rs

1use std::collections::HashMap;
2use std::sync::OnceLock;
3
4use super::builtin;
5use super::family::{ModelFamily, ModelRole};
6use super::provider::ProviderKind;
7use super::spec::{ModelId, ModelSpec};
8
9static REGISTRY: OnceLock<ModelRegistry> = OnceLock::new();
10
11pub fn registry() -> &'static ModelRegistry {
12    REGISTRY.get_or_init(ModelRegistry::builtins)
13}
14
15#[derive(Debug, Default)]
16pub struct ModelRegistry {
17    models: HashMap<ModelId, ModelSpec>,
18    aliases: HashMap<String, ModelId>,
19    by_family: HashMap<ModelFamily, Vec<ModelId>>,
20    defaults: HashMap<ModelRole, ModelId>,
21}
22
23impl ModelRegistry {
24    pub fn new() -> Self {
25        Self::default()
26    }
27
28    pub fn builtins() -> Self {
29        let mut registry = Self::new();
30        builtin::register_all(&mut registry);
31        registry
32    }
33
34    pub fn register(&mut self, spec: ModelSpec) {
35        let id = spec.id.clone();
36        let family = spec.family;
37
38        self.models.insert(id.clone(), spec);
39
40        self.by_family.entry(family).or_default().push(id.clone());
41
42        for alias in family.aliases() {
43            self.aliases.insert(alias.to_string(), id.clone());
44        }
45    }
46
47    pub fn set_default(&mut self, role: ModelRole, id: ModelId) {
48        self.defaults.insert(role, id);
49    }
50
51    pub fn add_alias(&mut self, alias: impl Into<String>, id: ModelId) {
52        self.aliases.insert(alias.into(), id);
53    }
54
55    pub fn get(&self, id: &str) -> Option<&ModelSpec> {
56        self.models.get(id)
57    }
58
59    pub fn resolve(&self, alias_or_id: &str) -> Option<&ModelSpec> {
60        // Direct ID lookup
61        if let Some(spec) = self.models.get(alias_or_id) {
62            return Some(spec);
63        }
64
65        // Alias lookup
66        if let Some(canonical) = self.aliases.get(alias_or_id) {
67            return self.models.get(canonical);
68        }
69
70        // Fallback: substring matching for model family
71        let lower = alias_or_id.to_lowercase();
72        let fallback = if lower.contains("opus") {
73            self.latest(ModelFamily::Opus)
74        } else if lower.contains("sonnet") {
75            self.latest(ModelFamily::Sonnet)
76        } else if lower.contains("haiku") {
77            self.latest(ModelFamily::Haiku)
78        } else {
79            None
80        };
81
82        if let Some(spec) = &fallback {
83            tracing::debug!(
84                input = alias_or_id,
85                resolved = %spec.id,
86                "model resolved via substring fallback"
87            );
88        }
89
90        fallback
91    }
92
93    pub fn default_for_role(&self, role: ModelRole) -> Option<&ModelSpec> {
94        let id = self.defaults.get(&role)?;
95        self.models.get(id)
96    }
97
98    pub fn latest(&self, family: ModelFamily) -> Option<&ModelSpec> {
99        let ids = self.by_family.get(&family)?;
100        let id = ids.first()?;
101        self.models.get(id)
102    }
103
104    pub fn for_provider(&self, provider: ProviderKind, provider_id: &str) -> Option<&ModelSpec> {
105        self.models
106            .values()
107            .find(|spec| spec.provider_ids.for_provider(provider) == Some(provider_id))
108    }
109
110    pub fn family_models(&self, family: ModelFamily) -> Vec<&ModelSpec> {
111        self.by_family
112            .get(&family)
113            .map(|ids| ids.iter().filter_map(|id| self.models.get(id)).collect())
114            .unwrap_or_default()
115    }
116
117    pub fn all(&self) -> impl Iterator<Item = &ModelSpec> {
118        self.models.values()
119    }
120}
121
122#[cfg(test)]
123mod tests {
124    use super::*;
125
126    #[test]
127    fn test_registry_resolve() {
128        let registry = ModelRegistry::builtins();
129
130        assert!(registry.resolve("sonnet").is_some());
131        assert!(registry.resolve("haiku").is_some());
132        assert!(registry.resolve("opus").is_some());
133    }
134
135    #[test]
136    fn test_registry_default_roles() {
137        let registry = ModelRegistry::builtins();
138
139        assert!(registry.default_for_role(ModelRole::Primary).is_some());
140        assert!(registry.default_for_role(ModelRole::Small).is_some());
141        assert!(registry.default_for_role(ModelRole::Reasoning).is_some());
142    }
143
144    #[test]
145    fn test_registry_global() {
146        assert!(registry().resolve("sonnet").is_some());
147    }
148}