1use std::collections::HashMap;
7use std::sync::OnceLock;
8
9use anyhow::Result;
10use serde::Deserialize;
11
12#[derive(Debug, Deserialize, Clone)]
14pub struct BetaHeader {
15 pub key: String,
17 pub value: String,
19 #[serde(default)]
21 pub max_output_tokens: Option<usize>,
22 #[serde(default)]
24 pub input_context: Option<usize>,
25}
26
27#[derive(Debug, Deserialize, Clone)]
29pub struct ModelSpec {
30 pub provider: String,
32 pub model: String,
34 pub api_identifier: String,
36 pub max_output_tokens: usize,
38 pub input_context: usize,
40 pub generation: f32,
42 pub tier: String,
44 #[serde(default)]
46 pub legacy: bool,
47 #[serde(default)]
49 pub beta_headers: Vec<BetaHeader>,
50}
51
52#[derive(Debug, Deserialize)]
54pub struct TierInfo {
55 pub description: String,
57 pub use_cases: Vec<String>,
59}
60
61#[derive(Debug, Deserialize)]
63pub struct DefaultConfig {
64 pub max_output_tokens: usize,
66 pub input_context: usize,
68}
69
70#[derive(Debug, Deserialize)]
72pub struct ProviderConfig {
73 pub name: String,
75 pub api_base: String,
77 pub default_model: String,
79 pub tiers: HashMap<String, TierInfo>,
81 pub defaults: DefaultConfig,
83}
84
85#[derive(Debug, Deserialize)]
87pub struct ModelConfiguration {
88 pub models: Vec<ModelSpec>,
90 pub providers: HashMap<String, ProviderConfig>,
92}
93
94pub struct ModelRegistry {
96 config: ModelConfiguration,
97 by_identifier: HashMap<String, ModelSpec>,
98 by_provider: HashMap<String, Vec<ModelSpec>>,
99}
100
101impl ModelRegistry {
102 pub fn load() -> Result<Self> {
104 let yaml_content = include_str!("../templates/models.yaml");
105 let config: ModelConfiguration = serde_yaml::from_str(yaml_content)?;
106
107 let mut by_identifier = HashMap::new();
109 let mut by_provider: HashMap<String, Vec<ModelSpec>> = HashMap::new();
110
111 for model in &config.models {
112 by_identifier.insert(model.api_identifier.clone(), model.clone());
113 by_provider
114 .entry(model.provider.clone())
115 .or_default()
116 .push(model.clone());
117 }
118
119 Ok(Self {
120 config,
121 by_identifier,
122 by_provider,
123 })
124 }
125
126 pub fn get_model_spec(&self, api_identifier: &str) -> Option<&ModelSpec> {
128 if let Some(spec) = self.by_identifier.get(api_identifier) {
130 return Some(spec);
131 }
132
133 self.find_model_by_fuzzy_match(api_identifier)
135 }
136
137 pub fn get_max_output_tokens(&self, api_identifier: &str) -> usize {
139 if let Some(spec) = self.get_model_spec(api_identifier) {
140 return spec.max_output_tokens;
141 }
142
143 if let Some(provider) = self.infer_provider(api_identifier) {
145 if let Some(provider_config) = self.config.providers.get(&provider) {
146 return provider_config.defaults.max_output_tokens;
147 }
148 }
149
150 4096
152 }
153
154 pub fn get_input_context(&self, api_identifier: &str) -> usize {
156 if let Some(spec) = self.get_model_spec(api_identifier) {
157 return spec.input_context;
158 }
159
160 if let Some(provider) = self.infer_provider(api_identifier) {
162 if let Some(provider_config) = self.config.providers.get(&provider) {
163 return provider_config.defaults.input_context;
164 }
165 }
166
167 100000
169 }
170
171 fn infer_provider(&self, api_identifier: &str) -> Option<String> {
173 if api_identifier.starts_with("claude") || api_identifier.contains("anthropic") {
174 Some("claude".to_string())
175 } else {
176 None
177 }
178 }
179
180 fn find_model_by_fuzzy_match(&self, api_identifier: &str) -> Option<&ModelSpec> {
182 let core_identifier = self.extract_core_model_identifier(api_identifier);
188
189 if let Some(spec) = self.by_identifier.get(&core_identifier) {
191 return Some(spec);
192 }
193
194 for (stored_id, spec) in &self.by_identifier {
196 if self.models_match_fuzzy(&core_identifier, stored_id) {
197 return Some(spec);
198 }
199 }
200
201 None
202 }
203
204 fn extract_core_model_identifier(&self, api_identifier: &str) -> String {
206 let mut identifier = api_identifier.to_string();
207
208 if let Some(dot_pos) = identifier.find('.') {
210 if identifier[..dot_pos].len() <= 3 {
211 identifier = identifier[dot_pos + 1..].to_string();
213 }
214 }
215
216 if identifier.starts_with("anthropic.") {
218 identifier = identifier["anthropic.".len()..].to_string();
219 }
220
221 if let Some(version_pos) = identifier.rfind("-v") {
223 if identifier[version_pos..].contains(':') {
224 identifier = identifier[..version_pos].to_string();
225 }
226 }
227
228 identifier
229 }
230
231 fn models_match_fuzzy(&self, input_id: &str, stored_id: &str) -> bool {
233 input_id == stored_id
236 }
237
238 #[must_use]
240 pub fn is_legacy_model(&self, api_identifier: &str) -> bool {
241 self.get_model_spec(api_identifier)
242 .map(|spec| spec.legacy)
243 .unwrap_or(false)
244 }
245
246 pub fn get_all_models(&self) -> &[ModelSpec] {
248 &self.config.models
249 }
250
251 pub fn get_models_by_provider(&self, provider: &str) -> Vec<&ModelSpec> {
253 self.by_provider
254 .get(provider)
255 .map(|models| models.iter().collect())
256 .unwrap_or_default()
257 }
258
259 pub fn get_models_by_provider_and_tier(&self, provider: &str, tier: &str) -> Vec<&ModelSpec> {
261 self.get_models_by_provider(provider)
262 .into_iter()
263 .filter(|model| model.tier == tier)
264 .collect()
265 }
266
267 pub fn get_provider_config(&self, provider: &str) -> Option<&ProviderConfig> {
269 self.config.providers.get(provider)
270 }
271
272 pub fn get_tier_info(&self, provider: &str, tier: &str) -> Option<&TierInfo> {
274 self.config.providers.get(provider)?.tiers.get(tier)
275 }
276
277 pub fn get_beta_headers(&self, api_identifier: &str) -> &[BetaHeader] {
279 self.get_model_spec(api_identifier)
280 .map(|spec| spec.beta_headers.as_slice())
281 .unwrap_or_default()
282 }
283
284 pub fn get_max_output_tokens_with_beta(&self, api_identifier: &str, beta_value: &str) -> usize {
286 if let Some(spec) = self.get_model_spec(api_identifier) {
287 if let Some(bh) = spec.beta_headers.iter().find(|b| b.value == beta_value) {
288 if let Some(max) = bh.max_output_tokens {
289 return max;
290 }
291 }
292 return spec.max_output_tokens;
293 }
294 self.get_max_output_tokens(api_identifier)
295 }
296
297 pub fn get_input_context_with_beta(&self, api_identifier: &str, beta_value: &str) -> usize {
299 if let Some(spec) = self.get_model_spec(api_identifier) {
300 if let Some(bh) = spec.beta_headers.iter().find(|b| b.value == beta_value) {
301 if let Some(ctx) = bh.input_context {
302 return ctx;
303 }
304 }
305 return spec.input_context;
306 }
307 self.get_input_context(api_identifier)
308 }
309}
310
311static MODEL_REGISTRY: OnceLock<ModelRegistry> = OnceLock::new();
313
314pub fn get_model_registry() -> &'static ModelRegistry {
316 MODEL_REGISTRY.get_or_init(|| ModelRegistry::load().expect("Failed to load model registry"))
317}
318
319#[cfg(test)]
320mod tests {
321 use super::*;
322
323 #[test]
324 fn load_model_registry() {
325 let registry = ModelRegistry::load().unwrap();
326 assert!(!registry.config.models.is_empty());
327 assert!(registry.config.providers.contains_key("claude"));
328 }
329
330 #[test]
331 fn claude_model_lookup() {
332 let registry = ModelRegistry::load().unwrap();
333
334 let opus_spec = registry.get_model_spec("claude-3-opus-20240229");
336 assert!(opus_spec.is_some());
337 assert_eq!(opus_spec.unwrap().max_output_tokens, 4096);
338 assert_eq!(opus_spec.unwrap().provider, "claude");
339 assert!(registry.is_legacy_model("claude-3-opus-20240229"));
340
341 let sonnet45_tokens = registry.get_max_output_tokens("claude-sonnet-4-5-20250929");
343 assert_eq!(sonnet45_tokens, 64000);
344
345 let sonnet4_tokens = registry.get_max_output_tokens("claude-sonnet-4-20250514");
347 assert_eq!(sonnet4_tokens, 64000);
348 assert!(registry.is_legacy_model("claude-sonnet-4-20250514"));
349
350 let unknown_tokens = registry.get_max_output_tokens("claude-unknown-model");
352 assert_eq!(unknown_tokens, 4096); }
354
355 #[test]
356 fn provider_filtering() {
357 let registry = ModelRegistry::load().unwrap();
358
359 let claude_models = registry.get_models_by_provider("claude");
360 assert!(!claude_models.is_empty());
361
362 let fast_claude_models = registry.get_models_by_provider_and_tier("claude", "fast");
363 assert!(!fast_claude_models.is_empty());
364
365 let tier_info = registry.get_tier_info("claude", "fast");
366 assert!(tier_info.is_some());
367 }
368
369 #[test]
370 fn provider_config() {
371 let registry = ModelRegistry::load().unwrap();
372
373 let claude_config = registry.get_provider_config("claude");
374 assert!(claude_config.is_some());
375 assert_eq!(claude_config.unwrap().name, "Anthropic Claude");
376 }
377
378 #[test]
379 fn fuzzy_model_matching() {
380 let registry = ModelRegistry::load().unwrap();
381
382 let bedrock_3_7_sonnet = "us.anthropic.claude-3-7-sonnet-20250219-v1:0";
384 let spec = registry.get_model_spec(bedrock_3_7_sonnet);
385 assert!(spec.is_some());
386 assert_eq!(spec.unwrap().api_identifier, "claude-3-7-sonnet-20250219");
387 assert_eq!(spec.unwrap().max_output_tokens, 64000);
388
389 let aws_haiku = "anthropic.claude-3-haiku-20240307-v1:0";
391 let spec = registry.get_model_spec(aws_haiku);
392 assert!(spec.is_some());
393 assert_eq!(spec.unwrap().api_identifier, "claude-3-haiku-20240307");
394 assert_eq!(spec.unwrap().max_output_tokens, 4096);
395
396 let eu_opus = "eu.anthropic.claude-3-opus-20240229-v2:1";
398 let spec = registry.get_model_spec(eu_opus);
399 assert!(spec.is_some());
400 assert_eq!(spec.unwrap().api_identifier, "claude-3-opus-20240229");
401 assert_eq!(spec.unwrap().max_output_tokens, 4096);
402
403 let exact_sonnet45 = "claude-sonnet-4-5-20250929";
405 let spec = registry.get_model_spec(exact_sonnet45);
406 assert!(spec.is_some());
407 assert_eq!(spec.unwrap().max_output_tokens, 64000);
408
409 let exact_sonnet4 = "claude-sonnet-4-20250514";
411 let spec = registry.get_model_spec(exact_sonnet4);
412 assert!(spec.is_some());
413 assert_eq!(spec.unwrap().max_output_tokens, 64000);
414 }
415
416 #[test]
417 fn extract_core_model_identifier() {
418 let registry = ModelRegistry::load().unwrap();
419
420 assert_eq!(
422 registry.extract_core_model_identifier("us.anthropic.claude-3-7-sonnet-20250219-v1:0"),
423 "claude-3-7-sonnet-20250219"
424 );
425
426 assert_eq!(
427 registry.extract_core_model_identifier("anthropic.claude-3-haiku-20240307-v1:0"),
428 "claude-3-haiku-20240307"
429 );
430
431 assert_eq!(
432 registry.extract_core_model_identifier("claude-3-opus-20240229"),
433 "claude-3-opus-20240229"
434 );
435
436 assert_eq!(
437 registry.extract_core_model_identifier("eu.anthropic.claude-sonnet-4-20250514-v2:1"),
438 "claude-sonnet-4-20250514"
439 );
440 }
441
442 #[test]
443 fn beta_header_lookups() {
444 let registry = ModelRegistry::load().unwrap();
445
446 assert_eq!(registry.get_max_output_tokens("claude-opus-4-6"), 128000);
448 assert_eq!(registry.get_input_context("claude-opus-4-6"), 200000);
449
450 assert_eq!(
452 registry.get_input_context_with_beta("claude-opus-4-6", "context-1m-2025-08-07"),
453 1000000
454 );
455 assert_eq!(
457 registry.get_max_output_tokens_with_beta("claude-opus-4-6", "context-1m-2025-08-07"),
458 128000
459 );
460
461 assert_eq!(
463 registry.get_max_output_tokens_with_beta(
464 "claude-3-7-sonnet-20250219",
465 "output-128k-2025-02-19"
466 ),
467 128000
468 );
469
470 assert_eq!(
472 registry.get_max_output_tokens("claude-3-7-sonnet-20250219"),
473 64000
474 );
475
476 let headers = registry.get_beta_headers("claude-opus-4-6");
478 assert_eq!(headers.len(), 1);
479 assert_eq!(headers[0].key, "anthropic-beta");
480 assert_eq!(headers[0].value, "context-1m-2025-08-07");
481
482 let headers = registry.get_beta_headers("claude-3-7-sonnet-20250219");
484 assert_eq!(headers.len(), 2);
485
486 let headers = registry.get_beta_headers("claude-3-haiku-20240307");
488 assert!(headers.is_empty());
489
490 let headers = registry.get_beta_headers("unknown-model");
492 assert!(headers.is_empty());
493 }
494}