claude_agent/models/
registry.rs

1use std::collections::HashMap;
2use std::sync::{OnceLock, RwLock, RwLockReadGuard};
3
4use super::builtin;
5use super::family::{ModelFamily, ModelRole};
6use super::provider::CloudProvider;
7use super::spec::{ModelId, ModelSpec};
8
9static REGISTRY: OnceLock<RwLock<ModelRegistry>> = OnceLock::new();
10
11pub fn registry() -> &'static RwLock<ModelRegistry> {
12    REGISTRY.get_or_init(|| RwLock::new(ModelRegistry::with_builtins()))
13}
14
15pub fn read_registry() -> RwLockReadGuard<'static, ModelRegistry> {
16    registry()
17        .read()
18        .expect("model registry lock poisoned - this indicates a panic during write")
19}
20
21#[derive(Debug, Default)]
22pub struct ModelRegistry {
23    models: HashMap<ModelId, ModelSpec>,
24    aliases: HashMap<String, ModelId>,
25    by_family: HashMap<ModelFamily, Vec<ModelId>>,
26    defaults: HashMap<ModelRole, ModelId>,
27}
28
29impl ModelRegistry {
30    pub fn new() -> Self {
31        Self::default()
32    }
33
34    pub fn with_builtins() -> Self {
35        let mut registry = Self::new();
36        builtin::register_all(&mut registry);
37        registry
38    }
39
40    pub fn register(&mut self, spec: ModelSpec) {
41        let id = spec.id.clone();
42        let family = spec.family;
43
44        self.models.insert(id.clone(), spec);
45
46        self.by_family.entry(family).or_default().push(id.clone());
47
48        for alias in family.aliases() {
49            self.aliases.insert(alias.to_string(), id.clone());
50        }
51    }
52
53    pub fn set_default(&mut self, role: ModelRole, id: ModelId) {
54        self.defaults.insert(role, id);
55    }
56
57    pub fn add_alias(&mut self, alias: impl Into<String>, id: ModelId) {
58        self.aliases.insert(alias.into(), id);
59    }
60
61    pub fn get(&self, id: &str) -> Option<&ModelSpec> {
62        self.models.get(id)
63    }
64
65    pub fn resolve(&self, alias_or_id: &str) -> Option<&ModelSpec> {
66        // Direct ID lookup
67        if let Some(spec) = self.models.get(alias_or_id) {
68            return Some(spec);
69        }
70
71        // Alias lookup
72        if let Some(canonical) = self.aliases.get(alias_or_id) {
73            return self.models.get(canonical);
74        }
75
76        // Fallback: substring matching for model family
77        let lower = alias_or_id.to_lowercase();
78        let fallback = if lower.contains("opus") {
79            self.latest(ModelFamily::Opus)
80        } else if lower.contains("sonnet") {
81            self.latest(ModelFamily::Sonnet)
82        } else if lower.contains("haiku") {
83            self.latest(ModelFamily::Haiku)
84        } else {
85            None
86        };
87
88        if let Some(spec) = &fallback {
89            tracing::debug!(
90                input = alias_or_id,
91                resolved = %spec.id,
92                "model resolved via substring fallback"
93            );
94        }
95
96        fallback
97    }
98
99    pub fn default_for_role(&self, role: ModelRole) -> Option<&ModelSpec> {
100        let id = self.defaults.get(&role)?;
101        self.models.get(id)
102    }
103
104    pub fn latest(&self, family: ModelFamily) -> Option<&ModelSpec> {
105        let ids = self.by_family.get(&family)?;
106        let id = ids.first()?;
107        self.models.get(id)
108    }
109
110    pub fn for_provider(&self, provider: CloudProvider, provider_id: &str) -> Option<&ModelSpec> {
111        self.models
112            .values()
113            .find(|spec| spec.provider_ids.for_provider(provider) == Some(provider_id))
114    }
115
116    pub fn family_models(&self, family: ModelFamily) -> Vec<&ModelSpec> {
117        self.by_family
118            .get(&family)
119            .map(|ids| ids.iter().filter_map(|id| self.models.get(id)).collect())
120            .unwrap_or_default()
121    }
122
123    pub fn all(&self) -> impl Iterator<Item = &ModelSpec> {
124        self.models.values()
125    }
126}
127
128#[cfg(test)]
129mod tests {
130    use super::*;
131
132    #[test]
133    fn test_registry_resolve() {
134        let registry = ModelRegistry::with_builtins();
135
136        assert!(registry.resolve("sonnet").is_some());
137        assert!(registry.resolve("haiku").is_some());
138        assert!(registry.resolve("opus").is_some());
139    }
140
141    #[test]
142    fn test_registry_default_roles() {
143        let registry = ModelRegistry::with_builtins();
144
145        assert!(registry.default_for_role(ModelRole::Primary).is_some());
146        assert!(registry.default_for_role(ModelRole::Small).is_some());
147        assert!(registry.default_for_role(ModelRole::Reasoning).is_some());
148    }
149
150    #[test]
151    fn test_registry_global() {
152        assert!(read_registry().resolve("sonnet").is_some());
153    }
154}