Skip to main content

rs_stats/distributions/
negative_binomial.rs

1//! # Negative Binomial Distribution
2//!
3//! The Negative Binomial distribution NegBinom(r, p) models the number of failures
4//! before achieving r successes in Bernoulli trials with success probability p.
5//! It also serves as the canonical model for **overdispersed count data** (variance > mean).
6//!
7//! Support: k = 0, 1, 2, …  (number of failures)
8//!
9//! **PMF**: P(X = k) = C(k+r−1, k) · p^r · (1−p)^k
10//!
11//! **Mean**: r(1−p)/p   **Variance**: r(1−p)/p²
12//!
13//! ## When to use over Poisson
14//!
15//! When count data has **variance > mean** (overdispersion), Poisson underfits.
16//! Use Negative Binomial — it adds a free parameter r to absorb the extra variability.
17//! Common in healthcare where patients are heterogeneous (different baseline risks).
18//!
19//! ## Medical applications
20//!
21//! - **Hospital readmissions** before stable remission (heterogeneous patient risk)
22//! - **Overdispersed adverse event counts** in pharmacovigilance
23//! - **Number of disease recurrences** before sustained response
24//! - **Emergency department visits** per patient per year (high inter-patient variability)
25//!
26//! ## Example — re-admissions before remission
27//!
28//! ```rust
29//! use rs_stats::distributions::negative_binomial::NegativeBinomial;
30//! use rs_stats::DiscreteDistribution;
31//!
32//! // Re-admissions data across 20 patients — variance >> mean → overdispersed
33//! let admissions = vec![
34//!     0.0, 2.0, 1.0, 5.0, 3.0, 0.0, 4.0, 1.0, 2.0, 6.0,
35//!     1.0, 0.0, 3.0, 2.0, 1.0, 4.0, 0.0, 2.0, 3.0, 1.0,
36//! ];
37//! let nb = NegativeBinomial::fit(&admissions).unwrap();
38//! println!("NegBin(r={:.2}, p={:.3})", nb.r, nb.p);
39//! println!("P(0 re-admissions) = {:.1}%", nb.pmf(0).unwrap() * 100.0);
40//! ```
41
42use crate::distributions::traits::DiscreteDistribution;
43use crate::error::{StatsError, StatsResult};
44use crate::utils::special_functions::ln_gamma;
45use serde::{Deserialize, Serialize};
46
47/// Negative Binomial distribution NegBinom(r, p).
48///
49/// # Examples
50/// ```
51/// use rs_stats::distributions::negative_binomial::NegativeBinomial;
52/// use rs_stats::distributions::traits::DiscreteDistribution;
53///
54/// let nb = NegativeBinomial::new(5.0, 0.5).unwrap();
55/// assert!((nb.mean() - 5.0).abs() < 1e-10);
56/// ```
57#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
58pub struct NegativeBinomial {
59    /// Number of successes r > 0 (can be non-integer, i.e. the overdispersion parameter)
60    pub r: f64,
61    /// Success probability p ∈ (0, 1)
62    pub p: f64,
63}
64
65impl NegativeBinomial {
66    /// Creates a `NegBinom(r, p)` distribution.
67    pub fn new(r: f64, p: f64) -> StatsResult<Self> {
68        if r <= 0.0 {
69            return Err(StatsError::InvalidInput {
70                message: "NegativeBinomial::new: r must be positive".to_string(),
71            });
72        }
73        if !(0.0 < p && p < 1.0) {
74            return Err(StatsError::InvalidInput {
75                message: "NegativeBinomial::new: p must be in (0, 1)".to_string(),
76            });
77        }
78        Ok(Self { r, p })
79    }
80
81    /// Method-of-moments fitting.
82    ///
83    /// - mean = r(1−p)/p  → p = mean/variance (requires variance > mean)
84    /// - r = mean² / (variance − mean)
85    pub fn fit(data: &[f64]) -> StatsResult<Self> {
86        if data.is_empty() {
87            return Err(StatsError::InvalidInput {
88                message: "NegativeBinomial::fit: data must not be empty".to_string(),
89            });
90        }
91        if data.iter().any(|&x| x < 0.0 || x.fract() != 0.0) {
92            return Err(StatsError::InvalidInput {
93                message: "NegativeBinomial::fit: all data values must be non-negative integers"
94                    .to_string(),
95            });
96        }
97        let n = data.len() as f64;
98        let mean = data.iter().sum::<f64>() / n;
99        let variance = data.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n;
100
101        if variance <= mean {
102            // Data appears Poisson-like; fall back to large r (approximates Poisson)
103            return Self::new(mean.max(0.01) * 10.0, 1.0 - 1.0 / 11.0);
104        }
105
106        let p = mean / variance;
107        let r = mean * p / (1.0 - p);
108        Self::new(r.max(0.01), p.clamp(1e-9, 1.0 - 1e-9))
109    }
110}
111
112impl DiscreteDistribution for NegativeBinomial {
113    fn name(&self) -> &str {
114        "NegativeBinomial"
115    }
116    fn num_params(&self) -> usize {
117        2
118    }
119
120    fn pmf(&self, k: u64) -> StatsResult<f64> {
121        Ok(self.logpmf(k)?.exp())
122    }
123
124    fn logpmf(&self, k: u64) -> StatsResult<f64> {
125        let kf = k as f64;
126        // log C(k+r-1, k) = ln_gamma(k+r) - ln_gamma(r) - ln_gamma(k+1)
127        let log_binom = ln_gamma(kf + self.r) - ln_gamma(self.r) - ln_gamma(kf + 1.0);
128        Ok(log_binom + self.r * self.p.ln() + kf * (1.0 - self.p).ln())
129    }
130
131    fn cdf(&self, k: u64) -> StatsResult<f64> {
132        // Sum PMF from 0 to k
133        let mut sum = 0.0_f64;
134        for i in 0..=k {
135            sum += self.pmf(i)?;
136            // Early exit if essentially 1
137            if sum >= 1.0 - 1e-15 {
138                return Ok(1.0);
139            }
140        }
141        Ok(sum.clamp(0.0, 1.0))
142    }
143
144    fn mean(&self) -> f64 {
145        self.r * (1.0 - self.p) / self.p
146    }
147
148    fn variance(&self) -> f64 {
149        self.r * (1.0 - self.p) / (self.p * self.p)
150    }
151}
152
153#[cfg(test)]
154mod tests {
155    use super::*;
156
157    #[test]
158    fn test_neg_binom_mean_variance() {
159        let nb = NegativeBinomial::new(5.0, 0.5).unwrap();
160        assert!((nb.mean() - 5.0).abs() < 1e-10);
161        assert!((nb.variance() - 10.0).abs() < 1e-10);
162    }
163
164    #[test]
165    fn test_neg_binom_pmf_k0() {
166        // P(X=0) = p^r
167        let nb = NegativeBinomial::new(2.0, 0.5).unwrap();
168        assert!((nb.pmf(0).unwrap() - 0.25).abs() < 1e-10);
169    }
170
171    #[test]
172    fn test_neg_binom_cdf_monotone() {
173        let nb = NegativeBinomial::new(3.0, 0.4).unwrap();
174        let mut prev = 0.0;
175        for k in 0..20 {
176            let c = nb.cdf(k).unwrap();
177            assert!(c >= prev, "CDF not monotone at k={k}");
178            prev = c;
179        }
180    }
181
182    #[test]
183    fn test_neg_binom_fit() {
184        let data = vec![0.0, 1.0, 2.0, 0.0, 3.0, 1.0, 0.0, 4.0, 1.0, 2.0];
185        let nb = NegativeBinomial::fit(&data).unwrap();
186        assert!(nb.r > 0.0 && nb.p > 0.0 && nb.p < 1.0);
187    }
188
189    #[test]
190    fn test_neg_binom_invalid() {
191        assert!(NegativeBinomial::new(0.0, 0.5).is_err());
192        assert!(NegativeBinomial::new(1.0, 0.0).is_err());
193        assert!(NegativeBinomial::new(1.0, 1.0).is_err());
194    }
195}