#[cfg(feature = "serde1")]
use serde::{Deserialize, Serialize};
use crate::data::DataOrSuffStat;
use crate::dist::Gaussian;
use crate::traits::SuffStat;
#[derive(Debug, Clone, PartialEq, Copy)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))]
pub struct GaussianSuffStat {
n: usize,
mean: f64,
sx: f64,
}
impl GaussianSuffStat {
#[inline]
#[must_use]
pub fn new() -> Self {
GaussianSuffStat {
n: 0,
mean: 0.0,
sx: 0.0,
}
}
#[inline]
#[must_use]
pub fn from_parts_unchecked(n: usize, mean: f64, sx: f64) -> Self {
GaussianSuffStat { n, mean, sx }
}
#[inline]
#[must_use]
pub fn n(&self) -> usize {
self.n
}
#[inline]
#[must_use]
pub fn mean(&self) -> f64 {
self.mean
}
#[inline]
#[must_use]
pub fn sum_x(&self) -> f64 {
self.mean * self.n as f64
}
#[inline]
#[must_use]
pub fn sum_x_sq(&self) -> f64 {
let nf = self.n as f64;
(self.mean() * self.mean()).mul_add(nf, self.sx)
}
#[inline]
#[must_use]
pub fn sum_sq_diff(&self) -> f64 {
self.sx
}
}
impl Default for GaussianSuffStat {
fn default() -> Self {
GaussianSuffStat::new()
}
}
macro_rules! impl_gaussian_suffstat {
($kind:ty) => {
impl<'a> From<&'a GaussianSuffStat>
for DataOrSuffStat<'a, $kind, Gaussian>
{
fn from(stat: &'a GaussianSuffStat) -> Self {
DataOrSuffStat::SuffStat(stat)
}
}
impl<'a> From<&'a Vec<$kind>> for DataOrSuffStat<'a, $kind, Gaussian> {
fn from(xs: &'a Vec<$kind>) -> Self {
DataOrSuffStat::Data(xs)
}
}
impl<'a> From<&'a [$kind]> for DataOrSuffStat<'a, $kind, Gaussian> {
fn from(xs: &'a [$kind]) -> Self {
DataOrSuffStat::Data(xs)
}
}
impl From<&Vec<$kind>> for GaussianSuffStat {
fn from(xs: &Vec<$kind>) -> Self {
let mut stat = GaussianSuffStat::new();
stat.observe_many(xs);
stat
}
}
impl From<&[$kind]> for GaussianSuffStat {
fn from(xs: &[$kind]) -> Self {
let mut stat = GaussianSuffStat::new();
stat.observe_many(xs);
stat
}
}
impl SuffStat<$kind> for GaussianSuffStat {
fn n(&self) -> usize {
self.n
}
fn observe(&mut self, x: &$kind) {
let xf = f64::from(*x);
if self.n == 0 {
*self = GaussianSuffStat {
n: 1,
mean: xf,
sx: 0.0,
};
} else {
let n = self.n;
let mean = self.mean;
let sx = self.sx;
let n1 = n + 1;
let mean_xn =
(xf - mean).mul_add((n1 as f64).recip(), mean);
self.n = n + 1;
self.mean = mean_xn;
self.sx = (xf - mean).mul_add(xf - mean_xn, sx);
}
}
fn forget(&mut self, x: &$kind) {
let n = self.n;
if n > 1 {
let xf = f64::from(*x);
let mean = self.mean;
let n_float = n as f64;
let nm1 = n_float - 1.0;
let nm1_recip = nm1.recip();
let old_mean =
(n_float * nm1_recip).mul_add(mean, -xf * nm1_recip);
let sx = (xf - old_mean).mul_add(-(xf - mean), self.sx);
*self = GaussianSuffStat {
n: n - 1,
mean: old_mean,
sx,
};
} else {
*self = GaussianSuffStat {
n: 0,
mean: 0.0,
sx: 0.0,
};
}
}
fn merge(&mut self, other: Self) {
if other.n == 0 {
return;
}
let n1 = self.n as f64;
let n2 = other.n as f64;
let m1 = self.mean;
let m2 = other.mean;
let sum = n1 + n2;
let mean = n1.mul_add(m1, n2 * m2) / sum;
let d1 = m1 - mean;
let d2 = m2 - mean;
let sx = (n2 * d2)
.mul_add(d2, (n1 * d1).mul_add(d1, self.sx + other.sx));
self.mean = mean;
self.sx = sx;
self.n += other.n;
}
}
};
}
#[cfg(feature = "experimental")]
impl_gaussian_suffstat!(f16);
impl_gaussian_suffstat!(f32);
impl_gaussian_suffstat!(f64);
#[cfg(test)]
mod tests {
use crate::traits::Sampleable;
use super::*;
#[test]
fn from_parts_unchecked() {
let stat = GaussianSuffStat::from_parts_unchecked(10, 0.5, 1.2);
assert_eq!(stat.n(), 10);
assert_eq!(stat.mean(), 0.5);
assert_eq!(stat.sx, 1.2);
}
#[test]
fn suffstat_increments_correctly() {
let xs: Vec<f64> = vec![0.0, 1.2, 2.3, 4.6];
let mut suffstat = GaussianSuffStat::new();
for x in xs {
suffstat.observe(&x);
}
assert_eq!(suffstat.n(), 4);
assert::close(suffstat.mean(), 2.025, 1e-14);
assert::close(suffstat.sum_x(), 8.1, 1e-14);
assert::close(suffstat.sum_x_sq(), 27.889_999_999_999_993, 1e-14);
}
#[test]
fn suffstat_decrements_correctly() {
let xs: Vec<f64> = vec![0.0, 1.2, 2.3, 4.6];
let mut suffstat = GaussianSuffStat::new();
for x in xs {
suffstat.observe(&x);
}
suffstat.observe(&5.0);
suffstat.forget(&5.0);
assert_eq!(suffstat.n(), 4);
assert::close(suffstat.mean(), 2.025, 1e-14);
assert::close(suffstat.sum_x(), 8.1, 1e-14);
assert::close(suffstat.sum_x_sq(), 27.889_999_999_999_993, 1e-13);
}
#[test]
fn incremental_merge() {
let mut rng = rand::rng();
let g = crate::dist::Gaussian::standard();
let xs: Vec<f64> = g.sample(5, &mut rng);
let stat_a = {
let mut stat = GaussianSuffStat::new();
stat.observe_many(&xs);
stat
};
let mut stat_b = {
let mut stat = GaussianSuffStat::new();
stat.observe(&xs[0]);
stat
};
for x in xs.iter().skip(1) {
let mut stat_temp = GaussianSuffStat::new();
stat_temp.observe(x);
<GaussianSuffStat as SuffStat<f64>>::merge(&mut stat_b, stat_temp);
}
assert_eq!(stat_a.n, stat_b.n);
assert::close(stat_a.mean, stat_b.mean, 1e-10);
assert::close(stat_a.sx, stat_b.sx, 1e-10);
}
}