1use std::collections::HashMap;
8
9use serde::{Deserialize, Serialize};
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct TrialResult {
16 pub trial_id: usize,
18 pub success: bool,
20 pub duration_ms: u64,
22 pub error: Option<String>,
24 pub metadata: HashMap<String, serde_json::Value>,
27}
28
29impl TrialResult {
30 pub fn success(trial_id: usize, duration_ms: u64) -> Self {
32 Self {
33 trial_id,
34 success: true,
35 duration_ms,
36 error: None,
37 metadata: HashMap::new(),
38 }
39 }
40
41 pub fn failure(trial_id: usize, duration_ms: u64, error: impl Into<String>) -> Self {
43 Self {
44 trial_id,
45 success: false,
46 duration_ms,
47 error: Some(error.into()),
48 metadata: HashMap::new(),
49 }
50 }
51
52 pub fn with_meta(
54 mut self,
55 key: impl Into<String>,
56 value: impl Into<serde_json::Value>,
57 ) -> Self {
58 self.metadata.insert(key.into(), value.into());
59 self
60 }
61}
62
63#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
67pub struct ConfidenceInterval95 {
68 pub lower: f64,
70 pub upper: f64,
72}
73
74impl ConfidenceInterval95 {
75 pub fn wilson(successes: usize, n: usize) -> Self {
83 if n == 0 {
84 return Self {
85 lower: 0.0,
86 upper: 1.0,
87 };
88 }
89
90 const Z: f64 = 1.96; let p = successes as f64 / n as f64;
92 let nf = n as f64;
93 let z2 = Z * Z;
94
95 let centre = p + z2 / (2.0 * nf);
96 let margin = Z * (p * (1.0 - p) / nf + z2 / (4.0 * nf * nf)).sqrt();
97 let denom = 1.0 + z2 / nf;
98
99 Self {
100 lower: ((centre - margin) / denom).clamp(0.0, 1.0),
101 upper: ((centre + margin) / denom).clamp(0.0, 1.0),
102 }
103 }
104}
105
106#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct EvaluationStats {
111 pub n_trials: usize,
113 pub successes: usize,
115 pub success_rate: f64,
117 pub confidence_interval_95: ConfidenceInterval95,
119 pub mean_duration_ms: f64,
121 pub p50_duration_ms: f64,
123 pub p95_duration_ms: f64,
125}
126
127impl EvaluationStats {
128 pub fn from_trials(results: &[TrialResult]) -> Option<Self> {
132 let n = results.len();
133 if n == 0 {
134 return None;
135 }
136
137 let successes = results.iter().filter(|r| r.success).count();
138 let success_rate = successes as f64 / n as f64;
139 let ci = ConfidenceInterval95::wilson(successes, n);
140
141 let mut durations: Vec<f64> = results.iter().map(|r| r.duration_ms as f64).collect();
142 durations.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
143
144 let mean_duration_ms = durations.iter().sum::<f64>() / n as f64;
145 let p50_duration_ms = percentile(&durations, 50.0);
146 let p95_duration_ms = percentile(&durations, 95.0);
147
148 Some(Self {
149 n_trials: n,
150 successes,
151 success_rate,
152 confidence_interval_95: ci,
153 mean_duration_ms,
154 p50_duration_ms,
155 p95_duration_ms,
156 })
157 }
158}
159
160fn percentile(sorted: &[f64], p: f64) -> f64 {
162 if sorted.is_empty() {
163 return 0.0;
164 }
165 if sorted.len() == 1 {
166 return sorted[0];
167 }
168 let rank = p / 100.0 * (sorted.len() - 1) as f64;
169 let lower = rank.floor() as usize;
170 let upper = rank.ceil() as usize;
171 let frac = rank - lower as f64;
172 sorted[lower] * (1.0 - frac) + sorted[upper] * frac
173}
174
175#[cfg(test)]
176mod tests {
177 use super::*;
178
179 #[test]
180 fn test_trial_success_builder() {
181 let t = TrialResult::success(0, 42);
182 assert!(t.success);
183 assert_eq!(t.trial_id, 0);
184 assert_eq!(t.duration_ms, 42);
185 assert!(t.error.is_none());
186 }
187
188 #[test]
189 fn test_trial_failure_builder() {
190 let t = TrialResult::failure(1, 100, "timeout");
191 assert!(!t.success);
192 assert_eq!(t.error.as_deref(), Some("timeout"));
193 }
194
195 #[test]
196 fn test_trial_with_meta() {
197 let t = TrialResult::success(0, 10)
198 .with_meta("iterations", serde_json::json!(7))
199 .with_meta("model", serde_json::json!("claude-sonnet"));
200 assert_eq!(t.metadata["iterations"], serde_json::json!(7));
201 }
202
203 #[test]
204 fn test_wilson_ci_all_successes() {
205 let ci = ConfidenceInterval95::wilson(10, 10);
206 assert!(
207 ci.lower > 0.7,
208 "lower bound should be well above 0 for 10/10"
209 );
210 assert!((ci.upper - 1.0).abs() < 1e-9, "upper bound should be 1.0");
211 }
212
213 #[test]
214 fn test_wilson_ci_no_successes() {
215 let ci = ConfidenceInterval95::wilson(0, 10);
216 assert_eq!(ci.lower, 0.0);
217 assert!(ci.upper < 0.3, "upper bound should be low for 0/10");
218 }
219
220 #[test]
221 fn test_wilson_ci_zero_trials() {
222 let ci = ConfidenceInterval95::wilson(0, 0);
223 assert_eq!(ci.lower, 0.0);
224 assert_eq!(ci.upper, 1.0);
225 }
226
227 #[test]
228 fn test_wilson_ci_contains_true_rate() {
229 let ci = ConfidenceInterval95::wilson(70, 100);
231 assert!(ci.lower < 0.70 && ci.upper > 0.70);
232 }
233
234 #[test]
235 fn test_evaluation_stats_empty() {
236 assert!(EvaluationStats::from_trials(&[]).is_none());
237 }
238
239 #[test]
240 fn test_evaluation_stats_all_success() {
241 let trials: Vec<_> = (0..10).map(|i| TrialResult::success(i, 100)).collect();
242 let stats = EvaluationStats::from_trials(&trials).unwrap();
243 assert_eq!(stats.n_trials, 10);
244 assert_eq!(stats.successes, 10);
245 assert!((stats.success_rate - 1.0).abs() < 1e-9);
246 }
247
248 #[test]
249 fn test_evaluation_stats_mixed() {
250 let mut trials: Vec<_> = (0..7).map(|i| TrialResult::success(i, 50)).collect();
251 trials.extend((7..10).map(|i| TrialResult::failure(i, 200, "err")));
252 let stats = EvaluationStats::from_trials(&trials).unwrap();
253 assert_eq!(stats.successes, 7);
254 assert!((stats.success_rate - 0.7).abs() < 1e-9);
255 assert!(stats.p95_duration_ms >= stats.p50_duration_ms);
256 assert!(stats.p50_duration_ms >= stats.mean_duration_ms * 0.5);
257 }
258
259 #[test]
260 fn test_percentile_single_element() {
261 assert_eq!(percentile(&[42.0], 50.0), 42.0);
262 }
263
264 #[test]
265 fn test_percentile_interpolation() {
266 let data = vec![0.0, 10.0, 20.0, 30.0, 40.0];
267 let p50 = percentile(&data, 50.0);
268 assert!((p50 - 20.0).abs() < 1e-9);
269 }
270}