oxicuda_rl/normalize/
running_stats.rs1use crate::error::{RlError, RlResult};
18
19#[derive(Debug, Clone)]
23pub struct RunningStats {
24 dim: usize,
25 mean: Vec<f64>,
27 m2: Vec<f64>,
29 count: u64,
31}
32
33impl RunningStats {
34 #[must_use]
36 pub fn new(dim: usize) -> Self {
37 assert!(dim > 0, "dim must be > 0");
38 Self {
39 dim,
40 mean: vec![0.0_f64; dim],
41 m2: vec![0.0_f64; dim],
42 count: 0,
43 }
44 }
45
46 #[must_use]
48 #[inline]
49 pub fn dim(&self) -> usize {
50 self.dim
51 }
52
53 #[must_use]
55 #[inline]
56 pub fn count(&self) -> u64 {
57 self.count
58 }
59
60 #[must_use]
62 pub fn mean_f32(&self) -> Vec<f32> {
63 self.mean.iter().map(|&m| m as f32).collect()
64 }
65
66 #[must_use]
68 pub fn std_f32(&self) -> Vec<f32> {
69 if self.count < 2 {
70 return vec![1.0_f32; self.dim];
71 }
72 let n = (self.count - 1) as f64;
73 self.m2
74 .iter()
75 .map(|&m2| ((m2 / n).max(1e-8)).sqrt() as f32)
76 .collect()
77 }
78
79 #[must_use]
81 pub fn var_f32(&self) -> Vec<f32> {
82 if self.count < 2 {
83 return vec![1.0_f32; self.dim];
84 }
85 let n = (self.count - 1) as f64;
86 self.m2.iter().map(|&m2| (m2 / n) as f32).collect()
87 }
88
89 pub fn update(&mut self, obs: &[f32]) -> RlResult<()> {
95 if obs.len() != self.dim {
96 return Err(RlError::DimensionMismatch {
97 expected: self.dim,
98 got: obs.len(),
99 });
100 }
101 self.count += 1;
102 let n = self.count as f64;
103 for (i, &x) in obs.iter().enumerate() {
104 let x64 = x as f64;
105 let delta = x64 - self.mean[i];
106 self.mean[i] += delta / n;
107 let delta2 = x64 - self.mean[i];
108 self.m2[i] += delta * delta2;
109 }
110 Ok(())
111 }
112
113 pub fn update_batch(&mut self, batch: &[f32]) -> RlResult<()> {
119 if batch.len() % self.dim != 0 {
120 return Err(RlError::DimensionMismatch {
121 expected: self.dim,
122 got: batch.len(),
123 });
124 }
125 for chunk in batch.chunks_exact(self.dim) {
126 self.update(chunk)?;
127 }
128 Ok(())
129 }
130
131 pub fn normalise(&self, obs: &[f32]) -> RlResult<Vec<f32>> {
137 if obs.len() != self.dim {
138 return Err(RlError::DimensionMismatch {
139 expected: self.dim,
140 got: obs.len(),
141 });
142 }
143 let std = self.std_f32();
144 let mean = self.mean_f32();
145 Ok(obs
146 .iter()
147 .zip(mean.iter())
148 .zip(std.iter())
149 .map(|((&x, &m), &s)| (x - m) / (s + 1e-8))
150 .collect())
151 }
152
153 pub fn reset(&mut self) {
155 self.mean.iter_mut().for_each(|v| *v = 0.0);
156 self.m2.iter_mut().for_each(|v| *v = 0.0);
157 self.count = 0;
158 }
159}
160
161#[cfg(test)]
164mod tests {
165 use super::*;
166
167 #[test]
168 fn initial_count_zero() {
169 let rs = RunningStats::new(3);
170 assert_eq!(rs.count(), 0);
171 }
172
173 #[test]
174 fn single_update_count_one() {
175 let mut rs = RunningStats::new(2);
176 rs.update(&[1.0, 2.0]).unwrap();
177 assert_eq!(rs.count(), 1);
178 }
179
180 #[test]
181 fn mean_converges_to_true_mean() {
182 let mut rs = RunningStats::new(1);
183 for _ in 0..1000 {
184 rs.update(&[3.0]).unwrap();
185 }
186 let mean = rs.mean_f32()[0];
187 assert!((mean - 3.0).abs() < 0.01, "mean={mean}");
188 }
189
190 #[test]
191 fn std_converges_to_true_std() {
192 let mut rs = RunningStats::new(1);
194 for i in 0..2000 {
195 let v = if i % 2 == 0 { 1.0 } else { -1.0 };
196 rs.update(&[v]).unwrap();
197 }
198 let std = rs.std_f32()[0];
199 assert!((std - 1.0).abs() < 0.05, "std={std}");
200 }
201
202 #[test]
203 fn normalise_close_to_zero_mean() {
204 let mut rs = RunningStats::new(1);
205 for i in 0..100 {
206 rs.update(&[i as f32]).unwrap();
207 }
208 let norm = rs.normalise(&[50.0]).unwrap(); assert!(
210 norm[0].abs() < 0.5,
211 "normalised mean should be near 0, got {}",
212 norm[0]
213 );
214 }
215
216 #[test]
217 fn normalise_dimension_error() {
218 let rs = RunningStats::new(3);
219 assert!(rs.normalise(&[1.0, 2.0]).is_err());
220 }
221
222 #[test]
223 fn update_batch_increments_count() {
224 let mut rs = RunningStats::new(2);
225 let batch = vec![1.0_f32; 10]; rs.update_batch(&batch).unwrap();
227 assert_eq!(rs.count(), 5);
228 }
229
230 #[test]
231 fn reset_zeroes_stats() {
232 let mut rs = RunningStats::new(2);
233 rs.update(&[3.0, 4.0]).unwrap();
234 rs.reset();
235 assert_eq!(rs.count(), 0);
236 let mean = rs.mean_f32();
237 assert!(mean.iter().all(|&m| m.abs() < 1e-9));
238 }
239
240 #[test]
241 fn std_default_before_two_samples() {
242 let mut rs = RunningStats::new(2);
243 rs.update(&[1.0, 2.0]).unwrap();
244 let std = rs.std_f32();
245 assert_eq!(std, vec![1.0, 1.0]);
247 }
248}