probability/distribution/
binomial.rs1use alloc::{vec, vec::Vec};
2#[allow(unused_imports)]
3use special::Primitive;
4
5use distribution;
6use source::Source;
7
8#[derive(Clone, Copy, Debug)]
10pub struct Binomial {
11 n: usize,
12 p: f64,
13 q: f64,
14 np: f64,
15 nq: f64,
16 npq: f64,
17}
18
19impl Binomial {
20 pub fn new(n: usize, p: f64) -> Self {
25 should!(0.0 < p && p < 1.0);
26 let q = 1.0 - p;
27 let np = n as f64 * p;
28 let nq = n as f64 * q;
29 Binomial {
30 n,
31 p,
32 q,
33 np,
34 nq,
35 npq: np * q,
36 }
37 }
38
39 pub fn with_failure(n: usize, q: f64) -> Self {
45 should!(0.0 < q && q < 1.0);
46 let p = 1.0 - q;
47 let np = n as f64 * p;
48 let nq = n as f64 * q;
49 Binomial {
50 n,
51 p,
52 q,
53 np,
54 nq,
55 npq: np * q,
56 }
57 }
58
59 #[inline(always)]
61 pub fn n(&self) -> usize {
62 self.n
63 }
64
65 #[inline(always)]
67 pub fn p(&self) -> f64 {
68 self.p
69 }
70
71 #[inline(always)]
73 pub fn q(&self) -> f64 {
74 self.q
75 }
76}
77
78impl distribution::Discrete for Binomial {
79 fn mass(&self, x: usize) -> f64 {
89 use core::f64::consts::PI;
90
91 if self.p == 0.0 {
92 return if x == 0 { 1.0 } else { 0.0 };
93 }
94 if self.p == 1.0 {
95 return if x == self.n { 1.0 } else { 0.0 };
96 }
97
98 let n = self.n as f64;
99 if x == 0 {
100 (n * self.q.ln()).exp()
101 } else if x == self.n {
102 (n * self.p.ln()).exp()
103 } else {
104 let x = x as f64;
105 let n_m_x = n - x;
106 let ln_c = stirlerr(n)
107 - stirlerr(x)
108 - stirlerr(n_m_x)
109 - ln_d0(x, self.np)
110 - ln_d0(n_m_x, self.nq);
111 ln_c.exp() * (n / (2.0 * PI * x * (n_m_x))).sqrt()
112 }
113 }
114}
115
116impl distribution::Distribution for Binomial {
117 type Value = usize;
118
119 fn distribution(&self, x: f64) -> f64 {
123 use special::Beta;
124 if x < 0.0 {
125 return 0.0;
126 }
127 let x = x as usize;
128 if x == 0 {
129 return self.q.powi(self.n as i32);
130 }
131 if x >= self.n {
132 return 1.0;
133 }
134 let (p, q) = ((self.n - x) as f64, (x + 1) as f64);
135 self.q.inc_beta(p, q, p.ln_beta(q))
136 }
137}
138
139impl distribution::Entropy for Binomial {
140 fn entropy(&self) -> f64 {
141 use core::f64::consts::PI;
142 use distribution::Discrete;
143
144 if self.n > 10000 && self.npq > 80.0 {
145 0.5 * ((2.0 * PI * self.npq).ln() + 1.0)
147 } else {
148 -(0..(self.n + 1)).fold(0.0, |sum, i| sum + self.mass(i) * self.mass(i).ln())
149 }
150 }
151}
152
153impl distribution::Inverse for Binomial {
154 fn inverse(&self, p: f64) -> usize {
166 use distribution::{Discrete, Distribution, Modes};
167
168 should!((0.0..=1.0).contains(&p));
169
170 macro_rules! sum_bottom_up(
171 ($prod_term: expr) => ({
172 let mut k = 1;
173 let mut a = self.q.powi(self.n as i32);
174 let mut sum = a - p;
175 while sum < 0.0 {
176 a *= $prod_term(k);
177 sum += a;
178 k += 1;
179 }
180 k - 1
181 });
182 );
183 macro_rules! sum_top_down(
184 ($prod_term: expr) => ({
185 let mut k = 1;
186 let mut a = self.p.powi(self.n as i32);
187 let mut sum = (1.0 - p) - a;
188 while sum >= 0.0 {
189 a *= $prod_term(k);
190 sum -= a;
191 k += 1;
192 }
193 self.n - k + 1
194 });
195 );
196
197 if p == 0.0 {
198 0
199 } else if p == 1.0 {
200 self.n
201 } else if self.n < 1000 {
202 if p <= self.distribution((self.n / 2) as f64) {
204 sum_bottom_up!(|k| self.p / self.q * ((self.n - k + 1) as f64 / k as f64))
205 } else {
206 sum_top_down!(|k| self.q / self.p * ((self.n - k + 1) as f64 / k as f64))
207 }
208 } else if self.npq > 80.0 {
209 inverse_normal(self.p, self.np, self.npq, p).floor() as usize
211 } else {
212 const ALPHA: f64 = 0.999;
214 let mut q = self.modes()[0] as f64;
215 let mut alpha = 1.0;
216 loop {
217 let delta = alpha * (p - self.distribution(q)) / self.mass(q as usize);
218 if delta.abs() < 0.5 {
219 return q as usize;
220 }
221 q += delta;
222 alpha *= ALPHA;
223 }
224 }
225 }
226}
227
228impl distribution::Kurtosis for Binomial {
229 #[inline]
230 fn kurtosis(&self) -> f64 {
231 (1.0 - 6.0 * self.p * self.q) / self.npq
232 }
233}
234
235impl distribution::Mean for Binomial {
236 #[inline]
237 fn mean(&self) -> f64 {
238 self.np
239 }
240}
241
242impl distribution::Median for Binomial {
243 fn median(&self) -> f64 {
244 use core::f64::consts::LN_2;
245 use distribution::Inverse;
246
247 if (self.np - self.np.trunc()) == 0.0 || (self.p == 0.5 && self.n % 2 != 0) {
248 self.np
249 } else if self.p <= 1.0 - LN_2
250 || self.p >= LN_2
251 || (self.np.round() - self.np).abs() <= self.p.min(self.q)
252 {
253 self.np.round()
254 } else if self.n > 1000 && self.npq > 80.0 {
255 self.np.floor()
257 } else {
258 self.inverse(0.5) as f64
259 }
260 }
261}
262
263impl distribution::Modes for Binomial {
264 fn modes(&self) -> Vec<usize> {
265 let r = self.p * (self.n + 1) as f64;
266 if r == 0.0 {
267 vec![0]
268 } else if self.p == 1.0 {
269 vec![self.n]
270 } else if (r - r.trunc()) != 0.0 {
271 vec![r.floor() as usize]
272 } else {
273 vec![r as usize - 1, r as usize]
274 }
275 }
276}
277
278impl distribution::Sample for Binomial {
279 #[inline]
280 fn sample<S>(&self, source: &mut S) -> usize
281 where
282 S: Source,
283 {
284 use distribution::Inverse;
285 self.inverse(source.read::<f64>())
286 }
287}
288
289impl distribution::Skewness for Binomial {
290 #[inline]
291 fn skewness(&self) -> f64 {
292 (1.0 - 2.0 * self.p) / self.npq.sqrt()
293 }
294}
295
296impl distribution::Variance for Binomial {
297 #[inline]
298 fn variance(&self) -> f64 {
299 self.npq
300 }
301}
302
303#[rustfmt::skip]
305fn inverse_normal(p: f64, np: f64, v: f64, u: f64) -> f64 {
306 use distribution::gaussian;
307
308 let w = gaussian::inverse(u);
309 let w2 = w * w;
310 let w3 = w2 * w;
311 let w4 = w3 * w;
312 let w5 = w4 * w;
313 let w6 = w5 * w;
314 let sd = v.sqrt();
315 let sd_em1 = sd.recip();
316 let sd_em2 = v.recip();
317 let sd_em3 = sd_em1 * sd_em2;
318 let sd_em4 = sd_em2 * sd_em2;
319 let p2 = p * p;
320 let p3 = p2 * p;
321 let p4 = p2 * p2;
322
323 np +
324 sd * w +
325 (p + 1.0) / 3.0 -
326 (2.0 * p - 1.0) * w2 / 6.0 +
327 sd_em1 * w3 * (2.0 * p2 - 2.0 * p - 1.0) / 72.0 -
328 w * (7.0 * p2 - 7.0 * p + 1.0) / 36.0 +
329 sd_em2 * (2.0 * p - 1.0) * (p + 1.0) * (p - 2.0) * (3.0 * w4 + 7.0 * w2 - 16.0 / 1620.0) +
330 sd_em3 * (
331 w5 * (4.0 * p4 - 8.0 * p3 - 48.0 * p2 + 52.0 * p - 23.0) / 17280.0 +
332 w3 * (256.0 * p4 - 512.0 * p3 - 147.0 * p2 + 403.0 * p - 137.0) / 38880.0 -
333 w * (433.0 * p4 - 866.0 * p3 - 921.0 * p2 + 1354.0 * p - 671.0) / 38880.0
334 ) +
335 sd_em4 * (
336 w6 * (2.0 * p - 1.0) * (p2 - p + 1.0) * (p2 - p + 19.0) / 34020.0 +
337 w4 * (2.0 * p - 1.0) * (9.0 * p4 - 18.0 * p3 - 35.0 * p2 + 44.0 * p - 25.0) / 15120.0 +
338 w2 * (2.0 * p - 1.0) * (
339 923.0 * p4 - 1846.0 * p3 + 5271.0 * p2 - 4348.0 * p + 5189.0
340 ) / 408240.0 -
341 4.0 * (2.0 * p - 1.0) * (p + 1.0) * (p - 2.0) * (23.0 * p2 - 23.0 * p + 2.0) / 25515.0
342 )
343 }
345
346fn ln_d0(x: f64, np: f64) -> f64 {
348 if (x - np).abs() < 0.1 * (x + np) {
349 let mut s = (x - np).powi(2) / (x + np);
351 let v = (x - np) / (x + np);
352 let mut ej = 2.0 * x * v;
353 let mut j = 1;
354 loop {
355 ej *= v * v;
356 let s1 = s + ej / (2 * j + 1) as f64;
357 if s1 == s {
358 return s1;
359 }
360 s = s1;
361 j += 1;
362 }
363 }
364 x * (x / np).ln() + np - x
365}
366
367fn stirlerr(n: f64) -> f64 {
369 const S0: f64 = 1.0 / 12.0;
370 const S1: f64 = 1.0 / 360.0;
371 const S2: f64 = 1.0 / 1260.0;
372 const S3: f64 = 1.0 / 1680.0;
373 const S4: f64 = 1.0 / 1188.0;
374
375 #[allow(clippy::excessive_precision)]
377 const SFE: [f64; 16] = [
378 0.000000000000000000e+00,
379 8.106146679532725822e-02,
380 4.134069595540929409e-02,
381 2.767792568499833915e-02,
382 2.079067210376509311e-02,
383 1.664469118982119216e-02,
384 1.387612882307074800e-02,
385 1.189670994589177010e-02,
386 1.041126526197209650e-02,
387 9.255462182712732918e-03,
388 8.330563433362871256e-03,
389 7.757367548795184079e-03,
390 6.942840107209529866e-03,
391 6.408994188004207068e-03,
392 5.951370112758847736e-03,
393 5.554733551962801371e-03,
394 ];
395
396 if n < 16.0 {
397 return SFE[n as usize];
398 }
399
400 let nn = n * n;
402 if n > 500.0 {
403 (S0 - S1 / nn) / n
404 } else if n > 80.0 {
405 (S0 - (S1 - S2 / nn) / nn) / n
406 } else if n > 35.0 {
407 (S0 - (S1 - (S2 - S3 / nn) / nn) / nn) / n
408 } else {
409 (S0 - (S1 - (S2 - (S3 - S4 / nn) / nn) / nn) / nn) / n
410 }
411}
412
413#[cfg(test)]
414mod tests {
415 use alloc::{vec, vec::Vec};
416 use assert;
417 use prelude::*;
418
419 macro_rules! new {
420 ($n:expr, $p:expr) => {
421 Binomial::new($n, $p)
422 };
423 }
424
425 #[test]
426 fn distribution() {
427 let d = new!(16, 0.75);
428 let p = vec![
429 0.000000000000000e+00,
430 2.328306436538699e-10,
431 2.628657966852194e-07,
432 3.810715861618527e-05,
433 1.644465373829007e-03,
434 2.712995628826319e-02,
435 1.896545726340262e-01,
436 5.950128899421541e-01,
437 9.365235602017492e-01,
438 1.000000000000000e+00,
439 ];
440
441 let x = (-1..9)
442 .map(|i| d.distribution(2.0 * i as f64))
443 .collect::<Vec<_>>();
444 assert::close(&x, &p, 1e-14);
445
446 let x = (-1..9)
447 .map(|i| d.distribution(2.0 * i as f64 + 0.5))
448 .collect::<Vec<_>>();
449 assert::close(&x, &p, 1e-14);
450 }
451
452 #[test]
453 fn entropy() {
454 assert_eq!(new!(16, 0.25).entropy(), 1.9588018945068573);
455 assert_eq!(new!(10_000_000, 0.5).entropy(), 8.784839178123887);
456 }
457
458 #[test]
459 fn inverse() {
460 let d = new!(10, 0.5);
462 assert_eq!(d.inverse(0.0), 0);
463 assert_eq!(d.inverse(1.0), 10);
464
465 let d = new!(250, 0.55);
467 assert_eq!(d.inverse(0.025), 122);
468 assert_eq!(d.inverse(0.1), 127);
469
470 let d = new!(2500, 0.55);
472 assert_eq!(d.inverse(d.distribution(1298.0)), 1298);
473 assert_eq!(new!(1001, 0.25).inverse(0.5), 250);
474 assert_eq!(new!(1500, 0.15).inverse(0.2), 213);
475
476 assert_eq!(new!(1_000_000, 2.5e-5).inverse(0.9995), 42);
478 assert_eq!(new!(1_000_000_000, 6.66e-9).inverse(0.8), 8);
479 }
480
481 #[test]
482 fn inverse_convergence() {
483 let d = new!(1024, 0.009765625);
484 assert_eq!(d.inverse(0.32185663510619567), 8);
485
486 let d = new!(3666, 0.9810204628647335);
487 assert_eq!(d.inverse(0.0033333333333332993), 3573);
488 }
489
490 #[test]
491 fn kurtosis() {
492 assert_eq!(new!(16, 0.25).kurtosis(), -0.041666666666666664);
493 }
494
495 #[test]
496 fn mass() {
497 let d = new!(16, 0.25);
498 let p = vec![
499 1.002259575761855e-02,
500 1.336346101015806e-01,
501 2.251990651711821e-01,
502 1.100973207503558e-01,
503 1.966023584827779e-02,
504 1.359226182103156e-03,
505 3.432389348745344e-05,
506 2.514570951461788e-07,
507 2.328306436538698e-10,
508 ];
509
510 assert::close(
511 &(0..9).map(|i| d.mass(2 * i)).collect::<Vec<_>>(),
512 &p,
513 1e-14,
514 );
515 }
516
517 #[test]
518 fn mean() {
519 assert_eq!(new!(16, 0.25).mean(), 4.0);
520 }
521
522 #[test]
523 fn median() {
524 assert_eq!(new!(16, 0.25).median(), 4.0);
525 assert_eq!(new!(3, 0.5).median(), 1.5);
526 assert_eq!(new!(1000, 0.015).median(), 15.0);
527 assert_eq!(new!(39, 0.1).median(), 4.0);
528 }
529
530 #[test]
531 fn modes() {
532 assert_eq!(new!(16, 0.25).modes(), vec![4]);
533 assert_eq!(new!(3, 0.5).modes(), vec![1, 2]);
534 assert_eq!(new!(1000, 0.015).modes(), vec![15]);
535 assert_eq!(new!(39, 0.1).modes(), vec![3, 4]);
536 }
537
538 #[test]
539 fn skewness() {
540 assert_eq!(new!(16, 0.25).skewness(), 0.2886751345948129);
541 }
542
543 #[test]
544 fn variance() {
545 assert_eq!(new!(16, 0.25).variance(), 3.0);
546 }
547}