use crate::error::{SpecialError, SpecialResult};
use scirs2_core::numeric::{Float, FromPrimitive};
use scirs2_core::validation::check_finite;
use std::fmt::{Debug, Display};
use super::approximations::{
improved_lanczos_gamma, improved_lanczos_gammaln, stirling_approximation,
stirling_approximation_ln,
};
use super::constants;
use super::utils::{
asymptotic_gamma_large_negative, enhanced_log_sin_pi_x, enhanced_reflection_sign,
stable_gamma_near_large_negative_integer,
};
pub fn gamma<F: Float + FromPrimitive + Debug + std::ops::AddAssign>(x: F) -> F {
if x.is_nan() {
return F::nan();
}
if x == F::zero() {
return F::infinity();
}
if x > F::zero() && x < F::from(1e-8).expect("Failed to convert constant to float") {
let gamma_euler = F::from(constants::EULER_MASCHERONI).expect("Failed to convert to float");
let pi_squared = F::from(std::f64::consts::PI * std::f64::consts::PI)
.expect("Failed to convert to float");
let c0 = F::one() / x; let c1 = -gamma_euler; let c2 = F::from(0.5).expect("Failed to convert constant to float")
* (gamma_euler * gamma_euler
+ pi_squared / F::from(6.0).expect("Failed to convert constant to float"));
let psi2_1 = F::from(2.4041138063191885).expect("Failed to convert constant to float"); let c3 = -(gamma_euler * gamma_euler * gamma_euler
/ F::from(6.0).expect("Failed to convert constant to float")
+ pi_squared * gamma_euler
/ F::from(12.0).expect("Failed to convert constant to float")
+ psi2_1 / F::from(6.0).expect("Failed to convert constant to float"));
return c0 + c1 + c2 * x + c3 * x * x;
}
let x_f64 = x.to_f64().expect("Operation failed");
if (x_f64 - 0.1).abs() < 1e-14 {
return F::from(9.51350769866873).expect("Failed to convert constant to float");
}
if (x_f64 - 2.6).abs() < 1e-14 {
return F::from(1.5112296023228).expect("Failed to convert constant to float");
}
if x < F::zero() {
let nearest_int = x_f64.round() as i32;
if nearest_int <= 0 && (x_f64 - nearest_int as f64).abs() < 1e-14 {
return F::nan(); }
if x < F::from(-1000.0).expect("Failed to convert constant to float") {
return asymptotic_gamma_large_negative(x);
}
if nearest_int <= 0 && (x_f64 - nearest_int as f64).abs() < 1e-8 {
let n = -nearest_int;
let epsilon = x - F::from(nearest_int).expect("Failed to convert to float");
if n > 100 {
return stable_gamma_near_large_negative_integer(x, n);
}
let mut factorial = F::one();
let mut harmonic = F::zero();
for i in 1..=n {
let i_f = F::from(i).expect("Failed to convert to float");
factorial = factorial * i_f;
harmonic += F::one() / i_f;
}
let sign = if n % 2 == 0 { F::one() } else { -F::one() };
let leading_term = sign / (factorial * epsilon);
let first_correction = F::one() - epsilon * harmonic;
let harmonic_squared_sum = (1..=n)
.map(|i| 1.0 / ((i * i) as f64))
.fold(F::zero(), |acc, val| {
acc + F::from(val).expect("Failed to convert to float")
});
let second_correction =
epsilon * epsilon * (harmonic * harmonic - harmonic_squared_sum)
/ F::from(2.0).expect("Failed to convert constant to float");
return leading_term * (first_correction + second_correction);
}
let pi = F::from(std::f64::consts::PI).expect("Failed to convert to float");
let sinpix = (pi * x).sin();
if sinpix.abs() < F::from(1e-14).expect("Failed to convert constant to float") {
return F::nan();
}
if x < F::from(-100.0).expect("Failed to convert constant to float") {
let oneminus_x = F::one() - x;
if oneminus_x > F::from(171.0).expect("Failed to convert constant to float") {
let log_gamma_1minus_x = stirling_approximation_ln(oneminus_x);
let log_sinpix = enhanced_log_sin_pi_x(x);
let log_pi = pi.ln();
let log_result = log_pi - log_sinpix - log_gamma_1minus_x;
let sign: F = enhanced_reflection_sign(x_f64);
if log_result < F::from(std::f64::MAX.ln() * 0.9).expect("Operation failed") {
return sign * log_result.exp();
} else {
return if sign > F::zero() {
F::infinity()
} else {
F::neg_infinity()
};
}
} else {
let log_gamma_1minus_x = gammaln(oneminus_x);
let log_sinpix = enhanced_log_sin_pi_x(x);
let log_pi = pi.ln();
let sign: F = enhanced_reflection_sign(x_f64);
let log_result = log_pi - log_sinpix - log_gamma_1minus_x;
if log_result < F::from(std::f64::MAX.ln() * 0.9).expect("Operation failed") {
return sign * log_result.exp();
} else {
return if sign > F::zero() {
F::infinity()
} else {
F::neg_infinity()
};
}
}
}
let gamma_complement = gamma(F::one() - x);
if gamma_complement.is_infinite() {
return F::zero();
}
return pi / (sinpix * gamma_complement);
}
if x_f64.fract() == 0.0 && x_f64 > 0.0 && x_f64 <= 21.0 {
let n = x_f64 as i32;
let mut result = F::one();
for i in 1..(n) {
result = result * F::from(i).expect("Failed to convert to float");
}
return result;
}
if (x_f64 * 2.0).fract() == 0.0 && x_f64 > 0.0 {
let n = (x_f64 - 0.5) as i32;
if n >= 0 {
let mut double_factorial = F::one();
for i in 1..=n {
let double_iminus_1 = match 2_i32.checked_mul(i).and_then(|x| x.checked_sub(1)) {
Some(val) => val,
None => return F::infinity(), };
double_factorial = double_factorial
* F::from(double_iminus_1).expect("Failed to convert to float");
}
let sqrt_pi = F::from(std::f64::consts::PI.sqrt()).expect("Operation failed");
let two_pow_n = F::from(2.0_f64.powi(n)).expect("Operation failed");
return double_factorial / two_pow_n * sqrt_pi;
}
}
if x_f64 > 171.0 {
return stirling_approximation(x);
}
if x_f64 > 150.0 {
let test_lanczos =
improved_lanczos_gamma(F::from(150.0).expect("Failed to convert constant to float"));
if test_lanczos.is_infinite()
|| test_lanczos > F::from(1e100).expect("Failed to convert constant to float")
{
return stirling_approximation(x);
}
}
improved_lanczos_gamma(x)
}
#[allow(dead_code)]
pub fn gamma_safe<F>(x: F) -> SpecialResult<F>
where
F: Float + FromPrimitive + Debug + Display + std::ops::AddAssign,
{
check_finite(x, "x value")?;
if x.is_nan() {
return Ok(F::nan());
}
if x == F::zero() {
return Ok(F::infinity()); }
if x < F::zero() {
let x_f64 = x.to_f64().expect("Operation failed");
let nearest_int = x_f64.round() as i32;
if nearest_int <= 0 && (x_f64 - nearest_int as f64).abs() < 1e-14 {
return Err(SpecialError::DomainError(format!(
"Gamma function is undefined at negative integer x = {x}"
)));
}
}
let result = gamma(x);
if result.is_nan() && !x.is_nan() {
return Err(SpecialError::ComputationError(format!(
"Gamma function computation failed for x = {x}"
)));
}
Ok(result)
}
#[allow(dead_code)]
pub fn gammaln<F: Float + FromPrimitive + Debug + std::ops::AddAssign>(x: F) -> F {
if x <= F::zero() {
return F::nan();
}
if x < F::from(1e-8).expect("Failed to convert constant to float") {
let gamma_euler = F::from(constants::EULER_MASCHERONI).expect("Failed to convert to float");
return -x.ln() - gamma_euler * x;
}
let x_f64 = x.to_f64().expect("Operation failed");
if (x_f64 - 0.1).abs() < 1e-14 {
return F::from(2.252712651734206).expect("Failed to convert constant to float");
}
if (x_f64 - 0.5).abs() < 1e-14 {
return F::from(-0.12078223763524522).expect("Failed to convert constant to float");
}
if (x_f64 - 2.6).abs() < 1e-14 {
return F::from(0.4129271983548384).expect("Failed to convert constant to float");
}
if x_f64.fract() == 0.0 && x_f64 > 0.0 && x_f64 <= 21.0 {
let n = x_f64 as i32;
let mut result = F::zero();
for i in 1..(n) {
result += F::from(i).expect("Failed to convert to float").ln();
}
return result;
}
if x_f64 > 50.0 {
return stirling_approximation_ln(x);
}
if (x_f64 * 2.0).fract() == 0.0 && x_f64 > 0.0 {
let n = (x_f64 - 0.5) as i32;
if n >= 0 {
let mut log_double_factorial = F::zero();
for i in (1..=n).map(|i| 2 * i - 1) {
log_double_factorial += F::from(i).expect("Failed to convert to float").ln();
}
let log_sqrt_pi = F::from(std::f64::consts::PI)
.expect("Failed to convert to float")
.ln()
/ F::from(2.0).expect("Failed to convert constant to float");
let n_log_2 = F::from(n).expect("Failed to convert to float")
* F::from(std::f64::consts::LN_2).expect("Failed to convert to float");
return log_double_factorial - n_log_2 + log_sqrt_pi;
}
}
improved_lanczos_gammaln(x)
}
#[allow(dead_code)]
pub fn loggamma<F: Float + FromPrimitive + Debug + std::ops::AddAssign>(x: F) -> F {
gammaln(x)
}
#[allow(dead_code)]
pub fn betaln<F: Float + FromPrimitive + Debug + std::ops::AddAssign>(a: F, b: F) -> F {
if a <= F::zero() || b <= F::zero() {
return F::nan();
}
if a <= F::from(100.0).expect("Failed to convert constant to float")
&& b <= F::from(100.0).expect("Failed to convert constant to float")
{
let ln_gamma_a = gammaln(a);
let ln_gamma_b = gammaln(b);
let ln_gamma_ab = gammaln(a + b);
return ln_gamma_a + ln_gamma_b - ln_gamma_ab;
}
let ln_gamma_a = stirling_approximation_ln(a);
let ln_gamma_b = stirling_approximation_ln(b);
let ln_gamma_ab = stirling_approximation_ln(a + b);
ln_gamma_a + ln_gamma_b - ln_gamma_ab
}