ferric/distributions/
poisson.rs

1// Copyright 2022 The Ferric AI Project Developers
2use rand_distr::Distribution as Distribution2;
3use rand_distr::Poisson as Poisson2;
4
5use crate::distributions::Distribution;
6use rand::Rng;
7
8pub struct Poisson {
9    /// Probability of success.
10    rate: f64,
11}
12
13impl Poisson {
14    pub fn new(rate: f64) -> Result<Poisson, String> {
15        if rate <= 0f64 {
16            Err(format! {"Poisson: illegal rate `{}` should be greater than 0", rate})
17        } else {
18            Ok(Poisson { rate })
19        }
20    }
21}
22
23impl<R: Rng + ?Sized> Distribution<R> for Poisson {
24    type Domain = u64;
25    fn sample(&self, rng: &mut R) -> u64 {
26        Poisson2::new(self.rate).unwrap().sample(rng) as u64
27    }
28}
29
30impl std::fmt::Display for Poisson {
31    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32        write!(f, "Poisson {{ rate = {} }}", self.rate)
33    }
34}
35
36#[cfg(test)]
37mod tests {
38    use super::*;
39    use rand::thread_rng;
40
41    #[test]
42    fn poisson_sample() {
43        let mut rng = thread_rng();
44        let rate = 2.7f64;
45        let dist = Poisson::new(rate).unwrap();
46        println!("dist = {}", dist);
47        let mut total = 0u64;
48        let trials = 10000;
49        for _ in 0..trials {
50            total += dist.sample(&mut rng);
51        }
52        let mean = (total as f64) / (trials as f64);
53        let err = 5.0 * (rate / (trials as f64)).sqrt();
54        println!("empirical mean is {} 5sigma error is {}", mean, err);
55        assert!((mean - 2.7).abs() < err);
56    }
57
58    #[test]
59    #[should_panic]
60    fn poisson_too_low() {
61        let _dist = Poisson::new(-0.01).unwrap();
62    }
63}