omni_dev/claude/
model_config.rs

1//! AI model configuration and specifications
2//!
3//! This module provides model specifications loaded from embedded YAML templates
4//! to ensure correct API parameters for different AI models.
5
6use anyhow::Result;
7use serde::Deserialize;
8use std::collections::HashMap;
9use std::sync::OnceLock;
10
11/// Model specification from YAML configuration
12#[derive(Debug, Deserialize, Clone)]
13pub struct ModelSpec {
14    /// AI provider name (e.g., "claude")
15    pub provider: String,
16    /// Human-readable model name (e.g., "Claude Opus 4")
17    pub model: String,
18    /// API identifier used for requests (e.g., "claude-3-opus-20240229")
19    pub api_identifier: String,
20    /// Maximum number of tokens that can be generated in a single response
21    pub max_output_tokens: usize,
22    /// Maximum number of tokens that can be included in the input context
23    pub input_context: usize,
24    /// Model generation number (e.g., 3.0, 3.5, 4.0)
25    pub generation: f32,
26    /// Performance tier (e.g., "fast", "balanced", "flagship")
27    pub tier: String,
28    /// Whether this is a legacy model that may be deprecated
29    #[serde(default)]
30    pub legacy: bool,
31}
32
33/// Model tier information
34#[derive(Debug, Deserialize)]
35pub struct TierInfo {
36    /// Human-readable description of the tier
37    pub description: String,
38    /// List of recommended use cases for this tier
39    pub use_cases: Vec<String>,
40}
41
42/// Default fallback configuration for a provider
43#[derive(Debug, Deserialize)]
44pub struct DefaultConfig {
45    /// Default maximum output tokens for unknown models from this provider
46    pub max_output_tokens: usize,
47    /// Default input context limit for unknown models from this provider
48    pub input_context: usize,
49}
50
51/// Provider-specific configuration
52#[derive(Debug, Deserialize)]
53pub struct ProviderConfig {
54    /// Human-readable provider name
55    pub name: String,
56    /// Base URL for API requests
57    pub api_base: String,
58    /// Default model identifier to use if none specified
59    pub default_model: String,
60    /// Available performance tiers and their descriptions
61    pub tiers: HashMap<String, TierInfo>,
62    /// Default configuration for unknown models
63    pub defaults: DefaultConfig,
64}
65
66/// Complete model configuration
67#[derive(Debug, Deserialize)]
68pub struct ModelConfiguration {
69    /// List of all available models
70    pub models: Vec<ModelSpec>,
71    /// Provider-specific configurations
72    pub providers: HashMap<String, ProviderConfig>,
73}
74
75/// Model registry for looking up specifications
76pub struct ModelRegistry {
77    config: ModelConfiguration,
78    by_identifier: HashMap<String, ModelSpec>,
79    by_provider: HashMap<String, Vec<ModelSpec>>,
80}
81
82impl ModelRegistry {
83    /// Load model registry from embedded YAML
84    pub fn load() -> Result<Self> {
85        let yaml_content = include_str!("../templates/models.yaml");
86        let config: ModelConfiguration = serde_yaml::from_str(yaml_content)?;
87
88        // Build lookup maps
89        let mut by_identifier = HashMap::new();
90        let mut by_provider: HashMap<String, Vec<ModelSpec>> = HashMap::new();
91
92        for model in &config.models {
93            by_identifier.insert(model.api_identifier.clone(), model.clone());
94            by_provider
95                .entry(model.provider.clone())
96                .or_default()
97                .push(model.clone());
98        }
99
100        Ok(Self {
101            config,
102            by_identifier,
103            by_provider,
104        })
105    }
106
107    /// Get model specification by API identifier
108    pub fn get_model_spec(&self, api_identifier: &str) -> Option<&ModelSpec> {
109        // Try exact match first
110        if let Some(spec) = self.by_identifier.get(api_identifier) {
111            return Some(spec);
112        }
113
114        // Try fuzzy matching for Bedrock-style identifiers
115        self.find_model_by_fuzzy_match(api_identifier)
116    }
117
118    /// Get max output tokens for a model, with fallback to provider defaults
119    pub fn get_max_output_tokens(&self, api_identifier: &str) -> usize {
120        if let Some(spec) = self.get_model_spec(api_identifier) {
121            return spec.max_output_tokens;
122        }
123
124        // Try to infer provider from model identifier and use defaults
125        if let Some(provider) = self.infer_provider(api_identifier) {
126            if let Some(provider_config) = self.config.providers.get(&provider) {
127                return provider_config.defaults.max_output_tokens;
128            }
129        }
130
131        // Ultimate fallback
132        4096
133    }
134
135    /// Get input context limit for a model, with fallback to provider defaults
136    pub fn get_input_context(&self, api_identifier: &str) -> usize {
137        if let Some(spec) = self.get_model_spec(api_identifier) {
138            return spec.input_context;
139        }
140
141        // Try to infer provider from model identifier and use defaults
142        if let Some(provider) = self.infer_provider(api_identifier) {
143            if let Some(provider_config) = self.config.providers.get(&provider) {
144                return provider_config.defaults.input_context;
145            }
146        }
147
148        // Ultimate fallback
149        100000
150    }
151
152    /// Infer provider from model identifier
153    fn infer_provider(&self, api_identifier: &str) -> Option<String> {
154        if api_identifier.starts_with("claude") || api_identifier.contains("anthropic") {
155            Some("claude".to_string())
156        } else {
157            None
158        }
159    }
160
161    /// Find model by fuzzy matching for various identifier formats
162    fn find_model_by_fuzzy_match(&self, api_identifier: &str) -> Option<&ModelSpec> {
163        // Extract core model identifier from various formats:
164        // - Bedrock: "us.anthropic.claude-3-7-sonnet-20250219-v1:0" -> "claude-3-7-sonnet-20250219"
165        // - AWS: "anthropic.claude-3-haiku-20240307-v1:0" -> "claude-3-haiku-20240307"
166        // - Standard: "claude-3-opus-20240229" -> "claude-3-opus-20240229"
167
168        let core_identifier = self.extract_core_model_identifier(api_identifier);
169
170        // Try to find exact match with core identifier
171        if let Some(spec) = self.by_identifier.get(&core_identifier) {
172            return Some(spec);
173        }
174
175        // Try partial matching - look for models that contain the core parts
176        for (stored_id, spec) in &self.by_identifier {
177            if self.models_match_fuzzy(&core_identifier, stored_id) {
178                return Some(spec);
179            }
180        }
181
182        None
183    }
184
185    /// Extract core model identifier from various formats
186    fn extract_core_model_identifier(&self, api_identifier: &str) -> String {
187        let mut identifier = api_identifier.to_string();
188
189        // Remove region prefixes (us., eu., etc.)
190        if let Some(dot_pos) = identifier.find('.') {
191            if identifier[..dot_pos].len() <= 3 {
192                // likely a region code
193                identifier = identifier[dot_pos + 1..].to_string();
194            }
195        }
196
197        // Remove provider prefixes (anthropic.)
198        if identifier.starts_with("anthropic.") {
199            identifier = identifier["anthropic.".len()..].to_string();
200        }
201
202        // Remove version suffixes (-v1:0, -v2:1, etc.)
203        if let Some(version_pos) = identifier.rfind("-v") {
204            if identifier[version_pos..].contains(':') {
205                identifier = identifier[..version_pos].to_string();
206            }
207        }
208
209        identifier
210    }
211
212    /// Check if two model identifiers represent the same model
213    fn models_match_fuzzy(&self, input_id: &str, stored_id: &str) -> bool {
214        // For now, just check if they're the same after extraction
215        // This could be enhanced with more sophisticated matching
216        input_id == stored_id
217    }
218
219    /// Check if a model is legacy
220    pub fn is_legacy_model(&self, api_identifier: &str) -> bool {
221        self.get_model_spec(api_identifier)
222            .map(|spec| spec.legacy)
223            .unwrap_or(false)
224    }
225
226    /// Get all available models
227    pub fn get_all_models(&self) -> &[ModelSpec] {
228        &self.config.models
229    }
230
231    /// Get models by provider
232    pub fn get_models_by_provider(&self, provider: &str) -> Vec<&ModelSpec> {
233        self.by_provider
234            .get(provider)
235            .map(|models| models.iter().collect())
236            .unwrap_or_default()
237    }
238
239    /// Get models by provider and tier
240    pub fn get_models_by_provider_and_tier(&self, provider: &str, tier: &str) -> Vec<&ModelSpec> {
241        self.get_models_by_provider(provider)
242            .into_iter()
243            .filter(|model| model.tier == tier)
244            .collect()
245    }
246
247    /// Get provider configuration
248    pub fn get_provider_config(&self, provider: &str) -> Option<&ProviderConfig> {
249        self.config.providers.get(provider)
250    }
251
252    /// Get tier information for a provider
253    pub fn get_tier_info(&self, provider: &str, tier: &str) -> Option<&TierInfo> {
254        self.config.providers.get(provider)?.tiers.get(tier)
255    }
256}
257
258/// Global model registry instance
259static MODEL_REGISTRY: OnceLock<ModelRegistry> = OnceLock::new();
260
261/// Get the global model registry instance
262pub fn get_model_registry() -> &'static ModelRegistry {
263    MODEL_REGISTRY.get_or_init(|| ModelRegistry::load().expect("Failed to load model registry"))
264}
265
266#[cfg(test)]
267mod tests {
268    use super::*;
269
270    #[test]
271    fn test_load_model_registry() {
272        let registry = ModelRegistry::load().unwrap();
273        assert!(!registry.config.models.is_empty());
274        assert!(registry.config.providers.contains_key("claude"));
275    }
276
277    #[test]
278    fn test_claude_model_lookup() {
279        let registry = ModelRegistry::load().unwrap();
280
281        // Test legacy Claude 3 Opus
282        let opus_spec = registry.get_model_spec("claude-3-opus-20240229");
283        assert!(opus_spec.is_some());
284        assert_eq!(opus_spec.unwrap().max_output_tokens, 4096);
285        assert_eq!(opus_spec.unwrap().provider, "claude");
286        assert!(registry.is_legacy_model("claude-3-opus-20240229"));
287
288        // Test newer model
289        let sonnet4_tokens = registry.get_max_output_tokens("claude-sonnet-4-20250514");
290        assert_eq!(sonnet4_tokens, 64000);
291
292        // Test unknown model falls back to provider defaults
293        let unknown_tokens = registry.get_max_output_tokens("claude-unknown-model");
294        assert_eq!(unknown_tokens, 4096); // Should use Claude provider defaults
295    }
296
297    #[test]
298    fn test_provider_filtering() {
299        let registry = ModelRegistry::load().unwrap();
300
301        let claude_models = registry.get_models_by_provider("claude");
302        assert!(!claude_models.is_empty());
303
304        let fast_claude_models = registry.get_models_by_provider_and_tier("claude", "fast");
305        assert!(!fast_claude_models.is_empty());
306
307        let tier_info = registry.get_tier_info("claude", "fast");
308        assert!(tier_info.is_some());
309    }
310
311    #[test]
312    fn test_provider_config() {
313        let registry = ModelRegistry::load().unwrap();
314
315        let claude_config = registry.get_provider_config("claude");
316        assert!(claude_config.is_some());
317        assert_eq!(claude_config.unwrap().name, "Anthropic Claude");
318    }
319
320    #[test]
321    fn test_fuzzy_model_matching() {
322        let registry = ModelRegistry::load().unwrap();
323
324        // Test Bedrock-style identifiers
325        let bedrock_3_7_sonnet = "us.anthropic.claude-3-7-sonnet-20250219-v1:0";
326        let spec = registry.get_model_spec(bedrock_3_7_sonnet);
327        assert!(spec.is_some());
328        assert_eq!(spec.unwrap().api_identifier, "claude-3-7-sonnet-20250219");
329        assert_eq!(spec.unwrap().max_output_tokens, 64000);
330
331        // Test AWS-style identifiers
332        let aws_haiku = "anthropic.claude-3-haiku-20240307-v1:0";
333        let spec = registry.get_model_spec(aws_haiku);
334        assert!(spec.is_some());
335        assert_eq!(spec.unwrap().api_identifier, "claude-3-haiku-20240307");
336        assert_eq!(spec.unwrap().max_output_tokens, 4096);
337
338        // Test European region
339        let eu_opus = "eu.anthropic.claude-3-opus-20240229-v2:1";
340        let spec = registry.get_model_spec(eu_opus);
341        assert!(spec.is_some());
342        assert_eq!(spec.unwrap().api_identifier, "claude-3-opus-20240229");
343        assert_eq!(spec.unwrap().max_output_tokens, 4096);
344
345        // Test exact match still works
346        let exact_sonnet4 = "claude-sonnet-4-20250514";
347        let spec = registry.get_model_spec(exact_sonnet4);
348        assert!(spec.is_some());
349        assert_eq!(spec.unwrap().max_output_tokens, 64000);
350    }
351
352    #[test]
353    fn test_extract_core_model_identifier() {
354        let registry = ModelRegistry::load().unwrap();
355
356        // Test various formats
357        assert_eq!(
358            registry.extract_core_model_identifier("us.anthropic.claude-3-7-sonnet-20250219-v1:0"),
359            "claude-3-7-sonnet-20250219"
360        );
361
362        assert_eq!(
363            registry.extract_core_model_identifier("anthropic.claude-3-haiku-20240307-v1:0"),
364            "claude-3-haiku-20240307"
365        );
366
367        assert_eq!(
368            registry.extract_core_model_identifier("claude-3-opus-20240229"),
369            "claude-3-opus-20240229"
370        );
371
372        assert_eq!(
373            registry.extract_core_model_identifier("eu.anthropic.claude-sonnet-4-20250514-v2:1"),
374            "claude-sonnet-4-20250514"
375        );
376    }
377}