Skip to main content

oxi/
model_resolver.rs

1//! Model name resolution and matching
2//!
3//! Provides utilities for parsing model patterns, resolving model names,
4//! and finding the best model for startup.
5
6use crate::settings::Settings;
7use std::collections::HashMap;
8
9/// Known AI providers
10#[derive(Debug, Clone, PartialEq, Eq, Hash)]
11pub struct Provider {
12    pub id: String,
13    pub name: String,
14    pub website: Option<String>,
15}
16
17impl Provider {
18    pub fn new(id: impl Into<String>, name: impl Into<String>) -> Self {
19        Self {
20            id: id.into(),
21            name: name.into(),
22            website: None,
23        }
24    }
25
26    pub fn with_website(mut self, website: impl Into<String>) -> Self {
27        self.website = Some(website.into());
28        self
29    }
30}
31
32/// A discovered model
33#[derive(Debug, Clone)]
34pub struct Model {
35    pub provider: String,
36    pub id: String,
37    pub name: Option<String>,
38    pub description: Option<String>,
39    pub context_window: Option<u32>,
40    pub supported_features: Vec<String>,
41}
42
43impl Model {
44    /// Get the full model identifier (provider/model_id)
45    pub fn full_id(&self) -> String {
46        format!("{}/{}", self.provider, self.id)
47    }
48}
49
50/// Result of parsing a model pattern
51#[derive(Debug)]
52pub struct ParsedModelResult {
53    pub provider: Option<String>,
54    pub model_id: String,
55    pub thinking_level: Option<String>,
56    pub warning: Option<String>,
57}
58
59/// Result of resolving a CLI model
60#[derive(Debug)]
61pub struct ResolveCliModelResult {
62    pub model: Option<Model>,
63    pub thinking_level: Option<String>,
64    pub warning: Option<String>,
65    pub error: Option<String>,
66}
67
68/// Result of finding initial model
69#[derive(Debug)]
70pub struct InitialModelResult {
71    pub model: Option<Model>,
72    pub thinking_level: String,
73    pub fallback_message: Option<String>,
74}
75
76/// Default models per provider
77pub fn default_model_per_provider() -> HashMap<String, String> {
78    let mut map = HashMap::new();
79    map.insert("anthropic".to_string(), "claude-sonnet-4-5".to_string());
80    map.insert("openai".to_string(), "gpt-4o".to_string());
81    map.insert("google".to_string(), "gemini-2.5-pro".to_string());
82    map.insert("deepseek".to_string(), "deepseek-v3".to_string());
83    map.insert("openrouter".to_string(), "anthropic/claude-sonnet-4".to_string());
84    map.insert("groq".to_string(), "mixtral-8x7b".to_string());
85    map.insert("cerebras".to_string(), "llama-3.3-70b".to_string());
86    map.insert("mistral".to_string(), "mistral-large".to_string());
87    map.insert("xai".to_string(), "grok-2".to_string());
88    map.insert("amazon-bedrock".to_string(), "anthropic.claude-v2".to_string());
89    map.insert("azure-openai".to_string(), "gpt-4o".to_string());
90    map
91}
92
93/// Check if a model ID looks like an alias (no date suffix)
94fn is_alias(id: &str) -> bool {
95    // Aliases end with -latest or don't have date patterns
96    if id.ends_with("-latest") {
97        return true;
98    }
99    // Check if ends with date pattern (-YYYYMMDD)
100    let date_pattern = regex::Regex::new(r"-\d{8}$").ok();
101    match date_pattern {
102        Some(re) => !re.is_match(id),
103        None => true,
104    }
105}
106
107/// Parse a model pattern into components
108///
109/// # Arguments
110/// * `pattern` - The model pattern (e.g., "anthropic/claude-3.5-sonnet" or "sonnet:high")
111/// * `available_models` - List of available models for validation
112///
113/// # Returns
114/// A parsed model result with provider, model_id, and optional thinking level
115pub fn parse_model_pattern(
116    pattern: &str,
117    available_models: &[Model],
118) -> ParsedModelResult {
119    let pattern = pattern.trim();
120    if pattern.is_empty() {
121        return ParsedModelResult {
122            provider: None,
123            model_id: String::new(),
124            thinking_level: None,
125            warning: Some("Empty model pattern".to_string()),
126        };
127    }
128
129    // Check for thinking level suffix (e.g., "sonnet:high")
130    let thinking_levels = ["off", "minimal", "low", "medium", "high", "xhigh"];
131    let last_colon = pattern.rfind(':');
132    let (base_pattern, thinking_level) = if let Some(idx) = last_colon {
133        let suffix = &pattern[idx + 1..];
134        if thinking_levels.contains(&suffix) {
135            (&pattern[..idx], Some(suffix.to_string()))
136        } else {
137            (pattern, None)
138        }
139    } else {
140        (pattern, None)
141    };
142
143    // Try to find an exact match first
144    let exact_match = available_models.iter().find(|m| {
145        m.id.eq_ignore_ascii_case(base_pattern)
146            || m.full_id().eq_ignore_ascii_case(base_pattern)
147    });
148
149    if let Some(model) = exact_match {
150        return ParsedModelResult {
151            provider: Some(model.provider.clone()),
152            model_id: model.id.clone(),
153            thinking_level,
154            warning: None,
155        };
156    }
157
158    // Try to parse provider/model format
159    if let Some(slash_idx) = base_pattern.find('/') {
160        let provider = &base_pattern[..slash_idx];
161        let model_id = &base_pattern[slash_idx + 1..];
162
163        // Check if provider exists in available models
164        let provider_exists = available_models.iter().any(|m| {
165            m.provider.eq_ignore_ascii_case(provider)
166        });
167
168        if provider_exists {
169            return ParsedModelResult {
170                provider: Some(provider.to_string()),
171                model_id: model_id.to_string(),
172                thinking_level,
173                warning: None,
174            };
175        }
176    }
177
178    // Try partial matching
179    let partial_matches: Vec<&Model> = available_models
180        .iter()
181        .filter(|m| {
182            m.id.to_lowercase().contains(&base_pattern.to_lowercase())
183                || m.name
184                    .as_ref()
185                    .map(|n| n.to_lowercase().contains(&base_pattern.to_lowercase()))
186                    .unwrap_or(false)
187        })
188        .collect();
189
190    if partial_matches.len() == 1 {
191        let model = partial_matches[0];
192        return ParsedModelResult {
193            provider: Some(model.provider.clone()),
194            model_id: model.id.clone(),
195            thinking_level,
196            warning: None,
197        };
198    } else if partial_matches.len() > 1 {
199        // Prefer aliases over dated versions
200        let aliases: Vec<_> = partial_matches.iter().filter(|m| is_alias(&m.id)).collect();
201        if !aliases.is_empty() {
202            let model = aliases[0];
203            return ParsedModelResult {
204                provider: Some(model.provider.clone()),
205                model_id: model.id.clone(),
206                thinking_level,
207                warning: Some(format!(
208                    "Multiple models match '{}', selected '{}'",
209                    base_pattern,
210                    model.full_id()
211                )),
212            };
213        }
214        // Use the latest dated version (sort descending)
215        let mut sorted = partial_matches.to_vec();
216        sorted.sort_by(|a, b| b.id.cmp(&a.id));
217        let model = sorted[0];
218        return ParsedModelResult {
219            provider: Some(model.provider.clone()),
220            model_id: model.id.clone(),
221            thinking_level,
222            warning: Some(format!(
223                "Multiple models match '{}', selected '{}'",
224                base_pattern,
225                model.full_id()
226            )),
227        };
228    }
229
230    // No match found - return as raw pattern
231    ParsedModelResult {
232        provider: None,
233        model_id: pattern.to_string(),
234        thinking_level,
235        warning: Some(format!(
236            "Model '{}' not found in available models. Treating as custom model ID.",
237            pattern
238        )),
239    }
240}
241
242/// Find all models matching a glob pattern
243pub fn find_models_by_pattern(pattern: &str, models: &[Model]) -> Vec<Model> {
244    let pattern_lower = pattern.to_lowercase();
245    models
246        .iter()
247        .filter(|m| {
248            m.id.to_lowercase().contains(&pattern_lower)
249                || m.full_id().to_lowercase().contains(&pattern_lower)
250                || m.name
251                    .as_ref()
252                    .map(|n| n.to_lowercase().contains(&pattern_lower))
253                    .unwrap_or(false)
254        })
255        .cloned()
256        .collect()
257}
258
259/// Resolve a model from CLI arguments
260pub fn resolve_cli_model(
261    cli_provider: Option<&str>,
262    cli_model: Option<&str>,
263    available_models: &[Model],
264    _settings: Option<&Settings>,
265) -> ResolveCliModelResult {
266    let cli_model = match cli_model {
267        Some(m) => m,
268        None => {
269            return ResolveCliModelResult {
270                model: None,
271                thinking_level: None,
272                warning: None,
273                error: None,
274            };
275        }
276    };
277
278    // Build provider map for case-insensitive lookup
279    let mut provider_map: HashMap<String, String> = HashMap::new();
280    for model in available_models {
281        provider_map.insert(model.provider.to_lowercase(), model.provider.clone());
282    }
283
284    // Try to resolve provider
285    let provider = if let Some(p) = cli_provider {
286        provider_map.get(&p.to_lowercase()).cloned()
287    } else if let Some(slash_idx) = cli_model.find('/') {
288        let maybe_provider = &cli_model[..slash_idx];
289        provider_map.get(&maybe_provider.to_lowercase()).cloned()
290    } else {
291        None
292    };
293
294    // Extract the model pattern
295    let model_pattern = if let Some(ref p) = provider {
296        if cli_model.to_lowercase().starts_with(&format!("{}/", p.to_lowercase())) {
297            &cli_model[p.len() + 1..]
298        } else {
299            cli_model
300        }
301    } else {
302        cli_model
303    };
304
305    // Parse the pattern
306    let parsed = parse_model_pattern(model_pattern, available_models);
307
308    // Find the model
309    let model = if let Some(ref p) = provider {
310        available_models
311            .iter()
312            .find(|m| {
313                m.provider.eq_ignore_ascii_case(p) && m.id.eq_ignore_ascii_case(&parsed.model_id)
314            })
315            .cloned()
316    } else if let Some(ref p) = parsed.provider {
317        available_models
318            .iter()
319            .find(|m| {
320                m.provider.eq_ignore_ascii_case(p) && m.id.eq_ignore_ascii_case(&parsed.model_id)
321            })
322            .cloned()
323    } else {
324        // Try matching without provider
325        available_models
326            .iter()
327            .find(|m| m.id.eq_ignore_ascii_case(&parsed.model_id))
328            .cloned()
329    };
330
331    if let Some(ref m) = model {
332        ResolveCliModelResult {
333            model: Some(m.clone()),
334            thinking_level: parsed.thinking_level,
335            warning: parsed.warning,
336            error: None,
337        }
338    } else {
339        // Try building a fallback custom model
340        let fallback_model = if let Some(ref p) = provider {
341            Some(Model {
342                provider: p.clone(),
343                id: parsed.model_id.clone(),
344                name: Some(parsed.model_id.clone()),
345                description: None,
346                context_window: None,
347                supported_features: vec![],
348            })
349        } else {
350            None
351        };
352
353        ResolveCliModelResult {
354            model: fallback_model.clone(),
355            thinking_level: parsed.thinking_level,
356            warning: parsed.warning,
357            error: fallback_model.is_none().then(|| {
358                format!(
359                    "Model '{}' not found. Use --list-models to see available models.",
360                    cli_model
361                )
362            }),
363        }
364    }
365}
366
367/// Find the initial model to use based on priority:
368/// 1. CLI args
369/// 2. First model from scoped models
370/// 3. Saved default from settings
371/// 4. First available model
372pub fn find_initial_model(
373    cli_provider: Option<&str>,
374    cli_model: Option<&str>,
375    scoped_models: &[Model],
376    is_continuing: bool,
377    settings: Option<&Settings>,
378    available_models: &[Model],
379) -> InitialModelResult {
380    // 1. CLI args take priority
381    if cli_provider.is_some() || cli_model.is_some() {
382        let result = resolve_cli_model(cli_provider, cli_model, available_models, settings);
383        if result.error.is_none() {
384            return InitialModelResult {
385                model: result.model,
386                thinking_level: result.thinking_level.unwrap_or_else(|| "medium".to_string()),
387                fallback_message: None,
388            };
389        }
390    }
391
392    // 2. Use first model from scoped models (skip if continuing)
393    if !scoped_models.is_empty() && !is_continuing {
394        return InitialModelResult {
395            model: Some(scoped_models[0].clone()),
396            thinking_level: "medium".to_string(),
397            fallback_message: None,
398        };
399    }
400
401    // 3. Try saved default from settings
402    if let Some(ref s) = settings {
403        if let Some(default_model) = &s.default_model {
404            let parsed = parse_model_pattern(default_model, available_models);
405            if let Some(ref p) = parsed.provider {
406                let model = available_models
407                    .iter()
408                    .find(|m| m.provider.eq_ignore_ascii_case(p) && m.id.eq_ignore_ascii_case(&parsed.model_id))
409                    .cloned();
410                if model.is_some() {
411                    return InitialModelResult {
412                        model,
413                        thinking_level: format!("{:?}", s.thinking_level),
414                        fallback_message: None,
415                    };
416                }
417            }
418        }
419    }
420
421    // 4. Try default models from known providers
422    let defaults = default_model_per_provider();
423    for (provider, default_id) in &defaults {
424        if let Some(model) = available_models
425            .iter()
426            .find(|m| m.provider.eq_ignore_ascii_case(provider) && m.id.eq_ignore_ascii_case(default_id))
427        {
428            return InitialModelResult {
429                model: Some(model.clone()),
430                thinking_level: "medium".to_string(),
431                fallback_message: None,
432            };
433        }
434    }
435
436    // 5. Use first available model
437    if let Some(model) = available_models.first() {
438        return InitialModelResult {
439            model: Some(model.clone()),
440            thinking_level: "medium".to_string(),
441            fallback_message: None,
442        };
443    }
444
445    // No model found
446    InitialModelResult {
447        model: None,
448        thinking_level: "medium".to_string(),
449        fallback_message: Some("No models available. Check your installation.".to_string()),
450    }
451}
452
453/// Restore model from session with fallback
454pub fn restore_model_from_session(
455    saved_provider: &str,
456    saved_model_id: &str,
457    current_model: Option<&Model>,
458    should_print_messages: bool,
459    available_models: &[Model],
460) -> (Option<Model>, Option<String>) {
461    let restored = available_models
462        .iter()
463        .find(|m| {
464            m.provider.eq_ignore_ascii_case(saved_provider) && m.id.eq_ignore_ascii_case(saved_model_id)
465        })
466        .cloned();
467
468    match (&restored, current_model) {
469        (Some(ref model), _) => {
470            if should_print_messages {
471                eprintln!("Restored model: {}/{}", saved_provider, saved_model_id);
472            }
473            (Some(model.clone()), None)
474        }
475        (None, Some(current)) => {
476            if should_print_messages {
477                eprintln!(
478                    "Warning: Could not restore model {}/{} (model not found). Falling back to current model.",
479                    saved_provider, saved_model_id
480                );
481                eprintln!("Falling back to: {}/{}", current.provider, current.id);
482            }
483            (
484                Some(current.clone()),
485                Some(format!(
486                    "Could not restore model {}/{} (model not found). Using current model.",
487                    saved_provider, saved_model_id
488                )),
489            )
490        }
491        (None, None) => {
492            // Try to find any available model
493            if let Some(model) = available_models.first() {
494                if should_print_messages {
495                    eprintln!(
496                        "Warning: Could not restore model {}/{} (model not found).",
497                        saved_provider, saved_model_id
498                    );
499                    eprintln!("Using first available model: {}/{}", model.provider, model.id);
500                }
501                (
502                    Some(model.clone()),
503                    Some(format!(
504                        "Could not restore model {}/{}. Using first available model.",
505                        saved_provider, saved_model_id
506                    )),
507                )
508            } else {
509                (None, Some("No models available.".to_string()))
510            }
511        }
512    }
513}
514
515#[cfg(test)]
516mod tests {
517    use super::*;
518
519    fn sample_models() -> Vec<Model> {
520        vec![
521            Model {
522                provider: "anthropic".to_string(),
523                id: "claude-sonnet-4-5".to_string(),
524                name: Some("Claude Sonnet 4.5".to_string()),
525                description: None,
526                context_window: Some(200000),
527                supported_features: vec!["tools".to_string(), "vision".to_string()],
528            },
529            Model {
530                provider: "anthropic".to_string(),
531                id: "claude-opus-4-7".to_string(),
532                name: Some("Claude Opus 4.7".to_string()),
533                description: None,
534                context_window: Some(200000),
535                supported_features: vec!["tools".to_string(), "vision".to_string()],
536            },
537            Model {
538                provider: "openai".to_string(),
539                id: "gpt-4o".to_string(),
540                name: Some("GPT-4o".to_string()),
541                description: None,
542                context_window: Some(128000),
543                supported_features: vec!["tools".to_string()],
544            },
545            Model {
546                provider: "google".to_string(),
547                id: "gemini-2.5-pro".to_string(),
548                name: Some("Gemini 2.5 Pro".to_string()),
549                description: None,
550                context_window: Some(1000000),
551                supported_features: vec!["tools".to_string()],
552            },
553        ]
554    }
555
556    #[test]
557    fn test_parse_model_pattern_exact() {
558        let models = sample_models();
559        let result = parse_model_pattern("claude-sonnet-4-5", &models);
560
561        assert_eq!(result.model_id, "claude-sonnet-4-5");
562        assert_eq!(result.provider, Some("anthropic".to_string()));
563        assert!(result.warning.is_none());
564    }
565
566    #[test]
567    fn test_parse_model_pattern_with_provider() {
568        let models = sample_models();
569        let result = parse_model_pattern("anthropic/claude-sonnet-4-5", &models);
570
571        assert_eq!(result.model_id, "claude-sonnet-4-5");
572        assert_eq!(result.provider, Some("anthropic".to_string()));
573    }
574
575    #[test]
576    fn test_parse_model_pattern_with_thinking_level() {
577        let models = sample_models();
578        let result = parse_model_pattern("sonnet:high", &models);
579
580        assert_eq!(result.thinking_level, Some("high".to_string()));
581    }
582
583    #[test]
584    fn test_parse_model_pattern_partial_match() {
585        let models = sample_models();
586        let result = parse_model_pattern("sonnet", &models);
587
588        assert!(result.model_id.contains("sonnet") || result.model_id == "sonnet");
589        assert!(result.warning.is_some() || result.provider.is_some());
590    }
591
592    #[test]
593    fn test_parse_model_pattern_not_found() {
594        let models = sample_models();
595        let result = parse_model_pattern("nonexistent-model", &models);
596
597        assert_eq!(result.model_id, "nonexistent-model");
598        assert!(result.warning.is_some());
599    }
600
601    #[test]
602    fn test_resolve_cli_model_with_provider() {
603        let models = sample_models();
604        let result = resolve_cli_model(Some("anthropic"), Some("claude-sonnet-4-5"), &models, None);
605
606        assert!(result.error.is_none());
607        assert!(result.model.is_some());
608        assert_eq!(result.model.unwrap().id, "claude-sonnet-4-5");
609    }
610
611    #[test]
612    fn test_resolve_cli_model_with_slash() {
613        let models = sample_models();
614        let result = resolve_cli_model(None, Some("anthropic/claude-sonnet-4-5"), &models, None);
615
616        assert!(result.error.is_none());
617        assert!(result.model.is_some());
618    }
619
620    #[test]
621    fn test_resolve_cli_model_not_found() {
622        let models = sample_models();
623        let result = resolve_cli_model(None, Some("nonexistent-model"), &models, None);
624
625        assert!(result.error.is_some() || result.model.is_none());
626    }
627
628    #[test]
629    fn test_find_models_by_pattern() {
630        let models = sample_models();
631        let results = find_models_by_pattern("sonnet", &models);
632
633        assert!(!results.is_empty());
634        assert!(results.iter().all(|m| m.id.contains("sonnet") || m.name.as_ref().map(|n| n.contains("sonnet")).unwrap_or(false)));
635    }
636
637    #[test]
638    fn test_find_initial_model_from_cli() {
639        let models = sample_models();
640        let result = find_initial_model(
641            Some("openai"),
642            Some("gpt-4o"),
643            &[],
644            false,
645            None,
646            &models,
647        );
648
649        assert!(result.model.is_some());
650        assert_eq!(result.model.unwrap().id, "gpt-4o");
651    }
652
653    #[test]
654    fn test_find_initial_model_fallback_to_available() {
655        let models = sample_models();
656        let result = find_initial_model(None, None, &[], false, None, &models);
657
658        assert!(result.model.is_some());
659        // Should use first available
660        assert!(result.fallback_message.is_none());
661    }
662
663    #[test]
664    fn test_restore_model_from_session_success() {
665        let models = sample_models();
666        let (model, message) = restore_model_from_session(
667            "anthropic",
668            "claude-sonnet-4-5",
669            None,
670            false,
671            &models,
672        );
673
674        assert!(model.is_some());
675        assert!(message.is_none());
676    }
677
678    #[test]
679    fn test_restore_model_from_session_fallback() {
680        let models = sample_models();
681        let current = &models[0];
682        let (model, message) = restore_model_from_session(
683            "nonexistent",
684            "model",
685            Some(current),
686            false,
687            &models,
688        );
689
690        assert!(model.is_some());
691        assert!(message.is_some());
692    }
693
694    #[test]
695    fn test_is_alias() {
696        assert!(is_alias("claude-sonnet-4-latest"));
697        assert!(!is_alias("claude-sonnet-4-20250929"));
698        assert!(is_alias("simple-model"));
699    }
700}