Skip to main content

omni_dev/claude/
model_config.rs

1//! AI model configuration and specifications.
2//!
3//! # Data model
4//!
5//! [`ModelConfiguration`] is the top-level container. It owns a
6//! <code>Vec<[ModelSpec]></code> of every known model and a
7//! <code>HashMap<String, [ProviderConfig]></code> keyed by provider name.
8//! [`ModelSpec`] records the per-model limits, generation, tier name, and
9//! any [`BetaHeader`]s that unlock enhanced limits. [`ProviderConfig`]
10//! records provider-wide settings — including a [`TierInfo`] map describing
11//! each named tier and a [`DefaultConfig`] block used as the fallback for
12//! unknown identifiers from that provider. Every entry carries a
13//! [`ModelSource`] tag identifying which layer contributed it.
14//!
15//! [`ModelRegistry`] wraps a fully merged [`ModelConfiguration`] and adds
16//! identifier-normalised lookup (so a Bedrock or AWS-direct identifier
17//! resolves to the same [`ModelSpec`] as the canonical Anthropic form).
18//!
19//! # Loader
20//!
21//! [`ModelRegistry::load`] builds the registry from a layered set of YAML
22//! sources: an embedded catalog (compile-time `include_str!`), an optional
23//! user-level file at `~/.omni-dev/models.yaml`, and an optional
24//! project-local file at `./.omni-dev/models.yaml`. Layers are deep-merged
25//! with project > user > embedded precedence; an explicit override path
26//! provided via `OMNI_DEV_MODELS_YAML` short-circuits the user/project
27//! lookup. See [ADR-0022](../../docs/adrs/adr-0022.md) for the layered
28//! loader rationale and [ADR-0011](../../docs/adrs/adr-0011.md) for the
29//! original compile-time design.
30
31use std::collections::HashMap;
32use std::path::{Path, PathBuf};
33use std::sync::OnceLock;
34
35use anyhow::{anyhow, Result};
36use serde::{Deserialize, Serialize};
37
38/// Embedded models YAML configuration, loaded at compile time.
39pub(crate) const MODELS_YAML: &str = include_str!("../templates/models.yaml");
40
41/// Schema version that this build of omni-dev understands.
42///
43/// User/project files declaring a different version receive a warning at
44/// load time. Files without a `version:` field are accepted with a warning
45/// for backwards compatibility.
46pub const MODELS_SCHEMA_VERSION: &str = "1";
47
48/// Environment variable that, when set, points at a single user-side YAML
49/// file and short-circuits the standard user/project lookup.
50pub const OMNI_DEV_MODELS_YAML_ENV: &str = "OMNI_DEV_MODELS_YAML";
51
52/// Ultimate fallback max output tokens when no model or provider config matches.
53const FALLBACK_MAX_OUTPUT_TOKENS: usize = 4096;
54
55/// Ultimate fallback input context when no model or provider config matches.
56const FALLBACK_INPUT_CONTEXT: usize = 100_000;
57
58/// Layer that contributed a model or provider entry.
59#[derive(
60    Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize, Default,
61)]
62#[serde(rename_all = "lowercase")]
63pub enum ModelSource {
64    /// Compile-time embedded catalog (`src/templates/models.yaml`).
65    #[default]
66    Embedded,
67    /// User-level catalog at `~/.omni-dev/models.yaml`.
68    User,
69    /// Project-local catalog at `./.omni-dev/models.yaml`.
70    Project,
71    /// File explicitly pointed to by `OMNI_DEV_MODELS_YAML`/`--models-yaml`.
72    Override,
73}
74
75impl std::fmt::Display for ModelSource {
76    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77        f.write_str(match self {
78            Self::Embedded => "embedded",
79            Self::User => "user",
80            Self::Project => "project",
81            Self::Override => "override",
82        })
83    }
84}
85
86/// HTTP header that, when sent on a request, unlocks enhanced limits for a
87/// model.
88///
89/// A [`BetaHeader`] is a leaf of a [`ModelSpec`]: it names the header to
90/// send (`key`/`value`) and records the new ceiling for [`max_output_tokens`]
91/// and/or [`input_context`] that the header makes available. An absent
92/// override field means that header does not move that limit; the model's
93/// base value still applies. Callers consult these via
94/// [`ModelRegistry::get_max_output_tokens_with_beta`] and
95/// [`ModelRegistry::get_input_context_with_beta`].
96///
97/// [`max_output_tokens`]: ModelSpec::max_output_tokens
98/// [`input_context`]: ModelSpec::input_context
99#[derive(Debug, Deserialize, Serialize, Clone)]
100pub struct BetaHeader {
101    /// HTTP header name (e.g., "anthropic-beta").
102    pub key: String,
103    /// Header value (e.g., "context-1m-2025-08-07").
104    pub value: String,
105    /// Overridden max output tokens when this header is active.
106    #[serde(default, skip_serializing_if = "Option::is_none")]
107    pub max_output_tokens: Option<usize>,
108    /// Overridden input context when this header is active.
109    #[serde(default, skip_serializing_if = "Option::is_none")]
110    pub input_context: Option<usize>,
111}
112
113/// Specification for a single model: its identity, limits, tier, and any
114/// beta-header unlocks.
115///
116/// A [`ModelSpec`] is the central row of the registry. `provider` and
117/// `tier` cross-reference into a [`ProviderConfig`] (via
118/// [`ModelConfiguration::providers`] and [`ProviderConfig::tiers`]).
119/// `max_output_tokens` and `input_context` are the *base* limits; entries
120/// in `beta_headers` raise them when the corresponding HTTP header is sent.
121/// `source` is loader-populated and records which layer contributed the
122/// entry — never read from YAML.
123///
124/// # Identifier normalization
125///
126/// The same underlying model is addressable through several identifier
127/// formats depending on how the API is reached:
128///
129/// - Canonical (Anthropic direct): `claude-3-7-sonnet-20250219`
130/// - Bedrock with region prefix: `us.anthropic.claude-3-7-sonnet-20250219-v1:0`
131/// - AWS-direct without region: `anthropic.claude-3-haiku-20240307-v1:0`
132/// - Regional gateways: `eu.anthropic.claude-3-opus-20240229-v2:1`
133///
134/// All four resolve to the same [`ModelSpec`]:
135/// [`ModelRegistry::get_model_spec`] tries an exact match first, and on
136/// miss strips region/provider prefixes and version suffixes before
137/// retrying. See [ADR-0011](../../docs/adrs/adr-0011.md) for the design
138/// rationale.
139#[derive(Debug, Deserialize, Serialize, Clone)]
140pub struct ModelSpec {
141    /// AI provider name (e.g., "claude").
142    pub provider: String,
143    /// Human-readable model name (e.g., "Claude Opus 4").
144    pub model: String,
145    /// API identifier used for requests (e.g., "claude-3-opus-20240229").
146    pub api_identifier: String,
147    /// Maximum number of tokens that can be generated in a single response.
148    pub max_output_tokens: usize,
149    /// Maximum number of tokens that can be included in the input context.
150    pub input_context: usize,
151    /// Model generation number (e.g., 3.0, 3.5, 4.0).
152    pub generation: f32,
153    /// Performance tier (e.g., "fast", "balanced", "flagship").
154    pub tier: String,
155    /// Whether this is a legacy model that may be deprecated.
156    #[serde(default)]
157    pub legacy: bool,
158    /// Beta headers that unlock enhanced limits for this model.
159    #[serde(default, skip_serializing_if = "Vec::is_empty")]
160    pub beta_headers: Vec<BetaHeader>,
161    /// Layer that contributed this entry. Populated by the loader; never
162    /// read from YAML.
163    #[serde(default, skip_deserializing)]
164    pub source: ModelSource,
165}
166
167/// Human-readable metadata for a named performance tier.
168///
169/// A tier groups models with comparable speed/capability trade-offs
170/// (e.g. `fast`, `balanced`, `flagship`). [`TierInfo`] holds only the
171/// *description* and recommended use cases — the *limits* (output tokens,
172/// input context, beta-header unlocks) live on each [`ModelSpec`], not
173/// here. [`TierInfo`] is stored in [`ProviderConfig::tiers`] keyed by tier
174/// name, and the same tier name appears on [`ModelSpec::tier`] to link a
175/// model into its tier.
176#[derive(Debug, Deserialize, Serialize, Clone)]
177pub struct TierInfo {
178    /// Human-readable description of the tier.
179    pub description: String,
180    /// List of recommended use cases for this tier.
181    pub use_cases: Vec<String>,
182}
183
184/// Provider-wide fallback limits used when a requested identifier does not
185/// match any [`ModelSpec`].
186///
187/// [`ModelRegistry::get_max_output_tokens`] and
188/// [`ModelRegistry::get_input_context`] consult these values whenever the
189/// caller passes an identifier the registry has not seen — typically a
190/// brand-new model the embedded catalog has not yet been updated for, but
191/// whose provider can still be inferred from the identifier shape. If the
192/// provider itself cannot be inferred, an ultimate hard-coded fallback in
193/// this module applies instead.
194#[derive(Debug, Deserialize, Serialize, Clone)]
195pub struct DefaultConfig {
196    /// Default maximum output tokens for unknown models from this provider.
197    pub max_output_tokens: usize,
198    /// Default input context limit for unknown models from this provider.
199    pub input_context: usize,
200}
201
202/// Per-provider settings: endpoint, default model, named tiers, and the
203/// fallback limits for unknown identifiers.
204///
205/// One [`ProviderConfig`] exists per AI vendor (Anthropic Claude, OpenAI,
206/// Bedrock, Ollama, …) and is stored in [`ModelConfiguration::providers`]
207/// keyed by provider name. `tiers` maps tier names to [`TierInfo`]
208/// descriptions; the same names appear on [`ModelSpec::tier`]. `defaults`
209/// is the per-provider [`DefaultConfig`] used as a fallback when a model
210/// identifier does not match any [`ModelSpec`]. `source` is
211/// loader-populated and records the highest-precedence layer that
212/// contributed any field to this provider block.
213#[derive(Debug, Deserialize, Serialize, Clone)]
214pub struct ProviderConfig {
215    /// Human-readable provider name.
216    pub name: String,
217    /// Base URL for API requests.
218    pub api_base: String,
219    /// Default model identifier to use if none specified.
220    pub default_model: String,
221    /// Available performance tiers and their descriptions.
222    pub tiers: HashMap<String, TierInfo>,
223    /// Default configuration for unknown models.
224    pub defaults: DefaultConfig,
225    /// Layer that contributed this provider block. Populated by the loader.
226    #[serde(default, skip_deserializing)]
227    pub source: ModelSource,
228}
229
230/// Top-level deserialised model catalog: every known model plus every
231/// provider's settings.
232///
233/// [`ModelConfiguration`] is the result of merging the embedded
234/// `src/templates/models.yaml` with any optional user
235/// (`~/.omni-dev/models.yaml`) and project (`./.omni-dev/models.yaml`)
236/// overrides, in that precedence order. See
237/// [ADR-0022](../../docs/adrs/adr-0022.md) for the layered loader and
238/// merge semantics. The canonical entry point that produces a fully merged
239/// instance — and wraps it in lookup indices — is [`ModelRegistry::load`];
240/// the raw configuration is reachable from there via
241/// [`ModelRegistry::config`].
242#[derive(Debug, Deserialize, Serialize, Clone)]
243pub struct ModelConfiguration {
244    /// Schema version declared by the source YAML, if any.
245    #[serde(default, skip_serializing_if = "Option::is_none")]
246    pub version: Option<String>,
247    /// List of all available models.
248    pub models: Vec<ModelSpec>,
249    /// Provider-specific configurations.
250    pub providers: HashMap<String, ProviderConfig>,
251}
252
253/// Indexed view over a [`ModelConfiguration`] with identifier-normalised
254/// lookup.
255///
256/// [`ModelRegistry`] owns the merged catalog and two auxiliary indices —
257/// by API identifier and by provider — populated at construction time.
258/// Construct one with [`ModelRegistry::load`], which performs the layered
259/// YAML load described on [`ModelConfiguration`]. Most callers use the
260/// process-wide singleton returned by [`get_model_registry`] rather than
261/// loading their own instance.
262pub struct ModelRegistry {
263    config: ModelConfiguration,
264    by_identifier: HashMap<String, ModelSpec>,
265    by_provider: HashMap<String, Vec<ModelSpec>>,
266}
267
268impl ModelRegistry {
269    /// Loads the model registry, layering an optional user-side catalog
270    /// over the embedded one.
271    ///
272    /// Lookup order (highest precedence wins):
273    /// 1. `OMNI_DEV_MODELS_YAML` — explicit override path; short-circuits 2 & 3.
274    /// 2. `./.omni-dev/models.yaml` — project-local catalog (if present).
275    /// 3. `~/.omni-dev/models.yaml` — user-level catalog (if present).
276    /// 4. Embedded `src/templates/models.yaml` — always present, lowest layer.
277    ///
278    /// Missing user-side files fall through silently. Malformed user-side
279    /// files log an error and are skipped. A malformed embedded catalog is
280    /// a hard failure (compile-time invariant).
281    pub fn load() -> Result<Self> {
282        let override_path = std::env::var(OMNI_DEV_MODELS_YAML_ENV)
283            .ok()
284            .filter(|s| !s.is_empty())
285            .map(PathBuf::from);
286        let project_path = default_project_path();
287        let user_path = default_user_path();
288        Self::load_layered_from_paths(
289            project_path.as_deref(),
290            user_path.as_deref(),
291            override_path.as_deref(),
292        )
293    }
294
295    /// Loads the registry with explicit paths for the user-side layers.
296    ///
297    /// Exposed primarily for testing — the public entry point is `load()`.
298    pub fn load_layered_from_paths(
299        project_path: Option<&Path>,
300        user_path: Option<&Path>,
301        override_path: Option<&Path>,
302    ) -> Result<Self> {
303        let mut layers: Vec<(ModelSource, String)> = Vec::new();
304        layers.push((ModelSource::Embedded, MODELS_YAML.to_string()));
305
306        if let Some(path) = override_path {
307            match read_optional_yaml(path) {
308                Some(yaml) => layers.push((ModelSource::Override, yaml)),
309                None => {
310                    tracing::warn!(
311                        "{OMNI_DEV_MODELS_YAML_ENV} points at {} but the file is missing or unreadable; falling back to embedded catalog",
312                        path.display()
313                    );
314                }
315            }
316        } else {
317            if let Some(path) = user_path {
318                if let Some(yaml) = read_optional_yaml(path) {
319                    layers.push((ModelSource::User, yaml));
320                }
321            }
322            if let Some(path) = project_path {
323                if let Some(yaml) = read_optional_yaml(path) {
324                    layers.push((ModelSource::Project, yaml));
325                }
326            }
327        }
328
329        Self::from_layers(&layers)
330    }
331
332    /// Builds the registry from already-loaded YAML sources.
333    ///
334    /// `layers` must be ordered from lowest to highest precedence; the
335    /// first entry is treated as the embedded catalog and a parse failure
336    /// there is a hard error.
337    pub(crate) fn from_layers(layers: &[(ModelSource, String)]) -> Result<Self> {
338        let mut merged: serde_yaml::Value =
339            serde_yaml::Value::Mapping(serde_yaml::Mapping::default());
340        let mut model_sources: HashMap<String, ModelSource> = HashMap::new();
341        let mut provider_sources: HashMap<String, ModelSource> = HashMap::new();
342        let mut declared_versions: Vec<(ModelSource, Option<String>)> = Vec::new();
343
344        for (source, yaml) in layers {
345            let value: serde_yaml::Value = match serde_yaml::from_str(yaml) {
346                Ok(v) => v,
347                Err(e) => {
348                    if matches!(source, ModelSource::Embedded) {
349                        return Err(anyhow!(
350                            "Embedded models.yaml is malformed at compile time: {e}"
351                        ));
352                    }
353                    tracing::error!(
354                        "Malformed {source} models.yaml: {e}. Falling through to lower-precedence layers."
355                    );
356                    continue;
357                }
358            };
359
360            // Track version declared by this layer.
361            let version = value
362                .get("version")
363                .and_then(|v| v.as_str())
364                .map(String::from);
365            declared_versions.push((*source, version));
366
367            merge_layer_into(
368                &mut merged,
369                value,
370                *source,
371                &mut model_sources,
372                &mut provider_sources,
373            );
374        }
375
376        warn_on_version_mismatch(&declared_versions);
377
378        let mut config: ModelConfiguration = serde_yaml::from_value(merged)
379            .map_err(|e| anyhow!("Failed to deserialize merged model configuration: {e}"))?;
380
381        for spec in &mut config.models {
382            spec.source = model_sources
383                .get(&spec.api_identifier)
384                .copied()
385                .unwrap_or_default();
386        }
387        for (name, prov) in &mut config.providers {
388            prov.source = provider_sources.get(name).copied().unwrap_or_default();
389        }
390
391        let mut by_identifier = HashMap::new();
392        let mut by_provider: HashMap<String, Vec<ModelSpec>> = HashMap::new();
393        for model in &config.models {
394            by_identifier.insert(model.api_identifier.clone(), model.clone());
395            by_provider
396                .entry(model.provider.clone())
397                .or_default()
398                .push(model.clone());
399        }
400
401        Ok(Self {
402            config,
403            by_identifier,
404            by_provider,
405        })
406    }
407
408    /// Returns the merged model configuration.
409    #[must_use]
410    pub fn config(&self) -> &ModelConfiguration {
411        &self.config
412    }
413
414    /// Returns the model specification for the given API identifier.
415    #[must_use]
416    pub fn get_model_spec(&self, api_identifier: &str) -> Option<&ModelSpec> {
417        // Try exact match first
418        if let Some(spec) = self.by_identifier.get(api_identifier) {
419            return Some(spec);
420        }
421
422        // Try normalizing the identifier and looking up again
423        self.find_model_by_normalized_id(api_identifier)
424    }
425
426    /// Returns the max output tokens for a model, with fallback to provider defaults.
427    #[must_use]
428    pub fn get_max_output_tokens(&self, api_identifier: &str) -> usize {
429        if let Some(spec) = self.get_model_spec(api_identifier) {
430            return spec.max_output_tokens;
431        }
432
433        // Try to infer provider from model identifier and use defaults
434        if let Some(provider) = self.infer_provider(api_identifier) {
435            if let Some(provider_config) = self.config.providers.get(&provider) {
436                return provider_config.defaults.max_output_tokens;
437            }
438        }
439
440        // Ultimate fallback
441        FALLBACK_MAX_OUTPUT_TOKENS
442    }
443
444    /// Returns the input context limit for a model, with fallback to provider defaults.
445    #[must_use]
446    pub fn get_input_context(&self, api_identifier: &str) -> usize {
447        if let Some(spec) = self.get_model_spec(api_identifier) {
448            return spec.input_context;
449        }
450
451        // Try to infer provider from model identifier and use defaults
452        if let Some(provider) = self.infer_provider(api_identifier) {
453            if let Some(provider_config) = self.config.providers.get(&provider) {
454                return provider_config.defaults.input_context;
455            }
456        }
457
458        // Ultimate fallback
459        FALLBACK_INPUT_CONTEXT
460    }
461
462    /// Infers the provider from a model identifier.
463    fn infer_provider(&self, api_identifier: &str) -> Option<String> {
464        if api_identifier.starts_with("claude") || api_identifier.contains("anthropic") {
465            Some("claude".to_string())
466        } else {
467            None
468        }
469    }
470
471    /// Finds a model by normalizing the identifier and performing an exact lookup.
472    ///
473    /// Handles Bedrock-style (`us.anthropic.claude-3-7-sonnet-20250219-v1:0`),
474    /// AWS-style (`anthropic.claude-3-haiku-20240307-v1:0`), and standard identifiers.
475    fn find_model_by_normalized_id(&self, api_identifier: &str) -> Option<&ModelSpec> {
476        let core_identifier = self.extract_core_model_identifier(api_identifier);
477        self.by_identifier.get(&core_identifier)
478    }
479
480    /// Extracts the core model identifier from various formats.
481    fn extract_core_model_identifier(&self, api_identifier: &str) -> String {
482        let mut identifier = api_identifier.to_string();
483
484        // Remove region prefixes (us., eu., etc.)
485        if let Some(dot_pos) = identifier.find('.') {
486            if identifier[..dot_pos].len() <= 3 {
487                // likely a region code
488                identifier = identifier[dot_pos + 1..].to_string();
489            }
490        }
491
492        // Remove provider prefixes (anthropic.)
493        if identifier.starts_with("anthropic.") {
494            identifier = identifier["anthropic.".len()..].to_string();
495        }
496
497        // Remove version suffixes (-v1:0, -v2:1, etc.)
498        if let Some(version_pos) = identifier.rfind("-v") {
499            if identifier[version_pos..].contains(':') {
500                identifier = identifier[..version_pos].to_string();
501            }
502        }
503
504        identifier
505    }
506
507    /// Checks if a model is legacy.
508    #[must_use]
509    pub fn is_legacy_model(&self, api_identifier: &str) -> bool {
510        self.get_model_spec(api_identifier)
511            .is_some_and(|spec| spec.legacy)
512    }
513
514    /// Returns all available models.
515    #[must_use]
516    pub fn get_all_models(&self) -> &[ModelSpec] {
517        &self.config.models
518    }
519
520    /// Returns models filtered by provider.
521    #[must_use]
522    pub fn get_models_by_provider(&self, provider: &str) -> Vec<&ModelSpec> {
523        self.by_provider
524            .get(provider)
525            .map(|models| models.iter().collect())
526            .unwrap_or_default()
527    }
528
529    /// Returns models filtered by provider and tier.
530    #[must_use]
531    pub fn get_models_by_provider_and_tier(&self, provider: &str, tier: &str) -> Vec<&ModelSpec> {
532        self.get_models_by_provider(provider)
533            .into_iter()
534            .filter(|model| model.tier == tier)
535            .collect()
536    }
537
538    /// Returns the default model identifier for a provider, as defined in `models.yaml`.
539    #[must_use]
540    pub fn get_default_model(&self, provider: &str) -> Option<&str> {
541        self.config
542            .providers
543            .get(provider)
544            .map(|p| p.default_model.as_str())
545    }
546
547    /// Returns the provider configuration.
548    #[must_use]
549    pub fn get_provider_config(&self, provider: &str) -> Option<&ProviderConfig> {
550        self.config.providers.get(provider)
551    }
552
553    /// Returns tier information for a provider.
554    #[must_use]
555    pub fn get_tier_info(&self, provider: &str, tier: &str) -> Option<&TierInfo> {
556        self.config.providers.get(provider)?.tiers.get(tier)
557    }
558
559    /// Returns the beta headers for a model.
560    #[must_use]
561    pub fn get_beta_headers(&self, api_identifier: &str) -> &[BetaHeader] {
562        self.get_model_spec(api_identifier)
563            .map(|spec| spec.beta_headers.as_slice())
564            .unwrap_or_default()
565    }
566
567    /// Returns the max output tokens for a model with a specific beta header active.
568    #[must_use]
569    pub fn get_max_output_tokens_with_beta(&self, api_identifier: &str, beta_value: &str) -> usize {
570        if let Some(spec) = self.get_model_spec(api_identifier) {
571            if let Some(bh) = spec.beta_headers.iter().find(|b| b.value == beta_value) {
572                if let Some(max) = bh.max_output_tokens {
573                    return max;
574                }
575            }
576            return spec.max_output_tokens;
577        }
578        self.get_max_output_tokens(api_identifier)
579    }
580
581    /// Returns the input context for a model with a specific beta header active.
582    #[must_use]
583    pub fn get_input_context_with_beta(&self, api_identifier: &str, beta_value: &str) -> usize {
584        if let Some(spec) = self.get_model_spec(api_identifier) {
585            if let Some(bh) = spec.beta_headers.iter().find(|b| b.value == beta_value) {
586                if let Some(ctx) = bh.input_context {
587                    return ctx;
588                }
589            }
590            return spec.input_context;
591        }
592        self.get_input_context(api_identifier)
593    }
594}
595
596/// Default project-local catalog path: `<cwd>/.omni-dev/models.yaml`.
597fn default_project_path() -> Option<PathBuf> {
598    std::env::current_dir()
599        .ok()
600        .map(|cwd| cwd.join(".omni-dev").join("models.yaml"))
601}
602
603/// Default user-level catalog path: `~/.omni-dev/models.yaml`.
604fn default_user_path() -> Option<PathBuf> {
605    dirs::home_dir().map(|h| h.join(".omni-dev").join("models.yaml"))
606}
607
608/// Reads `path` if it exists. Returns `None` for missing files; logs and
609/// returns `None` for read errors so the caller can fall through.
610fn read_optional_yaml(path: &Path) -> Option<String> {
611    if !path.exists() {
612        return None;
613    }
614    match std::fs::read_to_string(path) {
615        Ok(s) => Some(s),
616        Err(e) => {
617            tracing::error!(
618                "Failed to read {}: {e}. Falling through to lower-precedence layers.",
619                path.display()
620            );
621            None
622        }
623    }
624}
625
626/// Merges a single layer's parsed YAML value into the accumulator.
627///
628/// The structure is treated specially at two top-level keys:
629/// - `models`: a sequence merged by `api_identifier`. Existing entries are
630///   deep-merged with the incoming entry; new entries are appended.
631/// - `providers`: a mapping deep-merged per provider name (so a user file
632///   can override e.g. `default_model` on the embedded `claude` provider
633///   without having to re-declare every tier).
634///
635/// All other top-level keys (such as `version`) are last-writer-wins.
636fn merge_layer_into(
637    dest: &mut serde_yaml::Value,
638    src: serde_yaml::Value,
639    source: ModelSource,
640    model_sources: &mut HashMap<String, ModelSource>,
641    provider_sources: &mut HashMap<String, ModelSource>,
642) {
643    use serde_yaml::Value;
644
645    let Value::Mapping(src_map) = src else {
646        // Top-level isn't a mapping — treat the layer as a wholesale
647        // replacement. (The embedded YAML is well-formed, so this is only
648        // exercised by adversarial user input.)
649        *dest = src;
650        return;
651    };
652
653    if !matches!(dest, Value::Mapping(_)) {
654        *dest = Value::Mapping(serde_yaml::Mapping::new());
655    }
656    let Value::Mapping(dest_map) = dest else {
657        unreachable!("dest is a mapping after the check above");
658    };
659
660    for (k, v) in src_map {
661        match k.as_str() {
662            Some("models") => merge_models_into(dest_map, k, v, source, model_sources),
663            Some("providers") => merge_providers_into(dest_map, k, v, source, provider_sources),
664            _ => {
665                dest_map.insert(k, v);
666            }
667        }
668    }
669}
670
671fn merge_models_into(
672    dest_map: &mut serde_yaml::Mapping,
673    key: serde_yaml::Value,
674    incoming: serde_yaml::Value,
675    source: ModelSource,
676    model_sources: &mut HashMap<String, ModelSource>,
677) {
678    use serde_yaml::Value;
679
680    let Value::Sequence(incoming_seq) = incoming else {
681        // Not a sequence — replace whatever is there.
682        dest_map.insert(key, incoming);
683        return;
684    };
685
686    let dest_value = dest_map
687        .entry(key)
688        .or_insert_with(|| Value::Sequence(Vec::new()));
689    if !matches!(dest_value, Value::Sequence(_)) {
690        *dest_value = Value::Sequence(Vec::new());
691    }
692    let Value::Sequence(dest_seq) = dest_value else {
693        unreachable!("dest is a sequence after the check above");
694    };
695
696    for entry in incoming_seq {
697        let api_id = entry
698            .get("api_identifier")
699            .and_then(|v| v.as_str())
700            .map(String::from);
701
702        let Some(api_id) = api_id else {
703            tracing::warn!(
704                "Skipping model entry without `api_identifier` from {source} models.yaml"
705            );
706            continue;
707        };
708
709        if let Some(existing) = dest_seq
710            .iter_mut()
711            .find(|e| e.get("api_identifier").and_then(serde_yaml::Value::as_str) == Some(&api_id))
712        {
713            deep_merge(existing, entry);
714        } else {
715            dest_seq.push(entry);
716        }
717
718        model_sources.insert(api_id, source);
719    }
720}
721
722fn merge_providers_into(
723    dest_map: &mut serde_yaml::Mapping,
724    key: serde_yaml::Value,
725    incoming: serde_yaml::Value,
726    source: ModelSource,
727    provider_sources: &mut HashMap<String, ModelSource>,
728) {
729    use serde_yaml::Value;
730
731    let Value::Mapping(incoming_providers) = incoming else {
732        dest_map.insert(key, incoming);
733        return;
734    };
735
736    let dest_value = dest_map
737        .entry(key)
738        .or_insert_with(|| Value::Mapping(serde_yaml::Mapping::new()));
739    if !matches!(dest_value, Value::Mapping(_)) {
740        *dest_value = Value::Mapping(serde_yaml::Mapping::new());
741    }
742    let Value::Mapping(dest_providers) = dest_value else {
743        unreachable!("dest is a mapping after the check above");
744    };
745
746    for (pname, pvalue) in incoming_providers {
747        let pname_str = pname.as_str().map(String::from);
748
749        if let Some(existing) = dest_providers.get_mut(&pname) {
750            deep_merge(existing, pvalue);
751        } else {
752            dest_providers.insert(pname.clone(), pvalue);
753        }
754
755        if let Some(name) = pname_str {
756            provider_sources.insert(name, source);
757        }
758    }
759}
760
761/// Recursive deep-merge: mappings are merged key-by-key, sequences and
762/// scalars are replaced wholesale.
763fn deep_merge(dest: &mut serde_yaml::Value, src: serde_yaml::Value) {
764    use serde_yaml::Value;
765    match (dest, src) {
766        (Value::Mapping(d), Value::Mapping(s)) => {
767            for (k, v) in s {
768                if let Some(existing) = d.get_mut(&k) {
769                    deep_merge(existing, v);
770                } else {
771                    d.insert(k, v);
772                }
773            }
774        }
775        (d, s) => *d = s,
776    }
777}
778
779/// Logs a warning for each user-side layer whose `version` field differs
780/// from the schema version this build understands.
781fn warn_on_version_mismatch(declared: &[(ModelSource, Option<String>)]) {
782    for (source, version) in declared {
783        if matches!(source, ModelSource::Embedded) {
784            continue;
785        }
786        match version {
787            None => {
788                tracing::warn!(
789                    "{source} models.yaml has no `version:` field; assuming compatibility with schema version {MODELS_SCHEMA_VERSION}. Add `version: \"{MODELS_SCHEMA_VERSION}\"` to silence this warning."
790                );
791            }
792            Some(v) if v == MODELS_SCHEMA_VERSION => {}
793            Some(v) => {
794                tracing::warn!(
795                    "{source} models.yaml declares schema version {v}; this build understands {MODELS_SCHEMA_VERSION}. Continuing — unrecognised fields may be ignored."
796                );
797            }
798        }
799    }
800}
801
802/// Global model registry instance.
803static MODEL_REGISTRY: OnceLock<ModelRegistry> = OnceLock::new();
804
805/// Returns the global model registry instance.
806#[must_use]
807pub fn get_model_registry() -> &'static ModelRegistry {
808    #[allow(clippy::expect_used)] // YAML is embedded via include_str! at compile time
809    MODEL_REGISTRY.get_or_init(|| ModelRegistry::load().expect("Failed to load model registry"))
810}
811
812#[cfg(test)]
813#[allow(clippy::unwrap_used, clippy::expect_used)]
814mod tests {
815    use super::*;
816    use std::io::Write;
817
818    fn embedded_only() -> ModelRegistry {
819        ModelRegistry::load_layered_from_paths(None, None, None).unwrap()
820    }
821
822    fn write_yaml(dir: &Path, name: &str, contents: &str) -> PathBuf {
823        let path = dir.join(name);
824        let mut f = std::fs::File::create(&path).unwrap();
825        f.write_all(contents.as_bytes()).unwrap();
826        path
827    }
828
829    #[test]
830    fn load_model_registry() {
831        let registry = embedded_only();
832        assert!(!registry.config.models.is_empty());
833        assert!(registry.config.providers.contains_key("claude"));
834        assert_eq!(
835            registry.config.version.as_deref(),
836            Some(MODELS_SCHEMA_VERSION)
837        );
838    }
839
840    #[test]
841    fn claude_model_lookup() {
842        let registry = embedded_only();
843
844        // Test legacy Claude 3 Opus
845        let opus_spec = registry.get_model_spec("claude-3-opus-20240229");
846        assert!(opus_spec.is_some());
847        assert_eq!(opus_spec.unwrap().max_output_tokens, 4096);
848        assert_eq!(opus_spec.unwrap().provider, "claude");
849        assert!(registry.is_legacy_model("claude-3-opus-20240229"));
850
851        // Test Claude 4.5 Sonnet (current generation)
852        let sonnet45_tokens = registry.get_max_output_tokens("claude-sonnet-4-5-20250929");
853        assert_eq!(sonnet45_tokens, 64000);
854
855        // Test legacy Claude 4 Sonnet
856        let sonnet4_tokens = registry.get_max_output_tokens("claude-sonnet-4-20250514");
857        assert_eq!(sonnet4_tokens, 64000);
858        assert!(registry.is_legacy_model("claude-sonnet-4-20250514"));
859
860        // Test unknown model falls back to provider defaults
861        let unknown_tokens = registry.get_max_output_tokens("claude-unknown-model");
862        assert_eq!(unknown_tokens, 4096); // Should use Claude provider defaults
863    }
864
865    #[test]
866    fn unknown_provider_uses_ultimate_fallback() {
867        let registry = embedded_only();
868
869        // Unknown identifier with no recognisable provider → ultimate fallback.
870        assert_eq!(
871            registry.get_max_output_tokens("totally-unknown-vendor-x"),
872            FALLBACK_MAX_OUTPUT_TOKENS
873        );
874        assert_eq!(
875            registry.get_input_context("totally-unknown-vendor-x"),
876            FALLBACK_INPUT_CONTEXT
877        );
878    }
879
880    #[test]
881    fn provider_filtering() {
882        let registry = embedded_only();
883
884        let claude_models = registry.get_models_by_provider("claude");
885        assert!(!claude_models.is_empty());
886
887        let fast_claude_models = registry.get_models_by_provider_and_tier("claude", "fast");
888        assert!(!fast_claude_models.is_empty());
889
890        let tier_info = registry.get_tier_info("claude", "fast");
891        assert!(tier_info.is_some());
892    }
893
894    #[test]
895    fn provider_config() {
896        let registry = embedded_only();
897
898        let claude_config = registry.get_provider_config("claude");
899        assert!(claude_config.is_some());
900        assert_eq!(claude_config.unwrap().name, "Anthropic Claude");
901    }
902
903    #[test]
904    fn default_model_per_provider() {
905        let registry = embedded_only();
906
907        assert_eq!(
908            registry.get_default_model("claude"),
909            Some("claude-sonnet-4-6")
910        );
911        assert_eq!(registry.get_default_model("openai"), Some("gpt-5-mini"));
912        assert_eq!(
913            registry.get_default_model("gemini"),
914            Some("gemini-2.5-flash")
915        );
916        assert_eq!(registry.get_default_model("nonexistent"), None);
917    }
918
919    #[test]
920    fn normalized_id_matching() {
921        let registry = embedded_only();
922
923        // Test Bedrock-style identifiers
924        let bedrock_3_7_sonnet = "us.anthropic.claude-3-7-sonnet-20250219-v1:0";
925        let spec = registry.get_model_spec(bedrock_3_7_sonnet);
926        assert!(spec.is_some());
927        assert_eq!(spec.unwrap().api_identifier, "claude-3-7-sonnet-20250219");
928        assert_eq!(spec.unwrap().max_output_tokens, 64000);
929
930        // Test AWS-style identifiers
931        let aws_haiku = "anthropic.claude-3-haiku-20240307-v1:0";
932        let spec = registry.get_model_spec(aws_haiku);
933        assert!(spec.is_some());
934        assert_eq!(spec.unwrap().api_identifier, "claude-3-haiku-20240307");
935        assert_eq!(spec.unwrap().max_output_tokens, 4096);
936
937        // Test European region
938        let eu_opus = "eu.anthropic.claude-3-opus-20240229-v2:1";
939        let spec = registry.get_model_spec(eu_opus);
940        assert!(spec.is_some());
941        assert_eq!(spec.unwrap().api_identifier, "claude-3-opus-20240229");
942        assert_eq!(spec.unwrap().max_output_tokens, 4096);
943
944        // Test exact match still works for Claude 4.5 Sonnet
945        let exact_sonnet45 = "claude-sonnet-4-5-20250929";
946        let spec = registry.get_model_spec(exact_sonnet45);
947        assert!(spec.is_some());
948        assert_eq!(spec.unwrap().max_output_tokens, 64000);
949
950        // Test legacy Claude 4 Sonnet
951        let exact_sonnet4 = "claude-sonnet-4-20250514";
952        let spec = registry.get_model_spec(exact_sonnet4);
953        assert!(spec.is_some());
954        assert_eq!(spec.unwrap().max_output_tokens, 64000);
955    }
956
957    #[test]
958    fn extract_core_model_identifier() {
959        let registry = embedded_only();
960
961        // Test various formats
962        assert_eq!(
963            registry.extract_core_model_identifier("us.anthropic.claude-3-7-sonnet-20250219-v1:0"),
964            "claude-3-7-sonnet-20250219"
965        );
966
967        assert_eq!(
968            registry.extract_core_model_identifier("anthropic.claude-3-haiku-20240307-v1:0"),
969            "claude-3-haiku-20240307"
970        );
971
972        assert_eq!(
973            registry.extract_core_model_identifier("claude-3-opus-20240229"),
974            "claude-3-opus-20240229"
975        );
976
977        assert_eq!(
978            registry.extract_core_model_identifier("eu.anthropic.claude-sonnet-4-20250514-v2:1"),
979            "claude-sonnet-4-20250514"
980        );
981    }
982
983    #[test]
984    fn beta_header_lookups() {
985        let registry = embedded_only();
986
987        // Opus 4.6 base limits
988        assert_eq!(registry.get_max_output_tokens("claude-opus-4-6"), 128_000);
989        assert_eq!(registry.get_input_context("claude-opus-4-6"), 200_000);
990
991        // Opus 4.6 with 1M context beta
992        assert_eq!(
993            registry.get_input_context_with_beta("claude-opus-4-6", "context-1m-2025-08-07"),
994            1_000_000
995        );
996        // max_output_tokens unchanged with context beta
997        assert_eq!(
998            registry.get_max_output_tokens_with_beta("claude-opus-4-6", "context-1m-2025-08-07"),
999            128_000
1000        );
1001
1002        // Sonnet 3.7 with output-128k beta
1003        assert_eq!(
1004            registry.get_max_output_tokens_with_beta(
1005                "claude-3-7-sonnet-20250219",
1006                "output-128k-2025-02-19"
1007            ),
1008            128_000
1009        );
1010
1011        // Sonnet 3.7 base max_output_tokens without beta
1012        assert_eq!(
1013            registry.get_max_output_tokens("claude-3-7-sonnet-20250219"),
1014            64000
1015        );
1016
1017        // Beta headers accessor
1018        let headers = registry.get_beta_headers("claude-opus-4-6");
1019        assert_eq!(headers.len(), 1);
1020        assert_eq!(headers[0].key, "anthropic-beta");
1021        assert_eq!(headers[0].value, "context-1m-2025-08-07");
1022
1023        // Sonnet 3.7 has two beta headers
1024        let headers = registry.get_beta_headers("claude-3-7-sonnet-20250219");
1025        assert_eq!(headers.len(), 2);
1026
1027        // Model without beta headers returns empty slice
1028        let headers = registry.get_beta_headers("claude-3-haiku-20240307");
1029        assert!(headers.is_empty());
1030
1031        // Unknown model returns empty slice
1032        let headers = registry.get_beta_headers("unknown-model");
1033        assert!(headers.is_empty());
1034    }
1035
1036    #[test]
1037    fn beta_lookups_for_unknown_model_fall_through_to_provider_defaults() {
1038        let registry = embedded_only();
1039
1040        // Unknown model with arbitrary beta value: get_max_output_tokens_with_beta
1041        // and get_input_context_with_beta should both delegate to the no-beta
1042        // resolver, which in turn returns provider defaults for "claude-…".
1043        assert_eq!(
1044            registry
1045                .get_max_output_tokens_with_beta("claude-unknown-model", "context-1m-2025-08-07"),
1046            4096
1047        );
1048        assert_eq!(
1049            registry.get_input_context_with_beta("claude-unknown-model", "context-1m-2025-08-07"),
1050            200_000
1051        );
1052    }
1053
1054    #[test]
1055    fn embedded_models_default_to_embedded_source() {
1056        let registry = embedded_only();
1057        let spec = registry.get_model_spec("claude-opus-4-6").unwrap();
1058        assert_eq!(spec.source, ModelSource::Embedded);
1059
1060        let provider = registry.get_provider_config("claude").unwrap();
1061        assert_eq!(provider.source, ModelSource::Embedded);
1062    }
1063
1064    #[test]
1065    fn missing_user_and_project_files_fall_through_silently() {
1066        let dir = tempfile::tempdir().unwrap();
1067        let project_path = dir.path().join("missing-project.yaml");
1068        let user_path = dir.path().join("missing-user.yaml");
1069        let registry =
1070            ModelRegistry::load_layered_from_paths(Some(&project_path), Some(&user_path), None)
1071                .unwrap();
1072
1073        // Behaviour identical to embedded-only.
1074        let spec = registry.get_model_spec("claude-opus-4-6").unwrap();
1075        assert_eq!(spec.source, ModelSource::Embedded);
1076        assert_eq!(spec.max_output_tokens, 128_000);
1077    }
1078
1079    #[test]
1080    fn user_layer_overrides_embedded_entry() {
1081        let dir = tempfile::tempdir().unwrap();
1082        let user = write_yaml(
1083            dir.path(),
1084            "user.yaml",
1085            r#"
1086version: "1"
1087models:
1088  - provider: "claude"
1089    model: "Claude Opus 4.6 (custom)"
1090    api_identifier: "claude-opus-4-6"
1091    max_output_tokens: 999999
1092    input_context: 200000
1093    generation: 4.6
1094    tier: "flagship"
1095"#,
1096        );
1097
1098        let registry = ModelRegistry::load_layered_from_paths(None, Some(&user), None).unwrap();
1099        let spec = registry.get_model_spec("claude-opus-4-6").unwrap();
1100        assert_eq!(spec.max_output_tokens, 999_999);
1101        assert_eq!(spec.model, "Claude Opus 4.6 (custom)");
1102        assert_eq!(spec.source, ModelSource::User);
1103    }
1104
1105    #[test]
1106    fn project_layer_takes_precedence_over_user_layer() {
1107        let dir = tempfile::tempdir().unwrap();
1108        let user = write_yaml(
1109            dir.path(),
1110            "user.yaml",
1111            r#"
1112version: "1"
1113models:
1114  - provider: "claude"
1115    model: "From User"
1116    api_identifier: "claude-opus-4-6"
1117    max_output_tokens: 1
1118    input_context: 1
1119    generation: 4.6
1120    tier: "flagship"
1121"#,
1122        );
1123        let project = write_yaml(
1124            dir.path(),
1125            "project.yaml",
1126            r#"
1127version: "1"
1128models:
1129  - provider: "claude"
1130    model: "From Project"
1131    api_identifier: "claude-opus-4-6"
1132    max_output_tokens: 2
1133    input_context: 2
1134    generation: 4.6
1135    tier: "flagship"
1136"#,
1137        );
1138
1139        let registry =
1140            ModelRegistry::load_layered_from_paths(Some(&project), Some(&user), None).unwrap();
1141        let spec = registry.get_model_spec("claude-opus-4-6").unwrap();
1142        assert_eq!(spec.model, "From Project");
1143        assert_eq!(spec.max_output_tokens, 2);
1144        assert_eq!(spec.source, ModelSource::Project);
1145    }
1146
1147    #[test]
1148    fn additive_user_entry_is_appended() {
1149        let dir = tempfile::tempdir().unwrap();
1150        let user = write_yaml(
1151            dir.path(),
1152            "user.yaml",
1153            r#"
1154version: "1"
1155models:
1156  - provider: "claude"
1157    model: "Claude Custom Future"
1158    api_identifier: "claude-future-9000"
1159    max_output_tokens: 250000
1160    input_context: 5000000
1161    generation: 9.0
1162    tier: "flagship"
1163"#,
1164        );
1165
1166        let registry = ModelRegistry::load_layered_from_paths(None, Some(&user), None).unwrap();
1167        let spec = registry.get_model_spec("claude-future-9000").unwrap();
1168        assert_eq!(spec.max_output_tokens, 250_000);
1169        assert_eq!(spec.input_context, 5_000_000);
1170        assert_eq!(spec.source, ModelSource::User);
1171
1172        // And a pre-existing model is still present, sourced from embedded.
1173        let opus = registry.get_model_spec("claude-opus-4-6").unwrap();
1174        assert_eq!(opus.source, ModelSource::Embedded);
1175    }
1176
1177    #[test]
1178    fn provider_fields_can_be_partially_overridden() {
1179        let dir = tempfile::tempdir().unwrap();
1180        // User only changes claude.default_model. Other fields (tiers,
1181        // defaults, api_base, name) must be preserved from the embedded layer.
1182        let user = write_yaml(
1183            dir.path(),
1184            "user.yaml",
1185            r#"
1186version: "1"
1187providers:
1188  claude:
1189    default_model: "claude-opus-4-6"
1190"#,
1191        );
1192
1193        let registry = ModelRegistry::load_layered_from_paths(None, Some(&user), None).unwrap();
1194        let claude = registry.get_provider_config("claude").unwrap();
1195        assert_eq!(claude.default_model, "claude-opus-4-6");
1196        // Embedded fields must survive the partial override.
1197        assert_eq!(claude.name, "Anthropic Claude");
1198        assert_eq!(claude.api_base, "https://api.anthropic.com/v1");
1199        assert!(claude.tiers.contains_key("flagship"));
1200        // Provider source reflects the most-recent contributing layer.
1201        assert_eq!(claude.source, ModelSource::User);
1202    }
1203
1204    #[test]
1205    fn malformed_user_yaml_logs_and_falls_through() {
1206        let dir = tempfile::tempdir().unwrap();
1207        let user = write_yaml(
1208            dir.path(),
1209            "user.yaml",
1210            "this: is: definitely: not: valid: yaml: [unbalanced",
1211        );
1212
1213        let registry = ModelRegistry::load_layered_from_paths(None, Some(&user), None).unwrap();
1214        // Embedded catalog is intact.
1215        let spec = registry.get_model_spec("claude-opus-4-6").unwrap();
1216        assert_eq!(spec.source, ModelSource::Embedded);
1217        assert_eq!(spec.max_output_tokens, 128_000);
1218    }
1219
1220    #[test]
1221    fn override_path_short_circuits_user_and_project() {
1222        let dir = tempfile::tempdir().unwrap();
1223        let user = write_yaml(
1224            dir.path(),
1225            "user.yaml",
1226            r#"
1227version: "1"
1228models:
1229  - provider: "claude"
1230    model: "From User"
1231    api_identifier: "claude-opus-4-6"
1232    max_output_tokens: 1
1233    input_context: 1
1234    generation: 4.6
1235    tier: "flagship"
1236"#,
1237        );
1238        let project = write_yaml(
1239            dir.path(),
1240            "project.yaml",
1241            r#"
1242version: "1"
1243models:
1244  - provider: "claude"
1245    model: "From Project"
1246    api_identifier: "claude-opus-4-6"
1247    max_output_tokens: 2
1248    input_context: 2
1249    generation: 4.6
1250    tier: "flagship"
1251"#,
1252        );
1253        let override_file = write_yaml(
1254            dir.path(),
1255            "override.yaml",
1256            r#"
1257version: "1"
1258models:
1259  - provider: "claude"
1260    model: "From Override"
1261    api_identifier: "claude-opus-4-6"
1262    max_output_tokens: 3
1263    input_context: 3
1264    generation: 4.6
1265    tier: "flagship"
1266"#,
1267        );
1268
1269        let registry = ModelRegistry::load_layered_from_paths(
1270            Some(&project),
1271            Some(&user),
1272            Some(&override_file),
1273        )
1274        .unwrap();
1275        let spec = registry.get_model_spec("claude-opus-4-6").unwrap();
1276        assert_eq!(spec.model, "From Override");
1277        assert_eq!(spec.max_output_tokens, 3);
1278        assert_eq!(spec.source, ModelSource::Override);
1279    }
1280
1281    #[test]
1282    fn missing_override_path_falls_back_to_embedded() {
1283        let dir = tempfile::tempdir().unwrap();
1284        let missing = dir.path().join("does-not-exist.yaml");
1285        let registry = ModelRegistry::load_layered_from_paths(None, None, Some(&missing)).unwrap();
1286        let spec = registry.get_model_spec("claude-opus-4-6").unwrap();
1287        assert_eq!(spec.source, ModelSource::Embedded);
1288    }
1289
1290    #[test]
1291    fn version_mismatch_is_warned_not_fatal() {
1292        let dir = tempfile::tempdir().unwrap();
1293        let user = write_yaml(
1294            dir.path(),
1295            "user.yaml",
1296            r#"
1297version: "9999"
1298models:
1299  - provider: "claude"
1300    model: "From Future"
1301    api_identifier: "claude-future-9000"
1302    max_output_tokens: 1
1303    input_context: 1
1304    generation: 9.0
1305    tier: "flagship"
1306"#,
1307        );
1308        let registry = ModelRegistry::load_layered_from_paths(None, Some(&user), None).unwrap();
1309        // Loaded successfully despite version mismatch.
1310        assert!(registry.get_model_spec("claude-future-9000").is_some());
1311    }
1312
1313    #[test]
1314    fn missing_version_is_accepted() {
1315        let dir = tempfile::tempdir().unwrap();
1316        let user = write_yaml(
1317            dir.path(),
1318            "user.yaml",
1319            r#"
1320models:
1321  - provider: "claude"
1322    model: "Versionless"
1323    api_identifier: "claude-versionless"
1324    max_output_tokens: 1
1325    input_context: 1
1326    generation: 1.0
1327    tier: "flagship"
1328"#,
1329        );
1330        let registry = ModelRegistry::load_layered_from_paths(None, Some(&user), None).unwrap();
1331        assert!(registry.get_model_spec("claude-versionless").is_some());
1332    }
1333
1334    #[test]
1335    fn model_entry_without_api_identifier_is_skipped() {
1336        let dir = tempfile::tempdir().unwrap();
1337        let user = write_yaml(
1338            dir.path(),
1339            "user.yaml",
1340            r#"
1341version: "1"
1342models:
1343  - provider: "claude"
1344    model: "No Id"
1345    max_output_tokens: 1
1346    input_context: 1
1347    generation: 1.0
1348    tier: "flagship"
1349"#,
1350        );
1351        let registry = ModelRegistry::load_layered_from_paths(None, Some(&user), None).unwrap();
1352        // Registry still loads; embedded catalog unchanged.
1353        let opus = registry.get_model_spec("claude-opus-4-6").unwrap();
1354        assert_eq!(opus.source, ModelSource::Embedded);
1355    }
1356
1357    #[test]
1358    fn model_source_display() {
1359        assert_eq!(ModelSource::Embedded.to_string(), "embedded");
1360        assert_eq!(ModelSource::User.to_string(), "user");
1361        assert_eq!(ModelSource::Project.to_string(), "project");
1362        assert_eq!(ModelSource::Override.to_string(), "override");
1363    }
1364
1365    #[test]
1366    fn embedded_yaml_must_not_be_malformed() {
1367        // Sanity-check: a malformed embedded layer would be a hard error.
1368        let layers = [(ModelSource::Embedded, "::: not yaml :::".to_string())];
1369        let result = ModelRegistry::from_layers(&layers);
1370        assert!(result.is_err());
1371    }
1372
1373    #[test]
1374    fn user_layer_with_scalar_top_level_returns_error() {
1375        // Adversarial: user YAML root is a string, not a mapping. The
1376        // wholesale-replacement branch in `merge_layer_into` discards the
1377        // embedded mapping; deserialise then fails cleanly.
1378        let dir = tempfile::tempdir().unwrap();
1379        let user = write_yaml(dir.path(), "user.yaml", "\"just a string\"\n");
1380        let result = ModelRegistry::load_layered_from_paths(None, Some(&user), None);
1381        assert!(result.is_err());
1382    }
1383
1384    #[test]
1385    fn user_layer_with_non_sequence_models_returns_error() {
1386        // Adversarial: `models: 42` triggers the non-sequence branch in
1387        // `merge_models_into`, which writes the scalar through. The final
1388        // `from_value` fails because `models` must be a sequence.
1389        let dir = tempfile::tempdir().unwrap();
1390        let user = write_yaml(
1391            dir.path(),
1392            "user.yaml",
1393            r#"
1394version: "1"
1395models: 42
1396"#,
1397        );
1398        let result = ModelRegistry::load_layered_from_paths(None, Some(&user), None);
1399        assert!(result.is_err());
1400    }
1401
1402    #[test]
1403    fn user_layer_with_non_mapping_providers_returns_error() {
1404        // Adversarial: `providers: 42` triggers the non-mapping branch in
1405        // `merge_providers_into`. The final `from_value` then fails.
1406        let dir = tempfile::tempdir().unwrap();
1407        let user = write_yaml(
1408            dir.path(),
1409            "user.yaml",
1410            r#"
1411version: "1"
1412providers: 42
1413"#,
1414        );
1415        let result = ModelRegistry::load_layered_from_paths(None, Some(&user), None);
1416        assert!(result.is_err());
1417    }
1418
1419    #[test]
1420    fn deep_merge_inserts_new_keys_into_existing_mapping() {
1421        // Exercises the "key not in dest" branch of `deep_merge`. Adding a
1422        // new tier under `providers.claude.tiers` requires the merger to
1423        // *insert* (not overwrite) within an existing mapping.
1424        let dir = tempfile::tempdir().unwrap();
1425        let user = write_yaml(
1426            dir.path(),
1427            "user.yaml",
1428            r#"
1429version: "1"
1430providers:
1431  claude:
1432    tiers:
1433      experimental:
1434        description: "Experimental tier"
1435        use_cases: ["bleeding edge"]
1436"#,
1437        );
1438        let registry = ModelRegistry::load_layered_from_paths(None, Some(&user), None).unwrap();
1439        let claude = registry.get_provider_config("claude").unwrap();
1440        // Embedded tiers preserved…
1441        assert!(claude.tiers.contains_key("flagship"));
1442        assert!(claude.tiers.contains_key("balanced"));
1443        assert!(claude.tiers.contains_key("fast"));
1444        // …and the new tier was inserted.
1445        let experimental = claude.tiers.get("experimental").unwrap();
1446        assert_eq!(experimental.description, "Experimental tier");
1447        assert_eq!(experimental.use_cases, vec!["bleeding edge".to_string()]);
1448    }
1449
1450    #[test]
1451    #[cfg(unix)]
1452    fn user_path_pointing_at_a_directory_logs_and_falls_through() {
1453        // A directory exists at the path, so `path.exists()` is true, but
1454        // `read_to_string` errors. The loader logs and falls through.
1455        let dir = tempfile::tempdir().unwrap();
1456        let bogus = dir.path().join("models.yaml");
1457        std::fs::create_dir(&bogus).unwrap();
1458        let registry = ModelRegistry::load_layered_from_paths(None, Some(&bogus), None).unwrap();
1459        let spec = registry.get_model_spec("claude-opus-4-6").unwrap();
1460        assert_eq!(spec.source, ModelSource::Embedded);
1461    }
1462
1463    #[test]
1464    #[cfg(unix)]
1465    fn override_path_pointing_at_a_directory_warns_and_falls_through() {
1466        let dir = tempfile::tempdir().unwrap();
1467        let bogus = dir.path().join("override.yaml");
1468        std::fs::create_dir(&bogus).unwrap();
1469        let registry = ModelRegistry::load_layered_from_paths(None, None, Some(&bogus)).unwrap();
1470        let spec = registry.get_model_spec("claude-opus-4-6").unwrap();
1471        assert_eq!(spec.source, ModelSource::Embedded);
1472    }
1473
1474    #[test]
1475    fn project_layer_recovers_after_user_replaces_top_level_with_scalar() {
1476        // Layer-2 (user) wholesale-replaces the merged accumulator with a
1477        // scalar (early-return branch in `merge_layer_into`). Layer-3
1478        // (project) must hit the "dest is not a mapping" recovery branch
1479        // and rebuild a mapping before merging its own content. Project
1480        // must redeclare `providers` since the user layer wiped them.
1481        let dir = tempfile::tempdir().unwrap();
1482        let user = write_yaml(dir.path(), "user.yaml", "\"junk\"\n");
1483        let project = write_yaml(
1484            dir.path(),
1485            "project.yaml",
1486            r#"
1487version: "1"
1488models:
1489  - provider: "claude"
1490    model: "Project Rescue"
1491    api_identifier: "claude-rescue"
1492    max_output_tokens: 1
1493    input_context: 1
1494    generation: 1.0
1495    tier: "flagship"
1496providers:
1497  custom-provider:
1498    name: "Custom"
1499    api_base: "https://example.invalid"
1500    default_model: "custom-default"
1501    tiers: {}
1502    defaults:
1503      max_output_tokens: 100
1504      input_context: 1000
1505"#,
1506        );
1507        let registry =
1508            ModelRegistry::load_layered_from_paths(Some(&project), Some(&user), None).unwrap();
1509        // Project's model survives the user layer's top-level scalar wipe.
1510        let spec = registry.get_model_spec("claude-rescue").unwrap();
1511        assert_eq!(spec.source, ModelSource::Project);
1512    }
1513
1514    #[test]
1515    fn project_layer_recovers_after_user_replaces_models_with_scalar() {
1516        // Layer-2 sets `models: 42`, replacing the embedded sequence with
1517        // a scalar. Layer-3 must trigger the "dest is not a sequence"
1518        // recovery branch in `merge_models_into` and rebuild the sequence.
1519        let dir = tempfile::tempdir().unwrap();
1520        let user = write_yaml(
1521            dir.path(),
1522            "user.yaml",
1523            r#"
1524version: "1"
1525models: 42
1526"#,
1527        );
1528        let project = write_yaml(
1529            dir.path(),
1530            "project.yaml",
1531            r#"
1532version: "1"
1533models:
1534  - provider: "claude"
1535    model: "Project Rescue"
1536    api_identifier: "claude-rescue"
1537    max_output_tokens: 1
1538    input_context: 1
1539    generation: 1.0
1540    tier: "flagship"
1541"#,
1542        );
1543        let registry =
1544            ModelRegistry::load_layered_from_paths(Some(&project), Some(&user), None).unwrap();
1545        let spec = registry.get_model_spec("claude-rescue").unwrap();
1546        assert_eq!(spec.source, ModelSource::Project);
1547    }
1548
1549    #[test]
1550    fn project_layer_recovers_after_user_replaces_providers_with_scalar() {
1551        // Layer-2 sets `providers: 42`. Layer-3 must trigger the "dest is
1552        // not a mapping" recovery branch in `merge_providers_into`.
1553        let dir = tempfile::tempdir().unwrap();
1554        let user = write_yaml(
1555            dir.path(),
1556            "user.yaml",
1557            r#"
1558version: "1"
1559providers: 42
1560"#,
1561        );
1562        let project = write_yaml(
1563            dir.path(),
1564            "project.yaml",
1565            r#"
1566version: "1"
1567providers:
1568  custom-provider:
1569    name: "Custom"
1570    api_base: "https://example.invalid"
1571    default_model: "custom-default"
1572    tiers: {}
1573    defaults:
1574      max_output_tokens: 100
1575      input_context: 1000
1576"#,
1577        );
1578        let registry =
1579            ModelRegistry::load_layered_from_paths(Some(&project), Some(&user), None).unwrap();
1580        let provider = registry.get_provider_config("custom-provider").unwrap();
1581        assert_eq!(provider.name, "Custom");
1582        assert_eq!(provider.source, ModelSource::Project);
1583    }
1584
1585    #[test]
1586    fn empty_omni_dev_models_yaml_env_var_is_ignored() {
1587        // Exercises the `.filter(|s| !s.is_empty())` branch from `load()`
1588        // directly. The `load()` entry point is not safely callable from
1589        // a unit test because it consults a process-wide OnceLock.
1590        let resolved: Option<PathBuf> = Some(String::new())
1591            .filter(|s| !s.is_empty())
1592            .map(PathBuf::from);
1593        assert!(resolved.is_none());
1594        let resolved: Option<PathBuf> = Some("/some/path".to_string())
1595            .filter(|s| !s.is_empty())
1596            .map(PathBuf::from);
1597        assert_eq!(resolved.as_deref(), Some(Path::new("/some/path")));
1598    }
1599}