online_statistics/
mean.rs1use num::{Float, FromPrimitive};
2use std::ops::{AddAssign, SubAssign};
3
4use crate::count::Count;
5use crate::stats::{Revertable, RollableUnivariate, Univariate};
6use serde::{Deserialize, Serialize};
7
8#[derive(Clone, Copy, Default, Debug, Serialize, Deserialize)]
32pub struct Mean<F: Float + FromPrimitive + AddAssign + SubAssign> {
33 pub mean: F,
34 pub n: Count<F>,
35}
36impl<F: Float + FromPrimitive + AddAssign + SubAssign> Mean<F> {
37 pub fn new() -> Self {
38 Self {
39 mean: F::from_f64(0.0).unwrap(),
40 n: Count::new(),
41 }
42 }
43}
44
45impl<F: Float + FromPrimitive + AddAssign + SubAssign> Univariate<F> for Mean<F> {
46 fn update(&mut self, x: F) {
47 self.n.update(x);
48 self.mean += (F::from_f64(1.).unwrap() / self.n.get()) * (x - self.mean);
49 }
50 fn get(&self) -> F {
51 self.mean
52 }
53}
54
55impl<F: Float + FromPrimitive + AddAssign + SubAssign> Revertable<F> for Mean<F> {
56 fn revert(&mut self, x: F) -> Result<(), &'static str> {
57 match self.n.revert(x) {
58 Ok(it) => it,
59 Err(err) => return Err(err),
60 };
61
62 let count = self.n.get();
63 if count == F::from_f64(0.).unwrap() {
64 self.mean = F::from_f64(0.0).unwrap();
65 } else {
66 self.mean -= (F::from_f64(1.0).unwrap() / count) * (x - self.mean);
67 }
68 Ok(())
69 }
70}
71
72impl<F: Float + FromPrimitive + AddAssign + SubAssign> RollableUnivariate<F> for Mean<F> {}