use crate::distributions::traits::Distribution;
use crate::error::{StatsError, StatsResult};
use crate::utils::special_functions::{bisect_inverse_cdf, ln_gamma, regularized_incomplete_gamma};
#[derive(Debug, Clone, Copy)]
pub struct Gamma {
pub alpha: f64,
pub beta: f64,
}
impl Gamma {
pub fn new(alpha: f64, beta: f64) -> StatsResult<Self> {
if alpha <= 0.0 || beta <= 0.0 {
return Err(StatsError::InvalidInput {
message: "Gamma::new: alpha and beta must be positive".to_string(),
});
}
Ok(Self { alpha, beta })
}
pub fn fit(data: &[f64]) -> StatsResult<Self> {
if data.is_empty() {
return Err(StatsError::InvalidInput {
message: "Gamma::fit: data must not be empty".to_string(),
});
}
if data.iter().any(|&x| x <= 0.0) {
return Err(StatsError::InvalidInput {
message: "Gamma::fit: all data values must be positive".to_string(),
});
}
let n = data.len() as f64;
let mean = data.iter().sum::<f64>() / n;
let log_mean = data.iter().map(|&x| x.ln()).sum::<f64>() / n;
let s = mean.ln() - log_mean;
let alpha = if s > 0.0 {
(3.0 - s + ((s - 3.0).powi(2) + 24.0 * s).sqrt()) / (12.0 * s)
} else {
1.0
};
let beta = alpha / mean;
Self::new(alpha, beta)
}
}
impl Distribution for Gamma {
fn name(&self) -> &str {
"Gamma"
}
fn num_params(&self) -> usize {
2
}
fn pdf(&self, x: f64) -> StatsResult<f64> {
if x <= 0.0 {
return Ok(0.0);
}
Ok(self.logpdf(x)?.exp())
}
fn logpdf(&self, x: f64) -> StatsResult<f64> {
if x <= 0.0 {
return Ok(f64::NEG_INFINITY);
}
Ok(self.alpha * self.beta.ln() + (self.alpha - 1.0) * x.ln()
- self.beta * x
- ln_gamma(self.alpha))
}
fn cdf(&self, x: f64) -> StatsResult<f64> {
if x <= 0.0 {
return Ok(0.0);
}
Ok(regularized_incomplete_gamma(self.alpha, self.beta * x))
}
fn inverse_cdf(&self, p: f64) -> StatsResult<f64> {
if !(0.0..=1.0).contains(&p) {
return Err(StatsError::InvalidInput {
message: "Gamma::inverse_cdf: p must be in [0, 1]".to_string(),
});
}
if p == 0.0 {
return Ok(0.0);
}
if p == 1.0 {
return Ok(f64::INFINITY);
}
let alpha = self.alpha;
let beta = self.beta;
let hi = (alpha / beta) + 10.0 * (alpha / beta / beta).sqrt();
Ok(bisect_inverse_cdf(
|x| regularized_incomplete_gamma(alpha, beta * x),
p,
0.0,
hi.max(1.0),
))
}
fn mean(&self) -> f64 {
self.alpha / self.beta
}
fn variance(&self) -> f64 {
self.alpha / (self.beta * self.beta)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_gamma_mean_variance() {
let g = Gamma::new(3.0, 2.0).unwrap();
assert!((g.mean() - 1.5).abs() < 1e-10);
assert!((g.variance() - 0.75).abs() < 1e-10);
}
#[test]
fn test_gamma_pdf_positive() {
let g = Gamma::new(2.0, 1.0).unwrap();
let p = g.pdf(1.0).unwrap();
assert!((p - (-1.0_f64).exp()).abs() < 1e-8);
}
#[test]
fn test_gamma_cdf_zero() {
let g = Gamma::new(2.0, 1.0).unwrap();
assert_eq!(g.cdf(0.0).unwrap(), 0.0);
}
#[test]
fn test_gamma_inverse_cdf_roundtrip() {
let g = Gamma::new(3.0, 0.5).unwrap();
for p in [0.1, 0.25, 0.5, 0.75, 0.9] {
let x = g.inverse_cdf(p).unwrap();
let p_back = g.cdf(x).unwrap();
assert!((p - p_back).abs() < 1e-6, "p={p}: roundtrip failed");
}
}
#[test]
fn test_gamma_fit() {
let data: Vec<f64> = vec![1.0, 2.0, 3.0, 1.5, 2.5, 1.8, 2.2, 0.8, 3.2, 1.0];
let g = Gamma::fit(&data).unwrap();
let data_mean = data.iter().sum::<f64>() / data.len() as f64;
assert!((g.mean() - data_mean).abs() < 1e-10);
}
#[test]
fn test_gamma_invalid() {
assert!(Gamma::new(0.0, 1.0).is_err());
assert!(Gamma::new(1.0, -1.0).is_err());
}
}