use crate::distributions::Distribution;
use rand::Rng;
pub struct Bernoulli {
p: f64,
}
impl Bernoulli {
pub fn new(p: f64) -> Result<Bernoulli, String> {
if !(0f64..=1f64).contains(&p) {
Err(format! {"Bernoulli: illegal probability `{}` should be between 0 and 1", p})
} else {
Ok(Bernoulli { p })
}
}
}
impl<R: Rng + ?Sized> Distribution<R> for Bernoulli {
type Domain = bool;
fn sample(&self, rng: &mut R) -> bool {
let val: f64 = rng.r#gen();
val < self.p
}
fn log_prob(&self, x: &bool) -> f64 {
if *x { self.p.ln() } else { (1.0 - self.p).ln() }
}
fn is_discrete(&self) -> bool {
true
}
}
impl std::fmt::Display for Bernoulli {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Bernoulli {{ p = {} }}", self.p)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::distributions::Distribution;
use rand::thread_rng;
#[test]
fn bernoulli_sample() {
let mut rng = thread_rng();
let dist = Bernoulli::new(0.1).unwrap();
println!("dist = {}", dist);
let mut succ = 0;
let trials = 10000;
for _ in 0..trials {
if dist.sample(&mut rng) {
succ += 1;
}
}
let mean = (succ as f64) / (trials as f64);
assert!((mean - 0.1).abs() < 0.01);
let _dist2 = Bernoulli::new(0.0).unwrap();
let _dist3 = Bernoulli::new(1.0).unwrap();
}
#[test]
#[should_panic]
fn bernoulli_too_low() {
let _dist = Bernoulli::new(-0.01).unwrap();
}
#[test]
#[should_panic]
fn bernoulli_too_high() {
let _dist = Bernoulli::new(1.01).unwrap();
}
#[test]
fn bernoulli_log_prob_and_display() {
let dist = Bernoulli::new(0.25).unwrap();
let lp_true = Distribution::<rand::rngs::ThreadRng>::log_prob(&dist, &true);
let lp_false = Distribution::<rand::rngs::ThreadRng>::log_prob(&dist, &false);
assert_eq!(lp_true, 0.25f64.ln());
assert_eq!(lp_false, (1.0f64 - 0.25).ln());
let d0 = Bernoulli::new(0.0).unwrap();
assert!(Distribution::<rand::rngs::ThreadRng>::log_prob(&d0, &true).is_infinite());
assert_eq!(
Distribution::<rand::rngs::ThreadRng>::log_prob(&d0, &false),
(1.0f64 - 0.0).ln()
);
let d1 = Bernoulli::new(1.0).unwrap();
assert_eq!(
Distribution::<rand::rngs::ThreadRng>::log_prob(&d1, &true),
(1.0f64).ln()
);
assert!(Distribution::<rand::rngs::ThreadRng>::log_prob(&d1, &false).is_infinite());
assert!(format!("{}", dist).contains("Bernoulli"));
assert!(Distribution::<rand::rngs::ThreadRng>::is_discrete(&dist));
}
}