omni_dev/claude/
model_config.rs1use anyhow::Result;
7use serde::Deserialize;
8use std::collections::HashMap;
9use std::sync::OnceLock;
10
11#[derive(Debug, Deserialize, Clone)]
13pub struct ModelSpec {
14 pub provider: String,
16 pub model: String,
18 pub api_identifier: String,
20 pub max_output_tokens: usize,
22 pub input_context: usize,
24 pub generation: f32,
26 pub tier: String,
28 #[serde(default)]
30 pub legacy: bool,
31}
32
33#[derive(Debug, Deserialize)]
35pub struct TierInfo {
36 pub description: String,
38 pub use_cases: Vec<String>,
40}
41
42#[derive(Debug, Deserialize)]
44pub struct DefaultConfig {
45 pub max_output_tokens: usize,
47 pub input_context: usize,
49}
50
51#[derive(Debug, Deserialize)]
53pub struct ProviderConfig {
54 pub name: String,
56 pub api_base: String,
58 pub default_model: String,
60 pub tiers: HashMap<String, TierInfo>,
62 pub defaults: DefaultConfig,
64}
65
66#[derive(Debug, Deserialize)]
68pub struct ModelConfiguration {
69 pub models: Vec<ModelSpec>,
71 pub providers: HashMap<String, ProviderConfig>,
73}
74
75pub struct ModelRegistry {
77 config: ModelConfiguration,
78 by_identifier: HashMap<String, ModelSpec>,
79 by_provider: HashMap<String, Vec<ModelSpec>>,
80}
81
82impl ModelRegistry {
83 pub fn load() -> Result<Self> {
85 let yaml_content = include_str!("../templates/models.yaml");
86 let config: ModelConfiguration = serde_yaml::from_str(yaml_content)?;
87
88 let mut by_identifier = HashMap::new();
90 let mut by_provider: HashMap<String, Vec<ModelSpec>> = HashMap::new();
91
92 for model in &config.models {
93 by_identifier.insert(model.api_identifier.clone(), model.clone());
94 by_provider
95 .entry(model.provider.clone())
96 .or_default()
97 .push(model.clone());
98 }
99
100 Ok(Self {
101 config,
102 by_identifier,
103 by_provider,
104 })
105 }
106
107 pub fn get_model_spec(&self, api_identifier: &str) -> Option<&ModelSpec> {
109 if let Some(spec) = self.by_identifier.get(api_identifier) {
111 return Some(spec);
112 }
113
114 self.find_model_by_fuzzy_match(api_identifier)
116 }
117
118 pub fn get_max_output_tokens(&self, api_identifier: &str) -> usize {
120 if let Some(spec) = self.get_model_spec(api_identifier) {
121 return spec.max_output_tokens;
122 }
123
124 if let Some(provider) = self.infer_provider(api_identifier) {
126 if let Some(provider_config) = self.config.providers.get(&provider) {
127 return provider_config.defaults.max_output_tokens;
128 }
129 }
130
131 4096
133 }
134
135 pub fn get_input_context(&self, api_identifier: &str) -> usize {
137 if let Some(spec) = self.get_model_spec(api_identifier) {
138 return spec.input_context;
139 }
140
141 if let Some(provider) = self.infer_provider(api_identifier) {
143 if let Some(provider_config) = self.config.providers.get(&provider) {
144 return provider_config.defaults.input_context;
145 }
146 }
147
148 100000
150 }
151
152 fn infer_provider(&self, api_identifier: &str) -> Option<String> {
154 if api_identifier.starts_with("claude") || api_identifier.contains("anthropic") {
155 Some("claude".to_string())
156 } else {
157 None
158 }
159 }
160
161 fn find_model_by_fuzzy_match(&self, api_identifier: &str) -> Option<&ModelSpec> {
163 let core_identifier = self.extract_core_model_identifier(api_identifier);
169
170 if let Some(spec) = self.by_identifier.get(&core_identifier) {
172 return Some(spec);
173 }
174
175 for (stored_id, spec) in &self.by_identifier {
177 if self.models_match_fuzzy(&core_identifier, stored_id) {
178 return Some(spec);
179 }
180 }
181
182 None
183 }
184
185 fn extract_core_model_identifier(&self, api_identifier: &str) -> String {
187 let mut identifier = api_identifier.to_string();
188
189 if let Some(dot_pos) = identifier.find('.') {
191 if identifier[..dot_pos].len() <= 3 {
192 identifier = identifier[dot_pos + 1..].to_string();
194 }
195 }
196
197 if identifier.starts_with("anthropic.") {
199 identifier = identifier["anthropic.".len()..].to_string();
200 }
201
202 if let Some(version_pos) = identifier.rfind("-v") {
204 if identifier[version_pos..].contains(':') {
205 identifier = identifier[..version_pos].to_string();
206 }
207 }
208
209 identifier
210 }
211
212 fn models_match_fuzzy(&self, input_id: &str, stored_id: &str) -> bool {
214 input_id == stored_id
217 }
218
219 pub fn is_legacy_model(&self, api_identifier: &str) -> bool {
221 self.get_model_spec(api_identifier)
222 .map(|spec| spec.legacy)
223 .unwrap_or(false)
224 }
225
226 pub fn get_all_models(&self) -> &[ModelSpec] {
228 &self.config.models
229 }
230
231 pub fn get_models_by_provider(&self, provider: &str) -> Vec<&ModelSpec> {
233 self.by_provider
234 .get(provider)
235 .map(|models| models.iter().collect())
236 .unwrap_or_default()
237 }
238
239 pub fn get_models_by_provider_and_tier(&self, provider: &str, tier: &str) -> Vec<&ModelSpec> {
241 self.get_models_by_provider(provider)
242 .into_iter()
243 .filter(|model| model.tier == tier)
244 .collect()
245 }
246
247 pub fn get_provider_config(&self, provider: &str) -> Option<&ProviderConfig> {
249 self.config.providers.get(provider)
250 }
251
252 pub fn get_tier_info(&self, provider: &str, tier: &str) -> Option<&TierInfo> {
254 self.config.providers.get(provider)?.tiers.get(tier)
255 }
256}
257
258static MODEL_REGISTRY: OnceLock<ModelRegistry> = OnceLock::new();
260
261pub fn get_model_registry() -> &'static ModelRegistry {
263 MODEL_REGISTRY.get_or_init(|| ModelRegistry::load().expect("Failed to load model registry"))
264}
265
266#[cfg(test)]
267mod tests {
268 use super::*;
269
270 #[test]
271 fn test_load_model_registry() {
272 let registry = ModelRegistry::load().unwrap();
273 assert!(!registry.config.models.is_empty());
274 assert!(registry.config.providers.contains_key("claude"));
275 }
276
277 #[test]
278 fn test_claude_model_lookup() {
279 let registry = ModelRegistry::load().unwrap();
280
281 let opus_spec = registry.get_model_spec("claude-3-opus-20240229");
283 assert!(opus_spec.is_some());
284 assert_eq!(opus_spec.unwrap().max_output_tokens, 4096);
285 assert_eq!(opus_spec.unwrap().provider, "claude");
286 assert!(registry.is_legacy_model("claude-3-opus-20240229"));
287
288 let sonnet45_tokens = registry.get_max_output_tokens("claude-sonnet-4-5-20250929");
290 assert_eq!(sonnet45_tokens, 64000);
291
292 let sonnet4_tokens = registry.get_max_output_tokens("claude-sonnet-4-20250514");
294 assert_eq!(sonnet4_tokens, 64000);
295 assert!(registry.is_legacy_model("claude-sonnet-4-20250514"));
296
297 let unknown_tokens = registry.get_max_output_tokens("claude-unknown-model");
299 assert_eq!(unknown_tokens, 4096); }
301
302 #[test]
303 fn test_provider_filtering() {
304 let registry = ModelRegistry::load().unwrap();
305
306 let claude_models = registry.get_models_by_provider("claude");
307 assert!(!claude_models.is_empty());
308
309 let fast_claude_models = registry.get_models_by_provider_and_tier("claude", "fast");
310 assert!(!fast_claude_models.is_empty());
311
312 let tier_info = registry.get_tier_info("claude", "fast");
313 assert!(tier_info.is_some());
314 }
315
316 #[test]
317 fn test_provider_config() {
318 let registry = ModelRegistry::load().unwrap();
319
320 let claude_config = registry.get_provider_config("claude");
321 assert!(claude_config.is_some());
322 assert_eq!(claude_config.unwrap().name, "Anthropic Claude");
323 }
324
325 #[test]
326 fn test_fuzzy_model_matching() {
327 let registry = ModelRegistry::load().unwrap();
328
329 let bedrock_3_7_sonnet = "us.anthropic.claude-3-7-sonnet-20250219-v1:0";
331 let spec = registry.get_model_spec(bedrock_3_7_sonnet);
332 assert!(spec.is_some());
333 assert_eq!(spec.unwrap().api_identifier, "claude-3-7-sonnet-20250219");
334 assert_eq!(spec.unwrap().max_output_tokens, 64000);
335
336 let aws_haiku = "anthropic.claude-3-haiku-20240307-v1:0";
338 let spec = registry.get_model_spec(aws_haiku);
339 assert!(spec.is_some());
340 assert_eq!(spec.unwrap().api_identifier, "claude-3-haiku-20240307");
341 assert_eq!(spec.unwrap().max_output_tokens, 4096);
342
343 let eu_opus = "eu.anthropic.claude-3-opus-20240229-v2:1";
345 let spec = registry.get_model_spec(eu_opus);
346 assert!(spec.is_some());
347 assert_eq!(spec.unwrap().api_identifier, "claude-3-opus-20240229");
348 assert_eq!(spec.unwrap().max_output_tokens, 4096);
349
350 let exact_sonnet45 = "claude-sonnet-4-5-20250929";
352 let spec = registry.get_model_spec(exact_sonnet45);
353 assert!(spec.is_some());
354 assert_eq!(spec.unwrap().max_output_tokens, 64000);
355
356 let exact_sonnet4 = "claude-sonnet-4-20250514";
358 let spec = registry.get_model_spec(exact_sonnet4);
359 assert!(spec.is_some());
360 assert_eq!(spec.unwrap().max_output_tokens, 64000);
361 }
362
363 #[test]
364 fn test_extract_core_model_identifier() {
365 let registry = ModelRegistry::load().unwrap();
366
367 assert_eq!(
369 registry.extract_core_model_identifier("us.anthropic.claude-3-7-sonnet-20250219-v1:0"),
370 "claude-3-7-sonnet-20250219"
371 );
372
373 assert_eq!(
374 registry.extract_core_model_identifier("anthropic.claude-3-haiku-20240307-v1:0"),
375 "claude-3-haiku-20240307"
376 );
377
378 assert_eq!(
379 registry.extract_core_model_identifier("claude-3-opus-20240229"),
380 "claude-3-opus-20240229"
381 );
382
383 assert_eq!(
384 registry.extract_core_model_identifier("eu.anthropic.claude-sonnet-4-20250514-v2:1"),
385 "claude-sonnet-4-20250514"
386 );
387 }
388}