use crate::error::RusTorchError; use num_traits::Float;
use std::f64::consts::PI;
const LANCZOS_G: f64 = 7.0;
const LANCZOS_COEF: [f64; 9] = [
0.99999999999980993,
676.5203681218851,
-1259.1392167224028,
771.32342877765313,
-176.61502916214059,
12.507343278686905,
-0.13857109526572012,
9.9843695780195716e-6,
1.5056327351493116e-7,
];
pub fn gamma_scalar<T: Float>(x: T) -> Result<T, RusTorchError> {
let x_f64 = x.to_f64().ok_or(RusTorchError::DomainError(
"Cannot convert to f64".to_string(),
))?;
if x_f64 <= 0.0 {
if x_f64 == x_f64.floor() {
return Err(RusTorchError::DomainError(format!(
"Gamma undefined for x = {}",
x_f64
)));
}
let sin_pi_x = (PI * x_f64).sin();
if sin_pi_x.abs() < 1e-10 {
return Err(RusTorchError::DomainError(
"Near pole of gamma function".to_string(),
));
}
let gamma_1_minus_x = gamma_scalar(T::from(1.0 - x_f64).unwrap())?;
let result = PI / (sin_pi_x * gamma_1_minus_x.to_f64().unwrap());
return T::from(result).ok_or(RusTorchError::OverflowError("Gamma function overflow"));
}
let result = lanczos_gamma(x_f64)?;
T::from(result).ok_or(RusTorchError::OverflowError("Gamma function overflow"))
}
fn lanczos_gamma(x: f64) -> Result<f64, RusTorchError> {
if x < 0.5 {
let sin_pi_x = (PI * x).sin();
if sin_pi_x.abs() < 1e-10 {
return Err(RusTorchError::DomainError(
"Near pole of gamma function".to_string(),
));
}
Ok(PI / (sin_pi_x * lanczos_gamma(1.0 - x)?))
} else {
let x = x - 1.0;
let mut ag = LANCZOS_COEF[0];
for i in 1..9 {
ag += LANCZOS_COEF[i] / (x + i as f64);
}
let tmp = x + LANCZOS_G + 0.5;
let sqrt_2pi = (2.0 * PI).sqrt();
Ok(sqrt_2pi * tmp.powf(x + 0.5) * (-tmp).exp() * ag)
}
}
pub fn lgamma_scalar<T: Float>(x: T) -> Result<T, RusTorchError> {
let x_f64 = x.to_f64().ok_or(RusTorchError::DomainError(
"Cannot convert to f64".to_string(),
))?;
if x_f64 <= 0.0 && x_f64 == x_f64.floor() {
return Err(RusTorchError::DomainError(format!(
"lgamma undefined for non-positive integer x = {}",
x_f64
)));
}
if x_f64 > 12.0 {
let result = stirling_lgamma(x_f64);
return T::from(result).ok_or(RusTorchError::OverflowError("Lgamma Stirling overflow"));
}
let result = lanczos_lgamma(x_f64)?;
T::from(result).ok_or(RusTorchError::OverflowError("Lgamma Lanczos overflow"))
}
fn stirling_lgamma(x: f64) -> f64 {
let x_inv = 1.0 / x;
let x_inv2 = x_inv * x_inv;
let mut series = 1.0 / (12.0 * x);
series -= x_inv2 / 360.0;
series += x_inv2 * x_inv2 / 1260.0;
series -= x_inv2 * x_inv2 * x_inv2 / 1680.0;
(x - 0.5) * x.ln() - x + 0.5 * (2.0 * PI).ln() + series
}
fn lanczos_lgamma(x: f64) -> Result<f64, RusTorchError> {
if x < 0.5 {
let sin_pi_x = (PI * x).sin();
if sin_pi_x.abs() < 1e-10 {
return Err(RusTorchError::DomainError(
"Near pole of gamma function".to_string(),
));
}
Ok(PI.ln() - sin_pi_x.abs().ln() - lanczos_lgamma(1.0 - x)?)
} else {
let x = x - 1.0;
let mut ag = LANCZOS_COEF[0];
for i in 1..9 {
ag += LANCZOS_COEF[i] / (x + i as f64);
}
let tmp = x + LANCZOS_G + 0.5;
Ok(0.5 * (2.0 * PI).ln() + (x + 0.5) * tmp.ln() - tmp + ag.ln())
}
}
pub fn digamma_scalar<T: Float>(x: T) -> Result<T, RusTorchError> {
let x_f64 = x.to_f64().ok_or(RusTorchError::DomainError(
"Cannot convert to f64".to_string(),
))?;
if x_f64 <= 0.0 && x_f64 == x_f64.floor() {
return Err(RusTorchError::DomainError(format!(
"Digamma has poles at non-positive integers, x = {}",
x_f64
)));
}
let result = if x_f64 < 0.0 {
let cot_pi_x = PI * x_f64.cos() / (PI * x_f64).sin();
digamma_scalar(T::from(1.0 - x_f64).unwrap())?
.to_f64()
.unwrap()
- cot_pi_x
} else if x_f64 < 6.0 {
let mut result = 0.0;
let mut x_shift = x_f64;
while x_shift < 6.0 {
result -= 1.0 / x_shift;
x_shift += 1.0;
}
result + digamma_asymptotic(x_shift)
} else {
digamma_asymptotic(x_f64)
};
T::from(result).ok_or(RusTorchError::OverflowError("Digamma overflow"))
}
fn digamma_asymptotic(x: f64) -> f64 {
let x_inv = 1.0 / x;
let x_inv2 = x_inv * x_inv;
let mut series = -x_inv / 2.0;
series -= x_inv2 / 12.0;
series += x_inv2 * x_inv2 / 120.0;
series -= x_inv2 * x_inv2 * x_inv2 / 252.0;
x.ln() + series
}
pub fn beta<T: Float>(a: T, b: T) -> Result<T, RusTorchError> {
let lbeta_val = lbeta(a, b)?;
T::from(lbeta_val.to_f64().unwrap().exp())
.ok_or(RusTorchError::OverflowError("Beta function overflow"))
}
pub fn lbeta<T: Float>(a: T, b: T) -> Result<T, RusTorchError> {
let lgamma_a = lgamma_scalar(a)?;
let lgamma_b = lgamma_scalar(b)?;
let a_plus_b = a + b;
let lgamma_ab = lgamma_scalar(a_plus_b)?;
Ok(lgamma_a + lgamma_b - lgamma_ab)
}
pub fn gamma<T: Float + 'static>(
x: &crate::tensor::Tensor<T>,
) -> Result<crate::tensor::Tensor<T>, RusTorchError> {
let mut result = vec![T::zero(); x.data.len()];
for (i, &val) in x.data.iter().enumerate() {
result[i] = gamma_scalar(val)?;
}
Ok(crate::tensor::Tensor::from_vec(result, x.shape().to_vec()))
}
pub fn lgamma<T: Float + 'static>(
x: &crate::tensor::Tensor<T>,
) -> Result<crate::tensor::Tensor<T>, RusTorchError> {
let mut result = vec![T::zero(); x.data.len()];
for (i, &val) in x.data.iter().enumerate() {
result[i] = lgamma_scalar(val)?;
}
Ok(crate::tensor::Tensor::from_vec(result, x.shape().to_vec()))
}
pub fn digamma<T: Float + 'static>(
x: &crate::tensor::Tensor<T>,
) -> Result<crate::tensor::Tensor<T>, RusTorchError> {
let mut result = vec![T::zero(); x.data.len()];
for (i, &val) in x.data.iter().enumerate() {
result[i] = digamma_scalar(val)?;
}
Ok(crate::tensor::Tensor::from_vec(result, x.shape().to_vec()))
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_gamma_function() {
assert_relative_eq!(gamma_scalar(1.0_f64).unwrap(), 1.0, epsilon = 1e-10);
assert_relative_eq!(gamma_scalar(2.0_f64).unwrap(), 1.0, epsilon = 1e-10);
assert_relative_eq!(gamma_scalar(3.0_f64).unwrap(), 2.0, epsilon = 1e-10);
assert_relative_eq!(gamma_scalar(4.0_f64).unwrap(), 6.0, epsilon = 1e-10);
assert_relative_eq!(gamma_scalar(5.0_f64).unwrap(), 24.0, epsilon = 1e-10);
assert_relative_eq!(gamma_scalar(0.5_f64).unwrap(), PI.sqrt(), epsilon = 1e-10);
}
#[test]
fn test_lgamma_function() {
assert_relative_eq!(lgamma_scalar(1.0_f64).unwrap(), 0.0, epsilon = 1e-10);
assert_relative_eq!(lgamma_scalar(2.0_f64).unwrap(), 0.0, epsilon = 1e-10);
assert_relative_eq!(
lgamma_scalar(3.0_f64).unwrap(),
2.0_f64.ln(),
epsilon = 1e-10
);
assert_relative_eq!(
lgamma_scalar(4.0_f64).unwrap(),
6.0_f64.ln(),
epsilon = 1e-10
);
let x = 100.0_f64;
let expected = stirling_lgamma(x);
assert_relative_eq!(lgamma_scalar(x).unwrap(), expected, epsilon = 1e-10);
}
#[test]
fn test_digamma_function() {
let euler_gamma = 0.5772156649015329;
assert_relative_eq!(
digamma_scalar(1.0_f64).unwrap(),
-euler_gamma,
epsilon = 1e-7
);
let x = 2.5_f64;
let psi_x = digamma_scalar(x).unwrap();
let psi_x_plus_1 = digamma_scalar(x + 1.0).unwrap();
assert_relative_eq!(psi_x_plus_1, psi_x + 1.0 / x, epsilon = 1e-10);
}
#[test]
fn test_beta_function() {
assert_relative_eq!(
beta(2.0_f64, 3.0).unwrap(),
beta(3.0_f64, 2.0).unwrap(),
epsilon = 1e-10
);
assert_relative_eq!(beta(1.0_f64, 1.0).unwrap(), 1.0, epsilon = 1e-10);
assert_relative_eq!(beta(2.0_f64, 2.0).unwrap(), 1.0 / 6.0, epsilon = 1e-10);
}
}