use serde::Serialize;
#[derive(Debug, Clone, Serialize)]
pub struct StatSummary {
pub mean: f64,
pub variance: f64,
pub min: f64,
pub max: f64,
pub count: usize,
}
impl StatSummary {
pub fn from_values(values: &[f64]) -> Self {
if values.is_empty() {
return Self {
mean: 0.0,
variance: 0.0,
min: 0.0,
max: 0.0,
count: 0,
};
}
let count = values.len();
let sum: f64 = values.iter().sum();
let mean = sum / count as f64;
let min = values.iter().cloned().fold(f64::INFINITY, f64::min);
let max = values.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let variance = if count == 1 {
0.0
} else {
let sum_squared_diff: f64 = values.iter().map(|v| (v - mean).powi(2)).sum();
sum_squared_diff / (count - 1) as f64
};
Self {
mean,
variance,
min,
max,
count,
}
}
pub fn std_dev(&self) -> f64 {
self.variance.sqrt()
}
}
#[derive(Debug, Clone, Serialize)]
pub struct TrialStatistics {
pub pass_rate: StatSummary,
pub elapsed_secs: StatSummary,
pub total_input_tokens: StatSummary,
pub total_output_tokens: StatSummary,
pub iterations: StatSummary,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_stat_summary_empty() {
let summary = StatSummary::from_values(&[]);
assert_eq!(summary.mean, 0.0);
assert_eq!(summary.variance, 0.0);
assert_eq!(summary.min, 0.0);
assert_eq!(summary.max, 0.0);
assert_eq!(summary.count, 0);
}
#[test]
fn test_stat_summary_single() {
let summary = StatSummary::from_values(&[42.0]);
assert_eq!(summary.mean, 42.0);
assert_eq!(summary.variance, 0.0);
assert_eq!(summary.min, 42.0);
assert_eq!(summary.max, 42.0);
assert_eq!(summary.count, 1);
}
#[test]
fn test_stat_summary_multiple() {
let summary = StatSummary::from_values(&[1.0, 2.0, 3.0, 4.0, 5.0]);
assert_eq!(summary.mean, 3.0);
assert_eq!(summary.variance, 2.5);
assert_eq!(summary.min, 1.0);
assert_eq!(summary.max, 5.0);
assert_eq!(summary.count, 5);
}
#[test]
fn test_stat_summary_std_dev() {
let summary = StatSummary::from_values(&[1.0, 2.0, 3.0, 4.0, 5.0]);
assert!((summary.std_dev() - summary.variance.sqrt()).abs() < 1e-10);
assert!((summary.std_dev() - 1.5811388300841898).abs() < 1e-10);
}
}