1use crate::settings::Settings;
7use std::collections::HashMap;
8
9#[derive(Debug, Clone, PartialEq, Eq, Hash)]
11pub struct Provider {
12 pub id: String,
13 pub name: String,
14 pub website: Option<String>,
15}
16
17impl Provider {
18 pub fn new(id: impl Into<String>, name: impl Into<String>) -> Self {
19 Self {
20 id: id.into(),
21 name: name.into(),
22 website: None,
23 }
24 }
25
26 pub fn with_website(mut self, website: impl Into<String>) -> Self {
27 self.website = Some(website.into());
28 self
29 }
30}
31
32#[derive(Debug, Clone)]
34pub struct Model {
35 pub provider: String,
36 pub id: String,
37 pub name: Option<String>,
38 pub description: Option<String>,
39 pub context_window: Option<u32>,
40 pub supported_features: Vec<String>,
41}
42
43impl Model {
44 pub fn full_id(&self) -> String {
46 format!("{}/{}", self.provider, self.id)
47 }
48}
49
50#[derive(Debug)]
52pub struct ParsedModelResult {
53 pub provider: Option<String>,
54 pub model_id: String,
55 pub thinking_level: Option<String>,
56 pub warning: Option<String>,
57}
58
59#[derive(Debug)]
61pub struct ResolveCliModelResult {
62 pub model: Option<Model>,
63 pub thinking_level: Option<String>,
64 pub warning: Option<String>,
65 pub error: Option<String>,
66}
67
68#[derive(Debug)]
70pub struct InitialModelResult {
71 pub model: Option<Model>,
72 pub thinking_level: String,
73 pub fallback_message: Option<String>,
74}
75
76pub fn default_model_per_provider() -> HashMap<String, String> {
78 let mut map = HashMap::new();
79 map.insert("anthropic".to_string(), "claude-sonnet-4-5".to_string());
80 map.insert("openai".to_string(), "gpt-4o".to_string());
81 map.insert("google".to_string(), "gemini-2.5-pro".to_string());
82 map.insert("deepseek".to_string(), "deepseek-v3".to_string());
83 map.insert("openrouter".to_string(), "anthropic/claude-sonnet-4".to_string());
84 map.insert("groq".to_string(), "mixtral-8x7b".to_string());
85 map.insert("cerebras".to_string(), "llama-3.3-70b".to_string());
86 map.insert("mistral".to_string(), "mistral-large".to_string());
87 map.insert("xai".to_string(), "grok-2".to_string());
88 map.insert("amazon-bedrock".to_string(), "anthropic.claude-v2".to_string());
89 map.insert("azure-openai".to_string(), "gpt-4o".to_string());
90 map
91}
92
93fn is_alias(id: &str) -> bool {
95 if id.ends_with("-latest") {
97 return true;
98 }
99 let date_pattern = regex::Regex::new(r"-\d{8}$").ok();
101 match date_pattern {
102 Some(re) => !re.is_match(id),
103 None => true,
104 }
105}
106
107pub fn parse_model_pattern(
116 pattern: &str,
117 available_models: &[Model],
118) -> ParsedModelResult {
119 let pattern = pattern.trim();
120 if pattern.is_empty() {
121 return ParsedModelResult {
122 provider: None,
123 model_id: String::new(),
124 thinking_level: None,
125 warning: Some("Empty model pattern".to_string()),
126 };
127 }
128
129 let thinking_levels = ["off", "minimal", "low", "medium", "high", "xhigh"];
131 let last_colon = pattern.rfind(':');
132 let (base_pattern, thinking_level) = if let Some(idx) = last_colon {
133 let suffix = &pattern[idx + 1..];
134 if thinking_levels.contains(&suffix) {
135 (&pattern[..idx], Some(suffix.to_string()))
136 } else {
137 (pattern, None)
138 }
139 } else {
140 (pattern, None)
141 };
142
143 let exact_match = available_models.iter().find(|m| {
145 m.id.eq_ignore_ascii_case(base_pattern)
146 || m.full_id().eq_ignore_ascii_case(base_pattern)
147 });
148
149 if let Some(model) = exact_match {
150 return ParsedModelResult {
151 provider: Some(model.provider.clone()),
152 model_id: model.id.clone(),
153 thinking_level,
154 warning: None,
155 };
156 }
157
158 if let Some(slash_idx) = base_pattern.find('/') {
160 let provider = &base_pattern[..slash_idx];
161 let model_id = &base_pattern[slash_idx + 1..];
162
163 let provider_exists = available_models.iter().any(|m| {
165 m.provider.eq_ignore_ascii_case(provider)
166 });
167
168 if provider_exists {
169 return ParsedModelResult {
170 provider: Some(provider.to_string()),
171 model_id: model_id.to_string(),
172 thinking_level,
173 warning: None,
174 };
175 }
176 }
177
178 let partial_matches: Vec<&Model> = available_models
180 .iter()
181 .filter(|m| {
182 m.id.to_lowercase().contains(&base_pattern.to_lowercase())
183 || m.name
184 .as_ref()
185 .map(|n| n.to_lowercase().contains(&base_pattern.to_lowercase()))
186 .unwrap_or(false)
187 })
188 .collect();
189
190 if partial_matches.len() == 1 {
191 let model = partial_matches[0];
192 return ParsedModelResult {
193 provider: Some(model.provider.clone()),
194 model_id: model.id.clone(),
195 thinking_level,
196 warning: None,
197 };
198 } else if partial_matches.len() > 1 {
199 let aliases: Vec<_> = partial_matches.iter().filter(|m| is_alias(&m.id)).collect();
201 if !aliases.is_empty() {
202 let model = aliases[0];
203 return ParsedModelResult {
204 provider: Some(model.provider.clone()),
205 model_id: model.id.clone(),
206 thinking_level,
207 warning: Some(format!(
208 "Multiple models match '{}', selected '{}'",
209 base_pattern,
210 model.full_id()
211 )),
212 };
213 }
214 let mut sorted = partial_matches.to_vec();
216 sorted.sort_by(|a, b| b.id.cmp(&a.id));
217 let model = sorted[0];
218 return ParsedModelResult {
219 provider: Some(model.provider.clone()),
220 model_id: model.id.clone(),
221 thinking_level,
222 warning: Some(format!(
223 "Multiple models match '{}', selected '{}'",
224 base_pattern,
225 model.full_id()
226 )),
227 };
228 }
229
230 ParsedModelResult {
232 provider: None,
233 model_id: pattern.to_string(),
234 thinking_level,
235 warning: Some(format!(
236 "Model '{}' not found in available models. Treating as custom model ID.",
237 pattern
238 )),
239 }
240}
241
242pub fn find_models_by_pattern(pattern: &str, models: &[Model]) -> Vec<Model> {
244 let pattern_lower = pattern.to_lowercase();
245 models
246 .iter()
247 .filter(|m| {
248 m.id.to_lowercase().contains(&pattern_lower)
249 || m.full_id().to_lowercase().contains(&pattern_lower)
250 || m.name
251 .as_ref()
252 .map(|n| n.to_lowercase().contains(&pattern_lower))
253 .unwrap_or(false)
254 })
255 .cloned()
256 .collect()
257}
258
259pub fn resolve_cli_model(
261 cli_provider: Option<&str>,
262 cli_model: Option<&str>,
263 available_models: &[Model],
264 _settings: Option<&Settings>,
265) -> ResolveCliModelResult {
266 let cli_model = match cli_model {
267 Some(m) => m,
268 None => {
269 return ResolveCliModelResult {
270 model: None,
271 thinking_level: None,
272 warning: None,
273 error: None,
274 };
275 }
276 };
277
278 let mut provider_map: HashMap<String, String> = HashMap::new();
280 for model in available_models {
281 provider_map.insert(model.provider.to_lowercase(), model.provider.clone());
282 }
283
284 let provider = if let Some(p) = cli_provider {
286 provider_map.get(&p.to_lowercase()).cloned()
287 } else if let Some(slash_idx) = cli_model.find('/') {
288 let maybe_provider = &cli_model[..slash_idx];
289 provider_map.get(&maybe_provider.to_lowercase()).cloned()
290 } else {
291 None
292 };
293
294 let model_pattern = if let Some(ref p) = provider {
296 if cli_model.to_lowercase().starts_with(&format!("{}/", p.to_lowercase())) {
297 &cli_model[p.len() + 1..]
298 } else {
299 cli_model
300 }
301 } else {
302 cli_model
303 };
304
305 let parsed = parse_model_pattern(model_pattern, available_models);
307
308 let model = if let Some(ref p) = provider {
310 available_models
311 .iter()
312 .find(|m| {
313 m.provider.eq_ignore_ascii_case(p) && m.id.eq_ignore_ascii_case(&parsed.model_id)
314 })
315 .cloned()
316 } else if let Some(ref p) = parsed.provider {
317 available_models
318 .iter()
319 .find(|m| {
320 m.provider.eq_ignore_ascii_case(p) && m.id.eq_ignore_ascii_case(&parsed.model_id)
321 })
322 .cloned()
323 } else {
324 available_models
326 .iter()
327 .find(|m| m.id.eq_ignore_ascii_case(&parsed.model_id))
328 .cloned()
329 };
330
331 if let Some(ref m) = model {
332 ResolveCliModelResult {
333 model: Some(m.clone()),
334 thinking_level: parsed.thinking_level,
335 warning: parsed.warning,
336 error: None,
337 }
338 } else {
339 let fallback_model = if let Some(ref p) = provider {
341 Some(Model {
342 provider: p.clone(),
343 id: parsed.model_id.clone(),
344 name: Some(parsed.model_id.clone()),
345 description: None,
346 context_window: None,
347 supported_features: vec![],
348 })
349 } else {
350 None
351 };
352
353 ResolveCliModelResult {
354 model: fallback_model.clone(),
355 thinking_level: parsed.thinking_level,
356 warning: parsed.warning,
357 error: fallback_model.is_none().then(|| {
358 format!(
359 "Model '{}' not found. Use --list-models to see available models.",
360 cli_model
361 )
362 }),
363 }
364 }
365}
366
367pub fn find_initial_model(
373 cli_provider: Option<&str>,
374 cli_model: Option<&str>,
375 scoped_models: &[Model],
376 is_continuing: bool,
377 settings: Option<&Settings>,
378 available_models: &[Model],
379) -> InitialModelResult {
380 if cli_provider.is_some() || cli_model.is_some() {
382 let result = resolve_cli_model(cli_provider, cli_model, available_models, settings);
383 if result.error.is_none() {
384 return InitialModelResult {
385 model: result.model,
386 thinking_level: result.thinking_level.unwrap_or_else(|| "medium".to_string()),
387 fallback_message: None,
388 };
389 }
390 }
391
392 if !scoped_models.is_empty() && !is_continuing {
394 return InitialModelResult {
395 model: Some(scoped_models[0].clone()),
396 thinking_level: "medium".to_string(),
397 fallback_message: None,
398 };
399 }
400
401 if let Some(ref s) = settings {
403 if let Some(default_model) = &s.default_model {
404 let parsed = parse_model_pattern(default_model, available_models);
405 if let Some(ref p) = parsed.provider {
406 let model = available_models
407 .iter()
408 .find(|m| m.provider.eq_ignore_ascii_case(p) && m.id.eq_ignore_ascii_case(&parsed.model_id))
409 .cloned();
410 if model.is_some() {
411 return InitialModelResult {
412 model,
413 thinking_level: format!("{:?}", s.thinking_level),
414 fallback_message: None,
415 };
416 }
417 }
418 }
419 }
420
421 let defaults = default_model_per_provider();
423 for (provider, default_id) in &defaults {
424 if let Some(model) = available_models
425 .iter()
426 .find(|m| m.provider.eq_ignore_ascii_case(provider) && m.id.eq_ignore_ascii_case(default_id))
427 {
428 return InitialModelResult {
429 model: Some(model.clone()),
430 thinking_level: "medium".to_string(),
431 fallback_message: None,
432 };
433 }
434 }
435
436 if let Some(model) = available_models.first() {
438 return InitialModelResult {
439 model: Some(model.clone()),
440 thinking_level: "medium".to_string(),
441 fallback_message: None,
442 };
443 }
444
445 InitialModelResult {
447 model: None,
448 thinking_level: "medium".to_string(),
449 fallback_message: Some("No models available. Check your installation.".to_string()),
450 }
451}
452
453pub fn restore_model_from_session(
455 saved_provider: &str,
456 saved_model_id: &str,
457 current_model: Option<&Model>,
458 should_print_messages: bool,
459 available_models: &[Model],
460) -> (Option<Model>, Option<String>) {
461 let restored = available_models
462 .iter()
463 .find(|m| {
464 m.provider.eq_ignore_ascii_case(saved_provider) && m.id.eq_ignore_ascii_case(saved_model_id)
465 })
466 .cloned();
467
468 match (&restored, current_model) {
469 (Some(ref model), _) => {
470 if should_print_messages {
471 eprintln!("Restored model: {}/{}", saved_provider, saved_model_id);
472 }
473 (Some(model.clone()), None)
474 }
475 (None, Some(current)) => {
476 if should_print_messages {
477 eprintln!(
478 "Warning: Could not restore model {}/{} (model not found). Falling back to current model.",
479 saved_provider, saved_model_id
480 );
481 eprintln!("Falling back to: {}/{}", current.provider, current.id);
482 }
483 (
484 Some(current.clone()),
485 Some(format!(
486 "Could not restore model {}/{} (model not found). Using current model.",
487 saved_provider, saved_model_id
488 )),
489 )
490 }
491 (None, None) => {
492 if let Some(model) = available_models.first() {
494 if should_print_messages {
495 eprintln!(
496 "Warning: Could not restore model {}/{} (model not found).",
497 saved_provider, saved_model_id
498 );
499 eprintln!("Using first available model: {}/{}", model.provider, model.id);
500 }
501 (
502 Some(model.clone()),
503 Some(format!(
504 "Could not restore model {}/{}. Using first available model.",
505 saved_provider, saved_model_id
506 )),
507 )
508 } else {
509 (None, Some("No models available.".to_string()))
510 }
511 }
512 }
513}
514
515#[cfg(test)]
516mod tests {
517 use super::*;
518
519 fn sample_models() -> Vec<Model> {
520 vec![
521 Model {
522 provider: "anthropic".to_string(),
523 id: "claude-sonnet-4-5".to_string(),
524 name: Some("Claude Sonnet 4.5".to_string()),
525 description: None,
526 context_window: Some(200000),
527 supported_features: vec!["tools".to_string(), "vision".to_string()],
528 },
529 Model {
530 provider: "anthropic".to_string(),
531 id: "claude-opus-4-7".to_string(),
532 name: Some("Claude Opus 4.7".to_string()),
533 description: None,
534 context_window: Some(200000),
535 supported_features: vec!["tools".to_string(), "vision".to_string()],
536 },
537 Model {
538 provider: "openai".to_string(),
539 id: "gpt-4o".to_string(),
540 name: Some("GPT-4o".to_string()),
541 description: None,
542 context_window: Some(128000),
543 supported_features: vec!["tools".to_string()],
544 },
545 Model {
546 provider: "google".to_string(),
547 id: "gemini-2.5-pro".to_string(),
548 name: Some("Gemini 2.5 Pro".to_string()),
549 description: None,
550 context_window: Some(1000000),
551 supported_features: vec!["tools".to_string()],
552 },
553 ]
554 }
555
556 #[test]
557 fn test_parse_model_pattern_exact() {
558 let models = sample_models();
559 let result = parse_model_pattern("claude-sonnet-4-5", &models);
560
561 assert_eq!(result.model_id, "claude-sonnet-4-5");
562 assert_eq!(result.provider, Some("anthropic".to_string()));
563 assert!(result.warning.is_none());
564 }
565
566 #[test]
567 fn test_parse_model_pattern_with_provider() {
568 let models = sample_models();
569 let result = parse_model_pattern("anthropic/claude-sonnet-4-5", &models);
570
571 assert_eq!(result.model_id, "claude-sonnet-4-5");
572 assert_eq!(result.provider, Some("anthropic".to_string()));
573 }
574
575 #[test]
576 fn test_parse_model_pattern_with_thinking_level() {
577 let models = sample_models();
578 let result = parse_model_pattern("sonnet:high", &models);
579
580 assert_eq!(result.thinking_level, Some("high".to_string()));
581 }
582
583 #[test]
584 fn test_parse_model_pattern_partial_match() {
585 let models = sample_models();
586 let result = parse_model_pattern("sonnet", &models);
587
588 assert!(result.model_id.contains("sonnet") || result.model_id == "sonnet");
589 assert!(result.warning.is_some() || result.provider.is_some());
590 }
591
592 #[test]
593 fn test_parse_model_pattern_not_found() {
594 let models = sample_models();
595 let result = parse_model_pattern("nonexistent-model", &models);
596
597 assert_eq!(result.model_id, "nonexistent-model");
598 assert!(result.warning.is_some());
599 }
600
601 #[test]
602 fn test_resolve_cli_model_with_provider() {
603 let models = sample_models();
604 let result = resolve_cli_model(Some("anthropic"), Some("claude-sonnet-4-5"), &models, None);
605
606 assert!(result.error.is_none());
607 assert!(result.model.is_some());
608 assert_eq!(result.model.unwrap().id, "claude-sonnet-4-5");
609 }
610
611 #[test]
612 fn test_resolve_cli_model_with_slash() {
613 let models = sample_models();
614 let result = resolve_cli_model(None, Some("anthropic/claude-sonnet-4-5"), &models, None);
615
616 assert!(result.error.is_none());
617 assert!(result.model.is_some());
618 }
619
620 #[test]
621 fn test_resolve_cli_model_not_found() {
622 let models = sample_models();
623 let result = resolve_cli_model(None, Some("nonexistent-model"), &models, None);
624
625 assert!(result.error.is_some() || result.model.is_none());
626 }
627
628 #[test]
629 fn test_find_models_by_pattern() {
630 let models = sample_models();
631 let results = find_models_by_pattern("sonnet", &models);
632
633 assert!(!results.is_empty());
634 assert!(results.iter().all(|m| m.id.contains("sonnet") || m.name.as_ref().map(|n| n.contains("sonnet")).unwrap_or(false)));
635 }
636
637 #[test]
638 fn test_find_initial_model_from_cli() {
639 let models = sample_models();
640 let result = find_initial_model(
641 Some("openai"),
642 Some("gpt-4o"),
643 &[],
644 false,
645 None,
646 &models,
647 );
648
649 assert!(result.model.is_some());
650 assert_eq!(result.model.unwrap().id, "gpt-4o");
651 }
652
653 #[test]
654 fn test_find_initial_model_fallback_to_available() {
655 let models = sample_models();
656 let result = find_initial_model(None, None, &[], false, None, &models);
657
658 assert!(result.model.is_some());
659 assert!(result.fallback_message.is_none());
661 }
662
663 #[test]
664 fn test_restore_model_from_session_success() {
665 let models = sample_models();
666 let (model, message) = restore_model_from_session(
667 "anthropic",
668 "claude-sonnet-4-5",
669 None,
670 false,
671 &models,
672 );
673
674 assert!(model.is_some());
675 assert!(message.is_none());
676 }
677
678 #[test]
679 fn test_restore_model_from_session_fallback() {
680 let models = sample_models();
681 let current = &models[0];
682 let (model, message) = restore_model_from_session(
683 "nonexistent",
684 "model",
685 Some(current),
686 false,
687 &models,
688 );
689
690 assert!(model.is_some());
691 assert!(message.is_some());
692 }
693
694 #[test]
695 fn test_is_alias() {
696 assert!(is_alias("claude-sonnet-4-latest"));
697 assert!(!is_alias("claude-sonnet-4-20250929"));
698 assert!(is_alias("simple-model"));
699 }
700}