Skip to main content

oxicuda_rl/normalize/
running_stats.rs

1//! # Online Welford Mean/Variance Tracker
2//!
3//! Welford's online algorithm computes the running mean and variance in a single
4//! pass with O(1) memory:
5//!
6//! ```text
7//! n   += 1
8//! δ    = x - mean
9//! mean += δ / n
10//! δ₂   = x - mean  (new mean)
11//! M2  += δ * δ₂
12//! var  = M2 / (n - 1)  (sample variance)
13//! ```
14//!
15//! Supports scalar and vector (per-dimension) tracking.
16
17use crate::error::{RlError, RlResult};
18
19// ─── RunningStats ─────────────────────────────────────────────────────────────
20
21/// Per-dimension running mean and variance tracker using Welford's algorithm.
22#[derive(Debug, Clone)]
23pub struct RunningStats {
24    dim: usize,
25    /// Running mean per dimension.
26    mean: Vec<f64>,
27    /// Running M2 (sum of squared deviations) per dimension.
28    m2: Vec<f64>,
29    /// Sample count.
30    count: u64,
31}
32
33impl RunningStats {
34    /// Create a new tracker for vectors of dimension `dim`.
35    #[must_use]
36    pub fn new(dim: usize) -> Self {
37        assert!(dim > 0, "dim must be > 0");
38        Self {
39            dim,
40            mean: vec![0.0_f64; dim],
41            m2: vec![0.0_f64; dim],
42            count: 0,
43        }
44    }
45
46    /// Dimension (number of tracked features).
47    #[must_use]
48    #[inline]
49    pub fn dim(&self) -> usize {
50        self.dim
51    }
52
53    /// Number of samples seen so far.
54    #[must_use]
55    #[inline]
56    pub fn count(&self) -> u64 {
57        self.count
58    }
59
60    /// Per-dimension mean.
61    #[must_use]
62    pub fn mean_f32(&self) -> Vec<f32> {
63        self.mean.iter().map(|&m| m as f32).collect()
64    }
65
66    /// Per-dimension sample standard deviation (returns `[1.0; dim]` until `count ≥ 2`).
67    #[must_use]
68    pub fn std_f32(&self) -> Vec<f32> {
69        if self.count < 2 {
70            return vec![1.0_f32; self.dim];
71        }
72        let n = (self.count - 1) as f64;
73        self.m2
74            .iter()
75            .map(|&m2| ((m2 / n).max(1e-8)).sqrt() as f32)
76            .collect()
77    }
78
79    /// Per-dimension variance.
80    #[must_use]
81    pub fn var_f32(&self) -> Vec<f32> {
82        if self.count < 2 {
83            return vec![1.0_f32; self.dim];
84        }
85        let n = (self.count - 1) as f64;
86        self.m2.iter().map(|&m2| (m2 / n) as f32).collect()
87    }
88
89    /// Update statistics with a new observation vector.
90    ///
91    /// # Errors
92    ///
93    /// * [`RlError::DimensionMismatch`] if `obs.len() != dim`.
94    pub fn update(&mut self, obs: &[f32]) -> RlResult<()> {
95        if obs.len() != self.dim {
96            return Err(RlError::DimensionMismatch {
97                expected: self.dim,
98                got: obs.len(),
99            });
100        }
101        self.count += 1;
102        let n = self.count as f64;
103        for (i, &x) in obs.iter().enumerate() {
104            let x64 = x as f64;
105            let delta = x64 - self.mean[i];
106            self.mean[i] += delta / n;
107            let delta2 = x64 - self.mean[i];
108            self.m2[i] += delta * delta2;
109        }
110        Ok(())
111    }
112
113    /// Update statistics with a batch of observations `[B × dim]`.
114    ///
115    /// # Errors
116    ///
117    /// * [`RlError::DimensionMismatch`] if `batch.len() % dim != 0`.
118    pub fn update_batch(&mut self, batch: &[f32]) -> RlResult<()> {
119        if batch.len() % self.dim != 0 {
120            return Err(RlError::DimensionMismatch {
121                expected: self.dim,
122                got: batch.len(),
123            });
124        }
125        for chunk in batch.chunks_exact(self.dim) {
126            self.update(chunk)?;
127        }
128        Ok(())
129    }
130
131    /// Normalise a single observation: `(obs - mean) / (std + eps)`.
132    ///
133    /// # Errors
134    ///
135    /// * [`RlError::DimensionMismatch`] if `obs.len() != dim`.
136    pub fn normalise(&self, obs: &[f32]) -> RlResult<Vec<f32>> {
137        if obs.len() != self.dim {
138            return Err(RlError::DimensionMismatch {
139                expected: self.dim,
140                got: obs.len(),
141            });
142        }
143        let std = self.std_f32();
144        let mean = self.mean_f32();
145        Ok(obs
146            .iter()
147            .zip(mean.iter())
148            .zip(std.iter())
149            .map(|((&x, &m), &s)| (x - m) / (s + 1e-8))
150            .collect())
151    }
152
153    /// Reset all statistics to zero.
154    pub fn reset(&mut self) {
155        self.mean.iter_mut().for_each(|v| *v = 0.0);
156        self.m2.iter_mut().for_each(|v| *v = 0.0);
157        self.count = 0;
158    }
159}
160
161// ─── Tests ───────────────────────────────────────────────────────────────────
162
163#[cfg(test)]
164mod tests {
165    use super::*;
166
167    #[test]
168    fn initial_count_zero() {
169        let rs = RunningStats::new(3);
170        assert_eq!(rs.count(), 0);
171    }
172
173    #[test]
174    fn single_update_count_one() {
175        let mut rs = RunningStats::new(2);
176        rs.update(&[1.0, 2.0]).unwrap();
177        assert_eq!(rs.count(), 1);
178    }
179
180    #[test]
181    fn mean_converges_to_true_mean() {
182        let mut rs = RunningStats::new(1);
183        for _ in 0..1000 {
184            rs.update(&[3.0]).unwrap();
185        }
186        let mean = rs.mean_f32()[0];
187        assert!((mean - 3.0).abs() < 0.01, "mean={mean}");
188    }
189
190    #[test]
191    fn std_converges_to_true_std() {
192        // Values alternating ±1 → mean=0, std≈1
193        let mut rs = RunningStats::new(1);
194        for i in 0..2000 {
195            let v = if i % 2 == 0 { 1.0 } else { -1.0 };
196            rs.update(&[v]).unwrap();
197        }
198        let std = rs.std_f32()[0];
199        assert!((std - 1.0).abs() < 0.05, "std={std}");
200    }
201
202    #[test]
203    fn normalise_close_to_zero_mean() {
204        let mut rs = RunningStats::new(1);
205        for i in 0..100 {
206            rs.update(&[i as f32]).unwrap();
207        }
208        let norm = rs.normalise(&[50.0]).unwrap(); // near mean
209        assert!(
210            norm[0].abs() < 0.5,
211            "normalised mean should be near 0, got {}",
212            norm[0]
213        );
214    }
215
216    #[test]
217    fn normalise_dimension_error() {
218        let rs = RunningStats::new(3);
219        assert!(rs.normalise(&[1.0, 2.0]).is_err());
220    }
221
222    #[test]
223    fn update_batch_increments_count() {
224        let mut rs = RunningStats::new(2);
225        let batch = vec![1.0_f32; 10]; // 5 observations of dim 2
226        rs.update_batch(&batch).unwrap();
227        assert_eq!(rs.count(), 5);
228    }
229
230    #[test]
231    fn reset_zeroes_stats() {
232        let mut rs = RunningStats::new(2);
233        rs.update(&[3.0, 4.0]).unwrap();
234        rs.reset();
235        assert_eq!(rs.count(), 0);
236        let mean = rs.mean_f32();
237        assert!(mean.iter().all(|&m| m.abs() < 1e-9));
238    }
239
240    #[test]
241    fn std_default_before_two_samples() {
242        let mut rs = RunningStats::new(2);
243        rs.update(&[1.0, 2.0]).unwrap();
244        let std = rs.std_f32();
245        // With count < 2, returns [1, 1]
246        assert_eq!(std, vec![1.0, 1.0]);
247    }
248}