Skip to main content

brainwires_eval/
trial.rs

1//! Evaluation trial results and statistical analysis.
2//!
3//! A *trial* is one execution of an [`EvaluationCase`](crate::case::EvaluationCase).  Run N trials and
4//! summarise with [`EvaluationStats`] which reports the success rate together
5//! with a Wilson-score 95 % confidence interval.
6
7use std::collections::HashMap;
8
9use serde::{Deserialize, Serialize};
10
11// ── Trial result ──────────────────────────────────────────────────────────────
12
13/// Result produced by a single trial run.
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct TrialResult {
16    /// Sequential index of this trial (0-based).
17    pub trial_id: usize,
18    /// Whether the trial succeeded.
19    pub success: bool,
20    /// Wall-clock duration of the trial in milliseconds.
21    pub duration_ms: u64,
22    /// Error message when `success == false`.
23    pub error: Option<String>,
24    /// Arbitrary key-value metadata emitted by the case (e.g. iteration count,
25    /// token usage, tool names used).
26    pub metadata: HashMap<String, serde_json::Value>,
27}
28
29impl TrialResult {
30    /// Create a successful trial result.
31    pub fn success(trial_id: usize, duration_ms: u64) -> Self {
32        Self {
33            trial_id,
34            success: true,
35            duration_ms,
36            error: None,
37            metadata: HashMap::new(),
38        }
39    }
40
41    /// Create a failed trial result.
42    pub fn failure(trial_id: usize, duration_ms: u64, error: impl Into<String>) -> Self {
43        Self {
44            trial_id,
45            success: false,
46            duration_ms,
47            error: Some(error.into()),
48            metadata: HashMap::new(),
49        }
50    }
51
52    /// Attach an arbitrary metadata value.
53    pub fn with_meta(
54        mut self,
55        key: impl Into<String>,
56        value: impl Into<serde_json::Value>,
57    ) -> Self {
58        self.metadata.insert(key.into(), value.into());
59        self
60    }
61}
62
63// ── Confidence interval ───────────────────────────────────────────────────────
64
65/// A symmetric 95 % confidence interval around a proportion.
66#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
67pub struct ConfidenceInterval95 {
68    /// Lower bound (clipped to 0).
69    pub lower: f64,
70    /// Upper bound (clipped to 1).
71    pub upper: f64,
72}
73
74impl ConfidenceInterval95 {
75    /// Compute a Wilson-score 95 % confidence interval.
76    ///
77    /// The Wilson interval is preferred over the naïve Wald interval because it
78    /// behaves well at the extremes (p = 0 or p = 1) and for small N.
79    ///
80    /// Formula: `(p̂ + z²/2n ± z√(p̂(1−p̂)/n + z²/4n²)) / (1 + z²/n)`
81    /// where `z = 1.96` for 95 % confidence.
82    pub fn wilson(successes: usize, n: usize) -> Self {
83        if n == 0 {
84            return Self {
85                lower: 0.0,
86                upper: 1.0,
87            };
88        }
89
90        const Z: f64 = 1.96; // 95 % two-tailed
91        let p = successes as f64 / n as f64;
92        let nf = n as f64;
93        let z2 = Z * Z;
94
95        let centre = p + z2 / (2.0 * nf);
96        let margin = Z * (p * (1.0 - p) / nf + z2 / (4.0 * nf * nf)).sqrt();
97        let denom = 1.0 + z2 / nf;
98
99        Self {
100            lower: ((centre - margin) / denom).clamp(0.0, 1.0),
101            upper: ((centre + margin) / denom).clamp(0.0, 1.0),
102        }
103    }
104}
105
106// ── Summary statistics ────────────────────────────────────────────────────────
107
108/// Aggregate statistics for a set of [`TrialResult`]s from the same case.
109#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct EvaluationStats {
111    /// Total number of trials executed.
112    pub n_trials: usize,
113    /// Number of trials that succeeded.
114    pub successes: usize,
115    /// `successes / n_trials` (0.0 when n_trials == 0).
116    pub success_rate: f64,
117    /// Wilson-score 95 % confidence interval around `success_rate`.
118    pub confidence_interval_95: ConfidenceInterval95,
119    /// Mean trial duration across all trials in milliseconds.
120    pub mean_duration_ms: f64,
121    /// Median (P50) trial duration in milliseconds.
122    pub p50_duration_ms: f64,
123    /// 95th-percentile trial duration in milliseconds.
124    pub p95_duration_ms: f64,
125}
126
127impl EvaluationStats {
128    /// Compute statistics from a slice of trial results.
129    ///
130    /// Returns `None` if `results` is empty.
131    pub fn from_trials(results: &[TrialResult]) -> Option<Self> {
132        let n = results.len();
133        if n == 0 {
134            return None;
135        }
136
137        let successes = results.iter().filter(|r| r.success).count();
138        let success_rate = successes as f64 / n as f64;
139        let ci = ConfidenceInterval95::wilson(successes, n);
140
141        let mut durations: Vec<f64> = results.iter().map(|r| r.duration_ms as f64).collect();
142        durations.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
143
144        let mean_duration_ms = durations.iter().sum::<f64>() / n as f64;
145        let p50_duration_ms = percentile(&durations, 50.0);
146        let p95_duration_ms = percentile(&durations, 95.0);
147
148        Some(Self {
149            n_trials: n,
150            successes,
151            success_rate,
152            confidence_interval_95: ci,
153            mean_duration_ms,
154            p50_duration_ms,
155            p95_duration_ms,
156        })
157    }
158}
159
160/// Compute the p-th percentile of a sorted slice (linear interpolation).
161fn percentile(sorted: &[f64], p: f64) -> f64 {
162    if sorted.is_empty() {
163        return 0.0;
164    }
165    if sorted.len() == 1 {
166        return sorted[0];
167    }
168    let rank = p / 100.0 * (sorted.len() - 1) as f64;
169    let lower = rank.floor() as usize;
170    let upper = rank.ceil() as usize;
171    let frac = rank - lower as f64;
172    sorted[lower] * (1.0 - frac) + sorted[upper] * frac
173}
174
175#[cfg(test)]
176mod tests {
177    use super::*;
178
179    #[test]
180    fn test_trial_success_builder() {
181        let t = TrialResult::success(0, 42);
182        assert!(t.success);
183        assert_eq!(t.trial_id, 0);
184        assert_eq!(t.duration_ms, 42);
185        assert!(t.error.is_none());
186    }
187
188    #[test]
189    fn test_trial_failure_builder() {
190        let t = TrialResult::failure(1, 100, "timeout");
191        assert!(!t.success);
192        assert_eq!(t.error.as_deref(), Some("timeout"));
193    }
194
195    #[test]
196    fn test_trial_with_meta() {
197        let t = TrialResult::success(0, 10)
198            .with_meta("iterations", serde_json::json!(7))
199            .with_meta("model", serde_json::json!("claude-sonnet"));
200        assert_eq!(t.metadata["iterations"], serde_json::json!(7));
201    }
202
203    #[test]
204    fn test_wilson_ci_all_successes() {
205        let ci = ConfidenceInterval95::wilson(10, 10);
206        assert!(
207            ci.lower > 0.7,
208            "lower bound should be well above 0 for 10/10"
209        );
210        assert!((ci.upper - 1.0).abs() < 1e-9, "upper bound should be 1.0");
211    }
212
213    #[test]
214    fn test_wilson_ci_no_successes() {
215        let ci = ConfidenceInterval95::wilson(0, 10);
216        assert_eq!(ci.lower, 0.0);
217        assert!(ci.upper < 0.3, "upper bound should be low for 0/10");
218    }
219
220    #[test]
221    fn test_wilson_ci_zero_trials() {
222        let ci = ConfidenceInterval95::wilson(0, 0);
223        assert_eq!(ci.lower, 0.0);
224        assert_eq!(ci.upper, 1.0);
225    }
226
227    #[test]
228    fn test_wilson_ci_contains_true_rate() {
229        // For 70 % true rate with 100 trials the CI must contain 0.70
230        let ci = ConfidenceInterval95::wilson(70, 100);
231        assert!(ci.lower < 0.70 && ci.upper > 0.70);
232    }
233
234    #[test]
235    fn test_evaluation_stats_empty() {
236        assert!(EvaluationStats::from_trials(&[]).is_none());
237    }
238
239    #[test]
240    fn test_evaluation_stats_all_success() {
241        let trials: Vec<_> = (0..10).map(|i| TrialResult::success(i, 100)).collect();
242        let stats = EvaluationStats::from_trials(&trials).unwrap();
243        assert_eq!(stats.n_trials, 10);
244        assert_eq!(stats.successes, 10);
245        assert!((stats.success_rate - 1.0).abs() < 1e-9);
246    }
247
248    #[test]
249    fn test_evaluation_stats_mixed() {
250        let mut trials: Vec<_> = (0..7).map(|i| TrialResult::success(i, 50)).collect();
251        trials.extend((7..10).map(|i| TrialResult::failure(i, 200, "err")));
252        let stats = EvaluationStats::from_trials(&trials).unwrap();
253        assert_eq!(stats.successes, 7);
254        assert!((stats.success_rate - 0.7).abs() < 1e-9);
255        assert!(stats.p95_duration_ms >= stats.p50_duration_ms);
256        assert!(stats.p50_duration_ms >= stats.mean_duration_ms * 0.5);
257    }
258
259    #[test]
260    fn test_percentile_single_element() {
261        assert_eq!(percentile(&[42.0], 50.0), 42.0);
262    }
263
264    #[test]
265    fn test_percentile_interpolation() {
266        let data = vec![0.0, 10.0, 20.0, 30.0, 40.0];
267        let p50 = percentile(&data, 50.0);
268        assert!((p50 - 20.0).abs() < 1e-9);
269    }
270}