Skip to main content

mur_common/
route.rs

1//! Route data types for the cost-router orchestrator.
2//!
3//! These types are shared between `mur-core` (router logic) and
4//! `mur-agent-runtime` (future Phase 2 spawn decisions).
5
6use serde::{Deserialize, Deserializer, Serialize, Serializer};
7use std::fmt;
8
9/// Whether a model is cheap/local or frontier/expensive.
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum RouteTier {
12    /// Free or near-free local model (Ollama, llama.cpp, mlx).
13    Local,
14    /// Paid cloud frontier model (Claude, GPT, Gemini).
15    Frontier,
16}
17
18impl Serialize for RouteTier {
19    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
20        serializer.serialize_str(match self {
21            RouteTier::Local => "local",
22            RouteTier::Frontier => "frontier",
23        })
24    }
25}
26
27impl<'de> Deserialize<'de> for RouteTier {
28    fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
29        let s = String::deserialize(deserializer)?;
30        match s.to_lowercase().as_str() {
31            "local" => Ok(RouteTier::Local),
32            "frontier" => Ok(RouteTier::Frontier),
33            other => Err(serde::de::Error::unknown_variant(
34                other,
35                &["local", "frontier"],
36            )),
37        }
38    }
39}
40
41/// Which model and tier to use for a sub-task.
42#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
43#[serde(tag = "kind", rename_all = "snake_case")]
44pub enum RouteDecision {
45    /// Route to a cheap/local model.
46    Local { model_id: String, reason: String },
47    /// Escalate to a frontier model.
48    Escalate { model_id: String, reason: String },
49}
50
51impl fmt::Display for RouteDecision {
52    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
53        match self {
54            RouteDecision::Local { model_id, reason } => {
55                write!(f, "local → {model_id} ({reason})")
56            }
57            RouteDecision::Escalate { model_id, reason } => {
58                write!(f, "escalate → {model_id} ({reason})")
59            }
60        }
61    }
62}
63
64/// Categories of sub-tasks for difficulty scoring.
65#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
66#[serde(rename_all = "snake_case")]
67pub enum TaskType {
68    /// Writing or modifying code.
69    CodeGen,
70    /// Reviewing / auditing existing code.
71    CodeReview,
72    /// Searching / retrieving information.
73    Retrieval,
74    /// Refactoring across multiple files.
75    Refactor,
76    /// Writing or updating documentation.
77    Documentation,
78    /// Debugging / investigating issues.
79    Debugging,
80    /// Running tests or commands.
81    Execution,
82    /// General chat / Q&A.
83    General,
84}
85
86impl TaskType {
87    /// Base difficulty score 0.0–1.0 before other factors.
88    pub fn base_difficulty(&self) -> f64 {
89        match self {
90            TaskType::CodeGen => 0.65,
91            TaskType::Refactor => 0.70,
92            TaskType::Debugging => 0.60,
93            TaskType::CodeReview => 0.55,
94            TaskType::Retrieval => 0.30,
95            TaskType::Documentation => 0.35,
96            TaskType::Execution => 0.25,
97            TaskType::General => 0.40,
98        }
99    }
100}
101
102/// One routing decision recorded in the escalation audit ledger.
103#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct EscalationEvent {
105    /// ISO-8601 timestamp.
106    pub timestamp: String,
107    /// Human-readable summary of what was asked.
108    pub task_summary: String,
109    /// 0.0–1.0 difficulty score from the heuristic.
110    pub difficulty_score: f64,
111    /// Classified task type.
112    pub task_type: TaskType,
113    /// Estimated context window tokens needed.
114    pub estimated_context_tokens: u64,
115    /// The routing decision made.
116    pub decision: RouteDecision,
117    /// Role that originated this task (if any).
118    #[serde(default, skip_serializing_if = "Option::is_none")]
119    pub role: Option<String>,
120    /// Which local model would have been used if not escalated.
121    #[serde(default, skip_serializing_if = "Option::is_none")]
122    pub escalation_from: Option<String>,
123    /// Estimated USD cost of the route actually taken (0.0 for local).
124    #[serde(default)]
125    pub estimated_cost_usd: f64,
126    /// Estimated USD this task would have cost on the frontier model —
127    /// i.e. the cost avoided when routed local. Equals `estimated_cost_usd`
128    /// for escalations.
129    #[serde(default)]
130    pub counterfactual_cost_usd: f64,
131}
132
133/// Per-role routing override, stored in `RoleEntry.route_policy`.
134#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
135#[serde(rename_all = "snake_case")]
136pub enum RoutePolicy {
137    /// Let the heuristic decide (default).
138    Auto,
139    /// Bias toward local models; only escalate above a higher threshold.
140    PreferLocal,
141    /// Always use local models for this role.
142    ForceLocal,
143    /// Always use a specific frontier model for this role.
144    ForceFrontier {
145        /// Model registry key (e.g. "anthropic_opus_4_7").
146        model_id: String,
147    },
148}
149
150#[cfg(test)]
151mod tests {
152    use super::*;
153
154    #[test]
155    fn route_tier_serializes_as_lowercase() {
156        let local = RouteTier::Local;
157        let json = serde_json::to_string(&local).unwrap();
158        assert_eq!(json, r#""local""#);
159
160        let frontier = RouteTier::Frontier;
161        let json = serde_json::to_string(&frontier).unwrap();
162        assert_eq!(json, r#""frontier""#);
163    }
164
165    #[test]
166    fn route_tier_deserializes_case_insensitive() {
167        let tier: RouteTier = serde_json::from_str(r#""local""#).unwrap();
168        assert_eq!(tier, RouteTier::Local);
169
170        let tier: RouteTier = serde_json::from_str(r#""LOCAL""#).unwrap();
171        assert_eq!(tier, RouteTier::Local);
172
173        let tier: RouteTier = serde_json::from_str(r#""Frontier""#).unwrap();
174        assert_eq!(tier, RouteTier::Frontier);
175    }
176
177    #[test]
178    fn route_decision_display() {
179        let d = RouteDecision::Local {
180            model_id: "ollama_llama3".into(),
181            reason: "low difficulty (score 0.15)".into(),
182        };
183        let s = d.to_string();
184        assert!(s.contains("local"));
185        assert!(s.contains("ollama_llama3"));
186
187        let d = RouteDecision::Escalate {
188            model_id: "anthropic_opus_4_7".into(),
189            reason: "high complexity (score 0.82)".into(),
190        };
191        let s = d.to_string();
192        assert!(s.contains("escalate"));
193        assert!(s.contains("anthropic_opus_4_7"));
194    }
195
196    #[test]
197    fn escalation_event_roundtrip() {
198        let event = EscalationEvent {
199            timestamp: "2026-06-01T12:00:00Z".into(),
200            task_summary: "Refactor auth module".into(),
201            difficulty_score: 0.82,
202            task_type: TaskType::CodeGen,
203            estimated_context_tokens: 3500,
204            decision: RouteDecision::Escalate {
205                model_id: "anthropic_opus_4_7".into(),
206                reason: "high complexity".into(),
207            },
208            role: Some("dev".into()),
209            escalation_from: Some("ollama_llama3".into()),
210            estimated_cost_usd: 0.0525,
211            counterfactual_cost_usd: 0.0525,
212        };
213        let json = serde_json::to_string(&event).unwrap();
214        let parsed: EscalationEvent = serde_json::from_str(&json).unwrap();
215        assert_eq!(parsed.task_summary, "Refactor auth module");
216        assert_eq!(parsed.difficulty_score, 0.82);
217        assert_eq!(parsed.task_type, TaskType::CodeGen);
218        assert_eq!(parsed.counterfactual_cost_usd, 0.0525);
219        assert!(matches!(parsed.decision, RouteDecision::Escalate { .. }));
220    }
221
222    #[test]
223    fn route_policy_serde_roundtrip() {
224        let policy = RoutePolicy::Auto;
225        let yaml = serde_yaml_ng::to_string(&policy).unwrap();
226        let parsed: RoutePolicy = serde_yaml_ng::from_str(&yaml).unwrap();
227        assert_eq!(parsed, RoutePolicy::Auto);
228
229        let policy = RoutePolicy::PreferLocal;
230        let yaml = serde_yaml_ng::to_string(&policy).unwrap();
231        let parsed: RoutePolicy = serde_yaml_ng::from_str(&yaml).unwrap();
232        assert_eq!(parsed, RoutePolicy::PreferLocal);
233
234        let policy = RoutePolicy::ForceFrontier {
235            model_id: "claude".into(),
236        };
237        let yaml = serde_yaml_ng::to_string(&policy).unwrap();
238        let parsed: RoutePolicy = serde_yaml_ng::from_str(&yaml).unwrap();
239        assert_eq!(
240            parsed,
241            RoutePolicy::ForceFrontier {
242                model_id: "claude".into()
243            }
244        );
245    }
246
247    #[test]
248    fn task_type_base_difficulty_ordering() {
249        // Execution and retrieval should be easiest.
250        assert!(TaskType::Execution.base_difficulty() < TaskType::CodeGen.base_difficulty());
251        assert!(TaskType::Retrieval.base_difficulty() < TaskType::CodeGen.base_difficulty());
252        // Refactor should be hardest.
253        assert!(TaskType::Refactor.base_difficulty() > TaskType::Documentation.base_difficulty());
254        // All scores in [0.0, 1.0].
255        let all = [
256            TaskType::CodeGen,
257            TaskType::CodeReview,
258            TaskType::Retrieval,
259            TaskType::Refactor,
260            TaskType::Documentation,
261            TaskType::Debugging,
262            TaskType::Execution,
263            TaskType::General,
264        ];
265        for tt in &all {
266            let s = tt.base_difficulty();
267            assert!(
268                (0.0..=1.0).contains(&s),
269                "{tt:?} base_difficulty {s} out of range"
270            );
271        }
272    }
273}