#[derive(Debug, Clone)]
pub struct OutcomeStats {
count: u64,
mean: Vec<f32>,
m2: Vec<f32>,
min: Vec<f32>,
max: Vec<f32>,
}
impl OutcomeStats {
#[must_use]
pub fn new(dim: usize) -> Self {
Self {
count: 0,
mean: vec![0.0; dim],
m2: vec![0.0; dim],
min: vec![f32::INFINITY; dim],
max: vec![f32::NEG_INFINITY; dim],
}
}
pub fn update(&mut self, outcome: &[f32]) {
assert_eq!(
outcome.len(),
self.dim(),
"Outcome dimension mismatch: expected {}, got {}",
self.dim(),
outcome.len()
);
self.count += 1;
let n = self.count as f32;
for i in 0..self.dim() {
let x = outcome[i];
let delta = x - self.mean[i];
self.mean[i] += delta / n;
let delta2 = x - self.mean[i];
self.m2[i] += delta * delta2;
self.min[i] = self.min[i].min(x);
self.max[i] = self.max[i].max(x);
}
}
#[must_use]
pub fn merge(&self, other: &Self) -> Self {
if self.count == 0 {
return other.clone();
}
if other.count == 0 {
return self.clone();
}
assert_eq!(self.dim(), other.dim(), "Dimension mismatch in merge");
let combined_count = self.count + other.count;
let mut combined_mean = vec![0.0; self.dim()];
let mut combined_m2 = vec![0.0; self.dim()];
let mut combined_min = vec![0.0; self.dim()];
let mut combined_max = vec![0.0; self.dim()];
for i in 0..self.dim() {
let delta = other.mean[i] - self.mean[i];
combined_mean[i] = self.mean[i]
+ delta * (other.count as f32 / combined_count as f32);
combined_m2[i] = self.m2[i]
+ other.m2[i]
+ delta * delta
* (self.count as f32 * other.count as f32 / combined_count as f32);
combined_min[i] = self.min[i].min(other.min[i]);
combined_max[i] = self.max[i].max(other.max[i]);
}
Self {
count: combined_count,
mean: combined_mean,
m2: combined_m2,
min: combined_min,
max: combined_max,
}
}
pub fn update_scalar(&mut self, value: f64) {
self.update(&[value as f32]);
}
#[must_use]
pub const fn count(&self) -> u64 {
self.count
}
#[must_use]
pub fn mean_scalar(&self) -> Option<f64> {
self.mean().map(|m| m[0] as f64)
}
#[must_use]
pub fn variance_scalar(&self) -> Option<f64> {
self.variance().map(|v| v[0] as f64)
}
#[must_use]
pub fn std_scalar(&self) -> Option<f64> {
self.std().map(|s| s[0] as f64)
}
#[must_use]
pub fn dim(&self) -> usize {
self.mean.len()
}
#[must_use]
pub fn mean(&self) -> Option<&[f32]> {
if self.count > 0 {
Some(&self.mean)
} else {
None
}
}
#[must_use]
pub fn variance(&self) -> Option<Vec<f32>> {
if self.count < 2 {
return None;
}
Some(self.m2.iter().map(|m| m / self.count as f32).collect())
}
#[must_use]
pub fn std(&self) -> Option<Vec<f32>> {
self.variance().map(|v| v.iter().map(|x| x.sqrt()).collect())
}
#[must_use]
pub fn sample_variance(&self) -> Option<Vec<f32>> {
if self.count < 2 {
return None;
}
Some(
self.m2
.iter()
.map(|m| m / (self.count - 1) as f32)
.collect(),
)
}
#[must_use]
pub fn min(&self) -> Option<&[f32]> {
if self.count > 0 {
Some(&self.min)
} else {
None
}
}
#[must_use]
pub fn max(&self) -> Option<&[f32]> {
if self.count > 0 {
Some(&self.max)
} else {
None
}
}
#[must_use]
pub fn confidence_interval(&self, confidence: f32) -> Option<(Vec<f32>, Vec<f32>)> {
if self.count < 2 {
return None;
}
let std = self.std()?;
let std_err: Vec<f32> = std.iter().map(|s| s / (self.count as f32).sqrt()).collect();
let t_val = if self.count < 30 {
2.0 + 1.0 / (self.count as f32).sqrt()
} else {
match confidence {
c if (c - 0.90).abs() < 0.01 => 1.645,
c if (c - 0.95).abs() < 0.01 => 1.96,
c if (c - 0.99).abs() < 0.01 => 2.576,
_ => 1.96, }
};
let lower: Vec<f32> = self
.mean
.iter()
.zip(&std_err)
.map(|(m, se)| m - t_val * se)
.collect();
let upper: Vec<f32> = self
.mean
.iter()
.zip(&std_err)
.map(|(m, se)| m + t_val * se)
.collect();
Some((lower, upper))
}
}
impl Default for OutcomeStats {
fn default() -> Self {
Self::new(0)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_empty_stats() {
let stats = OutcomeStats::new(3);
assert_eq!(stats.count(), 0);
assert!(stats.mean().is_none());
assert!(stats.variance().is_none());
}
#[test]
fn test_single_update() {
let mut stats = OutcomeStats::new(3);
stats.update(&[1.0, 2.0, 3.0]);
assert_eq!(stats.count(), 1);
assert_eq!(stats.mean(), Some([1.0, 2.0, 3.0].as_slice()));
assert!(stats.variance().is_none()); }
#[test]
fn test_multiple_updates() {
let mut stats = OutcomeStats::new(2);
stats.update(&[1.0, 2.0]);
stats.update(&[3.0, 4.0]);
stats.update(&[5.0, 6.0]);
assert_eq!(stats.count(), 3);
let mean = stats.mean().unwrap();
assert!((mean[0] - 3.0).abs() < 1e-6);
assert!((mean[1] - 4.0).abs() < 1e-6);
}
#[test]
fn test_merge() {
let mut stats1 = OutcomeStats::new(2);
stats1.update(&[1.0, 2.0]);
stats1.update(&[2.0, 3.0]);
let mut stats2 = OutcomeStats::new(2);
stats2.update(&[3.0, 4.0]);
stats2.update(&[4.0, 5.0]);
let merged = stats1.merge(&stats2);
assert_eq!(merged.count(), 4);
let mean = merged.mean().unwrap();
assert!((mean[0] - 2.5).abs() < 1e-6);
assert!((mean[1] - 3.5).abs() < 1e-6);
}
#[test]
fn test_numerical_stability() {
let mut stats = OutcomeStats::new(1);
let base = 1e9_f32;
for i in 0..1000 {
stats.update(&[base + (i as f32) * 0.001]);
}
let mean = stats.mean().unwrap()[0];
assert!((mean - base).abs() < 1.0);
let var = stats.variance().unwrap()[0];
assert!(var >= 0.0); }
#[test]
fn test_min_max() {
let mut stats = OutcomeStats::new(2);
stats.update(&[1.0, 5.0]);
stats.update(&[3.0, 2.0]);
stats.update(&[2.0, 8.0]);
assert_eq!(stats.min(), Some([1.0, 2.0].as_slice()));
assert_eq!(stats.max(), Some([3.0, 8.0].as_slice()));
}
}