use std::time::Instant;
use serde::{Deserialize, Serialize};
use crate::auto::{AutoHint, AutoRoute};
use crate::pool::Pool;
use crate::store::PoolStore;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RouteTestCase {
pub label: String,
pub prompt: String,
pub expected: Vec<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub hints: Option<AutoHint>,
}
impl RouteTestCase {
pub fn new(label: impl Into<String>, prompt: impl Into<String>, expected: &[&str]) -> Self {
Self {
label: label.into(),
prompt: prompt.into(),
expected: expected.iter().map(|s| (*s).to_string()).collect(),
hints: None,
}
}
pub fn with_hints(mut self, hints: AutoHint) -> Self {
self.hints = Some(hints);
self
}
pub fn matches(&self, got: &str) -> bool {
self.expected.iter().any(|e| e == "any" || e == got)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RouteTestResult {
pub label: String,
pub prompt: String,
pub expected: Vec<String>,
pub got: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub route: Option<AutoRoute>,
pub pass: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<String>,
pub elapsed_ms: u64,
#[serde(skip_serializing_if = "Option::is_none")]
pub sub_count: Option<usize>,
}
impl RouteTestResult {
pub fn outcome(&self) -> &'static str {
if self.error.is_some() {
"ERROR"
} else if self.pass {
"OK"
} else {
"MISMATCH"
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RouteTestSummary {
pub results: Vec<RouteTestResult>,
pub total: usize,
pub correct: usize,
pub wrong: usize,
pub errors: usize,
pub total_elapsed_ms: u64,
}
impl RouteTestSummary {
pub fn accuracy(&self) -> f64 {
let denominator = self.total - self.errors;
if denominator == 0 {
return 0.0;
}
self.correct as f64 / denominator as f64 * 100.0
}
pub fn to_json(&self) -> String {
serde_json::to_string_pretty(self).unwrap_or_default()
}
pub fn failures(&self) -> Vec<&RouteTestResult> {
self.results.iter().filter(|r| !r.pass).collect()
}
}
impl std::fmt::Display for RouteTestSummary {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
for r in &self.results {
let expected = r.expected.join("|");
write!(f, "{:<40} expected={:<10} ", r.label, expected)?;
match r.outcome() {
"OK" => {
let got = r.got.as_deref().unwrap_or("?");
if let Some(n) = r.sub_count {
write!(f, "(n={n}) ")?;
}
writeln!(f, "got={got:<10} OK ({} ms)", r.elapsed_ms)?;
}
"MISMATCH" => {
let got = r.got.as_deref().unwrap_or("?");
if let Some(n) = r.sub_count {
write!(f, "(n={n}) ")?;
}
writeln!(f, "got={got:<10} MISMATCH ({} ms)", r.elapsed_ms)?;
if let Some(route) = &r.route {
let json = serde_json::to_string_pretty(route).unwrap_or_default();
writeln!(f, " prompt: {:?}", r.prompt)?;
writeln!(f, " route: {json}")?;
}
}
_ => {
let err = r.error.as_deref().unwrap_or("unknown");
writeln!(f, "ERROR: {err} ({} ms)", r.elapsed_ms)?;
writeln!(f, " prompt: {:?}", r.prompt)?;
}
}
}
writeln!(f)?;
writeln!(f, "--- Results ---")?;
writeln!(
f,
"Total: {} Correct: {} Wrong: {} Errors: {} Accuracy: {:.0}% Time: {} ms",
self.total,
self.correct,
self.wrong,
self.errors,
self.accuracy(),
self.total_elapsed_ms,
)?;
Ok(())
}
}
pub struct RouteTestRunner<'a, S: PoolStore> {
pool: &'a Pool<S>,
}
impl<'a, S: PoolStore + 'static> RouteTestRunner<'a, S> {
pub fn new(pool: &'a Pool<S>) -> Self {
Self { pool }
}
pub async fn run(&self, cases: &[RouteTestCase]) -> RouteTestSummary {
let run_start = Instant::now();
let mut results = Vec::with_capacity(cases.len());
for case in cases {
results.push(self.run_case(case).await);
}
let total = results.len();
let correct = results.iter().filter(|r| r.pass).count();
let errors = results.iter().filter(|r| r.error.is_some()).count();
let wrong = total - correct - errors;
RouteTestSummary {
results,
total,
correct,
wrong,
errors,
total_elapsed_ms: run_start.elapsed().as_millis() as u64,
}
}
async fn run_case(&self, case: &RouteTestCase) -> RouteTestResult {
let start = Instant::now();
let result = if let Some(hints) = &case.hints {
self.pool.route_with_hints(&case.prompt, hints).await
} else {
self.pool.route(&case.prompt).await
};
let elapsed_ms = start.elapsed().as_millis() as u64;
match result {
Ok(route) => {
let (got, sub_count) = match &route {
AutoRoute::Single { .. } => ("single".to_string(), None),
AutoRoute::Parallel { prompts } => {
("parallel".to_string(), Some(prompts.len()))
}
AutoRoute::Chain { steps } => ("chain".to_string(), Some(steps.len())),
};
let pass = case.matches(&got);
RouteTestResult {
label: case.label.clone(),
prompt: case.prompt.clone(),
expected: case.expected.clone(),
got: Some(got),
route: Some(route),
pass,
error: None,
elapsed_ms,
sub_count,
}
}
Err(e) => RouteTestResult {
label: case.label.clone(),
prompt: case.prompt.clone(),
expected: case.expected.clone(),
got: None,
route: None,
pass: false,
error: Some(e.to_string()),
elapsed_ms,
sub_count: None,
},
}
}
}