rolling_stats/
lib.rs

1#![no_std]
2
3extern crate alloc;
4
5use core::{
6    fmt::{self, Debug},
7    ops::AddAssign,
8};
9
10use num_traits::{cast::FromPrimitive, float::Float, identities::One, identities::Zero};
11
12#[cfg(feature = "serde")]
13use serde::{Deserialize, Serialize};
14
15/// A statistics object that continuously calculates min, max, mean, and deviation for tracking time-varying statistics.
16/// Utilizes Welford's Online algorithm. More details on the algorithm can be found at:
17/// "https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm"
18///
19///
20/// # Example
21///
22/// ```
23/// use rolling_stats::Stats;
24/// use rand_distr::{Distribution, Normal};
25/// use rand::SeedableRng;
26///
27/// type T = f64;
28///
29/// const MEAN: T = 0.0;
30/// const STD_DEV: T = 1.0;
31/// const NUM_SAMPLES: usize = 10_000;
32/// const SEED: u64 = 42;
33///
34/// let mut stats: Stats<T> = Stats::new();
35/// let mut rng = rand::rngs::StdRng::seed_from_u64(SEED); // Seed the RNG for reproducibility
36/// let normal = Normal::<T>::new(MEAN, STD_DEV).unwrap();
37///
38/// // Generate random data
39/// let random_data: Vec<T> = (0..NUM_SAMPLES).map(|_x| normal.sample(&mut rng)).collect();
40///
41/// // Update the stats one by one
42/// random_data.iter().for_each(|v| stats.update(*v));
43///
44/// // Print the stats
45/// println!("{}", stats);
46/// // Output: (avg: 0.00, std_dev: 1.00, min: -3.53, max: 4.11, count: 10000)
47///
48/// ```
49#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
50#[derive(Clone, Debug)]
51pub struct Stats<T: Float + Zero + One + AddAssign + FromPrimitive + PartialEq + Debug> {
52    /// The smallest value seen so far.
53    pub min: T,
54
55    /// The largest value seen so far.
56    pub max: T,
57
58    /// The calculated mean (average) of all the values seen so far.
59    pub mean: T,
60
61    /// The calculated standard deviation of all the values seen so far.
62    pub std_dev: T,
63
64    /// The count of the total values seen.
65    pub count: usize,
66
67    /// The square of the mean value. This is an internal value used in the calculation of the standard deviation.
68    mean2: T,
69}
70
71/// Implementing the Display trait for the Stats struct to present the statistics in a readable format.
72impl<T> fmt::Display for Stats<T>
73where
74    T: fmt::Display + Float + Zero + One + AddAssign + FromPrimitive + PartialEq + Debug,
75{
76    /// Formats the output of the statistics.
77    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
78        let precision = f.precision().unwrap_or(2);
79
80        write!(f, "(avg: {:.precision$}, std_dev: {:.precision$}, min: {:.precision$}, max: {:.precision$}, count: {})", self.mean, self.std_dev, self.min, self.max, self.count, precision=precision)
81    }
82}
83
84impl<T> Default for Stats<T>
85where
86    T: Float + Zero + One + AddAssign + FromPrimitive + PartialEq + Debug,
87{
88    fn default() -> Stats<T> {
89        Stats::new()
90    }
91}
92
93impl<T> Stats<T>
94where
95    T: Float + Zero + One + AddAssign + FromPrimitive + PartialEq + Debug,
96{
97    /// Creates a new stats object with all values set to their initial states.
98    pub fn new() -> Stats<T> {
99        Stats {
100            count: 0,
101            min: T::infinity(),
102            max: T::neg_infinity(),
103            mean: T::zero(),
104            std_dev: T::zero(),
105            mean2: T::zero(),
106        }
107    }
108
109    /// Updates the stats object with a new value. The statistics are recalculated using the new value.
110    pub fn update(&mut self, value: T) {
111        // Track min and max
112        if value > self.max {
113            self.max = value;
114        }
115        if value < self.min {
116            self.min = value;
117        }
118
119        // Increment counter
120        self.count += 1;
121        let count = T::from(self.count).unwrap();
122
123        // Calculate mean
124        let delta = value - self.mean;
125        self.mean += delta / count;
126
127        // Mean2 used internally for standard deviation calculation
128        let delta2 = value - self.mean;
129        self.mean2 += delta * delta2;
130
131        // Calculate standard deviation
132        if self.count > 1 {
133            self.std_dev = (self.mean2 / (count - T::one())).sqrt();
134        }
135    }
136
137    /// Merges another stats object into new one. This is done by combining the statistics of the two objects
138    /// in accordance with the formula provided at:
139    /// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
140    ///
141    /// This is useful for combining statistics from multiple threads or processes.
142    ///
143    /// # Example
144    ///
145    /// ```
146    /// use rolling_stats::Stats;
147    /// use rand_distr::{Distribution, Normal};
148    /// use rand::SeedableRng;
149    /// use rayon::prelude::*;
150    ///
151    /// type T = f64;
152    ///
153    /// const MEAN: T = 0.0;
154    /// const STD_DEV: T = 1.0;
155    /// const NUM_SAMPLES: usize = 500_000;
156    /// const SEED: u64 = 42;
157    /// const CHUNK_SIZE: usize = 1000;
158    ///
159    /// let mut stats: Stats<T> = Stats::new();
160    /// let mut rng = rand::rngs::StdRng::seed_from_u64(SEED); // Seed the RNG for reproducibility
161    /// let normal = Normal::<T>::new(MEAN, STD_DEV).unwrap();
162    ///
163    /// // Generate random data
164    /// let random_data: Vec<T> = (0..NUM_SAMPLES).map(|_x| normal.sample(&mut rng)).collect();
165    ///
166    /// // Update the stats in parallel. New stats objects are created for each chunk of data.
167    /// let stats: Vec<Stats<T>> = random_data
168    ///     .par_chunks(CHUNK_SIZE) // Multi-threaded parallelization via Rayon
169    ///     .map(|chunk| {
170    ///             let mut s: Stats<T> = Stats::new();
171    ///             chunk.iter().for_each(|v| s.update(*v));
172    ///             s
173    ///      })
174    ///     .collect();
175    ///
176    /// // Check if there's more than one stat object
177    /// assert!(stats.len() > 1);
178    ///
179    /// // Accumulate the stats using the reduce method. The last stats object is returned.
180    /// let merged_stats = stats.into_iter().reduce(|acc, s| acc.merge(&s)).unwrap();
181    ///
182    /// // Print the stats
183    /// println!("{}", merged_stats);
184    ///
185    /// // Output: (avg: -0.00, std_dev: 1.00, min: -4.53, max: 4.57, count: 500000)
186    ///```
187    pub fn merge(&self, other: &Self) -> Self {
188        let mut merged = Stats::<T>::new();
189
190        // If both stats objects are empty, return an empty stats object
191        if self.count + other.count == 0 {
192            return merged;
193        }
194
195        // If one of the stats objects is empty, return the other one
196        if self.count == 0 {
197            return other.clone();
198        } else if other.count == 0 {
199            return self.clone();
200        }
201
202        merged.max = if other.max > self.max {
203            other.max
204        } else {
205            self.max
206        };
207
208        merged.min = if other.min < self.min {
209            other.min
210        } else {
211            self.min
212        };
213
214        merged.count = self.count + other.count;
215
216        // Convert to T to avoid overflow
217        let merged_count = T::from(merged.count).unwrap();
218        let self_count = T::from(self.count).unwrap();
219        let other_count = T::from(other.count).unwrap();
220
221        let delta = other.mean - self.mean;
222
223        merged.mean = (self.mean * self_count + other.mean * other_count) / merged_count;
224
225        merged.mean2 =
226            self.mean2 + other.mean2 + delta * delta * self_count * other_count / merged_count;
227
228        merged.std_dev = (merged.mean2 / (merged_count - T::one())).sqrt();
229
230        merged
231    }
232}
233
234#[cfg(test)]
235mod tests {
236    use super::*;
237
238    use alloc::vec;
239    use alloc::vec::Vec;
240
241    use float_cmp::{ApproxEq, ApproxEqUlps};
242    use rand::SeedableRng;
243    use rand_distr::{Distribution, Normal};
244    use rayon::prelude::*;
245
246    type T = f64;
247
248    #[test]
249    fn it_works() {
250        let mut s: Stats<f32> = Stats::new();
251
252        let vals: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0];
253        for v in &vals {
254            s.update(*v);
255        }
256
257        assert_eq!(s.count, vals.len());
258
259        assert_eq!(s.min, 1.0);
260        assert_eq!(s.max, 5.0);
261
262        assert!(s.mean.approx_eq_ulps(&3.0, 2));
263        assert!(s.std_dev.approx_eq_ulps(&1.5811388, 2));
264    }
265
266    /// Calculate the mean of a vector of values
267    fn calc_mean(vals: &Vec<T>) -> T {
268        let sum = vals.iter().fold(T::zero(), |acc, x| acc + *x);
269
270        sum / T::from_usize(vals.len()).unwrap()
271    }
272
273    /// Calculate the standard deviation of a vector of values
274    fn calc_std_dev(vals: &Vec<T>) -> T {
275        let mean = calc_mean(vals);
276        let std_dev = (vals
277            .iter()
278            .fold(T::zero(), |acc, x| acc + (*x - mean).powi(2))
279            / T::from_usize(vals.len() - 1).unwrap())
280        .sqrt();
281
282        std_dev
283    }
284
285    /// Get the maximum value in a vector of values
286    fn get_max(vals: &Vec<T>) -> T {
287        let mut max = T::min_value();
288        for v in vals {
289            if *v > max {
290                max = *v;
291            }
292        }
293        max
294    }
295
296    /// Get the minimum value in a vector of values
297    fn get_min(vals: &Vec<T>) -> T {
298        let mut min = T::max_value();
299        for v in vals {
300            if *v < min {
301                min = *v;
302            }
303        }
304        min
305    }
306
307    #[test]
308    fn stats_for_large_random_data() {
309        // Define some constants
310        const MEAN: T = 2.0;
311        const STD_DEV: T = 3.0;
312        const SEED: u64 = 42;
313        const NUM_SAMPLES: usize = 10_000;
314
315        let mut s: Stats<T> = Stats::new();
316        let mut rng = rand::rngs::StdRng::seed_from_u64(SEED);
317
318        let normal = Normal::<T>::new(MEAN, STD_DEV).unwrap();
319
320        // Generate some random data
321        let random_data: Vec<T> = (0..NUM_SAMPLES).map(|_x| normal.sample(&mut rng)).collect();
322
323        // Update the stats
324        random_data.iter().for_each(|v| s.update(*v));
325
326        // Calculate the mean using sum/count method
327        let mean = calc_mean(&random_data);
328
329        // Check the mean value against the stats' mean value
330        assert!(s.mean.approx_eq(mean, (1.0e-13, 2)));
331
332        // Calculate the standard deviation
333        let std_dev = calc_std_dev(&random_data);
334
335        // Check the standard deviation against the stats' standard deviation
336        assert!(s.std_dev.approx_eq(std_dev, (1.0e-13, 2)));
337
338        // Check the count
339        assert_eq!(s.count, random_data.len());
340
341        // Find the max and min values
342        let max = get_max(&random_data);
343        let min = get_min(&random_data);
344
345        // Check the max and min values
346        assert_eq!(s.max, max);
347        assert_eq!(s.min, min);
348    }
349
350    #[test]
351    fn stats_merge() {
352        // Define some constants
353        const MEAN: T = 2.0;
354        const STD_DEV: T = 3.0;
355        const SEED: u64 = 42;
356        const NUM_SAMPLES: usize = 10_000;
357
358        let mut s: Stats<T> = Stats::new();
359        let mut rng = rand::rngs::StdRng::seed_from_u64(SEED);
360
361        let normal = Normal::<T>::new(MEAN, STD_DEV).unwrap();
362
363        // Generate some random data
364        let random_data: Vec<T> = (0..NUM_SAMPLES).map(|_x| normal.sample(&mut rng)).collect();
365
366        // Update the stats
367        random_data.iter().for_each(|v| s.update(*v));
368
369        // Calculate the stats using the aggregate method instead of the rolling method
370        let mean = calc_mean(&random_data);
371        let std_dev = calc_std_dev(&random_data);
372        let max = get_max(&random_data);
373        let min = get_min(&random_data);
374
375        let chunks_size = 1000;
376
377        let stats: Vec<Stats<T>> = random_data
378            .chunks(chunks_size)
379            .map(|chunk| {
380                let mut s: Stats<T> = Stats::new();
381                chunk.iter().for_each(|v| s.update(*v));
382                s
383            })
384            .collect();
385
386        assert_eq!(stats.len(), NUM_SAMPLES / chunks_size);
387
388        // Accumulate the stats
389        let merged_stats = stats.into_iter().reduce(|acc, s| acc.merge(&s)).unwrap();
390
391        // Check the stats against the aggregate stats (using sum/count method)
392        assert!(merged_stats.mean.approx_eq(mean, (1.0e-13, 2)));
393        assert!(merged_stats.std_dev.approx_eq(std_dev, (1.0e-13, 2)));
394        assert_eq!(merged_stats.max, max);
395        assert_eq!(merged_stats.min, min);
396        assert_eq!(merged_stats.count, NUM_SAMPLES);
397
398        // Check the stats against the merged stats object
399        assert!(merged_stats.mean.approx_eq(s.mean, (1.0e-13, 2)));
400        assert!(merged_stats.std_dev.approx_eq(s.std_dev, (1.0e-13, 2)));
401        assert_eq!(merged_stats.max, s.max);
402        assert_eq!(merged_stats.min, s.min);
403        assert_eq!(merged_stats.count, s.count);
404
405        // Check edge cases
406
407        // Check merging with an empty stats object
408        let empty_stats: Stats<T> = Stats::new();
409        let merged_stats = s.merge(&empty_stats);
410        assert_eq!(merged_stats.count, s.count);
411
412        // Check merging an empty stats object with a non-empty stats object
413        let empty_stats: Stats<T> = Stats::new();
414        let merged_stats = empty_stats.merge(&s);
415        assert_eq!(merged_stats.count, s.count);
416
417        // Check merging two empty stats objects
418        let empty_stats_1: Stats<T> = Stats::new();
419        let empty_stats_2: Stats<T> = Stats::new();
420
421        let merged_stats = empty_stats_1.merge(&empty_stats_2);
422        assert_eq!(merged_stats.count, 0);
423    }
424
425    #[test]
426    fn stats_merge_parallel() {
427        // Define some constants
428        const MEAN: T = 2.0;
429        const STD_DEV: T = 3.0;
430        const SEED: u64 = 42;
431        const NUM_SAMPLES: usize = 500_000;
432
433        let mut s: Stats<T> = Stats::new();
434        let mut rng = rand::rngs::StdRng::seed_from_u64(SEED);
435
436        let normal = Normal::<T>::new(MEAN, STD_DEV).unwrap();
437
438        // Generate some random data
439        let random_data: Vec<T> = (0..NUM_SAMPLES).map(|_x| normal.sample(&mut rng)).collect();
440
441        // Update the stats
442        random_data.iter().for_each(|v| s.update(*v));
443
444        // Calculate the stats using the aggregate method instead of the rolling method
445        let mean = calc_mean(&random_data);
446        let std_dev = calc_std_dev(&random_data);
447        let max = get_max(&random_data);
448        let min = get_min(&random_data);
449
450        let chunks_size = 1000;
451
452        let stats: Vec<Stats<T>> = random_data
453            .par_chunks(chunks_size) // <--- Parallelization by Rayon
454            .map(|chunk| {
455                let mut s: Stats<T> = Stats::new();
456                chunk.iter().for_each(|v| s.update(*v));
457                s
458            })
459            .collect();
460
461        // There should be more than one stat
462        assert!(stats.len() >= NUM_SAMPLES / chunks_size);
463
464        // Accumulate the stats
465        let merged_stats = stats.into_iter().reduce(|acc, s| acc.merge(&s)).unwrap();
466
467        // Check the stats against the aggregate stats (using sum/count method)
468        assert!(merged_stats.mean.approx_eq(mean, (1.0e-13, 2)));
469        assert!(merged_stats.std_dev.approx_eq(std_dev, (1.0e-13, 2)));
470        assert_eq!(merged_stats.max, max);
471        assert_eq!(merged_stats.min, min);
472        assert_eq!(merged_stats.count, NUM_SAMPLES);
473
474        // Check the stats against the merged stats object
475        assert!(merged_stats.mean.approx_eq(s.mean, (1.0e-13, 2)));
476        assert!(merged_stats.std_dev.approx_eq(s.std_dev, (1.0e-13, 2)));
477        assert_eq!(merged_stats.max, s.max);
478        assert_eq!(merged_stats.min, s.min);
479        assert_eq!(merged_stats.count, s.count);
480    }
481}