Skip to main content

claude_pool/
route_test.rs

1//! Test harness for auto-routing accuracy.
2//!
3//! Provides structured test cases, result collection, and multiple output
4//! formats for validating routing prompt changes. Used by the `route_stress`
5//! example and designed to feed into CI tracking (#286).
6//!
7//! # Usage
8//!
9//! ```rust,no_run
10//! use claude_pool::{Pool, route_test::{RouteTestCase, RouteTestRunner}};
11//! use claude_wrapper::Claude;
12//!
13//! # async fn example() -> anyhow::Result<()> {
14//! let claude = Claude::builder().build()?;
15//! let pool = Pool::builder(claude).slots(1).build().await?;
16//!
17//! let cases = vec![
18//!     RouteTestCase::new("trivial", "What is 2+2?", &["single"]),
19//! ];
20//!
21//! let runner = RouteTestRunner::new(&pool);
22//! let summary = runner.run(&cases).await;
23//! println!("{summary}");
24//! # Ok(())
25//! # }
26//! ```
27
28use std::time::Instant;
29
30use serde::{Deserialize, Serialize};
31
32use crate::auto::{AutoHint, AutoRoute};
33use crate::pool::Pool;
34use crate::store::PoolStore;
35
36/// A single test case for the routing classifier.
37#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct RouteTestCase {
39    /// Short label for display and tracking.
40    pub label: String,
41    /// The task prompt to classify.
42    pub prompt: String,
43    /// Acceptable route types. Use `["any"]` to accept anything.
44    pub expected: Vec<String>,
45    /// Optional hints to pass to the router.
46    #[serde(default, skip_serializing_if = "Option::is_none")]
47    pub hints: Option<AutoHint>,
48}
49
50impl RouteTestCase {
51    /// Create a test case with one or more acceptable outcomes.
52    pub fn new(label: impl Into<String>, prompt: impl Into<String>, expected: &[&str]) -> Self {
53        Self {
54            label: label.into(),
55            prompt: prompt.into(),
56            expected: expected.iter().map(|s| (*s).to_string()).collect(),
57            hints: None,
58        }
59    }
60
61    /// Add hints to this test case.
62    pub fn with_hints(mut self, hints: AutoHint) -> Self {
63        self.hints = Some(hints);
64        self
65    }
66
67    /// Check if a route name matches the expected outcomes.
68    pub fn matches(&self, got: &str) -> bool {
69        self.expected.iter().any(|e| e == "any" || e == got)
70    }
71}
72
73/// Result of running a single test case.
74#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct RouteTestResult {
76    /// The test case label.
77    pub label: String,
78    /// The prompt that was classified.
79    pub prompt: String,
80    /// Expected route types.
81    pub expected: Vec<String>,
82    /// The route type returned, if successful.
83    pub got: Option<String>,
84    /// The full route decision, if successful.
85    #[serde(skip_serializing_if = "Option::is_none")]
86    pub route: Option<AutoRoute>,
87    /// Whether the result matched expectations.
88    pub pass: bool,
89    /// Error message, if the routing call failed.
90    #[serde(skip_serializing_if = "Option::is_none")]
91    pub error: Option<String>,
92    /// Wall-clock time for this case in milliseconds.
93    pub elapsed_ms: u64,
94    /// Number of sub-items (parallel prompts or chain steps).
95    #[serde(skip_serializing_if = "Option::is_none")]
96    pub sub_count: Option<usize>,
97}
98
99impl RouteTestResult {
100    /// Outcome as a short string: "OK", "MISMATCH", or "ERROR".
101    pub fn outcome(&self) -> &'static str {
102        if self.error.is_some() {
103            "ERROR"
104        } else if self.pass {
105            "OK"
106        } else {
107            "MISMATCH"
108        }
109    }
110}
111
112/// Summary of a full test run.
113#[derive(Debug, Clone, Serialize, Deserialize)]
114pub struct RouteTestSummary {
115    /// Per-case results in execution order.
116    pub results: Vec<RouteTestResult>,
117    /// Total number of cases.
118    pub total: usize,
119    /// Cases that matched expectations.
120    pub correct: usize,
121    /// Cases that returned a route but didn't match.
122    pub wrong: usize,
123    /// Cases that errored (no route returned).
124    pub errors: usize,
125    /// Total wall-clock time in milliseconds.
126    pub total_elapsed_ms: u64,
127}
128
129impl RouteTestSummary {
130    /// Accuracy as a percentage (correct / (total - errors)).
131    pub fn accuracy(&self) -> f64 {
132        let denominator = self.total - self.errors;
133        if denominator == 0 {
134            return 0.0;
135        }
136        self.correct as f64 / denominator as f64 * 100.0
137    }
138
139    /// Serialize to pretty JSON.
140    pub fn to_json(&self) -> String {
141        serde_json::to_string_pretty(self).unwrap_or_default()
142    }
143
144    /// Return only the failed/errored results.
145    pub fn failures(&self) -> Vec<&RouteTestResult> {
146        self.results.iter().filter(|r| !r.pass).collect()
147    }
148}
149
150impl std::fmt::Display for RouteTestSummary {
151    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
152        for r in &self.results {
153            let expected = r.expected.join("|");
154            write!(f, "{:<40} expected={:<10} ", r.label, expected)?;
155
156            match r.outcome() {
157                "OK" => {
158                    let got = r.got.as_deref().unwrap_or("?");
159                    if let Some(n) = r.sub_count {
160                        write!(f, "(n={n}) ")?;
161                    }
162                    writeln!(f, "got={got:<10} OK  ({} ms)", r.elapsed_ms)?;
163                }
164                "MISMATCH" => {
165                    let got = r.got.as_deref().unwrap_or("?");
166                    if let Some(n) = r.sub_count {
167                        write!(f, "(n={n}) ")?;
168                    }
169                    writeln!(f, "got={got:<10} MISMATCH  ({} ms)", r.elapsed_ms)?;
170                    if let Some(route) = &r.route {
171                        let json = serde_json::to_string_pretty(route).unwrap_or_default();
172                        writeln!(f, "  prompt: {:?}", r.prompt)?;
173                        writeln!(f, "  route:  {json}")?;
174                    }
175                }
176                _ => {
177                    let err = r.error.as_deref().unwrap_or("unknown");
178                    writeln!(f, "ERROR: {err}  ({} ms)", r.elapsed_ms)?;
179                    writeln!(f, "  prompt: {:?}", r.prompt)?;
180                }
181            }
182        }
183
184        writeln!(f)?;
185        writeln!(f, "--- Results ---")?;
186        writeln!(
187            f,
188            "Total: {}  Correct: {}  Wrong: {}  Errors: {}  Accuracy: {:.0}%  Time: {} ms",
189            self.total,
190            self.correct,
191            self.wrong,
192            self.errors,
193            self.accuracy(),
194            self.total_elapsed_ms,
195        )?;
196        Ok(())
197    }
198}
199
200/// Runs routing test cases against a pool.
201pub struct RouteTestRunner<'a, S: PoolStore> {
202    pool: &'a Pool<S>,
203}
204
205impl<'a, S: PoolStore + 'static> RouteTestRunner<'a, S> {
206    /// Create a runner for the given pool.
207    pub fn new(pool: &'a Pool<S>) -> Self {
208        Self { pool }
209    }
210
211    /// Run all test cases and return a summary.
212    pub async fn run(&self, cases: &[RouteTestCase]) -> RouteTestSummary {
213        let run_start = Instant::now();
214        let mut results = Vec::with_capacity(cases.len());
215
216        for case in cases {
217            results.push(self.run_case(case).await);
218        }
219
220        let total = results.len();
221        let correct = results.iter().filter(|r| r.pass).count();
222        let errors = results.iter().filter(|r| r.error.is_some()).count();
223        let wrong = total - correct - errors;
224
225        RouteTestSummary {
226            results,
227            total,
228            correct,
229            wrong,
230            errors,
231            total_elapsed_ms: run_start.elapsed().as_millis() as u64,
232        }
233    }
234
235    /// Run a single test case.
236    async fn run_case(&self, case: &RouteTestCase) -> RouteTestResult {
237        let start = Instant::now();
238
239        let result = if let Some(hints) = &case.hints {
240            self.pool.route_with_hints(&case.prompt, hints).await
241        } else {
242            self.pool.route(&case.prompt).await
243        };
244
245        let elapsed_ms = start.elapsed().as_millis() as u64;
246
247        match result {
248            Ok(route) => {
249                let (got, sub_count) = match &route {
250                    AutoRoute::Single { .. } => ("single".to_string(), None),
251                    AutoRoute::Parallel { prompts } => {
252                        ("parallel".to_string(), Some(prompts.len()))
253                    }
254                    AutoRoute::Chain { steps } => ("chain".to_string(), Some(steps.len())),
255                };
256                let pass = case.matches(&got);
257                RouteTestResult {
258                    label: case.label.clone(),
259                    prompt: case.prompt.clone(),
260                    expected: case.expected.clone(),
261                    got: Some(got),
262                    route: Some(route),
263                    pass,
264                    error: None,
265                    elapsed_ms,
266                    sub_count,
267                }
268            }
269            Err(e) => RouteTestResult {
270                label: case.label.clone(),
271                prompt: case.prompt.clone(),
272                expected: case.expected.clone(),
273                got: None,
274                route: None,
275                pass: false,
276                error: Some(e.to_string()),
277                elapsed_ms,
278                sub_count: None,
279            },
280        }
281    }
282}