claude_agent/models/
registry.rs1use std::collections::HashMap;
2use std::sync::OnceLock;
3
4use super::builtin;
5use super::family::{ModelFamily, ModelRole};
6use super::provider::ProviderKind;
7use super::spec::{ModelId, ModelSpec};
8
9static REGISTRY: OnceLock<ModelRegistry> = OnceLock::new();
10
11pub fn registry() -> &'static ModelRegistry {
12 REGISTRY.get_or_init(ModelRegistry::builtins)
13}
14
15#[derive(Debug, Default)]
16pub struct ModelRegistry {
17 models: HashMap<ModelId, ModelSpec>,
18 aliases: HashMap<String, ModelId>,
19 by_family: HashMap<ModelFamily, Vec<ModelId>>,
20 defaults: HashMap<ModelRole, ModelId>,
21}
22
23impl ModelRegistry {
24 pub fn new() -> Self {
25 Self::default()
26 }
27
28 pub fn builtins() -> Self {
29 let mut registry = Self::new();
30 builtin::register_all(&mut registry);
31 registry
32 }
33
34 pub fn register(&mut self, spec: ModelSpec) {
35 let id = spec.id.clone();
36 let family = spec.family;
37
38 self.models.insert(id.clone(), spec);
39
40 self.by_family.entry(family).or_default().push(id.clone());
41
42 for alias in family.aliases() {
43 self.aliases.insert(alias.to_string(), id.clone());
44 }
45 }
46
47 pub fn set_default(&mut self, role: ModelRole, id: ModelId) {
48 self.defaults.insert(role, id);
49 }
50
51 pub fn add_alias(&mut self, alias: impl Into<String>, id: ModelId) {
52 self.aliases.insert(alias.into(), id);
53 }
54
55 pub fn get(&self, id: &str) -> Option<&ModelSpec> {
56 self.models.get(id)
57 }
58
59 pub fn resolve(&self, alias_or_id: &str) -> Option<&ModelSpec> {
60 if let Some(spec) = self.models.get(alias_or_id) {
62 return Some(spec);
63 }
64
65 if let Some(canonical) = self.aliases.get(alias_or_id) {
67 return self.models.get(canonical);
68 }
69
70 let lower = alias_or_id.to_lowercase();
72 let fallback = if lower.contains("opus") {
73 self.latest(ModelFamily::Opus)
74 } else if lower.contains("sonnet") {
75 self.latest(ModelFamily::Sonnet)
76 } else if lower.contains("haiku") {
77 self.latest(ModelFamily::Haiku)
78 } else {
79 None
80 };
81
82 if let Some(spec) = &fallback {
83 tracing::debug!(
84 input = alias_or_id,
85 resolved = %spec.id,
86 "model resolved via substring fallback"
87 );
88 }
89
90 fallback
91 }
92
93 pub fn default_for_role(&self, role: ModelRole) -> Option<&ModelSpec> {
94 let id = self.defaults.get(&role)?;
95 self.models.get(id)
96 }
97
98 pub fn latest(&self, family: ModelFamily) -> Option<&ModelSpec> {
99 let ids = self.by_family.get(&family)?;
100 let id = ids.first()?;
101 self.models.get(id)
102 }
103
104 pub fn for_provider(&self, provider: ProviderKind, provider_id: &str) -> Option<&ModelSpec> {
105 self.models
106 .values()
107 .find(|spec| spec.provider_ids.for_provider(provider) == Some(provider_id))
108 }
109
110 pub fn family_models(&self, family: ModelFamily) -> Vec<&ModelSpec> {
111 self.by_family
112 .get(&family)
113 .map(|ids| ids.iter().filter_map(|id| self.models.get(id)).collect())
114 .unwrap_or_default()
115 }
116
117 pub fn all(&self) -> impl Iterator<Item = &ModelSpec> {
118 self.models.values()
119 }
120}
121
122#[cfg(test)]
123mod tests {
124 use super::*;
125
126 #[test]
127 fn test_registry_resolve() {
128 let registry = ModelRegistry::builtins();
129
130 assert!(registry.resolve("sonnet").is_some());
131 assert!(registry.resolve("haiku").is_some());
132 assert!(registry.resolve("opus").is_some());
133 }
134
135 #[test]
136 fn test_registry_default_roles() {
137 let registry = ModelRegistry::builtins();
138
139 assert!(registry.default_for_role(ModelRole::Primary).is_some());
140 assert!(registry.default_for_role(ModelRole::Small).is_some());
141 assert!(registry.default_for_role(ModelRole::Reasoning).is_some());
142 }
143
144 #[test]
145 fn test_registry_global() {
146 assert!(registry().resolve("sonnet").is_some());
147 }
148}