1use serde::{Deserialize, Deserializer, Serialize, Serializer};
7use std::fmt;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum RouteTier {
12 Local,
14 Frontier,
16}
17
18impl Serialize for RouteTier {
19 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
20 serializer.serialize_str(match self {
21 RouteTier::Local => "local",
22 RouteTier::Frontier => "frontier",
23 })
24 }
25}
26
27impl<'de> Deserialize<'de> for RouteTier {
28 fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
29 let s = String::deserialize(deserializer)?;
30 match s.to_lowercase().as_str() {
31 "local" => Ok(RouteTier::Local),
32 "frontier" => Ok(RouteTier::Frontier),
33 other => Err(serde::de::Error::unknown_variant(
34 other,
35 &["local", "frontier"],
36 )),
37 }
38 }
39}
40
41#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
43#[serde(tag = "kind", rename_all = "snake_case")]
44pub enum RouteDecision {
45 Local { model_id: String, reason: String },
47 Escalate { model_id: String, reason: String },
49}
50
51impl fmt::Display for RouteDecision {
52 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
53 match self {
54 RouteDecision::Local { model_id, reason } => {
55 write!(f, "local → {model_id} ({reason})")
56 }
57 RouteDecision::Escalate { model_id, reason } => {
58 write!(f, "escalate → {model_id} ({reason})")
59 }
60 }
61 }
62}
63
64#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
66#[serde(rename_all = "snake_case")]
67pub enum TaskType {
68 CodeGen,
70 CodeReview,
72 Retrieval,
74 Refactor,
76 Documentation,
78 Debugging,
80 Execution,
82 General,
84}
85
86impl TaskType {
87 pub fn base_difficulty(&self) -> f64 {
89 match self {
90 TaskType::CodeGen => 0.65,
91 TaskType::Refactor => 0.70,
92 TaskType::Debugging => 0.60,
93 TaskType::CodeReview => 0.55,
94 TaskType::Retrieval => 0.30,
95 TaskType::Documentation => 0.35,
96 TaskType::Execution => 0.25,
97 TaskType::General => 0.40,
98 }
99 }
100}
101
102#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct EscalationEvent {
105 pub timestamp: String,
107 pub task_summary: String,
109 pub difficulty_score: f64,
111 pub task_type: TaskType,
113 pub estimated_context_tokens: u64,
115 pub decision: RouteDecision,
117 #[serde(default, skip_serializing_if = "Option::is_none")]
119 pub role: Option<String>,
120 #[serde(default, skip_serializing_if = "Option::is_none")]
122 pub escalation_from: Option<String>,
123 #[serde(default)]
125 pub estimated_cost_usd: f64,
126 #[serde(default)]
130 pub counterfactual_cost_usd: f64,
131}
132
133#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
135#[serde(rename_all = "snake_case")]
136pub enum RoutePolicy {
137 Auto,
139 PreferLocal,
141 ForceLocal,
143 ForceFrontier {
145 model_id: String,
147 },
148}
149
150#[cfg(test)]
151mod tests {
152 use super::*;
153
154 #[test]
155 fn route_tier_serializes_as_lowercase() {
156 let local = RouteTier::Local;
157 let json = serde_json::to_string(&local).unwrap();
158 assert_eq!(json, r#""local""#);
159
160 let frontier = RouteTier::Frontier;
161 let json = serde_json::to_string(&frontier).unwrap();
162 assert_eq!(json, r#""frontier""#);
163 }
164
165 #[test]
166 fn route_tier_deserializes_case_insensitive() {
167 let tier: RouteTier = serde_json::from_str(r#""local""#).unwrap();
168 assert_eq!(tier, RouteTier::Local);
169
170 let tier: RouteTier = serde_json::from_str(r#""LOCAL""#).unwrap();
171 assert_eq!(tier, RouteTier::Local);
172
173 let tier: RouteTier = serde_json::from_str(r#""Frontier""#).unwrap();
174 assert_eq!(tier, RouteTier::Frontier);
175 }
176
177 #[test]
178 fn route_decision_display() {
179 let d = RouteDecision::Local {
180 model_id: "ollama_llama3".into(),
181 reason: "low difficulty (score 0.15)".into(),
182 };
183 let s = d.to_string();
184 assert!(s.contains("local"));
185 assert!(s.contains("ollama_llama3"));
186
187 let d = RouteDecision::Escalate {
188 model_id: "anthropic_opus_4_7".into(),
189 reason: "high complexity (score 0.82)".into(),
190 };
191 let s = d.to_string();
192 assert!(s.contains("escalate"));
193 assert!(s.contains("anthropic_opus_4_7"));
194 }
195
196 #[test]
197 fn escalation_event_roundtrip() {
198 let event = EscalationEvent {
199 timestamp: "2026-06-01T12:00:00Z".into(),
200 task_summary: "Refactor auth module".into(),
201 difficulty_score: 0.82,
202 task_type: TaskType::CodeGen,
203 estimated_context_tokens: 3500,
204 decision: RouteDecision::Escalate {
205 model_id: "anthropic_opus_4_7".into(),
206 reason: "high complexity".into(),
207 },
208 role: Some("dev".into()),
209 escalation_from: Some("ollama_llama3".into()),
210 estimated_cost_usd: 0.0525,
211 counterfactual_cost_usd: 0.0525,
212 };
213 let json = serde_json::to_string(&event).unwrap();
214 let parsed: EscalationEvent = serde_json::from_str(&json).unwrap();
215 assert_eq!(parsed.task_summary, "Refactor auth module");
216 assert_eq!(parsed.difficulty_score, 0.82);
217 assert_eq!(parsed.task_type, TaskType::CodeGen);
218 assert_eq!(parsed.counterfactual_cost_usd, 0.0525);
219 assert!(matches!(parsed.decision, RouteDecision::Escalate { .. }));
220 }
221
222 #[test]
223 fn route_policy_serde_roundtrip() {
224 let policy = RoutePolicy::Auto;
225 let yaml = serde_yaml_ng::to_string(&policy).unwrap();
226 let parsed: RoutePolicy = serde_yaml_ng::from_str(&yaml).unwrap();
227 assert_eq!(parsed, RoutePolicy::Auto);
228
229 let policy = RoutePolicy::PreferLocal;
230 let yaml = serde_yaml_ng::to_string(&policy).unwrap();
231 let parsed: RoutePolicy = serde_yaml_ng::from_str(&yaml).unwrap();
232 assert_eq!(parsed, RoutePolicy::PreferLocal);
233
234 let policy = RoutePolicy::ForceFrontier {
235 model_id: "claude".into(),
236 };
237 let yaml = serde_yaml_ng::to_string(&policy).unwrap();
238 let parsed: RoutePolicy = serde_yaml_ng::from_str(&yaml).unwrap();
239 assert_eq!(
240 parsed,
241 RoutePolicy::ForceFrontier {
242 model_id: "claude".into()
243 }
244 );
245 }
246
247 #[test]
248 fn task_type_base_difficulty_ordering() {
249 assert!(TaskType::Execution.base_difficulty() < TaskType::CodeGen.base_difficulty());
251 assert!(TaskType::Retrieval.base_difficulty() < TaskType::CodeGen.base_difficulty());
252 assert!(TaskType::Refactor.base_difficulty() > TaskType::Documentation.base_difficulty());
254 let all = [
256 TaskType::CodeGen,
257 TaskType::CodeReview,
258 TaskType::Retrieval,
259 TaskType::Refactor,
260 TaskType::Documentation,
261 TaskType::Debugging,
262 TaskType::Execution,
263 TaskType::General,
264 ];
265 for tt in &all {
266 let s = tt.base_difficulty();
267 assert!(
268 (0.0..=1.0).contains(&s),
269 "{tt:?} base_difficulty {s} out of range"
270 );
271 }
272 }
273}