use crate::distributions::*;
use crate::functions::binom_coeff;
#[derive(Debug, Clone, Copy)]
pub struct Binomial {
n: u64,
p: f64,
}
impl Binomial {
pub fn new(n: u64, p: f64) -> Self {
if !(0. ..=1.).contains(&p) {
panic!("`p` must be in [0, 1]");
}
Binomial { n, p }
}
pub fn set_n(&mut self, n: u64) -> &mut Self {
self.n = n;
self
}
pub fn set_p(&mut self, p: f64) -> &mut Self {
if !(0. ..=1.).contains(&p) {
panic!("`p` must be in [0, 1]");
}
self.p = p;
self
}
}
impl Default for Binomial {
fn default() -> Self {
Self::new(1, 0.5)
}
}
impl Distribution for Binomial {
type Output = f64;
fn sample(&self) -> f64 {
if self.n == 0 || self.p == 0. {
return 0.;
} else if (self.p - 1.).abs() <= f64::EPSILON {
return self.n as f64;
}
let switch = self.p > 0.5;
let p = if switch { 1. - self.p } else { self.p };
let res = if p * self.n as f64 <= 30. {
binomial_inversion(self.n, p)
} else {
binomial_btpe(self.n, p)
};
if switch {
(self.n - res) as f64
} else {
res as f64
}
}
}
impl Distribution1D for Binomial {
fn update(&mut self, params: &[f64]) {
self.set_n(params[0] as u64);
self.set_p(params[1]);
}
}
pub fn binomial_inversion(n: u64, p: f64) -> u64 {
let s = p / (1. - p);
let a = ((n + 1) as f64) * s;
let mut r = (1. - p).powi(n as i32);
let mut u = alea::f64();
let mut x: u64 = 0;
while u > r as f64 {
u -= r;
x += 1;
r *= a / (x as f64) - s;
}
x
}
pub fn binomial_btpe(n: u64, p: f64) -> u64 {
let nf = n as f64;
let r = if p <= 0.5 { p } else { 1. - p };
let q = 1. - r;
let nrq = nf * r * q;
let fm = nf * r + r;
let m = fm.floor();
let p1 = (2.195 * nrq.sqrt() - 4.6 * q).floor() + 0.5;
let xm = m + 0.5;
let xl = xm - p1;
let xr = xm + p1;
let lambda = |x: f64| x * (1. + x / 2.);
let c = 0.134 + 20.5 / (15.3 + m);
let ll = lambda((fm - xl) / (fm - xl * r));
let lr = lambda((xr - fm) / (xr * q));
let p2 = p1 * (1. + 2. * c);
let p3 = p2 + c / ll;
let p4 = p3 + c / lr;
let mut y: f64;
let ugen = Uniform::new(0., p4);
let vgen = Uniform::new(0., 1.);
loop {
let u = ugen.sample();
let mut v = vgen.sample();
if matches!(
u.partial_cmp(&p1),
None | Some(std::cmp::Ordering::Equal) | Some(std::cmp::Ordering::Less)
) {
y = (xm - p1 * v + u).floor();
break;
}
if matches!(
u.partial_cmp(&p2),
None | Some(std::cmp::Ordering::Equal) | Some(std::cmp::Ordering::Less)
) {
let x = xl + (u - p1) / c;
v = v * c + 1. - (m - x + 0.5).abs() / p1;
if v > 1. {
continue;
} else {
y = x.floor();
}
} else if matches!(
u.partial_cmp(&p3),
None | Some(std::cmp::Ordering::Equal) | Some(std::cmp::Ordering::Less)
) {
y = (xl + v.ln() / ll).floor();
if y < 0. {
continue;
} else {
v *= (u - p2) * ll;
}
} else {
y = (xr - v.ln() / lr).floor();
if y > nf {
continue;
} else {
v *= (u - p3) * lr;
}
}
let k = (y - m).abs();
if !(k > 20. && k < 0.5 * (nrq) - 1.) {
let s = p / q;
let a = s * (n as f64 + 1.);
let mut f = 1.;
if m < y {
let mut i = m;
loop {
i += 1.;
f *= (a / i) - s;
if (i - y).abs() < f64::EPSILON {
break;
}
}
} else if m > y {
let mut i = y;
loop {
i += 1.;
f /= (a / i) - s;
if (i - m).abs() < f64::EPSILON {
break;
}
}
}
if v > f {
continue;
} else {
break;
}
}
let rho = (k / nrq) * ((k * (k / 3. + 0.625) + 1. / 6.) / nrq + 0.5);
let t = -k * k / (2. * nrq);
let biga = v.ln();
if biga < t - rho {
break;
}
if biga > t + rho {
continue;
}
let x1 = y + 1.;
let f1 = m + 1.;
let z = nf + 1. - m;
let w = nf - y + 1.;
let st = |x: f64| {
(13860. - (462. - (132. - (99. - 140. / (x * x)) / (x * x)) / (x * x)) / (x * x))
/ x
/ 166320.
};
if biga
> xm * (f1 / x1).ln()
+ (nf - m + 0.5) * (z / w).ln()
+ (y - m) * (w * r / (x1 * q)).ln()
+ st(f1)
+ st(z)
+ st(x1)
+ st(w)
{
continue;
}
break;
}
if p > 0.5 {
y = nf - y;
}
y as u64
}
impl Discrete for Binomial {
fn pmf(&self, k: i64) -> f64 {
binom_coeff(self.n, k as u64) as f64
* self.p.powi(k as i32)
* (1. - self.p).powi((self.n - k as u64) as i32)
}
}
impl Mean for Binomial {
type MeanType = f64;
fn mean(&self) -> f64 {
self.n as f64 * self.p
}
}
impl Variance for Binomial {
type VarianceType = f64;
fn var(&self) -> f64 {
self.n as f64 * self.p * (1. - self.p)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::statistics::{mean, var};
use approx_eq::assert_approx_eq;
#[test]
fn test_moments() {
let distr1 = Binomial::new(15, 0.3);
let data1 = distr1.sample_n(1e6 as usize);
let mean1 = mean(&data1);
let var1 = var(&data1);
assert_approx_eq!(mean1, 4.5, 1e-2);
assert_approx_eq!(var1, 3.15, 1e-2);
let distr2 = Binomial::new(70, 0.5);
let data2 = distr2.sample_n(1e6 as usize);
let mean2 = mean(&data2);
let var2 = var(&data2);
assert_approx_eq!(mean2, 35., 1e-2);
assert_approx_eq!(var2, 17.5, 1e-2);
}
}