#![allow(clippy::float_cmp)]
use crate::distributions::*;
#[derive(Debug, Clone, Copy)]
pub struct Bernoulli {
p: f64,
}
impl Bernoulli {
pub fn new(p: f64) -> Self {
if !(0. ..=1.).contains(&p) {
panic!("`p` must be in [0, 1].");
}
Bernoulli { p }
}
pub fn set_p(&mut self, p: f64) -> &mut Self {
if !(0. ..=1.).contains(&p) {
panic!("`p` must be in [0, 1].");
}
self.p = p;
self
}
}
impl Default for Bernoulli {
fn default() -> Self {
Self::new(0.5)
}
}
impl Distribution for Bernoulli {
type Output = f64;
fn sample(&self) -> f64 {
if self.p == 1. {
return 1.;
} else if self.p == 0. {
return 0.;
}
if self.p > alea::f64() {
1.
} else {
0.
}
}
}
impl Distribution1D for Bernoulli {
fn update(&mut self, params: &[f64]) {
self.set_p(params[0]);
}
}
impl Discrete for Bernoulli {
fn pmf(&self, k: i64) -> f64 {
if k == 0 {
1. - self.p
} else if k == 1 {
self.p
} else {
0.
}
}
}
impl Mean for Bernoulli {
type MeanType = f64;
fn mean(&self) -> f64 {
self.p
}
}
impl Variance for Bernoulli {
type VarianceType = f64;
fn var(&self) -> f64 {
self.p * (1. - self.p)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::statistics::{mean, var};
use approx_eq::assert_approx_eq;
#[test]
fn test_bernoulli() {
let data = Bernoulli::new(0.75).sample_n(1e6 as usize);
for i in &data {
assert!(*i == 0. || *i == 1.);
}
assert_approx_eq!(0.75, mean(&data), 1e-2);
assert_approx_eq!(0.75 * 0.25, var(&data), 1e-2);
assert!(Bernoulli::default().pmf(2) == 0.);
assert!(Bernoulli::default().pmf(0) == 0.5);
}
}