1use std::time::Instant;
29
30use serde::{Deserialize, Serialize};
31
32use crate::auto::{AutoHint, AutoRoute};
33use crate::pool::Pool;
34use crate::store::PoolStore;
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct RouteTestCase {
39 pub label: String,
41 pub prompt: String,
43 pub expected: Vec<String>,
45 #[serde(default, skip_serializing_if = "Option::is_none")]
47 pub hints: Option<AutoHint>,
48}
49
50impl RouteTestCase {
51 pub fn new(label: impl Into<String>, prompt: impl Into<String>, expected: &[&str]) -> Self {
53 Self {
54 label: label.into(),
55 prompt: prompt.into(),
56 expected: expected.iter().map(|s| (*s).to_string()).collect(),
57 hints: None,
58 }
59 }
60
61 pub fn with_hints(mut self, hints: AutoHint) -> Self {
63 self.hints = Some(hints);
64 self
65 }
66
67 pub fn matches(&self, got: &str) -> bool {
69 self.expected.iter().any(|e| e == "any" || e == got)
70 }
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct RouteTestResult {
76 pub label: String,
78 pub prompt: String,
80 pub expected: Vec<String>,
82 pub got: Option<String>,
84 #[serde(skip_serializing_if = "Option::is_none")]
86 pub route: Option<AutoRoute>,
87 pub pass: bool,
89 #[serde(skip_serializing_if = "Option::is_none")]
91 pub error: Option<String>,
92 pub elapsed_ms: u64,
94 #[serde(skip_serializing_if = "Option::is_none")]
96 pub sub_count: Option<usize>,
97}
98
99impl RouteTestResult {
100 pub fn outcome(&self) -> &'static str {
102 if self.error.is_some() {
103 "ERROR"
104 } else if self.pass {
105 "OK"
106 } else {
107 "MISMATCH"
108 }
109 }
110}
111
112#[derive(Debug, Clone, Serialize, Deserialize)]
114pub struct RouteTestSummary {
115 pub results: Vec<RouteTestResult>,
117 pub total: usize,
119 pub correct: usize,
121 pub wrong: usize,
123 pub errors: usize,
125 pub total_elapsed_ms: u64,
127}
128
129impl RouteTestSummary {
130 pub fn accuracy(&self) -> f64 {
132 let denominator = self.total - self.errors;
133 if denominator == 0 {
134 return 0.0;
135 }
136 self.correct as f64 / denominator as f64 * 100.0
137 }
138
139 pub fn to_json(&self) -> String {
141 serde_json::to_string_pretty(self).unwrap_or_default()
142 }
143
144 pub fn failures(&self) -> Vec<&RouteTestResult> {
146 self.results.iter().filter(|r| !r.pass).collect()
147 }
148}
149
150impl std::fmt::Display for RouteTestSummary {
151 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
152 for r in &self.results {
153 let expected = r.expected.join("|");
154 write!(f, "{:<40} expected={:<10} ", r.label, expected)?;
155
156 match r.outcome() {
157 "OK" => {
158 let got = r.got.as_deref().unwrap_or("?");
159 if let Some(n) = r.sub_count {
160 write!(f, "(n={n}) ")?;
161 }
162 writeln!(f, "got={got:<10} OK ({} ms)", r.elapsed_ms)?;
163 }
164 "MISMATCH" => {
165 let got = r.got.as_deref().unwrap_or("?");
166 if let Some(n) = r.sub_count {
167 write!(f, "(n={n}) ")?;
168 }
169 writeln!(f, "got={got:<10} MISMATCH ({} ms)", r.elapsed_ms)?;
170 if let Some(route) = &r.route {
171 let json = serde_json::to_string_pretty(route).unwrap_or_default();
172 writeln!(f, " prompt: {:?}", r.prompt)?;
173 writeln!(f, " route: {json}")?;
174 }
175 }
176 _ => {
177 let err = r.error.as_deref().unwrap_or("unknown");
178 writeln!(f, "ERROR: {err} ({} ms)", r.elapsed_ms)?;
179 writeln!(f, " prompt: {:?}", r.prompt)?;
180 }
181 }
182 }
183
184 writeln!(f)?;
185 writeln!(f, "--- Results ---")?;
186 writeln!(
187 f,
188 "Total: {} Correct: {} Wrong: {} Errors: {} Accuracy: {:.0}% Time: {} ms",
189 self.total,
190 self.correct,
191 self.wrong,
192 self.errors,
193 self.accuracy(),
194 self.total_elapsed_ms,
195 )?;
196 Ok(())
197 }
198}
199
200pub struct RouteTestRunner<'a, S: PoolStore> {
202 pool: &'a Pool<S>,
203}
204
205impl<'a, S: PoolStore + 'static> RouteTestRunner<'a, S> {
206 pub fn new(pool: &'a Pool<S>) -> Self {
208 Self { pool }
209 }
210
211 pub async fn run(&self, cases: &[RouteTestCase]) -> RouteTestSummary {
213 let run_start = Instant::now();
214 let mut results = Vec::with_capacity(cases.len());
215
216 for case in cases {
217 results.push(self.run_case(case).await);
218 }
219
220 let total = results.len();
221 let correct = results.iter().filter(|r| r.pass).count();
222 let errors = results.iter().filter(|r| r.error.is_some()).count();
223 let wrong = total - correct - errors;
224
225 RouteTestSummary {
226 results,
227 total,
228 correct,
229 wrong,
230 errors,
231 total_elapsed_ms: run_start.elapsed().as_millis() as u64,
232 }
233 }
234
235 async fn run_case(&self, case: &RouteTestCase) -> RouteTestResult {
237 let start = Instant::now();
238
239 let result = if let Some(hints) = &case.hints {
240 self.pool.route_with_hints(&case.prompt, hints).await
241 } else {
242 self.pool.route(&case.prompt).await
243 };
244
245 let elapsed_ms = start.elapsed().as_millis() as u64;
246
247 match result {
248 Ok(route) => {
249 let (got, sub_count) = match &route {
250 AutoRoute::Single { .. } => ("single".to_string(), None),
251 AutoRoute::Parallel { prompts } => {
252 ("parallel".to_string(), Some(prompts.len()))
253 }
254 AutoRoute::Chain { steps } => ("chain".to_string(), Some(steps.len())),
255 };
256 let pass = case.matches(&got);
257 RouteTestResult {
258 label: case.label.clone(),
259 prompt: case.prompt.clone(),
260 expected: case.expected.clone(),
261 got: Some(got),
262 route: Some(route),
263 pass,
264 error: None,
265 elapsed_ms,
266 sub_count,
267 }
268 }
269 Err(e) => RouteTestResult {
270 label: case.label.clone(),
271 prompt: case.prompt.clone(),
272 expected: case.expected.clone(),
273 got: None,
274 route: None,
275 pass: false,
276 error: Some(e.to_string()),
277 elapsed_ms,
278 sub_count: None,
279 },
280 }
281 }
282}