rv/dist/
gamma.rs

1//! Gamma distribution over x in (0, ∞)
2#[cfg(feature = "serde1")]
3use serde::{Deserialize, Serialize};
4
5use crate::impl_display;
6use crate::misc::ln_gammafn;
7use crate::traits::{
8    Cdf, ContinuousDistr, Entropy, HasDensity, Kurtosis, Mean, Mode,
9    Parameterized, Sampleable, Scalable, Shiftable, Skewness, Support,
10    Variance,
11};
12use rand::Rng;
13use special::Gamma as _;
14use std::fmt;
15use std::sync::OnceLock;
16
17mod poisson_prior;
18
19/// [Gamma distribution](https://en.wikipedia.org/wiki/Gamma_distribution) G(α, β)
20/// over x in (0, ∞).
21///
22/// **NOTE**: The gamma distribution is parameterized in terms of shape, α, and
23/// rate, β.
24///
25/// ```math
26///             β^α
27/// f(x|α, β) = ----  x^(α-1) e^(-βx)
28///             Γ(α)
29/// ```
30#[derive(Debug, Clone)]
31#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
32#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))]
33pub struct Gamma {
34    shape: f64,
35    rate: f64,
36    // ln(gamma(shape))
37    #[cfg_attr(feature = "serde1", serde(skip))]
38    ln_gamma_shape: OnceLock<f64>,
39    // ln(rate)
40    #[cfg_attr(feature = "serde1", serde(skip))]
41    ln_rate: OnceLock<f64>,
42}
43
44pub struct GammaParameters {
45    pub shape: f64,
46    pub rate: f64,
47}
48
49impl Parameterized for Gamma {
50    type Parameters = GammaParameters;
51
52    fn emit_params(&self) -> Self::Parameters {
53        Self::Parameters {
54            shape: self.shape(),
55            rate: self.rate(),
56        }
57    }
58
59    fn from_params(params: Self::Parameters) -> Self {
60        Self::new_unchecked(params.shape, params.rate)
61    }
62}
63
64impl PartialEq for Gamma {
65    fn eq(&self, other: &Gamma) -> bool {
66        self.shape == other.shape && self.rate == other.rate
67    }
68}
69
70#[derive(Debug, Clone, PartialEq)]
71#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
72#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))]
73pub enum GammaError {
74    /// Shape parameter is less than or equal to zero
75    ShapeTooLow { shape: f64 },
76    /// Shape parameter is infinite or NaN
77    ShapeNotFinite { shape: f64 },
78    /// Rate parameter is less than or equal to zero
79    RateTooLow { rate: f64 },
80    /// Rate parameter is infinite or NaN
81    RateNotFinite { rate: f64 },
82}
83
84impl Gamma {
85    /// Create a new `Gamma` distribution with shape (α) and rate (β).
86    pub fn new(shape: f64, rate: f64) -> Result<Self, GammaError> {
87        if shape <= 0.0 {
88            Err(GammaError::ShapeTooLow { shape })
89        } else if rate <= 0.0 {
90            Err(GammaError::RateTooLow { rate })
91        } else if !shape.is_finite() {
92            Err(GammaError::ShapeNotFinite { shape })
93        } else if !rate.is_finite() {
94            Err(GammaError::RateNotFinite { rate })
95        } else {
96            Ok(Gamma::new_unchecked(shape, rate))
97        }
98    }
99
100    /// Creates a new Gamma without checking whether the parameters are valid.
101    #[inline]
102    #[must_use]
103    pub fn new_unchecked(shape: f64, rate: f64) -> Self {
104        Gamma {
105            shape,
106            rate,
107            ln_gamma_shape: OnceLock::new(),
108            ln_rate: OnceLock::new(),
109        }
110    }
111
112    /// Get ln(rate)
113    #[inline]
114    fn ln_rate(&self) -> f64 {
115        *self.ln_rate.get_or_init(|| self.rate.ln())
116    }
117
118    /// Get ln(gamma(rate))
119    #[inline]
120    fn ln_gamma_shape(&self) -> f64 {
121        *self.ln_gamma_shape.get_or_init(|| ln_gammafn(self.shape))
122    }
123
124    /// Get the shape parameter
125    ///
126    /// # Example
127    ///
128    /// ```rust
129    /// # use rv::dist::Gamma;
130    /// let gam = Gamma::new(2.0, 1.0).unwrap();
131    /// assert_eq!(gam.shape(), 2.0);
132    /// ```
133    #[inline]
134    pub fn shape(&self) -> f64 {
135        self.shape
136    }
137
138    /// Set the shape parameter
139    ///
140    /// # Example
141    ///
142    /// ```rust
143    /// # use rv::dist::Gamma;
144    /// let mut gam = Gamma::new(2.0, 1.0).unwrap();
145    /// assert_eq!(gam.shape(), 2.0);
146    ///
147    /// gam.set_shape(1.1).unwrap();
148    /// assert_eq!(gam.shape(), 1.1);
149    /// ```
150    ///
151    /// Will error for invalid values
152    ///
153    /// ```rust
154    /// # use rv::dist::Gamma;
155    /// # let mut gam = Gamma::new(2.0, 1.0).unwrap();
156    /// assert!(gam.set_shape(1.1).is_ok());
157    /// assert!(gam.set_shape(0.0).is_err());
158    /// assert!(gam.set_shape(-1.0).is_err());
159    /// assert!(gam.set_shape(f64::INFINITY).is_err());
160    /// assert!(gam.set_shape(f64::NEG_INFINITY).is_err());
161    /// assert!(gam.set_shape(f64::NAN).is_err());
162    /// ```
163    #[inline]
164    pub fn set_shape(&mut self, shape: f64) -> Result<(), GammaError> {
165        if shape <= 0.0 {
166            Err(GammaError::ShapeTooLow { shape })
167        } else if !shape.is_finite() {
168            Err(GammaError::ShapeNotFinite { shape })
169        } else {
170            self.set_shape_unchecked(shape);
171            Ok(())
172        }
173    }
174
175    /// Set the shape parameter without input validation
176    #[inline]
177    pub fn set_shape_unchecked(&mut self, shape: f64) {
178        self.shape = shape;
179        self.ln_gamma_shape = OnceLock::new();
180    }
181
182    /// Get the rate parameter
183    ///
184    /// # Example
185    ///
186    /// ```rust
187    /// # use rv::dist::Gamma;
188    /// let gam = Gamma::new(2.0, 1.0).unwrap();
189    /// assert_eq!(gam.rate(), 1.0);
190    /// ```
191    #[inline]
192    pub fn rate(&self) -> f64 {
193        self.rate
194    }
195
196    /// Set the rate parameter
197    ///
198    /// # Example
199    ///
200    /// ```rust
201    /// # use rv::dist::Gamma;
202    /// let mut gam = Gamma::new(2.0, 1.0).unwrap();
203    /// assert_eq!(gam.rate(), 1.0);
204    ///
205    /// gam.set_rate(1.1).unwrap();
206    /// assert_eq!(gam.rate(), 1.1);
207    /// ```
208    ///
209    /// Will error for invalid values
210    ///
211    /// ```rust
212    /// # use rv::dist::Gamma;
213    /// # let mut gam = Gamma::new(2.0, 1.0).unwrap();
214    /// assert!(gam.set_rate(1.1).is_ok());
215    /// assert!(gam.set_rate(0.0).is_err());
216    /// assert!(gam.set_rate(-1.0).is_err());
217    /// assert!(gam.set_rate(f64::INFINITY).is_err());
218    /// assert!(gam.set_rate(f64::NEG_INFINITY).is_err());
219    /// assert!(gam.set_rate(f64::NAN).is_err());
220    /// ```
221    #[inline]
222    pub fn set_rate(&mut self, rate: f64) -> Result<(), GammaError> {
223        if rate <= 0.0 {
224            Err(GammaError::RateTooLow { rate })
225        } else if !rate.is_finite() {
226            Err(GammaError::RateNotFinite { rate })
227        } else {
228            self.set_rate_unchecked(rate);
229            Ok(())
230        }
231    }
232
233    /// Set the rate parameter without input validation
234    #[inline]
235    pub fn set_rate_unchecked(&mut self, rate: f64) {
236        self.rate = rate;
237        self.ln_rate = OnceLock::new();
238    }
239}
240
241impl Default for Gamma {
242    fn default() -> Self {
243        Gamma::new_unchecked(1.0, 1.0)
244    }
245}
246
247impl From<&Gamma> for String {
248    fn from(gam: &Gamma) -> String {
249        format!("G(α: {}, β: {})", gam.shape, gam.rate)
250    }
251}
252
253impl_display!(Gamma);
254
255macro_rules! impl_traits {
256    ($kind:ty) => {
257        impl HasDensity<$kind> for Gamma {
258            fn ln_f(&self, x: &$kind) -> f64 {
259                self.shape.mul_add(self.ln_rate(), -self.ln_gamma_shape())
260                    + (self.shape - 1.0).mul_add(
261                        f64::from(*x).ln(),
262                        -(self.rate * f64::from(*x)),
263                    )
264            }
265        }
266
267        impl Sampleable<$kind> for Gamma {
268            fn draw<R: Rng>(&self, rng: &mut R) -> $kind {
269                let g = rand_distr::Gamma::new(self.shape, 1.0 / self.rate)
270                    .unwrap();
271                rng.sample(g) as $kind
272            }
273
274            fn sample<R: Rng>(&self, n: usize, rng: &mut R) -> Vec<$kind> {
275                let g = rand_distr::Gamma::new(self.shape, 1.0 / self.rate)
276                    .unwrap();
277                (0..n).map(|_| rng.sample(g) as $kind).collect()
278            }
279        }
280
281        impl ContinuousDistr<$kind> for Gamma {}
282
283        impl Support<$kind> for Gamma {
284            fn supports(&self, x: &$kind) -> bool {
285                x.is_finite() && *x > 0.0
286            }
287        }
288
289        impl Cdf<$kind> for Gamma {
290            fn cdf(&self, x: &$kind) -> f64 {
291                if *x <= 0.0 {
292                    0.0
293                } else {
294                    (self.rate * f64::from(*x)).inc_gamma(self.shape)
295                }
296            }
297        }
298
299        impl Mean<$kind> for Gamma {
300            fn mean(&self) -> Option<$kind> {
301                Some((self.shape / self.rate) as $kind)
302            }
303        }
304
305        impl Mode<$kind> for Gamma {
306            fn mode(&self) -> Option<$kind> {
307                if self.shape >= 1.0 {
308                    let m = (self.shape - 1.0) / self.rate;
309                    Some(m as $kind)
310                } else {
311                    None
312                }
313            }
314        }
315    };
316}
317
318impl Variance<f64> for Gamma {
319    fn variance(&self) -> Option<f64> {
320        Some(self.shape / (self.rate * self.rate))
321    }
322}
323
324impl Entropy for Gamma {
325    fn entropy(&self) -> f64 {
326        self.shape - self.ln_rate()
327            + (1.0 - self.shape)
328                .mul_add(self.shape.digamma(), self.ln_gamma_shape())
329    }
330}
331
332impl Skewness for Gamma {
333    fn skewness(&self) -> Option<f64> {
334        Some(2.0 / self.shape.sqrt())
335    }
336}
337
338impl Kurtosis for Gamma {
339    fn kurtosis(&self) -> Option<f64> {
340        Some(6.0 / self.shape)
341    }
342}
343
344impl_traits!(f32);
345impl_traits!(f64);
346
347impl std::error::Error for GammaError {}
348
349#[cfg_attr(coverage_nightly, coverage(off))]
350impl fmt::Display for GammaError {
351    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
352        match self {
353            Self::ShapeTooLow { shape } => {
354                write!(f, "rate ({shape}) must be greater than zero")
355            }
356            Self::ShapeNotFinite { shape } => {
357                write!(f, "non-finite rate: {shape}")
358            }
359            Self::RateTooLow { rate } => {
360                write!(f, "rate ({rate}) must be greater than zero")
361            }
362            Self::RateNotFinite { rate } => {
363                write!(f, "non-finite rate: {rate}")
364            }
365        }
366    }
367}
368
369crate::impl_shiftable!(Gamma);
370
371impl Scalable for Gamma {
372    type Output = Gamma;
373    type Error = GammaError;
374
375    fn scaled(self, scale: f64) -> Result<Self::Output, Self::Error> {
376        Ok(Gamma::new_unchecked(self.shape, self.rate / scale))
377    }
378
379    fn scaled_unchecked(self, scale: f64) -> Self::Output {
380        Gamma::new_unchecked(self.shape, self.rate / scale)
381    }
382}
383
384#[cfg(test)]
385mod tests {
386    use super::*;
387    use crate::misc::ks_test;
388    use crate::test_basic_impls;
389    use std::f64;
390
391    const TOL: f64 = 1E-12;
392    const KS_PVAL: f64 = 0.2;
393    const N_TRIES: usize = 5;
394
395    test_basic_impls!(f64, Gamma, Gamma::new_unchecked(1.0, 2.0));
396
397    #[test]
398    fn new() {
399        let gam = Gamma::new(1.0, 2.0).unwrap();
400        assert::close(gam.shape, 1.0, TOL);
401        assert::close(gam.rate, 2.0, TOL);
402    }
403
404    #[test]
405    fn ln_pdf_low_value() {
406        let gam = Gamma::new(1.2, 3.4).unwrap();
407        assert::close(gam.ln_pdf(&0.1_f64), 0.753_387_589_351_045_6, TOL);
408    }
409
410    #[test]
411    fn ln_pdf_at_mean() {
412        let gam = Gamma::new(1.2, 3.4).unwrap();
413        assert::close(gam.ln_pdf(&100.0_f64), -337.525_061_354_852_54, TOL);
414    }
415
416    #[test]
417    fn cdf() {
418        let gam = Gamma::new(1.2, 3.4).unwrap();
419        assert::close(gam.cdf(&0.5_f32), 0.759_436_544_318_054_6, TOL);
420        assert::close(
421            gam.cdf(&0.352_941_176_470_588_26_f64),
422            0.620_918_065_523_85,
423            TOL,
424        );
425        assert::close(gam.cdf(&100.0_f64), 1.0, TOL);
426    }
427
428    #[test]
429    fn ln_pdf_high_value() {
430        let gam = Gamma::new(1.2, 3.4).unwrap();
431        assert::close(
432            gam.ln_pdf(&0.352_941_176_470_588_26_f64),
433            0.145_613_832_984_222_48,
434            TOL,
435        );
436    }
437
438    #[test]
439    fn mean_should_be_ratio_of_params() {
440        let m1: f64 = Gamma::new(1.0, 2.0).unwrap().mean().unwrap();
441        let m2: f64 = Gamma::new(1.0, 1.0).unwrap().mean().unwrap();
442        let m3: f64 = Gamma::new(3.0, 1.0).unwrap().mean().unwrap();
443        let m4: f64 = Gamma::new(0.3, 0.1).unwrap().mean().unwrap();
444        assert::close(m1, 0.5, TOL);
445        assert::close(m2, 1.0, TOL);
446        assert::close(m3, 3.0, TOL);
447        assert::close(m4, 3.0, TOL);
448    }
449
450    #[test]
451    fn mode_undefined_for_shape_less_than_one() {
452        let m1_opt: Option<f64> = Gamma::new(1.0, 2.0).unwrap().mode();
453        let m2_opt: Option<f64> = Gamma::new(0.999, 2.0).unwrap().mode();
454        let m3_opt: Option<f64> = Gamma::new(0.5, 2.0).unwrap().mode();
455        let m4_opt: Option<f64> = Gamma::new(0.1, 2.0).unwrap().mode();
456        assert!(m1_opt.is_some());
457        assert!(m2_opt.is_none());
458        assert!(m3_opt.is_none());
459        assert!(m4_opt.is_none());
460    }
461
462    #[test]
463    fn mode() {
464        let m1: f64 = Gamma::new(2.0, 2.0).unwrap().mode().unwrap();
465        let m2: f64 = Gamma::new(1.0, 2.0).unwrap().mode().unwrap();
466        let m3: f64 = Gamma::new(2.0, 1.0).unwrap().mode().unwrap();
467        assert::close(m1, 0.5, TOL);
468        assert::close(m2, 0.0, TOL);
469        assert::close(m3, 1.0, TOL);
470    }
471
472    #[test]
473    fn variance() {
474        assert::close(
475            Gamma::new(2.0, 2.0).unwrap().variance().unwrap(),
476            0.5,
477            TOL,
478        );
479        assert::close(
480            Gamma::new(0.5, 2.0).unwrap().variance().unwrap(),
481            1.0 / 8.0,
482            TOL,
483        );
484    }
485
486    #[test]
487    fn skewness() {
488        assert::close(
489            Gamma::new(4.0, 3.0).unwrap().skewness().unwrap(),
490            1.0,
491            TOL,
492        );
493        assert::close(
494            Gamma::new(16.0, 4.0).unwrap().skewness().unwrap(),
495            0.5,
496            TOL,
497        );
498        assert::close(
499            Gamma::new(16.0, 1.0).unwrap().skewness().unwrap(),
500            0.5,
501            TOL,
502        );
503    }
504
505    #[test]
506    fn kurtosis() {
507        assert::close(
508            Gamma::new(6.0, 3.0).unwrap().kurtosis().unwrap(),
509            1.0,
510            TOL,
511        );
512        assert::close(
513            Gamma::new(6.0, 1.0).unwrap().kurtosis().unwrap(),
514            1.0,
515            TOL,
516        );
517        assert::close(
518            Gamma::new(12.0, 1.0).unwrap().kurtosis().unwrap(),
519            0.5,
520            TOL,
521        );
522    }
523
524    #[test]
525    fn entropy() {
526        let gam1 = Gamma::new(2.0, 1.0).unwrap();
527        let gam2 = Gamma::new(1.2, 3.4).unwrap();
528        assert::close(gam1.entropy(), 1.577_215_664_901_532_8, TOL);
529        assert::close(gam2.entropy(), -0.051_341_542_306_993_84, TOL);
530    }
531
532    #[test]
533    fn draw_test() {
534        let mut rng = rand::rng();
535        let gam = Gamma::new(1.2, 3.4).unwrap();
536        let cdf = |x: f64| gam.cdf(&x);
537
538        // test is flaky, try a few times
539        let passes = (0..N_TRIES).fold(0, |acc, _| {
540            let xs: Vec<f64> = gam.sample(1000, &mut rng);
541            let (_, p) = ks_test(&xs, cdf);
542            if p > KS_PVAL { acc + 1 } else { acc }
543        });
544
545        assert!(passes > 0);
546    }
547
548    use crate::test_scalable_cdf;
549    use crate::test_scalable_density;
550    use crate::test_scalable_entropy;
551    use crate::test_scalable_method;
552
553    test_scalable_method!(Gamma::new(2.0, 4.0).unwrap(), mean);
554    test_scalable_method!(Gamma::new(2.0, 4.0).unwrap(), variance);
555    test_scalable_method!(Gamma::new(2.0, 4.0).unwrap(), skewness);
556    test_scalable_method!(Gamma::new(2.0, 4.0).unwrap(), kurtosis);
557    test_scalable_density!(Gamma::new(2.0, 4.0).unwrap());
558    test_scalable_entropy!(Gamma::new(2.0, 4.0).unwrap());
559    test_scalable_cdf!(Gamma::new(2.0, 4.0).unwrap());
560
561    #[test]
562    fn emit_and_from_params_are_identity() {
563        let dist_a = Gamma::new(3.0, 5.0).unwrap();
564        let dist_b = Gamma::from_params(dist_a.emit_params());
565        assert_eq!(dist_a, dist_b);
566    }
567}