#[cfg(feature = "alloc")]
use crate::math;
#[cfg(feature = "alloc")]
use alloc::vec;
#[cfg(feature = "alloc")]
use alloc::vec::Vec;
#[cfg(feature = "alloc")]
pub struct WelfordNormalizer {
mean: Vec<f64>,
m2: Vec<f64>,
count: u64,
}
#[cfg(feature = "alloc")]
impl WelfordNormalizer {
pub fn new(d: usize) -> Self {
Self {
mean: vec![0.0; d],
m2: vec![0.0; d],
count: 0,
}
}
#[inline]
pub fn count(&self) -> u64 {
self.count
}
pub fn update(&mut self, x: &[f64]) {
debug_assert_eq!(
x.len(),
self.mean.len(),
"WelfordNormalizer: input length {} != expected {}",
x.len(),
self.mean.len()
);
self.count += 1;
let n = self.count as f64;
for ((m, m2), &xi) in self.mean.iter_mut().zip(self.m2.iter_mut()).zip(x.iter()) {
let delta = xi - *m;
*m += delta / n;
let delta2 = xi - *m;
*m2 += delta * delta2;
}
}
pub fn normalize(&self, x: &[f64], out: &mut [f64]) {
debug_assert_eq!(x.len(), self.mean.len());
debug_assert_eq!(out.len(), self.mean.len());
const EPS: f64 = 1e-8;
if self.count < 2 {
for o in out.iter_mut() {
*o = 0.0;
}
return;
}
let n_minus_1 = (self.count - 1) as f64;
for (((m, m2), &xi), o) in self
.mean
.iter()
.zip(self.m2.iter())
.zip(x.iter())
.zip(out.iter_mut())
{
let var = *m2 / n_minus_1;
let std = math::sqrt(var + EPS);
*o = (xi - m) / std;
}
}
pub fn update_and_normalize(&mut self, x: &[f64], out: &mut [f64]) {
debug_assert_eq!(x.len(), self.mean.len());
debug_assert_eq!(out.len(), self.mean.len());
const EPS: f64 = 1e-8;
self.count += 1;
let n = self.count as f64;
for (((m, m2), &xi), o) in self
.mean
.iter_mut()
.zip(self.m2.iter_mut())
.zip(x.iter())
.zip(out.iter_mut())
{
let delta = xi - *m;
*m += delta / n;
let delta2 = xi - *m;
*m2 += delta * delta2;
if self.count < 2 {
*o = 0.0;
} else {
let n_minus_1 = (self.count - 1) as f64;
let var = *m2 / n_minus_1;
let std = math::sqrt(var + EPS);
*o = (xi - *m) / std;
}
}
}
pub fn reset(&mut self) {
for m in self.mean.iter_mut() {
*m = 0.0;
}
for m2 in self.m2.iter_mut() {
*m2 = 0.0;
}
self.count = 0;
}
}
#[cfg(all(test, feature = "alloc"))]
mod tests {
use super::*;
#[test]
fn welford_known_mean_and_variance() {
let xs: &[f64] = &[1.0, 2.0, 3.0, 4.0, 5.0];
let mut norm = WelfordNormalizer::new(1);
for &x in xs {
norm.update(&[x]);
}
assert!((norm.mean[0] - 3.0).abs() < 1e-12, "mean should be 3.0");
let bessel_var = norm.m2[0] / 4.0;
assert!((bessel_var - 2.5).abs() < 1e-12, "Bessel var should be 2.5");
}
#[test]
fn normalized_output_is_standardized() {
let xs: &[f64] = &[2.0, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0];
let mut norm = WelfordNormalizer::new(1);
for &x in xs {
norm.update(&[x]);
}
let mut out = [0.0f64; 1];
let mut sum = 0.0;
for &x in xs {
norm.normalize(&[x], &mut out);
sum += out[0];
}
let normalized_mean = sum / xs.len() as f64;
assert!(
normalized_mean.abs() < 0.1,
"normalized mean {normalized_mean} should be near 0"
);
}
#[test]
fn cold_start_returns_zeros() {
let mut norm = WelfordNormalizer::new(3);
let mut out = [0.0f64; 3];
norm.normalize(&[1.0, 2.0, 3.0], &mut out);
assert_eq!(out, [0.0, 0.0, 0.0]);
norm.update(&[1.0, 2.0, 3.0]); norm.normalize(&[1.0, 2.0, 3.0], &mut out);
assert_eq!(out, [0.0, 0.0, 0.0]);
}
#[test]
fn reset_clears_all_state() {
let mut norm = WelfordNormalizer::new(2);
for i in 0..10 {
norm.update(&[i as f64, i as f64 * 2.0]);
}
norm.reset();
assert_eq!(norm.count(), 0);
assert_eq!(norm.mean[0], 0.0);
assert_eq!(norm.m2[0], 0.0);
}
#[test]
fn update_and_normalize_matches_separate_calls() {
let inputs: &[f64] = &[10.0, 20.0, 30.0, 40.0, 50.0];
let mut norm_a = WelfordNormalizer::new(1);
let mut norm_b = WelfordNormalizer::new(1);
let mut out_a = [0.0f64; 1];
let mut out_b = [0.0f64; 1];
for &x in inputs {
norm_a.update_and_normalize(&[x], &mut out_a);
norm_b.update(&[x]);
norm_b.normalize(&[x], &mut out_b);
assert!(
(out_a[0] - out_b[0]).abs() < 1e-10,
"combined and separate paths diverged at x={x}: combined={}, separate={}",
out_a[0],
out_b[0]
);
}
}
}