Skip to main content

numra_stats/distributions/
gamma_dist.rs

1//! Gamma distribution.
2//!
3//! Author: Moussa Leblouba
4//! Date: 9 February 2026
5//! Modified: 2 May 2026
6
7use numra_core::Scalar;
8use numra_special::{gammainc, lgamma};
9use rand::RngCore;
10
11use super::normal::random_uniform_01;
12use super::ContinuousDistribution;
13
14/// Gamma distribution with shape alpha and rate beta.
15///
16/// PDF: f(x) = beta^alpha * x^(alpha-1) * exp(-beta*x) / Gamma(alpha)
17#[derive(Clone, Debug)]
18pub struct GammaDist<S: Scalar> {
19    pub shape: S,
20    pub rate: S,
21}
22
23impl<S: Scalar> GammaDist<S> {
24    pub fn new(shape: S, rate: S) -> Self {
25        Self { shape, rate }
26    }
27}
28
29impl<S: Scalar> ContinuousDistribution<S> for GammaDist<S> {
30    fn pdf(&self, x: S) -> S {
31        if x < S::ZERO {
32            return S::ZERO;
33        }
34        if x == S::ZERO {
35            return if self.shape == S::ONE {
36                self.rate
37            } else if self.shape > S::ONE {
38                S::ZERO
39            } else {
40                S::INFINITY
41            };
42        }
43        let log_pdf = self.shape * self.rate.ln() + (self.shape - S::ONE) * x.ln()
44            - self.rate * x
45            - lgamma(self.shape);
46        log_pdf.exp()
47    }
48
49    fn cdf(&self, x: S) -> S {
50        if x <= S::ZERO {
51            return S::ZERO;
52        }
53        // CDF = P(a, b*x) = regularized lower incomplete gamma
54        gammainc(self.shape, self.rate * x)
55    }
56
57    fn quantile(&self, p: S) -> S {
58        if p <= S::ZERO {
59            return S::ZERO;
60        }
61        if p >= S::ONE {
62            return S::INFINITY;
63        }
64        // Newton iteration on CDF
65        // Initial guess using Wilson-Hilferty approximation
66        let mu = self.shape / self.rate;
67        let sig = (self.shape / (self.rate * self.rate)).sqrt();
68        let mut x = mu + sig * normal_quantile_approx(p);
69        if x <= S::ZERO {
70            x = mu * S::from_f64(0.01);
71        }
72        for _ in 0..50 {
73            let f_val = self.cdf(x) - p;
74            let f_prime = self.pdf(x);
75            if f_prime.to_f64().abs() < 1e-300 {
76                break;
77            }
78            let step = f_val / f_prime;
79            x -= step;
80            if x <= S::ZERO {
81                x = S::from_f64(1e-10);
82            }
83            if step.to_f64().abs() < 1e-12 * x.to_f64().abs() {
84                break;
85            }
86        }
87        x
88    }
89
90    fn mean(&self) -> S {
91        self.shape / self.rate
92    }
93
94    fn variance(&self) -> S {
95        self.shape / (self.rate * self.rate)
96    }
97
98    fn sample(&self, rng: &mut dyn RngCore) -> S {
99        // Marsaglia and Tsang's method for shape >= 1
100        let one = S::ONE;
101        let shape = if self.shape < one {
102            self.shape + one
103        } else {
104            self.shape
105        };
106
107        let d = shape - S::from_f64(1.0 / 3.0);
108        let c = S::ONE / (S::from_f64(9.0) * d).sqrt();
109
110        loop {
111            let x = sample_standard_normal::<S>(rng);
112            let v = S::ONE + c * x;
113            if v <= S::ZERO {
114                continue;
115            }
116            let v = v * v * v;
117            let u = random_uniform_01::<S>(rng);
118            let x2 = x * x;
119            if u < S::ONE - S::from_f64(0.0331) * x2 * x2 {
120                let result = d * v / self.rate;
121                if self.shape < one {
122                    let u2 = random_uniform_01::<S>(rng);
123                    return result * u2.ln().exp() / self.shape.ln().exp()
124                        * (S::ONE / self.shape).ln().exp();
125                }
126                return result;
127            }
128            if u.ln() < S::HALF * x2 + d * (S::ONE - v + v.ln()) {
129                let result = d * v / self.rate;
130                if self.shape < one {
131                    let u2 = random_uniform_01::<S>(rng);
132                    return result * u2.powf(S::ONE / self.shape);
133                }
134                return result;
135            }
136        }
137    }
138}
139
140/// Quick normal quantile approximation for initial guesses.
141pub(crate) fn normal_quantile_approx<S: Scalar>(p: S) -> S {
142    let p_f64 = p.to_f64();
143    // Rational approximation from Abramowitz & Stegun 26.2.23
144    let t = if p_f64 < 0.5 {
145        (-2.0 * p_f64.ln()).sqrt()
146    } else {
147        (-2.0 * (1.0 - p_f64).ln()).sqrt()
148    };
149    let c0 = 2.515517;
150    let c1 = 0.802853;
151    let c2 = 0.010328;
152    let d1 = 1.432788;
153    let d2 = 0.189269;
154    let d3 = 0.001308;
155    let val = t - (c0 + c1 * t + c2 * t * t) / (1.0 + d1 * t + d2 * t * t + d3 * t * t * t);
156    if p_f64 < 0.5 {
157        S::from_f64(-val)
158    } else {
159        S::from_f64(val)
160    }
161}
162
163/// Sample from standard normal using Box-Muller.
164pub(crate) fn sample_standard_normal<S: Scalar>(rng: &mut dyn RngCore) -> S {
165    let u1 = random_uniform_01::<S>(rng);
166    let u2 = random_uniform_01::<S>(rng);
167    let two = S::TWO;
168    let pi2 = S::from_f64(core::f64::consts::TAU);
169    (S::ZERO - two * u1.ln()).sqrt() * (pi2 * u2).cos()
170}
171
172#[cfg(test)]
173mod tests {
174    use super::*;
175
176    #[test]
177    fn test_gamma_pdf_exponential_case() {
178        // Gamma(1, lambda) = Exponential(lambda)
179        let g = GammaDist::new(1.0_f64, 2.0);
180        assert!((g.pdf(0.0) - 2.0).abs() < 1e-12);
181        assert!((g.pdf(1.0) - 2.0 * (-2.0_f64).exp()).abs() < 1e-12);
182    }
183
184    #[test]
185    fn test_gamma_cdf() {
186        let g = GammaDist::new(1.0_f64, 1.0);
187        // CDF of Exp(1) at x=1 is 1-e^-1
188        assert!((g.cdf(1.0) - (1.0 - (-1.0_f64).exp())).abs() < 1e-8);
189    }
190
191    #[test]
192    fn test_gamma_quantile_roundtrip() {
193        let g = GammaDist::new(3.0_f64, 2.0);
194        for &p in &[0.1, 0.5, 0.9] {
195            let x = g.quantile(p);
196            let p2 = g.cdf(x);
197            assert!((p - p2).abs() < 1e-6, "p={}, p2={}", p, p2);
198        }
199    }
200
201    #[test]
202    fn test_gamma_mean_variance() {
203        let g = GammaDist::new(5.0_f64, 2.0);
204        assert!((g.mean() - 2.5).abs() < 1e-14);
205        assert!((g.variance() - 1.25).abs() < 1e-14);
206    }
207
208    #[test]
209    fn test_gamma_sample_mean() {
210        use rand::SeedableRng;
211        let g = GammaDist::new(3.0_f64, 1.0);
212        let mut rng = rand::rngs::StdRng::seed_from_u64(45);
213        let samples = g.sample_n(&mut rng, 10000);
214        let mean: f64 = samples.iter().sum::<f64>() / samples.len() as f64;
215        assert!((mean - 3.0).abs() < 0.2, "sample mean = {}", mean);
216    }
217}