Skip to main content

oxi_store/
model_resolver.rs

1//! Model name resolution and matching
2//!
3//! Provides utilities for parsing model patterns, resolving model names,
4//! finding models by glob, and determining the best model for startup.
5//!
6//! This module integrates with:
7//! - oxi_ai::model_db::ModelEntry for cost/modality info
8//! - oxi_ai::model_registry for runtime model lookup
9//! - auth_storage for auth validation
10
11use crate::settings::Settings;
12use std::collections::HashMap;
13use std::sync::LazyLock;
14
15/// Cached regex for date pattern matching (-YYYYMMDD)
16static DATE_PATTERN_RE: LazyLock<regex::Regex> =
17    LazyLock::new(|| regex::Regex::new(r"-\d{8}$").expect("date pattern regex should compile"));
18
19/// Cached regex for date pattern stripping
20static DATE_PATTERN_STRIP_RE: LazyLock<regex::Regex> = LazyLock::new(|| {
21    regex::Regex::new(r"-\d{8}").expect("date pattern strip regex should compile")
22});
23
24// =============================================================================
25// Constants
26// =============================================================================
27
28/// Default thinking level when none is specified
29pub const DEFAULT_THINKING_LEVEL: &str = "medium";
30
31/// Valid thinking levels in order of intensity
32pub const THINKING_LEVELS: &[&str] = &["off", "minimal", "low", "medium", "high", "xhigh"];
33
34// =============================================================================
35// Model Types
36// =============================================================================
37
38/// Known AI providers
39#[derive(Debug, Clone, PartialEq, Eq, Hash)]
40pub struct Provider {
41    /// Unique provider identifier (e.g. "anthropic", "openai").
42    pub id: String,
43    /// Human-readable provider name.
44    pub name: String,
45    /// Optional provider website URL.
46    pub website: Option<String>,
47}
48
49impl Provider {
50    /// Create a new provider with the given identifier and display name.
51    pub fn new(id: impl Into<String>, name: impl Into<String>) -> Self {
52        Self {
53            id: id.into(),
54            name: name.into(),
55            website: None,
56        }
57    }
58
59    /// Attach a website URL to this provider (builder-style).
60    pub fn with_website(mut self, website: impl Into<String>) -> Self {
61        self.website = Some(website.into());
62        self
63    }
64}
65
66/// A discovered model
67#[derive(Debug, Clone)]
68pub struct Model {
69    /// Provider that hosts this model.
70    pub provider: String,
71    /// Model identifier (e.g. "claude-sonnet-4-5").
72    pub id: String,
73    /// Human-readable model name.
74    pub name: Option<String>,
75    /// Optional model description.
76    pub description: Option<String>,
77    /// Maximum context window size in tokens.
78    pub context_window: Option<u32>,
79    /// Features the model supports (e.g. "tools", "vision").
80    pub supported_features: Vec<String>,
81    // Unified from ModelEntry - cost info
82    /// Cost per million input tokens.
83    pub cost_input: Option<f64>,
84    /// Cost per million output tokens.
85    pub cost_output: Option<f64>,
86    /// Cost per million cache-read tokens.
87    pub cost_cache_read: Option<f64>,
88    /// Cost per million cache-write tokens.
89    pub cost_cache_write: Option<f64>,
90    // Unified from ModelEntry - input modalities
91    /// Supported input modalities (e.g. "text", "image").
92    pub input_modalities: Vec<String>,
93}
94
95impl Model {
96    /// Get the full model identifier (provider/model_id)
97    pub fn full_id(&self) -> String {
98        format!("{}/{}", self.provider, self.id)
99    }
100
101    /// Create from a model_db::ModelEntry
102    pub fn from_entry(entry: &oxi_ai::ModelEntry) -> Self {
103        Self {
104            provider: entry.provider.to_string(),
105            id: entry.id.to_string(),
106            name: Some(entry.name.to_string()),
107            description: None,
108            context_window: Some(entry.context_window),
109            supported_features: vec![],
110            cost_input: Some(entry.cost_input),
111            cost_output: Some(entry.cost_output),
112            cost_cache_read: Some(entry.cost_cache_read),
113            cost_cache_write: Some(entry.cost_cache_write),
114            input_modalities: entry
115                .input
116                .iter()
117                .map(|m| format!("{:?}", m).to_lowercase())
118                .collect(),
119        }
120    }
121
122    /// Create from a model_registry::Model
123    pub fn from_registry_model(model: &oxi_ai::Model) -> Self {
124        Self {
125            provider: model.provider.clone(),
126            id: model.id.clone(),
127            name: Some(model.name.clone()),
128            description: None,
129            context_window: Some(model.context_window as u32),
130            supported_features: vec![],
131            cost_input: Some(model.cost.input),
132            cost_output: Some(model.cost.output),
133            cost_cache_read: Some(model.cost.cache_read),
134            cost_cache_write: Some(model.cost.cache_write),
135            input_modalities: model
136                .input
137                .iter()
138                .map(|m| format!("{:?}", m).to_lowercase())
139                .collect(),
140        }
141    }
142}
143
144/// Result of parsing a model pattern
145#[derive(Debug)]
146pub struct ParsedModelResult {
147    /// Resolved provider name, if identified.
148    pub provider: Option<String>,
149    /// Model ID extracted from the pattern.
150    pub model_id: String,
151    /// Requested thinking level (e.g. "high", "medium").
152    pub thinking_level: Option<String>,
153    /// Warning message when resolution is ambiguous.
154    pub warning: Option<String>,
155}
156
157/// Result of resolving a CLI model
158#[derive(Debug)]
159pub struct ResolveCliModelResult {
160    /// The resolved model, if found.
161    pub model: Option<Model>,
162    /// Requested thinking level.
163    pub thinking_level: Option<String>,
164    /// Warning about ambiguous resolution.
165    pub warning: Option<String>,
166    /// Error message when resolution fails.
167    pub error: Option<String>,
168}
169
170/// Result of finding initial model
171#[derive(Debug)]
172pub struct InitialModelResult {
173    /// The selected model.
174    pub model: Option<Model>,
175    /// Thinking level applied.
176    pub thinking_level: String,
177    /// Message when a fallback model was chosen.
178    pub fallback_message: Option<String>,
179}
180
181/// Result of restore model from session
182#[derive(Debug)]
183pub struct RestoreModelResult {
184    /// The restored model, if available.
185    pub model: Option<Model>,
186    /// Message explaining why a fallback was used.
187    pub fallback_message: Option<String>,
188    /// Machine-readable reason code ("no_auth", "model_not_found", "no_models").
189    pub reason: Option<String>,
190}
191
192// =============================================================================
193// Core Functions
194// =============================================================================
195
196/// Check if two models are equal (same provider and id)
197///
198/// Used for deduplication in model scope resolution.
199pub fn models_are_equal(a: &Model, b: &Model) -> bool {
200    a.provider == b.provider && a.id == b.id
201}
202
203/// Check if a model ID looks like an alias (no date suffix)
204fn is_alias(id: &str) -> bool {
205    // Aliases end with -latest or don't have date patterns
206    if id.ends_with("-latest") {
207        return true;
208    }
209    // Check if ends with date pattern (-YYYYMMDD)
210    !DATE_PATTERN_RE.is_match(id)
211}
212
213/// Match a glob pattern against text
214///
215/// Supports:
216/// - `*` - matches any characters
217/// - `?` - matches any single character
218/// - `[abc]` - matches any character in the set
219///
220/// Note: Matching is case-insensitive.
221pub fn match_glob(pattern: &str, text: &str) -> bool {
222    // Simple implementation: convert glob to regex
223    let mut regex_pattern = String::new();
224    let mut in_class = false;
225    let chars = pattern.chars().peekable();
226
227    for c in chars {
228        match c {
229            '*' => {
230                if in_class {
231                    regex_pattern.push_str("\\*");
232                } else {
233                    regex_pattern.push_str(".*");
234                }
235            }
236            '?' => {
237                if in_class {
238                    regex_pattern.push('?');
239                } else {
240                    regex_pattern.push('.');
241                }
242            }
243            '[' => {
244                in_class = true;
245                regex_pattern.push('[');
246            }
247            ']' => {
248                in_class = false;
249                regex_pattern.push(']');
250            }
251            '.' | '+' | '^' | '$' | '\\' | '(' | ')' | '{' | '}' | '|' => {
252                // Escape regex special characters (but not those in char classes)
253                if !in_class {
254                    regex_pattern.push('\\');
255                }
256                regex_pattern.push(c);
257            }
258            _ => regex_pattern.push(c),
259        }
260    }
261
262    // Handle trailing ** in patterns
263    if pattern.ends_with("**") {
264        regex_pattern.push_str(".*");
265    }
266
267    // Use case-insensitive regex matching
268    regex::RegexBuilder::new(&format!("^{}$", regex_pattern))
269        .case_insensitive(true)
270        .build()
271        .map(|re| re.is_match(text))
272        .unwrap_or_else(|_| pattern.eq_ignore_ascii_case(text))
273}
274
275/// Find all models matching a provider and glob pattern
276pub fn find_models_by_glob<'a>(
277    provider: &str,
278    pattern: &str,
279    models: &'a [Model],
280) -> Vec<&'a Model> {
281    models
282        .iter()
283        .filter(|m| m.provider == provider && match_glob(pattern, &m.id))
284        .collect()
285}
286
287/// Find all models matching a pattern (glob or substring) from the full model database
288pub fn find_models_by_pattern(pattern: &str, models: &[Model]) -> Vec<Model> {
289    let pattern_lower = pattern.to_lowercase();
290    models
291        .iter()
292        .filter(|m| {
293            m.id.to_lowercase().contains(&pattern_lower)
294                || m.full_id().to_lowercase().contains(&pattern_lower)
295                || m.name
296                    .as_ref()
297                    .map(|n| n.to_lowercase().contains(&pattern_lower))
298                    .unwrap_or(false)
299        })
300        .cloned()
301        .collect()
302}
303
304/// Get the thinking level mapping for a model that has thinking variants
305///
306/// For models like claude-3-5-sonnet-*, returns a mapping:
307/// - "high" -> "claude-3-5-sonnet-20240620"
308/// - "medium" -> "claude-3-5-sonnet-latest"
309pub fn get_thinking_level_map(model_id: &str) -> Option<HashMap<String, String>> {
310    // Strip common suffixes to find base model name
311    let base = if let Some(stripped) = model_id.strip_suffix("-latest") {
312        stripped
313    } else if let Some(dated) = model_id.strip_suffix(DATE_PATTERN_STRIP_RE.as_str()) {
314        dated
315    } else {
316        // Also try stripping numbered suffixes like -5, -6, etc.
317        // Match patterns like claude-opus-4-5, claude-sonnet-4-6, etc.
318        model_id
319    };
320
321    // Only certain models have thinking variants
322    let thinking_models = [
323        (
324            "claude-3-5-sonnet",
325            vec![
326                ("high", "claude-3-5-sonnet-20240620"),
327                ("medium", "claude-3-5-sonnet-latest"),
328                ("low", "claude-3-5-sonnet-latest"),
329            ],
330        ),
331        (
332            "claude-3-7-sonnet",
333            vec![
334                ("high", "claude-3-7-sonnet-20250219"),
335                ("medium", "claude-3-7-sonnet-20250219"),
336                ("low", "claude-3-7-sonnet-latest"),
337            ],
338        ),
339        (
340            "claude-opus-4",
341            vec![
342                ("high", "claude-opus-4-5-20251101"),
343                ("medium", "claude-opus-4-5"),
344                ("low", "claude-opus-4-1"),
345                ("off", "claude-opus-4-0"),
346            ],
347        ),
348        (
349            "claude-sonnet-4",
350            vec![
351                ("high", "claude-sonnet-4-20250514"),
352                ("medium", "claude-sonnet-4-5"),
353                ("low", "claude-sonnet-4-0"),
354                ("off", "claude-sonnet-4-0"),
355            ],
356        ),
357    ];
358
359    for (base_name, levels) in thinking_models {
360        if base.contains(base_name) {
361            let mut map = HashMap::new();
362            for (level, id) in levels {
363                map.insert(level.to_string(), id.to_string());
364            }
365            return Some(map);
366        }
367    }
368
369    None
370}
371
372/// Clamp a requested thinking level to what the model supports
373///
374/// If the model doesn't support the requested level, returns the closest
375/// supported level.
376pub fn clamp_thinking_level(model_id: &str, requested_level: &str) -> String {
377    if let Some(map) = get_thinking_level_map(model_id) {
378        // If the requested level is in the map, use it
379        if map.contains_key(requested_level) {
380            return map
381                .get(requested_level)
382                .cloned()
383                .unwrap_or_else(|| model_id.to_string());
384        }
385
386        // Find the closest level that IS in the map
387        let level_idx = THINKING_LEVELS.iter().position(|&l| l == requested_level);
388        if let Some(idx) = level_idx {
389            // Search downward for the closest supported level
390            for i in (0..idx).rev() {
391                let level = THINKING_LEVELS[i];
392                if map.contains_key(level) {
393                    return map
394                        .get(level)
395                        .cloned()
396                        .unwrap_or_else(|| model_id.to_string());
397                }
398            }
399        }
400    }
401
402    // Fallback: return the requested level as-is if model has no special mapping
403    requested_level.to_string()
404}
405
406/// Check if auth is configured for a provider/model
407///
408/// Checks stored credentials in auth.json. Does NOT check environment variables.
409/// Use `oxi setup` to configure credentials persistently.
410pub fn has_configured_auth(provider: &str, _model: &Model) -> bool {
411    // Check auth storage (auth.json)
412    let auth = crate::auth_storage::shared_auth_storage();
413    auth.has_auth(provider)
414}
415
416/// Parse a model pattern into components
417///
418/// # Arguments
419/// * `pattern` - The model pattern (e.g., "anthropic/claude-3.5-sonnet" or "sonnet:high")
420/// * `available_models` - List of available models for validation
421///
422/// # Returns
423/// A parsed model result with provider, model_id, and optional thinking level
424pub fn parse_model_pattern(pattern: &str, available_models: &[Model]) -> ParsedModelResult {
425    let pattern = pattern.trim();
426    if pattern.is_empty() {
427        return ParsedModelResult {
428            provider: None,
429            model_id: String::new(),
430            thinking_level: None,
431            warning: Some("Empty model pattern".to_string()),
432        };
433    }
434
435    // Check for thinking level suffix (e.g., "sonnet:high")
436    let last_colon = pattern.rfind(':');
437    let (base_pattern, thinking_level) = if let Some(idx) = last_colon {
438        let suffix = &pattern[idx + 1..];
439        if THINKING_LEVELS.contains(&suffix) {
440            (&pattern[..idx], Some(suffix.to_string()))
441        } else {
442            (pattern, None)
443        }
444    } else {
445        (pattern, None)
446    };
447
448    // Try to find an exact match first
449    let exact_match = available_models.iter().find(|m| {
450        m.id.eq_ignore_ascii_case(base_pattern) || m.full_id().eq_ignore_ascii_case(base_pattern)
451    });
452
453    if let Some(model) = exact_match {
454        return ParsedModelResult {
455            provider: Some(model.provider.clone()),
456            model_id: model.id.clone(),
457            thinking_level,
458            warning: None,
459        };
460    }
461
462    // Try to parse provider/model format
463    if let Some(slash_idx) = base_pattern.find('/') {
464        let provider = &base_pattern[..slash_idx];
465        let model_id = &base_pattern[slash_idx + 1..];
466
467        // Check if provider exists in available models
468        let provider_exists = available_models
469            .iter()
470            .any(|m| m.provider.eq_ignore_ascii_case(provider));
471
472        if provider_exists {
473            return ParsedModelResult {
474                provider: Some(provider.to_string()),
475                model_id: model_id.to_string(),
476                thinking_level,
477                warning: None,
478            };
479        }
480    }
481
482    // Try partial matching
483    let partial_matches: Vec<&Model> = available_models
484        .iter()
485        .filter(|m| {
486            m.id.to_lowercase().contains(&base_pattern.to_lowercase())
487                || m.name
488                    .as_ref()
489                    .map(|n| n.to_lowercase().contains(&base_pattern.to_lowercase()))
490                    .unwrap_or(false)
491        })
492        .collect();
493
494    if partial_matches.len() == 1 {
495        let model = partial_matches[0];
496        return ParsedModelResult {
497            provider: Some(model.provider.clone()),
498            model_id: model.id.clone(),
499            thinking_level,
500            warning: None,
501        };
502    } else if partial_matches.len() > 1 {
503        // Prefer aliases over dated versions
504        let aliases: Vec<_> = partial_matches.iter().filter(|m| is_alias(&m.id)).collect();
505        if !aliases.is_empty() {
506            let model = aliases[0];
507            return ParsedModelResult {
508                provider: Some(model.provider.clone()),
509                model_id: model.id.clone(),
510                thinking_level,
511                warning: Some(format!(
512                    "Multiple models match '{}', selected '{}'",
513                    base_pattern,
514                    model.full_id()
515                )),
516            };
517        }
518        // Use the latest dated version (sort descending)
519        let mut sorted = partial_matches.to_vec();
520        sorted.sort_by(|a, b| b.id.cmp(&a.id));
521        let model = sorted[0];
522        return ParsedModelResult {
523            provider: Some(model.provider.clone()),
524            model_id: model.id.clone(),
525            thinking_level,
526            warning: Some(format!(
527                "Multiple models match '{}', selected '{}'",
528                base_pattern,
529                model.full_id()
530            )),
531        };
532    }
533
534    // No match found - return as raw pattern
535    ParsedModelResult {
536        provider: None,
537        model_id: pattern.to_string(),
538        thinking_level,
539        warning: Some(format!(
540            "Model '{}' not found in available models. Treating as custom model ID.",
541            pattern
542        )),
543    }
544}
545
546/// Default models per provider
547pub fn default_model_per_provider() -> HashMap<String, String> {
548    let mut map = HashMap::new();
549    map.insert("anthropic".to_string(), "claude-sonnet-4-5".to_string());
550    map.insert("openai".to_string(), "gpt-4o".to_string());
551    map.insert("google".to_string(), "gemini-2.5-pro".to_string());
552    map.insert("deepseek".to_string(), "deepseek-v3".to_string());
553    map.insert(
554        "openrouter".to_string(),
555        "anthropic/claude-sonnet-4".to_string(),
556    );
557    map.insert("groq".to_string(), "mixtral-8x7b".to_string());
558    map.insert("cerebras".to_string(), "llama-3.3-70b".to_string());
559    map.insert("mistral".to_string(), "mistral-large".to_string());
560    map.insert("xai".to_string(), "grok-2".to_string());
561    map.insert(
562        "amazon-bedrock".to_string(),
563        "anthropic.claude-v2".to_string(),
564    );
565    map.insert("azure-openai".to_string(), "gpt-4o".to_string());
566    map
567}
568
569/// Resolve a model from CLI arguments
570pub fn resolve_cli_model(
571    cli_provider: Option<&str>,
572    cli_model: Option<&str>,
573    available_models: &[Model],
574    _settings: Option<&Settings>,
575) -> ResolveCliModelResult {
576    let cli_model = match cli_model {
577        Some(m) => m,
578        None => {
579            return ResolveCliModelResult {
580                model: None,
581                thinking_level: None,
582                warning: None,
583                error: None,
584            };
585        }
586    };
587
588    // Build provider map for case-insensitive lookup
589    let mut provider_map: HashMap<String, String> = HashMap::new();
590    for model in available_models {
591        provider_map.insert(model.provider.to_lowercase(), model.provider.clone());
592    }
593
594    // Try to resolve provider
595    let provider = if let Some(p) = cli_provider {
596        provider_map.get(&p.to_lowercase()).cloned()
597    } else if let Some(slash_idx) = cli_model.find('/') {
598        let maybe_provider = &cli_model[..slash_idx];
599        provider_map.get(&maybe_provider.to_lowercase()).cloned()
600    } else {
601        None
602    };
603
604    // Extract the model pattern
605    let model_pattern = if let Some(ref p) = provider {
606        if cli_model
607            .to_lowercase()
608            .starts_with(&format!("{}/", p.to_lowercase()))
609        {
610            &cli_model[p.len() + 1..]
611        } else {
612            cli_model
613        }
614    } else {
615        cli_model
616    };
617
618    // Parse the pattern
619    let parsed = parse_model_pattern(model_pattern, available_models);
620
621    // Find the model
622    let model = if let Some(ref p) = provider {
623        available_models
624            .iter()
625            .find(|m| {
626                m.provider.eq_ignore_ascii_case(p) && m.id.eq_ignore_ascii_case(&parsed.model_id)
627            })
628            .cloned()
629    } else if let Some(ref p) = parsed.provider {
630        available_models
631            .iter()
632            .find(|m| {
633                m.provider.eq_ignore_ascii_case(p) && m.id.eq_ignore_ascii_case(&parsed.model_id)
634            })
635            .cloned()
636    } else {
637        // Try matching without provider
638        available_models
639            .iter()
640            .find(|m| m.id.eq_ignore_ascii_case(&parsed.model_id))
641            .cloned()
642    };
643
644    if let Some(ref m) = model {
645        ResolveCliModelResult {
646            model: Some(m.clone()),
647            thinking_level: parsed.thinking_level,
648            warning: parsed.warning,
649            error: None,
650        }
651    } else {
652        // Try building a fallback custom model
653        let fallback_model = if let Some(ref p) = provider {
654            Some(Model {
655                provider: p.clone(),
656                id: parsed.model_id.clone(),
657                name: Some(parsed.model_id.clone()),
658                description: None,
659                context_window: None,
660                supported_features: vec![],
661                cost_input: None,
662                cost_output: None,
663                cost_cache_read: None,
664                cost_cache_write: None,
665                input_modalities: vec!["text".to_string()],
666            })
667        } else {
668            None
669        };
670
671        ResolveCliModelResult {
672            model: fallback_model.clone(),
673            thinking_level: parsed.thinking_level,
674            warning: parsed.warning,
675            error: fallback_model.is_none().then(|| {
676                format!(
677                    "Model '{}' not found. Use --list-models to see available models.",
678                    cli_model
679                )
680            }),
681        }
682    }
683}
684
685/// Find the initial model to use based on priority:
686/// 1. CLI args
687/// 2. First model from scoped models
688/// 3. Saved default from settings
689/// 4. First available model
690pub fn find_initial_model(
691    cli_provider: Option<&str>,
692    cli_model: Option<&str>,
693    scoped_models: &[Model],
694    is_continuing: bool,
695    settings: Option<&Settings>,
696    available_models: &[Model],
697) -> InitialModelResult {
698    // 1. CLI args take priority
699    if cli_provider.is_some() || cli_model.is_some() {
700        let result = resolve_cli_model(cli_provider, cli_model, available_models, settings);
701        if result.error.is_none() {
702            return InitialModelResult {
703                model: result.model,
704                thinking_level: result
705                    .thinking_level
706                    .unwrap_or_else(|| DEFAULT_THINKING_LEVEL.to_string()),
707                fallback_message: None,
708            };
709        }
710    }
711
712    // 2. Use first model from scoped models (skip if continuing)
713    if !scoped_models.is_empty() && !is_continuing {
714        return InitialModelResult {
715            model: Some(scoped_models[0].clone()),
716            thinking_level: DEFAULT_THINKING_LEVEL.to_string(),
717            fallback_message: None,
718        };
719    }
720
721    // 3. Try saved default from settings
722    if let Some(s) = settings {
723        if let Some(default_model) = &s.default_model {
724            let parsed = parse_model_pattern(default_model, available_models);
725            if let Some(ref p) = parsed.provider {
726                let model = available_models
727                    .iter()
728                    .find(|m| {
729                        m.provider.eq_ignore_ascii_case(p)
730                            && m.id.eq_ignore_ascii_case(&parsed.model_id)
731                    })
732                    .cloned();
733                if model.is_some() {
734                    return InitialModelResult {
735                        model,
736                        thinking_level: format!("{:?}", s.thinking_level),
737                        fallback_message: None,
738                    };
739                }
740            }
741        }
742    }
743
744    // 4. Try default models from known providers
745    let defaults = default_model_per_provider();
746    for (provider, default_id) in &defaults {
747        if let Some(model) = available_models.iter().find(|m| {
748            m.provider.eq_ignore_ascii_case(provider) && m.id.eq_ignore_ascii_case(default_id)
749        }) {
750            return InitialModelResult {
751                model: Some(model.clone()),
752                thinking_level: DEFAULT_THINKING_LEVEL.to_string(),
753                fallback_message: None,
754            };
755        }
756    }
757
758    // 5. Use first available model
759    if let Some(model) = available_models.first() {
760        return InitialModelResult {
761            model: Some(model.clone()),
762            thinking_level: DEFAULT_THINKING_LEVEL.to_string(),
763            fallback_message: None,
764        };
765    }
766
767    // No model found
768    InitialModelResult {
769        model: None,
770        thinking_level: DEFAULT_THINKING_LEVEL.to_string(),
771        fallback_message: Some("No models available. Check your installation.".to_string()),
772    }
773}
774
775/// Restore model from session with fallback and auth validation
776///
777/// This function now properly checks auth configuration before restoring a model.
778pub fn restore_model_from_session(
779    saved_provider: &str,
780    saved_model_id: &str,
781    current_model: Option<&Model>,
782    should_print_messages: bool,
783    available_models: &[Model],
784) -> RestoreModelResult {
785    let restored = available_models
786        .iter()
787        .find(|m| {
788            m.provider.eq_ignore_ascii_case(saved_provider)
789                && m.id.eq_ignore_ascii_case(saved_model_id)
790        })
791        .cloned();
792
793    match (&restored, current_model) {
794        (Some(ref model), _) => {
795            // Check if the model has auth configured
796            if has_configured_auth(saved_provider, model) {
797                if should_print_messages {
798                    eprintln!("Restored model: {}/{}", saved_provider, saved_model_id);
799                }
800                RestoreModelResult {
801                    model: Some((*model).clone()),
802                    fallback_message: None,
803                    reason: None,
804                }
805            } else {
806                // Model exists but no auth - try fallback
807                if should_print_messages {
808                    eprintln!(
809                        "Warning: Could not restore model {}/{} (no auth configured).",
810                        saved_provider, saved_model_id
811                    );
812                }
813
814                if let Some(current) = current_model {
815                    if should_print_messages {
816                        eprintln!("Falling back to: {}/{}", current.provider, current.id);
817                    }
818                    RestoreModelResult {
819                        model: Some((*current).clone()),
820                        fallback_message: Some(format!(
821                            "Could not restore model {}/{} (no auth configured). Using current model.",
822                            saved_provider, saved_model_id
823                        )),
824                        reason: Some("no_auth".to_string()),
825                    }
826                } else if let Some(fallback) = available_models.first() {
827                    if should_print_messages {
828                        eprintln!(
829                            "Using first available model: {}/{}",
830                            fallback.provider, fallback.id
831                        );
832                    }
833                    RestoreModelResult {
834                        model: Some(fallback.clone()),
835                        fallback_message: Some(format!(
836                            "Could not restore model {}/{} (no auth configured). Using first available model.",
837                            saved_provider, saved_model_id
838                        )),
839                        reason: Some("no_auth".to_string()),
840                    }
841                } else {
842                    RestoreModelResult {
843                        model: None,
844                        fallback_message: Some("No models available.".to_string()),
845                        reason: Some("no_auth".to_string()),
846                    }
847                }
848            }
849        }
850        (None, Some(current)) => {
851            if should_print_messages {
852                eprintln!(
853                    "Warning: Could not restore model {}/{} (model not found). Falling back to current model.",
854                    saved_provider, saved_model_id
855                );
856                eprintln!("Falling back to: {}/{}", current.provider, current.id);
857            }
858            RestoreModelResult {
859                model: Some((*current).clone()),
860                fallback_message: Some(format!(
861                    "Could not restore model {}/{} (model not found). Using current model.",
862                    saved_provider, saved_model_id
863                )),
864                reason: Some("model_not_found".to_string()),
865            }
866        }
867        (None, None) => {
868            // Try to find any available model
869            if let Some(model) = available_models.first() {
870                if should_print_messages {
871                    eprintln!(
872                        "Warning: Could not restore model {}/{} (model not found).",
873                        saved_provider, saved_model_id
874                    );
875                    eprintln!(
876                        "Using first available model: {}/{}",
877                        model.provider, model.id
878                    );
879                }
880                RestoreModelResult {
881                    model: Some(model.clone()),
882                    fallback_message: Some(format!(
883                        "Could not restore model {}/{}. Using first available model.",
884                        saved_provider, saved_model_id
885                    )),
886                    reason: Some("model_not_found".to_string()),
887                }
888            } else {
889                RestoreModelResult {
890                    model: None,
891                    fallback_message: Some("No models available.".to_string()),
892                    reason: Some("no_models".to_string()),
893                }
894            }
895        }
896    }
897}
898
899#[cfg(test)]
900mod tests {
901    use super::*;
902
903    fn sample_models() -> Vec<Model> {
904        vec![
905            Model {
906                provider: "anthropic".to_string(),
907                id: "claude-sonnet-4-5".to_string(),
908                name: Some("Claude Sonnet 4.5".to_string()),
909                description: None,
910                context_window: Some(200000),
911                supported_features: vec!["tools".to_string(), "vision".to_string()],
912                cost_input: Some(3.0),
913                cost_output: Some(15.0),
914                cost_cache_read: Some(0.3),
915                cost_cache_write: Some(3.75),
916                input_modalities: vec!["text".to_string(), "image".to_string()],
917            },
918            Model {
919                provider: "anthropic".to_string(),
920                id: "claude-opus-4-7".to_string(),
921                name: Some("Claude Opus 4.7".to_string()),
922                description: None,
923                context_window: Some(200000),
924                supported_features: vec!["tools".to_string(), "vision".to_string()],
925                cost_input: Some(15.0),
926                cost_output: Some(75.0),
927                cost_cache_read: Some(0.5),
928                cost_cache_write: Some(6.25),
929                input_modalities: vec!["text".to_string(), "image".to_string()],
930            },
931            Model {
932                provider: "openai".to_string(),
933                id: "gpt-4o".to_string(),
934                name: Some("GPT-4o".to_string()),
935                description: None,
936                context_window: Some(128000),
937                supported_features: vec!["tools".to_string()],
938                cost_input: Some(2.5),
939                cost_output: Some(10.0),
940                cost_cache_read: Some(1.25),
941                cost_cache_write: Some(0.0),
942                input_modalities: vec!["text".to_string(), "image".to_string()],
943            },
944            Model {
945                provider: "google".to_string(),
946                id: "gemini-2.5-pro".to_string(),
947                name: Some("Gemini 2.5 Pro".to_string()),
948                description: None,
949                context_window: Some(1000000),
950                supported_features: vec!["tools".to_string()],
951                cost_input: Some(1.25),
952                cost_output: Some(5.0),
953                cost_cache_read: Some(0.0),
954                cost_cache_write: Some(0.0),
955                input_modalities: vec!["text".to_string(), "image".to_string()],
956            },
957        ]
958    }
959
960    // =============================================================================
961    // Model Equality Tests
962    // =============================================================================
963
964    #[test]
965    fn test_models_are_equal_same() {
966        let model1 = Model {
967            provider: "anthropic".to_string(),
968            id: "claude-sonnet-4-5".to_string(),
969            name: Some("Claude Sonnet 4.5".to_string()),
970            description: None,
971            context_window: Some(200000),
972            supported_features: vec![],
973            cost_input: None,
974            cost_output: None,
975            cost_cache_read: None,
976            cost_cache_write: None,
977            input_modalities: vec![],
978        };
979        let model2 = Model {
980            provider: "anthropic".to_string(),
981            id: "claude-sonnet-4-5".to_string(),
982            name: Some("Claude Sonnet 4.5 (different name)".to_string()),
983            description: None,
984            context_window: Some(200000),
985            supported_features: vec![],
986            cost_input: None,
987            cost_output: None,
988            cost_cache_read: None,
989            cost_cache_write: None,
990            input_modalities: vec![],
991        };
992
993        assert!(models_are_equal(&model1, &model2));
994    }
995
996    #[test]
997    fn test_models_are_equal_different_provider() {
998        let model1 = Model {
999            provider: "anthropic".to_string(),
1000            id: "claude-sonnet-4-5".to_string(),
1001            name: None,
1002            description: None,
1003            context_window: None,
1004            supported_features: vec![],
1005            cost_input: None,
1006            cost_output: None,
1007            cost_cache_read: None,
1008            cost_cache_write: None,
1009            input_modalities: vec![],
1010        };
1011        let model2 = Model {
1012            provider: "openai".to_string(),
1013            id: "claude-sonnet-4-5".to_string(),
1014            name: None,
1015            description: None,
1016            context_window: None,
1017            supported_features: vec![],
1018            cost_input: None,
1019            cost_output: None,
1020            cost_cache_read: None,
1021            cost_cache_write: None,
1022            input_modalities: vec![],
1023        };
1024
1025        assert!(!models_are_equal(&model1, &model2));
1026    }
1027
1028    #[test]
1029    fn test_models_are_equal_different_id() {
1030        let model1 = Model {
1031            provider: "anthropic".to_string(),
1032            id: "claude-sonnet-4-5".to_string(),
1033            name: None,
1034            description: None,
1035            context_window: None,
1036            supported_features: vec![],
1037            cost_input: None,
1038            cost_output: None,
1039            cost_cache_read: None,
1040            cost_cache_write: None,
1041            input_modalities: vec![],
1042        };
1043        let model2 = Model {
1044            provider: "anthropic".to_string(),
1045            id: "claude-opus-4-7".to_string(),
1046            name: None,
1047            description: None,
1048            context_window: None,
1049            supported_features: vec![],
1050            cost_input: None,
1051            cost_output: None,
1052            cost_cache_read: None,
1053            cost_cache_write: None,
1054            input_modalities: vec![],
1055        };
1056
1057        assert!(!models_are_equal(&model1, &model2));
1058    }
1059
1060    // =============================================================================
1061    // Glob Pattern Tests
1062    // =============================================================================
1063
1064    #[test]
1065    fn test_match_glob_exact() {
1066        assert!(match_glob("claude-sonnet-4-5", "claude-sonnet-4-5"));
1067        assert!(!match_glob("claude-sonnet-4-5", "claude-opus-4-7"));
1068    }
1069
1070    #[test]
1071    fn test_match_glob_asterisk() {
1072        assert!(match_glob("claude-*", "claude-sonnet-4-5"));
1073        assert!(match_glob("claude-*", "claude-opus-4-7"));
1074        assert!(!match_glob("claude-*", "gpt-4o"));
1075    }
1076
1077    #[test]
1078    fn test_match_glob_question() {
1079        assert!(match_glob("claude-?-sonnet-4-5", "claude-3-sonnet-4-5"));
1080        assert!(!match_glob("claude-?-sonnet-4-5", "claude-35-sonnet-4-5"));
1081    }
1082
1083    #[test]
1084    fn test_match_glob_char_class() {
1085        // Character classes match case-insensitively (since glob matching is case-insensitive overall)
1086        assert!(match_glob("claude-[a-z]-sonnet", "claude-a-sonnet"));
1087        assert!(match_glob("claude-[a-z]-sonnet", "claude-b-sonnet"));
1088        // Note: character classes are case-insensitive with our implementation
1089        assert!(match_glob("claude-[a-z]-sonnet", "claude-A-sonnet"));
1090    }
1091
1092    #[test]
1093    fn test_match_glob_case_insensitive() {
1094        assert!(match_glob("CLAUDE-*", "claude-sonnet-4-5"));
1095    }
1096
1097    #[test]
1098    fn test_find_models_by_glob() {
1099        let models = sample_models();
1100        let results = find_models_by_glob("anthropic", "claude-*", &models);
1101        assert_eq!(results.len(), 2);
1102        assert!(results.iter().all(|m| m.provider == "anthropic"));
1103    }
1104
1105    #[test]
1106    fn test_find_models_by_glob_no_match() {
1107        let models = sample_models();
1108        let results = find_models_by_glob("openai", "gpt-*", &models);
1109        assert_eq!(results.len(), 1);
1110    }
1111
1112    // =============================================================================
1113    // Thinking Level Mapping Tests
1114    // =============================================================================
1115
1116    #[test]
1117    fn test_get_thinking_level_map_claude_35_sonnet() {
1118        let map = get_thinking_level_map("claude-3-5-sonnet-latest");
1119        assert!(map.is_some());
1120        let map = map.unwrap();
1121        assert_eq!(
1122            map.get("high"),
1123            Some(&"claude-3-5-sonnet-20240620".to_string())
1124        );
1125    }
1126
1127    #[test]
1128    fn test_get_thinking_level_map_claude_opus_4() {
1129        let map = get_thinking_level_map("claude-opus-4-5");
1130        assert!(map.is_some());
1131        let map = map.unwrap();
1132        assert_eq!(
1133            map.get("high"),
1134            Some(&"claude-opus-4-5-20251101".to_string())
1135        );
1136        assert_eq!(map.get("medium"), Some(&"claude-opus-4-5".to_string()));
1137    }
1138
1139    #[test]
1140    fn test_get_thinking_level_map_no_match() {
1141        let map = get_thinking_level_map("gpt-4o");
1142        assert!(map.is_none());
1143    }
1144
1145    #[test]
1146    fn test_clamp_thinking_level_supported() {
1147        let result = clamp_thinking_level("claude-3-5-sonnet-latest", "high");
1148        assert_eq!(result, "claude-3-5-sonnet-20240620");
1149    }
1150
1151    #[test]
1152    fn test_clamp_thinking_level_clamp_down() {
1153        // Request "xhigh" which doesn't exist, should clamp to "high"
1154        let result = clamp_thinking_level("claude-3-5-sonnet-latest", "xhigh");
1155        assert_eq!(result, "claude-3-5-sonnet-20240620");
1156    }
1157
1158    #[test]
1159    fn test_clamp_thinking_level_no_mapping() {
1160        // gpt-4o has no mapping, should return requested level
1161        let result = clamp_thinking_level("gpt-4o", "high");
1162        assert_eq!(result, "high");
1163    }
1164
1165    // =============================================================================
1166    // Auth Validation Tests
1167    // =============================================================================
1168
1169    #[test]
1170    fn test_has_configured_auth_unknown_provider() {
1171        let model = Model {
1172            provider: "unknown".to_string(),
1173            id: "test".to_string(),
1174            name: None,
1175            description: None,
1176            context_window: None,
1177            supported_features: vec![],
1178            cost_input: None,
1179            cost_output: None,
1180            cost_cache_read: None,
1181            cost_cache_write: None,
1182            input_modalities: vec![],
1183        };
1184
1185        // Without env vars, this should return false
1186        let has_auth = has_configured_auth("unknown", &model);
1187        assert!(!has_auth);
1188    }
1189
1190    #[test]
1191    fn test_has_configured_auth_known_provider_no_env() {
1192        let model = Model {
1193            provider: "anthropic".to_string(),
1194            id: "claude-sonnet-4-5".to_string(),
1195            name: None,
1196            description: None,
1197            context_window: None,
1198            supported_features: vec![],
1199            cost_input: None,
1200            cost_output: None,
1201            cost_cache_read: None,
1202            cost_cache_write: None,
1203            input_modalities: vec![],
1204        };
1205
1206        // Without setting env vars in test, this should return false
1207        let _has_auth = has_configured_auth("anthropic", &model);
1208        // This might be true if ANTHROPIC_API_KEY is set in the environment
1209        // which is fine - the test just checks the function works
1210    }
1211
1212    // =============================================================================
1213    // Model Parsing Tests
1214    // =============================================================================
1215
1216    #[test]
1217    fn test_parse_model_pattern_exact() {
1218        let models = sample_models();
1219        let result = parse_model_pattern("claude-sonnet-4-5", &models);
1220
1221        assert_eq!(result.model_id, "claude-sonnet-4-5");
1222        assert_eq!(result.provider, Some("anthropic".to_string()));
1223        assert!(result.warning.is_none());
1224    }
1225
1226    #[test]
1227    fn test_parse_model_pattern_with_provider() {
1228        let models = sample_models();
1229        let result = parse_model_pattern("anthropic/claude-sonnet-4-5", &models);
1230
1231        assert_eq!(result.model_id, "claude-sonnet-4-5");
1232        assert_eq!(result.provider, Some("anthropic".to_string()));
1233    }
1234
1235    #[test]
1236    fn test_parse_model_pattern_with_thinking_level() {
1237        let models = sample_models();
1238        let result = parse_model_pattern("sonnet:high", &models);
1239
1240        assert_eq!(result.thinking_level, Some("high".to_string()));
1241    }
1242
1243    #[test]
1244    fn test_parse_model_pattern_invalid_thinking_level() {
1245        let models = sample_models();
1246        let result = parse_model_pattern("sonnet:invalid", &models);
1247
1248        assert!(result.thinking_level.is_none());
1249    }
1250
1251    #[test]
1252    fn test_parse_model_pattern_partial_match() {
1253        let models = sample_models();
1254        let result = parse_model_pattern("sonnet", &models);
1255
1256        assert!(result.model_id.contains("sonnet") || result.model_id == "sonnet");
1257        assert!(result.warning.is_some() || result.provider.is_some());
1258    }
1259
1260    #[test]
1261    fn test_parse_model_pattern_not_found() {
1262        let models = sample_models();
1263        let result = parse_model_pattern("nonexistent-model", &models);
1264
1265        assert_eq!(result.model_id, "nonexistent-model");
1266        assert!(result.warning.is_some());
1267    }
1268
1269    // =============================================================================
1270    // CLI Resolution Tests
1271    // =============================================================================
1272
1273    #[test]
1274    fn test_resolve_cli_model_with_provider() {
1275        let models = sample_models();
1276        let result = resolve_cli_model(Some("anthropic"), Some("claude-sonnet-4-5"), &models, None);
1277
1278        assert!(result.error.is_none());
1279        assert!(result.model.is_some());
1280        assert_eq!(result.model.unwrap().id, "claude-sonnet-4-5");
1281    }
1282
1283    #[test]
1284    fn test_resolve_cli_model_with_slash() {
1285        let models = sample_models();
1286        let result = resolve_cli_model(None, Some("anthropic/claude-sonnet-4-5"), &models, None);
1287
1288        assert!(result.error.is_none());
1289        assert!(result.model.is_some());
1290    }
1291
1292    #[test]
1293    fn test_resolve_cli_model_not_found() {
1294        let models = sample_models();
1295        let result = resolve_cli_model(None, Some("nonexistent-model"), &models, None);
1296
1297        assert!(result.error.is_some() || result.model.is_none());
1298    }
1299
1300    #[test]
1301    fn test_resolve_cli_model_no_args() {
1302        let models = sample_models();
1303        let result = resolve_cli_model(None, None, &models, None);
1304
1305        assert!(result.model.is_none());
1306        assert!(result.error.is_none());
1307    }
1308
1309    // =============================================================================
1310    // Search Tests
1311    // =============================================================================
1312
1313    #[test]
1314    fn test_find_models_by_pattern() {
1315        let models = sample_models();
1316        let results = find_models_by_pattern("sonnet", &models);
1317
1318        assert!(!results.is_empty());
1319        assert!(results.iter().all(|m| m.id.contains("sonnet")
1320            || m.name
1321                .as_ref()
1322                .map(|n| n.contains("sonnet"))
1323                .unwrap_or(false)));
1324    }
1325
1326    #[test]
1327    fn test_find_models_by_pattern_full_id() {
1328        let models = sample_models();
1329        let results = find_models_by_pattern("anthropic/claude-sonnet-4-5", &models);
1330        assert!(!results.is_empty());
1331    }
1332
1333    // =============================================================================
1334    // Initial Model Selection Tests
1335    // =============================================================================
1336
1337    #[test]
1338    fn test_find_initial_model_from_cli() {
1339        let models = sample_models();
1340        let result = find_initial_model(Some("openai"), Some("gpt-4o"), &[], false, None, &models);
1341
1342        assert!(result.model.is_some());
1343        assert_eq!(result.model.unwrap().id, "gpt-4o");
1344    }
1345
1346    #[test]
1347    fn test_find_initial_model_fallback_to_available() {
1348        let models = sample_models();
1349        let result = find_initial_model(None, None, &[], false, None, &models);
1350
1351        assert!(result.model.is_some());
1352        assert!(result.fallback_message.is_none());
1353    }
1354
1355    #[test]
1356    fn test_find_initial_model_default_thinking_level() {
1357        let models = sample_models();
1358        let result = find_initial_model(Some("openai"), Some("gpt-4o"), &[], false, None, &models);
1359
1360        assert_eq!(result.thinking_level, DEFAULT_THINKING_LEVEL);
1361    }
1362
1363    // =============================================================================
1364    // Restore Model Tests
1365    // =============================================================================
1366
1367    #[test]
1368    fn test_restore_model_from_session_success() {
1369        let models = sample_models();
1370        let result =
1371            restore_model_from_session("anthropic", "claude-sonnet-4-5", None, false, &models);
1372
1373        // Model should be found
1374        assert!(result.model.is_some());
1375        // Note: fallback_message may be set if no auth is configured
1376        // (which is expected in test environment without API keys)
1377        if result.fallback_message.is_some() {
1378            assert_eq!(result.reason, Some("no_auth".to_string()));
1379        }
1380    }
1381
1382    #[test]
1383    fn test_restore_model_from_session_not_found() {
1384        let models = sample_models();
1385        let current = &models[0];
1386        let result =
1387            restore_model_from_session("nonexistent", "model", Some(current), false, &models);
1388
1389        assert!(result.model.is_some());
1390        assert!(result.fallback_message.is_some());
1391        assert_eq!(result.reason, Some("model_not_found".to_string()));
1392    }
1393
1394    #[test]
1395    fn test_restore_model_from_session_fallback() {
1396        let models = sample_models();
1397        let current = &models[0];
1398        let result =
1399            restore_model_from_session("nonexistent", "model", Some(current), false, &models);
1400
1401        assert!(result.model.is_some());
1402        // Should fall back to current model
1403        assert_eq!(result.model.unwrap().id, current.id);
1404    }
1405
1406    // =============================================================================
1407    // Alias Detection Tests
1408    // =============================================================================
1409
1410    #[test]
1411    fn test_is_alias() {
1412        assert!(is_alias("claude-sonnet-4-latest"));
1413        assert!(is_alias("simple-model"));
1414        assert!(!is_alias("claude-sonnet-4-20250929"));
1415        assert!(!is_alias("claude-sonnet-4-20250514"));
1416    }
1417
1418    #[test]
1419    fn test_default_thinking_level_constant() {
1420        assert_eq!(DEFAULT_THINKING_LEVEL, "medium");
1421    }
1422
1423    #[test]
1424    fn test_thinking_levels_constant() {
1425        assert_eq!(
1426            THINKING_LEVELS,
1427            &["off", "minimal", "low", "medium", "high", "xhigh"]
1428        );
1429    }
1430}