Skip to main content

oxi/
model_resolver.rs

1//! Model name resolution and matching
2//!
3//! Provides utilities for parsing model patterns, resolving model names,
4//! finding models by glob, and determining the best model for startup.
5//!
6//! This module integrates with:
7//! - oxi_ai::model_db::ModelEntry for cost/modality info
8//! - oxi_ai::model_registry for runtime model lookup
9//! - auth_storage for auth validation
10
11use crate::settings::Settings;
12use std::collections::HashMap;
13
14// =============================================================================
15// Constants
16// =============================================================================
17
18/// Default thinking level when none is specified
19pub const DEFAULT_THINKING_LEVEL: &str = "medium";
20
21/// Valid thinking levels in order of intensity
22pub const THINKING_LEVELS: &[&str] = &["off", "minimal", "low", "medium", "high", "xhigh"];
23
24// =============================================================================
25// Model Types
26// =============================================================================
27
28/// Known AI providers
29#[derive(Debug, Clone, PartialEq, Eq, Hash)]
30pub struct Provider {
31    pub id: String,
32    pub name: String,
33    pub website: Option<String>,
34}
35
36impl Provider {
37    pub fn new(id: impl Into<String>, name: impl Into<String>) -> Self {
38        Self {
39            id: id.into(),
40            name: name.into(),
41            website: None,
42        }
43    }
44
45    pub fn with_website(mut self, website: impl Into<String>) -> Self {
46        self.website = Some(website.into());
47        self
48    }
49}
50
51/// A discovered model
52#[derive(Debug, Clone)]
53pub struct Model {
54    pub provider: String,
55    pub id: String,
56    pub name: Option<String>,
57    pub description: Option<String>,
58    pub context_window: Option<u32>,
59    pub supported_features: Vec<String>,
60    // Unified from ModelEntry - cost info
61    pub cost_input: Option<f64>,
62    pub cost_output: Option<f64>,
63    pub cost_cache_read: Option<f64>,
64    pub cost_cache_write: Option<f64>,
65    // Unified from ModelEntry - input modalities
66    pub input_modalities: Vec<String>,
67}
68
69impl Model {
70    /// Get the full model identifier (provider/model_id)
71    pub fn full_id(&self) -> String {
72        format!("{}/{}", self.provider, self.id)
73    }
74
75    /// Create from a model_db::ModelEntry
76    pub fn from_entry(entry: &oxi_ai::ModelEntry) -> Self {
77        Self {
78            provider: entry.provider.to_string(),
79            id: entry.id.to_string(),
80            name: Some(entry.name.to_string()),
81            description: None,
82            context_window: Some(entry.context_window),
83            supported_features: vec![],
84            cost_input: Some(entry.cost_input),
85            cost_output: Some(entry.cost_output),
86            cost_cache_read: Some(entry.cost_cache_read),
87            cost_cache_write: Some(entry.cost_cache_write),
88            input_modalities: entry.input.iter().map(|m| format!("{:?}", m).to_lowercase()).collect(),
89        }
90    }
91
92    /// Create from a model_registry::Model
93    pub fn from_registry_model(model: &oxi_ai::Model) -> Self {
94        Self {
95            provider: model.provider.clone(),
96            id: model.id.clone(),
97            name: Some(model.name.clone()),
98            description: None,
99            context_window: Some(model.context_window as u32),
100            supported_features: vec![],
101            cost_input: Some(model.cost.input),
102            cost_output: Some(model.cost.output),
103            cost_cache_read: Some(model.cost.cache_read),
104            cost_cache_write: Some(model.cost.cache_write),
105            input_modalities: model.input.iter().map(|m| format!("{:?}", m).to_lowercase()).collect(),
106        }
107    }
108}
109
110/// Result of parsing a model pattern
111#[derive(Debug)]
112pub struct ParsedModelResult {
113    pub provider: Option<String>,
114    pub model_id: String,
115    pub thinking_level: Option<String>,
116    pub warning: Option<String>,
117}
118
119/// Result of resolving a CLI model
120#[derive(Debug)]
121pub struct ResolveCliModelResult {
122    pub model: Option<Model>,
123    pub thinking_level: Option<String>,
124    pub warning: Option<String>,
125    pub error: Option<String>,
126}
127
128/// Result of finding initial model
129#[derive(Debug)]
130pub struct InitialModelResult {
131    pub model: Option<Model>,
132    pub thinking_level: String,
133    pub fallback_message: Option<String>,
134}
135
136/// Result of restore model from session
137#[derive(Debug)]
138pub struct RestoreModelResult {
139    pub model: Option<Model>,
140    pub fallback_message: Option<String>,
141    pub reason: Option<String>,
142}
143
144// =============================================================================
145// Core Functions
146// =============================================================================
147
148/// Check if two models are equal (same provider and id)
149///
150/// Used for deduplication in model scope resolution.
151pub fn models_are_equal(a: &Model, b: &Model) -> bool {
152    a.provider == b.provider && a.id == b.id
153}
154
155/// Check if a model ID looks like an alias (no date suffix)
156fn is_alias(id: &str) -> bool {
157    // Aliases end with -latest or don't have date patterns
158    if id.ends_with("-latest") {
159        return true;
160    }
161    // Check if ends with date pattern (-YYYYMMDD)
162    let date_pattern = regex::Regex::new(r"-\d{8}$").ok();
163    match date_pattern {
164        Some(re) => !re.is_match(id),
165        None => true,
166    }
167}
168
169/// Match a glob pattern against text
170///
171/// Supports:
172/// - `*` - matches any characters
173/// - `?` - matches any single character
174/// - `[abc]` - matches any character in the set
175///
176/// Note: Matching is case-insensitive.
177pub fn match_glob(pattern: &str, text: &str) -> bool {
178    // Simple implementation: convert glob to regex
179    let mut regex_pattern = String::new();
180    let mut in_class = false;
181    let mut chars = pattern.chars().peekable();
182    
183    while let Some(c) = chars.next() {
184        match c {
185            '*' => {
186                if in_class {
187                    regex_pattern.push_str("\\*");
188                } else {
189                    regex_pattern.push_str(".*");
190                }
191            }
192            '?' => {
193                if in_class {
194                    regex_pattern.push('?');
195                } else {
196                    regex_pattern.push('.');
197                }
198            }
199            '[' => {
200                in_class = true;
201                regex_pattern.push('[');
202            }
203            ']' => {
204                in_class = false;
205                regex_pattern.push(']');
206            }
207            '.' | '+' | '^' | '$' | '\\' | '(' | ')' | '{' | '}' | '|' => {
208                // Escape regex special characters (but not those in char classes)
209                if !in_class {
210                    regex_pattern.push('\\');
211                }
212                regex_pattern.push(c);
213            }
214            _ => regex_pattern.push(c),
215        }
216    }
217    
218    // Handle trailing ** in patterns
219    if pattern.ends_with("**") {
220        regex_pattern.push_str(".*");
221    }
222    
223    // Use case-insensitive regex matching
224    regex::RegexBuilder::new(&format!("^{}$", regex_pattern))
225        .case_insensitive(true)
226        .build()
227        .map(|re| re.is_match(text))
228        .unwrap_or_else(|_| pattern.eq_ignore_ascii_case(text))
229}
230
231/// Find all models matching a provider and glob pattern
232pub fn find_models_by_glob<'a>(provider: &str, pattern: &str, models: &'a [Model]) -> Vec<&'a Model> {
233    models
234        .iter()
235        .filter(|m| m.provider == provider && match_glob(pattern, &m.id))
236        .collect()
237}
238
239/// Find all models matching a pattern (glob or substring) from the full model database
240pub fn find_models_by_pattern(pattern: &str, models: &[Model]) -> Vec<Model> {
241    let pattern_lower = pattern.to_lowercase();
242    models
243        .iter()
244        .filter(|m| {
245            m.id.to_lowercase().contains(&pattern_lower)
246                || m.full_id().to_lowercase().contains(&pattern_lower)
247                || m.name
248                    .as_ref()
249                    .map(|n| n.to_lowercase().contains(&pattern_lower))
250                    .unwrap_or(false)
251        })
252        .cloned()
253        .collect()
254}
255
256/// Get the thinking level mapping for a model that has thinking variants
257///
258/// For models like claude-3-5-sonnet-*, returns a mapping:
259/// - "high" -> "claude-3-5-sonnet-20240620"
260/// - "medium" -> "claude-3-5-sonnet-latest"
261pub fn get_thinking_level_map(model_id: &str) -> Option<HashMap<String, String>> {
262    // Strip common suffixes to find base model name
263    let base = if let Some(stripped) = model_id.strip_suffix("-latest") {
264        stripped
265    } else if let Some(dated) = model_id.strip_suffix(regex::Regex::new(r"-\d{8}").ok().unwrap().as_str()) {
266        dated
267    } else {
268        // Also try stripping numbered suffixes like -5, -6, etc.
269        // Match patterns like claude-opus-4-5, claude-sonnet-4-6, etc.
270        model_id
271    };
272    
273    // Only certain models have thinking variants
274    let thinking_models = [
275        ("claude-3-5-sonnet", vec![
276            ("high", "claude-3-5-sonnet-20240620"),
277            ("medium", "claude-3-5-sonnet-latest"),
278            ("low", "claude-3-5-sonnet-latest"),
279        ]),
280        ("claude-3-7-sonnet", vec![
281            ("high", "claude-3-7-sonnet-20250219"),
282            ("medium", "claude-3-7-sonnet-20250219"),
283            ("low", "claude-3-7-sonnet-latest"),
284        ]),
285        ("claude-opus-4", vec![
286            ("high", "claude-opus-4-5-20251101"),
287            ("medium", "claude-opus-4-5"),
288            ("low", "claude-opus-4-1"),
289            ("off", "claude-opus-4-0"),
290        ]),
291        ("claude-sonnet-4", vec![
292            ("high", "claude-sonnet-4-20250514"),
293            ("medium", "claude-sonnet-4-5"),
294            ("low", "claude-sonnet-4-0"),
295            ("off", "claude-sonnet-4-0"),
296        ]),
297    ];
298    
299    for (base_name, levels) in thinking_models {
300        if base.contains(base_name) {
301            let mut map = HashMap::new();
302            for (level, id) in levels {
303                map.insert(level.to_string(), id.to_string());
304            }
305            return Some(map);
306        }
307    }
308    
309    None
310}
311
312/// Clamp a requested thinking level to what the model supports
313///
314/// If the model doesn't support the requested level, returns the closest
315/// supported level.
316pub fn clamp_thinking_level(model_id: &str, requested_level: &str) -> String {
317    if let Some(map) = get_thinking_level_map(model_id) {
318        // If the requested level is in the map, use it
319        if map.contains_key(requested_level) {
320            return map.get(requested_level).cloned().unwrap_or_else(|| model_id.to_string());
321        }
322        
323        // Find the closest level that IS in the map
324        let level_idx = THINKING_LEVELS.iter().position(|&l| l == requested_level);
325        if let Some(idx) = level_idx {
326            // Search downward for the closest supported level
327            for i in (0..idx).rev() {
328                let level = THINKING_LEVELS[i];
329                if map.contains_key(level) {
330                    return map.get(level).cloned().unwrap_or_else(|| model_id.to_string());
331                }
332            }
333        }
334    }
335    
336    // Fallback: return the requested level as-is if model has no special mapping
337    requested_level.to_string()
338}
339
340/// Check if auth is configured for a provider/model
341///
342/// This checks both environment variables and stored auth credentials.
343/// Note: This only checks environment variables as a simple, reliable approach.
344pub fn has_configured_auth(provider: &str, _model: &Model) -> bool {
345    // Check environment variables
346    let env_var = match provider {
347        "anthropic" => "ANTHROPIC_API_KEY",
348        "openai" => "OPENAI_API_KEY",
349        "google" => "GOOGLE_API_KEY",
350        "deepseek" => "DEEPSEEK_API_KEY",
351        "mistral" => "MISTRAL_API_KEY",
352        "groq" => "GROQ_API_KEY",
353        "cerebras" => "CEREBRAS_API_KEY",
354        "xai" => "XAI_API_KEY",
355        "openrouter" => "OPENROUTER_API_KEY",
356        "azure-openai" | "azure-openai-responses" => "AZURE_OPENAI_API_KEY",
357        "amazon-bedrock" => "AWS_ACCESS_KEY_ID",
358        _ => return false,
359    };
360    
361    std::env::var(env_var).is_ok()
362}
363
364/// Parse a model pattern into components
365///
366/// # Arguments
367/// * `pattern` - The model pattern (e.g., "anthropic/claude-3.5-sonnet" or "sonnet:high")
368/// * `available_models` - List of available models for validation
369///
370/// # Returns
371/// A parsed model result with provider, model_id, and optional thinking level
372pub fn parse_model_pattern(
373    pattern: &str,
374    available_models: &[Model],
375) -> ParsedModelResult {
376    let pattern = pattern.trim();
377    if pattern.is_empty() {
378        return ParsedModelResult {
379            provider: None,
380            model_id: String::new(),
381            thinking_level: None,
382            warning: Some("Empty model pattern".to_string()),
383        };
384    }
385
386    // Check for thinking level suffix (e.g., "sonnet:high")
387    let last_colon = pattern.rfind(':');
388    let (base_pattern, thinking_level) = if let Some(idx) = last_colon {
389        let suffix = &pattern[idx + 1..];
390        if THINKING_LEVELS.contains(&suffix) {
391            (&pattern[..idx], Some(suffix.to_string()))
392        } else {
393            (pattern, None)
394        }
395    } else {
396        (pattern, None)
397    };
398
399    // Try to find an exact match first
400    let exact_match = available_models.iter().find(|m| {
401        m.id.eq_ignore_ascii_case(base_pattern)
402            || m.full_id().eq_ignore_ascii_case(base_pattern)
403    });
404
405    if let Some(model) = exact_match {
406        return ParsedModelResult {
407            provider: Some(model.provider.clone()),
408            model_id: model.id.clone(),
409            thinking_level,
410            warning: None,
411        };
412    }
413
414    // Try to parse provider/model format
415    if let Some(slash_idx) = base_pattern.find('/') {
416        let provider = &base_pattern[..slash_idx];
417        let model_id = &base_pattern[slash_idx + 1..];
418
419        // Check if provider exists in available models
420        let provider_exists = available_models.iter().any(|m| {
421            m.provider.eq_ignore_ascii_case(provider)
422        });
423
424        if provider_exists {
425            return ParsedModelResult {
426                provider: Some(provider.to_string()),
427                model_id: model_id.to_string(),
428                thinking_level,
429                warning: None,
430            };
431        }
432    }
433
434    // Try partial matching
435    let partial_matches: Vec<&Model> = available_models
436        .iter()
437        .filter(|m| {
438            m.id.to_lowercase().contains(&base_pattern.to_lowercase())
439                || m.name
440                    .as_ref()
441                    .map(|n| n.to_lowercase().contains(&base_pattern.to_lowercase()))
442                    .unwrap_or(false)
443        })
444        .collect();
445
446    if partial_matches.len() == 1 {
447        let model = partial_matches[0];
448        return ParsedModelResult {
449            provider: Some(model.provider.clone()),
450            model_id: model.id.clone(),
451            thinking_level,
452            warning: None,
453        };
454    } else if partial_matches.len() > 1 {
455        // Prefer aliases over dated versions
456        let aliases: Vec<_> = partial_matches.iter().filter(|m| is_alias(&m.id)).collect();
457        if !aliases.is_empty() {
458            let model = aliases[0];
459            return ParsedModelResult {
460                provider: Some(model.provider.clone()),
461                model_id: model.id.clone(),
462                thinking_level,
463                warning: Some(format!(
464                    "Multiple models match '{}', selected '{}'",
465                    base_pattern,
466                    model.full_id()
467                )),
468            };
469        }
470        // Use the latest dated version (sort descending)
471        let mut sorted = partial_matches.to_vec();
472        sorted.sort_by(|a, b| b.id.cmp(&a.id));
473        let model = sorted[0];
474        return ParsedModelResult {
475            provider: Some(model.provider.clone()),
476            model_id: model.id.clone(),
477            thinking_level,
478            warning: Some(format!(
479                "Multiple models match '{}', selected '{}'",
480                base_pattern,
481                model.full_id()
482            )),
483        };
484    }
485
486    // No match found - return as raw pattern
487    ParsedModelResult {
488        provider: None,
489        model_id: pattern.to_string(),
490        thinking_level,
491        warning: Some(format!(
492            "Model '{}' not found in available models. Treating as custom model ID.",
493            pattern
494        )),
495    }
496}
497
498/// Default models per provider
499pub fn default_model_per_provider() -> HashMap<String, String> {
500    let mut map = HashMap::new();
501    map.insert("anthropic".to_string(), "claude-sonnet-4-5".to_string());
502    map.insert("openai".to_string(), "gpt-4o".to_string());
503    map.insert("google".to_string(), "gemini-2.5-pro".to_string());
504    map.insert("deepseek".to_string(), "deepseek-v3".to_string());
505    map.insert("openrouter".to_string(), "anthropic/claude-sonnet-4".to_string());
506    map.insert("groq".to_string(), "mixtral-8x7b".to_string());
507    map.insert("cerebras".to_string(), "llama-3.3-70b".to_string());
508    map.insert("mistral".to_string(), "mistral-large".to_string());
509    map.insert("xai".to_string(), "grok-2".to_string());
510    map.insert("amazon-bedrock".to_string(), "anthropic.claude-v2".to_string());
511    map.insert("azure-openai".to_string(), "gpt-4o".to_string());
512    map
513}
514
515/// Resolve a model from CLI arguments
516pub fn resolve_cli_model(
517    cli_provider: Option<&str>,
518    cli_model: Option<&str>,
519    available_models: &[Model],
520    _settings: Option<&Settings>,
521) -> ResolveCliModelResult {
522    let cli_model = match cli_model {
523        Some(m) => m,
524        None => {
525            return ResolveCliModelResult {
526                model: None,
527                thinking_level: None,
528                warning: None,
529                error: None,
530            };
531        }
532    };
533
534    // Build provider map for case-insensitive lookup
535    let mut provider_map: HashMap<String, String> = HashMap::new();
536    for model in available_models {
537        provider_map.insert(model.provider.to_lowercase(), model.provider.clone());
538    }
539
540    // Try to resolve provider
541    let provider = if let Some(p) = cli_provider {
542        provider_map.get(&p.to_lowercase()).cloned()
543    } else if let Some(slash_idx) = cli_model.find('/') {
544        let maybe_provider = &cli_model[..slash_idx];
545        provider_map.get(&maybe_provider.to_lowercase()).cloned()
546    } else {
547        None
548    };
549
550    // Extract the model pattern
551    let model_pattern = if let Some(ref p) = provider {
552        if cli_model.to_lowercase().starts_with(&format!("{}/", p.to_lowercase())) {
553            &cli_model[p.len() + 1..]
554        } else {
555            cli_model
556        }
557    } else {
558        cli_model
559    };
560
561    // Parse the pattern
562    let parsed = parse_model_pattern(model_pattern, available_models);
563
564    // Find the model
565    let model = if let Some(ref p) = provider {
566        available_models
567            .iter()
568            .find(|m| {
569                m.provider.eq_ignore_ascii_case(p) && m.id.eq_ignore_ascii_case(&parsed.model_id)
570            })
571            .cloned()
572    } else if let Some(ref p) = parsed.provider {
573        available_models
574            .iter()
575            .find(|m| {
576                m.provider.eq_ignore_ascii_case(p) && m.id.eq_ignore_ascii_case(&parsed.model_id)
577            })
578            .cloned()
579    } else {
580        // Try matching without provider
581        available_models
582            .iter()
583            .find(|m| m.id.eq_ignore_ascii_case(&parsed.model_id))
584            .cloned()
585    };
586
587    if let Some(ref m) = model {
588        ResolveCliModelResult {
589            model: Some(m.clone()),
590            thinking_level: parsed.thinking_level,
591            warning: parsed.warning,
592            error: None,
593        }
594    } else {
595        // Try building a fallback custom model
596        let fallback_model = if let Some(ref p) = provider {
597            Some(Model {
598                provider: p.clone(),
599                id: parsed.model_id.clone(),
600                name: Some(parsed.model_id.clone()),
601                description: None,
602                context_window: None,
603                supported_features: vec![],
604                cost_input: None,
605                cost_output: None,
606                cost_cache_read: None,
607                cost_cache_write: None,
608                input_modalities: vec!["text".to_string()],
609            })
610        } else {
611            None
612        };
613
614        ResolveCliModelResult {
615            model: fallback_model.clone(),
616            thinking_level: parsed.thinking_level,
617            warning: parsed.warning,
618            error: fallback_model.is_none().then(|| {
619                format!(
620                    "Model '{}' not found. Use --list-models to see available models.",
621                    cli_model
622                )
623            }),
624        }
625    }
626}
627
628/// Find the initial model to use based on priority:
629/// 1. CLI args
630/// 2. First model from scoped models
631/// 3. Saved default from settings
632/// 4. First available model
633pub fn find_initial_model(
634    cli_provider: Option<&str>,
635    cli_model: Option<&str>,
636    scoped_models: &[Model],
637    is_continuing: bool,
638    settings: Option<&Settings>,
639    available_models: &[Model],
640) -> InitialModelResult {
641    // 1. CLI args take priority
642    if cli_provider.is_some() || cli_model.is_some() {
643        let result = resolve_cli_model(cli_provider, cli_model, available_models, settings);
644        if result.error.is_none() {
645            return InitialModelResult {
646                model: result.model,
647                thinking_level: result.thinking_level.unwrap_or_else(|| DEFAULT_THINKING_LEVEL.to_string()),
648                fallback_message: None,
649            };
650        }
651    }
652
653    // 2. Use first model from scoped models (skip if continuing)
654    if !scoped_models.is_empty() && !is_continuing {
655        return InitialModelResult {
656            model: Some(scoped_models[0].clone()),
657            thinking_level: DEFAULT_THINKING_LEVEL.to_string(),
658            fallback_message: None,
659        };
660    }
661
662    // 3. Try saved default from settings
663    if let Some(ref s) = settings {
664        if let Some(default_model) = &s.default_model {
665            let parsed = parse_model_pattern(default_model, available_models);
666            if let Some(ref p) = parsed.provider {
667                let model = available_models
668                    .iter()
669                    .find(|m| m.provider.eq_ignore_ascii_case(p) && m.id.eq_ignore_ascii_case(&parsed.model_id))
670                    .cloned();
671                if model.is_some() {
672                    return InitialModelResult {
673                        model,
674                        thinking_level: format!("{:?}", s.thinking_level),
675                        fallback_message: None,
676                    };
677                }
678            }
679        }
680    }
681
682    // 4. Try default models from known providers
683    let defaults = default_model_per_provider();
684    for (provider, default_id) in &defaults {
685        if let Some(model) = available_models
686            .iter()
687            .find(|m| m.provider.eq_ignore_ascii_case(provider) && m.id.eq_ignore_ascii_case(default_id))
688        {
689            return InitialModelResult {
690                model: Some(model.clone()),
691                thinking_level: DEFAULT_THINKING_LEVEL.to_string(),
692                fallback_message: None,
693            };
694        }
695    }
696
697    // 5. Use first available model
698    if let Some(model) = available_models.first() {
699        return InitialModelResult {
700            model: Some(model.clone()),
701            thinking_level: DEFAULT_THINKING_LEVEL.to_string(),
702            fallback_message: None,
703        };
704    }
705
706    // No model found
707    InitialModelResult {
708        model: None,
709        thinking_level: DEFAULT_THINKING_LEVEL.to_string(),
710        fallback_message: Some("No models available. Check your installation.".to_string()),
711    }
712}
713
714/// Restore model from session with fallback and auth validation
715///
716/// This function now properly checks auth configuration before restoring a model.
717pub fn restore_model_from_session(
718    saved_provider: &str,
719    saved_model_id: &str,
720    current_model: Option<&Model>,
721    should_print_messages: bool,
722    available_models: &[Model],
723) -> RestoreModelResult {
724    let restored = available_models
725        .iter()
726        .find(|m| {
727            m.provider.eq_ignore_ascii_case(saved_provider) && m.id.eq_ignore_ascii_case(saved_model_id)
728        })
729        .cloned();
730
731    match (&restored, current_model) {
732        (Some(ref model), _) => {
733            // Check if the model has auth configured
734            if has_configured_auth(saved_provider, model) {
735                if should_print_messages {
736                    eprintln!("Restored model: {}/{}", saved_provider, saved_model_id);
737                }
738                RestoreModelResult {
739                    model: Some((*model).clone()),
740                    fallback_message: None,
741                    reason: None,
742                }
743            } else {
744                // Model exists but no auth - try fallback
745                if should_print_messages {
746                    eprintln!(
747                        "Warning: Could not restore model {}/{} (no auth configured).",
748                        saved_provider, saved_model_id
749                    );
750                }
751                
752                if let Some(current) = current_model {
753                    if should_print_messages {
754                        eprintln!("Falling back to: {}/{}", current.provider, current.id);
755                    }
756                    RestoreModelResult {
757                        model: Some((*current).clone()),
758                        fallback_message: Some(format!(
759                            "Could not restore model {}/{} (no auth configured). Using current model.",
760                            saved_provider, saved_model_id
761                        )),
762                        reason: Some("no_auth".to_string()),
763                    }
764                } else if let Some(fallback) = available_models.first() {
765                    if should_print_messages {
766                        eprintln!("Using first available model: {}/{}", fallback.provider, fallback.id);
767                    }
768                    RestoreModelResult {
769                        model: Some(fallback.clone()),
770                        fallback_message: Some(format!(
771                            "Could not restore model {}/{} (no auth configured). Using first available model.",
772                            saved_provider, saved_model_id
773                        )),
774                        reason: Some("no_auth".to_string()),
775                    }
776                } else {
777                    RestoreModelResult {
778                        model: None,
779                        fallback_message: Some("No models available.".to_string()),
780                        reason: Some("no_auth".to_string()),
781                    }
782                }
783            }
784        }
785        (None, Some(current)) => {
786            if should_print_messages {
787                eprintln!(
788                    "Warning: Could not restore model {}/{} (model not found). Falling back to current model.",
789                    saved_provider, saved_model_id
790                );
791                eprintln!("Falling back to: {}/{}", current.provider, current.id);
792            }
793            RestoreModelResult {
794                model: Some((*current).clone()),
795                fallback_message: Some(format!(
796                    "Could not restore model {}/{} (model not found). Using current model.",
797                    saved_provider, saved_model_id
798                )),
799                reason: Some("model_not_found".to_string()),
800            }
801        }
802        (None, None) => {
803            // Try to find any available model
804            if let Some(model) = available_models.first() {
805                if should_print_messages {
806                    eprintln!(
807                        "Warning: Could not restore model {}/{} (model not found).",
808                        saved_provider, saved_model_id
809                    );
810                    eprintln!("Using first available model: {}/{}", model.provider, model.id);
811                }
812                RestoreModelResult {
813                    model: Some(model.clone()),
814                    fallback_message: Some(format!(
815                        "Could not restore model {}/{}. Using first available model.",
816                        saved_provider, saved_model_id
817                    )),
818                    reason: Some("model_not_found".to_string()),
819                }
820            } else {
821                RestoreModelResult {
822                    model: None,
823                    fallback_message: Some("No models available.".to_string()),
824                    reason: Some("no_models".to_string()),
825                }
826            }
827        }
828    }
829}
830
831#[cfg(test)]
832mod tests {
833    use super::*;
834
835    fn sample_models() -> Vec<Model> {
836        vec![
837            Model {
838                provider: "anthropic".to_string(),
839                id: "claude-sonnet-4-5".to_string(),
840                name: Some("Claude Sonnet 4.5".to_string()),
841                description: None,
842                context_window: Some(200000),
843                supported_features: vec!["tools".to_string(), "vision".to_string()],
844                cost_input: Some(3.0),
845                cost_output: Some(15.0),
846                cost_cache_read: Some(0.3),
847                cost_cache_write: Some(3.75),
848                input_modalities: vec!["text".to_string(), "image".to_string()],
849            },
850            Model {
851                provider: "anthropic".to_string(),
852                id: "claude-opus-4-7".to_string(),
853                name: Some("Claude Opus 4.7".to_string()),
854                description: None,
855                context_window: Some(200000),
856                supported_features: vec!["tools".to_string(), "vision".to_string()],
857                cost_input: Some(15.0),
858                cost_output: Some(75.0),
859                cost_cache_read: Some(0.5),
860                cost_cache_write: Some(6.25),
861                input_modalities: vec!["text".to_string(), "image".to_string()],
862            },
863            Model {
864                provider: "openai".to_string(),
865                id: "gpt-4o".to_string(),
866                name: Some("GPT-4o".to_string()),
867                description: None,
868                context_window: Some(128000),
869                supported_features: vec!["tools".to_string()],
870                cost_input: Some(2.5),
871                cost_output: Some(10.0),
872                cost_cache_read: Some(1.25),
873                cost_cache_write: Some(0.0),
874                input_modalities: vec!["text".to_string(), "image".to_string()],
875            },
876            Model {
877                provider: "google".to_string(),
878                id: "gemini-2.5-pro".to_string(),
879                name: Some("Gemini 2.5 Pro".to_string()),
880                description: None,
881                context_window: Some(1000000),
882                supported_features: vec!["tools".to_string()],
883                cost_input: Some(1.25),
884                cost_output: Some(5.0),
885                cost_cache_read: Some(0.0),
886                cost_cache_write: Some(0.0),
887                input_modalities: vec!["text".to_string(), "image".to_string()],
888            },
889        ]
890    }
891
892    // =============================================================================
893    // Model Equality Tests
894    // =============================================================================
895    
896    #[test]
897    fn test_models_are_equal_same() {
898        let model1 = Model {
899            provider: "anthropic".to_string(),
900            id: "claude-sonnet-4-5".to_string(),
901            name: Some("Claude Sonnet 4.5".to_string()),
902            description: None,
903            context_window: Some(200000),
904            supported_features: vec![],
905            cost_input: None,
906            cost_output: None,
907            cost_cache_read: None,
908            cost_cache_write: None,
909            input_modalities: vec![],
910        };
911        let model2 = Model {
912            provider: "anthropic".to_string(),
913            id: "claude-sonnet-4-5".to_string(),
914            name: Some("Claude Sonnet 4.5 (different name)".to_string()),
915            description: None,
916            context_window: Some(200000),
917            supported_features: vec![],
918            cost_input: None,
919            cost_output: None,
920            cost_cache_read: None,
921            cost_cache_write: None,
922            input_modalities: vec![],
923        };
924        
925        assert!(models_are_equal(&model1, &model2));
926    }
927    
928    #[test]
929    fn test_models_are_equal_different_provider() {
930        let model1 = Model {
931            provider: "anthropic".to_string(),
932            id: "claude-sonnet-4-5".to_string(),
933            name: None,
934            description: None,
935            context_window: None,
936            supported_features: vec![],
937            cost_input: None,
938            cost_output: None,
939            cost_cache_read: None,
940            cost_cache_write: None,
941            input_modalities: vec![],
942        };
943        let model2 = Model {
944            provider: "openai".to_string(),
945            id: "claude-sonnet-4-5".to_string(),
946            name: None,
947            description: None,
948            context_window: None,
949            supported_features: vec![],
950            cost_input: None,
951            cost_output: None,
952            cost_cache_read: None,
953            cost_cache_write: None,
954            input_modalities: vec![],
955        };
956        
957        assert!(!models_are_equal(&model1, &model2));
958    }
959    
960    #[test]
961    fn test_models_are_equal_different_id() {
962        let model1 = Model {
963            provider: "anthropic".to_string(),
964            id: "claude-sonnet-4-5".to_string(),
965            name: None,
966            description: None,
967            context_window: None,
968            supported_features: vec![],
969            cost_input: None,
970            cost_output: None,
971            cost_cache_read: None,
972            cost_cache_write: None,
973            input_modalities: vec![],
974        };
975        let model2 = Model {
976            provider: "anthropic".to_string(),
977            id: "claude-opus-4-7".to_string(),
978            name: None,
979            description: None,
980            context_window: None,
981            supported_features: vec![],
982            cost_input: None,
983            cost_output: None,
984            cost_cache_read: None,
985            cost_cache_write: None,
986            input_modalities: vec![],
987        };
988        
989        assert!(!models_are_equal(&model1, &model2));
990    }
991
992    // =============================================================================
993    // Glob Pattern Tests
994    // =============================================================================
995    
996    #[test]
997    fn test_match_glob_exact() {
998        assert!(match_glob("claude-sonnet-4-5", "claude-sonnet-4-5"));
999        assert!(!match_glob("claude-sonnet-4-5", "claude-opus-4-7"));
1000    }
1001    
1002    #[test]
1003    fn test_match_glob_asterisk() {
1004        assert!(match_glob("claude-*", "claude-sonnet-4-5"));
1005        assert!(match_glob("claude-*", "claude-opus-4-7"));
1006        assert!(!match_glob("claude-*", "gpt-4o"));
1007    }
1008    
1009    #[test]
1010    fn test_match_glob_question() {
1011        assert!(match_glob("claude-?-sonnet-4-5", "claude-3-sonnet-4-5"));
1012        assert!(!match_glob("claude-?-sonnet-4-5", "claude-35-sonnet-4-5"));
1013    }
1014    
1015    #[test]
1016    fn test_match_glob_char_class() {
1017        // Character classes match case-insensitively (since glob matching is case-insensitive overall)
1018        assert!(match_glob("claude-[a-z]-sonnet", "claude-a-sonnet"));
1019        assert!(match_glob("claude-[a-z]-sonnet", "claude-b-sonnet"));
1020        // Note: character classes are case-insensitive with our implementation
1021        assert!(match_glob("claude-[a-z]-sonnet", "claude-A-sonnet"));
1022    }
1023    
1024    #[test]
1025    fn test_match_glob_case_insensitive() {
1026        assert!(match_glob("CLAUDE-*", "claude-sonnet-4-5"));
1027    }
1028    
1029    #[test]
1030    fn test_find_models_by_glob() {
1031        let models = sample_models();
1032        let results = find_models_by_glob("anthropic", "claude-*", &models);
1033        assert_eq!(results.len(), 2);
1034        assert!(results.iter().all(|m| m.provider == "anthropic"));
1035    }
1036    
1037    #[test]
1038    fn test_find_models_by_glob_no_match() {
1039        let models = sample_models();
1040        let results = find_models_by_glob("openai", "gpt-*", &models);
1041        assert_eq!(results.len(), 1);
1042    }
1043
1044    // =============================================================================
1045    // Thinking Level Mapping Tests
1046    // =============================================================================
1047    
1048    #[test]
1049    fn test_get_thinking_level_map_claude_35_sonnet() {
1050        let map = get_thinking_level_map("claude-3-5-sonnet-latest");
1051        assert!(map.is_some());
1052        let map = map.unwrap();
1053        assert_eq!(map.get("high"), Some(&"claude-3-5-sonnet-20240620".to_string()));
1054    }
1055    
1056    #[test]
1057    fn test_get_thinking_level_map_claude_opus_4() {
1058        let map = get_thinking_level_map("claude-opus-4-5");
1059        assert!(map.is_some());
1060        let map = map.unwrap();
1061        assert_eq!(map.get("high"), Some(&"claude-opus-4-5-20251101".to_string()));
1062        assert_eq!(map.get("medium"), Some(&"claude-opus-4-5".to_string()));
1063    }
1064    
1065    #[test]
1066    fn test_get_thinking_level_map_no_match() {
1067        let map = get_thinking_level_map("gpt-4o");
1068        assert!(map.is_none());
1069    }
1070    
1071    #[test]
1072    fn test_clamp_thinking_level_supported() {
1073        let result = clamp_thinking_level("claude-3-5-sonnet-latest", "high");
1074        assert_eq!(result, "claude-3-5-sonnet-20240620");
1075    }
1076    
1077    #[test]
1078    fn test_clamp_thinking_level_clamp_down() {
1079        // Request "xhigh" which doesn't exist, should clamp to "high"
1080        let result = clamp_thinking_level("claude-3-5-sonnet-latest", "xhigh");
1081        assert_eq!(result, "claude-3-5-sonnet-20240620");
1082    }
1083    
1084    #[test]
1085    fn test_clamp_thinking_level_no_mapping() {
1086        // gpt-4o has no mapping, should return requested level
1087        let result = clamp_thinking_level("gpt-4o", "high");
1088        assert_eq!(result, "high");
1089    }
1090
1091    // =============================================================================
1092    // Auth Validation Tests
1093    // =============================================================================
1094    
1095    #[test]
1096    fn test_has_configured_auth_unknown_provider() {
1097        let model = Model {
1098            provider: "unknown".to_string(),
1099            id: "test".to_string(),
1100            name: None,
1101            description: None,
1102            context_window: None,
1103            supported_features: vec![],
1104            cost_input: None,
1105            cost_output: None,
1106            cost_cache_read: None,
1107            cost_cache_write: None,
1108            input_modalities: vec![],
1109        };
1110        
1111        // Without env vars, this should return false
1112        let has_auth = has_configured_auth("unknown", &model);
1113        assert!(!has_auth);
1114    }
1115    
1116    #[test]
1117    fn test_has_configured_auth_known_provider_no_env() {
1118        let model = Model {
1119            provider: "anthropic".to_string(),
1120            id: "claude-sonnet-4-5".to_string(),
1121            name: None,
1122            description: None,
1123            context_window: None,
1124            supported_features: vec![],
1125            cost_input: None,
1126            cost_output: None,
1127            cost_cache_read: None,
1128            cost_cache_write: None,
1129            input_modalities: vec![],
1130        };
1131        
1132        // Without setting env vars in test, this should return false
1133        let has_auth = has_configured_auth("anthropic", &model);
1134        // This might be true if ANTHROPIC_API_KEY is set in the environment
1135        // which is fine - the test just checks the function works
1136    }
1137
1138    // =============================================================================
1139    // Model Parsing Tests
1140    // =============================================================================
1141    
1142    #[test]
1143    fn test_parse_model_pattern_exact() {
1144        let models = sample_models();
1145        let result = parse_model_pattern("claude-sonnet-4-5", &models);
1146
1147        assert_eq!(result.model_id, "claude-sonnet-4-5");
1148        assert_eq!(result.provider, Some("anthropic".to_string()));
1149        assert!(result.warning.is_none());
1150    }
1151
1152    #[test]
1153    fn test_parse_model_pattern_with_provider() {
1154        let models = sample_models();
1155        let result = parse_model_pattern("anthropic/claude-sonnet-4-5", &models);
1156
1157        assert_eq!(result.model_id, "claude-sonnet-4-5");
1158        assert_eq!(result.provider, Some("anthropic".to_string()));
1159    }
1160
1161    #[test]
1162    fn test_parse_model_pattern_with_thinking_level() {
1163        let models = sample_models();
1164        let result = parse_model_pattern("sonnet:high", &models);
1165
1166        assert_eq!(result.thinking_level, Some("high".to_string()));
1167    }
1168    
1169    #[test]
1170    fn test_parse_model_pattern_invalid_thinking_level() {
1171        let models = sample_models();
1172        let result = parse_model_pattern("sonnet:invalid", &models);
1173
1174        assert!(result.thinking_level.is_none());
1175    }
1176
1177    #[test]
1178    fn test_parse_model_pattern_partial_match() {
1179        let models = sample_models();
1180        let result = parse_model_pattern("sonnet", &models);
1181
1182        assert!(result.model_id.contains("sonnet") || result.model_id == "sonnet");
1183        assert!(result.warning.is_some() || result.provider.is_some());
1184    }
1185
1186    #[test]
1187    fn test_parse_model_pattern_not_found() {
1188        let models = sample_models();
1189        let result = parse_model_pattern("nonexistent-model", &models);
1190
1191        assert_eq!(result.model_id, "nonexistent-model");
1192        assert!(result.warning.is_some());
1193    }
1194
1195    // =============================================================================
1196    // CLI Resolution Tests
1197    // =============================================================================
1198    
1199    #[test]
1200    fn test_resolve_cli_model_with_provider() {
1201        let models = sample_models();
1202        let result = resolve_cli_model(Some("anthropic"), Some("claude-sonnet-4-5"), &models, None);
1203
1204        assert!(result.error.is_none());
1205        assert!(result.model.is_some());
1206        assert_eq!(result.model.unwrap().id, "claude-sonnet-4-5");
1207    }
1208
1209    #[test]
1210    fn test_resolve_cli_model_with_slash() {
1211        let models = sample_models();
1212        let result = resolve_cli_model(None, Some("anthropic/claude-sonnet-4-5"), &models, None);
1213
1214        assert!(result.error.is_none());
1215        assert!(result.model.is_some());
1216    }
1217
1218    #[test]
1219    fn test_resolve_cli_model_not_found() {
1220        let models = sample_models();
1221        let result = resolve_cli_model(None, Some("nonexistent-model"), &models, None);
1222
1223        assert!(result.error.is_some() || result.model.is_none());
1224    }
1225
1226    #[test]
1227    fn test_resolve_cli_model_no_args() {
1228        let models = sample_models();
1229        let result = resolve_cli_model(None, None, &models, None);
1230
1231        assert!(result.model.is_none());
1232        assert!(result.error.is_none());
1233    }
1234
1235    // =============================================================================
1236    // Search Tests
1237    // =============================================================================
1238    
1239    #[test]
1240    fn test_find_models_by_pattern() {
1241        let models = sample_models();
1242        let results = find_models_by_pattern("sonnet", &models);
1243
1244        assert!(!results.is_empty());
1245        assert!(results.iter().all(|m| 
1246            m.id.contains("sonnet") || 
1247            m.name.as_ref().map(|n| n.contains("sonnet")).unwrap_or(false)
1248        ));
1249    }
1250    
1251    #[test]
1252    fn test_find_models_by_pattern_full_id() {
1253        let models = sample_models();
1254        let results = find_models_by_pattern("anthropic/claude-sonnet-4-5", &models);
1255        assert!(!results.is_empty());
1256    }
1257
1258    // =============================================================================
1259    // Initial Model Selection Tests
1260    // =============================================================================
1261    
1262    #[test]
1263    fn test_find_initial_model_from_cli() {
1264        let models = sample_models();
1265        let result = find_initial_model(
1266            Some("openai"),
1267            Some("gpt-4o"),
1268            &[],
1269            false,
1270            None,
1271            &models,
1272        );
1273
1274        assert!(result.model.is_some());
1275        assert_eq!(result.model.unwrap().id, "gpt-4o");
1276    }
1277
1278    #[test]
1279    fn test_find_initial_model_fallback_to_available() {
1280        let models = sample_models();
1281        let result = find_initial_model(None, None, &[], false, None, &models);
1282
1283        assert!(result.model.is_some());
1284        assert!(result.fallback_message.is_none());
1285    }
1286    
1287    #[test]
1288    fn test_find_initial_model_default_thinking_level() {
1289        let models = sample_models();
1290        let result = find_initial_model(
1291            Some("openai"),
1292            Some("gpt-4o"),
1293            &[],
1294            false,
1295            None,
1296            &models,
1297        );
1298
1299        assert_eq!(result.thinking_level, DEFAULT_THINKING_LEVEL);
1300    }
1301
1302    // =============================================================================
1303    // Restore Model Tests
1304    // =============================================================================
1305    
1306    #[test]
1307    fn test_restore_model_from_session_success() {
1308        let models = sample_models();
1309        let result = restore_model_from_session(
1310            "anthropic",
1311            "claude-sonnet-4-5",
1312            None,
1313            false,
1314            &models,
1315        );
1316
1317
1318        // Model should be found
1319        assert!(result.model.is_some());
1320        // Note: fallback_message may be set if no auth is configured
1321        // (which is expected in test environment without API keys)
1322        if result.fallback_message.is_some() {
1323            assert_eq!(result.reason, Some("no_auth".to_string()));
1324        }
1325    }
1326    
1327    #[test]
1328    fn test_restore_model_from_session_not_found() {
1329        let models = sample_models();
1330        let current = &models[0];
1331        let result = restore_model_from_session(
1332            "nonexistent",
1333            "model",
1334            Some(current),
1335            false,
1336            &models,
1337        );
1338
1339        assert!(result.model.is_some());
1340        assert!(result.fallback_message.is_some());
1341        assert_eq!(result.reason, Some("model_not_found".to_string()));
1342    }
1343    
1344    #[test]
1345    fn test_restore_model_from_session_fallback() {
1346        let models = sample_models();
1347        let current = &models[0];
1348        let result = restore_model_from_session(
1349            "nonexistent",
1350            "model",
1351            Some(current),
1352            false,
1353            &models,
1354        );
1355
1356        assert!(result.model.is_some());
1357        // Should fall back to current model
1358        assert_eq!(result.model.unwrap().id, current.id);
1359    }
1360
1361    // =============================================================================
1362    // Alias Detection Tests
1363    // =============================================================================
1364    
1365    #[test]
1366    fn test_is_alias() {
1367        assert!(is_alias("claude-sonnet-4-latest"));
1368        assert!(is_alias("simple-model"));
1369        assert!(!is_alias("claude-sonnet-4-20250929"));
1370        assert!(!is_alias("claude-sonnet-4-20250514"));
1371    }
1372    
1373    #[test]
1374    fn test_default_thinking_level_constant() {
1375        assert_eq!(DEFAULT_THINKING_LEVEL, "medium");
1376    }
1377    
1378    #[test]
1379    fn test_thinking_levels_constant() {
1380        assert_eq!(THINKING_LEVELS, &["off", "minimal", "low", "medium", "high", "xhigh"]);
1381    }
1382}