use Rng;
use distributions::Distribution;
#[derive(Clone, Copy, Debug)]
pub struct Bernoulli {
p_int: u64,
}
impl Bernoulli {
#[inline]
pub fn new(p: f64) -> Bernoulli {
assert!((p >= 0.0) & (p <= 1.0), "Bernoulli::new not called with 0 <= p <= 0");
const MAX_P_INT: f64 = ::core::u64::MAX as f64;
let p_int = if p < 1.0 {
(p * MAX_P_INT) as u64
} else {
::core::u64::MAX
};
Bernoulli { p_int }
}
}
impl Distribution<bool> for Bernoulli {
#[inline]
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> bool {
if self.p_int == ::core::u64::MAX {
return true;
}
let r: u64 = rng.gen();
r < self.p_int
}
}
#[cfg(test)]
mod test {
use Rng;
use distributions::Distribution;
use super::Bernoulli;
#[test]
fn test_trivial() {
let mut r = ::test::rng(1);
let always_false = Bernoulli::new(0.0);
let always_true = Bernoulli::new(1.0);
for _ in 0..5 {
assert_eq!(r.sample::<bool, _>(&always_false), false);
assert_eq!(r.sample::<bool, _>(&always_true), true);
assert_eq!(Distribution::<bool>::sample(&always_false, &mut r), false);
assert_eq!(Distribution::<bool>::sample(&always_true, &mut r), true);
}
}
#[test]
fn test_average() {
const P: f64 = 0.3;
let d = Bernoulli::new(P);
const N: u32 = 10_000_000;
let mut sum: u32 = 0;
let mut rng = ::test::rng(2);
for _ in 0..N {
if d.sample(&mut rng) {
sum += 1;
}
}
let avg = (sum as f64) / (N as f64);
assert!((avg - P).abs() < 1e-3);
}
}