dnorm/
lib.rs

1//! Sampling exactly from the normal distribution (CF Karney)
2use rand::distributions::Distribution;
3use rand::Rng;
4
5const EXP_MINUS_HALF: f64 = 0.606530659712633;
6
7/// Discrete Normal Distribution
8/// # Examples
9///
10/// ```
11/// use dnorm::DiscreteNormal;
12/// use rand::distributions::Distribution;
13///
14/// let d = DiscreteNormal::new(0.0, 3.0);
15/// let v = d.sample(&mut rand::thread_rng());
16/// println!("{} is from a discrete N(0, 9) distribution", v)
17/// ```
18
19pub struct DiscreteNormal {
20    mu: f64,
21    sigma: f64,
22}
23
24impl DiscreteNormal {
25    pub fn new(mu: f64, sigma: f64) -> Self {
26        Self { mu, sigma }
27    }
28
29    #[inline]
30    fn bernoulli(&self) -> bool {
31        rand::thread_rng().gen::<f64>() < EXP_MINUS_HALF
32    }
33
34    #[inline]
35    fn g(&self) -> u32 {
36        let mut res: u32 = 0;
37        while self.bernoulli() {
38            res += 1;
39        }
40        res
41    }
42
43    #[inline]
44    fn s(&self) -> u32 {
45        loop {
46            let k = self.g() as i32;
47            if k < 2 {
48                return k as u32;
49            }
50            let mut z = k * (k - 1) - 1;
51            while z != 0 && self.bernoulli() {
52                z -= 1;
53            }
54            if z < 0 {
55                return k as u32;
56            };
57        }
58    }
59}
60
61impl Distribution<i32> for DiscreteNormal {
62    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> i32 {
63        loop {
64            let k = self.s();
65
66            let s = if rng.gen::<bool>() { 1 } else { -1 };
67
68            let mut xn0 = (k as f64) * self.sigma + (s as f64) * self.mu;
69            let i0 = xn0.ceil();
70            xn0 = (i0 - xn0) / self.sigma;
71            let j = rng.gen::<u32>() % (self.sigma.ceil() as u32);
72
73            let x = xn0 + (j as f64) / self.sigma;
74            if x < 1.0 && !(x == 0.0 && s < 0 && k == 0) {
75                xn0 = (-x * ((k << 1) as f64 + x) / 2.0).exp();
76                if x == 0.0 || rng.gen::<f64>() <= xn0 {
77                    return s * (i0 as u32 + j) as i32;
78                }
79            }
80        }
81    }
82}
83
84#[cfg(test)]
85mod tests {
86    use super::*;
87    use libm::erf;
88    use std::collections::HashMap;
89
90    fn normcdf(x: f64, mu: f64, sigma: f64) -> f64 {
91        (0.5) * (1.0 + erf((x - mu) / (sigma * (2.0_f64).sqrt())))
92    }
93
94    #[test]
95    fn test_chisq() {
96        let n: usize = 1000;
97        let mu = 0.0;
98        let sigma = 3.0;
99
100        let dn = DiscreteNormal::new(mu, sigma);
101
102        let mut hist = HashMap::new();
103        for _ in 0..n {
104            let o = dn.sample(&mut rand::thread_rng());
105            match hist.get(&o) {
106                Some(count) => hist.insert(o, count + 1),
107                None => hist.insert(o, 1),
108            };
109        }
110
111        let mut data = hist.iter().map(|(k, _)| *k).collect::<Vec<i32>>();
112        data.sort();
113
114        let s = data.len();
115        let probs = data
116            .iter()
117            .map(|d| normcdf(*d as f64, mu, sigma))
118            .collect::<Vec<f64>>();
119
120        let expected = (0..s)
121            .map(|i| match i {
122                0 => probs[i],
123                _ => probs[i] - probs[i - 1],
124            })
125            .collect::<Vec<f64>>();
126
127        let mut chisq = 0.0;
128        for i in 0..s {
129            let o = hist[&data[i]];
130            let e = expected[i] * n as f64;
131            chisq += (o as f64 - e).powf(2.0) / e;
132        }
133
134        dbg!(chisq);
135    }
136}