#[derive(Debug)]
pub struct WelfordAccumulator {
count: u64,
mean: f64,
m2: f64,
}
impl WelfordAccumulator {
#[must_use]
pub fn new() -> Self {
Self {
count: 0,
mean: 0.0,
m2: 0.0,
}
}
pub fn update(&mut self, value: f64) {
self.count += 1;
let delta = value - self.mean;
#[allow(clippy::cast_precision_loss)] let count_f64 = self.count as f64;
self.mean += delta / count_f64;
let delta2 = value - self.mean;
self.m2 += delta * delta2;
}
#[must_use]
pub fn count(&self) -> u64 {
self.count
}
#[must_use]
pub fn mean(&self) -> f64 {
self.mean
}
#[must_use]
pub fn variance(&self) -> f64 {
if self.count < 2 {
0.0
} else {
#[allow(clippy::cast_precision_loss)]
let count_f64 = self.count as f64;
self.m2 / count_f64
}
}
#[must_use]
pub fn std_dev(&self) -> f64 {
self.variance().sqrt()
}
#[must_use]
pub fn sample_variance(&self) -> f64 {
if self.count < 2 {
0.0
} else {
#[allow(clippy::cast_precision_loss)]
let count_f64 = self.count as f64;
self.m2 / (count_f64 - 1.0)
}
}
#[must_use]
pub fn sample_std_dev(&self) -> f64 {
self.sample_variance().sqrt()
}
#[must_use]
pub fn ci_95_half_width(&self) -> f64 {
if self.count < 2 {
0.0
} else {
#[allow(clippy::cast_precision_loss)]
let count_f64 = self.count as f64;
1.96 * self.std_dev() / count_f64.sqrt()
}
}
#[must_use]
pub fn sample_ci_95_half_width(&self) -> f64 {
if self.count < 2 {
0.0
} else {
#[allow(clippy::cast_precision_loss)]
let count_f64 = self.count as f64;
1.96 * self.sample_std_dev() / count_f64.sqrt()
}
}
}
impl Default for WelfordAccumulator {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::WelfordAccumulator;
#[test]
fn welford_known_dataset_mean_variance_std() {
let values = [2.0_f64, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0];
let mut acc = WelfordAccumulator::new();
for &v in &values {
acc.update(v);
}
assert!(
(acc.mean() - 5.0).abs() < 1e-10,
"mean: expected 5.0, got {}",
acc.mean()
);
assert!(
(acc.variance() - 4.0).abs() < 1e-10,
"variance: expected 4.0, got {}",
acc.variance()
);
assert!(
(acc.std_dev() - 2.0).abs() < 1e-10,
"std_dev: expected 2.0, got {}",
acc.std_dev()
);
}
#[test]
fn welford_single_value_no_variance() {
let mut acc = WelfordAccumulator::new();
acc.update(42.0);
assert!(
(acc.mean() - 42.0).abs() < 1e-10,
"mean: expected 42.0, got {}",
acc.mean()
);
assert_eq!(
acc.std_dev(),
0.0,
"std_dev must be 0.0 with one observation"
);
assert_eq!(
acc.ci_95_half_width(),
0.0,
"ci_95_half_width must be 0.0 with one observation"
);
}
#[test]
fn welford_zero_updates() {
let acc = WelfordAccumulator::new();
assert_eq!(acc.mean(), 0.0, "mean must be 0.0 with no observations");
assert_eq!(
acc.std_dev(),
0.0,
"std_dev must be 0.0 with no observations"
);
}
#[test]
fn welford_count_tracks_updates() {
let mut acc = WelfordAccumulator::new();
assert_eq!(acc.count(), 0, "count must be 0 before any updates");
acc.update(1.0);
acc.update(2.0);
acc.update(3.0);
assert_eq!(acc.count(), 3, "count must be 3 after 3 updates");
}
}