1use std::collections::HashMap;
7use std::sync::OnceLock;
8
9use anyhow::Result;
10use serde::Deserialize;
11
12pub(crate) const MODELS_YAML: &str = include_str!("../templates/models.yaml");
14
15const FALLBACK_MAX_OUTPUT_TOKENS: usize = 4096;
17
18const FALLBACK_INPUT_CONTEXT: usize = 100_000;
20
21#[derive(Debug, Deserialize, Clone)]
23pub struct BetaHeader {
24 pub key: String,
26 pub value: String,
28 #[serde(default)]
30 pub max_output_tokens: Option<usize>,
31 #[serde(default)]
33 pub input_context: Option<usize>,
34}
35
36#[derive(Debug, Deserialize, Clone)]
38pub struct ModelSpec {
39 pub provider: String,
41 pub model: String,
43 pub api_identifier: String,
45 pub max_output_tokens: usize,
47 pub input_context: usize,
49 pub generation: f32,
51 pub tier: String,
53 #[serde(default)]
55 pub legacy: bool,
56 #[serde(default)]
58 pub beta_headers: Vec<BetaHeader>,
59}
60
61#[derive(Debug, Deserialize)]
63pub struct TierInfo {
64 pub description: String,
66 pub use_cases: Vec<String>,
68}
69
70#[derive(Debug, Deserialize)]
72pub struct DefaultConfig {
73 pub max_output_tokens: usize,
75 pub input_context: usize,
77}
78
79#[derive(Debug, Deserialize)]
81pub struct ProviderConfig {
82 pub name: String,
84 pub api_base: String,
86 pub default_model: String,
88 pub tiers: HashMap<String, TierInfo>,
90 pub defaults: DefaultConfig,
92}
93
94#[derive(Debug, Deserialize)]
96pub struct ModelConfiguration {
97 pub models: Vec<ModelSpec>,
99 pub providers: HashMap<String, ProviderConfig>,
101}
102
103pub struct ModelRegistry {
105 config: ModelConfiguration,
106 by_identifier: HashMap<String, ModelSpec>,
107 by_provider: HashMap<String, Vec<ModelSpec>>,
108}
109
110impl ModelRegistry {
111 pub fn load() -> Result<Self> {
113 let config: ModelConfiguration = serde_yaml::from_str(MODELS_YAML)?;
114
115 let mut by_identifier = HashMap::new();
117 let mut by_provider: HashMap<String, Vec<ModelSpec>> = HashMap::new();
118
119 for model in &config.models {
120 by_identifier.insert(model.api_identifier.clone(), model.clone());
121 by_provider
122 .entry(model.provider.clone())
123 .or_default()
124 .push(model.clone());
125 }
126
127 Ok(Self {
128 config,
129 by_identifier,
130 by_provider,
131 })
132 }
133
134 #[must_use]
136 pub fn get_model_spec(&self, api_identifier: &str) -> Option<&ModelSpec> {
137 if let Some(spec) = self.by_identifier.get(api_identifier) {
139 return Some(spec);
140 }
141
142 self.find_model_by_normalized_id(api_identifier)
144 }
145
146 #[must_use]
148 pub fn get_max_output_tokens(&self, api_identifier: &str) -> usize {
149 if let Some(spec) = self.get_model_spec(api_identifier) {
150 return spec.max_output_tokens;
151 }
152
153 if let Some(provider) = self.infer_provider(api_identifier) {
155 if let Some(provider_config) = self.config.providers.get(&provider) {
156 return provider_config.defaults.max_output_tokens;
157 }
158 }
159
160 FALLBACK_MAX_OUTPUT_TOKENS
162 }
163
164 #[must_use]
166 pub fn get_input_context(&self, api_identifier: &str) -> usize {
167 if let Some(spec) = self.get_model_spec(api_identifier) {
168 return spec.input_context;
169 }
170
171 if let Some(provider) = self.infer_provider(api_identifier) {
173 if let Some(provider_config) = self.config.providers.get(&provider) {
174 return provider_config.defaults.input_context;
175 }
176 }
177
178 FALLBACK_INPUT_CONTEXT
180 }
181
182 fn infer_provider(&self, api_identifier: &str) -> Option<String> {
184 if api_identifier.starts_with("claude") || api_identifier.contains("anthropic") {
185 Some("claude".to_string())
186 } else {
187 None
188 }
189 }
190
191 fn find_model_by_normalized_id(&self, api_identifier: &str) -> Option<&ModelSpec> {
196 let core_identifier = self.extract_core_model_identifier(api_identifier);
197 self.by_identifier.get(&core_identifier)
198 }
199
200 fn extract_core_model_identifier(&self, api_identifier: &str) -> String {
202 let mut identifier = api_identifier.to_string();
203
204 if let Some(dot_pos) = identifier.find('.') {
206 if identifier[..dot_pos].len() <= 3 {
207 identifier = identifier[dot_pos + 1..].to_string();
209 }
210 }
211
212 if identifier.starts_with("anthropic.") {
214 identifier = identifier["anthropic.".len()..].to_string();
215 }
216
217 if let Some(version_pos) = identifier.rfind("-v") {
219 if identifier[version_pos..].contains(':') {
220 identifier = identifier[..version_pos].to_string();
221 }
222 }
223
224 identifier
225 }
226
227 #[must_use]
229 pub fn is_legacy_model(&self, api_identifier: &str) -> bool {
230 self.get_model_spec(api_identifier)
231 .is_some_and(|spec| spec.legacy)
232 }
233
234 #[must_use]
236 pub fn get_all_models(&self) -> &[ModelSpec] {
237 &self.config.models
238 }
239
240 #[must_use]
242 pub fn get_models_by_provider(&self, provider: &str) -> Vec<&ModelSpec> {
243 self.by_provider
244 .get(provider)
245 .map(|models| models.iter().collect())
246 .unwrap_or_default()
247 }
248
249 #[must_use]
251 pub fn get_models_by_provider_and_tier(&self, provider: &str, tier: &str) -> Vec<&ModelSpec> {
252 self.get_models_by_provider(provider)
253 .into_iter()
254 .filter(|model| model.tier == tier)
255 .collect()
256 }
257
258 #[must_use]
260 pub fn get_default_model(&self, provider: &str) -> Option<&str> {
261 self.config
262 .providers
263 .get(provider)
264 .map(|p| p.default_model.as_str())
265 }
266
267 #[must_use]
269 pub fn get_provider_config(&self, provider: &str) -> Option<&ProviderConfig> {
270 self.config.providers.get(provider)
271 }
272
273 #[must_use]
275 pub fn get_tier_info(&self, provider: &str, tier: &str) -> Option<&TierInfo> {
276 self.config.providers.get(provider)?.tiers.get(tier)
277 }
278
279 #[must_use]
281 pub fn get_beta_headers(&self, api_identifier: &str) -> &[BetaHeader] {
282 self.get_model_spec(api_identifier)
283 .map(|spec| spec.beta_headers.as_slice())
284 .unwrap_or_default()
285 }
286
287 #[must_use]
289 pub fn get_max_output_tokens_with_beta(&self, api_identifier: &str, beta_value: &str) -> usize {
290 if let Some(spec) = self.get_model_spec(api_identifier) {
291 if let Some(bh) = spec.beta_headers.iter().find(|b| b.value == beta_value) {
292 if let Some(max) = bh.max_output_tokens {
293 return max;
294 }
295 }
296 return spec.max_output_tokens;
297 }
298 self.get_max_output_tokens(api_identifier)
299 }
300
301 #[must_use]
303 pub fn get_input_context_with_beta(&self, api_identifier: &str, beta_value: &str) -> usize {
304 if let Some(spec) = self.get_model_spec(api_identifier) {
305 if let Some(bh) = spec.beta_headers.iter().find(|b| b.value == beta_value) {
306 if let Some(ctx) = bh.input_context {
307 return ctx;
308 }
309 }
310 return spec.input_context;
311 }
312 self.get_input_context(api_identifier)
313 }
314}
315
316static MODEL_REGISTRY: OnceLock<ModelRegistry> = OnceLock::new();
318
319#[must_use]
321pub fn get_model_registry() -> &'static ModelRegistry {
322 #[allow(clippy::expect_used)] MODEL_REGISTRY.get_or_init(|| ModelRegistry::load().expect("Failed to load model registry"))
324}
325
326#[cfg(test)]
327#[allow(clippy::unwrap_used, clippy::expect_used)]
328mod tests {
329 use super::*;
330
331 #[test]
332 fn load_model_registry() {
333 let registry = ModelRegistry::load().unwrap();
334 assert!(!registry.config.models.is_empty());
335 assert!(registry.config.providers.contains_key("claude"));
336 }
337
338 #[test]
339 fn claude_model_lookup() {
340 let registry = ModelRegistry::load().unwrap();
341
342 let opus_spec = registry.get_model_spec("claude-3-opus-20240229");
344 assert!(opus_spec.is_some());
345 assert_eq!(opus_spec.unwrap().max_output_tokens, 4096);
346 assert_eq!(opus_spec.unwrap().provider, "claude");
347 assert!(registry.is_legacy_model("claude-3-opus-20240229"));
348
349 let sonnet45_tokens = registry.get_max_output_tokens("claude-sonnet-4-5-20250929");
351 assert_eq!(sonnet45_tokens, 64000);
352
353 let sonnet4_tokens = registry.get_max_output_tokens("claude-sonnet-4-20250514");
355 assert_eq!(sonnet4_tokens, 64000);
356 assert!(registry.is_legacy_model("claude-sonnet-4-20250514"));
357
358 let unknown_tokens = registry.get_max_output_tokens("claude-unknown-model");
360 assert_eq!(unknown_tokens, 4096); }
362
363 #[test]
364 fn provider_filtering() {
365 let registry = ModelRegistry::load().unwrap();
366
367 let claude_models = registry.get_models_by_provider("claude");
368 assert!(!claude_models.is_empty());
369
370 let fast_claude_models = registry.get_models_by_provider_and_tier("claude", "fast");
371 assert!(!fast_claude_models.is_empty());
372
373 let tier_info = registry.get_tier_info("claude", "fast");
374 assert!(tier_info.is_some());
375 }
376
377 #[test]
378 fn provider_config() {
379 let registry = ModelRegistry::load().unwrap();
380
381 let claude_config = registry.get_provider_config("claude");
382 assert!(claude_config.is_some());
383 assert_eq!(claude_config.unwrap().name, "Anthropic Claude");
384 }
385
386 #[test]
387 fn default_model_per_provider() {
388 let registry = ModelRegistry::load().unwrap();
389
390 assert_eq!(
391 registry.get_default_model("claude"),
392 Some("claude-sonnet-4-6")
393 );
394 assert_eq!(registry.get_default_model("openai"), Some("gpt-5-mini"));
395 assert_eq!(
396 registry.get_default_model("gemini"),
397 Some("gemini-2.5-flash")
398 );
399 assert_eq!(registry.get_default_model("nonexistent"), None);
400 }
401
402 #[test]
403 fn normalized_id_matching() {
404 let registry = ModelRegistry::load().unwrap();
405
406 let bedrock_3_7_sonnet = "us.anthropic.claude-3-7-sonnet-20250219-v1:0";
408 let spec = registry.get_model_spec(bedrock_3_7_sonnet);
409 assert!(spec.is_some());
410 assert_eq!(spec.unwrap().api_identifier, "claude-3-7-sonnet-20250219");
411 assert_eq!(spec.unwrap().max_output_tokens, 64000);
412
413 let aws_haiku = "anthropic.claude-3-haiku-20240307-v1:0";
415 let spec = registry.get_model_spec(aws_haiku);
416 assert!(spec.is_some());
417 assert_eq!(spec.unwrap().api_identifier, "claude-3-haiku-20240307");
418 assert_eq!(spec.unwrap().max_output_tokens, 4096);
419
420 let eu_opus = "eu.anthropic.claude-3-opus-20240229-v2:1";
422 let spec = registry.get_model_spec(eu_opus);
423 assert!(spec.is_some());
424 assert_eq!(spec.unwrap().api_identifier, "claude-3-opus-20240229");
425 assert_eq!(spec.unwrap().max_output_tokens, 4096);
426
427 let exact_sonnet45 = "claude-sonnet-4-5-20250929";
429 let spec = registry.get_model_spec(exact_sonnet45);
430 assert!(spec.is_some());
431 assert_eq!(spec.unwrap().max_output_tokens, 64000);
432
433 let exact_sonnet4 = "claude-sonnet-4-20250514";
435 let spec = registry.get_model_spec(exact_sonnet4);
436 assert!(spec.is_some());
437 assert_eq!(spec.unwrap().max_output_tokens, 64000);
438 }
439
440 #[test]
441 fn extract_core_model_identifier() {
442 let registry = ModelRegistry::load().unwrap();
443
444 assert_eq!(
446 registry.extract_core_model_identifier("us.anthropic.claude-3-7-sonnet-20250219-v1:0"),
447 "claude-3-7-sonnet-20250219"
448 );
449
450 assert_eq!(
451 registry.extract_core_model_identifier("anthropic.claude-3-haiku-20240307-v1:0"),
452 "claude-3-haiku-20240307"
453 );
454
455 assert_eq!(
456 registry.extract_core_model_identifier("claude-3-opus-20240229"),
457 "claude-3-opus-20240229"
458 );
459
460 assert_eq!(
461 registry.extract_core_model_identifier("eu.anthropic.claude-sonnet-4-20250514-v2:1"),
462 "claude-sonnet-4-20250514"
463 );
464 }
465
466 #[test]
467 fn beta_header_lookups() {
468 let registry = ModelRegistry::load().unwrap();
469
470 assert_eq!(registry.get_max_output_tokens("claude-opus-4-6"), 128_000);
472 assert_eq!(registry.get_input_context("claude-opus-4-6"), 200_000);
473
474 assert_eq!(
476 registry.get_input_context_with_beta("claude-opus-4-6", "context-1m-2025-08-07"),
477 1_000_000
478 );
479 assert_eq!(
481 registry.get_max_output_tokens_with_beta("claude-opus-4-6", "context-1m-2025-08-07"),
482 128_000
483 );
484
485 assert_eq!(
487 registry.get_max_output_tokens_with_beta(
488 "claude-3-7-sonnet-20250219",
489 "output-128k-2025-02-19"
490 ),
491 128_000
492 );
493
494 assert_eq!(
496 registry.get_max_output_tokens("claude-3-7-sonnet-20250219"),
497 64000
498 );
499
500 let headers = registry.get_beta_headers("claude-opus-4-6");
502 assert_eq!(headers.len(), 1);
503 assert_eq!(headers[0].key, "anthropic-beta");
504 assert_eq!(headers[0].value, "context-1m-2025-08-07");
505
506 let headers = registry.get_beta_headers("claude-3-7-sonnet-20250219");
508 assert_eq!(headers.len(), 2);
509
510 let headers = registry.get_beta_headers("claude-3-haiku-20240307");
512 assert!(headers.is_empty());
513
514 let headers = registry.get_beta_headers("unknown-model");
516 assert!(headers.is_empty());
517 }
518}