use crate::{Distribution, Uniform};
use core::cmp::Ordering;
use core::fmt;
#[allow(unused_imports)]
use num_traits::Float;
use rand::{Rng, RngExt};
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct Binomial {
method: Method,
}
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
enum Method {
Binv(Binv, bool),
Btpe(Btpe, bool),
Poisson(crate::poisson::KnuthMethod<f64>),
Constant(u64),
}
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
struct Binv {
r: f64,
s: f64,
a: f64,
n: u64,
}
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
struct Btpe {
n: u64,
p: f64,
m: u64,
p1: f64,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[non_exhaustive]
pub enum Error {
ProbabilityTooSmall,
ProbabilityTooLarge,
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(match self {
Error::ProbabilityTooSmall => "p < 0 or is NaN in binomial distribution",
Error::ProbabilityTooLarge => "p > 1 in binomial distribution",
})
}
}
#[cfg(feature = "std")]
impl std::error::Error for Error {}
impl Binomial {
pub fn new(n: u64, p: f64) -> Result<Binomial, Error> {
if !(p >= 0.0) {
return Err(Error::ProbabilityTooSmall);
}
if !(p <= 1.0) {
return Err(Error::ProbabilityTooLarge);
}
if p == 0.0 {
return Ok(Binomial {
method: Method::Constant(0),
});
}
if p == 1.0 {
return Ok(Binomial {
method: Method::Constant(n),
});
}
let flipped = p > 0.5;
let p = if flipped { 1.0 - p } else { p };
const BINV_THRESHOLD: f64 = 10.;
let np = n as f64 * p;
let method = if np < BINV_THRESHOLD {
let q = 1.0 - p;
if q == 1.0 {
Method::Poisson(crate::poisson::KnuthMethod::new(np))
} else {
let s = p / q;
Method::Binv(
Binv {
r: q.powf(n as f64),
s,
a: (n as f64 + 1.0) * s,
n,
},
flipped,
)
}
} else {
let q = 1.0 - p;
let npq = np * q;
let p1 = (2.195 * npq.sqrt() - 4.6 * q).floor() + 0.5;
let f_m = np + p;
let m = f64_to_u64(f_m);
Method::Btpe(Btpe { n, p, m, p1 }, flipped)
};
Ok(Binomial { method })
}
}
fn f64_to_u64(x: f64) -> u64 {
assert!(x >= 0.0 && x < (u64::MAX as f64));
x as u64
}
fn binv<R: Rng + ?Sized>(binv: Binv, flipped: bool, rng: &mut R) -> u64 {
const BINV_MAX_X: u64 = 110;
let sample = 'outer: loop {
let mut r = binv.r;
let mut u: f64 = rng.random();
let mut x = 0;
while u > r {
u -= r;
x += 1;
if x > BINV_MAX_X {
continue 'outer;
}
r *= binv.a / (x as f64) - binv.s;
}
break x;
};
if flipped { binv.n - sample } else { sample }
}
#[allow(clippy::many_single_char_names)] fn btpe<R: Rng + ?Sized>(btpe: Btpe, flipped: bool, rng: &mut R) -> u64 {
const SQUEEZE_THRESHOLD: u64 = 20;
let n = btpe.n;
let np = (n as f64) * btpe.p;
let q = 1. - btpe.p;
let npq = np * q;
let f_m = np + btpe.p;
let m = btpe.m;
let p1 = btpe.p1;
let x_m = (m as f64) + 0.5;
let x_l = x_m - p1;
let x_r = x_m + p1;
let c = 0.134 + 20.5 / (15.3 + (m as f64));
let p2 = p1 * (1. + 2. * c);
fn lambda(a: f64) -> f64 {
a * (1. + 0.5 * a)
}
let lambda_l = lambda((f_m - x_l) / (f_m - x_l * btpe.p));
let lambda_r = lambda((x_r - f_m) / (x_r * q));
let p3 = p2 + c / lambda_l;
let p4 = p3 + c / lambda_r;
let mut y: u64;
let gen_u = Uniform::new(0., p4).unwrap();
let gen_v = Uniform::new(0., 1.).unwrap();
loop {
let u = gen_u.sample(rng);
let mut v = gen_v.sample(rng);
if !(u > p1) {
y = f64_to_u64(x_m - p1 * v + u);
break;
}
if !(u > p2) {
let x = x_l + (u - p1) / c;
v = v * c + 1.0 - (x - x_m).abs() / p1;
if v > 1. {
continue;
} else {
y = f64_to_u64(x);
}
} else if !(u > p3) {
let y_tmp = x_l + v.ln() / lambda_l;
if y_tmp < 0.0 {
continue;
} else {
y = f64_to_u64(y_tmp);
v *= (u - p2) * lambda_l;
}
} else {
y = (x_r - v.ln() / lambda_r) as u64; if y > btpe.n {
continue;
} else {
v *= (u - p3) * lambda_r;
}
}
let k = y.abs_diff(m);
if !(k > SQUEEZE_THRESHOLD && (k as f64) < 0.5 * npq - 1.) {
let s = btpe.p / q;
let a = s * (n as f64 + 1.);
let mut f = 1.0;
match m.cmp(&y) {
Ordering::Less => {
let mut i = m;
loop {
i += 1;
f *= a / (i as f64) - s;
if i == y {
break;
}
}
}
Ordering::Greater => {
let mut i = y;
loop {
i += 1;
f /= a / (i as f64) - s;
if i == m {
break;
}
}
}
Ordering::Equal => {}
}
if v > f {
continue;
} else {
break;
}
}
let k = k as f64;
let rho = (k / npq) * ((k * (k / 3. + 0.625) + 1. / 6.) / npq + 0.5);
let t = -0.5 * k * k / npq;
let alpha = v.ln();
if alpha < t - rho {
break;
}
if alpha > t + rho {
continue;
}
let x1 = (y + 1) as f64;
let f1 = (m + 1) as f64;
let z = ((n - m) + 1) as f64;
let w = ((n - y) + 1) as f64;
fn stirling(a: f64) -> f64 {
let a2 = a * a;
(13860. - (462. - (132. - (99. - 140. / a2) / a2) / a2) / a2) / a / 166320.
}
let y_sub_m = if y > m {
(y - m) as f64
} else {
-((m - y) as f64)
};
if alpha
> x_m * (f1 / x1).ln()
+ (((n - m) as f64) + 0.5) * (z / w).ln()
+ y_sub_m * (w * btpe.p / (x1 * q)).ln()
+ stirling(f1)
+ stirling(z)
- stirling(x1)
- stirling(w)
{
continue;
}
break;
}
if flipped { btpe.n - y } else { y }
}
impl Distribution<u64> for Binomial {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> u64 {
match self.method {
Method::Binv(binv_para, flipped) => binv(binv_para, flipped, rng),
Method::Btpe(btpe_para, flipped) => btpe(btpe_para, flipped, rng),
Method::Poisson(poisson) => poisson.sample(rng) as u64,
Method::Constant(c) => c,
}
}
}
#[cfg(test)]
mod test {
use super::*;
fn test_binomial_mean_and_variance<R: Rng>(n: u64, p: f64, rng: &mut R) {
let binomial = Binomial::new(n, p).unwrap();
let expected_mean = n as f64 * p;
let expected_variance = n as f64 * p * (1.0 - p);
let mut results = [0.0; 1000];
for i in results.iter_mut() {
*i = binomial.sample(rng) as f64;
}
let mean = results.iter().sum::<f64>() / results.len() as f64;
assert!((mean - expected_mean).abs() < expected_mean / 50.0);
let variance =
results.iter().map(|x| (x - mean) * (x - mean)).sum::<f64>() / results.len() as f64;
assert!((variance - expected_variance).abs() < expected_variance / 10.0);
}
#[test]
fn test_binomial() {
let mut rng = crate::test::rng(351);
test_binomial_mean_and_variance(150, 0.1, &mut rng);
test_binomial_mean_and_variance(70, 0.6, &mut rng);
test_binomial_mean_and_variance(40, 0.5, &mut rng);
test_binomial_mean_and_variance(20, 0.7, &mut rng);
test_binomial_mean_and_variance(20, 0.5, &mut rng);
test_binomial_mean_and_variance(1 << 61, 1e-17, &mut rng);
test_binomial_mean_and_variance(u64::MAX, 1e-19, &mut rng);
}
#[test]
fn test_binomial_end_points() {
let mut rng = crate::test::rng(352);
assert_eq!(rng.sample(Binomial::new(20, 0.0).unwrap()), 0);
assert_eq!(rng.sample(Binomial::new(20, 1.0).unwrap()), 20);
}
#[test]
#[should_panic]
fn test_binomial_invalid_lambda_neg() {
Binomial::new(20, -10.0).unwrap();
}
#[test]
fn binomial_distributions_can_be_compared() {
assert_eq!(Binomial::new(1, 1.0), Binomial::new(1, 1.0));
}
#[test]
fn binomial_avoid_infinite_loop() {
let dist = Binomial::new(16000000, 3.1444753148558566e-10).unwrap();
let mut sum: u64 = 0;
let mut rng = crate::test::rng(742);
for _ in 0..100_000 {
sum = sum.wrapping_add(dist.sample(&mut rng));
}
assert_ne!(sum, 0);
}
}