#[cfg(feature = "serde1")]
use serde::{Deserialize, Serialize};
use crate::data::Booleable;
use crate::data::DataOrSuffStat;
use crate::dist::Bernoulli;
use crate::traits::SuffStat;
#[derive(Debug, Clone, Eq, PartialEq)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))]
pub struct BernoulliSuffStat {
n: usize,
k: usize,
}
impl BernoulliSuffStat {
#[inline]
#[must_use]
pub fn new() -> Self {
BernoulliSuffStat { n: 0, k: 0 }
}
#[inline]
#[must_use]
pub fn from_parts_unchecked(n: usize, k: usize) -> Self {
BernoulliSuffStat { n, k }
}
#[inline]
#[must_use]
pub fn n(&self) -> usize {
self.n
}
#[inline]
#[must_use]
pub fn k(&self) -> usize {
self.k
}
}
impl Default for BernoulliSuffStat {
fn default() -> Self {
BernoulliSuffStat::new()
}
}
impl<'a, X> From<&'a BernoulliSuffStat> for DataOrSuffStat<'a, X, Bernoulli>
where
X: Booleable,
{
fn from(stat: &'a BernoulliSuffStat) -> Self {
DataOrSuffStat::SuffStat(stat)
}
}
impl<'a, X> From<&'a Vec<X>> for DataOrSuffStat<'a, X, Bernoulli>
where
X: Booleable,
{
fn from(xs: &'a Vec<X>) -> Self {
DataOrSuffStat::Data(xs.as_slice())
}
}
impl<'a, X> From<&'a [X]> for DataOrSuffStat<'a, X, Bernoulli>
where
X: Booleable,
{
fn from(xs: &'a [X]) -> Self {
DataOrSuffStat::Data(xs)
}
}
impl<X: Booleable> SuffStat<X> for BernoulliSuffStat {
fn n(&self) -> usize {
self.n
}
fn observe(&mut self, x: &X) {
self.n += 1;
if x.into_bool() {
self.k += 1;
}
}
fn forget(&mut self, x: &X) {
self.n -= 1;
if x.into_bool() {
self.k -= 1;
}
}
fn merge(&mut self, other: Self) {
self.n += other.n;
self.k += other.k;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn new_should_be_empty() {
let stat = BernoulliSuffStat::new();
assert_eq!(stat.n, 0);
assert_eq!(stat.k, 0);
}
#[test]
fn from_parts_unchecked() {
let stat = BernoulliSuffStat::from_parts_unchecked(10, 3);
assert_eq!(stat.n(), 10);
assert_eq!(stat.k(), 3);
}
#[test]
fn observe_1() {
let mut stat = BernoulliSuffStat::new();
stat.observe(&1_u8);
assert_eq!(stat.n, 1);
assert_eq!(stat.k, 1);
}
#[test]
fn observe_true() {
let mut stat = BernoulliSuffStat::new();
stat.observe(&true);
assert_eq!(stat.n, 1);
assert_eq!(stat.k, 1);
}
#[test]
fn observe_0() {
let mut stat = BernoulliSuffStat::new();
stat.observe(&0_i8);
assert_eq!(stat.n, 1);
assert_eq!(stat.k, 0);
}
#[test]
fn observe_false() {
let mut stat = BernoulliSuffStat::new();
stat.observe(&false);
assert_eq!(stat.n, 1);
assert_eq!(stat.k, 0);
}
}