Skip to main content

forest/utils/stats/
mod.rs

1// Copyright 2019-2026 ChainSafe Systems
2// SPDX-License-Identifier: Apache-2.0, MIT
3
4use anyhow::Context as _;
5use std::ops::AddAssign;
6
7#[derive(Default)]
8pub struct Stats<T: num::Num + num::NumCast + Copy + PartialOrd + AddAssign + Default> {
9    n: usize,
10    sum: T,
11}
12
13impl<T> Stats<T>
14where
15    T: num::Num + num::NumCast + Copy + PartialOrd + AddAssign + Default,
16{
17    pub fn new() -> Self {
18        Default::default()
19    }
20
21    /// Update the moments with the given value.
22    pub fn update(&mut self, x: T) {
23        self.sum += x;
24        self.n += 1;
25    }
26
27    pub fn mean(&self) -> anyhow::Result<T> {
28        if self.n == 0 {
29            anyhow::bail!("not enough data");
30        }
31        let sum_f64: f64 = num::NumCast::from(self.sum).context("error casting T to f64")?;
32        let n_f64: f64 = num::NumCast::from(self.n).context("error casting T to f64")?;
33        let result: T = num::NumCast::from(sum_f64 / n_f64).context("error casting f64 to T")?;
34        Ok(result)
35    }
36}
37
38#[cfg(test)]
39mod tests {
40    use super::*;
41
42    #[test]
43    fn test_stats_mean() {
44        let mut stats = Stats::new();
45        stats.mean().unwrap_err();
46        stats.update(10);
47        assert_eq!(stats.mean().unwrap(), 10);
48        stats.update(5);
49        assert_eq!(stats.mean().unwrap(), 7);
50        stats.update(3);
51        assert_eq!(stats.mean().unwrap(), 6);
52    }
53}