Skip to main content

omni_dev/claude/
model_config.rs

1//! AI model configuration and specifications
2//!
3//! This module provides model specifications loaded from embedded YAML templates
4//! to ensure correct API parameters for different AI models.
5
6use anyhow::Result;
7use serde::Deserialize;
8use std::collections::HashMap;
9use std::sync::OnceLock;
10
11/// Beta header that unlocks enhanced model limits
12#[derive(Debug, Deserialize, Clone)]
13pub struct BetaHeader {
14    /// HTTP header name (e.g., "anthropic-beta")
15    pub key: String,
16    /// Header value (e.g., "context-1m-2025-08-07")
17    pub value: String,
18    /// Overridden max output tokens when this header is active
19    #[serde(default)]
20    pub max_output_tokens: Option<usize>,
21    /// Overridden input context when this header is active
22    #[serde(default)]
23    pub input_context: Option<usize>,
24}
25
26/// Model specification from YAML configuration
27#[derive(Debug, Deserialize, Clone)]
28pub struct ModelSpec {
29    /// AI provider name (e.g., "claude")
30    pub provider: String,
31    /// Human-readable model name (e.g., "Claude Opus 4")
32    pub model: String,
33    /// API identifier used for requests (e.g., "claude-3-opus-20240229")
34    pub api_identifier: String,
35    /// Maximum number of tokens that can be generated in a single response
36    pub max_output_tokens: usize,
37    /// Maximum number of tokens that can be included in the input context
38    pub input_context: usize,
39    /// Model generation number (e.g., 3.0, 3.5, 4.0)
40    pub generation: f32,
41    /// Performance tier (e.g., "fast", "balanced", "flagship")
42    pub tier: String,
43    /// Whether this is a legacy model that may be deprecated
44    #[serde(default)]
45    pub legacy: bool,
46    /// Beta headers that unlock enhanced limits for this model
47    #[serde(default)]
48    pub beta_headers: Vec<BetaHeader>,
49}
50
51/// Model tier information
52#[derive(Debug, Deserialize)]
53pub struct TierInfo {
54    /// Human-readable description of the tier
55    pub description: String,
56    /// List of recommended use cases for this tier
57    pub use_cases: Vec<String>,
58}
59
60/// Default fallback configuration for a provider
61#[derive(Debug, Deserialize)]
62pub struct DefaultConfig {
63    /// Default maximum output tokens for unknown models from this provider
64    pub max_output_tokens: usize,
65    /// Default input context limit for unknown models from this provider
66    pub input_context: usize,
67}
68
69/// Provider-specific configuration
70#[derive(Debug, Deserialize)]
71pub struct ProviderConfig {
72    /// Human-readable provider name
73    pub name: String,
74    /// Base URL for API requests
75    pub api_base: String,
76    /// Default model identifier to use if none specified
77    pub default_model: String,
78    /// Available performance tiers and their descriptions
79    pub tiers: HashMap<String, TierInfo>,
80    /// Default configuration for unknown models
81    pub defaults: DefaultConfig,
82}
83
84/// Complete model configuration
85#[derive(Debug, Deserialize)]
86pub struct ModelConfiguration {
87    /// List of all available models
88    pub models: Vec<ModelSpec>,
89    /// Provider-specific configurations
90    pub providers: HashMap<String, ProviderConfig>,
91}
92
93/// Model registry for looking up specifications
94pub struct ModelRegistry {
95    config: ModelConfiguration,
96    by_identifier: HashMap<String, ModelSpec>,
97    by_provider: HashMap<String, Vec<ModelSpec>>,
98}
99
100impl ModelRegistry {
101    /// Load model registry from embedded YAML
102    pub fn load() -> Result<Self> {
103        let yaml_content = include_str!("../templates/models.yaml");
104        let config: ModelConfiguration = serde_yaml::from_str(yaml_content)?;
105
106        // Build lookup maps
107        let mut by_identifier = HashMap::new();
108        let mut by_provider: HashMap<String, Vec<ModelSpec>> = HashMap::new();
109
110        for model in &config.models {
111            by_identifier.insert(model.api_identifier.clone(), model.clone());
112            by_provider
113                .entry(model.provider.clone())
114                .or_default()
115                .push(model.clone());
116        }
117
118        Ok(Self {
119            config,
120            by_identifier,
121            by_provider,
122        })
123    }
124
125    /// Get model specification by API identifier
126    pub fn get_model_spec(&self, api_identifier: &str) -> Option<&ModelSpec> {
127        // Try exact match first
128        if let Some(spec) = self.by_identifier.get(api_identifier) {
129            return Some(spec);
130        }
131
132        // Try fuzzy matching for Bedrock-style identifiers
133        self.find_model_by_fuzzy_match(api_identifier)
134    }
135
136    /// Get max output tokens for a model, with fallback to provider defaults
137    pub fn get_max_output_tokens(&self, api_identifier: &str) -> usize {
138        if let Some(spec) = self.get_model_spec(api_identifier) {
139            return spec.max_output_tokens;
140        }
141
142        // Try to infer provider from model identifier and use defaults
143        if let Some(provider) = self.infer_provider(api_identifier) {
144            if let Some(provider_config) = self.config.providers.get(&provider) {
145                return provider_config.defaults.max_output_tokens;
146            }
147        }
148
149        // Ultimate fallback
150        4096
151    }
152
153    /// Get input context limit for a model, with fallback to provider defaults
154    pub fn get_input_context(&self, api_identifier: &str) -> usize {
155        if let Some(spec) = self.get_model_spec(api_identifier) {
156            return spec.input_context;
157        }
158
159        // Try to infer provider from model identifier and use defaults
160        if let Some(provider) = self.infer_provider(api_identifier) {
161            if let Some(provider_config) = self.config.providers.get(&provider) {
162                return provider_config.defaults.input_context;
163            }
164        }
165
166        // Ultimate fallback
167        100000
168    }
169
170    /// Infer provider from model identifier
171    fn infer_provider(&self, api_identifier: &str) -> Option<String> {
172        if api_identifier.starts_with("claude") || api_identifier.contains("anthropic") {
173            Some("claude".to_string())
174        } else {
175            None
176        }
177    }
178
179    /// Find model by fuzzy matching for various identifier formats
180    fn find_model_by_fuzzy_match(&self, api_identifier: &str) -> Option<&ModelSpec> {
181        // Extract core model identifier from various formats:
182        // - Bedrock: "us.anthropic.claude-3-7-sonnet-20250219-v1:0" -> "claude-3-7-sonnet-20250219"
183        // - AWS: "anthropic.claude-3-haiku-20240307-v1:0" -> "claude-3-haiku-20240307"
184        // - Standard: "claude-3-opus-20240229" -> "claude-3-opus-20240229"
185
186        let core_identifier = self.extract_core_model_identifier(api_identifier);
187
188        // Try to find exact match with core identifier
189        if let Some(spec) = self.by_identifier.get(&core_identifier) {
190            return Some(spec);
191        }
192
193        // Try partial matching - look for models that contain the core parts
194        for (stored_id, spec) in &self.by_identifier {
195            if self.models_match_fuzzy(&core_identifier, stored_id) {
196                return Some(spec);
197            }
198        }
199
200        None
201    }
202
203    /// Extract core model identifier from various formats
204    fn extract_core_model_identifier(&self, api_identifier: &str) -> String {
205        let mut identifier = api_identifier.to_string();
206
207        // Remove region prefixes (us., eu., etc.)
208        if let Some(dot_pos) = identifier.find('.') {
209            if identifier[..dot_pos].len() <= 3 {
210                // likely a region code
211                identifier = identifier[dot_pos + 1..].to_string();
212            }
213        }
214
215        // Remove provider prefixes (anthropic.)
216        if identifier.starts_with("anthropic.") {
217            identifier = identifier["anthropic.".len()..].to_string();
218        }
219
220        // Remove version suffixes (-v1:0, -v2:1, etc.)
221        if let Some(version_pos) = identifier.rfind("-v") {
222            if identifier[version_pos..].contains(':') {
223                identifier = identifier[..version_pos].to_string();
224            }
225        }
226
227        identifier
228    }
229
230    /// Check if two model identifiers represent the same model
231    fn models_match_fuzzy(&self, input_id: &str, stored_id: &str) -> bool {
232        // For now, just check if they're the same after extraction
233        // This could be enhanced with more sophisticated matching
234        input_id == stored_id
235    }
236
237    /// Check if a model is legacy
238    pub fn is_legacy_model(&self, api_identifier: &str) -> bool {
239        self.get_model_spec(api_identifier)
240            .map(|spec| spec.legacy)
241            .unwrap_or(false)
242    }
243
244    /// Get all available models
245    pub fn get_all_models(&self) -> &[ModelSpec] {
246        &self.config.models
247    }
248
249    /// Get models by provider
250    pub fn get_models_by_provider(&self, provider: &str) -> Vec<&ModelSpec> {
251        self.by_provider
252            .get(provider)
253            .map(|models| models.iter().collect())
254            .unwrap_or_default()
255    }
256
257    /// Get models by provider and tier
258    pub fn get_models_by_provider_and_tier(&self, provider: &str, tier: &str) -> Vec<&ModelSpec> {
259        self.get_models_by_provider(provider)
260            .into_iter()
261            .filter(|model| model.tier == tier)
262            .collect()
263    }
264
265    /// Get provider configuration
266    pub fn get_provider_config(&self, provider: &str) -> Option<&ProviderConfig> {
267        self.config.providers.get(provider)
268    }
269
270    /// Get tier information for a provider
271    pub fn get_tier_info(&self, provider: &str, tier: &str) -> Option<&TierInfo> {
272        self.config.providers.get(provider)?.tiers.get(tier)
273    }
274
275    /// Get beta headers for a model
276    pub fn get_beta_headers(&self, api_identifier: &str) -> &[BetaHeader] {
277        self.get_model_spec(api_identifier)
278            .map(|spec| spec.beta_headers.as_slice())
279            .unwrap_or_default()
280    }
281
282    /// Get max output tokens for a model with a specific beta header active
283    pub fn get_max_output_tokens_with_beta(&self, api_identifier: &str, beta_value: &str) -> usize {
284        if let Some(spec) = self.get_model_spec(api_identifier) {
285            if let Some(bh) = spec.beta_headers.iter().find(|b| b.value == beta_value) {
286                if let Some(max) = bh.max_output_tokens {
287                    return max;
288                }
289            }
290            return spec.max_output_tokens;
291        }
292        self.get_max_output_tokens(api_identifier)
293    }
294
295    /// Get input context for a model with a specific beta header active
296    pub fn get_input_context_with_beta(&self, api_identifier: &str, beta_value: &str) -> usize {
297        if let Some(spec) = self.get_model_spec(api_identifier) {
298            if let Some(bh) = spec.beta_headers.iter().find(|b| b.value == beta_value) {
299                if let Some(ctx) = bh.input_context {
300                    return ctx;
301                }
302            }
303            return spec.input_context;
304        }
305        self.get_input_context(api_identifier)
306    }
307}
308
309/// Global model registry instance
310static MODEL_REGISTRY: OnceLock<ModelRegistry> = OnceLock::new();
311
312/// Get the global model registry instance
313pub fn get_model_registry() -> &'static ModelRegistry {
314    MODEL_REGISTRY.get_or_init(|| ModelRegistry::load().expect("Failed to load model registry"))
315}
316
317#[cfg(test)]
318mod tests {
319    use super::*;
320
321    #[test]
322    fn test_load_model_registry() {
323        let registry = ModelRegistry::load().unwrap();
324        assert!(!registry.config.models.is_empty());
325        assert!(registry.config.providers.contains_key("claude"));
326    }
327
328    #[test]
329    fn test_claude_model_lookup() {
330        let registry = ModelRegistry::load().unwrap();
331
332        // Test legacy Claude 3 Opus
333        let opus_spec = registry.get_model_spec("claude-3-opus-20240229");
334        assert!(opus_spec.is_some());
335        assert_eq!(opus_spec.unwrap().max_output_tokens, 4096);
336        assert_eq!(opus_spec.unwrap().provider, "claude");
337        assert!(registry.is_legacy_model("claude-3-opus-20240229"));
338
339        // Test Claude 4.5 Sonnet (current generation)
340        let sonnet45_tokens = registry.get_max_output_tokens("claude-sonnet-4-5-20250929");
341        assert_eq!(sonnet45_tokens, 64000);
342
343        // Test legacy Claude 4 Sonnet
344        let sonnet4_tokens = registry.get_max_output_tokens("claude-sonnet-4-20250514");
345        assert_eq!(sonnet4_tokens, 64000);
346        assert!(registry.is_legacy_model("claude-sonnet-4-20250514"));
347
348        // Test unknown model falls back to provider defaults
349        let unknown_tokens = registry.get_max_output_tokens("claude-unknown-model");
350        assert_eq!(unknown_tokens, 4096); // Should use Claude provider defaults
351    }
352
353    #[test]
354    fn test_provider_filtering() {
355        let registry = ModelRegistry::load().unwrap();
356
357        let claude_models = registry.get_models_by_provider("claude");
358        assert!(!claude_models.is_empty());
359
360        let fast_claude_models = registry.get_models_by_provider_and_tier("claude", "fast");
361        assert!(!fast_claude_models.is_empty());
362
363        let tier_info = registry.get_tier_info("claude", "fast");
364        assert!(tier_info.is_some());
365    }
366
367    #[test]
368    fn test_provider_config() {
369        let registry = ModelRegistry::load().unwrap();
370
371        let claude_config = registry.get_provider_config("claude");
372        assert!(claude_config.is_some());
373        assert_eq!(claude_config.unwrap().name, "Anthropic Claude");
374    }
375
376    #[test]
377    fn test_fuzzy_model_matching() {
378        let registry = ModelRegistry::load().unwrap();
379
380        // Test Bedrock-style identifiers
381        let bedrock_3_7_sonnet = "us.anthropic.claude-3-7-sonnet-20250219-v1:0";
382        let spec = registry.get_model_spec(bedrock_3_7_sonnet);
383        assert!(spec.is_some());
384        assert_eq!(spec.unwrap().api_identifier, "claude-3-7-sonnet-20250219");
385        assert_eq!(spec.unwrap().max_output_tokens, 64000);
386
387        // Test AWS-style identifiers
388        let aws_haiku = "anthropic.claude-3-haiku-20240307-v1:0";
389        let spec = registry.get_model_spec(aws_haiku);
390        assert!(spec.is_some());
391        assert_eq!(spec.unwrap().api_identifier, "claude-3-haiku-20240307");
392        assert_eq!(spec.unwrap().max_output_tokens, 4096);
393
394        // Test European region
395        let eu_opus = "eu.anthropic.claude-3-opus-20240229-v2:1";
396        let spec = registry.get_model_spec(eu_opus);
397        assert!(spec.is_some());
398        assert_eq!(spec.unwrap().api_identifier, "claude-3-opus-20240229");
399        assert_eq!(spec.unwrap().max_output_tokens, 4096);
400
401        // Test exact match still works for Claude 4.5 Sonnet
402        let exact_sonnet45 = "claude-sonnet-4-5-20250929";
403        let spec = registry.get_model_spec(exact_sonnet45);
404        assert!(spec.is_some());
405        assert_eq!(spec.unwrap().max_output_tokens, 64000);
406
407        // Test legacy Claude 4 Sonnet
408        let exact_sonnet4 = "claude-sonnet-4-20250514";
409        let spec = registry.get_model_spec(exact_sonnet4);
410        assert!(spec.is_some());
411        assert_eq!(spec.unwrap().max_output_tokens, 64000);
412    }
413
414    #[test]
415    fn test_extract_core_model_identifier() {
416        let registry = ModelRegistry::load().unwrap();
417
418        // Test various formats
419        assert_eq!(
420            registry.extract_core_model_identifier("us.anthropic.claude-3-7-sonnet-20250219-v1:0"),
421            "claude-3-7-sonnet-20250219"
422        );
423
424        assert_eq!(
425            registry.extract_core_model_identifier("anthropic.claude-3-haiku-20240307-v1:0"),
426            "claude-3-haiku-20240307"
427        );
428
429        assert_eq!(
430            registry.extract_core_model_identifier("claude-3-opus-20240229"),
431            "claude-3-opus-20240229"
432        );
433
434        assert_eq!(
435            registry.extract_core_model_identifier("eu.anthropic.claude-sonnet-4-20250514-v2:1"),
436            "claude-sonnet-4-20250514"
437        );
438    }
439
440    #[test]
441    fn test_beta_header_lookups() {
442        let registry = ModelRegistry::load().unwrap();
443
444        // Opus 4.6 base limits
445        assert_eq!(registry.get_max_output_tokens("claude-opus-4-6"), 128000);
446        assert_eq!(registry.get_input_context("claude-opus-4-6"), 200000);
447
448        // Opus 4.6 with 1M context beta
449        assert_eq!(
450            registry.get_input_context_with_beta("claude-opus-4-6", "context-1m-2025-08-07"),
451            1000000
452        );
453        // max_output_tokens unchanged with context beta
454        assert_eq!(
455            registry.get_max_output_tokens_with_beta("claude-opus-4-6", "context-1m-2025-08-07"),
456            128000
457        );
458
459        // Sonnet 3.7 with output-128k beta
460        assert_eq!(
461            registry.get_max_output_tokens_with_beta(
462                "claude-3-7-sonnet-20250219",
463                "output-128k-2025-02-19"
464            ),
465            128000
466        );
467
468        // Sonnet 3.7 base max_output_tokens without beta
469        assert_eq!(
470            registry.get_max_output_tokens("claude-3-7-sonnet-20250219"),
471            64000
472        );
473
474        // Beta headers accessor
475        let headers = registry.get_beta_headers("claude-opus-4-6");
476        assert_eq!(headers.len(), 1);
477        assert_eq!(headers[0].key, "anthropic-beta");
478        assert_eq!(headers[0].value, "context-1m-2025-08-07");
479
480        // Sonnet 3.7 has two beta headers
481        let headers = registry.get_beta_headers("claude-3-7-sonnet-20250219");
482        assert_eq!(headers.len(), 2);
483
484        // Model without beta headers returns empty slice
485        let headers = registry.get_beta_headers("claude-3-haiku-20240307");
486        assert!(headers.is_empty());
487
488        // Unknown model returns empty slice
489        let headers = registry.get_beta_headers("unknown-model");
490        assert!(headers.is_empty());
491    }
492}