use crate::dist::{Continuous, DistError, Distribution, Moments};
use crate::num;
use crate::rng::RngCore;
#[derive(Debug, Clone, Copy)]
pub struct Gamma {
shape: f64, scale: f64, inv_scale: f64,
ln_gamma_shape: f64,
}
impl Gamma {
pub fn new(shape: f64, scale: f64) -> Result<Self, DistError> {
if !(shape > 0.0 && scale > 0.0) || !shape.is_finite() || !scale.is_finite() {
return Err(DistError::InvalidParameter);
}
let inv_scale = 1.0 / scale;
let ln_gamma_shape = ln_gamma(shape);
Ok(Self {
shape,
scale,
inv_scale,
ln_gamma_shape,
})
}
#[inline]
pub fn shape(&self) -> f64 {
self.shape
}
#[inline]
pub fn scale(&self) -> f64 {
self.scale
}
#[inline]
fn x_to_z(&self, x: f64) -> f64 {
x * self.inv_scale
}
}
impl Distribution for Gamma {
type Value = f64;
fn cdf(&self, x: f64) -> f64 {
if x <= 0.0 || !x.is_finite() {
return 0.0;
}
let z = self.x_to_z(x);
reg_lower_gamma(self.shape, z)
}
fn in_support(&self, x: f64) -> bool {
x >= 0.0 && x.is_finite()
}
fn sample<R: RngCore>(&self, rng: &mut R) -> f64 {
let k = self.shape;
if k >= 1.0 {
let d = k - 1.0 / 3.0;
let c = (1.0 / (9.0 * d)).sqrt();
loop {
let x = standard_normal(rng);
let v = 1.0 + c * x;
if v <= 0.0 {
continue;
}
let v3 = v * v * v;
let u = rng.next_f64();
if u < 1.0 - 0.0331 * x * x * x * x {
return self.scale * d * v3;
}
if u.ln() < 0.5 * x * x + d * (1.0 - v3 + (v3).ln()) {
return self.scale * d * v3;
}
}
} else {
let mut tmp = Gamma::new(k + 1.0, 1.0).unwrap().sample(rng);
let u = rng.next_f64();
tmp *= u.powf(1.0 / k);
self.scale * tmp
}
}
}
impl Continuous for Gamma {
fn pdf(&self, x: f64) -> f64 {
if !self.in_support(x) {
return 0.0;
}
let z = self.x_to_z(x);
((self.shape - 1.0) * z.ln() - z - self.ln_gamma_shape - self.shape * self.inv_scale.ln())
.exp()
* self.inv_scale
}
fn inv_cdf(&self, p: f64) -> f64 {
debug_assert!(p > 0.0 && p < 1.0);
let mean = self.shape * self.scale;
let std = (self.shape).sqrt() * self.scale;
let mut x = mean + std * num::standard_normal_inv_cdf(p);
if x <= 0.0 {
x = mean.max(1e-12);
}
let mut lo = 0.0_f64;
let mut hi = mean.max(x) * 2.0 + 10.0 * self.scale;
for _ in 0..50 {
let fx = self.cdf(x) - p;
if fx.abs() < 1e-10 {
break;
}
if fx < 0.0 {
lo = x;
} else {
hi = x;
}
let dfx = self.pdf(x).max(1e-300);
let mut x_new = x - fx / dfx;
if x_new <= lo || x_new >= hi || !x_new.is_finite() {
x_new = 0.5 * (lo + hi);
}
x = x_new;
}
x
}
}
impl Moments for Gamma {
fn mean(&self) -> f64 {
self.shape * self.scale
}
fn variance(&self) -> f64 {
self.shape * self.scale * self.scale
}
fn skewness(&self) -> f64 { 2.0 / self.shape.sqrt() }
fn kurtosis(&self) -> f64 { 6.0 / self.shape }
fn entropy(&self) -> f64 {
let k = self.shape;
let theta = self.scale;
k + theta.ln() + super::gamma::ln_gamma(k) + (1.0 - k) * crate::num::digamma(k)
}
}
fn standard_normal<R: RngCore>(rng: &mut R) -> f64 {
loop {
let u1 = 2.0 * rng.next_f64() - 1.0;
let u2 = 2.0 * rng.next_f64() - 1.0;
let s = u1 * u1 + u2 * u2;
if s == 0.0 || s >= 1.0 {
continue;
}
let m = (-2.0 * s.ln() / s).sqrt();
return u1 * m;
}
}
pub(crate) fn ln_gamma(z: f64) -> f64 {
const COF: [f64; 9] = [
0.99999999999980993,
676.5203681218851,
-1259.1392167224028,
771.32342877765313,
-176.61502916214059,
12.507343278686905,
-0.13857109526572012,
9.9843695780195716e-6,
1.5056327351493116e-7,
];
if z < 0.5 {
return std::f64::consts::PI.ln()
- (std::f64::consts::PI * z).sin().ln()
- ln_gamma(1.0 - z);
}
let z = z - 1.0;
let mut x = COF[0];
for (i, &c) in COF.iter().enumerate().skip(1) {
x += c / (z + i as f64);
}
let t = z + 7.5;
0.5 * (2.0 * std::f64::consts::PI).ln() + (z + 0.5) * t.ln() - t + x.ln()
}
fn reg_lower_gamma(a: f64, x: f64) -> f64 {
if x <= 0.0 {
return 0.0;
}
if x < a + 1.0 {
let mut sum = 1.0 / a;
let mut del = sum;
let mut ap = a;
for _ in 0..1000 {
ap += 1.0;
del *= x / ap;
sum += del;
if del.abs() < sum.abs() * 1e-14 {
break;
}
}
(sum * (-x + a * x.ln() - ln_gamma(a))).exp()
} else {
let mut b0 = x + 1.0 - a;
let mut c = 1.0 / 1e-30;
let mut d = 1.0 / b0;
let mut h = d;
for i in 1..=1000 {
let an = -(i as f64) * (i as f64 - a);
b0 += 2.0;
d = an * d + b0;
if d.abs() < 1e-30 {
d = 1e-30;
}
c = b0 + an / c;
if c.abs() < 1e-30 {
c = 1e-30;
}
d = 1.0 / d;
let del = d * c;
h *= del;
if (del - 1.0).abs() < 1e-14 {
break;
}
}
1.0 - (h * (-x + a * x.ln() - ln_gamma(a))).exp()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn moments() {
let g = Gamma::new(2.0, 3.0).unwrap();
assert!((g.mean() - 6.0).abs() < 1e-12);
assert!((g.variance() - 18.0).abs() < 1e-12);
}
#[test]
fn cdf_monotone() {
let g = Gamma::new(3.0, 2.0).unwrap();
assert!(g.cdf(1.0) < g.cdf(5.0));
}
#[test]
fn moments_higher() {
let g = Gamma::new(4.0, 1.0).unwrap();
assert!((g.skewness() - (2.0/4.0f64.sqrt())).abs() < 1e-15);
assert!((g.kurtosis() - (6.0/4.0)).abs() < 1e-15);
}
}