1#[cfg(feature = "serde1")]
3use serde::{Deserialize, Serialize};
4
5use crate::impl_display;
6use crate::misc::ln_gammafn;
7use crate::traits::{
8 Cdf, ContinuousDistr, Entropy, HasDensity, Kurtosis, Mean, Mode,
9 Parameterized, Sampleable, Scalable, Shiftable, Skewness, Support,
10 Variance,
11};
12use rand::Rng;
13use special::Gamma as _;
14use std::fmt;
15use std::sync::OnceLock;
16
17mod poisson_prior;
18
19#[derive(Debug, Clone)]
31#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
32#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))]
33pub struct Gamma {
34 shape: f64,
35 rate: f64,
36 #[cfg_attr(feature = "serde1", serde(skip))]
38 ln_gamma_shape: OnceLock<f64>,
39 #[cfg_attr(feature = "serde1", serde(skip))]
41 ln_rate: OnceLock<f64>,
42}
43
44pub struct GammaParameters {
45 pub shape: f64,
46 pub rate: f64,
47}
48
49impl Parameterized for Gamma {
50 type Parameters = GammaParameters;
51
52 fn emit_params(&self) -> Self::Parameters {
53 Self::Parameters {
54 shape: self.shape(),
55 rate: self.rate(),
56 }
57 }
58
59 fn from_params(params: Self::Parameters) -> Self {
60 Self::new_unchecked(params.shape, params.rate)
61 }
62}
63
64impl PartialEq for Gamma {
65 fn eq(&self, other: &Gamma) -> bool {
66 self.shape == other.shape && self.rate == other.rate
67 }
68}
69
70#[derive(Debug, Clone, PartialEq)]
71#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
72#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))]
73pub enum GammaError {
74 ShapeTooLow { shape: f64 },
76 ShapeNotFinite { shape: f64 },
78 RateTooLow { rate: f64 },
80 RateNotFinite { rate: f64 },
82}
83
84impl Gamma {
85 pub fn new(shape: f64, rate: f64) -> Result<Self, GammaError> {
87 if shape <= 0.0 {
88 Err(GammaError::ShapeTooLow { shape })
89 } else if rate <= 0.0 {
90 Err(GammaError::RateTooLow { rate })
91 } else if !shape.is_finite() {
92 Err(GammaError::ShapeNotFinite { shape })
93 } else if !rate.is_finite() {
94 Err(GammaError::RateNotFinite { rate })
95 } else {
96 Ok(Gamma::new_unchecked(shape, rate))
97 }
98 }
99
100 #[inline]
102 #[must_use]
103 pub fn new_unchecked(shape: f64, rate: f64) -> Self {
104 Gamma {
105 shape,
106 rate,
107 ln_gamma_shape: OnceLock::new(),
108 ln_rate: OnceLock::new(),
109 }
110 }
111
112 #[inline]
114 fn ln_rate(&self) -> f64 {
115 *self.ln_rate.get_or_init(|| self.rate.ln())
116 }
117
118 #[inline]
120 fn ln_gamma_shape(&self) -> f64 {
121 *self.ln_gamma_shape.get_or_init(|| ln_gammafn(self.shape))
122 }
123
124 #[inline]
134 pub fn shape(&self) -> f64 {
135 self.shape
136 }
137
138 #[inline]
164 pub fn set_shape(&mut self, shape: f64) -> Result<(), GammaError> {
165 if shape <= 0.0 {
166 Err(GammaError::ShapeTooLow { shape })
167 } else if !shape.is_finite() {
168 Err(GammaError::ShapeNotFinite { shape })
169 } else {
170 self.set_shape_unchecked(shape);
171 Ok(())
172 }
173 }
174
175 #[inline]
177 pub fn set_shape_unchecked(&mut self, shape: f64) {
178 self.shape = shape;
179 self.ln_gamma_shape = OnceLock::new();
180 }
181
182 #[inline]
192 pub fn rate(&self) -> f64 {
193 self.rate
194 }
195
196 #[inline]
222 pub fn set_rate(&mut self, rate: f64) -> Result<(), GammaError> {
223 if rate <= 0.0 {
224 Err(GammaError::RateTooLow { rate })
225 } else if !rate.is_finite() {
226 Err(GammaError::RateNotFinite { rate })
227 } else {
228 self.set_rate_unchecked(rate);
229 Ok(())
230 }
231 }
232
233 #[inline]
235 pub fn set_rate_unchecked(&mut self, rate: f64) {
236 self.rate = rate;
237 self.ln_rate = OnceLock::new();
238 }
239}
240
241impl Default for Gamma {
242 fn default() -> Self {
243 Gamma::new_unchecked(1.0, 1.0)
244 }
245}
246
247impl From<&Gamma> for String {
248 fn from(gam: &Gamma) -> String {
249 format!("G(α: {}, β: {})", gam.shape, gam.rate)
250 }
251}
252
253impl_display!(Gamma);
254
255macro_rules! impl_traits {
256 ($kind:ty) => {
257 impl HasDensity<$kind> for Gamma {
258 fn ln_f(&self, x: &$kind) -> f64 {
259 self.shape.mul_add(self.ln_rate(), -self.ln_gamma_shape())
260 + (self.shape - 1.0).mul_add(
261 f64::from(*x).ln(),
262 -(self.rate * f64::from(*x)),
263 )
264 }
265 }
266
267 impl Sampleable<$kind> for Gamma {
268 fn draw<R: Rng>(&self, rng: &mut R) -> $kind {
269 let g = rand_distr::Gamma::new(self.shape, 1.0 / self.rate)
270 .unwrap();
271 rng.sample(g) as $kind
272 }
273
274 fn sample<R: Rng>(&self, n: usize, rng: &mut R) -> Vec<$kind> {
275 let g = rand_distr::Gamma::new(self.shape, 1.0 / self.rate)
276 .unwrap();
277 (0..n).map(|_| rng.sample(g) as $kind).collect()
278 }
279 }
280
281 impl ContinuousDistr<$kind> for Gamma {}
282
283 impl Support<$kind> for Gamma {
284 fn supports(&self, x: &$kind) -> bool {
285 x.is_finite() && *x > 0.0
286 }
287 }
288
289 impl Cdf<$kind> for Gamma {
290 fn cdf(&self, x: &$kind) -> f64 {
291 if *x <= 0.0 {
292 0.0
293 } else {
294 (self.rate * f64::from(*x)).inc_gamma(self.shape)
295 }
296 }
297 }
298
299 impl Mean<$kind> for Gamma {
300 fn mean(&self) -> Option<$kind> {
301 Some((self.shape / self.rate) as $kind)
302 }
303 }
304
305 impl Mode<$kind> for Gamma {
306 fn mode(&self) -> Option<$kind> {
307 if self.shape >= 1.0 {
308 let m = (self.shape - 1.0) / self.rate;
309 Some(m as $kind)
310 } else {
311 None
312 }
313 }
314 }
315 };
316}
317
318impl Variance<f64> for Gamma {
319 fn variance(&self) -> Option<f64> {
320 Some(self.shape / (self.rate * self.rate))
321 }
322}
323
324impl Entropy for Gamma {
325 fn entropy(&self) -> f64 {
326 self.shape - self.ln_rate()
327 + (1.0 - self.shape)
328 .mul_add(self.shape.digamma(), self.ln_gamma_shape())
329 }
330}
331
332impl Skewness for Gamma {
333 fn skewness(&self) -> Option<f64> {
334 Some(2.0 / self.shape.sqrt())
335 }
336}
337
338impl Kurtosis for Gamma {
339 fn kurtosis(&self) -> Option<f64> {
340 Some(6.0 / self.shape)
341 }
342}
343
344impl_traits!(f32);
345impl_traits!(f64);
346
347impl std::error::Error for GammaError {}
348
349#[cfg_attr(coverage_nightly, coverage(off))]
350impl fmt::Display for GammaError {
351 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
352 match self {
353 Self::ShapeTooLow { shape } => {
354 write!(f, "rate ({shape}) must be greater than zero")
355 }
356 Self::ShapeNotFinite { shape } => {
357 write!(f, "non-finite rate: {shape}")
358 }
359 Self::RateTooLow { rate } => {
360 write!(f, "rate ({rate}) must be greater than zero")
361 }
362 Self::RateNotFinite { rate } => {
363 write!(f, "non-finite rate: {rate}")
364 }
365 }
366 }
367}
368
369crate::impl_shiftable!(Gamma);
370
371impl Scalable for Gamma {
372 type Output = Gamma;
373 type Error = GammaError;
374
375 fn scaled(self, scale: f64) -> Result<Self::Output, Self::Error> {
376 Ok(Gamma::new_unchecked(self.shape, self.rate / scale))
377 }
378
379 fn scaled_unchecked(self, scale: f64) -> Self::Output {
380 Gamma::new_unchecked(self.shape, self.rate / scale)
381 }
382}
383
384#[cfg(test)]
385mod tests {
386 use super::*;
387 use crate::misc::ks_test;
388 use crate::test_basic_impls;
389 use std::f64;
390
391 const TOL: f64 = 1E-12;
392 const KS_PVAL: f64 = 0.2;
393 const N_TRIES: usize = 5;
394
395 test_basic_impls!(f64, Gamma, Gamma::new_unchecked(1.0, 2.0));
396
397 #[test]
398 fn new() {
399 let gam = Gamma::new(1.0, 2.0).unwrap();
400 assert::close(gam.shape, 1.0, TOL);
401 assert::close(gam.rate, 2.0, TOL);
402 }
403
404 #[test]
405 fn ln_pdf_low_value() {
406 let gam = Gamma::new(1.2, 3.4).unwrap();
407 assert::close(gam.ln_pdf(&0.1_f64), 0.753_387_589_351_045_6, TOL);
408 }
409
410 #[test]
411 fn ln_pdf_at_mean() {
412 let gam = Gamma::new(1.2, 3.4).unwrap();
413 assert::close(gam.ln_pdf(&100.0_f64), -337.525_061_354_852_54, TOL);
414 }
415
416 #[test]
417 fn cdf() {
418 let gam = Gamma::new(1.2, 3.4).unwrap();
419 assert::close(gam.cdf(&0.5_f32), 0.759_436_544_318_054_6, TOL);
420 assert::close(
421 gam.cdf(&0.352_941_176_470_588_26_f64),
422 0.620_918_065_523_85,
423 TOL,
424 );
425 assert::close(gam.cdf(&100.0_f64), 1.0, TOL);
426 }
427
428 #[test]
429 fn ln_pdf_high_value() {
430 let gam = Gamma::new(1.2, 3.4).unwrap();
431 assert::close(
432 gam.ln_pdf(&0.352_941_176_470_588_26_f64),
433 0.145_613_832_984_222_48,
434 TOL,
435 );
436 }
437
438 #[test]
439 fn mean_should_be_ratio_of_params() {
440 let m1: f64 = Gamma::new(1.0, 2.0).unwrap().mean().unwrap();
441 let m2: f64 = Gamma::new(1.0, 1.0).unwrap().mean().unwrap();
442 let m3: f64 = Gamma::new(3.0, 1.0).unwrap().mean().unwrap();
443 let m4: f64 = Gamma::new(0.3, 0.1).unwrap().mean().unwrap();
444 assert::close(m1, 0.5, TOL);
445 assert::close(m2, 1.0, TOL);
446 assert::close(m3, 3.0, TOL);
447 assert::close(m4, 3.0, TOL);
448 }
449
450 #[test]
451 fn mode_undefined_for_shape_less_than_one() {
452 let m1_opt: Option<f64> = Gamma::new(1.0, 2.0).unwrap().mode();
453 let m2_opt: Option<f64> = Gamma::new(0.999, 2.0).unwrap().mode();
454 let m3_opt: Option<f64> = Gamma::new(0.5, 2.0).unwrap().mode();
455 let m4_opt: Option<f64> = Gamma::new(0.1, 2.0).unwrap().mode();
456 assert!(m1_opt.is_some());
457 assert!(m2_opt.is_none());
458 assert!(m3_opt.is_none());
459 assert!(m4_opt.is_none());
460 }
461
462 #[test]
463 fn mode() {
464 let m1: f64 = Gamma::new(2.0, 2.0).unwrap().mode().unwrap();
465 let m2: f64 = Gamma::new(1.0, 2.0).unwrap().mode().unwrap();
466 let m3: f64 = Gamma::new(2.0, 1.0).unwrap().mode().unwrap();
467 assert::close(m1, 0.5, TOL);
468 assert::close(m2, 0.0, TOL);
469 assert::close(m3, 1.0, TOL);
470 }
471
472 #[test]
473 fn variance() {
474 assert::close(
475 Gamma::new(2.0, 2.0).unwrap().variance().unwrap(),
476 0.5,
477 TOL,
478 );
479 assert::close(
480 Gamma::new(0.5, 2.0).unwrap().variance().unwrap(),
481 1.0 / 8.0,
482 TOL,
483 );
484 }
485
486 #[test]
487 fn skewness() {
488 assert::close(
489 Gamma::new(4.0, 3.0).unwrap().skewness().unwrap(),
490 1.0,
491 TOL,
492 );
493 assert::close(
494 Gamma::new(16.0, 4.0).unwrap().skewness().unwrap(),
495 0.5,
496 TOL,
497 );
498 assert::close(
499 Gamma::new(16.0, 1.0).unwrap().skewness().unwrap(),
500 0.5,
501 TOL,
502 );
503 }
504
505 #[test]
506 fn kurtosis() {
507 assert::close(
508 Gamma::new(6.0, 3.0).unwrap().kurtosis().unwrap(),
509 1.0,
510 TOL,
511 );
512 assert::close(
513 Gamma::new(6.0, 1.0).unwrap().kurtosis().unwrap(),
514 1.0,
515 TOL,
516 );
517 assert::close(
518 Gamma::new(12.0, 1.0).unwrap().kurtosis().unwrap(),
519 0.5,
520 TOL,
521 );
522 }
523
524 #[test]
525 fn entropy() {
526 let gam1 = Gamma::new(2.0, 1.0).unwrap();
527 let gam2 = Gamma::new(1.2, 3.4).unwrap();
528 assert::close(gam1.entropy(), 1.577_215_664_901_532_8, TOL);
529 assert::close(gam2.entropy(), -0.051_341_542_306_993_84, TOL);
530 }
531
532 #[test]
533 fn draw_test() {
534 let mut rng = rand::rng();
535 let gam = Gamma::new(1.2, 3.4).unwrap();
536 let cdf = |x: f64| gam.cdf(&x);
537
538 let passes = (0..N_TRIES).fold(0, |acc, _| {
540 let xs: Vec<f64> = gam.sample(1000, &mut rng);
541 let (_, p) = ks_test(&xs, cdf);
542 if p > KS_PVAL { acc + 1 } else { acc }
543 });
544
545 assert!(passes > 0);
546 }
547
548 use crate::test_scalable_cdf;
549 use crate::test_scalable_density;
550 use crate::test_scalable_entropy;
551 use crate::test_scalable_method;
552
553 test_scalable_method!(Gamma::new(2.0, 4.0).unwrap(), mean);
554 test_scalable_method!(Gamma::new(2.0, 4.0).unwrap(), variance);
555 test_scalable_method!(Gamma::new(2.0, 4.0).unwrap(), skewness);
556 test_scalable_method!(Gamma::new(2.0, 4.0).unwrap(), kurtosis);
557 test_scalable_density!(Gamma::new(2.0, 4.0).unwrap());
558 test_scalable_entropy!(Gamma::new(2.0, 4.0).unwrap());
559 test_scalable_cdf!(Gamma::new(2.0, 4.0).unwrap());
560
561 #[test]
562 fn emit_and_from_params_are_identity() {
563 let dist_a = Gamma::new(3.0, 5.0).unwrap();
564 let dist_b = Gamma::from_params(dist_a.emit_params());
565 assert_eq!(dist_a, dist_b);
566 }
567}