Skip to main content

omni_dev/claude/
model_config.rs

1//! AI model configuration and specifications.
2//!
3//! This module provides model specifications loaded from embedded YAML templates
4//! to ensure correct API parameters for different AI models.
5
6use std::collections::HashMap;
7use std::sync::OnceLock;
8
9use anyhow::Result;
10use serde::Deserialize;
11
12/// Beta header that unlocks enhanced model limits.
13#[derive(Debug, Deserialize, Clone)]
14pub struct BetaHeader {
15    /// HTTP header name (e.g., "anthropic-beta").
16    pub key: String,
17    /// Header value (e.g., "context-1m-2025-08-07").
18    pub value: String,
19    /// Overridden max output tokens when this header is active.
20    #[serde(default)]
21    pub max_output_tokens: Option<usize>,
22    /// Overridden input context when this header is active.
23    #[serde(default)]
24    pub input_context: Option<usize>,
25}
26
27/// Model specification from YAML configuration.
28#[derive(Debug, Deserialize, Clone)]
29pub struct ModelSpec {
30    /// AI provider name (e.g., "claude").
31    pub provider: String,
32    /// Human-readable model name (e.g., "Claude Opus 4").
33    pub model: String,
34    /// API identifier used for requests (e.g., "claude-3-opus-20240229").
35    pub api_identifier: String,
36    /// Maximum number of tokens that can be generated in a single response.
37    pub max_output_tokens: usize,
38    /// Maximum number of tokens that can be included in the input context.
39    pub input_context: usize,
40    /// Model generation number (e.g., 3.0, 3.5, 4.0).
41    pub generation: f32,
42    /// Performance tier (e.g., "fast", "balanced", "flagship").
43    pub tier: String,
44    /// Whether this is a legacy model that may be deprecated.
45    #[serde(default)]
46    pub legacy: bool,
47    /// Beta headers that unlock enhanced limits for this model.
48    #[serde(default)]
49    pub beta_headers: Vec<BetaHeader>,
50}
51
52/// Model tier information.
53#[derive(Debug, Deserialize)]
54pub struct TierInfo {
55    /// Human-readable description of the tier.
56    pub description: String,
57    /// List of recommended use cases for this tier.
58    pub use_cases: Vec<String>,
59}
60
61/// Default fallback configuration for a provider.
62#[derive(Debug, Deserialize)]
63pub struct DefaultConfig {
64    /// Default maximum output tokens for unknown models from this provider.
65    pub max_output_tokens: usize,
66    /// Default input context limit for unknown models from this provider.
67    pub input_context: usize,
68}
69
70/// Provider-specific configuration.
71#[derive(Debug, Deserialize)]
72pub struct ProviderConfig {
73    /// Human-readable provider name.
74    pub name: String,
75    /// Base URL for API requests.
76    pub api_base: String,
77    /// Default model identifier to use if none specified.
78    pub default_model: String,
79    /// Available performance tiers and their descriptions.
80    pub tiers: HashMap<String, TierInfo>,
81    /// Default configuration for unknown models.
82    pub defaults: DefaultConfig,
83}
84
85/// Complete model configuration.
86#[derive(Debug, Deserialize)]
87pub struct ModelConfiguration {
88    /// List of all available models.
89    pub models: Vec<ModelSpec>,
90    /// Provider-specific configurations.
91    pub providers: HashMap<String, ProviderConfig>,
92}
93
94/// Model registry for looking up specifications.
95pub struct ModelRegistry {
96    config: ModelConfiguration,
97    by_identifier: HashMap<String, ModelSpec>,
98    by_provider: HashMap<String, Vec<ModelSpec>>,
99}
100
101impl ModelRegistry {
102    /// Loads the model registry from embedded YAML.
103    pub fn load() -> Result<Self> {
104        let yaml_content = include_str!("../templates/models.yaml");
105        let config: ModelConfiguration = serde_yaml::from_str(yaml_content)?;
106
107        // Build lookup maps
108        let mut by_identifier = HashMap::new();
109        let mut by_provider: HashMap<String, Vec<ModelSpec>> = HashMap::new();
110
111        for model in &config.models {
112            by_identifier.insert(model.api_identifier.clone(), model.clone());
113            by_provider
114                .entry(model.provider.clone())
115                .or_default()
116                .push(model.clone());
117        }
118
119        Ok(Self {
120            config,
121            by_identifier,
122            by_provider,
123        })
124    }
125
126    /// Returns the model specification for the given API identifier.
127    pub fn get_model_spec(&self, api_identifier: &str) -> Option<&ModelSpec> {
128        // Try exact match first
129        if let Some(spec) = self.by_identifier.get(api_identifier) {
130            return Some(spec);
131        }
132
133        // Try fuzzy matching for Bedrock-style identifiers
134        self.find_model_by_fuzzy_match(api_identifier)
135    }
136
137    /// Returns the max output tokens for a model, with fallback to provider defaults.
138    pub fn get_max_output_tokens(&self, api_identifier: &str) -> usize {
139        if let Some(spec) = self.get_model_spec(api_identifier) {
140            return spec.max_output_tokens;
141        }
142
143        // Try to infer provider from model identifier and use defaults
144        if let Some(provider) = self.infer_provider(api_identifier) {
145            if let Some(provider_config) = self.config.providers.get(&provider) {
146                return provider_config.defaults.max_output_tokens;
147            }
148        }
149
150        // Ultimate fallback
151        4096
152    }
153
154    /// Returns the input context limit for a model, with fallback to provider defaults.
155    pub fn get_input_context(&self, api_identifier: &str) -> usize {
156        if let Some(spec) = self.get_model_spec(api_identifier) {
157            return spec.input_context;
158        }
159
160        // Try to infer provider from model identifier and use defaults
161        if let Some(provider) = self.infer_provider(api_identifier) {
162            if let Some(provider_config) = self.config.providers.get(&provider) {
163                return provider_config.defaults.input_context;
164            }
165        }
166
167        // Ultimate fallback
168        100000
169    }
170
171    /// Infers the provider from a model identifier.
172    fn infer_provider(&self, api_identifier: &str) -> Option<String> {
173        if api_identifier.starts_with("claude") || api_identifier.contains("anthropic") {
174            Some("claude".to_string())
175        } else {
176            None
177        }
178    }
179
180    /// Finds a model by fuzzy matching for various identifier formats.
181    fn find_model_by_fuzzy_match(&self, api_identifier: &str) -> Option<&ModelSpec> {
182        // Extract core model identifier from various formats:
183        // - Bedrock: "us.anthropic.claude-3-7-sonnet-20250219-v1:0" -> "claude-3-7-sonnet-20250219"
184        // - AWS: "anthropic.claude-3-haiku-20240307-v1:0" -> "claude-3-haiku-20240307"
185        // - Standard: "claude-3-opus-20240229" -> "claude-3-opus-20240229"
186
187        let core_identifier = self.extract_core_model_identifier(api_identifier);
188
189        // Try to find exact match with core identifier
190        if let Some(spec) = self.by_identifier.get(&core_identifier) {
191            return Some(spec);
192        }
193
194        // Try partial matching - look for models that contain the core parts
195        for (stored_id, spec) in &self.by_identifier {
196            if self.models_match_fuzzy(&core_identifier, stored_id) {
197                return Some(spec);
198            }
199        }
200
201        None
202    }
203
204    /// Extracts the core model identifier from various formats.
205    fn extract_core_model_identifier(&self, api_identifier: &str) -> String {
206        let mut identifier = api_identifier.to_string();
207
208        // Remove region prefixes (us., eu., etc.)
209        if let Some(dot_pos) = identifier.find('.') {
210            if identifier[..dot_pos].len() <= 3 {
211                // likely a region code
212                identifier = identifier[dot_pos + 1..].to_string();
213            }
214        }
215
216        // Remove provider prefixes (anthropic.)
217        if identifier.starts_with("anthropic.") {
218            identifier = identifier["anthropic.".len()..].to_string();
219        }
220
221        // Remove version suffixes (-v1:0, -v2:1, etc.)
222        if let Some(version_pos) = identifier.rfind("-v") {
223            if identifier[version_pos..].contains(':') {
224                identifier = identifier[..version_pos].to_string();
225            }
226        }
227
228        identifier
229    }
230
231    /// Checks if two model identifiers represent the same model.
232    fn models_match_fuzzy(&self, input_id: &str, stored_id: &str) -> bool {
233        // For now, just check if they're the same after extraction
234        // This could be enhanced with more sophisticated matching
235        input_id == stored_id
236    }
237
238    /// Checks if a model is legacy.
239    #[must_use]
240    pub fn is_legacy_model(&self, api_identifier: &str) -> bool {
241        self.get_model_spec(api_identifier)
242            .map(|spec| spec.legacy)
243            .unwrap_or(false)
244    }
245
246    /// Returns all available models.
247    pub fn get_all_models(&self) -> &[ModelSpec] {
248        &self.config.models
249    }
250
251    /// Returns models filtered by provider.
252    pub fn get_models_by_provider(&self, provider: &str) -> Vec<&ModelSpec> {
253        self.by_provider
254            .get(provider)
255            .map(|models| models.iter().collect())
256            .unwrap_or_default()
257    }
258
259    /// Returns models filtered by provider and tier.
260    pub fn get_models_by_provider_and_tier(&self, provider: &str, tier: &str) -> Vec<&ModelSpec> {
261        self.get_models_by_provider(provider)
262            .into_iter()
263            .filter(|model| model.tier == tier)
264            .collect()
265    }
266
267    /// Returns the provider configuration.
268    pub fn get_provider_config(&self, provider: &str) -> Option<&ProviderConfig> {
269        self.config.providers.get(provider)
270    }
271
272    /// Returns tier information for a provider.
273    pub fn get_tier_info(&self, provider: &str, tier: &str) -> Option<&TierInfo> {
274        self.config.providers.get(provider)?.tiers.get(tier)
275    }
276
277    /// Returns the beta headers for a model.
278    pub fn get_beta_headers(&self, api_identifier: &str) -> &[BetaHeader] {
279        self.get_model_spec(api_identifier)
280            .map(|spec| spec.beta_headers.as_slice())
281            .unwrap_or_default()
282    }
283
284    /// Returns the max output tokens for a model with a specific beta header active.
285    pub fn get_max_output_tokens_with_beta(&self, api_identifier: &str, beta_value: &str) -> usize {
286        if let Some(spec) = self.get_model_spec(api_identifier) {
287            if let Some(bh) = spec.beta_headers.iter().find(|b| b.value == beta_value) {
288                if let Some(max) = bh.max_output_tokens {
289                    return max;
290                }
291            }
292            return spec.max_output_tokens;
293        }
294        self.get_max_output_tokens(api_identifier)
295    }
296
297    /// Returns the input context for a model with a specific beta header active.
298    pub fn get_input_context_with_beta(&self, api_identifier: &str, beta_value: &str) -> usize {
299        if let Some(spec) = self.get_model_spec(api_identifier) {
300            if let Some(bh) = spec.beta_headers.iter().find(|b| b.value == beta_value) {
301                if let Some(ctx) = bh.input_context {
302                    return ctx;
303                }
304            }
305            return spec.input_context;
306        }
307        self.get_input_context(api_identifier)
308    }
309}
310
311/// Global model registry instance.
312static MODEL_REGISTRY: OnceLock<ModelRegistry> = OnceLock::new();
313
314/// Returns the global model registry instance.
315pub fn get_model_registry() -> &'static ModelRegistry {
316    MODEL_REGISTRY.get_or_init(|| ModelRegistry::load().expect("Failed to load model registry"))
317}
318
319#[cfg(test)]
320mod tests {
321    use super::*;
322
323    #[test]
324    fn load_model_registry() {
325        let registry = ModelRegistry::load().unwrap();
326        assert!(!registry.config.models.is_empty());
327        assert!(registry.config.providers.contains_key("claude"));
328    }
329
330    #[test]
331    fn claude_model_lookup() {
332        let registry = ModelRegistry::load().unwrap();
333
334        // Test legacy Claude 3 Opus
335        let opus_spec = registry.get_model_spec("claude-3-opus-20240229");
336        assert!(opus_spec.is_some());
337        assert_eq!(opus_spec.unwrap().max_output_tokens, 4096);
338        assert_eq!(opus_spec.unwrap().provider, "claude");
339        assert!(registry.is_legacy_model("claude-3-opus-20240229"));
340
341        // Test Claude 4.5 Sonnet (current generation)
342        let sonnet45_tokens = registry.get_max_output_tokens("claude-sonnet-4-5-20250929");
343        assert_eq!(sonnet45_tokens, 64000);
344
345        // Test legacy Claude 4 Sonnet
346        let sonnet4_tokens = registry.get_max_output_tokens("claude-sonnet-4-20250514");
347        assert_eq!(sonnet4_tokens, 64000);
348        assert!(registry.is_legacy_model("claude-sonnet-4-20250514"));
349
350        // Test unknown model falls back to provider defaults
351        let unknown_tokens = registry.get_max_output_tokens("claude-unknown-model");
352        assert_eq!(unknown_tokens, 4096); // Should use Claude provider defaults
353    }
354
355    #[test]
356    fn provider_filtering() {
357        let registry = ModelRegistry::load().unwrap();
358
359        let claude_models = registry.get_models_by_provider("claude");
360        assert!(!claude_models.is_empty());
361
362        let fast_claude_models = registry.get_models_by_provider_and_tier("claude", "fast");
363        assert!(!fast_claude_models.is_empty());
364
365        let tier_info = registry.get_tier_info("claude", "fast");
366        assert!(tier_info.is_some());
367    }
368
369    #[test]
370    fn provider_config() {
371        let registry = ModelRegistry::load().unwrap();
372
373        let claude_config = registry.get_provider_config("claude");
374        assert!(claude_config.is_some());
375        assert_eq!(claude_config.unwrap().name, "Anthropic Claude");
376    }
377
378    #[test]
379    fn fuzzy_model_matching() {
380        let registry = ModelRegistry::load().unwrap();
381
382        // Test Bedrock-style identifiers
383        let bedrock_3_7_sonnet = "us.anthropic.claude-3-7-sonnet-20250219-v1:0";
384        let spec = registry.get_model_spec(bedrock_3_7_sonnet);
385        assert!(spec.is_some());
386        assert_eq!(spec.unwrap().api_identifier, "claude-3-7-sonnet-20250219");
387        assert_eq!(spec.unwrap().max_output_tokens, 64000);
388
389        // Test AWS-style identifiers
390        let aws_haiku = "anthropic.claude-3-haiku-20240307-v1:0";
391        let spec = registry.get_model_spec(aws_haiku);
392        assert!(spec.is_some());
393        assert_eq!(spec.unwrap().api_identifier, "claude-3-haiku-20240307");
394        assert_eq!(spec.unwrap().max_output_tokens, 4096);
395
396        // Test European region
397        let eu_opus = "eu.anthropic.claude-3-opus-20240229-v2:1";
398        let spec = registry.get_model_spec(eu_opus);
399        assert!(spec.is_some());
400        assert_eq!(spec.unwrap().api_identifier, "claude-3-opus-20240229");
401        assert_eq!(spec.unwrap().max_output_tokens, 4096);
402
403        // Test exact match still works for Claude 4.5 Sonnet
404        let exact_sonnet45 = "claude-sonnet-4-5-20250929";
405        let spec = registry.get_model_spec(exact_sonnet45);
406        assert!(spec.is_some());
407        assert_eq!(spec.unwrap().max_output_tokens, 64000);
408
409        // Test legacy Claude 4 Sonnet
410        let exact_sonnet4 = "claude-sonnet-4-20250514";
411        let spec = registry.get_model_spec(exact_sonnet4);
412        assert!(spec.is_some());
413        assert_eq!(spec.unwrap().max_output_tokens, 64000);
414    }
415
416    #[test]
417    fn extract_core_model_identifier() {
418        let registry = ModelRegistry::load().unwrap();
419
420        // Test various formats
421        assert_eq!(
422            registry.extract_core_model_identifier("us.anthropic.claude-3-7-sonnet-20250219-v1:0"),
423            "claude-3-7-sonnet-20250219"
424        );
425
426        assert_eq!(
427            registry.extract_core_model_identifier("anthropic.claude-3-haiku-20240307-v1:0"),
428            "claude-3-haiku-20240307"
429        );
430
431        assert_eq!(
432            registry.extract_core_model_identifier("claude-3-opus-20240229"),
433            "claude-3-opus-20240229"
434        );
435
436        assert_eq!(
437            registry.extract_core_model_identifier("eu.anthropic.claude-sonnet-4-20250514-v2:1"),
438            "claude-sonnet-4-20250514"
439        );
440    }
441
442    #[test]
443    fn beta_header_lookups() {
444        let registry = ModelRegistry::load().unwrap();
445
446        // Opus 4.6 base limits
447        assert_eq!(registry.get_max_output_tokens("claude-opus-4-6"), 128000);
448        assert_eq!(registry.get_input_context("claude-opus-4-6"), 200000);
449
450        // Opus 4.6 with 1M context beta
451        assert_eq!(
452            registry.get_input_context_with_beta("claude-opus-4-6", "context-1m-2025-08-07"),
453            1000000
454        );
455        // max_output_tokens unchanged with context beta
456        assert_eq!(
457            registry.get_max_output_tokens_with_beta("claude-opus-4-6", "context-1m-2025-08-07"),
458            128000
459        );
460
461        // Sonnet 3.7 with output-128k beta
462        assert_eq!(
463            registry.get_max_output_tokens_with_beta(
464                "claude-3-7-sonnet-20250219",
465                "output-128k-2025-02-19"
466            ),
467            128000
468        );
469
470        // Sonnet 3.7 base max_output_tokens without beta
471        assert_eq!(
472            registry.get_max_output_tokens("claude-3-7-sonnet-20250219"),
473            64000
474        );
475
476        // Beta headers accessor
477        let headers = registry.get_beta_headers("claude-opus-4-6");
478        assert_eq!(headers.len(), 1);
479        assert_eq!(headers[0].key, "anthropic-beta");
480        assert_eq!(headers[0].value, "context-1m-2025-08-07");
481
482        // Sonnet 3.7 has two beta headers
483        let headers = registry.get_beta_headers("claude-3-7-sonnet-20250219");
484        assert_eq!(headers.len(), 2);
485
486        // Model without beta headers returns empty slice
487        let headers = registry.get_beta_headers("claude-3-haiku-20240307");
488        assert!(headers.is_empty());
489
490        // Unknown model returns empty slice
491        let headers = registry.get_beta_headers("unknown-model");
492        assert!(headers.is_empty());
493    }
494}