claude_agent/models/
registry.rs1use std::collections::HashMap;
2use std::sync::{OnceLock, RwLock, RwLockReadGuard};
3
4use super::builtin;
5use super::family::{ModelFamily, ModelRole};
6use super::provider::CloudProvider;
7use super::spec::{ModelId, ModelSpec};
8
9static REGISTRY: OnceLock<RwLock<ModelRegistry>> = OnceLock::new();
10
11pub fn registry() -> &'static RwLock<ModelRegistry> {
12 REGISTRY.get_or_init(|| RwLock::new(ModelRegistry::with_builtins()))
13}
14
15pub fn read_registry() -> RwLockReadGuard<'static, ModelRegistry> {
16 registry()
17 .read()
18 .expect("model registry lock poisoned - this indicates a panic during write")
19}
20
21#[derive(Debug, Default)]
22pub struct ModelRegistry {
23 models: HashMap<ModelId, ModelSpec>,
24 aliases: HashMap<String, ModelId>,
25 by_family: HashMap<ModelFamily, Vec<ModelId>>,
26 defaults: HashMap<ModelRole, ModelId>,
27}
28
29impl ModelRegistry {
30 pub fn new() -> Self {
31 Self::default()
32 }
33
34 pub fn with_builtins() -> Self {
35 let mut registry = Self::new();
36 builtin::register_all(&mut registry);
37 registry
38 }
39
40 pub fn register(&mut self, spec: ModelSpec) {
41 let id = spec.id.clone();
42 let family = spec.family;
43
44 self.models.insert(id.clone(), spec);
45
46 self.by_family.entry(family).or_default().push(id.clone());
47
48 for alias in family.aliases() {
49 self.aliases.insert(alias.to_string(), id.clone());
50 }
51 }
52
53 pub fn set_default(&mut self, role: ModelRole, id: ModelId) {
54 self.defaults.insert(role, id);
55 }
56
57 pub fn add_alias(&mut self, alias: impl Into<String>, id: ModelId) {
58 self.aliases.insert(alias.into(), id);
59 }
60
61 pub fn get(&self, id: &str) -> Option<&ModelSpec> {
62 self.models.get(id)
63 }
64
65 pub fn resolve(&self, alias_or_id: &str) -> Option<&ModelSpec> {
66 if let Some(spec) = self.models.get(alias_or_id) {
68 return Some(spec);
69 }
70
71 if let Some(canonical) = self.aliases.get(alias_or_id) {
73 return self.models.get(canonical);
74 }
75
76 let lower = alias_or_id.to_lowercase();
78 let fallback = if lower.contains("opus") {
79 self.latest(ModelFamily::Opus)
80 } else if lower.contains("sonnet") {
81 self.latest(ModelFamily::Sonnet)
82 } else if lower.contains("haiku") {
83 self.latest(ModelFamily::Haiku)
84 } else {
85 None
86 };
87
88 if let Some(spec) = &fallback {
89 tracing::debug!(
90 input = alias_or_id,
91 resolved = %spec.id,
92 "model resolved via substring fallback"
93 );
94 }
95
96 fallback
97 }
98
99 pub fn default_for_role(&self, role: ModelRole) -> Option<&ModelSpec> {
100 let id = self.defaults.get(&role)?;
101 self.models.get(id)
102 }
103
104 pub fn latest(&self, family: ModelFamily) -> Option<&ModelSpec> {
105 let ids = self.by_family.get(&family)?;
106 let id = ids.first()?;
107 self.models.get(id)
108 }
109
110 pub fn for_provider(&self, provider: CloudProvider, provider_id: &str) -> Option<&ModelSpec> {
111 self.models
112 .values()
113 .find(|spec| spec.provider_ids.for_provider(provider) == Some(provider_id))
114 }
115
116 pub fn family_models(&self, family: ModelFamily) -> Vec<&ModelSpec> {
117 self.by_family
118 .get(&family)
119 .map(|ids| ids.iter().filter_map(|id| self.models.get(id)).collect())
120 .unwrap_or_default()
121 }
122
123 pub fn all(&self) -> impl Iterator<Item = &ModelSpec> {
124 self.models.values()
125 }
126}
127
128#[cfg(test)]
129mod tests {
130 use super::*;
131
132 #[test]
133 fn test_registry_resolve() {
134 let registry = ModelRegistry::with_builtins();
135
136 assert!(registry.resolve("sonnet").is_some());
137 assert!(registry.resolve("haiku").is_some());
138 assert!(registry.resolve("opus").is_some());
139 }
140
141 #[test]
142 fn test_registry_default_roles() {
143 let registry = ModelRegistry::with_builtins();
144
145 assert!(registry.default_for_role(ModelRole::Primary).is_some());
146 assert!(registry.default_for_role(ModelRole::Small).is_some());
147 assert!(registry.default_for_role(ModelRole::Reasoning).is_some());
148 }
149
150 #[test]
151 fn test_registry_global() {
152 assert!(read_registry().resolve("sonnet").is_some());
153 }
154}