1use anyhow::{anyhow, Result};
45use serde::{Deserialize, Serialize};
46use std::collections::BTreeMap;
47use std::str::FromStr;
48
49#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
55#[serde(rename_all = "lowercase")]
56pub enum Tier {
57 #[default]
59 Auto,
60 Low,
62 Med,
64 High,
66}
67
68impl std::fmt::Display for Tier {
69 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
70 match self {
71 Self::Auto => write!(f, "auto"),
72 Self::Low => write!(f, "low"),
73 Self::Med => write!(f, "med"),
74 Self::High => write!(f, "high"),
75 }
76 }
77}
78
79impl FromStr for Tier {
80 type Err = anyhow::Error;
81
82 fn from_str(s: &str) -> Result<Self> {
83 match s.trim().to_ascii_lowercase().as_str() {
84 "auto" => Ok(Self::Auto),
85 "low" => Ok(Self::Low),
86 "med" | "medium" => Ok(Self::Med),
87 "high" => Ok(Self::High),
88 other => Err(anyhow!(
89 "invalid tier `{}`: expected one of auto|low|med|high",
90 other
91 )),
92 }
93 }
94}
95
96#[derive(Debug, Default, Clone, Serialize, Deserialize)]
100pub struct TierMap {
101 #[serde(default, skip_serializing_if = "Option::is_none")]
102 pub low: Option<String>,
103 #[serde(default, skip_serializing_if = "Option::is_none")]
104 pub med: Option<String>,
105 #[serde(default, skip_serializing_if = "Option::is_none")]
106 pub high: Option<String>,
107}
108
109impl TierMap {
110 pub fn get(&self, tier: Tier) -> Option<&str> {
111 match tier {
112 Tier::Auto => None,
113 Tier::Low => self.low.as_deref(),
114 Tier::Med => self.med.as_deref(),
115 Tier::High => self.high.as_deref(),
116 }
117 }
118
119 pub fn tier_of(&self, model_name: &str) -> Option<Tier> {
121 if self.low.as_deref() == Some(model_name) {
122 Some(Tier::Low)
123 } else if self.med.as_deref() == Some(model_name) {
124 Some(Tier::Med)
125 } else if self.high.as_deref() == Some(model_name) {
126 Some(Tier::High)
127 } else {
128 None
129 }
130 }
131}
132
133#[derive(Debug, Default, Clone, Serialize, Deserialize)]
135pub struct ModelConfig {
136 #[serde(default = "default_auto")]
138 pub auto: bool,
139 #[serde(default)]
142 pub tiers: BTreeMap<String, TierMap>,
143}
144
145fn default_auto() -> bool {
146 true
147}
148
149fn builtin_claude_code() -> TierMap {
151 TierMap {
152 low: Some("haiku".to_string()),
153 med: Some("sonnet".to_string()),
154 high: Some("opus".to_string()),
155 }
156}
157
158fn builtin_codex() -> TierMap {
160 TierMap {
161 low: Some("gpt-4o-mini".to_string()),
162 med: Some("gpt-4o".to_string()),
163 high: Some("o3".to_string()),
164 }
165}
166
167fn builtin_default() -> TierMap {
169 TierMap {
170 low: Some("haiku".to_string()),
171 med: Some("sonnet".to_string()),
172 high: Some("opus".to_string()),
173 }
174}
175
176fn builtin_for(harness: &str) -> TierMap {
178 match harness {
179 "claude-code" => builtin_claude_code(),
180 "codex" => builtin_codex(),
181 _ => builtin_default(),
182 }
183}
184
185pub fn detect_harness() -> String {
190 if std::env::var("CLAUDE_CODE_SESSION").is_ok() || std::env::var("CLAUDECODE").is_ok() {
191 "claude-code".to_string()
192 } else if std::env::var("CODEX_SESSION").is_ok() {
193 "codex".to_string()
194 } else {
195 "default".to_string()
196 }
197}
198
199pub fn resolve_tier_to_model(
204 tier: Tier,
205 harness: &str,
206 model_config: &ModelConfig,
207) -> Option<String> {
208 if matches!(tier, Tier::Auto) {
209 return None;
210 }
211 if let Some(map) = model_config.tiers.get(harness)
212 && let Some(name) = map.get(tier)
213 {
214 return Some(name.to_string());
215 }
216 builtin_for(harness).get(tier).map(|s| s.to_string())
217}
218
219pub fn tier_from_model_name(
224 model_name: &str,
225 harness: &str,
226 model_config: &ModelConfig,
227) -> Option<Tier> {
228 if let Some(map) = model_config.tiers.get(harness)
229 && let Some(t) = map.tier_of(model_name)
230 {
231 return Some(t);
232 }
233 builtin_for(harness).tier_of(model_name)
234}
235
236pub fn extract_model_component(content: &str) -> Option<String> {
242 let comps = crate::component::parse(content).ok()?;
243 let comp = comps.into_iter().find(|c| c.name == "model")?;
244 let inner = &content[comp.open_end..comp.close_start];
245 let trimmed = inner.trim();
246 if trimmed.is_empty() {
247 None
248 } else {
249 Some(trimmed.to_string())
250 }
251}
252
253pub fn component_value_to_tier(
258 value: &str,
259 harness: &str,
260 model_config: &ModelConfig,
261) -> Option<Tier> {
262 if let Ok(tier) = Tier::from_str(value) {
263 return Some(tier);
264 }
265 tier_from_model_name(value, harness, model_config)
266}
267
268pub fn suggested_tier(diff_type: Option<&str>, lines_added: usize, doc_path: &std::path::Path) -> Tier {
286 let base = match diff_type {
287 Some("simple_question") | Some("approval") | Some("boundary_artifact") | Some("annotation") => {
288 Tier::Low
289 }
290 Some("content_addition") => {
291 if lines_added < 10 {
292 Tier::Low
293 } else {
294 Tier::Med
295 }
296 }
297 Some("multi_topic") | Some("structural_change") => Tier::Med,
298 _ => Tier::Med,
299 };
300
301 let path_str = doc_path.to_string_lossy();
303 let boost = path_str.contains("tasks/software/")
304 || path_str.contains("/specs/")
305 || path_str.contains("agent-doc-bugs")
306 || path_str.contains("plan-")
307 || path_str.contains("/plan.md");
308 if boost {
309 match base {
310 Tier::Auto | Tier::Low => Tier::Med,
311 Tier::Med => Tier::High,
312 Tier::High => Tier::High,
313 }
314 } else {
315 base
316 }
317}
318
319#[derive(Debug, Clone)]
321pub struct ModelSwitchScan {
322 pub model_switch: Option<String>,
324 pub model_switch_tier: Option<Tier>,
326 pub stripped_diff: String,
328}
329
330pub fn scan_model_switch(
344 diff: &str,
345 harness: &str,
346 model_config: &ModelConfig,
347) -> ModelSwitchScan {
348 let mut model_switch: Option<String> = None;
349 let mut model_switch_tier: Option<Tier> = None;
350 let mut kept_lines: Vec<&str> = Vec::with_capacity(diff.lines().count());
351
352 let mut in_fence = false;
353 let mut fence_char = '`';
354 let mut fence_len = 0usize;
355
356 for line in diff.lines() {
357 if line.starts_with("---") || line.starts_with("+++") || line.starts_with("@@") {
359 kept_lines.push(line);
360 continue;
361 }
362
363 let content = if line.starts_with('+') || line.starts_with('-') || line.starts_with(' ') {
365 &line[1..]
366 } else {
367 line
368 };
369
370 let trimmed = content.trim_start();
372 if !in_fence {
373 let fc = trimmed.chars().next().unwrap_or('\0');
374 if (fc == '`' || fc == '~')
375 && let fl = trimmed.chars().take_while(|&c| c == fc).count()
376 && fl >= 3
377 {
378 in_fence = true;
379 fence_char = fc;
380 fence_len = fl;
381 kept_lines.push(line);
382 continue;
383 }
384 } else {
385 let fc = trimmed.chars().next().unwrap_or('\0');
386 if fc == fence_char {
387 let fl = trimmed.chars().take_while(|&c| c == fc).count();
388 if fl >= fence_len && trimmed[fl..].trim().is_empty() {
389 in_fence = false;
390 kept_lines.push(line);
391 continue;
392 }
393 }
394 }
395
396 let is_added = line.starts_with('+') && !line.starts_with("+++");
398 if !is_added {
399 kept_lines.push(line);
400 continue;
401 }
402
403 if in_fence {
405 kept_lines.push(line);
406 continue;
407 }
408
409 if content.starts_with('>') {
411 kept_lines.push(line);
412 continue;
413 }
414
415 let stripped = content.trim_end();
417 if let Some(rest) = stripped.strip_prefix("/model")
418 && let Some(arg) = rest.split_whitespace().next()
419 && !arg.is_empty()
420 {
421 if let Some((tier, name)) = parse_model_arg(arg, harness, model_config) {
423 if model_switch.is_none() {
424 model_switch = Some(name);
425 model_switch_tier = Some(tier);
426 }
427 continue;
429 }
430 continue;
432 }
433
434 kept_lines.push(line);
435 }
436
437 ModelSwitchScan {
438 model_switch,
439 model_switch_tier,
440 stripped_diff: kept_lines.join("\n"),
441 }
442}
443
444pub fn compose_effective_tier(
450 model_switch_tier: Option<Tier>,
451 component_tier: Option<Tier>,
452 frontmatter_tier: Option<Tier>,
453 suggested: Tier,
454) -> Tier {
455 for candidate in [model_switch_tier, component_tier, frontmatter_tier] {
456 if let Some(t) = candidate
457 && !matches!(t, Tier::Auto)
458 {
459 return t;
460 }
461 }
462 suggested
463}
464
465pub fn parse_model_arg(
472 arg: &str,
473 harness: &str,
474 model_config: &ModelConfig,
475) -> Option<(Tier, String)> {
476 let trimmed = arg.trim();
477 if let Ok(tier) = Tier::from_str(trimmed) {
479 if matches!(tier, Tier::Auto) {
480 return None;
481 }
482 let name = resolve_tier_to_model(tier, harness, model_config)
483 .unwrap_or_else(|| trimmed.to_string());
484 return Some((tier, name));
485 }
486 if let Some(tier) = tier_from_model_name(trimmed, harness, model_config) {
488 return Some((tier, trimmed.to_string()));
489 }
490 None
492}
493
494#[cfg(test)]
495mod tests {
496 use super::*;
497
498 #[test]
499 fn tier_ordering() {
500 assert!(Tier::Auto < Tier::Low);
501 assert!(Tier::Low < Tier::Med);
502 assert!(Tier::Med < Tier::High);
503 assert!(Tier::High > Tier::Low);
504 assert!(Tier::Med >= Tier::Med);
505 }
506
507 #[test]
508 fn tier_from_str_case_insensitive() {
509 assert_eq!("LOW".parse::<Tier>().unwrap(), Tier::Low);
510 assert_eq!("low".parse::<Tier>().unwrap(), Tier::Low);
511 assert_eq!("Low".parse::<Tier>().unwrap(), Tier::Low);
512 assert_eq!("AUTO".parse::<Tier>().unwrap(), Tier::Auto);
513 assert_eq!("med".parse::<Tier>().unwrap(), Tier::Med);
514 assert_eq!("medium".parse::<Tier>().unwrap(), Tier::Med);
515 assert_eq!("HIGH".parse::<Tier>().unwrap(), Tier::High);
516 }
517
518 #[test]
519 fn tier_from_str_invalid() {
520 assert!("ultra".parse::<Tier>().is_err());
521 assert!("".parse::<Tier>().is_err());
522 assert!("opus".parse::<Tier>().is_err());
523 }
524
525 #[test]
526 fn tier_display() {
527 assert_eq!(Tier::Low.to_string(), "low");
528 assert_eq!(Tier::Med.to_string(), "med");
529 assert_eq!(Tier::High.to_string(), "high");
530 assert_eq!(Tier::Auto.to_string(), "auto");
531 }
532
533 #[test]
534 fn harness_detection_returns_known_value() {
535 let h = detect_harness();
538 assert!(
539 matches!(h.as_str(), "claude-code" | "codex" | "default"),
540 "unexpected harness: {h}"
541 );
542 }
543
544 #[test]
545 fn resolve_builtin_claude_code() {
546 let cfg = ModelConfig::default();
547 assert_eq!(
548 resolve_tier_to_model(Tier::High, "claude-code", &cfg).as_deref(),
549 Some("opus")
550 );
551 assert_eq!(
552 resolve_tier_to_model(Tier::Med, "claude-code", &cfg).as_deref(),
553 Some("sonnet")
554 );
555 assert_eq!(
556 resolve_tier_to_model(Tier::Low, "claude-code", &cfg).as_deref(),
557 Some("haiku")
558 );
559 assert_eq!(resolve_tier_to_model(Tier::Auto, "claude-code", &cfg), None);
560 }
561
562 #[test]
563 fn resolve_builtin_codex() {
564 let cfg = ModelConfig::default();
565 assert_eq!(
566 resolve_tier_to_model(Tier::High, "codex", &cfg).as_deref(),
567 Some("o3")
568 );
569 assert_eq!(
570 resolve_tier_to_model(Tier::Low, "codex", &cfg).as_deref(),
571 Some("gpt-4o-mini")
572 );
573 }
574
575 #[test]
576 fn resolve_unknown_harness_uses_default() {
577 let cfg = ModelConfig::default();
578 assert_eq!(
580 resolve_tier_to_model(Tier::High, "junie", &cfg).as_deref(),
581 Some("opus")
582 );
583 }
584
585 #[test]
586 fn user_config_overrides_builtin() {
587 let mut cfg = ModelConfig::default();
588 let mut tiers = BTreeMap::new();
589 tiers.insert(
590 "claude-code".to_string(),
591 TierMap {
592 low: Some("haiku-3".to_string()),
593 med: Some("sonnet-4".to_string()),
594 high: Some("opus-4-1".to_string()),
595 },
596 );
597 cfg.tiers = tiers;
598 assert_eq!(
599 resolve_tier_to_model(Tier::High, "claude-code", &cfg).as_deref(),
600 Some("opus-4-1")
601 );
602 }
603
604 #[test]
605 fn tier_from_model_name_builtin() {
606 let cfg = ModelConfig::default();
607 assert_eq!(
608 tier_from_model_name("opus", "claude-code", &cfg),
609 Some(Tier::High)
610 );
611 assert_eq!(
612 tier_from_model_name("sonnet", "claude-code", &cfg),
613 Some(Tier::Med)
614 );
615 assert_eq!(
616 tier_from_model_name("haiku", "claude-code", &cfg),
617 Some(Tier::Low)
618 );
619 assert_eq!(tier_from_model_name("unknown", "claude-code", &cfg), None);
620 }
621
622 #[test]
623 fn parse_model_arg_tier_name() {
624 let cfg = ModelConfig::default();
625 let (tier, name) = parse_model_arg("high", "claude-code", &cfg).unwrap();
626 assert_eq!(tier, Tier::High);
627 assert_eq!(name, "opus");
628 }
629
630 #[test]
631 fn parse_model_arg_concrete_name() {
632 let cfg = ModelConfig::default();
633 let (tier, name) = parse_model_arg("opus", "claude-code", &cfg).unwrap();
634 assert_eq!(tier, Tier::High);
635 assert_eq!(name, "opus");
636 }
637
638 #[test]
639 fn parse_model_arg_unknown() {
640 let cfg = ModelConfig::default();
641 assert!(parse_model_arg("xyz-3000", "claude-code", &cfg).is_none());
642 }
643
644 #[test]
645 fn parse_model_arg_auto_rejected() {
646 let cfg = ModelConfig::default();
647 assert!(parse_model_arg("auto", "claude-code", &cfg).is_none());
648 }
649
650 #[test]
651 fn extract_model_component_present() {
652 let doc = "# Title\n\n<!-- agent:model -->\nhigh\n<!-- /agent:model -->\n\nbody\n";
653 assert_eq!(extract_model_component(doc).as_deref(), Some("high"));
654 }
655
656 #[test]
657 fn extract_model_component_absent() {
658 let doc = "# Title\n\nbody only\n";
659 assert_eq!(extract_model_component(doc), None);
660 }
661
662 #[test]
663 fn extract_model_component_empty_inner() {
664 let doc = "<!-- agent:model -->\n<!-- /agent:model -->\n";
665 assert_eq!(extract_model_component(doc), None);
666 }
667
668 #[test]
669 fn extract_model_component_concrete_name() {
670 let doc = "<!-- agent:model -->\nopus\n<!-- /agent:model -->\n";
671 assert_eq!(extract_model_component(doc).as_deref(), Some("opus"));
672 }
673
674 #[test]
675 fn component_value_to_tier_tier_name() {
676 let cfg = ModelConfig::default();
677 assert_eq!(
678 component_value_to_tier("high", "claude-code", &cfg),
679 Some(Tier::High)
680 );
681 }
682
683 #[test]
684 fn component_value_to_tier_concrete_name() {
685 let cfg = ModelConfig::default();
686 assert_eq!(
687 component_value_to_tier("opus", "claude-code", &cfg),
688 Some(Tier::High)
689 );
690 }
691
692 #[test]
693 fn component_value_to_tier_unknown() {
694 let cfg = ModelConfig::default();
695 assert_eq!(component_value_to_tier("xyz", "claude-code", &cfg), None);
696 }
697
698 #[test]
699 fn suggested_tier_simple_question() {
700 let path = std::path::Path::new("tasks/research/x.md");
701 assert_eq!(suggested_tier(Some("simple_question"), 1, path), Tier::Low);
702 }
703
704 #[test]
705 fn suggested_tier_small_addition() {
706 let path = std::path::Path::new("tasks/research/x.md");
707 assert_eq!(suggested_tier(Some("content_addition"), 5, path), Tier::Low);
708 }
709
710 #[test]
711 fn suggested_tier_large_addition() {
712 let path = std::path::Path::new("tasks/research/x.md");
713 assert_eq!(suggested_tier(Some("content_addition"), 50, path), Tier::Med);
714 }
715
716 #[test]
717 fn suggested_tier_default_for_unknown() {
718 let path = std::path::Path::new("tasks/research/x.md");
719 assert_eq!(suggested_tier(None, 0, path), Tier::Med);
720 }
721
722 #[test]
723 fn suggested_tier_path_boost_software() {
724 let path = std::path::Path::new("tasks/software/foo.md");
725 assert_eq!(
727 suggested_tier(Some("simple_question"), 1, path),
728 Tier::Med
729 );
730 assert_eq!(
732 suggested_tier(Some("content_addition"), 50, path),
733 Tier::High
734 );
735 }
736
737 #[test]
738 fn suggested_tier_path_boost_caps_at_high() {
739 let path = std::path::Path::new("tasks/software/foo.md");
740 let t = suggested_tier(Some("content_addition"), 50, path);
742 assert_eq!(t, Tier::High);
743 }
744
745 #[test]
746 fn compose_effective_tier_model_switch_wins() {
747 let t = compose_effective_tier(
748 Some(Tier::High),
749 Some(Tier::Low),
750 Some(Tier::Med),
751 Tier::Low,
752 );
753 assert_eq!(t, Tier::High);
754 }
755
756 #[test]
757 fn compose_effective_tier_component_beats_frontmatter() {
758 let t = compose_effective_tier(None, Some(Tier::High), Some(Tier::Low), Tier::Med);
759 assert_eq!(t, Tier::High);
760 }
761
762 #[test]
763 fn compose_effective_tier_frontmatter_beats_heuristic() {
764 let t = compose_effective_tier(None, None, Some(Tier::High), Tier::Low);
765 assert_eq!(t, Tier::High);
766 }
767
768 #[test]
769 fn compose_effective_tier_falls_through_to_heuristic() {
770 let t = compose_effective_tier(None, None, None, Tier::Med);
771 assert_eq!(t, Tier::Med);
772 }
773
774 #[test]
775 fn scan_model_switch_concrete_name() {
776 let cfg = ModelConfig::default();
777 let diff = "@@ -1,3 +1,4 @@\n context\n+/model opus\n+real edit\n";
778 let result = scan_model_switch(diff, "claude-code", &cfg);
779 assert_eq!(result.model_switch.as_deref(), Some("opus"));
780 assert_eq!(result.model_switch_tier, Some(Tier::High));
781 assert!(!result.stripped_diff.contains("/model opus"));
782 assert!(result.stripped_diff.contains("real edit"));
783 }
784
785 #[test]
786 fn scan_model_switch_tier_name() {
787 let cfg = ModelConfig::default();
788 let diff = "+/model high\n+other line\n";
789 let result = scan_model_switch(diff, "claude-code", &cfg);
790 assert_eq!(result.model_switch_tier, Some(Tier::High));
791 assert_eq!(result.model_switch.as_deref(), Some("opus"));
792 assert!(!result.stripped_diff.contains("/model high"));
793 }
794
795 #[test]
796 fn scan_model_switch_haiku() {
797 let cfg = ModelConfig::default();
798 let diff = "+/model haiku\n";
799 let result = scan_model_switch(diff, "claude-code", &cfg);
800 assert_eq!(result.model_switch_tier, Some(Tier::Low));
801 }
802
803 #[test]
804 fn scan_model_switch_inside_fenced_code_ignored() {
805 let cfg = ModelConfig::default();
806 let diff = "+```\n+/model opus\n+```\n+real line\n";
807 let result = scan_model_switch(diff, "claude-code", &cfg);
808 assert_eq!(result.model_switch, None);
809 assert!(result.stripped_diff.contains("/model opus"));
810 }
811
812 #[test]
813 fn scan_model_switch_inside_blockquote_ignored() {
814 let cfg = ModelConfig::default();
815 let diff = "+> /model opus\n+real line\n";
816 let result = scan_model_switch(diff, "claude-code", &cfg);
817 assert_eq!(result.model_switch, None);
818 assert!(result.stripped_diff.contains("/model opus"));
819 }
820
821 #[test]
822 fn scan_model_switch_only_added_lines() {
823 let cfg = ModelConfig::default();
824 let diff = " /model opus\n+real line\n";
826 let result = scan_model_switch(diff, "claude-code", &cfg);
827 assert_eq!(result.model_switch, None);
828 }
829
830 #[test]
831 fn scan_model_switch_no_match() {
832 let cfg = ModelConfig::default();
833 let diff = "+just a normal line\n+another\n";
834 let result = scan_model_switch(diff, "claude-code", &cfg);
835 assert_eq!(result.model_switch, None);
836 assert!(result.stripped_diff.contains("just a normal line"));
838 assert!(result.stripped_diff.contains("another"));
839 }
840
841 #[test]
842 fn scan_model_switch_unknown_arg_still_stripped() {
843 let cfg = ModelConfig::default();
844 let diff = "+/model xyz-3000\n+real line\n";
846 let result = scan_model_switch(diff, "claude-code", &cfg);
847 assert_eq!(result.model_switch, None);
848 assert!(!result.stripped_diff.contains("/model xyz-3000"));
849 assert!(result.stripped_diff.contains("real line"));
850 }
851
852 #[test]
853 fn scan_model_switch_first_match_wins() {
854 let cfg = ModelConfig::default();
855 let diff = "+/model opus\n+/model haiku\n";
856 let result = scan_model_switch(diff, "claude-code", &cfg);
857 assert_eq!(result.model_switch.as_deref(), Some("opus"));
858 assert!(!result.stripped_diff.contains("/model"));
860 }
861
862 #[test]
863 fn compose_effective_tier_auto_falls_through() {
864 let t = compose_effective_tier(
866 Some(Tier::Auto),
867 Some(Tier::Auto),
868 Some(Tier::High),
869 Tier::Low,
870 );
871 assert_eq!(t, Tier::High);
872 }
873}