use crate::FloatScalar;
use super::{DiscreteDistribution, StatsError};
#[derive(Debug, Clone, Copy)]
pub struct Bernoulli<T> {
p: T,
}
impl<T: FloatScalar> Bernoulli<T> {
pub fn new(p: T) -> Result<Self, StatsError> {
if p < T::zero() || p > T::one() {
return Err(StatsError::InvalidParameter);
}
Ok(Self { p })
}
}
impl<T: FloatScalar> Bernoulli<T> {
pub fn sample(&self, rng: &mut super::Rng) -> u64 {
if rng.next_float::<T>() < self.p { 1 } else { 0 }
}
pub fn sample_array<const K: usize>(&self, rng: &mut super::Rng) -> [u64; K] {
let mut out = [0u64; K];
for v in out.iter_mut() {
*v = self.sample(rng);
}
out
}
}
impl<T: FloatScalar> DiscreteDistribution<T> for Bernoulli<T> {
fn pmf(&self, k: u64) -> T {
match k {
0 => T::one() - self.p,
1 => self.p,
_ => T::zero(),
}
}
fn ln_pmf(&self, k: u64) -> T {
match k {
0 => (T::one() - self.p).ln(),
1 => self.p.ln(),
_ => T::neg_infinity(),
}
}
fn cdf(&self, k: u64) -> T {
if k == 0 {
T::one() - self.p
} else {
T::one()
}
}
fn quantile(&self, p: T) -> u64 {
if p <= T::zero() || self.cdf(0) >= p {
0
} else {
1
}
}
fn mean(&self) -> T {
self.p
}
fn variance(&self) -> T {
self.p * (T::one() - self.p)
}
}