use anyhow::{anyhow, Result};
use serde::{Deserialize, Serialize};
use std::collections::BTreeMap;
use std::str::FromStr;
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum Tier {
#[default]
Auto,
Low,
Med,
High,
}
impl std::fmt::Display for Tier {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Auto => write!(f, "auto"),
Self::Low => write!(f, "low"),
Self::Med => write!(f, "med"),
Self::High => write!(f, "high"),
}
}
}
impl FromStr for Tier {
type Err = anyhow::Error;
fn from_str(s: &str) -> Result<Self> {
match s.trim().to_ascii_lowercase().as_str() {
"auto" => Ok(Self::Auto),
"low" => Ok(Self::Low),
"med" | "medium" => Ok(Self::Med),
"high" => Ok(Self::High),
other => Err(anyhow!(
"invalid tier `{}`: expected one of auto|low|med|high",
other
)),
}
}
}
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
pub struct TierMap {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub low: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub med: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub high: Option<String>,
}
impl TierMap {
pub fn get(&self, tier: Tier) -> Option<&str> {
match tier {
Tier::Auto => None,
Tier::Low => self.low.as_deref(),
Tier::Med => self.med.as_deref(),
Tier::High => self.high.as_deref(),
}
}
pub fn tier_of(&self, model_name: &str) -> Option<Tier> {
if self.low.as_deref() == Some(model_name) {
Some(Tier::Low)
} else if self.med.as_deref() == Some(model_name) {
Some(Tier::Med)
} else if self.high.as_deref() == Some(model_name) {
Some(Tier::High)
} else {
None
}
}
}
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
pub struct ModelConfig {
#[serde(default = "default_auto")]
pub auto: bool,
#[serde(default)]
pub tiers: BTreeMap<String, TierMap>,
}
fn default_auto() -> bool {
true
}
fn builtin_claude_code() -> TierMap {
TierMap {
low: Some("haiku".to_string()),
med: Some("sonnet".to_string()),
high: Some("opus".to_string()),
}
}
fn builtin_codex() -> TierMap {
TierMap {
low: Some("gpt-4o-mini".to_string()),
med: Some("gpt-4o".to_string()),
high: Some("o3".to_string()),
}
}
fn builtin_default() -> TierMap {
TierMap {
low: Some("haiku".to_string()),
med: Some("sonnet".to_string()),
high: Some("opus".to_string()),
}
}
fn builtin_for(harness: &str) -> TierMap {
match harness {
"claude-code" => builtin_claude_code(),
"codex" => builtin_codex(),
_ => builtin_default(),
}
}
pub fn detect_harness() -> String {
if std::env::var("CLAUDE_CODE_SESSION").is_ok() || std::env::var("CLAUDECODE").is_ok() {
"claude-code".to_string()
} else if std::env::var("CODEX_SESSION").is_ok() {
"codex".to_string()
} else {
"default".to_string()
}
}
pub fn resolve_tier_to_model(
tier: Tier,
harness: &str,
model_config: &ModelConfig,
) -> Option<String> {
if matches!(tier, Tier::Auto) {
return None;
}
if let Some(map) = model_config.tiers.get(harness)
&& let Some(name) = map.get(tier)
{
return Some(name.to_string());
}
builtin_for(harness).get(tier).map(|s| s.to_string())
}
pub fn tier_from_model_name(
model_name: &str,
harness: &str,
model_config: &ModelConfig,
) -> Option<Tier> {
if let Some(map) = model_config.tiers.get(harness)
&& let Some(t) = map.tier_of(model_name)
{
return Some(t);
}
builtin_for(harness).tier_of(model_name)
}
pub fn extract_model_component(content: &str) -> Option<String> {
let comps = crate::component::parse(content).ok()?;
let comp = comps.into_iter().find(|c| c.name == "model")?;
let inner = &content[comp.open_end..comp.close_start];
let trimmed = inner.trim();
if trimmed.is_empty() {
None
} else {
Some(trimmed.to_string())
}
}
pub fn component_value_to_tier(
value: &str,
harness: &str,
model_config: &ModelConfig,
) -> Option<Tier> {
if let Ok(tier) = Tier::from_str(value) {
return Some(tier);
}
tier_from_model_name(value, harness, model_config)
}
pub fn suggested_tier(diff_type: Option<&str>, lines_added: usize, doc_path: &std::path::Path) -> Tier {
let base = match diff_type {
Some("simple_question") | Some("approval") | Some("boundary_artifact") | Some("annotation") => {
Tier::Low
}
Some("content_addition") => {
if lines_added < 10 {
Tier::Low
} else {
Tier::Med
}
}
Some("multi_topic") | Some("structural_change") => Tier::Med,
_ => Tier::Med,
};
let path_str = doc_path.to_string_lossy();
let boost = path_str.contains("tasks/software/")
|| path_str.contains("/specs/")
|| path_str.contains("agent-doc-bugs")
|| path_str.contains("plan-")
|| path_str.contains("/plan.md");
if boost {
match base {
Tier::Auto | Tier::Low => Tier::Med,
Tier::Med => Tier::High,
Tier::High => Tier::High,
}
} else {
base
}
}
#[derive(Debug, Clone)]
pub struct ModelSwitchScan {
pub model_switch: Option<String>,
pub model_switch_tier: Option<Tier>,
pub stripped_diff: String,
}
pub fn scan_model_switch(
diff: &str,
harness: &str,
model_config: &ModelConfig,
) -> ModelSwitchScan {
let mut model_switch: Option<String> = None;
let mut model_switch_tier: Option<Tier> = None;
let mut kept_lines: Vec<&str> = Vec::with_capacity(diff.lines().count());
let mut in_fence = false;
let mut fence_char = '`';
let mut fence_len = 0usize;
for line in diff.lines() {
if line.starts_with("---") || line.starts_with("+++") || line.starts_with("@@") {
kept_lines.push(line);
continue;
}
let content = if line.starts_with('+') || line.starts_with('-') || line.starts_with(' ') {
&line[1..]
} else {
line
};
let trimmed = content.trim_start();
if !in_fence {
let fc = trimmed.chars().next().unwrap_or('\0');
if (fc == '`' || fc == '~')
&& let fl = trimmed.chars().take_while(|&c| c == fc).count()
&& fl >= 3
{
in_fence = true;
fence_char = fc;
fence_len = fl;
kept_lines.push(line);
continue;
}
} else {
let fc = trimmed.chars().next().unwrap_or('\0');
if fc == fence_char {
let fl = trimmed.chars().take_while(|&c| c == fc).count();
if fl >= fence_len && trimmed[fl..].trim().is_empty() {
in_fence = false;
kept_lines.push(line);
continue;
}
}
}
let is_added = line.starts_with('+') && !line.starts_with("+++");
if !is_added {
kept_lines.push(line);
continue;
}
if in_fence {
kept_lines.push(line);
continue;
}
if content.starts_with('>') {
kept_lines.push(line);
continue;
}
let stripped = content.trim_end();
if let Some(rest) = stripped.strip_prefix("/model")
&& let Some(arg) = rest.split_whitespace().next()
&& !arg.is_empty()
{
if let Some((tier, name)) = parse_model_arg(arg, harness, model_config) {
if model_switch.is_none() {
model_switch = Some(name);
model_switch_tier = Some(tier);
}
continue;
}
continue;
}
kept_lines.push(line);
}
ModelSwitchScan {
model_switch,
model_switch_tier,
stripped_diff: kept_lines.join("\n"),
}
}
pub fn compose_effective_tier(
model_switch_tier: Option<Tier>,
component_tier: Option<Tier>,
frontmatter_tier: Option<Tier>,
suggested: Tier,
) -> Tier {
for candidate in [model_switch_tier, component_tier, frontmatter_tier] {
if let Some(t) = candidate
&& !matches!(t, Tier::Auto)
{
return t;
}
}
suggested
}
pub fn parse_model_arg(
arg: &str,
harness: &str,
model_config: &ModelConfig,
) -> Option<(Tier, String)> {
let trimmed = arg.trim();
if let Ok(tier) = Tier::from_str(trimmed) {
if matches!(tier, Tier::Auto) {
return None;
}
let name = resolve_tier_to_model(tier, harness, model_config)
.unwrap_or_else(|| trimmed.to_string());
return Some((tier, name));
}
if let Some(tier) = tier_from_model_name(trimmed, harness, model_config) {
return Some((tier, trimmed.to_string()));
}
None
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn tier_ordering() {
assert!(Tier::Auto < Tier::Low);
assert!(Tier::Low < Tier::Med);
assert!(Tier::Med < Tier::High);
assert!(Tier::High > Tier::Low);
assert!(Tier::Med >= Tier::Med);
}
#[test]
fn tier_from_str_case_insensitive() {
assert_eq!("LOW".parse::<Tier>().unwrap(), Tier::Low);
assert_eq!("low".parse::<Tier>().unwrap(), Tier::Low);
assert_eq!("Low".parse::<Tier>().unwrap(), Tier::Low);
assert_eq!("AUTO".parse::<Tier>().unwrap(), Tier::Auto);
assert_eq!("med".parse::<Tier>().unwrap(), Tier::Med);
assert_eq!("medium".parse::<Tier>().unwrap(), Tier::Med);
assert_eq!("HIGH".parse::<Tier>().unwrap(), Tier::High);
}
#[test]
fn tier_from_str_invalid() {
assert!("ultra".parse::<Tier>().is_err());
assert!("".parse::<Tier>().is_err());
assert!("opus".parse::<Tier>().is_err());
}
#[test]
fn tier_display() {
assert_eq!(Tier::Low.to_string(), "low");
assert_eq!(Tier::Med.to_string(), "med");
assert_eq!(Tier::High.to_string(), "high");
assert_eq!(Tier::Auto.to_string(), "auto");
}
#[test]
fn harness_detection_returns_known_value() {
let h = detect_harness();
assert!(
matches!(h.as_str(), "claude-code" | "codex" | "default"),
"unexpected harness: {h}"
);
}
#[test]
fn resolve_builtin_claude_code() {
let cfg = ModelConfig::default();
assert_eq!(
resolve_tier_to_model(Tier::High, "claude-code", &cfg).as_deref(),
Some("opus")
);
assert_eq!(
resolve_tier_to_model(Tier::Med, "claude-code", &cfg).as_deref(),
Some("sonnet")
);
assert_eq!(
resolve_tier_to_model(Tier::Low, "claude-code", &cfg).as_deref(),
Some("haiku")
);
assert_eq!(resolve_tier_to_model(Tier::Auto, "claude-code", &cfg), None);
}
#[test]
fn resolve_builtin_codex() {
let cfg = ModelConfig::default();
assert_eq!(
resolve_tier_to_model(Tier::High, "codex", &cfg).as_deref(),
Some("o3")
);
assert_eq!(
resolve_tier_to_model(Tier::Low, "codex", &cfg).as_deref(),
Some("gpt-4o-mini")
);
}
#[test]
fn resolve_unknown_harness_uses_default() {
let cfg = ModelConfig::default();
assert_eq!(
resolve_tier_to_model(Tier::High, "junie", &cfg).as_deref(),
Some("opus")
);
}
#[test]
fn user_config_overrides_builtin() {
let mut cfg = ModelConfig::default();
let mut tiers = BTreeMap::new();
tiers.insert(
"claude-code".to_string(),
TierMap {
low: Some("haiku-3".to_string()),
med: Some("sonnet-4".to_string()),
high: Some("opus-4-1".to_string()),
},
);
cfg.tiers = tiers;
assert_eq!(
resolve_tier_to_model(Tier::High, "claude-code", &cfg).as_deref(),
Some("opus-4-1")
);
}
#[test]
fn tier_from_model_name_builtin() {
let cfg = ModelConfig::default();
assert_eq!(
tier_from_model_name("opus", "claude-code", &cfg),
Some(Tier::High)
);
assert_eq!(
tier_from_model_name("sonnet", "claude-code", &cfg),
Some(Tier::Med)
);
assert_eq!(
tier_from_model_name("haiku", "claude-code", &cfg),
Some(Tier::Low)
);
assert_eq!(tier_from_model_name("unknown", "claude-code", &cfg), None);
}
#[test]
fn parse_model_arg_tier_name() {
let cfg = ModelConfig::default();
let (tier, name) = parse_model_arg("high", "claude-code", &cfg).unwrap();
assert_eq!(tier, Tier::High);
assert_eq!(name, "opus");
}
#[test]
fn parse_model_arg_concrete_name() {
let cfg = ModelConfig::default();
let (tier, name) = parse_model_arg("opus", "claude-code", &cfg).unwrap();
assert_eq!(tier, Tier::High);
assert_eq!(name, "opus");
}
#[test]
fn parse_model_arg_unknown() {
let cfg = ModelConfig::default();
assert!(parse_model_arg("xyz-3000", "claude-code", &cfg).is_none());
}
#[test]
fn parse_model_arg_auto_rejected() {
let cfg = ModelConfig::default();
assert!(parse_model_arg("auto", "claude-code", &cfg).is_none());
}
#[test]
fn extract_model_component_present() {
let doc = "# Title\n\n<!-- agent:model -->\nhigh\n<!-- /agent:model -->\n\nbody\n";
assert_eq!(extract_model_component(doc).as_deref(), Some("high"));
}
#[test]
fn extract_model_component_absent() {
let doc = "# Title\n\nbody only\n";
assert_eq!(extract_model_component(doc), None);
}
#[test]
fn extract_model_component_empty_inner() {
let doc = "<!-- agent:model -->\n<!-- /agent:model -->\n";
assert_eq!(extract_model_component(doc), None);
}
#[test]
fn extract_model_component_concrete_name() {
let doc = "<!-- agent:model -->\nopus\n<!-- /agent:model -->\n";
assert_eq!(extract_model_component(doc).as_deref(), Some("opus"));
}
#[test]
fn component_value_to_tier_tier_name() {
let cfg = ModelConfig::default();
assert_eq!(
component_value_to_tier("high", "claude-code", &cfg),
Some(Tier::High)
);
}
#[test]
fn component_value_to_tier_concrete_name() {
let cfg = ModelConfig::default();
assert_eq!(
component_value_to_tier("opus", "claude-code", &cfg),
Some(Tier::High)
);
}
#[test]
fn component_value_to_tier_unknown() {
let cfg = ModelConfig::default();
assert_eq!(component_value_to_tier("xyz", "claude-code", &cfg), None);
}
#[test]
fn suggested_tier_simple_question() {
let path = std::path::Path::new("tasks/research/x.md");
assert_eq!(suggested_tier(Some("simple_question"), 1, path), Tier::Low);
}
#[test]
fn suggested_tier_small_addition() {
let path = std::path::Path::new("tasks/research/x.md");
assert_eq!(suggested_tier(Some("content_addition"), 5, path), Tier::Low);
}
#[test]
fn suggested_tier_large_addition() {
let path = std::path::Path::new("tasks/research/x.md");
assert_eq!(suggested_tier(Some("content_addition"), 50, path), Tier::Med);
}
#[test]
fn suggested_tier_default_for_unknown() {
let path = std::path::Path::new("tasks/research/x.md");
assert_eq!(suggested_tier(None, 0, path), Tier::Med);
}
#[test]
fn suggested_tier_path_boost_software() {
let path = std::path::Path::new("tasks/software/foo.md");
assert_eq!(
suggested_tier(Some("simple_question"), 1, path),
Tier::Med
);
assert_eq!(
suggested_tier(Some("content_addition"), 50, path),
Tier::High
);
}
#[test]
fn suggested_tier_path_boost_caps_at_high() {
let path = std::path::Path::new("tasks/software/foo.md");
let t = suggested_tier(Some("content_addition"), 50, path);
assert_eq!(t, Tier::High);
}
#[test]
fn compose_effective_tier_model_switch_wins() {
let t = compose_effective_tier(
Some(Tier::High),
Some(Tier::Low),
Some(Tier::Med),
Tier::Low,
);
assert_eq!(t, Tier::High);
}
#[test]
fn compose_effective_tier_component_beats_frontmatter() {
let t = compose_effective_tier(None, Some(Tier::High), Some(Tier::Low), Tier::Med);
assert_eq!(t, Tier::High);
}
#[test]
fn compose_effective_tier_frontmatter_beats_heuristic() {
let t = compose_effective_tier(None, None, Some(Tier::High), Tier::Low);
assert_eq!(t, Tier::High);
}
#[test]
fn compose_effective_tier_falls_through_to_heuristic() {
let t = compose_effective_tier(None, None, None, Tier::Med);
assert_eq!(t, Tier::Med);
}
#[test]
fn scan_model_switch_concrete_name() {
let cfg = ModelConfig::default();
let diff = "@@ -1,3 +1,4 @@\n context\n+/model opus\n+real edit\n";
let result = scan_model_switch(diff, "claude-code", &cfg);
assert_eq!(result.model_switch.as_deref(), Some("opus"));
assert_eq!(result.model_switch_tier, Some(Tier::High));
assert!(!result.stripped_diff.contains("/model opus"));
assert!(result.stripped_diff.contains("real edit"));
}
#[test]
fn scan_model_switch_tier_name() {
let cfg = ModelConfig::default();
let diff = "+/model high\n+other line\n";
let result = scan_model_switch(diff, "claude-code", &cfg);
assert_eq!(result.model_switch_tier, Some(Tier::High));
assert_eq!(result.model_switch.as_deref(), Some("opus"));
assert!(!result.stripped_diff.contains("/model high"));
}
#[test]
fn scan_model_switch_haiku() {
let cfg = ModelConfig::default();
let diff = "+/model haiku\n";
let result = scan_model_switch(diff, "claude-code", &cfg);
assert_eq!(result.model_switch_tier, Some(Tier::Low));
}
#[test]
fn scan_model_switch_inside_fenced_code_ignored() {
let cfg = ModelConfig::default();
let diff = "+```\n+/model opus\n+```\n+real line\n";
let result = scan_model_switch(diff, "claude-code", &cfg);
assert_eq!(result.model_switch, None);
assert!(result.stripped_diff.contains("/model opus"));
}
#[test]
fn scan_model_switch_inside_blockquote_ignored() {
let cfg = ModelConfig::default();
let diff = "+> /model opus\n+real line\n";
let result = scan_model_switch(diff, "claude-code", &cfg);
assert_eq!(result.model_switch, None);
assert!(result.stripped_diff.contains("/model opus"));
}
#[test]
fn scan_model_switch_only_added_lines() {
let cfg = ModelConfig::default();
let diff = " /model opus\n+real line\n";
let result = scan_model_switch(diff, "claude-code", &cfg);
assert_eq!(result.model_switch, None);
}
#[test]
fn scan_model_switch_no_match() {
let cfg = ModelConfig::default();
let diff = "+just a normal line\n+another\n";
let result = scan_model_switch(diff, "claude-code", &cfg);
assert_eq!(result.model_switch, None);
assert!(result.stripped_diff.contains("just a normal line"));
assert!(result.stripped_diff.contains("another"));
}
#[test]
fn scan_model_switch_unknown_arg_still_stripped() {
let cfg = ModelConfig::default();
let diff = "+/model xyz-3000\n+real line\n";
let result = scan_model_switch(diff, "claude-code", &cfg);
assert_eq!(result.model_switch, None);
assert!(!result.stripped_diff.contains("/model xyz-3000"));
assert!(result.stripped_diff.contains("real line"));
}
#[test]
fn scan_model_switch_first_match_wins() {
let cfg = ModelConfig::default();
let diff = "+/model opus\n+/model haiku\n";
let result = scan_model_switch(diff, "claude-code", &cfg);
assert_eq!(result.model_switch.as_deref(), Some("opus"));
assert!(!result.stripped_diff.contains("/model"));
}
#[test]
fn compose_effective_tier_auto_falls_through() {
let t = compose_effective_tier(
Some(Tier::Auto),
Some(Tier::Auto),
Some(Tier::High),
Tier::Low,
);
assert_eq!(t, Tier::High);
}
}