#[cfg(feature = "serde1")]
use serde::{Deserialize, Serialize};
use crate::data::CategoricalDatum;
use crate::data::DataOrSuffStat;
use crate::dist::Categorical;
use crate::traits::SuffStat;
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))]
pub struct CategoricalSuffStat {
n: usize,
counts: Vec<f64>,
}
impl CategoricalSuffStat {
#[inline]
#[must_use]
pub fn new(k: usize) -> Self {
CategoricalSuffStat {
n: 0,
counts: vec![0.0; k],
}
}
#[inline]
#[must_use]
pub fn from_parts_unchecked(n: usize, counts: Vec<f64>) -> Self {
CategoricalSuffStat { n, counts }
}
#[inline]
#[must_use]
pub fn n(&self) -> usize {
self.n
}
#[inline]
#[must_use]
pub fn counts(&self) -> &Vec<f64> {
&self.counts
}
}
impl<'a, X> From<&'a CategoricalSuffStat> for DataOrSuffStat<'a, X, Categorical>
where
X: CategoricalDatum,
{
fn from(stat: &'a CategoricalSuffStat) -> Self {
DataOrSuffStat::SuffStat(stat)
}
}
impl<'a, X> From<&'a Vec<X>> for DataOrSuffStat<'a, X, Categorical>
where
X: CategoricalDatum,
{
fn from(xs: &'a Vec<X>) -> Self {
DataOrSuffStat::Data(xs.as_slice())
}
}
impl<'a, X> From<&'a [X]> for DataOrSuffStat<'a, X, Categorical>
where
X: CategoricalDatum,
{
fn from(xs: &'a [X]) -> Self {
DataOrSuffStat::Data(xs)
}
}
impl<X: CategoricalDatum> SuffStat<X> for CategoricalSuffStat {
fn n(&self) -> usize {
self.n
}
fn observe(&mut self, x: &X) {
let ix = x.into_usize();
self.n += 1;
self.counts[ix] += 1.0;
}
fn forget(&mut self, x: &X) {
let ix = x.into_usize();
self.n -= 1;
self.counts[ix] -= 1.0;
}
fn merge(&mut self, other: Self) {
self.n += other.n;
self.counts
.iter_mut()
.zip(other.counts.iter().copied())
.for_each(|(ct, ct_o)| {
*ct += ct_o;
});
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn new() {
let sf = CategoricalSuffStat::new(4);
assert_eq!(sf.counts.len(), 4);
assert_eq!(sf.n(), 0);
assert!(sf.counts.iter().all(|&ct| ct.abs() < 1E-12));
}
#[test]
fn from_parts_unchecked() {
let stat = CategoricalSuffStat::from_parts_unchecked(
10,
vec![1.0, 2.0, 3.0, 4.0],
);
assert_eq!(stat.n(), 10);
assert_eq!(stat.counts()[0], 1.0);
assert_eq!(stat.counts()[1], 2.0);
assert_eq!(stat.counts()[2], 3.0);
assert_eq!(stat.counts()[3], 4.0);
}
#[test]
fn merge() {
let mut a = CategoricalSuffStat::new(5);
let mut b = CategoricalSuffStat::new(5);
let mut c = CategoricalSuffStat::new(5);
a.observe_many(&[1_usize, 2, 3]);
b.observe_many(&[3_usize, 3, 4]);
c.observe_many(&[1_usize, 2, 3, 3, 3, 4]);
<CategoricalSuffStat as SuffStat<usize>>::merge(&mut a, b);
assert_eq!(a, c);
}
#[test]
fn observe_forget() {
let mut stat = CategoricalSuffStat::new(5);
stat.observe_many(&[0_usize, 0, 1, 1, 2]);
assert_eq!(stat.n(), 5);
assert_eq!(stat.counts(), &[2.0, 2.0, 1.0, 0.0, 0.0]);
stat.forget(&1_usize);
assert_eq!(stat.n(), 4);
assert_eq!(stat.counts(), &[2.0, 1.0, 1.0, 0.0, 0.0]);
}
}