probability/distribution/
binomial.rs

1use alloc::{vec, vec::Vec};
2#[allow(unused_imports)]
3use special::Primitive;
4
5use distribution;
6use source::Source;
7
8/// A binomial distribution.
9#[derive(Clone, Copy, Debug)]
10pub struct Binomial {
11    n: usize,
12    p: f64,
13    q: f64,
14    np: f64,
15    nq: f64,
16    npq: f64,
17}
18
19impl Binomial {
20    /// Create a binomial distribution with `n` trails and success probability
21    /// `p`.
22    ///
23    /// It should hold that `p >= 0` and `p <= 1`.
24    pub fn new(n: usize, p: f64) -> Self {
25        should!(0.0 < p && p < 1.0);
26        let q = 1.0 - p;
27        let np = n as f64 * p;
28        let nq = n as f64 * q;
29        Binomial {
30            n,
31            p,
32            q,
33            np,
34            nq,
35            npq: np * q,
36        }
37    }
38
39    /// Create a binomial distribution with `n` trails and failure probability
40    /// `q`.
41    ///
42    /// It should hold that if `q >= 0` or `q <= 1`. This constructor is
43    /// preferable when `q` is very small.
44    pub fn with_failure(n: usize, q: f64) -> Self {
45        should!(0.0 < q && q < 1.0);
46        let p = 1.0 - q;
47        let np = n as f64 * p;
48        let nq = n as f64 * q;
49        Binomial {
50            n,
51            p,
52            q,
53            np,
54            nq,
55            npq: np * q,
56        }
57    }
58
59    /// Return the number of trials.
60    #[inline(always)]
61    pub fn n(&self) -> usize {
62        self.n
63    }
64
65    /// Return the success probability.
66    #[inline(always)]
67    pub fn p(&self) -> f64 {
68        self.p
69    }
70
71    /// Return the failure probability.
72    #[inline(always)]
73    pub fn q(&self) -> f64 {
74        self.q
75    }
76}
77
78impl distribution::Discrete for Binomial {
79    /// Compute the probability mass function.
80    ///
81    /// For large `n`, a saddle-point expansion is used for more accurate
82    /// computation.
83    ///
84    /// ## References
85    ///
86    /// 1. C. Loader, “Fast and Accurate Computation of Binomial Probabilities,”
87    ///    2000.
88    fn mass(&self, x: usize) -> f64 {
89        use core::f64::consts::PI;
90
91        if self.p == 0.0 {
92            return if x == 0 { 1.0 } else { 0.0 };
93        }
94        if self.p == 1.0 {
95            return if x == self.n { 1.0 } else { 0.0 };
96        }
97
98        let n = self.n as f64;
99        if x == 0 {
100            (n * self.q.ln()).exp()
101        } else if x == self.n {
102            (n * self.p.ln()).exp()
103        } else {
104            let x = x as f64;
105            let n_m_x = n - x;
106            let ln_c = stirlerr(n)
107                - stirlerr(x)
108                - stirlerr(n_m_x)
109                - ln_d0(x, self.np)
110                - ln_d0(n_m_x, self.nq);
111            ln_c.exp() * (n / (2.0 * PI * x * (n_m_x))).sqrt()
112        }
113    }
114}
115
116impl distribution::Distribution for Binomial {
117    type Value = usize;
118
119    /// Compute the cumulative distribution function.
120    ///
121    /// The implementation is based on the incomplete beta function.
122    fn distribution(&self, x: f64) -> f64 {
123        use special::Beta;
124        if x < 0.0 {
125            return 0.0;
126        }
127        let x = x as usize;
128        if x == 0 {
129            return self.q.powi(self.n as i32);
130        }
131        if x >= self.n {
132            return 1.0;
133        }
134        let (p, q) = ((self.n - x) as f64, (x + 1) as f64);
135        self.q.inc_beta(p, q, p.ln_beta(q))
136    }
137}
138
139impl distribution::Entropy for Binomial {
140    fn entropy(&self) -> f64 {
141        use core::f64::consts::PI;
142        use distribution::Discrete;
143
144        if self.n > 10000 && self.npq > 80.0 {
145            // Use a normal approximation.
146            0.5 * ((2.0 * PI * self.npq).ln() + 1.0)
147        } else {
148            -(0..(self.n + 1)).fold(0.0, |sum, i| sum + self.mass(i) * self.mass(i).ln())
149        }
150    }
151}
152
153impl distribution::Inverse for Binomial {
154    /// Compute the inverse of the cumulative distribution function.
155    ///
156    /// For small `n`, a simple summation is utilized. For large `n` and large
157    /// variances, a normal asymptotic approximation is used. Otherwise,
158    /// Newton’s method is employed.
159    ///
160    /// ## References
161    ///
162    /// 1. S. Moorhead, “Efficient evaluation of the inverse binomial cumulative
163    ///    distribution function where the number of trials is large,” Oxford
164    ///    University, 2013.
165    fn inverse(&self, p: f64) -> usize {
166        use distribution::{Discrete, Distribution, Modes};
167
168        should!((0.0..=1.0).contains(&p));
169
170        macro_rules! sum_bottom_up(
171            ($prod_term: expr) => ({
172                let mut k = 1;
173                let mut a = self.q.powi(self.n as i32);
174                let mut sum = a - p;
175                while sum < 0.0 {
176                    a *= $prod_term(k);
177                    sum += a;
178                    k += 1;
179                }
180                k - 1
181            });
182        );
183        macro_rules! sum_top_down(
184            ($prod_term: expr) => ({
185                let mut k = 1;
186                let mut a = self.p.powi(self.n as i32);
187                let mut sum = (1.0 - p) - a;
188                while sum >= 0.0 {
189                    a *= $prod_term(k);
190                    sum -= a;
191                    k += 1;
192                }
193                self.n - k + 1
194            });
195        );
196
197        if p == 0.0 {
198            0
199        } else if p == 1.0 {
200            self.n
201        } else if self.n < 1000 {
202            // Find if top-down or bottom-up summation is better.
203            if p <= self.distribution((self.n / 2) as f64) {
204                sum_bottom_up!(|k| self.p / self.q * ((self.n - k + 1) as f64 / k as f64))
205            } else {
206                sum_top_down!(|k| self.q / self.p * ((self.n - k + 1) as f64 / k as f64))
207            }
208        } else if self.npq > 80.0 {
209            // Use a normal approximation.
210            inverse_normal(self.p, self.np, self.npq, p).floor() as usize
211        } else {
212            // Use Newton’s method starting at the mode.
213            const ALPHA: f64 = 0.999;
214            let mut q = self.modes()[0] as f64;
215            let mut alpha = 1.0;
216            loop {
217                let delta = alpha * (p - self.distribution(q)) / self.mass(q as usize);
218                if delta.abs() < 0.5 {
219                    return q as usize;
220                }
221                q += delta;
222                alpha *= ALPHA;
223            }
224        }
225    }
226}
227
228impl distribution::Kurtosis for Binomial {
229    #[inline]
230    fn kurtosis(&self) -> f64 {
231        (1.0 - 6.0 * self.p * self.q) / self.npq
232    }
233}
234
235impl distribution::Mean for Binomial {
236    #[inline]
237    fn mean(&self) -> f64 {
238        self.np
239    }
240}
241
242impl distribution::Median for Binomial {
243    fn median(&self) -> f64 {
244        use core::f64::consts::LN_2;
245        use distribution::Inverse;
246
247        if (self.np - self.np.trunc()) == 0.0 || (self.p == 0.5 && self.n % 2 != 0) {
248            self.np
249        } else if self.p <= 1.0 - LN_2
250            || self.p >= LN_2
251            || (self.np.round() - self.np).abs() <= self.p.min(self.q)
252        {
253            self.np.round()
254        } else if self.n > 1000 && self.npq > 80.0 {
255            // Use a normal approximation.
256            self.np.floor()
257        } else {
258            self.inverse(0.5) as f64
259        }
260    }
261}
262
263impl distribution::Modes for Binomial {
264    fn modes(&self) -> Vec<usize> {
265        let r = self.p * (self.n + 1) as f64;
266        if r == 0.0 {
267            vec![0]
268        } else if self.p == 1.0 {
269            vec![self.n]
270        } else if (r - r.trunc()) != 0.0 {
271            vec![r.floor() as usize]
272        } else {
273            vec![r as usize - 1, r as usize]
274        }
275    }
276}
277
278impl distribution::Sample for Binomial {
279    #[inline]
280    fn sample<S>(&self, source: &mut S) -> usize
281    where
282        S: Source,
283    {
284        use distribution::Inverse;
285        self.inverse(source.read::<f64>())
286    }
287}
288
289impl distribution::Skewness for Binomial {
290    #[inline]
291    fn skewness(&self) -> f64 {
292        (1.0 - 2.0 * self.p) / self.npq.sqrt()
293    }
294}
295
296impl distribution::Variance for Binomial {
297    #[inline]
298    fn variance(&self) -> f64 {
299        self.npq
300    }
301}
302
303// See [Moorhead, 2013, pp. 7].
304#[rustfmt::skip]
305fn inverse_normal(p: f64, np: f64, v: f64, u: f64) -> f64 {
306    use distribution::gaussian;
307
308    let w = gaussian::inverse(u);
309    let w2 = w * w;
310    let w3 = w2 * w;
311    let w4 = w3 * w;
312    let w5 = w4 * w;
313    let w6 = w5 * w;
314    let sd = v.sqrt();
315    let sd_em1 = sd.recip();
316    let sd_em2 = v.recip();
317    let sd_em3 = sd_em1 * sd_em2;
318    let sd_em4 = sd_em2 * sd_em2;
319    let p2 = p * p;
320    let p3 = p2 * p;
321    let p4 = p2 * p2;
322
323    np +
324    sd * w +
325    (p + 1.0) / 3.0 -
326    (2.0 * p - 1.0) * w2 / 6.0 +
327    sd_em1 * w3 * (2.0 * p2 - 2.0 * p - 1.0) / 72.0 -
328    w * (7.0 * p2 - 7.0 * p + 1.0) / 36.0 +
329    sd_em2 * (2.0 * p - 1.0) * (p + 1.0) * (p - 2.0) * (3.0 * w4 + 7.0 * w2 - 16.0 / 1620.0) +
330    sd_em3 * (
331        w5 * (4.0 * p4 - 8.0 * p3 - 48.0 * p2 + 52.0 * p - 23.0) / 17280.0 +
332        w3 * (256.0 * p4 - 512.0 * p3 - 147.0 * p2 + 403.0 * p - 137.0) / 38880.0 -
333        w * (433.0 * p4 - 866.0 * p3 - 921.0 * p2 + 1354.0 * p - 671.0) / 38880.0
334    ) +
335    sd_em4 * (
336        w6 * (2.0 * p - 1.0) * (p2 - p + 1.0) * (p2 - p + 19.0) / 34020.0 +
337        w4 * (2.0 * p - 1.0) * (9.0 * p4 - 18.0 * p3 - 35.0 * p2 + 44.0 * p - 25.0) / 15120.0 +
338        w2 * (2.0 * p - 1.0) * (
339                923.0 * p4 - 1846.0 * p3 + 5271.0 * p2 - 4348.0 * p + 5189.0
340        ) / 408240.0 -
341        4.0 * (2.0 * p - 1.0) * (p + 1.0) * (p - 2.0) * (23.0 * p2 - 23.0 * p + 2.0) / 25515.0
342    )
343    // + O(v.powf(-2.5)), with probabilty of 1 - 2e-9
344}
345
346// ln(np * D₀) = x * ln(x / np) + np - x
347fn ln_d0(x: f64, np: f64) -> f64 {
348    if (x - np).abs() < 0.1 * (x + np) {
349        // ε = (n / np) is close to 1. Use a series expansion.
350        let mut s = (x - np).powi(2) / (x + np);
351        let v = (x - np) / (x + np);
352        let mut ej = 2.0 * x * v;
353        let mut j = 1;
354        loop {
355            ej *= v * v;
356            let s1 = s + ej / (2 * j + 1) as f64;
357            if s1 == s {
358                return s1;
359            }
360            s = s1;
361            j += 1;
362        }
363    }
364    x * (x / np).ln() + np - x
365}
366
367// strilerr(n) = ln(n!) - ln(sqrt(2π * n) * (n / e)^n)
368fn stirlerr(n: f64) -> f64 {
369    const S0: f64 = 1.0 / 12.0;
370    const S1: f64 = 1.0 / 360.0;
371    const S2: f64 = 1.0 / 1260.0;
372    const S3: f64 = 1.0 / 1680.0;
373    const S4: f64 = 1.0 / 1188.0;
374
375    // See [Loader, 2000, pp. 7].
376    #[allow(clippy::excessive_precision)]
377    const SFE: [f64; 16] = [
378        0.000000000000000000e+00,
379        8.106146679532725822e-02,
380        4.134069595540929409e-02,
381        2.767792568499833915e-02,
382        2.079067210376509311e-02,
383        1.664469118982119216e-02,
384        1.387612882307074800e-02,
385        1.189670994589177010e-02,
386        1.041126526197209650e-02,
387        9.255462182712732918e-03,
388        8.330563433362871256e-03,
389        7.757367548795184079e-03,
390        6.942840107209529866e-03,
391        6.408994188004207068e-03,
392        5.951370112758847736e-03,
393        5.554733551962801371e-03,
394    ];
395
396    if n < 16.0 {
397        return SFE[n as usize];
398    }
399
400    // See [Loader, 2000, eq. 4].
401    let nn = n * n;
402    if n > 500.0 {
403        (S0 - S1 / nn) / n
404    } else if n > 80.0 {
405        (S0 - (S1 - S2 / nn) / nn) / n
406    } else if n > 35.0 {
407        (S0 - (S1 - (S2 - S3 / nn) / nn) / nn) / n
408    } else {
409        (S0 - (S1 - (S2 - (S3 - S4 / nn) / nn) / nn) / nn) / n
410    }
411}
412
413#[cfg(test)]
414mod tests {
415    use alloc::{vec, vec::Vec};
416    use assert;
417    use prelude::*;
418
419    macro_rules! new {
420        ($n:expr, $p:expr) => {
421            Binomial::new($n, $p)
422        };
423    }
424
425    #[test]
426    fn distribution() {
427        let d = new!(16, 0.75);
428        let p = vec![
429            0.000000000000000e+00,
430            2.328306436538699e-10,
431            2.628657966852194e-07,
432            3.810715861618527e-05,
433            1.644465373829007e-03,
434            2.712995628826319e-02,
435            1.896545726340262e-01,
436            5.950128899421541e-01,
437            9.365235602017492e-01,
438            1.000000000000000e+00,
439        ];
440
441        let x = (-1..9)
442            .map(|i| d.distribution(2.0 * i as f64))
443            .collect::<Vec<_>>();
444        assert::close(&x, &p, 1e-14);
445
446        let x = (-1..9)
447            .map(|i| d.distribution(2.0 * i as f64 + 0.5))
448            .collect::<Vec<_>>();
449        assert::close(&x, &p, 1e-14);
450    }
451
452    #[test]
453    fn entropy() {
454        assert_eq!(new!(16, 0.25).entropy(), 1.9588018945068573);
455        assert_eq!(new!(10_000_000, 0.5).entropy(), 8.784839178123887);
456    }
457
458    #[test]
459    fn inverse() {
460        // Check edge cases.
461        let d = new!(10, 0.5);
462        assert_eq!(d.inverse(0.0), 0);
463        assert_eq!(d.inverse(1.0), 10);
464
465        // Check the summation.
466        let d = new!(250, 0.55);
467        assert_eq!(d.inverse(0.025), 122);
468        assert_eq!(d.inverse(0.1), 127);
469
470        // Check the normal approximation.
471        let d = new!(2500, 0.55);
472        assert_eq!(d.inverse(d.distribution(1298.0)), 1298);
473        assert_eq!(new!(1001, 0.25).inverse(0.5), 250);
474        assert_eq!(new!(1500, 0.15).inverse(0.2), 213);
475
476        // Check Newton’s method.
477        assert_eq!(new!(1_000_000, 2.5e-5).inverse(0.9995), 42);
478        assert_eq!(new!(1_000_000_000, 6.66e-9).inverse(0.8), 8);
479    }
480
481    #[test]
482    fn inverse_convergence() {
483        let d = new!(1024, 0.009765625);
484        assert_eq!(d.inverse(0.32185663510619567), 8);
485
486        let d = new!(3666, 0.9810204628647335);
487        assert_eq!(d.inverse(0.0033333333333332993), 3573);
488    }
489
490    #[test]
491    fn kurtosis() {
492        assert_eq!(new!(16, 0.25).kurtosis(), -0.041666666666666664);
493    }
494
495    #[test]
496    fn mass() {
497        let d = new!(16, 0.25);
498        let p = vec![
499            1.002259575761855e-02,
500            1.336346101015806e-01,
501            2.251990651711821e-01,
502            1.100973207503558e-01,
503            1.966023584827779e-02,
504            1.359226182103156e-03,
505            3.432389348745344e-05,
506            2.514570951461788e-07,
507            2.328306436538698e-10,
508        ];
509
510        assert::close(
511            &(0..9).map(|i| d.mass(2 * i)).collect::<Vec<_>>(),
512            &p,
513            1e-14,
514        );
515    }
516
517    #[test]
518    fn mean() {
519        assert_eq!(new!(16, 0.25).mean(), 4.0);
520    }
521
522    #[test]
523    fn median() {
524        assert_eq!(new!(16, 0.25).median(), 4.0);
525        assert_eq!(new!(3, 0.5).median(), 1.5);
526        assert_eq!(new!(1000, 0.015).median(), 15.0);
527        assert_eq!(new!(39, 0.1).median(), 4.0);
528    }
529
530    #[test]
531    fn modes() {
532        assert_eq!(new!(16, 0.25).modes(), vec![4]);
533        assert_eq!(new!(3, 0.5).modes(), vec![1, 2]);
534        assert_eq!(new!(1000, 0.015).modes(), vec![15]);
535        assert_eq!(new!(39, 0.1).modes(), vec![3, 4]);
536    }
537
538    #[test]
539    fn skewness() {
540        assert_eq!(new!(16, 0.25).skewness(), 0.2886751345948129);
541    }
542
543    #[test]
544    fn variance() {
545        assert_eq!(new!(16, 0.25).variance(), 3.0);
546    }
547}