forest/utils/stats/
mod.rs1use anyhow::Context as _;
5use std::ops::AddAssign;
6
7#[derive(Default)]
8pub struct Stats<T: num::Num + num::NumCast + Copy + PartialOrd + AddAssign + Default> {
9 n: usize,
10 sum: T,
11}
12
13impl<T> Stats<T>
14where
15 T: num::Num + num::NumCast + Copy + PartialOrd + AddAssign + Default,
16{
17 pub fn new() -> Self {
18 Default::default()
19 }
20
21 pub fn update(&mut self, x: T) {
23 self.sum += x;
24 self.n += 1;
25 }
26
27 pub fn mean(&self) -> anyhow::Result<T> {
28 if self.n == 0 {
29 anyhow::bail!("not enough data");
30 }
31 let sum_f64: f64 = num::NumCast::from(self.sum).context("error casting T to f64")?;
32 let n_f64: f64 = num::NumCast::from(self.n).context("error casting T to f64")?;
33 let result: T = num::NumCast::from(sum_f64 / n_f64).context("error casting f64 to T")?;
34 Ok(result)
35 }
36}
37
38#[cfg(test)]
39mod tests {
40 use super::*;
41
42 #[test]
43 fn test_stats_mean() {
44 let mut stats = Stats::new();
45 stats.mean().unwrap_err();
46 stats.update(10);
47 assert_eq!(stats.mean().unwrap(), 10);
48 stats.update(5);
49 assert_eq!(stats.mean().unwrap(), 7);
50 stats.update(3);
51 assert_eq!(stats.mean().unwrap(), 6);
52 }
53}