use crate::Rng;
use crate::distributions::Distribution;
#[derive(Clone, Copy, Debug)]
pub struct Bernoulli {
p_int: u64,
}
const ALWAYS_TRUE: u64 = ::core::u64::MAX;
const SCALE: f64 = 2.0 * (1u64 << 63) as f64;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum BernoulliError {
InvalidProbability,
}
impl Bernoulli {
#[inline]
pub fn new(p: f64) -> Result<Bernoulli, BernoulliError> {
if p < 0.0 || p >= 1.0 {
if p == 1.0 { return Ok(Bernoulli { p_int: ALWAYS_TRUE }) }
return Err(BernoulliError::InvalidProbability);
}
Ok(Bernoulli { p_int: (p * SCALE) as u64 })
}
#[inline]
pub fn from_ratio(numerator: u32, denominator: u32) -> Result<Bernoulli, BernoulliError> {
if !(numerator <= denominator) {
return Err(BernoulliError::InvalidProbability);
}
if numerator == denominator {
return Ok(Bernoulli { p_int: ALWAYS_TRUE })
}
let p_int = ((numerator as f64 / denominator as f64) * SCALE) as u64;
Ok(Bernoulli { p_int })
}
}
impl Distribution<bool> for Bernoulli {
#[inline]
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> bool {
if self.p_int == ALWAYS_TRUE { return true; }
let v: u64 = rng.gen();
v < self.p_int
}
}
#[cfg(test)]
mod test {
use crate::Rng;
use crate::distributions::Distribution;
use super::Bernoulli;
#[test]
fn test_trivial() {
let mut r = crate::test::rng(1);
let always_false = Bernoulli::new(0.0).unwrap();
let always_true = Bernoulli::new(1.0).unwrap();
for _ in 0..5 {
assert_eq!(r.sample::<bool, _>(&always_false), false);
assert_eq!(r.sample::<bool, _>(&always_true), true);
assert_eq!(Distribution::<bool>::sample(&always_false, &mut r), false);
assert_eq!(Distribution::<bool>::sample(&always_true, &mut r), true);
}
}
#[test]
#[cfg(not(miri))] fn test_average() {
const P: f64 = 0.3;
const NUM: u32 = 3;
const DENOM: u32 = 10;
let d1 = Bernoulli::new(P).unwrap();
let d2 = Bernoulli::from_ratio(NUM, DENOM).unwrap();
const N: u32 = 100_000;
let mut sum1: u32 = 0;
let mut sum2: u32 = 0;
let mut rng = crate::test::rng(2);
for _ in 0..N {
if d1.sample(&mut rng) {
sum1 += 1;
}
if d2.sample(&mut rng) {
sum2 += 1;
}
}
let avg1 = (sum1 as f64) / (N as f64);
assert!((avg1 - P).abs() < 5e-3);
let avg2 = (sum2 as f64) / (N as f64);
assert!((avg2 - (NUM as f64)/(DENOM as f64)).abs() < 5e-3);
}
}