rs_stats/distributions/
negative_binomial.rs1use crate::distributions::traits::DiscreteDistribution;
43use crate::error::{StatsError, StatsResult};
44use crate::utils::special_functions::ln_gamma;
45use serde::{Deserialize, Serialize};
46
47#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
58pub struct NegativeBinomial {
59 pub r: f64,
61 pub p: f64,
63}
64
65impl NegativeBinomial {
66 pub fn new(r: f64, p: f64) -> StatsResult<Self> {
68 if r <= 0.0 {
69 return Err(StatsError::InvalidInput {
70 message: "NegativeBinomial::new: r must be positive".to_string(),
71 });
72 }
73 if !(0.0 < p && p < 1.0) {
74 return Err(StatsError::InvalidInput {
75 message: "NegativeBinomial::new: p must be in (0, 1)".to_string(),
76 });
77 }
78 Ok(Self { r, p })
79 }
80
81 pub fn fit(data: &[f64]) -> StatsResult<Self> {
86 if data.is_empty() {
87 return Err(StatsError::InvalidInput {
88 message: "NegativeBinomial::fit: data must not be empty".to_string(),
89 });
90 }
91 if data.iter().any(|&x| x < 0.0 || x.fract() != 0.0) {
92 return Err(StatsError::InvalidInput {
93 message: "NegativeBinomial::fit: all data values must be non-negative integers"
94 .to_string(),
95 });
96 }
97 let n = data.len() as f64;
98 let mean = data.iter().sum::<f64>() / n;
99 let variance = data.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n;
100
101 if variance <= mean {
102 return Self::new(mean.max(0.01) * 10.0, 1.0 - 1.0 / 11.0);
104 }
105
106 let p = mean / variance;
107 let r = mean * p / (1.0 - p);
108 Self::new(r.max(0.01), p.clamp(1e-9, 1.0 - 1e-9))
109 }
110}
111
112impl DiscreteDistribution for NegativeBinomial {
113 fn name(&self) -> &str {
114 "NegativeBinomial"
115 }
116 fn num_params(&self) -> usize {
117 2
118 }
119
120 fn pmf(&self, k: u64) -> StatsResult<f64> {
121 Ok(self.logpmf(k)?.exp())
122 }
123
124 fn logpmf(&self, k: u64) -> StatsResult<f64> {
125 let kf = k as f64;
126 let log_binom = ln_gamma(kf + self.r) - ln_gamma(self.r) - ln_gamma(kf + 1.0);
128 Ok(log_binom + self.r * self.p.ln() + kf * (1.0 - self.p).ln())
129 }
130
131 fn cdf(&self, k: u64) -> StatsResult<f64> {
132 let mut sum = 0.0_f64;
134 for i in 0..=k {
135 sum += self.pmf(i)?;
136 if sum >= 1.0 - 1e-15 {
138 return Ok(1.0);
139 }
140 }
141 Ok(sum.clamp(0.0, 1.0))
142 }
143
144 fn mean(&self) -> f64 {
145 self.r * (1.0 - self.p) / self.p
146 }
147
148 fn variance(&self) -> f64 {
149 self.r * (1.0 - self.p) / (self.p * self.p)
150 }
151}
152
153#[cfg(test)]
154mod tests {
155 use super::*;
156
157 #[test]
158 fn test_neg_binom_mean_variance() {
159 let nb = NegativeBinomial::new(5.0, 0.5).unwrap();
160 assert!((nb.mean() - 5.0).abs() < 1e-10);
161 assert!((nb.variance() - 10.0).abs() < 1e-10);
162 }
163
164 #[test]
165 fn test_neg_binom_pmf_k0() {
166 let nb = NegativeBinomial::new(2.0, 0.5).unwrap();
168 assert!((nb.pmf(0).unwrap() - 0.25).abs() < 1e-10);
169 }
170
171 #[test]
172 fn test_neg_binom_cdf_monotone() {
173 let nb = NegativeBinomial::new(3.0, 0.4).unwrap();
174 let mut prev = 0.0;
175 for k in 0..20 {
176 let c = nb.cdf(k).unwrap();
177 assert!(c >= prev, "CDF not monotone at k={k}");
178 prev = c;
179 }
180 }
181
182 #[test]
183 fn test_neg_binom_fit() {
184 let data = vec![0.0, 1.0, 2.0, 0.0, 3.0, 1.0, 0.0, 4.0, 1.0, 2.0];
185 let nb = NegativeBinomial::fit(&data).unwrap();
186 assert!(nb.r > 0.0 && nb.p > 0.0 && nb.p < 1.0);
187 }
188
189 #[test]
190 fn test_neg_binom_invalid() {
191 assert!(NegativeBinomial::new(0.0, 0.5).is_err());
192 assert!(NegativeBinomial::new(1.0, 0.0).is_err());
193 assert!(NegativeBinomial::new(1.0, 1.0).is_err());
194 }
195}