use std::collections::HashMap;
use std::sync::Arc;
use anyhow::Result;
use serde::{Deserialize, Serialize};
use super::case::EvaluationCase;
use super::trial::{EvaluationStats, TrialResult};
#[derive(Debug, Serialize, Deserialize)]
pub struct SuiteResult {
pub case_results: HashMap<String, Vec<TrialResult>>,
pub stats: HashMap<String, EvaluationStats>,
}
impl SuiteResult {
pub fn overall_success_rate(&self) -> f64 {
let total: usize = self.case_results.values().map(|v| v.len()).sum();
if total == 0 {
return 0.0;
}
let successes: usize = self
.case_results
.values()
.flat_map(|v| v.iter())
.filter(|r| r.success)
.count();
successes as f64 / total as f64
}
pub fn failing_cases(&self, threshold: f64) -> Vec<&str> {
self.stats
.iter()
.filter(|(_, s)| s.success_rate < threshold)
.map(|(name, _)| name.as_str())
.collect()
}
}
#[derive(Debug, Clone)]
pub struct SuiteConfig {
pub n_trials: usize,
pub max_parallel: usize,
pub catch_errors_as_failures: bool,
}
impl Default for SuiteConfig {
fn default() -> Self {
Self {
n_trials: 10,
max_parallel: 1,
catch_errors_as_failures: true,
}
}
}
pub struct EvaluationSuite {
config: SuiteConfig,
}
impl EvaluationSuite {
pub fn new(n_trials: usize) -> Self {
Self {
config: SuiteConfig {
n_trials: n_trials.max(1),
..SuiteConfig::default()
},
}
}
pub fn with_config(config: SuiteConfig) -> Self {
Self { config }
}
pub async fn run_case(&self, case: Arc<dyn EvaluationCase>) -> Vec<TrialResult> {
let mut results = Vec::with_capacity(self.config.n_trials);
if self.config.max_parallel <= 1 {
for trial_id in 0..self.config.n_trials {
let result = case.run(trial_id).await;
results.push(self.resolve(result, trial_id));
}
} else {
use tokio::sync::Semaphore;
let sem = Arc::new(Semaphore::new(self.config.max_parallel));
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<TrialResult>();
for trial_id in 0..self.config.n_trials {
let permit = sem.clone().acquire_owned().await.unwrap();
let tx = tx.clone();
let case_arc = Arc::clone(&case);
let catch_errors = self.config.catch_errors_as_failures;
tokio::spawn(async move {
let _permit = permit;
let result = case_arc.run(trial_id).await;
let trial = match result {
Ok(t) => t,
Err(e) if catch_errors => TrialResult::failure(trial_id, 0, e.to_string()),
Err(e) => {
tracing::error!(
"Trial {} errored (catch_errors_as_failures=false): {}",
trial_id,
e
);
TrialResult::failure(trial_id, 0, format!("Trial errored: {e}"))
}
};
if tx.send(trial).is_err() {
tracing::warn!("Trial {} result dropped: receiver closed", trial_id);
}
});
}
drop(tx);
while let Some(t) = rx.recv().await {
results.push(t);
}
results.sort_by_key(|r| r.trial_id);
}
results
}
pub async fn run_suite(&self, cases: &[Arc<dyn EvaluationCase>]) -> SuiteResult {
let mut case_results: HashMap<String, Vec<TrialResult>> = HashMap::new();
let mut stats: HashMap<String, EvaluationStats> = HashMap::new();
for case in cases {
let results = self.run_case(Arc::clone(case)).await;
let case_stats =
EvaluationStats::from_trials(&results).expect("case must have at least one trial");
let name = case.name().to_string();
tracing::info!(
case = %name,
n = case_stats.n_trials,
success_rate = %format!("{:.1}%", case_stats.success_rate * 100.0),
ci_low = %format!("{:.3}", case_stats.confidence_interval_95.lower),
ci_high = %format!("{:.3}", case_stats.confidence_interval_95.upper),
"EvaluationSuite: case complete"
);
case_results.insert(name.clone(), results);
stats.insert(name, case_stats);
}
SuiteResult {
case_results,
stats,
}
}
fn resolve(&self, result: Result<TrialResult>, trial_id: usize) -> TrialResult {
match result {
Ok(t) => t,
Err(e) if self.config.catch_errors_as_failures => {
TrialResult::failure(trial_id, 0, e.to_string())
}
Err(e) => {
tracing::error!(
"Trial {} errored (catch_errors_as_failures=false): {}",
trial_id,
e
);
TrialResult::failure(trial_id, 0, format!("Trial errored: {e}"))
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::eval::case::{AlwaysFailCase, AlwaysPassCase, StochasticCase};
#[tokio::test]
async fn test_suite_all_pass() {
let suite = EvaluationSuite::new(5);
let case = Arc::new(AlwaysPassCase::new("ok"));
let result = suite.run_suite(&[case]).await;
let stats = result.stats.get("ok").unwrap();
assert_eq!(stats.n_trials, 5);
assert_eq!(stats.successes, 5);
assert!((stats.success_rate - 1.0).abs() < 1e-9);
assert!((result.overall_success_rate() - 1.0).abs() < 1e-9);
}
#[tokio::test]
async fn test_suite_all_fail() {
let suite = EvaluationSuite::new(3);
let case = Arc::new(AlwaysFailCase::new("bad", "expected"));
let result = suite.run_suite(&[case]).await;
let stats = result.stats.get("bad").unwrap();
assert_eq!(stats.successes, 0);
assert_eq!(stats.success_rate, 0.0);
}
#[tokio::test]
async fn test_suite_multiple_cases() {
let suite = EvaluationSuite::new(10);
let cases: Vec<Arc<dyn EvaluationCase>> = vec![
Arc::new(AlwaysPassCase::new("pass")),
Arc::new(AlwaysFailCase::new("fail", "x")),
];
let result = suite.run_suite(&cases).await;
assert!(result.stats.contains_key("pass"));
assert!(result.stats.contains_key("fail"));
assert!((result.overall_success_rate() - 0.5).abs() < 1e-9);
}
#[tokio::test]
async fn test_suite_n_trials_minimum_one() {
let suite = EvaluationSuite::new(0); let case = Arc::new(AlwaysPassCase::new("x"));
let result = suite.run_suite(&[case]).await;
assert_eq!(result.stats["x"].n_trials, 1);
}
#[tokio::test]
async fn test_run_case_returns_correct_count() {
let suite = EvaluationSuite::new(7);
let case = Arc::new(AlwaysPassCase::new("seven"));
let results = suite.run_case(case).await;
assert_eq!(results.len(), 7);
for (i, r) in results.iter().enumerate() {
assert_eq!(r.trial_id, i);
}
}
#[tokio::test]
async fn test_failing_cases_filter() {
let suite = EvaluationSuite::new(10);
let cases: Vec<Arc<dyn EvaluationCase>> = vec![
Arc::new(AlwaysPassCase::new("good")),
Arc::new(StochasticCase::new("flaky", 0.0)), ];
let result = suite.run_suite(&cases).await;
let failing = result.failing_cases(0.5);
assert!(
failing.contains(&"flaky"),
"flaky should be in failing list"
);
assert!(
!failing.contains(&"good"),
"good should not be in failing list"
);
}
#[tokio::test]
async fn test_confidence_interval_in_suite_result() {
let suite = EvaluationSuite::new(50);
let case = Arc::new(StochasticCase::new("ci_test", 0.8));
let result = suite.run_suite(&[case]).await;
let stats = &result.stats["ci_test"];
let ci = stats.confidence_interval_95;
assert!(ci.lower < 0.85 && ci.upper > 0.65);
}
}