Skip to main content

omni_dev/claude/
model_config.rs

1//! AI model configuration and specifications.
2//!
3//! This module provides model specifications loaded from embedded YAML templates
4//! to ensure correct API parameters for different AI models.
5
6use std::collections::HashMap;
7use std::sync::OnceLock;
8
9use anyhow::Result;
10use serde::Deserialize;
11
12/// Embedded models YAML configuration, loaded at compile time.
13pub(crate) const MODELS_YAML: &str = include_str!("../templates/models.yaml");
14
15/// Ultimate fallback max output tokens when no model or provider config matches.
16const FALLBACK_MAX_OUTPUT_TOKENS: usize = 4096;
17
18/// Ultimate fallback input context when no model or provider config matches.
19const FALLBACK_INPUT_CONTEXT: usize = 100_000;
20
21/// Beta header that unlocks enhanced model limits.
22#[derive(Debug, Deserialize, Clone)]
23pub struct BetaHeader {
24    /// HTTP header name (e.g., "anthropic-beta").
25    pub key: String,
26    /// Header value (e.g., "context-1m-2025-08-07").
27    pub value: String,
28    /// Overridden max output tokens when this header is active.
29    #[serde(default)]
30    pub max_output_tokens: Option<usize>,
31    /// Overridden input context when this header is active.
32    #[serde(default)]
33    pub input_context: Option<usize>,
34}
35
36/// Model specification from YAML configuration.
37#[derive(Debug, Deserialize, Clone)]
38pub struct ModelSpec {
39    /// AI provider name (e.g., "claude").
40    pub provider: String,
41    /// Human-readable model name (e.g., "Claude Opus 4").
42    pub model: String,
43    /// API identifier used for requests (e.g., "claude-3-opus-20240229").
44    pub api_identifier: String,
45    /// Maximum number of tokens that can be generated in a single response.
46    pub max_output_tokens: usize,
47    /// Maximum number of tokens that can be included in the input context.
48    pub input_context: usize,
49    /// Model generation number (e.g., 3.0, 3.5, 4.0).
50    pub generation: f32,
51    /// Performance tier (e.g., "fast", "balanced", "flagship").
52    pub tier: String,
53    /// Whether this is a legacy model that may be deprecated.
54    #[serde(default)]
55    pub legacy: bool,
56    /// Beta headers that unlock enhanced limits for this model.
57    #[serde(default)]
58    pub beta_headers: Vec<BetaHeader>,
59}
60
61/// Model tier information.
62#[derive(Debug, Deserialize)]
63pub struct TierInfo {
64    /// Human-readable description of the tier.
65    pub description: String,
66    /// List of recommended use cases for this tier.
67    pub use_cases: Vec<String>,
68}
69
70/// Default fallback configuration for a provider.
71#[derive(Debug, Deserialize)]
72pub struct DefaultConfig {
73    /// Default maximum output tokens for unknown models from this provider.
74    pub max_output_tokens: usize,
75    /// Default input context limit for unknown models from this provider.
76    pub input_context: usize,
77}
78
79/// Provider-specific configuration.
80#[derive(Debug, Deserialize)]
81pub struct ProviderConfig {
82    /// Human-readable provider name.
83    pub name: String,
84    /// Base URL for API requests.
85    pub api_base: String,
86    /// Default model identifier to use if none specified.
87    pub default_model: String,
88    /// Available performance tiers and their descriptions.
89    pub tiers: HashMap<String, TierInfo>,
90    /// Default configuration for unknown models.
91    pub defaults: DefaultConfig,
92}
93
94/// Complete model configuration.
95#[derive(Debug, Deserialize)]
96pub struct ModelConfiguration {
97    /// List of all available models.
98    pub models: Vec<ModelSpec>,
99    /// Provider-specific configurations.
100    pub providers: HashMap<String, ProviderConfig>,
101}
102
103/// Model registry for looking up specifications.
104pub struct ModelRegistry {
105    config: ModelConfiguration,
106    by_identifier: HashMap<String, ModelSpec>,
107    by_provider: HashMap<String, Vec<ModelSpec>>,
108}
109
110impl ModelRegistry {
111    /// Loads the model registry from embedded YAML.
112    pub fn load() -> Result<Self> {
113        let config: ModelConfiguration = serde_yaml::from_str(MODELS_YAML)?;
114
115        // Build lookup maps
116        let mut by_identifier = HashMap::new();
117        let mut by_provider: HashMap<String, Vec<ModelSpec>> = HashMap::new();
118
119        for model in &config.models {
120            by_identifier.insert(model.api_identifier.clone(), model.clone());
121            by_provider
122                .entry(model.provider.clone())
123                .or_default()
124                .push(model.clone());
125        }
126
127        Ok(Self {
128            config,
129            by_identifier,
130            by_provider,
131        })
132    }
133
134    /// Returns the model specification for the given API identifier.
135    #[must_use]
136    pub fn get_model_spec(&self, api_identifier: &str) -> Option<&ModelSpec> {
137        // Try exact match first
138        if let Some(spec) = self.by_identifier.get(api_identifier) {
139            return Some(spec);
140        }
141
142        // Try normalizing the identifier and looking up again
143        self.find_model_by_normalized_id(api_identifier)
144    }
145
146    /// Returns the max output tokens for a model, with fallback to provider defaults.
147    #[must_use]
148    pub fn get_max_output_tokens(&self, api_identifier: &str) -> usize {
149        if let Some(spec) = self.get_model_spec(api_identifier) {
150            return spec.max_output_tokens;
151        }
152
153        // Try to infer provider from model identifier and use defaults
154        if let Some(provider) = self.infer_provider(api_identifier) {
155            if let Some(provider_config) = self.config.providers.get(&provider) {
156                return provider_config.defaults.max_output_tokens;
157            }
158        }
159
160        // Ultimate fallback
161        FALLBACK_MAX_OUTPUT_TOKENS
162    }
163
164    /// Returns the input context limit for a model, with fallback to provider defaults.
165    #[must_use]
166    pub fn get_input_context(&self, api_identifier: &str) -> usize {
167        if let Some(spec) = self.get_model_spec(api_identifier) {
168            return spec.input_context;
169        }
170
171        // Try to infer provider from model identifier and use defaults
172        if let Some(provider) = self.infer_provider(api_identifier) {
173            if let Some(provider_config) = self.config.providers.get(&provider) {
174                return provider_config.defaults.input_context;
175            }
176        }
177
178        // Ultimate fallback
179        FALLBACK_INPUT_CONTEXT
180    }
181
182    /// Infers the provider from a model identifier.
183    fn infer_provider(&self, api_identifier: &str) -> Option<String> {
184        if api_identifier.starts_with("claude") || api_identifier.contains("anthropic") {
185            Some("claude".to_string())
186        } else {
187            None
188        }
189    }
190
191    /// Finds a model by normalizing the identifier and performing an exact lookup.
192    ///
193    /// Handles Bedrock-style (`us.anthropic.claude-3-7-sonnet-20250219-v1:0`),
194    /// AWS-style (`anthropic.claude-3-haiku-20240307-v1:0`), and standard identifiers.
195    fn find_model_by_normalized_id(&self, api_identifier: &str) -> Option<&ModelSpec> {
196        let core_identifier = self.extract_core_model_identifier(api_identifier);
197        self.by_identifier.get(&core_identifier)
198    }
199
200    /// Extracts the core model identifier from various formats.
201    fn extract_core_model_identifier(&self, api_identifier: &str) -> String {
202        let mut identifier = api_identifier.to_string();
203
204        // Remove region prefixes (us., eu., etc.)
205        if let Some(dot_pos) = identifier.find('.') {
206            if identifier[..dot_pos].len() <= 3 {
207                // likely a region code
208                identifier = identifier[dot_pos + 1..].to_string();
209            }
210        }
211
212        // Remove provider prefixes (anthropic.)
213        if identifier.starts_with("anthropic.") {
214            identifier = identifier["anthropic.".len()..].to_string();
215        }
216
217        // Remove version suffixes (-v1:0, -v2:1, etc.)
218        if let Some(version_pos) = identifier.rfind("-v") {
219            if identifier[version_pos..].contains(':') {
220                identifier = identifier[..version_pos].to_string();
221            }
222        }
223
224        identifier
225    }
226
227    /// Checks if a model is legacy.
228    #[must_use]
229    pub fn is_legacy_model(&self, api_identifier: &str) -> bool {
230        self.get_model_spec(api_identifier)
231            .is_some_and(|spec| spec.legacy)
232    }
233
234    /// Returns all available models.
235    #[must_use]
236    pub fn get_all_models(&self) -> &[ModelSpec] {
237        &self.config.models
238    }
239
240    /// Returns models filtered by provider.
241    #[must_use]
242    pub fn get_models_by_provider(&self, provider: &str) -> Vec<&ModelSpec> {
243        self.by_provider
244            .get(provider)
245            .map(|models| models.iter().collect())
246            .unwrap_or_default()
247    }
248
249    /// Returns models filtered by provider and tier.
250    #[must_use]
251    pub fn get_models_by_provider_and_tier(&self, provider: &str, tier: &str) -> Vec<&ModelSpec> {
252        self.get_models_by_provider(provider)
253            .into_iter()
254            .filter(|model| model.tier == tier)
255            .collect()
256    }
257
258    /// Returns the default model identifier for a provider, as defined in `models.yaml`.
259    #[must_use]
260    pub fn get_default_model(&self, provider: &str) -> Option<&str> {
261        self.config
262            .providers
263            .get(provider)
264            .map(|p| p.default_model.as_str())
265    }
266
267    /// Returns the provider configuration.
268    #[must_use]
269    pub fn get_provider_config(&self, provider: &str) -> Option<&ProviderConfig> {
270        self.config.providers.get(provider)
271    }
272
273    /// Returns tier information for a provider.
274    #[must_use]
275    pub fn get_tier_info(&self, provider: &str, tier: &str) -> Option<&TierInfo> {
276        self.config.providers.get(provider)?.tiers.get(tier)
277    }
278
279    /// Returns the beta headers for a model.
280    #[must_use]
281    pub fn get_beta_headers(&self, api_identifier: &str) -> &[BetaHeader] {
282        self.get_model_spec(api_identifier)
283            .map(|spec| spec.beta_headers.as_slice())
284            .unwrap_or_default()
285    }
286
287    /// Returns the max output tokens for a model with a specific beta header active.
288    #[must_use]
289    pub fn get_max_output_tokens_with_beta(&self, api_identifier: &str, beta_value: &str) -> usize {
290        if let Some(spec) = self.get_model_spec(api_identifier) {
291            if let Some(bh) = spec.beta_headers.iter().find(|b| b.value == beta_value) {
292                if let Some(max) = bh.max_output_tokens {
293                    return max;
294                }
295            }
296            return spec.max_output_tokens;
297        }
298        self.get_max_output_tokens(api_identifier)
299    }
300
301    /// Returns the input context for a model with a specific beta header active.
302    #[must_use]
303    pub fn get_input_context_with_beta(&self, api_identifier: &str, beta_value: &str) -> usize {
304        if let Some(spec) = self.get_model_spec(api_identifier) {
305            if let Some(bh) = spec.beta_headers.iter().find(|b| b.value == beta_value) {
306                if let Some(ctx) = bh.input_context {
307                    return ctx;
308                }
309            }
310            return spec.input_context;
311        }
312        self.get_input_context(api_identifier)
313    }
314}
315
316/// Global model registry instance.
317static MODEL_REGISTRY: OnceLock<ModelRegistry> = OnceLock::new();
318
319/// Returns the global model registry instance.
320#[must_use]
321pub fn get_model_registry() -> &'static ModelRegistry {
322    #[allow(clippy::expect_used)] // YAML is embedded via include_str! at compile time
323    MODEL_REGISTRY.get_or_init(|| ModelRegistry::load().expect("Failed to load model registry"))
324}
325
326#[cfg(test)]
327#[allow(clippy::unwrap_used, clippy::expect_used)]
328mod tests {
329    use super::*;
330
331    #[test]
332    fn load_model_registry() {
333        let registry = ModelRegistry::load().unwrap();
334        assert!(!registry.config.models.is_empty());
335        assert!(registry.config.providers.contains_key("claude"));
336    }
337
338    #[test]
339    fn claude_model_lookup() {
340        let registry = ModelRegistry::load().unwrap();
341
342        // Test legacy Claude 3 Opus
343        let opus_spec = registry.get_model_spec("claude-3-opus-20240229");
344        assert!(opus_spec.is_some());
345        assert_eq!(opus_spec.unwrap().max_output_tokens, 4096);
346        assert_eq!(opus_spec.unwrap().provider, "claude");
347        assert!(registry.is_legacy_model("claude-3-opus-20240229"));
348
349        // Test Claude 4.5 Sonnet (current generation)
350        let sonnet45_tokens = registry.get_max_output_tokens("claude-sonnet-4-5-20250929");
351        assert_eq!(sonnet45_tokens, 64000);
352
353        // Test legacy Claude 4 Sonnet
354        let sonnet4_tokens = registry.get_max_output_tokens("claude-sonnet-4-20250514");
355        assert_eq!(sonnet4_tokens, 64000);
356        assert!(registry.is_legacy_model("claude-sonnet-4-20250514"));
357
358        // Test unknown model falls back to provider defaults
359        let unknown_tokens = registry.get_max_output_tokens("claude-unknown-model");
360        assert_eq!(unknown_tokens, 4096); // Should use Claude provider defaults
361    }
362
363    #[test]
364    fn provider_filtering() {
365        let registry = ModelRegistry::load().unwrap();
366
367        let claude_models = registry.get_models_by_provider("claude");
368        assert!(!claude_models.is_empty());
369
370        let fast_claude_models = registry.get_models_by_provider_and_tier("claude", "fast");
371        assert!(!fast_claude_models.is_empty());
372
373        let tier_info = registry.get_tier_info("claude", "fast");
374        assert!(tier_info.is_some());
375    }
376
377    #[test]
378    fn provider_config() {
379        let registry = ModelRegistry::load().unwrap();
380
381        let claude_config = registry.get_provider_config("claude");
382        assert!(claude_config.is_some());
383        assert_eq!(claude_config.unwrap().name, "Anthropic Claude");
384    }
385
386    #[test]
387    fn default_model_per_provider() {
388        let registry = ModelRegistry::load().unwrap();
389
390        assert_eq!(
391            registry.get_default_model("claude"),
392            Some("claude-sonnet-4-6")
393        );
394        assert_eq!(registry.get_default_model("openai"), Some("gpt-5-mini"));
395        assert_eq!(
396            registry.get_default_model("gemini"),
397            Some("gemini-2.5-flash")
398        );
399        assert_eq!(registry.get_default_model("nonexistent"), None);
400    }
401
402    #[test]
403    fn normalized_id_matching() {
404        let registry = ModelRegistry::load().unwrap();
405
406        // Test Bedrock-style identifiers
407        let bedrock_3_7_sonnet = "us.anthropic.claude-3-7-sonnet-20250219-v1:0";
408        let spec = registry.get_model_spec(bedrock_3_7_sonnet);
409        assert!(spec.is_some());
410        assert_eq!(spec.unwrap().api_identifier, "claude-3-7-sonnet-20250219");
411        assert_eq!(spec.unwrap().max_output_tokens, 64000);
412
413        // Test AWS-style identifiers
414        let aws_haiku = "anthropic.claude-3-haiku-20240307-v1:0";
415        let spec = registry.get_model_spec(aws_haiku);
416        assert!(spec.is_some());
417        assert_eq!(spec.unwrap().api_identifier, "claude-3-haiku-20240307");
418        assert_eq!(spec.unwrap().max_output_tokens, 4096);
419
420        // Test European region
421        let eu_opus = "eu.anthropic.claude-3-opus-20240229-v2:1";
422        let spec = registry.get_model_spec(eu_opus);
423        assert!(spec.is_some());
424        assert_eq!(spec.unwrap().api_identifier, "claude-3-opus-20240229");
425        assert_eq!(spec.unwrap().max_output_tokens, 4096);
426
427        // Test exact match still works for Claude 4.5 Sonnet
428        let exact_sonnet45 = "claude-sonnet-4-5-20250929";
429        let spec = registry.get_model_spec(exact_sonnet45);
430        assert!(spec.is_some());
431        assert_eq!(spec.unwrap().max_output_tokens, 64000);
432
433        // Test legacy Claude 4 Sonnet
434        let exact_sonnet4 = "claude-sonnet-4-20250514";
435        let spec = registry.get_model_spec(exact_sonnet4);
436        assert!(spec.is_some());
437        assert_eq!(spec.unwrap().max_output_tokens, 64000);
438    }
439
440    #[test]
441    fn extract_core_model_identifier() {
442        let registry = ModelRegistry::load().unwrap();
443
444        // Test various formats
445        assert_eq!(
446            registry.extract_core_model_identifier("us.anthropic.claude-3-7-sonnet-20250219-v1:0"),
447            "claude-3-7-sonnet-20250219"
448        );
449
450        assert_eq!(
451            registry.extract_core_model_identifier("anthropic.claude-3-haiku-20240307-v1:0"),
452            "claude-3-haiku-20240307"
453        );
454
455        assert_eq!(
456            registry.extract_core_model_identifier("claude-3-opus-20240229"),
457            "claude-3-opus-20240229"
458        );
459
460        assert_eq!(
461            registry.extract_core_model_identifier("eu.anthropic.claude-sonnet-4-20250514-v2:1"),
462            "claude-sonnet-4-20250514"
463        );
464    }
465
466    #[test]
467    fn beta_header_lookups() {
468        let registry = ModelRegistry::load().unwrap();
469
470        // Opus 4.6 base limits
471        assert_eq!(registry.get_max_output_tokens("claude-opus-4-6"), 128_000);
472        assert_eq!(registry.get_input_context("claude-opus-4-6"), 200_000);
473
474        // Opus 4.6 with 1M context beta
475        assert_eq!(
476            registry.get_input_context_with_beta("claude-opus-4-6", "context-1m-2025-08-07"),
477            1_000_000
478        );
479        // max_output_tokens unchanged with context beta
480        assert_eq!(
481            registry.get_max_output_tokens_with_beta("claude-opus-4-6", "context-1m-2025-08-07"),
482            128_000
483        );
484
485        // Sonnet 3.7 with output-128k beta
486        assert_eq!(
487            registry.get_max_output_tokens_with_beta(
488                "claude-3-7-sonnet-20250219",
489                "output-128k-2025-02-19"
490            ),
491            128_000
492        );
493
494        // Sonnet 3.7 base max_output_tokens without beta
495        assert_eq!(
496            registry.get_max_output_tokens("claude-3-7-sonnet-20250219"),
497            64000
498        );
499
500        // Beta headers accessor
501        let headers = registry.get_beta_headers("claude-opus-4-6");
502        assert_eq!(headers.len(), 1);
503        assert_eq!(headers[0].key, "anthropic-beta");
504        assert_eq!(headers[0].value, "context-1m-2025-08-07");
505
506        // Sonnet 3.7 has two beta headers
507        let headers = registry.get_beta_headers("claude-3-7-sonnet-20250219");
508        assert_eq!(headers.len(), 2);
509
510        // Model without beta headers returns empty slice
511        let headers = registry.get_beta_headers("claude-3-haiku-20240307");
512        assert!(headers.is_empty());
513
514        // Unknown model returns empty slice
515        let headers = registry.get_beta_headers("unknown-model");
516        assert!(headers.is_empty());
517    }
518}