use std::{
marker::PhantomData,
ops::{AddAssign, Div, Mul},
};
use num_traits::{Float, NumOps, One, Signed};
use crate::FillWithWeighted;
#[derive(Copy, Default, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct WeightedMean<T = f64, W = f64, O = f64, C = u32> {
sumwt: T,
sumwt2: T,
sumw: W,
sumw2: W,
count: C,
phantom_output_type: PhantomData<O>,
}
impl<T, W, O, C> WeightedMean<T, W, O, C>
where
T: Copy,
W: Copy,
O: From<T> + From<W> + From<C> + NumOps + Signed + Copy,
C: Copy,
{
pub fn new<I>(values: I) -> Self
where
I: IntoIterator<Item = (T, W)>,
Self: FillWithWeighted<T, W> + Default,
{
let mut r = Self::default();
values
.into_iter()
.for_each(|it| r.fill_with_weighted(it.0, it.1));
r
}
pub fn get(&self) -> O {
self.mean()
}
pub fn mean(&self) -> <O as Div>::Output {
O::from(self.sumwt) / O::from(self.sumw)
}
pub fn num_samples(&self) -> C {
self.count
}
pub fn variance_of_samples(&self) -> O {
let mu = self.mean();
let mu2 = mu * mu;
mu2 + (O::from(self.sumwt2) - (O::one() + O::one()) * O::from(self.sumwt) * mu)
/ O::from(self.sumw)
}
pub fn standard_deviation_of_samples(&self) -> O
where
O: Float,
{
self.variance_of_samples().sqrt()
}
pub fn variance_of_mean(&self) -> O {
self.variance_of_samples() * O::from(self.sumw2) / (O::from(self.sumw) * O::from(self.sumw))
}
pub fn standard_error_of_mean(&self) -> O
where
O: Float,
{
self.variance_of_mean().sqrt()
}
}
impl<T, W, O, C> FillWithWeighted<T, W> for WeightedMean<T, W, O, C>
where
T: Copy + AddAssign + Mul<W, Output = T> + Mul<T, Output = T>,
W: Copy + AddAssign + Mul<W, Output = W>,
C: AddAssign + One,
{
#[inline]
fn fill_with_weighted(&mut self, value: T, weight: W) {
self.sumwt += value * weight;
self.sumwt2 += value * value * weight;
self.sumw += weight;
self.sumw2 += weight * weight;
self.count += C::one();
}
}