use crate::error::{RlError, RlResult};
#[derive(Debug, Clone)]
pub struct RunningStats {
dim: usize,
mean: Vec<f64>,
m2: Vec<f64>,
count: u64,
}
impl RunningStats {
#[must_use]
pub fn new(dim: usize) -> Self {
assert!(dim > 0, "dim must be > 0");
Self {
dim,
mean: vec![0.0_f64; dim],
m2: vec![0.0_f64; dim],
count: 0,
}
}
#[must_use]
#[inline]
pub fn dim(&self) -> usize {
self.dim
}
#[must_use]
#[inline]
pub fn count(&self) -> u64 {
self.count
}
#[must_use]
pub fn mean_f32(&self) -> Vec<f32> {
self.mean.iter().map(|&m| m as f32).collect()
}
#[must_use]
pub fn std_f32(&self) -> Vec<f32> {
if self.count < 2 {
return vec![1.0_f32; self.dim];
}
let n = (self.count - 1) as f64;
self.m2
.iter()
.map(|&m2| ((m2 / n).max(1e-8)).sqrt() as f32)
.collect()
}
#[must_use]
pub fn var_f32(&self) -> Vec<f32> {
if self.count < 2 {
return vec![1.0_f32; self.dim];
}
let n = (self.count - 1) as f64;
self.m2.iter().map(|&m2| (m2 / n) as f32).collect()
}
pub fn update(&mut self, obs: &[f32]) -> RlResult<()> {
if obs.len() != self.dim {
return Err(RlError::DimensionMismatch {
expected: self.dim,
got: obs.len(),
});
}
self.count += 1;
let n = self.count as f64;
for (i, &x) in obs.iter().enumerate() {
let x64 = x as f64;
let delta = x64 - self.mean[i];
self.mean[i] += delta / n;
let delta2 = x64 - self.mean[i];
self.m2[i] += delta * delta2;
}
Ok(())
}
pub fn update_batch(&mut self, batch: &[f32]) -> RlResult<()> {
if batch.len() % self.dim != 0 {
return Err(RlError::DimensionMismatch {
expected: self.dim,
got: batch.len(),
});
}
for chunk in batch.chunks_exact(self.dim) {
self.update(chunk)?;
}
Ok(())
}
pub fn normalise(&self, obs: &[f32]) -> RlResult<Vec<f32>> {
if obs.len() != self.dim {
return Err(RlError::DimensionMismatch {
expected: self.dim,
got: obs.len(),
});
}
let std = self.std_f32();
let mean = self.mean_f32();
Ok(obs
.iter()
.zip(mean.iter())
.zip(std.iter())
.map(|((&x, &m), &s)| (x - m) / (s + 1e-8))
.collect())
}
pub fn reset(&mut self) {
self.mean.iter_mut().for_each(|v| *v = 0.0);
self.m2.iter_mut().for_each(|v| *v = 0.0);
self.count = 0;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn initial_count_zero() {
let rs = RunningStats::new(3);
assert_eq!(rs.count(), 0);
}
#[test]
fn single_update_count_one() {
let mut rs = RunningStats::new(2);
rs.update(&[1.0, 2.0]).unwrap();
assert_eq!(rs.count(), 1);
}
#[test]
fn mean_converges_to_true_mean() {
let mut rs = RunningStats::new(1);
for _ in 0..1000 {
rs.update(&[3.0]).unwrap();
}
let mean = rs.mean_f32()[0];
assert!((mean - 3.0).abs() < 0.01, "mean={mean}");
}
#[test]
fn std_converges_to_true_std() {
let mut rs = RunningStats::new(1);
for i in 0..2000 {
let v = if i % 2 == 0 { 1.0 } else { -1.0 };
rs.update(&[v]).unwrap();
}
let std = rs.std_f32()[0];
assert!((std - 1.0).abs() < 0.05, "std={std}");
}
#[test]
fn normalise_close_to_zero_mean() {
let mut rs = RunningStats::new(1);
for i in 0..100 {
rs.update(&[i as f32]).unwrap();
}
let norm = rs.normalise(&[50.0]).unwrap(); assert!(
norm[0].abs() < 0.5,
"normalised mean should be near 0, got {}",
norm[0]
);
}
#[test]
fn normalise_dimension_error() {
let rs = RunningStats::new(3);
assert!(rs.normalise(&[1.0, 2.0]).is_err());
}
#[test]
fn update_batch_increments_count() {
let mut rs = RunningStats::new(2);
let batch = vec![1.0_f32; 10]; rs.update_batch(&batch).unwrap();
assert_eq!(rs.count(), 5);
}
#[test]
fn reset_zeroes_stats() {
let mut rs = RunningStats::new(2);
rs.update(&[3.0, 4.0]).unwrap();
rs.reset();
assert_eq!(rs.count(), 0);
let mean = rs.mean_f32();
assert!(mean.iter().all(|&m| m.abs() < 1e-9));
}
#[test]
fn std_default_before_two_samples() {
let mut rs = RunningStats::new(2);
rs.update(&[1.0, 2.0]).unwrap();
let std = rs.std_f32();
assert_eq!(std, vec![1.0, 1.0]);
}
}