Skip to main content

oxi_store/
model_registry.rs

1//! Model registry — manages built-in and custom models, provides API key resolution.
2//!
3//! Originally inspired by pi-mono's model-registry.
4//!
5//! This module provides a `ModelRegistry` that:
6//! - Loads built-in models from `oxi_ai::model_db`
7//! - Loads custom models and provider overrides from a `models.json` file
8//! - Resolves API keys via `AuthStorage` (env vars, OAuth tokens, stored credentials)
9//! - Supports dynamic provider registration (extensions)
10//! - Provides model filtering by provider, capability, and modality
11
12use crate::auth_storage::{AuthStatus, AuthStorage};
13use oxi_ai::model_db;
14use oxi_ai::register_builtins::get_builtin_provider;
15use oxi_ai::{Api, CompatSettings, Cost, InputModality, Model};
16use parking_lot::RwLock;
17use serde::{Deserialize, Serialize};
18use std::collections::HashMap;
19use std::path::{Path, PathBuf};
20
21// =============================================================================
22// JSON Schema types for models.json
23// =============================================================================
24
25/// Per-model override fields (all optional, merged with built-in model).
26#[derive(Debug, Clone, Serialize, Deserialize, Default)]
27#[serde(rename_all = "camelCase")]
28pub struct ModelOverride {
29    /// pub.
30    pub name: Option<String>,
31    /// pub.
32    pub reasoning: Option<bool>,
33    /// pub.
34    pub thinking_level_map: Option<HashMap<String, Option<String>>>,
35    /// pub.
36    pub input: Option<Vec<InputModality>>,
37    /// pub.
38    pub cost: Option<PartialCost>,
39    /// pub.
40    pub context_window: Option<usize>,
41    /// pub.
42    pub max_tokens: Option<usize>,
43    /// pub.
44    pub headers: Option<HashMap<String, String>>,
45    /// pub.
46    pub compat: Option<CompatSettings>,
47}
48
49/// Partial cost override — each field is optional.
50#[derive(Debug, Clone, Serialize, Deserialize, Default)]
51#[serde(rename_all = "camelCase")]
52pub struct PartialCost {
53    /// pub.
54    pub input: Option<f64>,
55    /// pub.
56    pub output: Option<f64>,
57    /// pub.
58    pub cache_read: Option<f64>,
59    /// pub.
60    pub cache_write: Option<f64>,
61}
62
63/// Custom model definition in models.json.
64#[derive(Debug, Clone, Serialize, Deserialize)]
65#[serde(rename_all = "camelCase")]
66pub struct ModelDefinition {
67    /// pub.
68    pub id: String,
69    /// pub.
70    pub name: Option<String>,
71    /// pub.
72    pub api: Option<Api>,
73    /// pub.
74    pub base_url: Option<String>,
75    /// pub.
76    pub reasoning: Option<bool>,
77    /// pub.
78    pub thinking_level_map: Option<HashMap<String, Option<String>>>,
79    /// pub.
80    pub input: Option<Vec<InputModality>>,
81    /// pub.
82    pub cost: Option<Cost>,
83    /// pub.
84    pub context_window: Option<usize>,
85    /// pub.
86    pub max_tokens: Option<usize>,
87    /// pub.
88    pub headers: Option<HashMap<String, String>>,
89    /// pub.
90    pub compat: Option<CompatSettings>,
91}
92
93/// Provider configuration in models.json.
94#[derive(Debug, Clone, Serialize, Deserialize)]
95#[serde(rename_all = "camelCase")]
96pub struct ProviderConfig {
97    /// pub.
98    pub name: Option<String>,
99    /// pub.
100    pub base_url: Option<String>,
101    /// pub.
102    pub api_key: Option<String>,
103    /// pub.
104    pub api: Option<Api>,
105    /// pub.
106    pub headers: Option<HashMap<String, String>>,
107    /// pub.
108    pub compat: Option<CompatSettings>,
109    /// pub.
110    pub auth_header: Option<bool>,
111    /// pub.
112    pub models: Option<Vec<ModelDefinition>>,
113    /// pub.
114    pub model_overrides: Option<HashMap<String, ModelOverride>>,
115}
116
117/// Top-level models.json configuration.
118#[derive(Debug, Clone, Serialize, Deserialize)]
119pub struct ModelsConfig {
120    /// pub.
121    pub providers: HashMap<String, ProviderConfig>,
122}
123
124// =============================================================================
125// Provider override (baseUrl, compat) — not request auth/headers
126// =============================================================================
127
128/// Provider-level overrides loaded from models.json (baseUrl, compat).
129#[derive(Debug, Clone, Default)]
130struct ProviderOverride {
131    base_url: Option<String>,
132    compat: Option<CompatSettings>,
133}
134
135/// Provider-level request config (apiKey, headers, authHeader).
136#[derive(Debug, Clone, Default)]
137struct ProviderRequestConfig {
138    api_key: Option<String>,
139    headers: Option<HashMap<String, String>>,
140    auth_header: bool,
141}
142
143/// Result of loading custom models from models.json.
144struct CustomModelsResult {
145    models: Vec<Model>,
146    overrides: HashMap<String, ProviderOverride>,
147    model_overrides: HashMap<String, HashMap<String, ModelOverride>>,
148    error: Option<String>,
149}
150
151fn empty_custom_models_result(error: Option<String>) -> CustomModelsResult {
152    CustomModelsResult {
153        models: vec![],
154        overrides: HashMap::new(),
155        model_overrides: HashMap::new(),
156        error,
157    }
158}
159
160// =============================================================================
161// Resolved request auth
162// =============================================================================
163
164/// Result of resolving API key and headers for a model.
165#[derive(Debug, Clone)]
166pub struct ResolvedRequestAuth {
167    /// Whether resolution was successful.
168    pub ok: bool,
169    /// The API key (if resolved).
170    pub api_key: Option<String>,
171    /// Merged headers (model + provider + models.json).
172    pub headers: Option<HashMap<String, String>>,
173    /// Error message if `ok` is false.
174    pub error: Option<String>,
175}
176
177impl ResolvedRequestAuth {
178    fn ok(api_key: Option<String>, headers: Option<HashMap<String, String>>) -> Self {
179        Self {
180            ok: true,
181            api_key,
182            headers,
183            error: None,
184        }
185    }
186
187    fn err(msg: impl Into<String>) -> Self {
188        Self {
189            ok: false,
190            api_key: None,
191            headers: None,
192            error: Some(msg.into()),
193        }
194    }
195}
196
197// =============================================================================
198// Helper: merge compat
199// =============================================================================
200
201/// Merge a base compat with an override compat. Override fields win.
202fn merge_compat(
203    base: Option<&CompatSettings>,
204    override_compat: Option<&CompatSettings>,
205) -> Option<CompatSettings> {
206    match (base, override_compat) {
207        (None, None) => None,
208        (None, Some(ov)) => Some(ov.clone()),
209        (Some(b), None) => Some(b.clone()),
210        (Some(b), Some(ov)) => {
211            // Merge: override fields take precedence
212            Some(CompatSettings {
213                supports_store: ov.supports_store,
214                supports_developer_role: ov.supports_developer_role,
215                supports_reasoning_effort: ov.supports_reasoning_effort,
216                supports_usage_in_streaming: ov.supports_usage_in_streaming,
217                max_tokens_field: ov.max_tokens_field.or(b.max_tokens_field),
218                requires_tool_result_name: ov.requires_tool_result_name,
219                requires_assistant_after_tool_result: ov.requires_assistant_after_tool_result,
220                requires_thinking_as_text: ov.requires_thinking_as_text,
221                thinking_format: ov.thinking_format.or(b.thinking_format),
222            })
223        }
224    }
225}
226
227// =============================================================================
228// Helper: apply model override
229// =============================================================================
230
231/// Deep merge a model override into a model.
232fn apply_model_override(model: &Model, override_def: &ModelOverride) -> Model {
233    let mut result = model.clone();
234
235    if let Some(ref name) = override_def.name {
236        result.name = name.clone();
237    }
238    if let Some(reasoning) = override_def.reasoning {
239        result.reasoning = reasoning;
240    }
241    if let Some(ref input) = override_def.input {
242        result.input = input.clone();
243    }
244    if let Some(ctx) = override_def.context_window {
245        result.context_window = ctx;
246    }
247    if let Some(mt) = override_def.max_tokens {
248        result.max_tokens = mt;
249    }
250
251    // Merge cost (partial override)
252    if let Some(ref cost) = override_def.cost {
253        result.cost = Cost {
254            input: cost.input.unwrap_or(result.cost.input),
255            output: cost.output.unwrap_or(result.cost.output),
256            cache_read: cost.cache_read.unwrap_or(result.cost.cache_read),
257            cache_write: cost.cache_write.unwrap_or(result.cost.cache_write),
258        };
259    }
260
261    // Deep merge compat
262    result.compat = merge_compat(result.compat.as_ref(), override_def.compat.as_ref());
263
264    result
265}
266
267// =============================================================================
268// Helper: resolve a config value string
269// =============================================================================
270
271/// Resolve a config value. Supports environment variable references:
272/// - `$VAR` or `${VAR}` — resolves to the value of the environment variable.
273/// - Plain string — returned as-is.
274///
275/// The `!` command execution prefix has been removed for security.
276/// Use `$ENV_VAR` references instead.
277fn resolve_config_value(value: &str) -> Option<String> {
278    // Support environment variable references: $VAR or ${VAR}
279    if let Some(stripped) = value.strip_prefix('$') {
280        let var_name = value
281            .strip_prefix("${")
282            .and_then(|s| s.strip_suffix('}'))
283            .unwrap_or(stripped);
284        return std::env::var(var_name).ok().filter(|s| !s.is_empty());
285    }
286    // Command execution (! prefix) has been removed for security
287    if value.starts_with('!') {
288        tracing::warn!(
289            "Command execution in config values (! prefix) is no longer supported for security. Use $ENV_VAR instead. Value: {}",
290            value
291        );
292        return None;
293    }
294    Some(value.to_string())
295}
296
297/// Resolve a config value or throw with a descriptive error message.
298///
299/// Supports `$VAR` / `${VAR}` environment variable references.
300/// The `!` command execution prefix has been removed for security.
301fn resolve_config_value_or_throw(value: &str, label: &str) -> Result<String, String> {
302    if let Some(stripped) = value.strip_prefix('$') {
303        let var_name = value
304            .strip_prefix("${")
305            .and_then(|s| s.strip_suffix('}'))
306            .unwrap_or(stripped);
307        std::env::var(var_name)
308            .map_err(|_| format!("Environment variable {} not set for {}", var_name, label))
309    } else if value.starts_with('!') {
310        tracing::warn!(
311            "Command execution in config values (! prefix) is no longer supported for security. Use $ENV_VAR instead. Value: {}",
312            value
313        );
314        Err(format!(
315            "Command execution (! prefix) is no longer supported for {}. Use $ENV_VAR instead.",
316            label
317        ))
318    } else {
319        Ok(value.to_string())
320    }
321}
322
323/// Resolve optional headers map: if any value is an env-var ref, resolve it.
324fn resolve_headers(
325    headers: Option<&HashMap<String, String>>,
326    label: &str,
327) -> Result<Option<HashMap<String, String>>, String> {
328    let Some(h) = headers else {
329        return Ok(None);
330    };
331    if h.is_empty() {
332        return Ok(None);
333    }
334    let mut resolved = HashMap::new();
335    for (k, v) in h {
336        // If value looks like an env var (uppercase + underscores), try resolving
337        if v.chars().all(|c| c.is_uppercase() || c == '_' || c == ' ') && !v.contains(' ') {
338            if let Ok(rv) = resolve_config_value_or_throw(v, &format!("{}.{}", label, k)) {
339                resolved.insert(k.clone(), rv);
340            } else {
341                resolved.insert(k.clone(), v.clone());
342            }
343        } else {
344            resolved.insert(k.clone(), v.clone());
345        }
346    }
347    Ok(Some(resolved))
348}
349
350// =============================================================================
351// Model Registry
352// =============================================================================
353
354/// CLI-specific model registry with auth storage integration.
355///
356/// This extends the base model data from `oxi_ai::model_db` with:
357/// - API key resolution via `AuthStorage`
358/// - `models.json` file parsing
359/// - OAuth token detection
360/// - Available model filtering (only models with configured auth)
361///
362/// For SDK usage without CLI-specific features, use `oxi_ai::ModelRegistry` instead.
363pub struct CliModelRegistry {
364    /// All loaded models (built-in + custom).
365    models: RwLock<Vec<Model>>,
366    /// Per-provider request configs loaded from models.json.
367    provider_request_configs: RwLock<HashMap<String, ProviderRequestConfig>>,
368    /// Per-model request headers loaded from models.json.
369    model_request_headers: RwLock<HashMap<String, HashMap<String, String>>>,
370    /// Dynamically registered providers (from extensions).
371    registered_providers: RwLock<HashMap<String, ProviderConfigInput>>,
372    /// Error from loading models.json (if any).
373    load_error: RwLock<Option<String>>,
374    /// Auth storage for credential lookup.
375    auth_storage: AuthStorage,
376    /// Path to models.json (None for in-memory).
377    models_json_path: Option<PathBuf>,
378}
379
380/// Backward-compatible alias for [`CliModelRegistry`].
381pub type ModelRegistry = CliModelRegistry;
382
383impl CliModelRegistry {
384    /// Create a new `ModelRegistry` that loads models from the given path.
385    ///
386    /// If `models_json_path` is `None`, falls back to
387    /// `$HOME/.oxi/models.json` (or the XDG config dir equivalent).
388    pub fn create(auth_storage: AuthStorage, models_json_path: Option<PathBuf>) -> Self {
389        let models_json_path = models_json_path
390            .or_else(|| dirs::config_dir().map(|p| p.join("oxi").join("models.json")));
391
392        let registry = Self {
393            models: RwLock::new(Vec::new()),
394            provider_request_configs: RwLock::new(HashMap::new()),
395            model_request_headers: RwLock::new(HashMap::new()),
396            registered_providers: RwLock::new(HashMap::new()),
397            load_error: RwLock::new(None),
398            auth_storage,
399            models_json_path,
400        };
401
402        registry.load_models_internal();
403        registry
404    }
405
406    /// Create an in-memory registry (no models.json loading).
407    pub fn in_memory(auth_storage: AuthStorage) -> Self {
408        Self {
409            models: RwLock::new(Vec::new()),
410            provider_request_configs: RwLock::new(HashMap::new()),
411            model_request_headers: RwLock::new(HashMap::new()),
412            registered_providers: RwLock::new(HashMap::new()),
413            load_error: RwLock::new(None),
414            auth_storage,
415            models_json_path: None,
416        }
417    }
418
419    // =========================================================================
420    // Public API
421    // =========================================================================
422
423    /// Reload models from disk (built-in + custom from models.json).
424    pub fn refresh(&self) {
425        // Clear caches
426        self.provider_request_configs.write().clear();
427        self.model_request_headers.write().clear();
428        *self.load_error.write() = None;
429
430        self.load_models_internal();
431
432        // Re-apply registered providers
433        let providers = self.registered_providers.read().clone();
434        for (name, config) in &providers {
435            self.apply_provider_config(name, config);
436        }
437    }
438
439    /// Get any error from loading models.json.
440    pub fn get_error(&self) -> Option<String> {
441        self.load_error.read().clone()
442    }
443
444    /// Get all models (built-in + custom).
445    pub fn get_all(&self) -> Vec<Model> {
446        self.models.read().clone()
447    }
448
449    /// Get only models that have auth configured.
450    pub fn get_available(&self) -> Vec<Model> {
451        self.models
452            .read()
453            .iter()
454            .filter(|m| self.has_configured_auth(m))
455            .cloned()
456            .collect()
457    }
458
459    /// Find a model by provider and ID.
460    pub fn find(&self, provider: &str, model_id: &str) -> Option<Model> {
461        self.models
462            .read()
463            .iter()
464            .find(|m| m.provider == provider && m.id == model_id)
465            .cloned()
466    }
467
468    /// Resolve a model string (e.g. `"anthropic/claude-sonnet-4-20250514"`)
469    /// into a `Model` object.
470    ///
471    /// Supports:
472    /// - `"provider/model-id"`
473    /// - `"model-id"` (searches all providers)
474    pub fn resolve_model(&self, model_str: &str) -> Option<Model> {
475        let models = self.models.read();
476
477        if let Some(slash) = model_str.find('/') {
478            let provider = &model_str[..slash];
479            let id = &model_str[slash + 1..];
480            return models
481                .iter()
482                .find(|m| m.provider == provider && m.id == id)
483                .cloned();
484        }
485
486        // Search by ID across all providers
487        let matches: Vec<_> = models.iter().filter(|m| m.id == model_str).collect();
488
489        if matches.len() == 1 {
490            return Some(matches[0].clone());
491        }
492
493        // Multiple matches — prefer first, warn about ambiguity
494        if !matches.is_empty() {
495            if matches.len() > 1 {
496                tracing::warn!(
497                    "Ambiguous model ID '{}' matches providers: {}. Using first match.",
498                    model_str,
499                    matches
500                        .iter()
501                        .map(|m| m.provider.as_str())
502                        .collect::<Vec<_>>()
503                        .join(", ")
504                );
505            }
506            return Some(matches[0].clone());
507        }
508
509        // Fuzzy: try case-insensitive contains
510        let lower = model_str.to_lowercase();
511        models
512            .iter()
513            .find(|m| {
514                m.id.to_lowercase().contains(&lower) || m.name.to_lowercase().contains(&lower)
515            })
516            .cloned()
517    }
518
519    /// Check if a model has auth configured.
520    pub fn has_configured_auth(&self, model: &Model) -> bool {
521        self.auth_storage.has_auth(&model.provider)
522            || self
523                .provider_request_configs
524                .read()
525                .get(&model.provider)
526                .and_then(|c| c.api_key.as_ref())
527                .is_some()
528    }
529
530    /// Get API key and request headers for a model.
531    pub fn get_api_key_and_headers(&self, model: &Model) -> ResolvedRequestAuth {
532        self.get_api_key_and_headers_impl(model)
533    }
534
535    /// Check if a model is using OAuth credentials.
536    pub fn is_using_oauth(&self, model: &Model) -> bool {
537        let cred = self.auth_storage.get_oauth_credential(&model.provider);
538        cred.is_some()
539    }
540
541    /// Get display name for a provider.
542    ///
543    /// Returns a static string for known providers, or the raw provider string.
544    pub fn get_provider_display_name(&self, provider: &str) -> String {
545        // Check registered providers for a custom name
546        // Note: dynamic names are handled by get_provider_display_name_owned()
547
548        get_builtin_provider(provider)
549            .map(|p| p.display_name)
550            .unwrap_or(provider)
551            .to_string()
552    }
553
554    /// Get the display name for a provider, allocating if necessary for
555    /// dynamically registered names.
556    pub fn get_provider_display_name_owned(&self, provider: &str) -> String {
557        if let Some(config) = self.registered_providers.read().get(provider) {
558            if let Some(ref name) = config.name {
559                return name.clone();
560            }
561        }
562
563        get_builtin_provider(provider)
564            .map(|p| p.display_name)
565            .unwrap_or(provider)
566            .to_string()
567    }
568
569    /// Return auth status for a provider, including request auth configured
570    /// in models.json.
571    pub fn get_provider_auth_status(&self, provider: &str) -> AuthStatus {
572        let auth_status = self.auth_storage.get_status(provider);
573        if auth_status.source.is_some() {
574            return auth_status;
575        }
576
577        let provider_api_key = self
578            .provider_request_configs
579            .read()
580            .get(provider)
581            .and_then(|c| c.api_key.clone());
582
583        let Some(ref api_key_ref) = provider_api_key else {
584            return auth_status;
585        };
586
587        if api_key_ref.starts_with('$') {
588            return AuthStatus {
589                configured: true,
590                source: Some("models_json_env_var".to_string()),
591                label: None,
592            };
593        } else if api_key_ref.starts_with('!') {
594            // Deprecated: ! prefix no longer supported
595            tracing::warn!(
596                "Command execution (! prefix) in apiKey is no longer supported. Use $ENV_VAR instead."
597            );
598            return AuthStatus {
599                configured: false,
600                source: Some("models_json_command_deprecated".to_string()),
601                label: None,
602            };
603        }
604
605        // Plain string value — always available as config
606        AuthStatus {
607            configured: true,
608            source: Some("models_json_key".to_string()),
609            label: Some(api_key_ref.clone()),
610        }
611    }
612
613    /// Get API key for a provider.
614    pub fn get_api_key_for_provider(&self, provider: &str) -> Option<String> {
615        // Try auth storage first
616        if let Some(key) = self.auth_storage.get_api_key(provider) {
617            return Some(key);
618        }
619
620        // Try provider request config from models.json
621        let api_key_str = self
622            .provider_request_configs
623            .read()
624            .get(provider)
625            .and_then(|c| c.api_key.clone())?;
626
627        resolve_config_value(&api_key_str)
628    }
629
630    /// Get all unique providers from loaded models.
631    pub fn get_available_providers(&self) -> Vec<String> {
632        let mut providers: Vec<String> = self
633            .models
634            .read()
635            .iter()
636            .map(|m| m.provider.clone())
637            .collect();
638        providers.sort();
639        providers.dedup();
640        providers
641    }
642
643    /// Get all providers that have credentials configured.
644    pub fn get_providers_with_credentials(&self) -> Vec<String> {
645        let providers = self.get_available_providers();
646        providers
647            .into_iter()
648            .filter(|p| {
649                self.auth_storage.has_auth(p)
650                    || self
651                        .provider_request_configs
652                        .read()
653                        .get(p)
654                        .and_then(|c| c.api_key.as_ref())
655                        .is_some()
656            })
657            .collect()
658    }
659
660    /// Get all models that have credentials configured.
661    pub fn get_available_models(&self) -> Vec<Model> {
662        self.get_available()
663    }
664
665    /// Get the user's default model.
666    ///
667    /// Returns the first available model, preferring:
668    /// 1. Anthropic Claude Sonnet 4
669    /// 2. OpenAI GPT-4o
670    /// 3. Any available model
671    pub fn get_default_model(&self) -> Option<Model> {
672        let available = self.get_available();
673
674        // Prefer known defaults
675        let preferred = [
676            ("anthropic", "claude-sonnet-4-20250514"),
677            ("anthropic", "claude-sonnet-4-5"),
678            ("anthropic", "claude-sonnet-4-6"),
679            ("openai", "gpt-4o"),
680        ];
681
682        for (provider, id) in &preferred {
683            if let Some(model) = available
684                .iter()
685                .find(|m| m.provider == *provider && m.id == *id)
686            {
687                return Some(model.clone());
688            }
689        }
690
691        available.into_iter().next()
692    }
693
694    /// Register a provider dynamically (from extensions).
695    pub fn register_provider(&self, provider_name: &str, config: ProviderConfigInput) {
696        self.apply_provider_config(provider_name, &config);
697        self.upsert_registered_provider(provider_name, config);
698    }
699
700    /// Unregister a previously registered provider.
701    pub fn unregister_provider(&self, provider_name: &str) {
702        if !self.registered_providers.read().contains_key(provider_name) {
703            return;
704        }
705        self.registered_providers.write().remove(provider_name);
706        self.refresh();
707    }
708
709    // =========================================================================
710    // Model filtering
711    // =========================================================================
712
713    /// Filter models by provider.
714    pub fn filter_by_provider(&self, provider: &str) -> Vec<Model> {
715        self.models
716            .read()
717            .iter()
718            .filter(|m| m.provider == provider)
719            .cloned()
720            .collect()
721    }
722
723    /// Filter models by capability (reasoning).
724    pub fn filter_by_capability(&self, reasoning: bool) -> Vec<Model> {
725        self.models
726            .read()
727            .iter()
728            .filter(|m| m.reasoning == reasoning)
729            .cloned()
730            .collect()
731    }
732
733    /// Filter models by input modality (e.g., vision/image support).
734    pub fn filter_by_modality(&self, modality: InputModality) -> Vec<Model> {
735        self.models
736            .read()
737            .iter()
738            .filter(|m| m.input.contains(&modality))
739            .cloned()
740            .collect()
741    }
742
743    /// Search models by pattern (case-insensitive substring match on id or name).
744    pub fn search(&self, pattern: &str) -> Vec<Model> {
745        let lower = pattern.to_lowercase();
746        self.models
747            .read()
748            .iter()
749            .filter(|m| {
750                m.id.to_lowercase().contains(&lower) || m.name.to_lowercase().contains(&lower)
751            })
752            .cloned()
753            .collect()
754    }
755
756    // =========================================================================
757    // Private: model loading
758    // =========================================================================
759
760    fn load_models_internal(&self) {
761        // Load custom models and overrides from models.json
762        let custom_result = match self.models_json_path {
763            Some(ref path) => self.load_custom_models(path),
764            None => empty_custom_models_result(None),
765        };
766
767        if let Some(ref error) = custom_result.error {
768            *self.load_error.write() = Some(error.clone());
769            // Keep built-in models even if custom models failed to load
770        }
771
772        let built_in =
773            self.load_built_in_models(&custom_result.overrides, &custom_result.model_overrides);
774        let combined = self.merge_custom_models(built_in, &custom_result.models);
775
776        *self.models.write() = combined;
777    }
778
779    /// Load built-in models from model_db and apply provider/model overrides.
780    fn load_built_in_models(
781        &self,
782        overrides: &HashMap<String, ProviderOverride>,
783        model_overrides: &HashMap<String, HashMap<String, ModelOverride>>,
784    ) -> Vec<Model> {
785        let mut result = Vec::new();
786
787        for provider_name in model_db::get_providers() {
788            let entries = model_db::get_provider_models(provider_name);
789            let provider_override = overrides.get(provider_name);
790            let per_model_overrides = model_overrides.get(provider_name);
791
792            for entry in entries {
793                // Convert ModelEntry -> Model
794                let mut model = Model {
795                    id: entry.id.to_string(),
796                    name: entry.name.to_string(),
797                    api: entry.api,
798                    provider: entry.provider.to_string(),
799                    base_url: self.default_base_url_for_provider(entry.provider),
800                    reasoning: entry.reasoning,
801                    input: entry.input.to_vec(),
802                    cost: Cost {
803                        input: entry.cost_input,
804                        output: entry.cost_output,
805                        cache_read: entry.cost_cache_read,
806                        cache_write: entry.cost_cache_write,
807                    },
808                    context_window: entry.context_window as usize,
809                    max_tokens: entry.max_tokens as usize,
810                    headers: HashMap::new(),
811                    compat: None,
812                };
813
814                // Apply provider-level baseUrl/compat override
815                if let Some(po) = provider_override {
816                    if let Some(ref url) = po.base_url {
817                        model.base_url = url.clone();
818                    }
819                    model.compat = merge_compat(model.compat.as_ref(), po.compat.as_ref());
820                }
821
822                // Apply per-model override
823                if let Some(per_model) = per_model_overrides {
824                    if let Some(mo) = per_model.get(entry.id) {
825                        model = apply_model_override(&model, mo);
826                    }
827                }
828
829                result.push(model);
830            }
831        }
832
833        result
834    }
835
836    /// Merge custom models into built-in list by provider+id (custom wins on conflicts).
837    fn merge_custom_models(&self, built_in: Vec<Model>, custom: &[Model]) -> Vec<Model> {
838        let mut merged = built_in;
839
840        for custom_model in custom {
841            if let Some(idx) = merged
842                .iter()
843                .position(|m| m.provider == custom_model.provider && m.id == custom_model.id)
844            {
845                merged[idx] = custom_model.clone();
846            } else {
847                merged.push(custom_model.clone());
848            }
849        }
850
851        merged
852    }
853
854    // =========================================================================
855    // Private: models.json loading
856    // =========================================================================
857
858    fn load_custom_models(&self, path: &Path) -> CustomModelsResult {
859        if !path.exists() {
860            return empty_custom_models_result(None);
861        }
862
863        let content = match std::fs::read_to_string(path) {
864            Ok(c) => c,
865            Err(e) => {
866                return empty_custom_models_result(Some(format!(
867                    "Failed to read models.json: {}\n\nFile: {}",
868                    e,
869                    path.display()
870                )));
871            }
872        };
873
874        let config: ModelsConfig = match serde_json::from_str(&content) {
875            Ok(c) => c,
876            Err(e) => {
877                return empty_custom_models_result(Some(format!(
878                    "Failed to parse models.json: {}\n\nFile: {}",
879                    e,
880                    path.display()
881                )));
882            }
883        };
884
885        // Additional validation
886        if let Err(e) = self.validate_config(&config) {
887            return empty_custom_models_result(Some(format!(
888                "Invalid models.json: {}\n\nFile: {}",
889                e,
890                path.display()
891            )));
892        }
893
894        let mut overrides = HashMap::new();
895        let mut model_overrides_map: HashMap<String, HashMap<String, ModelOverride>> =
896            HashMap::new();
897
898        let built_in_providers: Vec<&str> = model_db::get_providers();
899
900        for (provider_name, provider_config) in &config.providers {
901            // Store provider-level overrides
902            if provider_config.base_url.is_some() || provider_config.compat.is_some() {
903                overrides.insert(
904                    provider_name.clone(),
905                    ProviderOverride {
906                        base_url: provider_config.base_url.clone(),
907                        compat: provider_config.compat.clone(),
908                    },
909                );
910            }
911
912            // Store provider request config (apiKey, headers, authHeader)
913            self.store_provider_request_config(provider_name, provider_config);
914
915            // Store per-model overrides
916            if let Some(ref model_overrides) = provider_config.model_overrides {
917                model_overrides_map.insert(provider_name.clone(), model_overrides.clone());
918                for (model_id, model_override) in model_overrides {
919                    self.store_model_headers(
920                        provider_name,
921                        model_id,
922                        model_override.headers.as_ref(),
923                    );
924                }
925            }
926        }
927
928        // Warn about apiKey in models.json — recommend env var references instead
929        for (provider_name, provider_config) in &config.providers {
930            if provider_config.api_key.is_some() {
931                tracing::warn!(
932                    "models.json contains apiKey for provider '{}'. Consider using $ENV_VAR reference instead.",
933                    provider_name
934                );
935            }
936        }
937
938        let models = self.parse_models(&config, &built_in_providers);
939
940        CustomModelsResult {
941            models,
942            overrides,
943            model_overrides: model_overrides_map,
944            error: None,
945        }
946    }
947
948    fn validate_config(&self, config: &ModelsConfig) -> Result<(), String> {
949        let built_in_providers: Vec<&str> = model_db::get_providers();
950
951        for (provider_name, provider_config) in &config.providers {
952            let is_built_in = built_in_providers.contains(&provider_name.as_str());
953            let models = provider_config.models.as_deref().unwrap_or(&[]);
954            let has_model_overrides = provider_config
955                .model_overrides
956                .as_ref()
957                .map(|m| !m.is_empty())
958                .unwrap_or(false);
959
960            if models.is_empty() {
961                // Override-only config
962                if provider_config.base_url.is_none()
963                    && provider_config.headers.is_none()
964                    && provider_config.compat.is_none()
965                    && !has_model_overrides
966                {
967                    return Err(format!(
968                        "Provider {}: must specify \"baseUrl\", \"headers\", \"compat\", \"modelOverrides\", or \"models\".",
969                        provider_name
970                    ));
971                }
972            } else if !is_built_in {
973                // Non-built-in providers with custom models require endpoint + auth
974                if provider_config.base_url.is_none() {
975                    return Err(format!(
976                        "Provider {}: \"baseUrl\" is required when defining custom models.",
977                        provider_name
978                    ));
979                }
980                if provider_config.api_key.is_none() {
981                    return Err(format!(
982                        "Provider {}: \"apiKey\" is required when defining custom models.",
983                        provider_name
984                    ));
985                }
986            }
987
988            for model_def in models {
989                let has_model_api = model_def.api.is_some();
990                let has_provider_api = provider_config.api.is_some();
991
992                if !has_provider_api && !has_model_api && !is_built_in {
993                    return Err(format!(
994                        "Provider {}, model {}: no \"api\" specified. Set at provider or model level.",
995                        provider_name, model_def.id
996                    ));
997                }
998
999                if model_def.context_window.is_some_and(|cw| cw == 0) {
1000                    return Err(format!(
1001                        "Provider {}, model {}: invalid contextWindow",
1002                        provider_name, model_def.id
1003                    ));
1004                }
1005                if model_def.max_tokens.is_some_and(|mt| mt == 0) {
1006                    return Err(format!(
1007                        "Provider {}, model {}: invalid maxTokens",
1008                        provider_name, model_def.id
1009                    ));
1010                }
1011            }
1012        }
1013
1014        Ok(())
1015    }
1016
1017    fn parse_models(&self, config: &ModelsConfig, built_in_providers: &[&str]) -> Vec<Model> {
1018        let mut models = Vec::new();
1019
1020        // Cache built-in defaults per provider
1021        let mut defaults_cache: HashMap<String, (Api, String)> = HashMap::new();
1022
1023        for (provider_name, provider_config) in &config.providers {
1024            let model_defs = match provider_config.models {
1025                Some(ref m) if !m.is_empty() => m,
1026                _ => continue,
1027            };
1028
1029            let is_built_in = built_in_providers.contains(&provider_name.as_str());
1030
1031            // Get built-in defaults (api, baseUrl) for this provider
1032            let built_in_defaults = if is_built_in {
1033                if !defaults_cache.contains_key(provider_name) {
1034                    let entries = model_db::get_provider_models(provider_name.as_str());
1035                    if let Some(first) = entries.first() {
1036                        defaults_cache.insert(
1037                            provider_name.clone(),
1038                            (
1039                                first.api,
1040                                self.default_base_url_for_provider(provider_name.as_str()),
1041                            ),
1042                        );
1043                    }
1044                }
1045                defaults_cache.get(provider_name)
1046            } else {
1047                None
1048            };
1049
1050            for model_def in model_defs {
1051                let api = model_def
1052                    .api
1053                    .or(provider_config.api)
1054                    .or(built_in_defaults.map(|(a, _)| *a));
1055
1056                let Some(api) = api else { continue };
1057
1058                let base_url = model_def
1059                    .base_url
1060                    .as_deref()
1061                    .or(provider_config.base_url.as_deref())
1062                    .or(built_in_defaults.map(|(_, u)| u.as_str()));
1063
1064                let Some(base_url) = base_url else { continue };
1065
1066                let compat =
1067                    merge_compat(provider_config.compat.as_ref(), model_def.compat.as_ref());
1068
1069                self.store_model_headers(provider_name, &model_def.id, model_def.headers.as_ref());
1070
1071                models.push(Model {
1072                    id: model_def.id.clone(),
1073                    name: model_def
1074                        .name
1075                        .clone()
1076                        .unwrap_or_else(|| model_def.id.clone()),
1077                    api,
1078                    provider: provider_name.clone(),
1079                    base_url: base_url.to_string(),
1080                    reasoning: model_def.reasoning.unwrap_or(false),
1081                    input: model_def
1082                        .input
1083                        .clone()
1084                        .unwrap_or_else(|| vec![InputModality::Text]),
1085                    cost: model_def.cost.clone().unwrap_or(Cost {
1086                        input: 0.0,
1087                        output: 0.0,
1088                        cache_read: 0.0,
1089                        cache_write: 0.0,
1090                    }),
1091                    context_window: model_def.context_window.unwrap_or(128_000),
1092                    max_tokens: model_def.max_tokens.unwrap_or(16_384),
1093                    headers: HashMap::new(),
1094                    compat,
1095                });
1096            }
1097        }
1098
1099        models
1100    }
1101
1102    // =========================================================================
1103    // Private: provider config helpers
1104    // =========================================================================
1105
1106    fn store_provider_request_config(&self, provider_name: &str, config: &ProviderConfig) {
1107        if config.api_key.is_none() && config.headers.is_none() && config.auth_header.is_none() {
1108            return;
1109        }
1110
1111        self.provider_request_configs.write().insert(
1112            provider_name.to_string(),
1113            ProviderRequestConfig {
1114                api_key: config.api_key.clone(),
1115                headers: config.headers.clone(),
1116                auth_header: config.auth_header.unwrap_or(false),
1117            },
1118        );
1119    }
1120
1121    fn store_model_headers(
1122        &self,
1123        provider_name: &str,
1124        model_id: &str,
1125        headers: Option<&HashMap<String, String>>,
1126    ) {
1127        let key = format!("{}:{}", provider_name, model_id);
1128        let mut hdr_map = self.model_request_headers.write();
1129
1130        match headers {
1131            Some(h) if !h.is_empty() => {
1132                hdr_map.insert(key, h.clone());
1133            }
1134            _ => {
1135                hdr_map.remove(&key);
1136            }
1137        }
1138    }
1139
1140    fn store_provider_request_config_from_input(
1141        &self,
1142        provider_name: &str,
1143        config: &ProviderConfigInput,
1144    ) {
1145        if config.api_key.is_none() && config.headers.is_none() && !config.auth_header {
1146            return;
1147        }
1148
1149        self.provider_request_configs.write().insert(
1150            provider_name.to_string(),
1151            ProviderRequestConfig {
1152                api_key: config.api_key.clone(),
1153                headers: config.headers.clone(),
1154                auth_header: config.auth_header,
1155            },
1156        );
1157    }
1158
1159    // =========================================================================
1160    // Private: API key resolution
1161    // =========================================================================
1162
1163    fn get_api_key_and_headers_impl(&self, model: &Model) -> ResolvedRequestAuth {
1164        let provider_config = self
1165            .provider_request_configs
1166            .read()
1167            .get(&model.provider)
1168            .cloned();
1169
1170        // Try auth storage
1171        let api_key_from_storage = self.auth_storage.get_api_key(&model.provider);
1172
1173        // Try provider config from models.json
1174        let api_key = match api_key_from_storage {
1175            Some(key) => Some(key),
1176            None => provider_config
1177                .as_ref()
1178                .and_then(|c| c.api_key.clone())
1179                .and_then(|raw| {
1180                    resolve_config_value_or_throw(
1181                        &raw,
1182                        &format!("API key for provider \"{}\"", model.provider),
1183                    )
1184                    .ok()
1185                }),
1186        };
1187
1188        // Resolve headers
1189        let provider_headers = resolve_headers(
1190            provider_config.as_ref().and_then(|c| c.headers.as_ref()),
1191            &format!("provider \"{}\"", model.provider),
1192        );
1193
1194        let model_headers_key = format!("{}:{}", model.provider, model.id);
1195        let model_headers_raw = self
1196            .model_request_headers
1197            .read()
1198            .get(&model_headers_key)
1199            .cloned();
1200        let model_headers = resolve_headers(
1201            model_headers_raw.as_ref(),
1202            &format!("model \"{}/{}\"", model.provider, model.id),
1203        );
1204
1205        // Merge headers: model.headers < provider_headers < model_headers
1206        let mut headers: HashMap<String, String> = HashMap::new();
1207        if !model.headers.is_empty() {
1208            headers.extend(model.headers.clone());
1209        }
1210        if let Ok(Some(ph)) = provider_headers {
1211            headers.extend(ph);
1212        }
1213        if let Ok(Some(mh)) = model_headers {
1214            headers.extend(mh);
1215        }
1216
1217        // If authHeader is set, add Authorization: Bearer
1218        if provider_config
1219            .as_ref()
1220            .map(|c| c.auth_header)
1221            .unwrap_or(false)
1222        {
1223            let Some(ref key) = api_key else {
1224                return ResolvedRequestAuth::err(format!(
1225                    "No API key found for \"{}\"",
1226                    model.provider
1227                ));
1228            };
1229            headers.insert("Authorization".to_string(), format!("Bearer {}", key));
1230        }
1231
1232        let headers = if headers.is_empty() {
1233            None
1234        } else {
1235            Some(headers)
1236        };
1237
1238        ResolvedRequestAuth::ok(api_key, headers)
1239    }
1240
1241    // =========================================================================
1242    // Private: provider registration
1243    // =========================================================================
1244
1245    fn apply_provider_config(&self, provider_name: &str, config: &ProviderConfigInput) {
1246        self.store_provider_request_config_from_input(provider_name, config);
1247
1248        if let Some(ref models) = config.models {
1249            if !models.is_empty() {
1250                // Full replacement: remove existing models for this provider
1251                let mut all_models = self.models.write();
1252                all_models.retain(|m| m.provider != provider_name);
1253
1254                for model_def in models {
1255                    let api = model_def.api.or(config.api);
1256                    let base_url = model_def
1257                        .base_url
1258                        .as_deref()
1259                        .or(config.base_url.as_deref())
1260                        .unwrap_or("");
1261
1262                    self.store_model_headers(
1263                        provider_name,
1264                        &model_def.id,
1265                        model_def.headers.as_ref(),
1266                    );
1267
1268                    all_models.push(Model {
1269                        id: model_def.id.clone(),
1270                        name: model_def
1271                            .name
1272                            .clone()
1273                            .unwrap_or_else(|| model_def.id.clone()),
1274                        api: api.unwrap_or(Api::OpenAiCompletions),
1275                        provider: provider_name.to_string(),
1276                        base_url: base_url.to_string(),
1277                        reasoning: model_def.reasoning.unwrap_or(false),
1278                        input: model_def
1279                            .input
1280                            .clone()
1281                            .unwrap_or_else(|| vec![InputModality::Text]),
1282                        cost: model_def.cost.clone().unwrap_or_default(),
1283                        context_window: model_def.context_window.unwrap_or(128_000),
1284                        max_tokens: model_def.max_tokens.unwrap_or(16_384),
1285                        headers: HashMap::new(),
1286                        compat: model_def.compat.clone(),
1287                    });
1288                }
1289            }
1290        } else if config.base_url.is_some() {
1291            // Override-only: update baseUrl for existing models
1292            let mut all_models = self.models.write();
1293            if let Some(ref base_url) = config.base_url {
1294                for m in all_models.iter_mut() {
1295                    if m.provider == provider_name {
1296                        m.base_url = base_url.clone();
1297                    }
1298                }
1299            }
1300        }
1301    }
1302
1303    fn upsert_registered_provider(&self, provider_name: &str, config: ProviderConfigInput) {
1304        let mut providers = self.registered_providers.write();
1305        match providers.get_mut(provider_name) {
1306            Some(existing) => {
1307                // Merge: defined values in incoming override, preserve undefined
1308                if config.name.is_some() {
1309                    existing.name = config.name.clone();
1310                }
1311                if config.base_url.is_some() {
1312                    existing.base_url = config.base_url.clone();
1313                }
1314                if config.api_key.is_some() {
1315                    existing.api_key = config.api_key.clone();
1316                }
1317                if config.api.is_some() {
1318                    existing.api = config.api;
1319                }
1320                if config.headers.is_some() {
1321                    existing.headers = config.headers.clone();
1322                }
1323                if config.auth_header {
1324                    existing.auth_header = config.auth_header;
1325                }
1326                if config.models.is_some() {
1327                    existing.models = config.models.clone();
1328                }
1329            }
1330            None => {
1331                providers.insert(provider_name.to_string(), config);
1332            }
1333        }
1334    }
1335
1336    // =========================================================================
1337    // Private: utilities
1338    // =========================================================================
1339
1340    /// Get the default base URL for a known provider.
1341    fn default_base_url_for_provider(&self, provider: &str) -> String {
1342        // Try model_db to find a model and extract its base URL
1343        // We don't store base URLs in model_db, so use a lookup table.
1344        match provider {
1345            "anthropic" => "https://api.anthropic.com".to_string(),
1346            "openai" => "https://api.openai.com/v1".to_string(),
1347            "google" => "https://generativelanguage.googleapis.com".to_string(),
1348            "google-vertex" => "https://us-central1-aiplatform.googleapis.com".to_string(),
1349            "deepseek" => "https://api.deepseek.com".to_string(),
1350            "mistral" => "https://api.mistral.ai".to_string(),
1351            "groq" => "https://api.groq.com/openai/v1".to_string(),
1352            "cerebras" => "https://api.cerebras.ai".to_string(),
1353            "xai" => "https://api.x.ai/v1".to_string(),
1354            "openrouter" => "https://openrouter.ai/api/v1".to_string(),
1355            "azure-openai-responses" => "https://{resource}.openai.azure.com".to_string(),
1356            "amazon-bedrock" => "https://bedrock-runtime.us-east-1.amazonaws.com".to_string(),
1357            _ => "".to_string(),
1358        }
1359    }
1360}
1361
1362// =============================================================================
1363// Provider config input (for dynamic registration from extensions)
1364// =============================================================================
1365
1366/// Input type for `register_provider` API (from extensions).
1367#[derive(Debug, Clone, Default)]
1368pub struct ProviderConfigInput {
1369    /// Display name for the provider.
1370    pub name: Option<String>,
1371    /// Base URL for the provider's API.
1372    pub base_url: Option<String>,
1373    /// API key (may be an env var name or a literal key).
1374    pub api_key: Option<String>,
1375    /// API protocol to use.
1376    pub api: Option<Api>,
1377    /// Additional headers to send with every request.
1378    pub headers: Option<HashMap<String, String>>,
1379    /// Whether to send `Authorization: Bearer <apiKey>` header.
1380    pub auth_header: bool,
1381    /// Models provided by this provider.
1382    pub models: Option<Vec<ModelDefinition>>,
1383}
1384
1385// =============================================================================
1386// Tests
1387// =============================================================================
1388
1389#[cfg(test)]
1390mod tests {
1391    use super::*;
1392
1393    fn test_registry() -> ModelRegistry {
1394        ModelRegistry::in_memory(AuthStorage::in_memory())
1395    }
1396
1397    #[test]
1398    fn test_in_memory_registry() {
1399        let registry = test_registry();
1400        // No models loaded for in-memory
1401        assert!(registry.get_all().is_empty());
1402    }
1403
1404    #[test]
1405    fn test_get_all_providers() {
1406        let registry = ModelRegistry::create(AuthStorage::in_memory(), None);
1407        let providers = registry.get_available_providers();
1408        assert!(!providers.is_empty());
1409        assert!(providers.contains(&"anthropic".to_string()));
1410        assert!(providers.contains(&"openai".to_string()));
1411    }
1412
1413    #[test]
1414    fn test_resolve_model_by_provider_id() {
1415        let registry = ModelRegistry::create(AuthStorage::in_memory(), None);
1416        let model = registry.resolve_model("anthropic/claude-sonnet-4-20250514");
1417        assert!(model.is_some());
1418        let m = model.unwrap();
1419        assert_eq!(m.provider, "anthropic");
1420        assert_eq!(m.id, "claude-sonnet-4-20250514");
1421    }
1422
1423    #[test]
1424    fn test_resolve_model_by_id_only() {
1425        let registry = ModelRegistry::create(AuthStorage::in_memory(), None);
1426        let model = registry.resolve_model("claude-sonnet-4-20250514");
1427        assert!(model.is_some());
1428        assert_eq!(model.unwrap().id, "claude-sonnet-4-20250514");
1429    }
1430
1431    #[test]
1432    fn test_resolve_model_fuzzy() {
1433        let registry = ModelRegistry::create(AuthStorage::in_memory(), None);
1434        let model = registry.resolve_model("sonnet 4");
1435        assert!(model.is_some());
1436    }
1437
1438    #[test]
1439    fn test_find_model() {
1440        let registry = ModelRegistry::create(AuthStorage::in_memory(), None);
1441        let model = registry.find("anthropic", "claude-sonnet-4-20250514");
1442        assert!(model.is_some());
1443    }
1444
1445    #[test]
1446    fn test_find_model_not_found() {
1447        let registry = ModelRegistry::create(AuthStorage::in_memory(), None);
1448        let model = registry.find("nonexistent", "model");
1449        assert!(model.is_none());
1450    }
1451
1452    #[test]
1453    fn test_filter_by_provider() {
1454        let registry = ModelRegistry::create(AuthStorage::in_memory(), None);
1455        let anthropic = registry.filter_by_provider("anthropic");
1456        assert!(!anthropic.is_empty());
1457        assert!(anthropic.iter().all(|m| m.provider == "anthropic"));
1458    }
1459
1460    #[test]
1461    fn test_filter_by_capability() {
1462        let registry = ModelRegistry::create(AuthStorage::in_memory(), None);
1463        let reasoning = registry.filter_by_capability(true);
1464        assert!(!reasoning.is_empty());
1465        assert!(reasoning.iter().all(|m| m.reasoning));
1466    }
1467
1468    #[test]
1469    fn test_filter_by_modality() {
1470        let registry = ModelRegistry::create(AuthStorage::in_memory(), None);
1471        let vision = registry.filter_by_modality(InputModality::Image);
1472        assert!(!vision.is_empty());
1473    }
1474
1475    #[test]
1476    fn test_search_models() {
1477        let registry = ModelRegistry::create(AuthStorage::in_memory(), None);
1478        let results = registry.search("claude");
1479        assert!(!results.is_empty());
1480        assert!(results
1481            .iter()
1482            .all(|m| m.id.to_lowercase().contains("claude")
1483                || m.name.to_lowercase().contains("claude")));
1484    }
1485
1486    #[test]
1487    fn test_provider_display_name() {
1488        let registry = test_registry();
1489        assert_eq!(registry.get_provider_display_name("anthropic"), "Anthropic");
1490        assert_eq!(registry.get_provider_display_name("unknown"), "unknown");
1491    }
1492
1493    #[test]
1494    fn test_has_configured_auth_no_auth() {
1495        // Remove all known provider env keys to avoid race conditions with parallel tests
1496        for key in &[
1497            "ANTHROPIC_API_KEY",
1498            "OPENAI_API_KEY",
1499            "GOOGLE_API_KEY",
1500            "GEMINI_API_KEY",
1501            "GROQ_API_KEY",
1502            "MISTRAL_API_KEY",
1503            "DEEPSEEK_API_KEY",
1504            "XAI_API_KEY",
1505            "COHERE_API_KEY",
1506            "CO_API_KEY",
1507            "PERPLEXITY_API_KEY",
1508            "ZAI_API_KEY",
1509            "FIREWORKS_API_KEY",
1510            "OPENROUTER_API_KEY",
1511            "CEREBRAS_API_KEY",
1512            "KIMI_API_KEY",
1513            "MOONSHOT_API_KEY",
1514            "XIAOMI_API_KEY",
1515            "CLOUDFLARE_API_KEY",
1516            "MINIMAX_API_KEY",
1517            "MINIMAX_CN_API_KEY",
1518        ] {
1519            std::env::remove_var(key);
1520        }
1521        let registry = ModelRegistry::create(AuthStorage::in_memory(), None);
1522        let model = registry
1523            .find("anthropic", "claude-sonnet-4-20250514")
1524            .unwrap();
1525        // No auth configured in test
1526        assert!(!registry.has_configured_auth(&model));
1527    }
1528
1529    #[test]
1530    fn test_has_configured_auth_with_env() {
1531        let auth = AuthStorage::in_memory();
1532        auth.set_runtime_key("anthropic", "test-key".to_string());
1533        let registry = ModelRegistry::create(auth, None);
1534        let model = registry
1535            .find("anthropic", "claude-sonnet-4-20250514")
1536            .unwrap();
1537        assert!(registry.has_configured_auth(&model));
1538    }
1539
1540    #[test]
1541    fn test_get_api_key_and_headers_no_auth() {
1542        // Remove all known provider env keys to avoid race conditions with parallel tests
1543        for key in &[
1544            "ANTHROPIC_API_KEY",
1545            "OPENAI_API_KEY",
1546            "GOOGLE_API_KEY",
1547            "GEMINI_API_KEY",
1548            "GROQ_API_KEY",
1549            "MISTRAL_API_KEY",
1550            "DEEPSEEK_API_KEY",
1551            "XAI_API_KEY",
1552            "COHERE_API_KEY",
1553            "CO_API_KEY",
1554            "PERPLEXITY_API_KEY",
1555            "ZAI_API_KEY",
1556            "FIREWORKS_API_KEY",
1557            "OPENROUTER_API_KEY",
1558            "CEREBRAS_API_KEY",
1559            "KIMI_API_KEY",
1560            "MOONSHOT_API_KEY",
1561            "XIAOMI_API_KEY",
1562            "CLOUDFLARE_API_KEY",
1563            "MINIMAX_API_KEY",
1564            "MINIMAX_CN_API_KEY",
1565        ] {
1566            std::env::remove_var(key);
1567        }
1568        let registry = ModelRegistry::create(AuthStorage::in_memory(), None);
1569        let model = registry
1570            .find("anthropic", "claude-sonnet-4-20250514")
1571            .unwrap();
1572        let result = registry.get_api_key_and_headers(&model);
1573        assert!(result.ok);
1574        assert!(result.api_key.is_none());
1575    }
1576
1577    #[test]
1578    fn test_is_using_oauth_false() {
1579        let registry = test_registry();
1580        let model = Model {
1581            id: "test".to_string(),
1582            name: "Test".to_string(),
1583            api: Api::AnthropicMessages,
1584            provider: "anthropic".to_string(),
1585            base_url: "https://test.com".to_string(),
1586            reasoning: false,
1587            input: vec![InputModality::Text],
1588            cost: Cost::default(),
1589            context_window: 128_000,
1590            max_tokens: 8192,
1591            headers: HashMap::new(),
1592            compat: None,
1593        };
1594        assert!(!registry.is_using_oauth(&model));
1595    }
1596
1597    #[test]
1598    fn test_register_provider() {
1599        let registry = test_registry();
1600
1601        let config = ProviderConfigInput {
1602            name: Some("Test Provider".to_string()),
1603            base_url: Some("https://test.example.com/v1".to_string()),
1604            api_key: Some("test-api-key".to_string()),
1605            api: Some(Api::OpenAiCompletions),
1606            models: Some(vec![ModelDefinition {
1607                id: "test-model".to_string(),
1608                name: Some("Test Model".to_string()),
1609                api: None,
1610                base_url: None,
1611                reasoning: Some(false),
1612                thinking_level_map: None,
1613                input: Some(vec![InputModality::Text]),
1614                cost: None,
1615                context_window: Some(128_000),
1616                max_tokens: Some(8192),
1617                headers: None,
1618                compat: None,
1619            }]),
1620            ..Default::default()
1621        };
1622
1623        registry.register_provider("test-provider", config);
1624
1625        let model = registry.find("test-provider", "test-model");
1626        assert!(model.is_some());
1627        assert_eq!(model.unwrap().name, "Test Model");
1628    }
1629
1630    #[test]
1631    fn test_unregister_provider() {
1632        let registry = test_registry();
1633
1634        let config = ProviderConfigInput {
1635            base_url: Some("https://test.example.com/v1".to_string()),
1636            api_key: Some("test-api-key".to_string()),
1637            api: Some(Api::OpenAiCompletions),
1638            models: Some(vec![ModelDefinition {
1639                id: "test-model".to_string(),
1640                name: Some("Test Model".to_string()),
1641                api: None,
1642                base_url: None,
1643                reasoning: None,
1644                thinking_level_map: None,
1645                input: None,
1646                cost: None,
1647                context_window: None,
1648                max_tokens: None,
1649                headers: None,
1650                compat: None,
1651            }]),
1652            ..Default::default()
1653        };
1654
1655        registry.register_provider("test-provider", config);
1656        assert!(registry.find("test-provider", "test-model").is_some());
1657
1658        registry.unregister_provider("test-provider");
1659        assert!(registry.find("test-provider", "test-model").is_none());
1660    }
1661
1662    #[test]
1663    fn test_get_default_model() {
1664        // Clear all known provider API keys to avoid interference from environment
1665        for key in &[
1666            "ANTHROPIC_API_KEY",
1667            "OPENAI_API_KEY",
1668            "GOOGLE_API_KEY",
1669            "GEMINI_API_KEY",
1670            "GROQ_API_KEY",
1671            "MISTRAL_API_KEY",
1672            "DEEPSEEK_API_KEY",
1673            "XAI_API_KEY",
1674            "COHERE_API_KEY",
1675            "CO_API_KEY",
1676            "PERPLEXITY_API_KEY",
1677            "MINIMAX_API_KEY",
1678            "ZAI_API_KEY",
1679            "FIREWORKS_API_KEY",
1680            "OPENROUTER_API_KEY",
1681            "CEREBRAS_API_KEY",
1682            "KIMI_API_KEY",
1683            "MOONSHOT_API_KEY",
1684            "XIAOMI_API_KEY",
1685            "CLOUDFLARE_API_KEY",
1686            "CLOUDFLARE_AI_GATEWAY_API_KEY",
1687            "AI_GATEWAY_API_KEY",
1688            "AZURE_OPENAI_API_KEY",
1689            "GOOGLE_CLOUD_API_KEY",
1690            "MINIMAX_CN_API_KEY",
1691        ] {
1692            std::env::remove_var(key);
1693        }
1694
1695        // Use runtime key (not env var) to avoid parallel test interference
1696        let auth = AuthStorage::in_memory();
1697        auth.set_runtime_key("anthropic", "test-key".to_string());
1698        let registry = ModelRegistry::create(auth, None);
1699        let model = registry.get_default_model();
1700        assert!(model.is_some());
1701        assert_eq!(model.unwrap().provider, "anthropic");
1702    }
1703
1704    #[test]
1705    fn test_apply_model_override() {
1706        let base = Model {
1707            id: "test".to_string(),
1708            name: "Test Model".to_string(),
1709            api: Api::OpenAiCompletions,
1710            provider: "openai".to_string(),
1711            base_url: "https://api.openai.com/v1".to_string(),
1712            reasoning: false,
1713            input: vec![InputModality::Text, InputModality::Image],
1714            cost: Cost {
1715                input: 2.5,
1716                output: 10.0,
1717                cache_read: 1.25,
1718                cache_write: 0.0,
1719            },
1720            context_window: 128_000,
1721            max_tokens: 16_384,
1722            headers: HashMap::new(),
1723            compat: None,
1724        };
1725
1726        let override_def = ModelOverride {
1727            name: Some("Overridden Name".to_string()),
1728            reasoning: Some(true),
1729            cost: Some(PartialCost {
1730                input: Some(5.0),
1731                ..Default::default()
1732            }),
1733            ..Default::default()
1734        };
1735
1736        let result = apply_model_override(&base, &override_def);
1737        assert_eq!(result.name, "Overridden Name");
1738        assert!(result.reasoning);
1739        assert_eq!(result.cost.input, 5.0);
1740        assert_eq!(result.cost.output, 10.0); // Preserved from base
1741    }
1742
1743    #[test]
1744    fn test_load_custom_models_file_not_found() {
1745        let registry = test_registry();
1746        let result = registry.load_custom_models(Path::new("/nonexistent/models.json"));
1747        assert!(result.error.is_none());
1748        assert!(result.models.is_empty());
1749    }
1750
1751    #[test]
1752    fn test_resolve_config_value_env() {
1753        // Test env var expansion ($ prefix)
1754        // Use a unique var name to avoid parallel test interference
1755        let var_name = format!("OXI_TEST_KEY_DOLLAR_{}", std::process::id());
1756        std::env::set_var(&var_name, "test-value-123");
1757        let result = resolve_config_value(&format!("${}", var_name));
1758        assert_eq!(result, Some("test-value-123".to_string()));
1759        std::env::remove_var(&var_name);
1760    }
1761
1762    #[test]
1763    fn test_resolve_config_value_env_braces() {
1764        // Test env var expansion with ${VAR} syntax
1765        let var_name = format!("OXI_TEST_KEY_{}", std::process::id());
1766        std::env::set_var(&var_name, "test-value-456");
1767        let result = resolve_config_value(&format!("${{{}}}", var_name));
1768        assert_eq!(result, Some("test-value-456".to_string()));
1769        std::env::remove_var(&var_name);
1770    }
1771
1772    #[test]
1773    fn test_resolve_config_value_command_rejected() {
1774        // ! prefix is now rejected for security
1775        let result = resolve_config_value("!echo hello");
1776        assert!(result.is_none());
1777    }
1778
1779    #[test]
1780    fn test_merge_compat_none_none() {
1781        assert!(merge_compat(None, None).is_none());
1782    }
1783
1784    #[test]
1785    fn test_merge_compat_some_none() {
1786        let base = CompatSettings::default();
1787        let result = merge_compat(Some(&base), None);
1788        assert!(result.is_some());
1789    }
1790}