use crate::error::{SpecialError, SpecialResult};
use scirs2_core::numeric::{Float, FromPrimitive};
use std::fmt::Debug;
use super::approximations::stirling_approximation_ln;
use super::constants::{EULER_MASCHERONI, LANCZOS_COEFFICIENTS};
#[inline(always)]
fn const_f64<F: Float + FromPrimitive>(value: f64) -> F {
F::from(value).unwrap_or_else(|| {
if value > 0.0 {
F::infinity()
} else if value < 0.0 {
F::neg_infinity()
} else {
F::zero()
}
})
}
#[allow(dead_code)]
pub fn gamma_enhanced<F>(x: F) -> SpecialResult<F>
where
F: Float + FromPrimitive + Debug + std::ops::AddAssign,
{
let x_f64 = x
.to_f64()
.ok_or_else(|| SpecialError::ValueError("Failed to convert x to f64".to_string()))?;
if x.is_nan() {
return Ok(F::nan());
}
if x == F::zero() {
return Ok(F::infinity());
}
if x < F::zero() {
return gamma_negative_enhanced(x);
}
if x < const_f64::<F>(1e-10) {
return gamma_near_zero(x);
}
if x < const_f64::<F>(0.5) {
return gamma_small_positive(x);
}
if x_f64 > 171.0 {
let log_gamma = gammaln_enhanced(x)?;
if log_gamma > const_f64::<F>(std::f64::MAX.ln() * 0.9) {
return Ok(F::infinity());
}
return Ok(log_gamma.exp());
}
if x_f64 > 12.0 {
return gamma_large_positive(x);
}
gamma_moderate_positive(x)
}
#[allow(dead_code)]
fn gamma_negative_enhanced<F>(x: F) -> SpecialResult<F>
where
F: Float + FromPrimitive + Debug + std::ops::AddAssign,
{
let x_f64 = x
.to_f64()
.ok_or_else(|| SpecialError::ValueError("Failed to convert x to f64".to_string()))?;
let nearest_int = x_f64.round() as i32;
let distance_to_int = (x_f64 - nearest_int as f64).abs();
if nearest_int <= 0 && distance_to_int < 1e-14 {
return Err(SpecialError::DomainError(format!(
"Gamma function has a pole at x = {x_f64}"
)));
}
if nearest_int <= 0 && distance_to_int < 1e-4 {
return gamma_near_pole(x, nearest_int);
}
let pi = const_f64::<F>(std::f64::consts::PI);
let sin_pi_x = (pi * x).sin();
if sin_pi_x.abs() < const_f64::<F>(1e-14) {
return Err(SpecialError::DomainError(
"Gamma function pole detected".to_string(),
));
}
let one_minus_x = F::one() - x;
let gamma_complement = if one_minus_x > const_f64::<F>(171.0) {
let log_gamma = gammaln_enhanced(one_minus_x)?;
if log_gamma > const_f64::<F>(std::f64::MAX.ln() * 0.9) {
F::infinity()
} else {
log_gamma.exp()
}
} else {
gamma_enhanced(one_minus_x)?
};
if gamma_complement.is_infinite() {
return Ok(F::zero());
}
Ok(pi / (sin_pi_x * gamma_complement))
}
#[allow(dead_code)]
fn gamma_near_pole<F>(x: F, n: i32) -> SpecialResult<F>
where
F: Float + FromPrimitive + Debug + std::ops::AddAssign,
{
let epsilon = x + const_f64::<F>(n as f64);
let n_abs = (-n) as u32;
let n_factorial = factorial_float::<F>(n_abs);
let sign = if n_abs.is_multiple_of(2) {
F::one()
} else {
-F::one()
};
let leading = sign / (n_factorial * epsilon);
let harmonic = harmonic_number::<F>(n_abs);
let correction = F::one() - epsilon * harmonic;
let harmonic_sq_sum = harmonic_sum_squared::<F>(n_abs);
let second_order =
epsilon * epsilon * (harmonic * harmonic - harmonic_sq_sum) / const_f64::<F>(2.0);
Ok(leading * (correction + second_order))
}
#[allow(dead_code)]
fn gamma_near_zero<F>(x: F) -> SpecialResult<F>
where
F: Float + FromPrimitive + Debug,
{
let gamma = const_f64::<F>(EULER_MASCHERONI);
let pi_sq = const_f64::<F>(std::f64::consts::PI * std::f64::consts::PI);
let leading = F::one() / x;
let c1 = -gamma;
let c2 = gamma * gamma / const_f64::<F>(2.0) + pi_sq / const_f64::<F>(12.0);
let zeta3 = const_f64::<F>(1.2020569031595942); let c3 = -(gamma * gamma * gamma / const_f64::<F>(6.0)
+ pi_sq * gamma / const_f64::<F>(12.0)
+ zeta3);
Ok(leading + c1 + c2 * x + c3 * x * x)
}
#[allow(dead_code)]
fn gamma_small_positive<F>(x: F) -> SpecialResult<F>
where
F: Float + FromPrimitive + Debug + std::ops::AddAssign,
{
let x_plus_1 = x + F::one();
let gamma_x_plus_1 = gamma_moderate_positive(x_plus_1)?;
Ok(gamma_x_plus_1 / x)
}
#[allow(dead_code)]
fn gamma_moderate_positive<F>(x: F) -> SpecialResult<F>
where
F: Float + FromPrimitive + Debug + std::ops::AddAssign,
{
const G: f64 = 7.0;
let x_minus_1 = x - F::one();
let mut ag = const_f64::<F>(LANCZOS_COEFFICIENTS[0]);
for (i, &coeff) in LANCZOS_COEFFICIENTS.iter().enumerate().skip(1) {
ag += const_f64::<F>(coeff) / (x_minus_1 + const_f64::<F>(i as f64));
}
let sqrt_2pi = const_f64::<F>((2.0 * std::f64::consts::PI).sqrt());
let tmp = x_minus_1 + const_f64::<F>(G + 0.5);
let power_term = tmp.powf(x_minus_1 + const_f64::<F>(0.5));
let exp_term = (-tmp).exp();
Ok(sqrt_2pi * ag * power_term * exp_term)
}
#[allow(dead_code)]
fn gamma_large_positive<F>(x: F) -> SpecialResult<F>
where
F: Float + FromPrimitive + Debug + std::ops::AddAssign,
{
let log_gamma = stirling_approximation_ln(x);
if log_gamma > const_f64::<F>(std::f64::MAX.ln() * 0.9) {
return Ok(F::infinity());
}
Ok(log_gamma.exp())
}
#[allow(dead_code)]
pub fn gammaln_enhanced<F>(x: F) -> SpecialResult<F>
where
F: Float + FromPrimitive + Debug + std::ops::AddAssign,
{
let x_f64 = x
.to_f64()
.ok_or_else(|| SpecialError::ValueError("Failed to convert x to f64".to_string()))?;
if x <= F::zero() {
return Err(SpecialError::DomainError(
"log-gamma requires positive argument".to_string(),
));
}
if x < const_f64::<F>(1e-8) {
let gamma = const_f64::<F>(EULER_MASCHERONI);
return Ok(-x.ln() - gamma * x);
}
if x_f64 < 50.0 {
return gammaln_lanczos(x);
}
Ok(stirling_approximation_ln(x))
}
#[allow(dead_code)]
fn gammaln_lanczos<F>(x: F) -> SpecialResult<F>
where
F: Float + FromPrimitive + Debug + std::ops::AddAssign,
{
const G: f64 = 7.0;
let x_minus_1 = x - F::one();
let mut ag = const_f64::<F>(LANCZOS_COEFFICIENTS[0]);
for (i, &coeff) in LANCZOS_COEFFICIENTS.iter().enumerate().skip(1) {
ag += const_f64::<F>(coeff) / (x_minus_1 + const_f64::<F>(i as f64));
}
let log_sqrt_2pi = const_f64::<F>((2.0 * std::f64::consts::PI).ln() / 2.0);
let tmp = x_minus_1 + const_f64::<F>(G + 0.5);
Ok(log_sqrt_2pi + ag.ln() + (x_minus_1 + const_f64::<F>(0.5)) * tmp.ln() - tmp)
}
#[allow(dead_code)]
pub fn rgamma<F>(x: F) -> SpecialResult<F>
where
F: Float + FromPrimitive + Debug + std::ops::AddAssign,
{
let x_f64 = x
.to_f64()
.ok_or_else(|| SpecialError::ValueError("Failed to convert x to f64".to_string()))?;
if x_f64 <= 0.0 && x_f64.fract() == 0.0 {
return Ok(F::zero());
}
if x_f64 > 171.0 {
return Ok(F::zero());
}
let gamma_x = gamma_enhanced(x)?;
if gamma_x.is_infinite() {
return Ok(F::zero());
}
Ok(F::one() / gamma_x)
}
#[allow(dead_code)]
pub fn lgamma_with_sign<F>(x: F) -> SpecialResult<(F, F)>
where
F: Float + FromPrimitive + Debug + std::ops::AddAssign,
{
let x_f64 = x
.to_f64()
.ok_or_else(|| SpecialError::ValueError("Failed to convert x to f64".to_string()))?;
if x > F::zero() {
let log_gamma = gammaln_enhanced(x)?;
return Ok((log_gamma, F::one()));
}
let nearest_int = x_f64.round() as i32;
if nearest_int <= 0 && (x_f64 - nearest_int as f64).abs() < 1e-14 {
return Err(SpecialError::DomainError(format!(
"log-gamma undefined at pole x = {x_f64}"
)));
}
let pi = const_f64::<F>(std::f64::consts::PI);
let sin_pi_x = (pi * x).sin();
let one_minus_x = F::one() - x;
let log_pi = pi.ln();
let log_sin_pi_x = sin_pi_x.abs().ln();
let log_gamma_complement = gammaln_enhanced(one_minus_x)?;
let log_abs_gamma = log_pi - log_sin_pi_x - log_gamma_complement;
let floor_neg_x = (-x_f64).floor() as i32;
let sin_sign = if sin_pi_x > F::zero() {
F::one()
} else {
-F::one()
};
let parity_sign = if floor_neg_x % 2 == 0 {
F::one()
} else {
-F::one()
};
Ok((log_abs_gamma, sin_sign * parity_sign))
}
#[allow(dead_code)]
pub fn gamma_ratio<F>(a: F, b: F) -> SpecialResult<F>
where
F: Float + FromPrimitive + Debug + std::ops::AddAssign,
{
let a_f64 = a
.to_f64()
.ok_or_else(|| SpecialError::ValueError("Failed to convert a to f64".to_string()))?;
let b_f64 = b
.to_f64()
.ok_or_else(|| SpecialError::ValueError("Failed to convert b to f64".to_string()))?;
let diff = a_f64 - b_f64;
if diff.abs() < 10.0 && diff.abs() == diff.abs().floor() {
let n = diff as i32;
if n >= 0 {
let mut result = F::one();
for i in 0..n {
result = result * (b + const_f64::<F>(i as f64));
}
return Ok(result);
} else {
let mut result = F::one();
for i in 0..(-n) {
result = result * (a + const_f64::<F>(i as f64));
}
return Ok(F::one() / result);
}
}
let log_gamma_a = gammaln_enhanced(a)?;
let log_gamma_b = gammaln_enhanced(b)?;
let log_ratio = log_gamma_a - log_gamma_b;
if log_ratio > const_f64::<F>(std::f64::MAX.ln() * 0.9) {
return Ok(F::infinity());
}
if log_ratio < const_f64::<F>(std::f64::MIN.ln() * 0.9) {
return Ok(F::zero());
}
Ok(log_ratio.exp())
}
#[allow(dead_code)]
pub fn pochhammer_enhanced<F>(a: F, n: F) -> SpecialResult<F>
where
F: Float + FromPrimitive + Debug + std::ops::AddAssign,
{
let n_f64 = n
.to_f64()
.ok_or_else(|| SpecialError::ValueError("Failed to convert n to f64".to_string()))?;
if n_f64.fract() == 0.0 && (0.0..=20.0).contains(&n_f64) {
let n_int = n_f64 as usize;
let mut result = F::one();
for i in 0..n_int {
result = result * (a + const_f64::<F>(i as f64));
}
return Ok(result);
}
gamma_ratio(a + n, a)
}
#[allow(dead_code)]
fn factorial_float<F: Float + FromPrimitive>(n: u32) -> F {
match n {
0 | 1 => F::one(),
_ => {
let mut result = F::one();
for i in 2..=n {
result = result * const_f64::<F>(i as f64);
}
result
}
}
}
#[allow(dead_code)]
fn harmonic_number<F: Float + FromPrimitive>(n: u32) -> F {
let mut result = F::zero();
for i in 1..=n {
result = result + F::one() / const_f64::<F>(i as f64);
}
result
}
#[allow(dead_code)]
fn harmonic_sum_squared<F: Float + FromPrimitive>(n: u32) -> F {
let mut result = F::zero();
for i in 1..=n {
let i_sq = (i * i) as f64;
result = result + F::one() / const_f64::<F>(i_sq);
}
result
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_gamma_enhanced_integers() {
assert_relative_eq!(
gamma_enhanced(1.0).expect("test should succeed"),
1.0,
epsilon = 1e-10
);
assert_relative_eq!(
gamma_enhanced(2.0).expect("test should succeed"),
1.0,
epsilon = 1e-10
);
assert_relative_eq!(
gamma_enhanced(5.0).expect("test should succeed"),
24.0,
epsilon = 1e-10
);
assert_relative_eq!(
gamma_enhanced(6.0).expect("test should succeed"),
120.0,
epsilon = 1e-10
);
}
#[test]
fn test_gamma_enhanced_half_integers() {
let sqrt_pi = std::f64::consts::PI.sqrt();
assert_relative_eq!(
gamma_enhanced(0.5).expect("test should succeed"),
sqrt_pi,
epsilon = 1e-10
);
assert_relative_eq!(
gamma_enhanced(1.5).expect("test should succeed"),
sqrt_pi / 2.0,
epsilon = 1e-10
);
}
#[test]
fn test_gamma_enhanced_negative() {
let result = gamma_enhanced(-0.5).expect("test should succeed");
assert_relative_eq!(result, -2.0 * std::f64::consts::PI.sqrt(), epsilon = 1e-10);
}
#[test]
fn test_gamma_at_pole() {
assert!(gamma_enhanced(-1.0_f64).is_err());
assert!(gamma_enhanced(-2.0_f64).is_err());
}
#[test]
fn test_gamma_near_pole() {
let x = -1.0 + 1e-10_f64;
let result = gamma_enhanced(x);
assert!(result.is_ok());
assert!(result.expect("test should succeed").is_finite());
}
#[test]
fn test_gammaln_enhanced() {
assert_relative_eq!(
gammaln_enhanced(5.0).expect("test should succeed"),
24.0_f64.ln(),
epsilon = 1e-10
);
assert_relative_eq!(
gammaln_enhanced(0.5).expect("test should succeed"),
std::f64::consts::PI.sqrt().ln(),
epsilon = 1e-10
);
}
#[test]
fn test_rgamma() {
assert_relative_eq!(
rgamma(5.0).expect("test should succeed"),
1.0 / 24.0,
epsilon = 1e-10
);
assert_relative_eq!(
rgamma(-1.0).expect("test should succeed"),
0.0,
epsilon = 1e-10
);
}
#[test]
fn test_gamma_ratio() {
assert_relative_eq!(
gamma_ratio(5.0, 3.0).expect("test should succeed"),
12.0,
epsilon = 1e-10
);
assert_relative_eq!(
gamma_ratio(3.0, 5.0).expect("test should succeed"),
1.0 / 12.0,
epsilon = 1e-10
);
}
#[test]
fn test_pochhammer_enhanced() {
assert_relative_eq!(
pochhammer_enhanced(1.0, 4.0).expect("test should succeed"),
24.0,
epsilon = 1e-10
);
assert_relative_eq!(
pochhammer_enhanced(3.0, 2.0).expect("test should succeed"),
12.0,
epsilon = 1e-10
);
}
}