use rand_distr::Distribution as Distribution2;
use rand_distr::Poisson as Poisson2;
use crate::distributions::Distribution;
use rand::Rng;
pub struct Poisson {
rate: f64,
}
impl Poisson {
pub fn new(rate: f64) -> Result<Poisson, String> {
if rate <= 0f64 {
Err(format! {"Poisson: illegal rate `{}` should be greater than 0", rate})
} else {
Ok(Poisson { rate })
}
}
}
impl<R: Rng + ?Sized> Distribution<R> for Poisson {
type Domain = u64;
fn sample(&self, rng: &mut R) -> u64 {
Poisson2::new(self.rate).unwrap().sample(rng) as u64
}
fn log_prob(&self, x: &u64) -> f64 {
let k = *x as f64;
let log_factorial: f64 = (1..=*x).map(|i| (i as f64).ln()).sum();
k * self.rate.ln() - self.rate - log_factorial
}
fn log_cum_prob(&self, x: &u64) -> f64 {
let mut cumulative = 0.0;
for k in 0..=*x {
cumulative += <Poisson as Distribution<R>>::log_prob(self, &k).exp();
}
cumulative.min(1.0).ln()
}
fn is_discrete(&self) -> bool {
true
}
}
impl std::fmt::Display for Poisson {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Poisson {{ rate = {} }}", self.rate)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::distributions::Distribution;
use rand::thread_rng;
#[test]
fn poisson_sample() {
let mut rng = thread_rng();
let rate = 2.7f64;
let dist = Poisson::new(rate).unwrap();
println!("dist = {}", dist);
let mut total = 0u64;
let trials = 10000;
for _ in 0..trials {
total += dist.sample(&mut rng);
}
let mean = (total as f64) / (trials as f64);
let err = 5.0 * (rate / (trials as f64)).sqrt();
println!("empirical mean is {} 5sigma error is {}", mean, err);
assert!((mean - 2.7).abs() < err);
}
#[test]
#[should_panic]
fn poisson_too_low() {
let _dist = Poisson::new(-0.01).unwrap();
}
#[test]
fn poisson_log_prob_and_display() {
let rate = 3.0f64;
let dist = Poisson::new(rate).unwrap();
let lp0 = Distribution::<rand::rngs::ThreadRng>::log_prob(&dist, &0u64);
assert_eq!(lp0, -rate);
let lp1 = Distribution::<rand::rngs::ThreadRng>::log_prob(&dist, &1u64);
let expected_lp1 = 1.0 * rate.ln() - rate - 0.0; assert!((lp1 - expected_lp1).abs() < 1e-12);
let lp5 = Distribution::<rand::rngs::ThreadRng>::log_prob(&dist, &5u64);
let k = 5u64;
let log_fact: f64 = (1..=k).map(|i| (i as f64).ln()).sum();
let expected_lp5 = (k as f64) * rate.ln() - rate - log_fact;
assert!((lp5 - expected_lp5).abs() < 1e-12);
assert!(format!("{}", dist).contains("Poisson"));
assert!(Distribution::<rand::rngs::ThreadRng>::is_discrete(&dist));
let below = Distribution::<rand::rngs::ThreadRng>::log_cum_prob(&dist, &2u64);
let expected_below = Distribution::<rand::rngs::ThreadRng>::log_prob(&dist, &0u64).exp()
+ Distribution::<rand::rngs::ThreadRng>::log_prob(&dist, &1u64).exp()
+ Distribution::<rand::rngs::ThreadRng>::log_prob(&dist, &2u64).exp();
assert!((below.exp() - expected_below).abs() < 1e-12);
let bad = Poisson::new(0.0);
assert!(bad.is_err());
}
}