1#![no_std]
2
3extern crate alloc;
4
5use core::{
6 fmt::{self, Debug},
7 ops::AddAssign,
8};
9
10use num_traits::{cast::FromPrimitive, float::Float, identities::One, identities::Zero};
11
12#[cfg(feature = "serde")]
13use serde::{Deserialize, Serialize};
14
15#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
50#[derive(Clone, Debug)]
51pub struct Stats<T: Float + Zero + One + AddAssign + FromPrimitive + PartialEq + Debug> {
52 pub min: T,
54
55 pub max: T,
57
58 pub mean: T,
60
61 pub std_dev: T,
63
64 pub count: usize,
66
67 mean2: T,
69}
70
71impl<T> fmt::Display for Stats<T>
73where
74 T: fmt::Display + Float + Zero + One + AddAssign + FromPrimitive + PartialEq + Debug,
75{
76 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
78 let precision = f.precision().unwrap_or(2);
79
80 write!(f, "(avg: {:.precision$}, std_dev: {:.precision$}, min: {:.precision$}, max: {:.precision$}, count: {})", self.mean, self.std_dev, self.min, self.max, self.count, precision=precision)
81 }
82}
83
84impl<T> Default for Stats<T>
85where
86 T: Float + Zero + One + AddAssign + FromPrimitive + PartialEq + Debug,
87{
88 fn default() -> Stats<T> {
89 Stats::new()
90 }
91}
92
93impl<T> Stats<T>
94where
95 T: Float + Zero + One + AddAssign + FromPrimitive + PartialEq + Debug,
96{
97 pub fn new() -> Stats<T> {
99 Stats {
100 count: 0,
101 min: T::infinity(),
102 max: T::neg_infinity(),
103 mean: T::zero(),
104 std_dev: T::zero(),
105 mean2: T::zero(),
106 }
107 }
108
109 pub fn update(&mut self, value: T) {
111 if value > self.max {
113 self.max = value;
114 }
115 if value < self.min {
116 self.min = value;
117 }
118
119 self.count += 1;
121 let count = T::from(self.count).unwrap();
122
123 let delta = value - self.mean;
125 self.mean += delta / count;
126
127 let delta2 = value - self.mean;
129 self.mean2 += delta * delta2;
130
131 if self.count > 1 {
133 self.std_dev = (self.mean2 / (count - T::one())).sqrt();
134 }
135 }
136
137 pub fn merge(&self, other: &Self) -> Self {
188 let mut merged = Stats::<T>::new();
189
190 if self.count + other.count == 0 {
192 return merged;
193 }
194
195 if self.count == 0 {
197 return other.clone();
198 } else if other.count == 0 {
199 return self.clone();
200 }
201
202 merged.max = if other.max > self.max {
203 other.max
204 } else {
205 self.max
206 };
207
208 merged.min = if other.min < self.min {
209 other.min
210 } else {
211 self.min
212 };
213
214 merged.count = self.count + other.count;
215
216 let merged_count = T::from(merged.count).unwrap();
218 let self_count = T::from(self.count).unwrap();
219 let other_count = T::from(other.count).unwrap();
220
221 let delta = other.mean - self.mean;
222
223 merged.mean = (self.mean * self_count + other.mean * other_count) / merged_count;
224
225 merged.mean2 =
226 self.mean2 + other.mean2 + delta * delta * self_count * other_count / merged_count;
227
228 merged.std_dev = (merged.mean2 / (merged_count - T::one())).sqrt();
229
230 merged
231 }
232}
233
234#[cfg(test)]
235mod tests {
236 use super::*;
237
238 use alloc::vec;
239 use alloc::vec::Vec;
240
241 use float_cmp::{ApproxEq, ApproxEqUlps};
242 use rand::SeedableRng;
243 use rand_distr::{Distribution, Normal};
244 use rayon::prelude::*;
245
246 type T = f64;
247
248 #[test]
249 fn it_works() {
250 let mut s: Stats<f32> = Stats::new();
251
252 let vals: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0];
253 for v in &vals {
254 s.update(*v);
255 }
256
257 assert_eq!(s.count, vals.len());
258
259 assert_eq!(s.min, 1.0);
260 assert_eq!(s.max, 5.0);
261
262 assert!(s.mean.approx_eq_ulps(&3.0, 2));
263 assert!(s.std_dev.approx_eq_ulps(&1.5811388, 2));
264 }
265
266 fn calc_mean(vals: &Vec<T>) -> T {
268 let sum = vals.iter().fold(T::zero(), |acc, x| acc + *x);
269
270 sum / T::from_usize(vals.len()).unwrap()
271 }
272
273 fn calc_std_dev(vals: &Vec<T>) -> T {
275 let mean = calc_mean(vals);
276 let std_dev = (vals
277 .iter()
278 .fold(T::zero(), |acc, x| acc + (*x - mean).powi(2))
279 / T::from_usize(vals.len() - 1).unwrap())
280 .sqrt();
281
282 std_dev
283 }
284
285 fn get_max(vals: &Vec<T>) -> T {
287 let mut max = T::min_value();
288 for v in vals {
289 if *v > max {
290 max = *v;
291 }
292 }
293 max
294 }
295
296 fn get_min(vals: &Vec<T>) -> T {
298 let mut min = T::max_value();
299 for v in vals {
300 if *v < min {
301 min = *v;
302 }
303 }
304 min
305 }
306
307 #[test]
308 fn stats_for_large_random_data() {
309 const MEAN: T = 2.0;
311 const STD_DEV: T = 3.0;
312 const SEED: u64 = 42;
313 const NUM_SAMPLES: usize = 10_000;
314
315 let mut s: Stats<T> = Stats::new();
316 let mut rng = rand::rngs::StdRng::seed_from_u64(SEED);
317
318 let normal = Normal::<T>::new(MEAN, STD_DEV).unwrap();
319
320 let random_data: Vec<T> = (0..NUM_SAMPLES).map(|_x| normal.sample(&mut rng)).collect();
322
323 random_data.iter().for_each(|v| s.update(*v));
325
326 let mean = calc_mean(&random_data);
328
329 assert!(s.mean.approx_eq(mean, (1.0e-13, 2)));
331
332 let std_dev = calc_std_dev(&random_data);
334
335 assert!(s.std_dev.approx_eq(std_dev, (1.0e-13, 2)));
337
338 assert_eq!(s.count, random_data.len());
340
341 let max = get_max(&random_data);
343 let min = get_min(&random_data);
344
345 assert_eq!(s.max, max);
347 assert_eq!(s.min, min);
348 }
349
350 #[test]
351 fn stats_merge() {
352 const MEAN: T = 2.0;
354 const STD_DEV: T = 3.0;
355 const SEED: u64 = 42;
356 const NUM_SAMPLES: usize = 10_000;
357
358 let mut s: Stats<T> = Stats::new();
359 let mut rng = rand::rngs::StdRng::seed_from_u64(SEED);
360
361 let normal = Normal::<T>::new(MEAN, STD_DEV).unwrap();
362
363 let random_data: Vec<T> = (0..NUM_SAMPLES).map(|_x| normal.sample(&mut rng)).collect();
365
366 random_data.iter().for_each(|v| s.update(*v));
368
369 let mean = calc_mean(&random_data);
371 let std_dev = calc_std_dev(&random_data);
372 let max = get_max(&random_data);
373 let min = get_min(&random_data);
374
375 let chunks_size = 1000;
376
377 let stats: Vec<Stats<T>> = random_data
378 .chunks(chunks_size)
379 .map(|chunk| {
380 let mut s: Stats<T> = Stats::new();
381 chunk.iter().for_each(|v| s.update(*v));
382 s
383 })
384 .collect();
385
386 assert_eq!(stats.len(), NUM_SAMPLES / chunks_size);
387
388 let merged_stats = stats.into_iter().reduce(|acc, s| acc.merge(&s)).unwrap();
390
391 assert!(merged_stats.mean.approx_eq(mean, (1.0e-13, 2)));
393 assert!(merged_stats.std_dev.approx_eq(std_dev, (1.0e-13, 2)));
394 assert_eq!(merged_stats.max, max);
395 assert_eq!(merged_stats.min, min);
396 assert_eq!(merged_stats.count, NUM_SAMPLES);
397
398 assert!(merged_stats.mean.approx_eq(s.mean, (1.0e-13, 2)));
400 assert!(merged_stats.std_dev.approx_eq(s.std_dev, (1.0e-13, 2)));
401 assert_eq!(merged_stats.max, s.max);
402 assert_eq!(merged_stats.min, s.min);
403 assert_eq!(merged_stats.count, s.count);
404
405 let empty_stats: Stats<T> = Stats::new();
409 let merged_stats = s.merge(&empty_stats);
410 assert_eq!(merged_stats.count, s.count);
411
412 let empty_stats: Stats<T> = Stats::new();
414 let merged_stats = empty_stats.merge(&s);
415 assert_eq!(merged_stats.count, s.count);
416
417 let empty_stats_1: Stats<T> = Stats::new();
419 let empty_stats_2: Stats<T> = Stats::new();
420
421 let merged_stats = empty_stats_1.merge(&empty_stats_2);
422 assert_eq!(merged_stats.count, 0);
423 }
424
425 #[test]
426 fn stats_merge_parallel() {
427 const MEAN: T = 2.0;
429 const STD_DEV: T = 3.0;
430 const SEED: u64 = 42;
431 const NUM_SAMPLES: usize = 500_000;
432
433 let mut s: Stats<T> = Stats::new();
434 let mut rng = rand::rngs::StdRng::seed_from_u64(SEED);
435
436 let normal = Normal::<T>::new(MEAN, STD_DEV).unwrap();
437
438 let random_data: Vec<T> = (0..NUM_SAMPLES).map(|_x| normal.sample(&mut rng)).collect();
440
441 random_data.iter().for_each(|v| s.update(*v));
443
444 let mean = calc_mean(&random_data);
446 let std_dev = calc_std_dev(&random_data);
447 let max = get_max(&random_data);
448 let min = get_min(&random_data);
449
450 let chunks_size = 1000;
451
452 let stats: Vec<Stats<T>> = random_data
453 .par_chunks(chunks_size) .map(|chunk| {
455 let mut s: Stats<T> = Stats::new();
456 chunk.iter().for_each(|v| s.update(*v));
457 s
458 })
459 .collect();
460
461 assert!(stats.len() >= NUM_SAMPLES / chunks_size);
463
464 let merged_stats = stats.into_iter().reduce(|acc, s| acc.merge(&s)).unwrap();
466
467 assert!(merged_stats.mean.approx_eq(mean, (1.0e-13, 2)));
469 assert!(merged_stats.std_dev.approx_eq(std_dev, (1.0e-13, 2)));
470 assert_eq!(merged_stats.max, max);
471 assert_eq!(merged_stats.min, min);
472 assert_eq!(merged_stats.count, NUM_SAMPLES);
473
474 assert!(merged_stats.mean.approx_eq(s.mean, (1.0e-13, 2)));
476 assert!(merged_stats.std_dev.approx_eq(s.std_dev, (1.0e-13, 2)));
477 assert_eq!(merged_stats.max, s.max);
478 assert_eq!(merged_stats.min, s.min);
479 assert_eq!(merged_stats.count, s.count);
480 }
481}