use Rng;
use distributions::Distribution;
#[derive(Clone, Copy, Debug)]
pub struct Bernoulli {
p_int: u64,
}
const ALWAYS_TRUE: u64 = ::core::u64::MAX;
const SCALE: f64 = 2.0 * (1u64 << 63) as f64;
impl Bernoulli {
#[inline]
pub fn new(p: f64) -> Bernoulli {
if p < 0.0 || p >= 1.0 {
if p == 1.0 { return Bernoulli { p_int: ALWAYS_TRUE } }
panic!("Bernoulli::new not called with 0.0 <= p <= 1.0");
}
Bernoulli { p_int: (p * SCALE) as u64 }
}
#[inline]
pub fn from_ratio(numerator: u32, denominator: u32) -> Bernoulli {
assert!(numerator <= denominator);
if numerator == denominator {
return Bernoulli { p_int: ::core::u64::MAX }
}
let p_int = ((numerator as f64 / denominator as f64) * SCALE) as u64;
Bernoulli { p_int }
}
}
impl Distribution<bool> for Bernoulli {
#[inline]
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> bool {
if self.p_int == ALWAYS_TRUE { return true; }
let v: u64 = rng.gen();
v < self.p_int
}
}
#[cfg(test)]
mod test {
use Rng;
use distributions::Distribution;
use super::Bernoulli;
#[test]
fn test_trivial() {
let mut r = ::test::rng(1);
let always_false = Bernoulli::new(0.0);
let always_true = Bernoulli::new(1.0);
for _ in 0..5 {
assert_eq!(r.sample::<bool, _>(&always_false), false);
assert_eq!(r.sample::<bool, _>(&always_true), true);
assert_eq!(Distribution::<bool>::sample(&always_false, &mut r), false);
assert_eq!(Distribution::<bool>::sample(&always_true, &mut r), true);
}
}
#[test]
fn test_average() {
const P: f64 = 0.3;
const NUM: u32 = 3;
const DENOM: u32 = 10;
let d1 = Bernoulli::new(P);
let d2 = Bernoulli::from_ratio(NUM, DENOM);
const N: u32 = 100_000;
let mut sum1: u32 = 0;
let mut sum2: u32 = 0;
let mut rng = ::test::rng(2);
for _ in 0..N {
if d1.sample(&mut rng) {
sum1 += 1;
}
if d2.sample(&mut rng) {
sum2 += 1;
}
}
let avg1 = (sum1 as f64) / (N as f64);
assert!((avg1 - P).abs() < 5e-3);
let avg2 = (sum2 as f64) / (N as f64);
assert!((avg2 - (NUM as f64)/(DENOM as f64)).abs() < 5e-3);
}
}