Skip to main content

brainwires_eval/
suite.rs

1//! Evaluation suite — N-trial Monte Carlo runner.
2//!
3//! [`EvaluationSuite`] runs each registered [`EvaluationCase`] N times,
4//! collects [`TrialResult`]s, and computes [`EvaluationStats`] for every case.
5
6use std::collections::HashMap;
7use std::sync::Arc;
8
9use anyhow::Result;
10use serde::{Deserialize, Serialize};
11
12use super::case::EvaluationCase;
13use super::trial::{EvaluationStats, TrialResult};
14
15// ── Suite result ──────────────────────────────────────────────────────────────
16
17/// Aggregated results for all cases in a suite run.
18#[derive(Debug, Serialize, Deserialize)]
19pub struct SuiteResult {
20    /// Raw trial results keyed by case name.
21    pub case_results: HashMap<String, Vec<TrialResult>>,
22    /// Summary statistics keyed by case name.
23    pub stats: HashMap<String, EvaluationStats>,
24}
25
26impl SuiteResult {
27    /// Overall success rate across *all* cases and trials.
28    pub fn overall_success_rate(&self) -> f64 {
29        let total: usize = self.case_results.values().map(|v| v.len()).sum();
30        if total == 0 {
31            return 0.0;
32        }
33        let successes: usize = self
34            .case_results
35            .values()
36            .flat_map(|v| v.iter())
37            .filter(|r| r.success)
38            .count();
39        successes as f64 / total as f64
40    }
41
42    /// Returns all cases whose success rate is strictly below `threshold`.
43    pub fn failing_cases(&self, threshold: f64) -> Vec<&str> {
44        self.stats
45            .iter()
46            .filter(|(_, s)| s.success_rate < threshold)
47            .map(|(name, _)| name.as_str())
48            .collect()
49    }
50}
51
52// ── Suite configuration ───────────────────────────────────────────────────────
53
54/// Configuration for [`EvaluationSuite`].
55#[derive(Debug, Clone)]
56pub struct SuiteConfig {
57    /// Number of times each case is run.  Minimum 1.
58    pub n_trials: usize,
59    /// Maximum number of trials to execute concurrently per case.
60    /// `1` means sequential execution (deterministic ordering).
61    pub max_parallel: usize,
62    /// If `true`, a single trial error (not a test failure, but a hard Rust
63    /// error) is treated as a test failure rather than propagating to the
64    /// caller.
65    pub catch_errors_as_failures: bool,
66}
67
68impl Default for SuiteConfig {
69    fn default() -> Self {
70        Self {
71            n_trials: 10,
72            max_parallel: 1,
73            catch_errors_as_failures: true,
74        }
75    }
76}
77
78// ── Suite ─────────────────────────────────────────────────────────────────────
79
80/// N-trial Monte Carlo evaluation runner.
81///
82/// ## Quick start
83/// ```rust,ignore
84/// use brainwires_eval::{EvaluationSuite, AlwaysPassCase};
85/// use std::sync::Arc;
86///
87/// #[tokio::main]
88/// async fn main() {
89///     let suite = EvaluationSuite::new(30);
90///     let case = Arc::new(AlwaysPassCase::new("smoke"));
91///     let results = suite.run_suite(&[case]).await;
92///     println!("overall: {:.1}%", results.overall_success_rate() * 100.0);
93/// }
94/// ```
95pub struct EvaluationSuite {
96    config: SuiteConfig,
97}
98
99impl EvaluationSuite {
100    /// Create a suite that runs each case `n_trials` times sequentially.
101    pub fn new(n_trials: usize) -> Self {
102        Self {
103            config: SuiteConfig {
104                n_trials: n_trials.max(1),
105                ..SuiteConfig::default()
106            },
107        }
108    }
109
110    /// Override the full configuration.
111    pub fn with_config(config: SuiteConfig) -> Self {
112        Self { config }
113    }
114
115    /// Run `n_trials` for a single case and return the raw results.
116    ///
117    /// Accepts `Arc<dyn EvaluationCase>` so the case can be shared across
118    /// parallel async tasks when `max_parallel > 1`.
119    pub async fn run_case(&self, case: Arc<dyn EvaluationCase>) -> Vec<TrialResult> {
120        let mut results = Vec::with_capacity(self.config.n_trials);
121
122        if self.config.max_parallel <= 1 {
123            // Sequential
124            for trial_id in 0..self.config.n_trials {
125                let result = case.run(trial_id).await;
126                results.push(self.resolve(result, trial_id));
127            }
128        } else {
129            // Bounded parallel execution using a semaphore
130            use tokio::sync::Semaphore;
131            let sem = Arc::new(Semaphore::new(self.config.max_parallel));
132            let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<TrialResult>();
133
134            for trial_id in 0..self.config.n_trials {
135                let permit = sem.clone().acquire_owned().await.unwrap();
136                let tx = tx.clone();
137                let case_arc = Arc::clone(&case);
138                let catch_errors = self.config.catch_errors_as_failures;
139
140                tokio::spawn(async move {
141                    let _permit = permit;
142                    let result = case_arc.run(trial_id).await;
143                    let trial = match result {
144                        Ok(t) => t,
145                        Err(e) if catch_errors => TrialResult::failure(trial_id, 0, e.to_string()),
146                        Err(e) => {
147                            tracing::error!(
148                                "Trial {} errored (catch_errors_as_failures=false): {}",
149                                trial_id,
150                                e
151                            );
152                            TrialResult::failure(trial_id, 0, format!("Trial errored: {e}"))
153                        }
154                    };
155                    if tx.send(trial).is_err() {
156                        tracing::warn!("Trial {} result dropped: receiver closed", trial_id);
157                    }
158                });
159            }
160            drop(tx);
161
162            // Drain the channel once all producers have finished.
163            while let Some(t) = rx.recv().await {
164                results.push(t);
165            }
166
167            // Sort by trial_id for deterministic output order.
168            results.sort_by_key(|r| r.trial_id);
169        }
170
171        results
172    }
173
174    /// Run the full suite: execute each case N times and return aggregated results.
175    pub async fn run_suite(&self, cases: &[Arc<dyn EvaluationCase>]) -> SuiteResult {
176        let mut case_results: HashMap<String, Vec<TrialResult>> = HashMap::new();
177        let mut stats: HashMap<String, EvaluationStats> = HashMap::new();
178
179        for case in cases {
180            let results = self.run_case(Arc::clone(case)).await;
181            let case_stats =
182                EvaluationStats::from_trials(&results).expect("case must have at least one trial");
183            let name = case.name().to_string();
184            tracing::info!(
185                case = %name,
186                n = case_stats.n_trials,
187                success_rate = %format!("{:.1}%", case_stats.success_rate * 100.0),
188                ci_low = %format!("{:.3}", case_stats.confidence_interval_95.lower),
189                ci_high = %format!("{:.3}", case_stats.confidence_interval_95.upper),
190                "EvaluationSuite: case complete"
191            );
192            case_results.insert(name.clone(), results);
193            stats.insert(name, case_stats);
194        }
195
196        SuiteResult {
197            case_results,
198            stats,
199        }
200    }
201
202    /// Resolve a `Result<TrialResult>` from a case run into a `TrialResult`,
203    /// converting errors into failures when `catch_errors_as_failures` is set.
204    fn resolve(&self, result: Result<TrialResult>, trial_id: usize) -> TrialResult {
205        match result {
206            Ok(t) => t,
207            Err(e) if self.config.catch_errors_as_failures => {
208                TrialResult::failure(trial_id, 0, e.to_string())
209            }
210            Err(e) => {
211                tracing::error!(
212                    "Trial {} errored (catch_errors_as_failures=false): {}",
213                    trial_id,
214                    e
215                );
216                TrialResult::failure(trial_id, 0, format!("Trial errored: {e}"))
217            }
218        }
219    }
220}
221
222#[cfg(test)]
223mod tests {
224    use super::*;
225    use crate::case::{AlwaysFailCase, AlwaysPassCase, StochasticCase};
226
227    #[tokio::test]
228    async fn test_suite_all_pass() {
229        let suite = EvaluationSuite::new(5);
230        let case = Arc::new(AlwaysPassCase::new("ok"));
231        let result = suite.run_suite(&[case]).await;
232
233        let stats = result.stats.get("ok").unwrap();
234        assert_eq!(stats.n_trials, 5);
235        assert_eq!(stats.successes, 5);
236        assert!((stats.success_rate - 1.0).abs() < 1e-9);
237        assert!((result.overall_success_rate() - 1.0).abs() < 1e-9);
238    }
239
240    #[tokio::test]
241    async fn test_suite_all_fail() {
242        let suite = EvaluationSuite::new(3);
243        let case = Arc::new(AlwaysFailCase::new("bad", "expected"));
244        let result = suite.run_suite(&[case]).await;
245
246        let stats = result.stats.get("bad").unwrap();
247        assert_eq!(stats.successes, 0);
248        assert_eq!(stats.success_rate, 0.0);
249    }
250
251    #[tokio::test]
252    async fn test_suite_multiple_cases() {
253        let suite = EvaluationSuite::new(10);
254        let cases: Vec<Arc<dyn EvaluationCase>> = vec![
255            Arc::new(AlwaysPassCase::new("pass")),
256            Arc::new(AlwaysFailCase::new("fail", "x")),
257        ];
258        let result = suite.run_suite(&cases).await;
259        assert!(result.stats.contains_key("pass"));
260        assert!(result.stats.contains_key("fail"));
261        assert!((result.overall_success_rate() - 0.5).abs() < 1e-9);
262    }
263
264    #[tokio::test]
265    async fn test_suite_n_trials_minimum_one() {
266        let suite = EvaluationSuite::new(0); // Should clamp to 1
267        let case = Arc::new(AlwaysPassCase::new("x"));
268        let result = suite.run_suite(&[case]).await;
269        assert_eq!(result.stats["x"].n_trials, 1);
270    }
271
272    #[tokio::test]
273    async fn test_run_case_returns_correct_count() {
274        let suite = EvaluationSuite::new(7);
275        let case = Arc::new(AlwaysPassCase::new("seven"));
276        let results = suite.run_case(case).await;
277        assert_eq!(results.len(), 7);
278        for (i, r) in results.iter().enumerate() {
279            assert_eq!(r.trial_id, i);
280        }
281    }
282
283    #[tokio::test]
284    async fn test_failing_cases_filter() {
285        let suite = EvaluationSuite::new(10);
286        let cases: Vec<Arc<dyn EvaluationCase>> = vec![
287            Arc::new(AlwaysPassCase::new("good")),
288            Arc::new(StochasticCase::new("flaky", 0.0)), // always fails
289        ];
290        let result = suite.run_suite(&cases).await;
291        let failing = result.failing_cases(0.5);
292        assert!(
293            failing.contains(&"flaky"),
294            "flaky should be in failing list"
295        );
296        assert!(
297            !failing.contains(&"good"),
298            "good should not be in failing list"
299        );
300    }
301
302    #[tokio::test]
303    async fn test_confidence_interval_in_suite_result() {
304        let suite = EvaluationSuite::new(50);
305        let case = Arc::new(StochasticCase::new("ci_test", 0.8));
306        let result = suite.run_suite(&[case]).await;
307        let stats = &result.stats["ci_test"];
308        let ci = stats.confidence_interval_95;
309        // With ~40/50 successes the 95 % CI should comfortably contain 0.8
310        assert!(ci.lower < 0.85 && ci.upper > 0.65);
311    }
312}