use crate::{FloatExt, XError, XResult, random::PAR_THRESHOLD};
use rand::prelude::*;
use rand_distr::{Exp1, Open01, StandardNormal};
use rand_xoshiro::Xoshiro256PlusPlus;
use rayon::prelude::*;
#[derive(Debug, Clone)]
pub struct Gamma<T: FloatExt = f64> {
shape: T,
scale: T,
}
impl<T: FloatExt> Default for Gamma<T> {
fn default() -> Self {
Self {
shape: T::one(),
scale: T::one(),
}
}
}
impl<T: FloatExt> Gamma<T> {
pub fn new(shape: T, scale: T) -> XResult<Self> {
if shape <= T::zero() {
return Err(XError::InvalidParameters(format!(
"The shape parameter `shape` must be greater than 0, got {shape:?}"
)));
}
if scale <= T::zero() {
return Err(XError::InvalidParameters(format!(
"The scale parameter `scale` must be greater than 0, got {scale:?}"
)));
}
Ok(Self { shape, scale })
}
pub fn get_shape(&self) -> T {
self.shape
}
pub fn get_scale(&self) -> T {
self.scale
}
pub fn samples(&self, n: usize) -> XResult<Vec<T>>
where
StandardNormal: Distribution<T>,
Exp1: Distribution<T>,
Open01: Distribution<T>,
{
rands(self.shape, self.scale, n)
}
}
pub fn rand<T: FloatExt>(shape: T, scale: T) -> XResult<T>
where
StandardNormal: Distribution<T>,
Exp1: Distribution<T>,
Open01: Distribution<T>,
{
let gamma = rand_distr::Gamma::new(shape, scale)
.map_err(|e| XError::InvalidParameters(e.to_string()))?;
let mut rng = Xoshiro256PlusPlus::from_rng(&mut rand::rng());
Ok(rng.sample(gamma))
}
pub fn rands<T: FloatExt>(shape: T, scale: T, n: usize) -> XResult<Vec<T>>
where
StandardNormal: Distribution<T>,
Exp1: Distribution<T>,
Open01: Distribution<T>,
{
let gamma = rand_distr::Gamma::new(shape, scale)
.map_err(|e| XError::InvalidParameters(e.to_string()))?;
if n <= PAR_THRESHOLD {
let mut rng = Xoshiro256PlusPlus::from_rng(&mut rand::rng());
Ok((0..n).map(|_| rng.sample(gamma)).collect())
} else {
Ok((0..n)
.into_par_iter()
.map_init(
|| Xoshiro256PlusPlus::from_rng(&mut rand::rng()),
|r, _| r.sample(gamma),
)
.collect())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::utils::calculate_stats;
use num_traits::Float;
#[test]
fn test_rand() {
let random = rand(1.0, 1.0).unwrap();
assert!(random.is_finite());
}
#[test]
fn test_rands() {
let randoms = rands(1.0, 1.0, 10).unwrap();
assert_eq!(randoms.len(), 10);
assert!(randoms.iter().all(|r| r.is_finite()));
}
#[test]
fn test_gamma_stats() {
let n = 1_000_000;
let shape = 1.0;
let scale = 1.0;
let samples = rands(shape, scale, n).unwrap();
let (mean, variance) = calculate_stats(&samples);
let std_dev = variance.sqrt();
assert!(mean.is_finite());
assert!(std_dev.is_finite());
}
}