use crate::distributions::{Distribution, DistributionTrait, DistributionUtils};
use crate::error::{RusTorchError, RusTorchResult};
use crate::tensor::Tensor;
use num_traits::Float;
use std::marker::PhantomData;
#[derive(Debug, Clone)]
pub struct Gamma<T: Float> {
pub concentration: Tensor<T>,
pub rate: Option<Tensor<T>>,
pub scale: Option<Tensor<T>>,
pub base: Distribution,
_phantom: PhantomData<T>,
}
impl<T: Float + 'static> Gamma<T>
where
T: rand::distributions::uniform::SampleUniform + num_traits::FromPrimitive + std::fmt::Display,
{
pub fn from_concentration_rate(
concentration: Tensor<T>,
rate: Tensor<T>,
validate_args: bool,
) -> RusTorchResult<Self> {
if validate_args {
DistributionUtils::validate_positive(&concentration, "concentration")?;
DistributionUtils::validate_positive(&rate, "rate")?;
}
let batch_shape = Distribution::broadcast_shapes(concentration.shape(), rate.shape())?;
let event_shape = vec![];
Ok(Self {
concentration,
rate: Some(rate),
scale: None,
base: Distribution::new(batch_shape, event_shape, validate_args),
_phantom: PhantomData,
})
}
pub fn from_concentration_scale(
concentration: Tensor<T>,
scale: Tensor<T>,
validate_args: bool,
) -> RusTorchResult<Self> {
if validate_args {
DistributionUtils::validate_positive(&concentration, "concentration")?;
DistributionUtils::validate_positive(&scale, "scale")?;
}
let batch_shape = Distribution::broadcast_shapes(concentration.shape(), scale.shape())?;
let event_shape = vec![];
Ok(Self {
concentration,
rate: None,
scale: Some(scale),
base: Distribution::new(batch_shape, event_shape, validate_args),
_phantom: PhantomData,
})
}
pub fn exponential(rate: T, validate_args: bool) -> RusTorchResult<Self> {
let concentration = Tensor::from_vec(vec![T::one()], vec![]);
let rate_tensor = Tensor::from_vec(vec![rate], vec![]);
Self::from_concentration_rate(concentration, rate_tensor, validate_args)
}
pub fn get_rate(&self) -> RusTorchResult<Tensor<T>> {
match (&self.rate, &self.scale) {
(Some(rate), _) => Ok(rate.clone()),
(None, Some(scale)) => {
let scale_data = scale.data.as_slice().unwrap();
let rate_data: Vec<T> = scale_data.iter().map(|&s| T::one() / s).collect();
Ok(Tensor::from_vec(rate_data, scale.shape().to_vec()))
}
_ => Err(RusTorchError::invalid_parameter(
"Either rate or scale must be specified",
)),
}
}
pub fn get_scale(&self) -> RusTorchResult<Tensor<T>> {
match (&self.scale, &self.rate) {
(Some(scale), _) => Ok(scale.clone()),
(None, Some(rate)) => {
let rate_data = rate.data.as_slice().unwrap();
let scale_data: Vec<T> = rate_data.iter().map(|&r| T::one() / r).collect();
Ok(Tensor::from_vec(scale_data, rate.shape().to_vec()))
}
_ => Err(RusTorchError::invalid_parameter(
"Either rate or scale must be specified",
)),
}
}
fn log_gamma_approx(x: T) -> T {
let half = T::from(0.5).unwrap();
let pi = T::from(std::f64::consts::PI).unwrap();
let two = T::from(2.0).unwrap();
if x < T::one() {
let gamma_x_plus_1 = Self::log_gamma_approx(x + T::one());
gamma_x_plus_1 - x.ln()
} else {
(x - half) * x.ln() - x + half * (two * pi).ln()
}
}
fn log_normalizing_constant(&self) -> RusTorchResult<Tensor<T>> {
let concentration_data = self.concentration.data.as_slice().unwrap();
let rate = self.get_rate()?;
let rate_data = rate.data.as_slice().unwrap();
let result_data: Vec<T> = concentration_data
.iter()
.zip(rate_data.iter().cycle())
.map(|(&alpha, &beta)| alpha * beta.ln() - Self::log_gamma_approx(alpha))
.collect();
Ok(Tensor::from_vec(
result_data,
self.concentration.shape().to_vec(),
))
}
fn sample_gamma(shape: T) -> T
where
T: rand::distributions::uniform::SampleUniform,
{
use rand::{thread_rng, Rng};
let mut rng = thread_rng();
if shape >= T::one() {
#[allow(clippy::many_single_char_names)]
{
let d = shape - T::from(1.0 / 3.0).unwrap();
let c = T::one() / (T::from(9.0).unwrap() * d).sqrt();
loop {
let x: T = rng.gen_range(T::from(-4.0).unwrap()..T::from(4.0).unwrap());
let v = (T::one() + c * x).powi(3);
if v > T::zero() {
let u: T = rng.gen_range(T::zero()..T::one());
let x_squared = x * x;
if u < T::one() - T::from(0.0331).unwrap() * x_squared * x_squared {
return d * v;
}
if u.ln() < T::from(0.5).unwrap() * x_squared + d * (T::one() - v + v.ln())
{
return d * v;
}
}
}
}
} else {
let gamma_alpha_plus_1 = Self::sample_gamma(shape + T::one());
let u: T = rng.gen_range(T::from(1e-10).unwrap()..T::one());
gamma_alpha_plus_1 * u.powf(T::one() / shape)
}
}
}
impl<T: Float + 'static> DistributionTrait<T> for Gamma<T>
where
T: rand::distributions::uniform::SampleUniform + num_traits::FromPrimitive + std::fmt::Display,
{
fn sample(&self, shape: Option<&[usize]>) -> RusTorchResult<Tensor<T>> {
let sample_shape = self.base.expand_shape(shape);
let concentration_data = self.concentration.data.as_slice().unwrap();
let scale = self.get_scale()?;
let scale_data = scale.data.as_slice().unwrap();
let sample_size: usize = sample_shape.iter().product();
let mut result_data = Vec::with_capacity(sample_size);
for i in 0..sample_size {
let batch_idx = i % concentration_data.len();
let alpha = concentration_data[batch_idx];
let theta = scale_data[batch_idx % scale_data.len()];
let gamma_sample = Self::sample_gamma(alpha);
result_data.push(gamma_sample * theta);
}
Ok(Tensor::from_vec(result_data, sample_shape))
}
fn log_prob(&self, value: &Tensor<T>) -> RusTorchResult<Tensor<T>> {
let concentration_data = self.concentration.data.as_slice().unwrap();
let rate = self.get_rate()?;
let rate_data = rate.data.as_slice().unwrap();
let value_data = value.data.as_slice().unwrap();
let log_norm = self.log_normalizing_constant()?;
let log_norm_data = log_norm.data.as_slice().unwrap();
let result_data: Vec<T> = (0..value_data.len())
.map(|i| {
let x = value_data[i];
let alpha = concentration_data[i % concentration_data.len()];
let beta = rate_data[i % rate_data.len()];
let ln_norm = log_norm_data[i % log_norm_data.len()];
if x <= T::zero() {
T::from(-1e10).unwrap() } else {
(alpha - T::one()) * x.ln() - beta * x + ln_norm
}
})
.collect();
Ok(Tensor::from_vec(result_data, value.shape().to_vec()))
}
fn cdf(&self, _value: &Tensor<T>) -> RusTorchResult<Tensor<T>> {
Err(RusTorchError::UnsupportedOperation(
"Gamma CDF requires incomplete gamma function (not implemented)",
))
}
fn icdf(&self, _value: &Tensor<T>) -> RusTorchResult<Tensor<T>> {
Err(RusTorchError::UnsupportedOperation(
"Gamma inverse CDF not implemented",
))
}
fn mean(&self) -> RusTorchResult<Tensor<T>> {
let scale = self.get_scale()?;
let concentration_data = self.concentration.data.as_slice().unwrap();
let scale_data = scale.data.as_slice().unwrap();
let result_data: Vec<T> = concentration_data
.iter()
.zip(scale_data.iter().cycle())
.map(|(&alpha, &theta)| alpha * theta)
.collect();
Ok(Tensor::from_vec(
result_data,
self.concentration.shape().to_vec(),
))
}
fn variance(&self) -> RusTorchResult<Tensor<T>> {
let scale = self.get_scale()?;
let concentration_data = self.concentration.data.as_slice().unwrap();
let scale_data = scale.data.as_slice().unwrap();
let result_data: Vec<T> = concentration_data
.iter()
.zip(scale_data.iter().cycle())
.map(|(&alpha, &theta)| alpha * theta * theta)
.collect();
Ok(Tensor::from_vec(
result_data,
self.concentration.shape().to_vec(),
))
}
fn entropy(&self) -> RusTorchResult<Tensor<T>> {
let concentration_data = self.concentration.data.as_slice().unwrap();
let scale = self.get_scale()?;
let scale_data = scale.data.as_slice().unwrap();
let result_data: Vec<T> = concentration_data
.iter()
.zip(scale_data.iter().cycle())
.map(|(&alpha, &theta)| alpha + theta.ln() + Self::log_gamma_approx(alpha))
.collect();
Ok(Tensor::from_vec(
result_data,
self.concentration.shape().to_vec(),
))
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
#[test]
fn test_gamma_creation() {
let concentration = Tensor::from_vec(vec![2.0f32], vec![1]);
let rate = Tensor::from_vec(vec![1.0f32], vec![1]);
let gamma = Gamma::from_concentration_rate(concentration, rate, true).unwrap();
assert_eq!(gamma.base.batch_shape, vec![1]);
}
#[test]
fn test_gamma_exponential() {
let gamma = Gamma::<f32>::exponential(1.0, true).unwrap();
let mean = gamma.mean().unwrap();
let variance = gamma.variance().unwrap();
assert_eq!(mean.data.as_slice().unwrap()[0], 1.0);
assert_eq!(variance.data.as_slice().unwrap()[0], 1.0);
}
#[test]
fn test_gamma_rate_scale_conversion() {
let concentration = Tensor::from_vec(vec![2.0f32], vec![1]);
let rate = Tensor::from_vec(vec![0.5f32], vec![1]);
let gamma = Gamma::from_concentration_rate(concentration, rate, true).unwrap();
let scale = gamma.get_scale().unwrap();
assert_eq!(scale.data.as_slice().unwrap()[0], 2.0);
}
#[test]
fn test_gamma_sampling() {
let gamma = Gamma::<f32>::exponential(1.0, true).unwrap();
let samples = gamma.sample(Some(&[1000])).unwrap();
assert_eq!(samples.shape(), &[1000]);
let data = samples.data.as_slice().unwrap();
for &sample in data {
assert!(sample > 0.0); }
}
#[test]
fn test_gamma_log_prob() {
let concentration = Tensor::from_vec(vec![2.0f32], vec![1]);
let rate = Tensor::from_vec(vec![1.0f32], vec![1]);
let gamma = Gamma::from_concentration_rate(concentration, rate, true).unwrap();
let values = Tensor::from_vec(vec![1.0f32, 2.0], vec![2]);
let log_probs = gamma.log_prob(&values).unwrap();
let log_prob_data = log_probs.data.as_slice().unwrap();
assert!(log_prob_data[0].is_finite());
assert!(log_prob_data[1].is_finite());
assert!(log_prob_data[0] > log_prob_data[1]);
}
#[test]
fn test_gamma_mean_variance() {
let concentration = Tensor::from_vec(vec![3.0f32], vec![1]);
let rate = Tensor::from_vec(vec![2.0f32], vec![1]);
let gamma = Gamma::from_concentration_rate(concentration, rate, true).unwrap();
let mean = gamma.mean().unwrap();
let variance = gamma.variance().unwrap();
assert_eq!(mean.data.as_slice().unwrap()[0], 1.5);
assert_eq!(variance.data.as_slice().unwrap()[0], 0.75);
}
#[test]
fn test_invalid_parameters() {
let concentration = Tensor::from_vec(vec![-1.0f32], vec![1]); let rate = Tensor::from_vec(vec![1.0f32], vec![1]);
assert!(Gamma::from_concentration_rate(concentration, rate, true).is_err());
let concentration = Tensor::from_vec(vec![1.0f32], vec![1]);
let rate = Tensor::from_vec(vec![0.0f32], vec![1]);
assert!(Gamma::from_concentration_rate(concentration, rate, true).is_err());
}
#[test]
fn test_log_gamma_approx() {
let log_gamma_1 = Gamma::<f32>::log_gamma_approx(1.0);
assert_abs_diff_eq!(log_gamma_1, 0.0, epsilon = 0.1);
let log_gamma_2 = Gamma::<f32>::log_gamma_approx(2.0);
assert_abs_diff_eq!(log_gamma_2, 0.0, epsilon = 0.1);
let log_gamma_3 = Gamma::<f32>::log_gamma_approx(3.0);
assert_abs_diff_eq!(log_gamma_3, 2.0f32.ln(), epsilon = 0.1);
}
}