use std::collections::HashMap;
use std::sync::OnceLock;
use anyhow::Result;
use serde::Deserialize;
pub(crate) const MODELS_YAML: &str = include_str!("../templates/models.yaml");
const FALLBACK_MAX_OUTPUT_TOKENS: usize = 4096;
const FALLBACK_INPUT_CONTEXT: usize = 100_000;
#[derive(Debug, Deserialize, Clone)]
pub struct BetaHeader {
pub key: String,
pub value: String,
#[serde(default)]
pub max_output_tokens: Option<usize>,
#[serde(default)]
pub input_context: Option<usize>,
}
#[derive(Debug, Deserialize, Clone)]
pub struct ModelSpec {
pub provider: String,
pub model: String,
pub api_identifier: String,
pub max_output_tokens: usize,
pub input_context: usize,
pub generation: f32,
pub tier: String,
#[serde(default)]
pub legacy: bool,
#[serde(default)]
pub beta_headers: Vec<BetaHeader>,
}
#[derive(Debug, Deserialize)]
pub struct TierInfo {
pub description: String,
pub use_cases: Vec<String>,
}
#[derive(Debug, Deserialize)]
pub struct DefaultConfig {
pub max_output_tokens: usize,
pub input_context: usize,
}
#[derive(Debug, Deserialize)]
pub struct ProviderConfig {
pub name: String,
pub api_base: String,
pub default_model: String,
pub tiers: HashMap<String, TierInfo>,
pub defaults: DefaultConfig,
}
#[derive(Debug, Deserialize)]
pub struct ModelConfiguration {
pub models: Vec<ModelSpec>,
pub providers: HashMap<String, ProviderConfig>,
}
pub struct ModelRegistry {
config: ModelConfiguration,
by_identifier: HashMap<String, ModelSpec>,
by_provider: HashMap<String, Vec<ModelSpec>>,
}
impl ModelRegistry {
pub fn load() -> Result<Self> {
let config: ModelConfiguration = serde_yaml::from_str(MODELS_YAML)?;
let mut by_identifier = HashMap::new();
let mut by_provider: HashMap<String, Vec<ModelSpec>> = HashMap::new();
for model in &config.models {
by_identifier.insert(model.api_identifier.clone(), model.clone());
by_provider
.entry(model.provider.clone())
.or_default()
.push(model.clone());
}
Ok(Self {
config,
by_identifier,
by_provider,
})
}
#[must_use]
pub fn get_model_spec(&self, api_identifier: &str) -> Option<&ModelSpec> {
if let Some(spec) = self.by_identifier.get(api_identifier) {
return Some(spec);
}
self.find_model_by_normalized_id(api_identifier)
}
#[must_use]
pub fn get_max_output_tokens(&self, api_identifier: &str) -> usize {
if let Some(spec) = self.get_model_spec(api_identifier) {
return spec.max_output_tokens;
}
if let Some(provider) = self.infer_provider(api_identifier) {
if let Some(provider_config) = self.config.providers.get(&provider) {
return provider_config.defaults.max_output_tokens;
}
}
FALLBACK_MAX_OUTPUT_TOKENS
}
#[must_use]
pub fn get_input_context(&self, api_identifier: &str) -> usize {
if let Some(spec) = self.get_model_spec(api_identifier) {
return spec.input_context;
}
if let Some(provider) = self.infer_provider(api_identifier) {
if let Some(provider_config) = self.config.providers.get(&provider) {
return provider_config.defaults.input_context;
}
}
FALLBACK_INPUT_CONTEXT
}
fn infer_provider(&self, api_identifier: &str) -> Option<String> {
if api_identifier.starts_with("claude") || api_identifier.contains("anthropic") {
Some("claude".to_string())
} else {
None
}
}
fn find_model_by_normalized_id(&self, api_identifier: &str) -> Option<&ModelSpec> {
let core_identifier = self.extract_core_model_identifier(api_identifier);
self.by_identifier.get(&core_identifier)
}
fn extract_core_model_identifier(&self, api_identifier: &str) -> String {
let mut identifier = api_identifier.to_string();
if let Some(dot_pos) = identifier.find('.') {
if identifier[..dot_pos].len() <= 3 {
identifier = identifier[dot_pos + 1..].to_string();
}
}
if identifier.starts_with("anthropic.") {
identifier = identifier["anthropic.".len()..].to_string();
}
if let Some(version_pos) = identifier.rfind("-v") {
if identifier[version_pos..].contains(':') {
identifier = identifier[..version_pos].to_string();
}
}
identifier
}
#[must_use]
pub fn is_legacy_model(&self, api_identifier: &str) -> bool {
self.get_model_spec(api_identifier)
.is_some_and(|spec| spec.legacy)
}
#[must_use]
pub fn get_all_models(&self) -> &[ModelSpec] {
&self.config.models
}
#[must_use]
pub fn get_models_by_provider(&self, provider: &str) -> Vec<&ModelSpec> {
self.by_provider
.get(provider)
.map(|models| models.iter().collect())
.unwrap_or_default()
}
#[must_use]
pub fn get_models_by_provider_and_tier(&self, provider: &str, tier: &str) -> Vec<&ModelSpec> {
self.get_models_by_provider(provider)
.into_iter()
.filter(|model| model.tier == tier)
.collect()
}
#[must_use]
pub fn get_default_model(&self, provider: &str) -> Option<&str> {
self.config
.providers
.get(provider)
.map(|p| p.default_model.as_str())
}
#[must_use]
pub fn get_provider_config(&self, provider: &str) -> Option<&ProviderConfig> {
self.config.providers.get(provider)
}
#[must_use]
pub fn get_tier_info(&self, provider: &str, tier: &str) -> Option<&TierInfo> {
self.config.providers.get(provider)?.tiers.get(tier)
}
#[must_use]
pub fn get_beta_headers(&self, api_identifier: &str) -> &[BetaHeader] {
self.get_model_spec(api_identifier)
.map(|spec| spec.beta_headers.as_slice())
.unwrap_or_default()
}
#[must_use]
pub fn get_max_output_tokens_with_beta(&self, api_identifier: &str, beta_value: &str) -> usize {
if let Some(spec) = self.get_model_spec(api_identifier) {
if let Some(bh) = spec.beta_headers.iter().find(|b| b.value == beta_value) {
if let Some(max) = bh.max_output_tokens {
return max;
}
}
return spec.max_output_tokens;
}
self.get_max_output_tokens(api_identifier)
}
#[must_use]
pub fn get_input_context_with_beta(&self, api_identifier: &str, beta_value: &str) -> usize {
if let Some(spec) = self.get_model_spec(api_identifier) {
if let Some(bh) = spec.beta_headers.iter().find(|b| b.value == beta_value) {
if let Some(ctx) = bh.input_context {
return ctx;
}
}
return spec.input_context;
}
self.get_input_context(api_identifier)
}
}
static MODEL_REGISTRY: OnceLock<ModelRegistry> = OnceLock::new();
#[must_use]
pub fn get_model_registry() -> &'static ModelRegistry {
#[allow(clippy::expect_used)] MODEL_REGISTRY.get_or_init(|| ModelRegistry::load().expect("Failed to load model registry"))
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
#[test]
fn load_model_registry() {
let registry = ModelRegistry::load().unwrap();
assert!(!registry.config.models.is_empty());
assert!(registry.config.providers.contains_key("claude"));
}
#[test]
fn claude_model_lookup() {
let registry = ModelRegistry::load().unwrap();
let opus_spec = registry.get_model_spec("claude-3-opus-20240229");
assert!(opus_spec.is_some());
assert_eq!(opus_spec.unwrap().max_output_tokens, 4096);
assert_eq!(opus_spec.unwrap().provider, "claude");
assert!(registry.is_legacy_model("claude-3-opus-20240229"));
let sonnet45_tokens = registry.get_max_output_tokens("claude-sonnet-4-5-20250929");
assert_eq!(sonnet45_tokens, 64000);
let sonnet4_tokens = registry.get_max_output_tokens("claude-sonnet-4-20250514");
assert_eq!(sonnet4_tokens, 64000);
assert!(registry.is_legacy_model("claude-sonnet-4-20250514"));
let unknown_tokens = registry.get_max_output_tokens("claude-unknown-model");
assert_eq!(unknown_tokens, 4096); }
#[test]
fn provider_filtering() {
let registry = ModelRegistry::load().unwrap();
let claude_models = registry.get_models_by_provider("claude");
assert!(!claude_models.is_empty());
let fast_claude_models = registry.get_models_by_provider_and_tier("claude", "fast");
assert!(!fast_claude_models.is_empty());
let tier_info = registry.get_tier_info("claude", "fast");
assert!(tier_info.is_some());
}
#[test]
fn provider_config() {
let registry = ModelRegistry::load().unwrap();
let claude_config = registry.get_provider_config("claude");
assert!(claude_config.is_some());
assert_eq!(claude_config.unwrap().name, "Anthropic Claude");
}
#[test]
fn default_model_per_provider() {
let registry = ModelRegistry::load().unwrap();
assert_eq!(
registry.get_default_model("claude"),
Some("claude-sonnet-4-6")
);
assert_eq!(registry.get_default_model("openai"), Some("gpt-5-mini"));
assert_eq!(
registry.get_default_model("gemini"),
Some("gemini-2.5-flash")
);
assert_eq!(registry.get_default_model("nonexistent"), None);
}
#[test]
fn normalized_id_matching() {
let registry = ModelRegistry::load().unwrap();
let bedrock_3_7_sonnet = "us.anthropic.claude-3-7-sonnet-20250219-v1:0";
let spec = registry.get_model_spec(bedrock_3_7_sonnet);
assert!(spec.is_some());
assert_eq!(spec.unwrap().api_identifier, "claude-3-7-sonnet-20250219");
assert_eq!(spec.unwrap().max_output_tokens, 64000);
let aws_haiku = "anthropic.claude-3-haiku-20240307-v1:0";
let spec = registry.get_model_spec(aws_haiku);
assert!(spec.is_some());
assert_eq!(spec.unwrap().api_identifier, "claude-3-haiku-20240307");
assert_eq!(spec.unwrap().max_output_tokens, 4096);
let eu_opus = "eu.anthropic.claude-3-opus-20240229-v2:1";
let spec = registry.get_model_spec(eu_opus);
assert!(spec.is_some());
assert_eq!(spec.unwrap().api_identifier, "claude-3-opus-20240229");
assert_eq!(spec.unwrap().max_output_tokens, 4096);
let exact_sonnet45 = "claude-sonnet-4-5-20250929";
let spec = registry.get_model_spec(exact_sonnet45);
assert!(spec.is_some());
assert_eq!(spec.unwrap().max_output_tokens, 64000);
let exact_sonnet4 = "claude-sonnet-4-20250514";
let spec = registry.get_model_spec(exact_sonnet4);
assert!(spec.is_some());
assert_eq!(spec.unwrap().max_output_tokens, 64000);
}
#[test]
fn extract_core_model_identifier() {
let registry = ModelRegistry::load().unwrap();
assert_eq!(
registry.extract_core_model_identifier("us.anthropic.claude-3-7-sonnet-20250219-v1:0"),
"claude-3-7-sonnet-20250219"
);
assert_eq!(
registry.extract_core_model_identifier("anthropic.claude-3-haiku-20240307-v1:0"),
"claude-3-haiku-20240307"
);
assert_eq!(
registry.extract_core_model_identifier("claude-3-opus-20240229"),
"claude-3-opus-20240229"
);
assert_eq!(
registry.extract_core_model_identifier("eu.anthropic.claude-sonnet-4-20250514-v2:1"),
"claude-sonnet-4-20250514"
);
}
#[test]
fn beta_header_lookups() {
let registry = ModelRegistry::load().unwrap();
assert_eq!(registry.get_max_output_tokens("claude-opus-4-6"), 128_000);
assert_eq!(registry.get_input_context("claude-opus-4-6"), 200_000);
assert_eq!(
registry.get_input_context_with_beta("claude-opus-4-6", "context-1m-2025-08-07"),
1_000_000
);
assert_eq!(
registry.get_max_output_tokens_with_beta("claude-opus-4-6", "context-1m-2025-08-07"),
128_000
);
assert_eq!(
registry.get_max_output_tokens_with_beta(
"claude-3-7-sonnet-20250219",
"output-128k-2025-02-19"
),
128_000
);
assert_eq!(
registry.get_max_output_tokens("claude-3-7-sonnet-20250219"),
64000
);
let headers = registry.get_beta_headers("claude-opus-4-6");
assert_eq!(headers.len(), 1);
assert_eq!(headers[0].key, "anthropic-beta");
assert_eq!(headers[0].value, "context-1m-2025-08-07");
let headers = registry.get_beta_headers("claude-3-7-sonnet-20250219");
assert_eq!(headers.len(), 2);
let headers = registry.get_beta_headers("claude-3-haiku-20240307");
assert!(headers.is_empty());
let headers = registry.get_beta_headers("unknown-model");
assert!(headers.is_empty());
}
}