numra_stats/distributions/
binomial.rs1use numra_special::betainc;
8use rand::RngCore;
9
10use super::normal::random_uniform_01;
11use super::DiscreteDistribution;
12
13#[derive(Clone, Debug)]
15pub struct Binomial {
16 pub n: usize,
17 pub p: f64,
18}
19
20impl Binomial {
21 pub fn new(n: usize, p: f64) -> Self {
22 Self { n, p }
23 }
24}
25
26impl DiscreteDistribution<f64> for Binomial {
27 fn pmf(&self, k: usize) -> f64 {
28 if k > self.n {
29 return 0.0;
30 }
31 let log_pmf = log_binom_coeff(self.n, k)
32 + k as f64 * self.p.ln()
33 + (self.n - k) as f64 * (1.0 - self.p).ln();
34 log_pmf.exp()
35 }
36
37 fn cdf(&self, k: usize) -> f64 {
38 if k >= self.n {
39 return 1.0;
40 }
41 betainc((self.n - k) as f64, (k + 1) as f64, 1.0 - self.p)
43 }
44
45 fn mean(&self) -> f64 {
46 self.n as f64 * self.p
47 }
48
49 fn variance(&self) -> f64 {
50 self.n as f64 * self.p * (1.0 - self.p)
51 }
52
53 fn sample(&self, rng: &mut dyn RngCore) -> usize {
54 if self.n < 50 {
56 let mut count = 0;
57 for _ in 0..self.n {
58 let u = random_uniform_01::<f64>(rng);
59 if u < self.p {
60 count += 1;
61 }
62 }
63 count
64 } else {
65 let mu = self.n as f64 * self.p;
67 let sigma = (mu * (1.0 - self.p)).sqrt();
68 let u1 = random_uniform_01::<f64>(rng);
69 let u2 = random_uniform_01::<f64>(rng);
70 let z = (-2.0 * u1.ln()).sqrt() * (core::f64::consts::TAU * u2).cos();
71 let x = (mu + sigma * z).round().max(0.0).min(self.n as f64);
72 x as usize
73 }
74 }
75}
76
77fn log_binom_coeff(n: usize, k: usize) -> f64 {
79 let lf = |m: usize| -> f64 {
80 if m <= 1 {
81 0.0
82 } else {
83 numra_special::lgamma((m + 1) as f64)
84 }
85 };
86 lf(n) - lf(k) - lf(n - k)
87}
88
89#[cfg(test)]
90mod tests {
91 use super::*;
92
93 #[test]
94 fn test_binomial_pmf_sum() {
95 let b = Binomial::new(10, 0.3);
96 let sum: f64 = (0..=10).map(|k| b.pmf(k)).sum();
97 assert!((sum - 1.0).abs() < 1e-10, "sum = {}", sum);
98 }
99
100 #[test]
101 fn test_binomial_cdf_monotone() {
102 let b = Binomial::new(10, 0.5);
103 let mut prev = 0.0;
104 for k in 0..=10 {
105 let c = b.cdf(k);
106 assert!(c >= prev - 1e-14, "CDF not monotone at k={}", k);
107 prev = c;
108 }
109 assert!((prev - 1.0).abs() < 1e-6);
110 }
111
112 #[test]
113 fn test_binomial_mean_variance() {
114 let b = Binomial::new(20, 0.4);
115 assert!((b.mean() - 8.0).abs() < 1e-14);
116 assert!((b.variance() - 4.8).abs() < 1e-14);
117 }
118
119 #[test]
120 fn test_binomial_coin_flip() {
121 let b = Binomial::new(1, 0.5);
123 assert!((b.pmf(0) - 0.5).abs() < 1e-14);
124 assert!((b.pmf(1) - 0.5).abs() < 1e-14);
125 }
126
127 #[test]
128 fn test_binomial_sample_mean() {
129 use rand::SeedableRng;
130 let b = Binomial::new(20, 0.5);
131 let mut rng = rand::rngs::StdRng::seed_from_u64(44);
132 let samples = b.sample_n(&mut rng, 5000);
133 let mean = samples.iter().sum::<usize>() as f64 / samples.len() as f64;
134 assert!((mean - 10.0).abs() < 1.0, "sample mean = {}", mean);
135 }
136}