use rand_distr::Binomial as Binomial2;
use rand_distr::Distribution as Distribution2;
use crate::distributions::Distribution;
use rand::Rng;
pub struct Binomial {
n: u64,
p: f64,
}
impl Binomial {
pub fn new(n: u64, p: f64) -> Result<Binomial, String> {
if n == 0 {
Err("Binomial: n must be at least 1".to_string())
} else if !(p > 0.0 && p < 1.0) {
Err(format!(
"Binomial: illegal p `{}` should be in the open interval (0, 1)",
p
))
} else {
Ok(Binomial { n, p })
}
}
}
impl<R: Rng + ?Sized> Distribution<R> for Binomial {
type Domain = u64;
fn sample(&self, rng: &mut R) -> u64 {
Binomial2::new(self.n, self.p).unwrap().sample(rng)
}
fn log_prob(&self, k: &u64) -> f64 {
if *k > self.n {
return f64::NEG_INFINITY;
}
let k_f = *k as f64;
let n_f = self.n as f64;
let log_binom =
libm::lgamma(n_f + 1.0) - libm::lgamma(k_f + 1.0) - libm::lgamma(n_f - k_f + 1.0);
log_binom + k_f * self.p.ln() + (n_f - k_f) * (1.0 - self.p).ln()
}
fn is_discrete(&self) -> bool {
true
}
}
impl std::fmt::Display for Binomial {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Binomial {{ n = {}, p = {} }}", self.n, self.p)
}
}
#[cfg(test)]
mod tests {
use super::*;
use rand::rngs::ThreadRng;
use rand::thread_rng;
#[test]
fn binomial_sample() {
let mut rng = thread_rng();
let n = 20u64;
let p = 0.4f64;
let dist = Binomial::new(n, p).unwrap();
println!("dist = {}", dist);
let mut total = 0u64;
let trials = 10000;
for _ in 0..trials {
total += dist.sample(&mut rng);
}
let empirical_mean = (total as f64) / (trials as f64);
let expected_mean = (n as f64) * p;
let expected_std = ((n as f64) * p * (1.0 - p)).sqrt();
let err = 5.0 * expected_std / (trials as f64).sqrt();
assert!((empirical_mean - expected_mean).abs() < err);
}
#[test]
fn binomial_log_prob() {
let dist = Binomial::new(10, 0.5).unwrap();
let lp = <Binomial as Distribution<ThreadRng>>::log_prob(&dist, &5);
let expected = (252.0f64 / 1024.0).ln();
assert!((lp - expected).abs() < 1e-10);
let lp_out = <Binomial as Distribution<ThreadRng>>::log_prob(&dist, &11);
assert_eq!(lp_out, f64::NEG_INFINITY);
assert!(<Binomial as Distribution<ThreadRng>>::is_discrete(&dist));
}
#[test]
#[should_panic]
fn binomial_zero_n() {
Binomial::new(0, 0.5).unwrap();
}
#[test]
#[should_panic]
fn binomial_p_too_low() {
Binomial::new(10, 0.0).unwrap();
}
#[test]
#[should_panic]
fn binomial_p_too_high() {
Binomial::new(10, 1.0).unwrap();
}
}