1use anyhow::Result;
7use serde::Deserialize;
8use std::collections::HashMap;
9use std::sync::OnceLock;
10
11#[derive(Debug, Deserialize, Clone)]
13pub struct BetaHeader {
14 pub key: String,
16 pub value: String,
18 #[serde(default)]
20 pub max_output_tokens: Option<usize>,
21 #[serde(default)]
23 pub input_context: Option<usize>,
24}
25
26#[derive(Debug, Deserialize, Clone)]
28pub struct ModelSpec {
29 pub provider: String,
31 pub model: String,
33 pub api_identifier: String,
35 pub max_output_tokens: usize,
37 pub input_context: usize,
39 pub generation: f32,
41 pub tier: String,
43 #[serde(default)]
45 pub legacy: bool,
46 #[serde(default)]
48 pub beta_headers: Vec<BetaHeader>,
49}
50
51#[derive(Debug, Deserialize)]
53pub struct TierInfo {
54 pub description: String,
56 pub use_cases: Vec<String>,
58}
59
60#[derive(Debug, Deserialize)]
62pub struct DefaultConfig {
63 pub max_output_tokens: usize,
65 pub input_context: usize,
67}
68
69#[derive(Debug, Deserialize)]
71pub struct ProviderConfig {
72 pub name: String,
74 pub api_base: String,
76 pub default_model: String,
78 pub tiers: HashMap<String, TierInfo>,
80 pub defaults: DefaultConfig,
82}
83
84#[derive(Debug, Deserialize)]
86pub struct ModelConfiguration {
87 pub models: Vec<ModelSpec>,
89 pub providers: HashMap<String, ProviderConfig>,
91}
92
93pub struct ModelRegistry {
95 config: ModelConfiguration,
96 by_identifier: HashMap<String, ModelSpec>,
97 by_provider: HashMap<String, Vec<ModelSpec>>,
98}
99
100impl ModelRegistry {
101 pub fn load() -> Result<Self> {
103 let yaml_content = include_str!("../templates/models.yaml");
104 let config: ModelConfiguration = serde_yaml::from_str(yaml_content)?;
105
106 let mut by_identifier = HashMap::new();
108 let mut by_provider: HashMap<String, Vec<ModelSpec>> = HashMap::new();
109
110 for model in &config.models {
111 by_identifier.insert(model.api_identifier.clone(), model.clone());
112 by_provider
113 .entry(model.provider.clone())
114 .or_default()
115 .push(model.clone());
116 }
117
118 Ok(Self {
119 config,
120 by_identifier,
121 by_provider,
122 })
123 }
124
125 pub fn get_model_spec(&self, api_identifier: &str) -> Option<&ModelSpec> {
127 if let Some(spec) = self.by_identifier.get(api_identifier) {
129 return Some(spec);
130 }
131
132 self.find_model_by_fuzzy_match(api_identifier)
134 }
135
136 pub fn get_max_output_tokens(&self, api_identifier: &str) -> usize {
138 if let Some(spec) = self.get_model_spec(api_identifier) {
139 return spec.max_output_tokens;
140 }
141
142 if let Some(provider) = self.infer_provider(api_identifier) {
144 if let Some(provider_config) = self.config.providers.get(&provider) {
145 return provider_config.defaults.max_output_tokens;
146 }
147 }
148
149 4096
151 }
152
153 pub fn get_input_context(&self, api_identifier: &str) -> usize {
155 if let Some(spec) = self.get_model_spec(api_identifier) {
156 return spec.input_context;
157 }
158
159 if let Some(provider) = self.infer_provider(api_identifier) {
161 if let Some(provider_config) = self.config.providers.get(&provider) {
162 return provider_config.defaults.input_context;
163 }
164 }
165
166 100000
168 }
169
170 fn infer_provider(&self, api_identifier: &str) -> Option<String> {
172 if api_identifier.starts_with("claude") || api_identifier.contains("anthropic") {
173 Some("claude".to_string())
174 } else {
175 None
176 }
177 }
178
179 fn find_model_by_fuzzy_match(&self, api_identifier: &str) -> Option<&ModelSpec> {
181 let core_identifier = self.extract_core_model_identifier(api_identifier);
187
188 if let Some(spec) = self.by_identifier.get(&core_identifier) {
190 return Some(spec);
191 }
192
193 for (stored_id, spec) in &self.by_identifier {
195 if self.models_match_fuzzy(&core_identifier, stored_id) {
196 return Some(spec);
197 }
198 }
199
200 None
201 }
202
203 fn extract_core_model_identifier(&self, api_identifier: &str) -> String {
205 let mut identifier = api_identifier.to_string();
206
207 if let Some(dot_pos) = identifier.find('.') {
209 if identifier[..dot_pos].len() <= 3 {
210 identifier = identifier[dot_pos + 1..].to_string();
212 }
213 }
214
215 if identifier.starts_with("anthropic.") {
217 identifier = identifier["anthropic.".len()..].to_string();
218 }
219
220 if let Some(version_pos) = identifier.rfind("-v") {
222 if identifier[version_pos..].contains(':') {
223 identifier = identifier[..version_pos].to_string();
224 }
225 }
226
227 identifier
228 }
229
230 fn models_match_fuzzy(&self, input_id: &str, stored_id: &str) -> bool {
232 input_id == stored_id
235 }
236
237 pub fn is_legacy_model(&self, api_identifier: &str) -> bool {
239 self.get_model_spec(api_identifier)
240 .map(|spec| spec.legacy)
241 .unwrap_or(false)
242 }
243
244 pub fn get_all_models(&self) -> &[ModelSpec] {
246 &self.config.models
247 }
248
249 pub fn get_models_by_provider(&self, provider: &str) -> Vec<&ModelSpec> {
251 self.by_provider
252 .get(provider)
253 .map(|models| models.iter().collect())
254 .unwrap_or_default()
255 }
256
257 pub fn get_models_by_provider_and_tier(&self, provider: &str, tier: &str) -> Vec<&ModelSpec> {
259 self.get_models_by_provider(provider)
260 .into_iter()
261 .filter(|model| model.tier == tier)
262 .collect()
263 }
264
265 pub fn get_provider_config(&self, provider: &str) -> Option<&ProviderConfig> {
267 self.config.providers.get(provider)
268 }
269
270 pub fn get_tier_info(&self, provider: &str, tier: &str) -> Option<&TierInfo> {
272 self.config.providers.get(provider)?.tiers.get(tier)
273 }
274
275 pub fn get_beta_headers(&self, api_identifier: &str) -> &[BetaHeader] {
277 self.get_model_spec(api_identifier)
278 .map(|spec| spec.beta_headers.as_slice())
279 .unwrap_or_default()
280 }
281
282 pub fn get_max_output_tokens_with_beta(&self, api_identifier: &str, beta_value: &str) -> usize {
284 if let Some(spec) = self.get_model_spec(api_identifier) {
285 if let Some(bh) = spec.beta_headers.iter().find(|b| b.value == beta_value) {
286 if let Some(max) = bh.max_output_tokens {
287 return max;
288 }
289 }
290 return spec.max_output_tokens;
291 }
292 self.get_max_output_tokens(api_identifier)
293 }
294
295 pub fn get_input_context_with_beta(&self, api_identifier: &str, beta_value: &str) -> usize {
297 if let Some(spec) = self.get_model_spec(api_identifier) {
298 if let Some(bh) = spec.beta_headers.iter().find(|b| b.value == beta_value) {
299 if let Some(ctx) = bh.input_context {
300 return ctx;
301 }
302 }
303 return spec.input_context;
304 }
305 self.get_input_context(api_identifier)
306 }
307}
308
309static MODEL_REGISTRY: OnceLock<ModelRegistry> = OnceLock::new();
311
312pub fn get_model_registry() -> &'static ModelRegistry {
314 MODEL_REGISTRY.get_or_init(|| ModelRegistry::load().expect("Failed to load model registry"))
315}
316
317#[cfg(test)]
318mod tests {
319 use super::*;
320
321 #[test]
322 fn test_load_model_registry() {
323 let registry = ModelRegistry::load().unwrap();
324 assert!(!registry.config.models.is_empty());
325 assert!(registry.config.providers.contains_key("claude"));
326 }
327
328 #[test]
329 fn test_claude_model_lookup() {
330 let registry = ModelRegistry::load().unwrap();
331
332 let opus_spec = registry.get_model_spec("claude-3-opus-20240229");
334 assert!(opus_spec.is_some());
335 assert_eq!(opus_spec.unwrap().max_output_tokens, 4096);
336 assert_eq!(opus_spec.unwrap().provider, "claude");
337 assert!(registry.is_legacy_model("claude-3-opus-20240229"));
338
339 let sonnet45_tokens = registry.get_max_output_tokens("claude-sonnet-4-5-20250929");
341 assert_eq!(sonnet45_tokens, 64000);
342
343 let sonnet4_tokens = registry.get_max_output_tokens("claude-sonnet-4-20250514");
345 assert_eq!(sonnet4_tokens, 64000);
346 assert!(registry.is_legacy_model("claude-sonnet-4-20250514"));
347
348 let unknown_tokens = registry.get_max_output_tokens("claude-unknown-model");
350 assert_eq!(unknown_tokens, 4096); }
352
353 #[test]
354 fn test_provider_filtering() {
355 let registry = ModelRegistry::load().unwrap();
356
357 let claude_models = registry.get_models_by_provider("claude");
358 assert!(!claude_models.is_empty());
359
360 let fast_claude_models = registry.get_models_by_provider_and_tier("claude", "fast");
361 assert!(!fast_claude_models.is_empty());
362
363 let tier_info = registry.get_tier_info("claude", "fast");
364 assert!(tier_info.is_some());
365 }
366
367 #[test]
368 fn test_provider_config() {
369 let registry = ModelRegistry::load().unwrap();
370
371 let claude_config = registry.get_provider_config("claude");
372 assert!(claude_config.is_some());
373 assert_eq!(claude_config.unwrap().name, "Anthropic Claude");
374 }
375
376 #[test]
377 fn test_fuzzy_model_matching() {
378 let registry = ModelRegistry::load().unwrap();
379
380 let bedrock_3_7_sonnet = "us.anthropic.claude-3-7-sonnet-20250219-v1:0";
382 let spec = registry.get_model_spec(bedrock_3_7_sonnet);
383 assert!(spec.is_some());
384 assert_eq!(spec.unwrap().api_identifier, "claude-3-7-sonnet-20250219");
385 assert_eq!(spec.unwrap().max_output_tokens, 64000);
386
387 let aws_haiku = "anthropic.claude-3-haiku-20240307-v1:0";
389 let spec = registry.get_model_spec(aws_haiku);
390 assert!(spec.is_some());
391 assert_eq!(spec.unwrap().api_identifier, "claude-3-haiku-20240307");
392 assert_eq!(spec.unwrap().max_output_tokens, 4096);
393
394 let eu_opus = "eu.anthropic.claude-3-opus-20240229-v2:1";
396 let spec = registry.get_model_spec(eu_opus);
397 assert!(spec.is_some());
398 assert_eq!(spec.unwrap().api_identifier, "claude-3-opus-20240229");
399 assert_eq!(spec.unwrap().max_output_tokens, 4096);
400
401 let exact_sonnet45 = "claude-sonnet-4-5-20250929";
403 let spec = registry.get_model_spec(exact_sonnet45);
404 assert!(spec.is_some());
405 assert_eq!(spec.unwrap().max_output_tokens, 64000);
406
407 let exact_sonnet4 = "claude-sonnet-4-20250514";
409 let spec = registry.get_model_spec(exact_sonnet4);
410 assert!(spec.is_some());
411 assert_eq!(spec.unwrap().max_output_tokens, 64000);
412 }
413
414 #[test]
415 fn test_extract_core_model_identifier() {
416 let registry = ModelRegistry::load().unwrap();
417
418 assert_eq!(
420 registry.extract_core_model_identifier("us.anthropic.claude-3-7-sonnet-20250219-v1:0"),
421 "claude-3-7-sonnet-20250219"
422 );
423
424 assert_eq!(
425 registry.extract_core_model_identifier("anthropic.claude-3-haiku-20240307-v1:0"),
426 "claude-3-haiku-20240307"
427 );
428
429 assert_eq!(
430 registry.extract_core_model_identifier("claude-3-opus-20240229"),
431 "claude-3-opus-20240229"
432 );
433
434 assert_eq!(
435 registry.extract_core_model_identifier("eu.anthropic.claude-sonnet-4-20250514-v2:1"),
436 "claude-sonnet-4-20250514"
437 );
438 }
439
440 #[test]
441 fn test_beta_header_lookups() {
442 let registry = ModelRegistry::load().unwrap();
443
444 assert_eq!(registry.get_max_output_tokens("claude-opus-4-6"), 128000);
446 assert_eq!(registry.get_input_context("claude-opus-4-6"), 200000);
447
448 assert_eq!(
450 registry.get_input_context_with_beta("claude-opus-4-6", "context-1m-2025-08-07"),
451 1000000
452 );
453 assert_eq!(
455 registry.get_max_output_tokens_with_beta("claude-opus-4-6", "context-1m-2025-08-07"),
456 128000
457 );
458
459 assert_eq!(
461 registry.get_max_output_tokens_with_beta(
462 "claude-3-7-sonnet-20250219",
463 "output-128k-2025-02-19"
464 ),
465 128000
466 );
467
468 assert_eq!(
470 registry.get_max_output_tokens("claude-3-7-sonnet-20250219"),
471 64000
472 );
473
474 let headers = registry.get_beta_headers("claude-opus-4-6");
476 assert_eq!(headers.len(), 1);
477 assert_eq!(headers[0].key, "anthropic-beta");
478 assert_eq!(headers[0].value, "context-1m-2025-08-07");
479
480 let headers = registry.get_beta_headers("claude-3-7-sonnet-20250219");
482 assert_eq!(headers.len(), 2);
483
484 let headers = registry.get_beta_headers("claude-3-haiku-20240307");
486 assert!(headers.is_empty());
487
488 let headers = registry.get_beta_headers("unknown-model");
490 assert!(headers.is_empty());
491 }
492}