use rand::Rng;
use rand_distr::Distribution as Distribution2;
use rand_distr::Gamma as Gamma2;
use rand_distr::Poisson;
use crate::distributions::Distribution;
pub struct NegativeBinomial {
r: f64,
p: f64,
}
impl NegativeBinomial {
pub fn new(r: f64, p: f64) -> Result<NegativeBinomial, String> {
if r <= 0.0 {
Err(format!(
"NegativeBinomial: illegal r `{}` should be greater than 0",
r
))
} else if !(p > 0.0 && p <= 1.0) {
Err(format!(
"NegativeBinomial: illegal p `{}` must be in (0, 1]",
p
))
} else {
Ok(NegativeBinomial { r, p })
}
}
}
impl<R: Rng + ?Sized> Distribution<R> for NegativeBinomial {
type Domain = u64;
fn sample(&self, rng: &mut R) -> u64 {
if self.p == 1.0 {
return 0;
}
let scale = (1.0 - self.p) / self.p;
let lambda = Gamma2::new(self.r, scale).unwrap().sample(rng);
Poisson::new(lambda).unwrap().sample(rng) as u64
}
fn log_prob(&self, x: &u64) -> f64 {
if self.p == 1.0 {
return if *x == 0 { 0.0 } else { f64::NEG_INFINITY };
}
let k = *x as f64;
libm::lgamma(k + self.r) - libm::lgamma(self.r) - libm::lgamma(k + 1.0)
+ self.r * self.p.ln()
+ k * (1.0 - self.p).ln()
}
fn is_discrete(&self) -> bool {
true
}
}
impl std::fmt::Display for NegativeBinomial {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "NegativeBinomial {{ r = {}, p = {} }}", self.r, self.p)
}
}
#[cfg(test)]
mod tests {
use super::*;
use rand::rngs::ThreadRng;
use rand::thread_rng;
#[test]
fn negative_binomial_sample() {
let mut rng = thread_rng();
let r = 5.0f64;
let p = 0.4f64;
let dist = NegativeBinomial::new(r, p).unwrap();
println!("dist = {}", dist);
let trials = 100_000;
let mut total = 0.0f64;
for _ in 0..trials {
total += dist.sample(&mut rng) as f64;
}
let empirical_mean = total / trials as f64;
let expected_mean = r * (1.0 - p) / p;
let variance = r * (1.0 - p) / (p * p);
let std = variance.sqrt();
let err = 5.0 * std / (trials as f64).sqrt();
assert!((empirical_mean - expected_mean).abs() < err);
}
#[test]
fn negative_binomial_log_prob() {
let dist = NegativeBinomial::new(1.0, 0.5).unwrap();
let lp = <NegativeBinomial as Distribution<ThreadRng>>::log_prob(&dist, &0);
assert!((lp - (-2.0f64.ln())).abs() < 1e-10);
assert!(<NegativeBinomial as Distribution<ThreadRng>>::is_discrete(
&dist
));
}
#[test]
fn negative_binomial_certain_success() {
let dist = NegativeBinomial::new(2.0, 1.0).unwrap();
assert_eq!(dist.sample(&mut thread_rng()), 0);
assert_eq!(
<NegativeBinomial as Distribution<ThreadRng>>::log_prob(&dist, &0),
0.0
);
assert_eq!(
<NegativeBinomial as Distribution<ThreadRng>>::log_prob(&dist, &1),
f64::NEG_INFINITY
);
}
#[test]
fn negative_binomial_display() {
let dist = NegativeBinomial::new(5.0, 0.4).unwrap();
let s = format!("{}", dist);
assert!(s.contains("NegativeBinomial"), "missing type name: {}", s);
}
#[test]
#[should_panic]
fn negative_binomial_zero_r() {
NegativeBinomial::new(0.0, 0.5).unwrap();
}
#[test]
#[should_panic]
fn negative_binomial_invalid_p() {
NegativeBinomial::new(1.0, 0.0).unwrap();
}
}