1use std::collections::HashMap;
7use std::sync::Arc;
8
9use anyhow::Result;
10use serde::{Deserialize, Serialize};
11
12use super::case::EvaluationCase;
13use super::trial::{EvaluationStats, TrialResult};
14
15#[derive(Debug, Serialize, Deserialize)]
19pub struct SuiteResult {
20 pub case_results: HashMap<String, Vec<TrialResult>>,
22 pub stats: HashMap<String, EvaluationStats>,
24}
25
26impl SuiteResult {
27 pub fn overall_success_rate(&self) -> f64 {
29 let total: usize = self.case_results.values().map(|v| v.len()).sum();
30 if total == 0 {
31 return 0.0;
32 }
33 let successes: usize = self
34 .case_results
35 .values()
36 .flat_map(|v| v.iter())
37 .filter(|r| r.success)
38 .count();
39 successes as f64 / total as f64
40 }
41
42 pub fn failing_cases(&self, threshold: f64) -> Vec<&str> {
44 self.stats
45 .iter()
46 .filter(|(_, s)| s.success_rate < threshold)
47 .map(|(name, _)| name.as_str())
48 .collect()
49 }
50}
51
52#[derive(Debug, Clone)]
56pub struct SuiteConfig {
57 pub n_trials: usize,
59 pub max_parallel: usize,
62 pub catch_errors_as_failures: bool,
66}
67
68impl Default for SuiteConfig {
69 fn default() -> Self {
70 Self {
71 n_trials: 10,
72 max_parallel: 1,
73 catch_errors_as_failures: true,
74 }
75 }
76}
77
78pub struct EvaluationSuite {
96 config: SuiteConfig,
97}
98
99impl EvaluationSuite {
100 pub fn new(n_trials: usize) -> Self {
102 Self {
103 config: SuiteConfig {
104 n_trials: n_trials.max(1),
105 ..SuiteConfig::default()
106 },
107 }
108 }
109
110 pub fn with_config(config: SuiteConfig) -> Self {
112 Self { config }
113 }
114
115 pub async fn run_case(&self, case: Arc<dyn EvaluationCase>) -> Vec<TrialResult> {
120 let mut results = Vec::with_capacity(self.config.n_trials);
121
122 if self.config.max_parallel <= 1 {
123 for trial_id in 0..self.config.n_trials {
125 let result = case.run(trial_id).await;
126 results.push(self.resolve(result, trial_id));
127 }
128 } else {
129 use tokio::sync::Semaphore;
131 let sem = Arc::new(Semaphore::new(self.config.max_parallel));
132 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<TrialResult>();
133
134 for trial_id in 0..self.config.n_trials {
135 let permit = sem.clone().acquire_owned().await.unwrap();
136 let tx = tx.clone();
137 let case_arc = Arc::clone(&case);
138 let catch_errors = self.config.catch_errors_as_failures;
139
140 tokio::spawn(async move {
141 let _permit = permit;
142 let result = case_arc.run(trial_id).await;
143 let trial = match result {
144 Ok(t) => t,
145 Err(e) if catch_errors => TrialResult::failure(trial_id, 0, e.to_string()),
146 Err(e) => {
147 tracing::error!(
148 "Trial {} errored (catch_errors_as_failures=false): {}",
149 trial_id,
150 e
151 );
152 TrialResult::failure(trial_id, 0, format!("Trial errored: {e}"))
153 }
154 };
155 if tx.send(trial).is_err() {
156 tracing::warn!("Trial {} result dropped: receiver closed", trial_id);
157 }
158 });
159 }
160 drop(tx);
161
162 while let Some(t) = rx.recv().await {
164 results.push(t);
165 }
166
167 results.sort_by_key(|r| r.trial_id);
169 }
170
171 results
172 }
173
174 pub async fn run_suite(&self, cases: &[Arc<dyn EvaluationCase>]) -> SuiteResult {
176 let mut case_results: HashMap<String, Vec<TrialResult>> = HashMap::new();
177 let mut stats: HashMap<String, EvaluationStats> = HashMap::new();
178
179 for case in cases {
180 let results = self.run_case(Arc::clone(case)).await;
181 let case_stats =
182 EvaluationStats::from_trials(&results).expect("case must have at least one trial");
183 let name = case.name().to_string();
184 tracing::info!(
185 case = %name,
186 n = case_stats.n_trials,
187 success_rate = %format!("{:.1}%", case_stats.success_rate * 100.0),
188 ci_low = %format!("{:.3}", case_stats.confidence_interval_95.lower),
189 ci_high = %format!("{:.3}", case_stats.confidence_interval_95.upper),
190 "EvaluationSuite: case complete"
191 );
192 case_results.insert(name.clone(), results);
193 stats.insert(name, case_stats);
194 }
195
196 SuiteResult {
197 case_results,
198 stats,
199 }
200 }
201
202 fn resolve(&self, result: Result<TrialResult>, trial_id: usize) -> TrialResult {
205 match result {
206 Ok(t) => t,
207 Err(e) if self.config.catch_errors_as_failures => {
208 TrialResult::failure(trial_id, 0, e.to_string())
209 }
210 Err(e) => {
211 tracing::error!(
212 "Trial {} errored (catch_errors_as_failures=false): {}",
213 trial_id,
214 e
215 );
216 TrialResult::failure(trial_id, 0, format!("Trial errored: {e}"))
217 }
218 }
219 }
220}
221
222#[cfg(test)]
223mod tests {
224 use super::*;
225 use crate::case::{AlwaysFailCase, AlwaysPassCase, StochasticCase};
226
227 #[tokio::test]
228 async fn test_suite_all_pass() {
229 let suite = EvaluationSuite::new(5);
230 let case = Arc::new(AlwaysPassCase::new("ok"));
231 let result = suite.run_suite(&[case]).await;
232
233 let stats = result.stats.get("ok").unwrap();
234 assert_eq!(stats.n_trials, 5);
235 assert_eq!(stats.successes, 5);
236 assert!((stats.success_rate - 1.0).abs() < 1e-9);
237 assert!((result.overall_success_rate() - 1.0).abs() < 1e-9);
238 }
239
240 #[tokio::test]
241 async fn test_suite_all_fail() {
242 let suite = EvaluationSuite::new(3);
243 let case = Arc::new(AlwaysFailCase::new("bad", "expected"));
244 let result = suite.run_suite(&[case]).await;
245
246 let stats = result.stats.get("bad").unwrap();
247 assert_eq!(stats.successes, 0);
248 assert_eq!(stats.success_rate, 0.0);
249 }
250
251 #[tokio::test]
252 async fn test_suite_multiple_cases() {
253 let suite = EvaluationSuite::new(10);
254 let cases: Vec<Arc<dyn EvaluationCase>> = vec![
255 Arc::new(AlwaysPassCase::new("pass")),
256 Arc::new(AlwaysFailCase::new("fail", "x")),
257 ];
258 let result = suite.run_suite(&cases).await;
259 assert!(result.stats.contains_key("pass"));
260 assert!(result.stats.contains_key("fail"));
261 assert!((result.overall_success_rate() - 0.5).abs() < 1e-9);
262 }
263
264 #[tokio::test]
265 async fn test_suite_n_trials_minimum_one() {
266 let suite = EvaluationSuite::new(0); let case = Arc::new(AlwaysPassCase::new("x"));
268 let result = suite.run_suite(&[case]).await;
269 assert_eq!(result.stats["x"].n_trials, 1);
270 }
271
272 #[tokio::test]
273 async fn test_run_case_returns_correct_count() {
274 let suite = EvaluationSuite::new(7);
275 let case = Arc::new(AlwaysPassCase::new("seven"));
276 let results = suite.run_case(case).await;
277 assert_eq!(results.len(), 7);
278 for (i, r) in results.iter().enumerate() {
279 assert_eq!(r.trial_id, i);
280 }
281 }
282
283 #[tokio::test]
284 async fn test_failing_cases_filter() {
285 let suite = EvaluationSuite::new(10);
286 let cases: Vec<Arc<dyn EvaluationCase>> = vec![
287 Arc::new(AlwaysPassCase::new("good")),
288 Arc::new(StochasticCase::new("flaky", 0.0)), ];
290 let result = suite.run_suite(&cases).await;
291 let failing = result.failing_cases(0.5);
292 assert!(
293 failing.contains(&"flaky"),
294 "flaky should be in failing list"
295 );
296 assert!(
297 !failing.contains(&"good"),
298 "good should not be in failing list"
299 );
300 }
301
302 #[tokio::test]
303 async fn test_confidence_interval_in_suite_result() {
304 let suite = EvaluationSuite::new(50);
305 let case = Arc::new(StochasticCase::new("ci_test", 0.8));
306 let result = suite.run_suite(&[case]).await;
307 let stats = &result.stats["ci_test"];
308 let ci = stats.confidence_interval_95;
309 assert!(ci.lower < 0.85 && ci.upper > 0.65);
311 }
312}