ai_agent/utils/model/
model.rs1use crate::constants::env::{ai, ai_code};
7use std::sync::OnceLock;
8
9pub type ModelShortName = String;
11
12pub type ModelName = String;
14
15pub type ModelSetting = Option<ModelNameOrAlias>;
17
18pub type ModelNameOrAlias = String;
20
21pub type ModelAlias = String;
23
24pub fn get_small_fast_model() -> ModelName {
30 std::env::var(ai::SMALL_FAST_MODEL)
31 .ok()
32 .filter(|s| !s.is_empty())
33 .unwrap_or_else(get_default_haiku_model)
34}
35
36pub fn is_non_custom_opus_model(model: &ModelName) -> bool {
38 model == &get_model_strings().opus_40
39 || model == &get_model_strings().opus_41
40 || model == &get_model_strings().opus_45
41 || model == &get_model_strings().opus_46
42}
43
44pub fn get_user_specified_model_setting() -> Option<String> {
47 if let Some(override_model) = get_main_loop_model_override() {
49 if is_model_allowed(&override_model) {
50 return Some(override_model);
51 } else {
52 return None;
53 }
54 }
55
56 if let Ok(env_model) = std::env::var(ai::MODEL) {
58 if !env_model.is_empty() && is_model_allowed(&env_model) {
59 return Some(env_model);
60 }
61 }
62
63 None
68}
69
70pub fn get_main_loop_model() -> ModelName {
78 if let Some(model) = get_user_specified_model_setting() {
79 return parse_user_specified_model(model);
80 }
81 get_default_main_loop_model()
82}
83
84pub fn get_best_model() -> ModelName {
86 get_default_opus_model()
87}
88
89pub fn get_default_opus_model() -> ModelName {
91 std::env::var(ai::DEFAULT_OPUS_MODEL)
92 .ok()
93 .filter(|s| !s.is_empty())
94 .unwrap_or_else(|| get_model_strings().opus_46.clone())
95}
96
97pub fn get_default_sonnet_model() -> ModelName {
99 if let Ok(model) = std::env::var(ai::DEFAULT_SONNET_MODEL) {
100 if !model.is_empty() {
101 return model;
102 }
103 }
104
105 if get_api_provider() != "firstParty" {
107 return get_model_strings().sonnet_45.clone();
108 }
109 get_model_strings().sonnet_46.clone()
110}
111
112pub fn get_default_haiku_model() -> ModelName {
114 std::env::var(ai::DEFAULT_HAIKU_MODEL)
115 .ok()
116 .filter(|s| !s.is_empty())
117 .unwrap_or_else(|| get_model_strings().haiku_45.clone())
118}
119
120pub fn get_runtime_main_loop_model(
122 permission_mode: &str,
123 main_loop_model: &str,
124 exceeds_200k_tokens: bool,
125) -> ModelName {
126 if get_user_specified_model_setting() == Some("opusplan".to_string())
128 && permission_mode == "plan"
129 && !exceeds_200k_tokens
130 {
131 return get_default_opus_model();
132 }
133
134 if get_user_specified_model_setting() == Some("haiku".to_string()) && permission_mode == "plan"
136 {
137 return get_default_sonnet_model();
138 }
139
140 main_loop_model.to_string()
141}
142
143pub fn get_default_main_loop_model_setting() -> ModelNameOrAlias {
145 if let Ok(user_type) = std::env::var(ai::USER_TYPE) {
147 if user_type == "ant" {
148 if let Some(ant_config) = get_ant_model_override_config() {
149 return ant_config.default_model;
150 }
151 return format!("{}[1m]", get_default_opus_model());
152 }
153 }
154
155 if is_max_subscriber() {
157 return if is_opus_1m_merge_enabled() {
158 format!("{}[1m]", get_default_opus_model())
159 } else {
160 get_default_opus_model()
161 };
162 }
163
164 if is_team_premium_subscriber() {
166 return if is_opus_1m_merge_enabled() {
167 format!("{}[1m]", get_default_opus_model())
168 } else {
169 get_default_opus_model()
170 };
171 }
172
173 get_default_sonnet_model()
175}
176
177pub fn get_default_main_loop_model() -> ModelName {
179 parse_user_specified_model(get_default_main_loop_model_setting())
180}
181
182pub fn first_party_name_to_canonical(name: &ModelName) -> ModelShortName {
184 let name_lower = name.to_lowercase();
185
186 if name_lower.contains("claude-opus-4-6") {
188 return "claude-opus-4-6".to_string();
189 }
190 if name_lower.contains("claude-opus-4-5") {
191 return "claude-opus-4-5".to_string();
192 }
193 if name_lower.contains("claude-opus-4-1") {
194 return "claude-opus-4-1".to_string();
195 }
196 if name_lower.contains("claude-opus-4") {
197 return "claude-opus-4".to_string();
198 }
199 if name_lower.contains("claude-sonnet-4-6") {
200 return "claude-sonnet-4-6".to_string();
201 }
202 if name_lower.contains("claude-sonnet-4-5") {
203 return "claude-sonnet-4-5".to_string();
204 }
205 if name_lower.contains("claude-sonnet-4") {
206 return "claude-sonnet-4".to_string();
207 }
208 if name_lower.contains("claude-haiku-4-5") {
209 return "claude-haiku-4-5".to_string();
210 }
211
212 if name_lower.contains("claude-3-7-sonnet") {
214 return "claude-3-7-sonnet".to_string();
215 }
216 if name_lower.contains("claude-3-5-sonnet") {
217 return "claude-3-5-sonnet".to_string();
218 }
219 if name_lower.contains("claude-3-5-haiku") {
220 return "claude-3-5-haiku".to_string();
221 }
222 if name_lower.contains("claude-3-opus") {
223 return "claude-3-opus".to_string();
224 }
225 if name_lower.contains("claude-3-sonnet") {
226 return "claude-3-sonnet".to_string();
227 }
228 if name_lower.contains("claude-3-haiku") {
229 return "claude-3-haiku".to_string();
230 }
231
232 if let Some(captures) = regex::Regex::new(r"(claude-(\d+-\d+-)?\w+)")
234 .ok()
235 .and_then(|re| re.captures(&name_lower))
236 {
237 if let Some(m) = captures.get(1) {
238 return m.as_str().to_string();
239 }
240 }
241
242 name.clone()
244}
245
246pub fn get_canonical_name(full_model_name: &str) -> ModelShortName {
248 let resolved = resolve_overridden_model(full_model_name);
249 first_party_name_to_canonical(&resolved)
250}
251
252pub fn get_claude_ai_user_default_model_description(fast_mode: bool) -> String {
254 if is_max_subscriber() || is_team_premium_subscriber() {
255 let base = if is_opus_1m_merge_enabled() {
256 "Opus 4.6 with 1M context"
257 } else {
258 "Opus 4.6"
259 };
260 let suffix = if fast_mode {
261 get_opus_46_pricing_suffix(true)
262 } else {
263 "".to_string()
264 };
265 format!("{} · Most capable for complex work{}", base, suffix)
266 } else {
267 "Sonnet 4.6 · Best for everyday tasks".to_string()
268 }
269}
270
271pub fn render_default_model_setting(setting: &ModelNameOrAlias) -> String {
273 if setting == "opusplan" {
274 return "Opus 4.6 in plan mode, else Sonnet 4.6".to_string();
275 }
276 render_model_name(&parse_user_specified_model(setting.clone()))
277}
278
279pub fn get_opus_46_pricing_suffix(fast_mode: bool) -> String {
281 if get_api_provider() != "firstParty" {
282 return "".to_string();
283 }
284 let pricing = "pricing_placeholder".to_string();
286 let fast_mode_indicator = if fast_mode { " (lightning)" } else { "" };
287 format!(" ·{} {}", fast_mode_indicator, pricing)
288}
289
290pub fn is_opus_1m_merge_enabled() -> bool {
292 if is_1m_context_disabled() || is_pro_subscriber() || get_api_provider() != "firstParty" {
293 return false;
294 }
295
296 if is_claude_ai_subscriber() && get_subscription_type().is_none() {
298 return false;
299 }
300
301 true
302}
303
304pub fn render_model_setting(setting: &ModelNameOrAlias) -> String {
306 if setting == "opusplan" {
307 return "Opus Plan".to_string();
308 }
309 if is_model_alias(setting) {
310 return capitalize(setting);
311 }
312 render_model_name(setting)
313}
314
315pub fn get_public_model_display_name(model: &ModelName) -> Option<String> {
317 let model_strings = get_model_strings();
318
319 if model == &model_strings.opus_46 {
320 return Some("Opus 4.6".to_string());
321 }
322 if model == &format!("{}[1m]", model_strings.opus_46) {
323 return Some("Opus 4.6 (1M context)".to_string());
324 }
325 if model == &model_strings.opus_45 {
326 return Some("Opus 4.5".to_string());
327 }
328 if model == &model_strings.opus_41 {
329 return Some("Opus 4.1".to_string());
330 }
331 if model == &model_strings.opus_40 {
332 return Some("Opus 4".to_string());
333 }
334 if model == &format!("{}[1m]", model_strings.sonnet_46) {
335 return Some("Sonnet 4.6 (1M context)".to_string());
336 }
337 if model == &model_strings.sonnet_46 {
338 return Some("Sonnet 4.6".to_string());
339 }
340 if model == &format!("{}[1m]", model_strings.sonnet_45) {
341 return Some("Sonnet 4.5 (1M context)".to_string());
342 }
343 if model == &model_strings.sonnet_45 {
344 return Some("Sonnet 4.5".to_string());
345 }
346 if model == &model_strings.sonnet_40 {
347 return Some("Sonnet 4".to_string());
348 }
349 if model == &format!("{}[1m]", model_strings.sonnet_40) {
350 return Some("Sonnet 4 (1M context)".to_string());
351 }
352 if model == &model_strings.sonnet_37 {
353 return Some("Sonnet 3.7".to_string());
354 }
355 if model == &model_strings.sonnet_35 {
356 return Some("Sonnet 3.5".to_string());
357 }
358 if model == &model_strings.haiku_45 {
359 return Some("Haiku 4.5".to_string());
360 }
361 if model == &model_strings.haiku_35 {
362 return Some("Haiku 3.5".to_string());
363 }
364
365 None
366}
367
368fn mask_model_codename(base_name: &str) -> String {
370 let parts: Vec<&str> = base_name.split('-').collect();
371 if parts.is_empty() {
372 return base_name.to_string();
373 }
374
375 let codename = parts[0];
376 let rest: Vec<&str> = parts[1..].to_vec();
377
378 let masked = if codename.len() > 3 {
379 format!("{}{}", &codename[..3], "*".repeat(codename.len() - 3))
380 } else {
381 codename.to_string()
382 };
383
384 let mut result = masked;
385 for part in rest {
386 result.push('-');
387 result.push_str(part);
388 }
389 result
390}
391
392pub fn render_model_name(model: &ModelName) -> String {
394 if let Some(public_name) = get_public_model_display_name(model) {
395 return public_name;
396 }
397
398 if let Ok(user_type) = std::env::var(ai::USER_TYPE) {
399 if user_type == "ant" {
400 let resolved = parse_user_specified_model(model.clone());
401 if let Some(ant_model) = resolve_ant_model(model) {
402 let base_name = ant_model.model.replace("[1m]", "");
403 let masked = mask_model_codename(&base_name);
404 let suffix = if has_1m_context(&resolved) {
405 "[1m]"
406 } else {
407 ""
408 };
409 return format!("{}{}", masked, suffix);
410 }
411 if resolved != *model {
412 return format!("{} ({})", model, resolved);
413 }
414 return resolved;
415 }
416 }
417
418 model.clone()
419}
420
421pub fn get_public_model_name(model: &ModelName) -> String {
423 if let Some(public_name) = get_public_model_display_name(model) {
424 return format!("Claude {}", public_name);
425 }
426 format!("Claude ({})", model)
427}
428
429pub fn parse_user_specified_model(model_input: ModelNameOrAlias) -> ModelName {
431 let model_input_trimmed = model_input.trim().to_string();
432 let normalized_model = model_input_trimmed.to_lowercase();
433
434 let has_1m_tag = has_1m_context(&normalized_model);
435 let model_string = if has_1m_tag {
436 normalized_model.replace("[1m]", "").trim().to_string()
437 } else {
438 normalized_model.clone()
439 };
440
441 if is_model_alias(&model_string) {
442 match model_string.as_str() {
443 "opusplan" => {
444 return format!(
445 "{}{}",
446 get_default_sonnet_model(),
447 if has_1m_tag { "[1m]" } else { "" }
448 );
449 }
450 "sonnet" => {
451 return format!(
452 "{}{}",
453 get_default_sonnet_model(),
454 if has_1m_tag { "[1m]" } else { "" }
455 );
456 }
457 "haiku" => {
458 return format!(
459 "{}{}",
460 get_default_haiku_model(),
461 if has_1m_tag { "[1m]" } else { "" }
462 );
463 }
464 "opus" => {
465 return format!(
466 "{}{}",
467 get_default_opus_model(),
468 if has_1m_tag { "[1m]" } else { "" }
469 );
470 }
471 "best" => {
472 return get_best_model();
473 }
474 _ => {}
475 }
476 }
477
478 if get_api_provider() == "firstParty"
480 && is_legacy_opus_first_party(&model_string)
481 && is_legacy_model_remap_enabled()
482 {
483 return format!(
484 "{}{}",
485 get_default_opus_model(),
486 if has_1m_tag { "[1m]" } else { "" }
487 );
488 }
489
490 if let Ok(user_type) = std::env::var(ai::USER_TYPE) {
492 if user_type == "ant" {
493 let has_1m_ant_tag = has_1m_context(&normalized_model);
494 let base_ant_model = normalized_model.replace("[1m]", "").trim().to_string();
495
496 if let Some(ant_model) = resolve_ant_model(&base_ant_model) {
497 let suffix = if has_1m_ant_tag { "[1m]" } else { "" };
498 return format!("{}{}", ant_model.model, suffix);
499 }
500 }
501 }
502
503 if has_1m_tag {
505 return format!("{}[1m]", model_input_trimmed.replace("[1m]", "").trim());
506 }
507 model_input_trimmed
508}
509
510pub fn resolve_skill_model_override(skill_model: &str, current_model: &str) -> String {
512 if has_1m_context(skill_model) || !has_1m_context(current_model) {
513 return skill_model.to_string();
514 }
515
516 if model_supports_1m(&parse_user_specified_model(skill_model.to_string())) {
517 return format!("{}[1m]", skill_model);
518 }
519 skill_model.to_string()
520}
521
522const LEGACY_OPUS_FIRSTPARTY: &[&str] = &[
524 "claude-opus-4-20250514",
525 "claude-opus-4-1-20250805",
526 "claude-opus-4-0",
527 "claude-opus-4-1",
528];
529
530fn is_legacy_opus_first_party(model: &str) -> bool {
531 LEGACY_OPUS_FIRSTPARTY.contains(&model)
532}
533
534pub fn is_legacy_model_remap_enabled() -> bool {
536 !is_env_truthy(&std::env::var(ai_code::DISABLE_LEGACY_MODEL_REMAP).unwrap_or_default())
537}
538
539pub fn model_display_string(model: &ModelSetting) -> String {
541 if model.is_none() {
542 if let Ok(user_type) = std::env::var(ai::USER_TYPE) {
543 if user_type == "ant" {
544 return format!(
545 "Default for Ants ({})",
546 render_default_model_setting(&get_default_main_loop_model_setting())
547 );
548 }
549 }
550 if is_claude_ai_subscriber() {
551 return format!(
552 "Default ({})",
553 get_claude_ai_user_default_model_description(false)
554 );
555 }
556 return format!("Default ({})", get_default_main_loop_model());
557 }
558
559 let model = model.as_ref().unwrap();
560 let resolved_model = parse_user_specified_model(model.clone());
561 if model == &resolved_model {
562 resolved_model
563 } else {
564 format!("{} ({})", model, resolved_model)
565 }
566}
567
568pub fn get_marketing_name_for_model(model_id: &str) -> Option<String> {
570 if get_api_provider() == "foundry" {
571 return None;
572 }
573
574 let has_1m = model_id.to_lowercase().contains("[1m]");
575 let canonical = get_canonical_name(model_id);
576
577 if canonical.contains("claude-opus-4-6") {
578 return Some(if has_1m {
579 "Opus 4.6 (with 1M context)".to_string()
580 } else {
581 "Opus 4.6".to_string()
582 });
583 }
584 if canonical.contains("claude-opus-4-5") {
585 return Some("Opus 4.5".to_string());
586 }
587 if canonical.contains("claude-opus-4-1") {
588 return Some("Opus 4.1".to_string());
589 }
590 if canonical.contains("claude-opus-4") {
591 return Some("Opus 4".to_string());
592 }
593 if canonical.contains("claude-sonnet-4-6") {
594 return Some(if has_1m {
595 "Sonnet 4.6 (with 1M context)".to_string()
596 } else {
597 "Sonnet 4.6".to_string()
598 });
599 }
600 if canonical.contains("claude-sonnet-4-5") {
601 return Some(if has_1m {
602 "Sonnet 4.5 (with 1M context)".to_string()
603 } else {
604 "Sonnet 4.5".to_string()
605 });
606 }
607 if canonical.contains("claude-sonnet-4") {
608 return Some(if has_1m {
609 "Sonnet 4 (with 1M context)".to_string()
610 } else {
611 "Sonnet 4".to_string()
612 });
613 }
614 if canonical.contains("claude-3-7-sonnet") {
615 return Some("Claude 3.7 Sonnet".to_string());
616 }
617 if canonical.contains("claude-3-5-sonnet") {
618 return Some("Claude 3.5 Sonnet".to_string());
619 }
620 if canonical.contains("claude-haiku-4-5") {
621 return Some("Haiku 4.5".to_string());
622 }
623 if canonical.contains("claude-3-5-haiku") {
624 return Some("Claude 3.5 Haiku".to_string());
625 }
626
627 None
628}
629
630pub fn normalize_model_string_for_api(model: &str) -> String {
632 regex::Regex::new(r"\[(1|2)m\]")
633 .map(|re| re.replace_all(model, "").to_string())
634 .unwrap_or_else(|_| model.to_string())
635}
636
637static MODEL_STRINGS: OnceLock<ModelStrings> = OnceLock::new();
643
644#[derive(Debug, Clone)]
645struct ModelStrings {
646 opus_40: ModelName,
647 opus_41: ModelName,
648 opus_45: ModelName,
649 opus_46: ModelName,
650 sonnet_35: ModelName,
651 sonnet_37: ModelName,
652 sonnet_40: ModelName,
653 sonnet_45: ModelName,
654 sonnet_46: ModelName,
655 haiku_35: ModelName,
656 haiku_45: ModelName,
657}
658
659fn get_model_strings() -> &'static ModelStrings {
660 MODEL_STRINGS.get_or_init(|| ModelStrings {
661 opus_40: "claude-opus-4-0-20250514".to_string(),
662 opus_41: "claude-opus-4-1-20250805".to_string(),
663 opus_45: "claude-opus-4-5-20250514".to_string(),
664 opus_46: "claude-opus-4-6-20251106".to_string(),
665 sonnet_35: "claude-sonnet-3-5-20241022".to_string(),
666 sonnet_37: "claude-sonnet-3-7-20250120".to_string(),
667 sonnet_40: "claude-sonnet-4-0-20250514".to_string(),
668 sonnet_45: "claude-sonnet-4-5-20241022".to_string(),
669 sonnet_46: "claude-sonnet-4-6-20251106".to_string(),
670 haiku_35: "claude-haiku-3-5-20241022".to_string(),
671 haiku_45: "claude-haiku-4-5-20250513".to_string(),
672 })
673}
674
675fn get_api_provider() -> String {
677 std::env::var(ai::API_PROVIDER)
678 .ok()
679 .unwrap_or_else(|| "firstParty".to_string())
680}
681
682fn get_main_loop_model_override() -> Option<ModelName> {
684 None
686}
687
688fn is_model_allowed(_model: &str) -> bool {
692 true
695}
696
697fn is_model_alias(model: &str) -> bool {
699 matches!(model, "opus" | "sonnet" | "haiku" | "opusplan" | "best")
700}
701
702fn capitalize(s: &str) -> String {
704 let mut chars = s.chars();
705 match chars.next() {
706 None => String::new(),
707 Some(first) => first.to_uppercase().collect::<String>() + chars.as_str(),
708 }
709}
710
711fn is_1m_context_disabled() -> bool {
713 false
715}
716
717fn has_1m_context(model: &str) -> bool {
719 model.to_lowercase().ends_with("[1m]")
720}
721
722fn model_supports_1m(model: &ModelName) -> bool {
724 let canonical = get_canonical_name(model);
726 matches!(
727 canonical.as_str(),
728 "claude-opus-4-6" | "claude-opus-4-5" | "claude-sonnet-4-6" | "claude-sonnet-4-5"
729 )
730}
731
732fn resolve_overridden_model(model: &str) -> ModelName {
734 model.to_string()
736}
737
738fn is_max_subscriber() -> bool {
740 get_subscription_type() == Some("max".to_string())
741}
742
743fn is_team_premium_subscriber() -> bool {
745 get_subscription_type() == Some("team".to_string())
746 && get_rate_limit_tier() == Some("default_claude_max_5x".to_string())
747}
748
749fn is_pro_subscriber() -> bool {
751 get_subscription_type() == Some("pro".to_string())
752}
753
754fn get_rate_limit_tier() -> Option<String> {
756 use crate::session_history::get_claude_ai_oauth_tokens;
757 get_claude_ai_oauth_tokens().and_then(|t| t.rate_limit_tier.clone())
758}
759
760pub fn is_claude_ai_subscriber() -> bool {
763 use crate::session_history::get_claude_ai_oauth_tokens;
764 use crate::utils::env_utils::is_env_truthy;
765
766 if is_env_truthy(Some("AI_CODE_USE_BEDROCK"))
768 || is_env_truthy(Some("AI_CODE_USE_VERTEX"))
769 || is_env_truthy(Some("AI_CODE_USE_FOUNDRY"))
770 {
771 return false;
772 }
773
774 if let Some(tokens) = get_claude_ai_oauth_tokens() {
776 return tokens.scopes.iter().any(|s| s.contains("user")) && !tokens.access_token.is_empty();
777 }
778
779 false
780}
781
782pub fn get_subscription_type() -> Option<String> {
784 use crate::session_history::get_claude_ai_oauth_tokens;
785
786 get_claude_ai_oauth_tokens().and_then(|t| t.subscription_type.clone())
787}
788
789fn is_env_truthy(value: &str) -> bool {
791 let normalized = value.to_lowercase();
792 matches!(normalized.trim(), "1" | "true" | "yes" | "on")
793}
794
795#[derive(Debug, Clone)]
797struct AntModelConfig {
798 default_model: String,
799 model: String,
800}
801
802fn get_ant_model_override_config() -> Option<AntModelConfig> {
804 None
806}
807
808fn resolve_ant_model(_model: &str) -> Option<AntModelConfig> {
810 None
812}