use crate::error::{StatsError, StatsResult};
#[derive(Clone, Debug, Default, PartialEq)]
pub struct Welford {
n: u64,
mean: f64,
m2: f64,
}
impl Welford {
#[inline]
pub fn new() -> Self {
Self::default()
}
#[inline]
pub fn push(&mut self, x: f64) {
self.n += 1;
let delta = x - self.mean;
self.mean += delta / self.n as f64;
let delta2 = x - self.mean;
self.m2 += delta * delta2;
}
pub fn pop(&mut self, x: f64) -> StatsResult<()> {
if self.n == 0 {
return Err(StatsError::empty_data(
"Welford::pop: cannot pop from empty estimator",
));
}
let new_n = self.n - 1;
if new_n == 0 {
self.mean = 0.0;
self.m2 = 0.0;
self.n = 0;
return Ok(());
}
let delta = x - self.mean;
self.mean -= delta / new_n as f64;
let delta2 = x - self.mean;
self.m2 -= delta * delta2;
self.n = new_n;
Ok(())
}
pub fn merge(&mut self, other: &Self) {
if other.n == 0 {
return;
}
if self.n == 0 {
*self = other.clone();
return;
}
let n_total = self.n + other.n;
let delta = other.mean - self.mean;
let new_mean = self.mean + delta * (other.n as f64) / (n_total as f64);
self.m2 += other.m2 + delta * delta * (self.n as f64 * other.n as f64) / (n_total as f64);
self.mean = new_mean;
self.n = n_total;
}
#[inline]
pub fn count(&self) -> u64 {
self.n
}
#[inline]
pub fn mean(&self) -> f64 {
self.mean
}
pub fn variance(&self) -> StatsResult<f64> {
if self.n < 2 {
return Err(StatsError::empty_data(
"Welford::variance: need at least 2 observations",
));
}
Ok(self.m2 / (self.n - 1) as f64)
}
#[inline]
pub fn population_variance(&self) -> f64 {
if self.n == 0 {
0.0
} else {
self.m2 / self.n as f64
}
}
pub fn std_dev(&self) -> StatsResult<f64> {
Ok(self.variance()?.sqrt())
}
#[inline]
pub fn m2(&self) -> f64 {
self.m2
}
pub fn from_raw(n: u64, mean: f64, m2: f64) -> Self {
Self { n, mean, m2 }
}
}
#[derive(Clone, Debug, PartialEq)]
pub struct WelfordVector {
dim: usize,
n: u64,
mean: Vec<f64>,
m2: Vec<f64>,
}
impl WelfordVector {
pub fn new(dim: usize) -> Self {
Self {
dim,
n: 0,
mean: vec![0.0; dim],
m2: vec![0.0; dim],
}
}
pub fn push(&mut self, x: &[f64]) -> StatsResult<()> {
if x.len() != self.dim {
return Err(StatsError::invalid_input(format!(
"WelfordVector::push: expected {} dims, got {}",
self.dim,
x.len()
)));
}
self.n += 1;
let n_inv = 1.0 / self.n as f64;
for i in 0..self.dim {
let delta = x[i] - self.mean[i];
self.mean[i] += delta * n_inv;
let delta2 = x[i] - self.mean[i];
self.m2[i] += delta * delta2;
}
Ok(())
}
pub fn merge(&mut self, other: &Self) -> StatsResult<()> {
if self.dim != other.dim {
return Err(StatsError::invalid_input(format!(
"WelfordVector::merge: dim mismatch ({} vs {})",
self.dim, other.dim
)));
}
if other.n == 0 {
return Ok(());
}
if self.n == 0 {
*self = other.clone();
return Ok(());
}
let n_total = self.n + other.n;
let n_total_f = n_total as f64;
let prod_n = (self.n as f64) * (other.n as f64);
for i in 0..self.dim {
let delta = other.mean[i] - self.mean[i];
let new_mean_i = self.mean[i] + delta * (other.n as f64) / n_total_f;
self.m2[i] += other.m2[i] + delta * delta * prod_n / n_total_f;
self.mean[i] = new_mean_i;
}
self.n = n_total;
Ok(())
}
pub fn count(&self) -> u64 {
self.n
}
pub fn dim(&self) -> usize {
self.dim
}
pub fn mean(&self) -> &[f64] {
&self.mean
}
pub fn m2(&self) -> &[f64] {
&self.m2
}
pub fn variance(&self) -> StatsResult<Vec<f64>> {
let mut out = vec![0.0; self.m2.len()];
self.variance_into(&mut out)?;
Ok(out)
}
pub fn variance_into(&self, out: &mut [f64]) -> StatsResult<()> {
if self.n < 2 {
return Err(StatsError::empty_data(
"WelfordVector::variance: need at least 2 observations",
));
}
if out.len() != self.m2.len() {
return Err(StatsError::invalid_input(format!(
"WelfordVector::variance_into: out len {} != dim {}",
out.len(),
self.m2.len()
)));
}
let denom = (self.n - 1) as f64;
for (o, m) in out.iter_mut().zip(self.m2.iter()) {
*o = m / denom;
}
Ok(())
}
pub fn std_dev(&self) -> StatsResult<Vec<f64>> {
let mut out = vec![0.0; self.m2.len()];
self.std_dev_into(&mut out)?;
Ok(out)
}
pub fn std_dev_into(&self, out: &mut [f64]) -> StatsResult<()> {
self.variance_into(out)?;
for o in out.iter_mut() {
*o = o.sqrt();
}
Ok(())
}
}
#[derive(Clone, Debug, PartialEq)]
pub struct WelfordCovariance {
dim: usize,
n: u64,
mean: Vec<f64>,
m2: Vec<f64>,
delta: Vec<f64>,
}
impl WelfordCovariance {
pub fn new(dim: usize) -> Self {
Self {
dim,
n: 0,
mean: vec![0.0; dim],
m2: vec![0.0; dim * dim],
delta: vec![0.0; dim],
}
}
pub fn push(&mut self, x: &[f64]) -> StatsResult<()> {
if x.len() != self.dim {
return Err(StatsError::invalid_input(format!(
"WelfordCovariance::push: expected {} dims, got {}",
self.dim,
x.len()
)));
}
self.n += 1;
let n_inv = 1.0 / self.n as f64;
for i in 0..self.dim {
self.delta[i] = x[i] - self.mean[i];
self.mean[i] += self.delta[i] * n_inv;
}
for j in 0..self.dim {
let delta_post_j = x[j] - self.mean[j];
let row_offset = j;
for i in 0..self.dim {
self.m2[i * self.dim + row_offset] += self.delta[i] * delta_post_j;
}
}
Ok(())
}
pub fn merge(&mut self, other: &Self) -> StatsResult<()> {
if self.dim != other.dim {
return Err(StatsError::invalid_input(format!(
"WelfordCovariance::merge: dim mismatch ({} vs {})",
self.dim, other.dim
)));
}
if other.n == 0 {
return Ok(());
}
if self.n == 0 {
self.n = other.n;
self.mean.copy_from_slice(&other.mean);
self.m2.copy_from_slice(&other.m2);
return Ok(());
}
let n_total = self.n + other.n;
let n_total_f = n_total as f64;
let prod_n = (self.n as f64) * (other.n as f64);
for i in 0..self.dim {
self.delta[i] = other.mean[i] - self.mean[i];
}
for i in 0..self.dim {
for j in 0..self.dim {
let idx = i * self.dim + j;
self.m2[idx] += other.m2[idx] + self.delta[i] * self.delta[j] * prod_n / n_total_f;
}
}
for i in 0..self.dim {
self.mean[i] += self.delta[i] * (other.n as f64) / n_total_f;
}
self.n = n_total;
Ok(())
}
pub fn count(&self) -> u64 {
self.n
}
pub fn dim(&self) -> usize {
self.dim
}
pub fn mean(&self) -> &[f64] {
&self.mean
}
pub fn m2(&self) -> &[f64] {
&self.m2
}
pub fn covariance(&self) -> StatsResult<Vec<f64>> {
if self.n < 2 {
return Err(StatsError::empty_data(
"WelfordCovariance::covariance: need at least 2 observations",
));
}
let denom = (self.n - 1) as f64;
Ok(self.m2.iter().map(|m| m / denom).collect())
}
pub fn covariance_into(&self, out: &mut [f64]) -> StatsResult<()> {
if self.n < 2 {
return Err(StatsError::empty_data(
"WelfordCovariance::covariance_into: need at least 2 observations",
));
}
if out.len() != self.dim * self.dim {
return Err(StatsError::invalid_input(format!(
"WelfordCovariance::covariance_into: out len {} != expected {}",
out.len(),
self.dim * self.dim
)));
}
let denom = (self.n - 1) as f64;
for (i, &m) in self.m2.iter().enumerate() {
out[i] = m / denom;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
fn approx(a: f64, b: f64, tol: f64) -> bool {
(a - b).abs() < tol
}
#[test]
fn welford_basic() {
let mut w = Welford::new();
for x in [4.0, 7.0, 13.0, 16.0] {
w.push(x);
}
assert_eq!(w.count(), 4);
assert!(approx(w.mean(), 10.0, 1e-12));
assert!(approx(w.variance().unwrap(), 30.0, 1e-12));
}
#[test]
fn welford_pop_inverts_push() {
let mut w = Welford::new();
for x in [4.0, 7.0, 13.0, 16.0] {
w.push(x);
}
let before = w.clone();
w.push(99.0);
w.pop(99.0).unwrap();
assert!(approx(w.mean(), before.mean(), 1e-12));
assert!(approx(w.m2(), before.m2(), 1e-9));
assert_eq!(w.count(), before.count());
}
#[test]
fn welford_pop_stable_far_from_origin() {
let mut w = Welford::new();
for i in 0..1000 {
w.push(1_700_000_000.0 + (i as f64) * 1e-3);
}
let before = w.clone();
w.push(1_700_000_500.123_456);
w.pop(1_700_000_500.123_456).unwrap();
assert!(approx(w.mean(), before.mean(), 1e-9));
assert_eq!(w.count(), before.count());
}
#[test]
fn welford_pop_empty_errors() {
let mut w = Welford::new();
assert!(w.pop(1.0).is_err());
}
#[test]
fn welford_pop_to_empty_resets() {
let mut w = Welford::new();
w.push(42.0);
w.pop(42.0).unwrap();
assert_eq!(w.count(), 0);
assert_eq!(w.mean(), 0.0);
assert_eq!(w.m2(), 0.0);
}
#[test]
fn welford_merge_chan() {
let mut a = Welford::new();
for x in [1.0, 2.0, 3.0] {
a.push(x);
}
let mut b = Welford::new();
for x in [4.0, 5.0, 6.0] {
b.push(x);
}
let mut full = Welford::new();
for x in [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] {
full.push(x);
}
let mut merged = a.clone();
merged.merge(&b);
assert!(approx(merged.mean(), full.mean(), 1e-12));
assert!(approx(
merged.variance().unwrap(),
full.variance().unwrap(),
1e-12
));
}
#[test]
fn welford_merge_with_empty() {
let mut a = Welford::new();
a.push(5.0);
a.push(10.0);
let snapshot = a.clone();
a.merge(&Welford::new());
assert_eq!(a, snapshot);
let mut b = Welford::new();
b.merge(&snapshot);
assert_eq!(b, snapshot);
}
#[test]
fn welford_population_vs_sample() {
let mut w = Welford::new();
for x in [2.0, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0] {
w.push(x);
}
let pop_var = w.population_variance();
let samp_var = w.variance().unwrap();
assert!(approx(samp_var, pop_var * 8.0 / 7.0, 1e-12));
}
#[test]
fn welford_vector_basic() {
let mut wv = WelfordVector::new(3);
wv.push(&[1.0, 10.0, 100.0]).unwrap();
wv.push(&[2.0, 20.0, 200.0]).unwrap();
wv.push(&[3.0, 30.0, 300.0]).unwrap();
assert_eq!(wv.mean(), &[2.0, 20.0, 200.0]);
let var = wv.variance().unwrap();
assert!(approx(var[0], 1.0, 1e-12));
assert!(approx(var[1], 100.0, 1e-12));
assert!(approx(var[2], 10000.0, 1e-12));
}
#[test]
fn welford_vector_merge() {
let mut a = WelfordVector::new(2);
a.push(&[1.0, 1.0]).unwrap();
a.push(&[2.0, 2.0]).unwrap();
let mut b = WelfordVector::new(2);
b.push(&[3.0, 3.0]).unwrap();
b.push(&[4.0, 4.0]).unwrap();
let mut full = WelfordVector::new(2);
for v in [[1.0, 1.0], [2.0, 2.0], [3.0, 3.0], [4.0, 4.0]] {
full.push(&v).unwrap();
}
let mut merged = a.clone();
merged.merge(&b).unwrap();
assert_eq!(merged.count(), 4);
for i in 0..2 {
assert!(approx(merged.mean()[i], full.mean()[i], 1e-12));
}
}
#[test]
fn welford_vector_dim_mismatch_errors() {
let mut wv = WelfordVector::new(3);
assert!(wv.push(&[1.0, 2.0]).is_err());
}
#[test]
fn welford_covariance_basic() {
let mut wc = WelfordCovariance::new(2);
wc.push(&[1.0, 2.0]).unwrap();
wc.push(&[2.0, 4.0]).unwrap();
wc.push(&[3.0, 6.0]).unwrap();
wc.push(&[4.0, 8.0]).unwrap();
let cov = wc.covariance().unwrap();
assert!(approx(cov[0], 5.0 / 3.0, 1e-12));
assert!(approx(cov[1], 2.0 * 5.0 / 3.0, 1e-12));
assert!(approx(cov[2], 2.0 * 5.0 / 3.0, 1e-12));
assert!(approx(cov[3], 4.0 * 5.0 / 3.0, 1e-12));
}
#[test]
fn welford_covariance_merge() {
let mut a = WelfordCovariance::new(2);
a.push(&[1.0, 2.0]).unwrap();
a.push(&[2.0, 4.0]).unwrap();
let mut b = WelfordCovariance::new(2);
b.push(&[3.0, 6.0]).unwrap();
b.push(&[4.0, 8.0]).unwrap();
let mut full = WelfordCovariance::new(2);
for v in [[1.0, 2.0], [2.0, 4.0], [3.0, 6.0], [4.0, 8.0]] {
full.push(&v).unwrap();
}
let mut merged = a.clone();
merged.merge(&b).unwrap();
assert_eq!(merged.count(), 4);
for i in 0..2 {
assert!(approx(merged.mean()[i], full.mean()[i], 1e-12));
}
let m_cov = merged.covariance().unwrap();
let f_cov = full.covariance().unwrap();
for i in 0..4 {
assert!(approx(m_cov[i], f_cov[i], 1e-9));
}
}
#[test]
fn welford_covariance_into_matches_covariance() {
let mut wc = WelfordCovariance::new(3);
for v in [[1.0, 2.0, 3.0], [2.0, 4.0, 6.0], [3.0, 6.0, 9.0]] {
wc.push(&v).unwrap();
}
let owned = wc.covariance().unwrap();
let mut buf = vec![0.0; 9];
wc.covariance_into(&mut buf).unwrap();
for i in 0..9 {
assert!(approx(owned[i], buf[i], 1e-15));
}
}
#[test]
fn welford_covariance_into_wrong_size_errors() {
let mut wc = WelfordCovariance::new(2);
wc.push(&[1.0, 2.0]).unwrap();
wc.push(&[3.0, 4.0]).unwrap();
let mut wrong = vec![0.0; 3];
assert!(wc.covariance_into(&mut wrong).is_err());
}
}