use self::GammaRepr::*;
use crate::{Distribution, Exp, Exp1, Open01, StandardNormal};
use core::fmt;
use num_traits::Float;
use rand::{Rng, RngExt};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct Gamma<F>
where
F: Float,
StandardNormal: Distribution<F>,
Exp1: Distribution<F>,
Open01: Distribution<F>,
{
repr: GammaRepr<F>,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Error {
ShapeTooSmall,
ScaleTooSmall,
ScaleTooLarge,
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(match self {
Error::ShapeTooSmall => "shape is not positive in gamma distribution",
Error::ScaleTooSmall => "scale is not positive in gamma distribution",
Error::ScaleTooLarge => "scale is infinity in gamma distribution",
})
}
}
#[cfg(feature = "std")]
impl std::error::Error for Error {}
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
enum GammaRepr<F>
where
F: Float,
StandardNormal: Distribution<F>,
Exp1: Distribution<F>,
Open01: Distribution<F>,
{
Large(GammaLargeShape<F>),
One(Exp<F>),
Small(GammaSmallShape<F>),
}
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
struct GammaSmallShape<F>
where
F: Float,
StandardNormal: Distribution<F>,
Open01: Distribution<F>,
{
inv_shape: F,
large_shape: GammaLargeShape<F>,
}
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
struct GammaLargeShape<F>
where
F: Float,
StandardNormal: Distribution<F>,
Open01: Distribution<F>,
{
scale: F,
c: F,
d: F,
}
impl<F> Gamma<F>
where
F: Float,
StandardNormal: Distribution<F>,
Exp1: Distribution<F>,
Open01: Distribution<F>,
{
#[inline]
pub fn new(shape: F, scale: F) -> Result<Gamma<F>, Error> {
if !(shape > F::zero()) {
return Err(Error::ShapeTooSmall);
}
if !(scale > F::zero()) {
return Err(Error::ScaleTooSmall);
}
let repr = if shape == F::infinity() || scale == F::infinity() {
One(Exp::new(F::zero()).unwrap())
} else if shape == F::one() {
One(Exp::new(F::one() / scale).unwrap())
} else if shape < F::one() {
Small(GammaSmallShape::new_raw(shape, scale))
} else {
Large(GammaLargeShape::new_raw(shape, scale))
};
Ok(Gamma { repr })
}
}
impl<F> GammaSmallShape<F>
where
F: Float,
StandardNormal: Distribution<F>,
Open01: Distribution<F>,
{
fn new_raw(shape: F, scale: F) -> GammaSmallShape<F> {
GammaSmallShape {
inv_shape: F::one() / shape,
large_shape: GammaLargeShape::new_raw(shape + F::one(), scale),
}
}
}
impl<F> GammaLargeShape<F>
where
F: Float,
StandardNormal: Distribution<F>,
Open01: Distribution<F>,
{
fn new_raw(shape: F, scale: F) -> GammaLargeShape<F> {
let d = shape - F::from(1. / 3.).unwrap();
GammaLargeShape {
scale,
c: F::one() / (F::from(9.).unwrap() * d).sqrt(),
d,
}
}
fn sample_unscaled<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
loop {
let x: F = rng.sample(StandardNormal);
let v_cbrt = F::one() + self.c * x;
if v_cbrt <= F::zero() {
continue;
}
let v = v_cbrt * v_cbrt * v_cbrt;
let u: F = rng.sample(Open01);
let x_sqr = x * x;
if u < F::one() - F::from(0.0331).unwrap() * x_sqr * x_sqr
|| u.ln() < F::from(0.5).unwrap() * x_sqr + self.d * (F::one() - v + v.ln())
{
return v;
}
}
}
}
impl<F> Distribution<F> for Gamma<F>
where
F: Float,
StandardNormal: Distribution<F>,
Exp1: Distribution<F>,
Open01: Distribution<F>,
{
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
match self.repr {
Small(ref g) => g.sample(rng),
One(ref g) => g.sample(rng),
Large(ref g) => g.sample(rng),
}
}
}
impl<F> Distribution<F> for GammaSmallShape<F>
where
F: Float,
StandardNormal: Distribution<F>,
Open01: Distribution<F>,
{
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
let u: F = rng.sample(Open01);
let a = self.large_shape.sample_unscaled(rng);
let b = u.powf(self.inv_shape);
(a * b * self.large_shape.d) * self.large_shape.scale
}
}
impl<F> Distribution<F> for GammaLargeShape<F>
where
F: Float,
StandardNormal: Distribution<F>,
Open01: Distribution<F>,
{
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
self.sample_unscaled(rng) * (self.d * self.scale)
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn gamma_distributions_can_be_compared() {
assert_eq!(Gamma::new(1.0, 2.0), Gamma::new(1.0, 2.0));
}
#[test]
fn gamma_extreme_values() {
let d = Gamma::new(f64::infinity(), 2.0).unwrap();
assert_eq!(d.sample(&mut crate::test::rng(21)), f64::infinity());
let d = Gamma::new(2.0, f64::infinity()).unwrap();
assert_eq!(d.sample(&mut crate::test::rng(21)), f64::infinity());
}
}