use crate::error::{SpecialError, SpecialResult};
use scirs2_core::numeric::{Float, FromPrimitive};
use std::fmt::Debug;
use super::approximations::stirling_approximation_ln;
use super::digamma::digamma;
#[allow(dead_code)]
pub(super) fn asymptotic_gamma_large_negative<F: Float + FromPrimitive + std::ops::AddAssign>(
x: F,
) -> F {
let x_f64 = x.to_f64().expect("Operation failed");
let n = (-x_f64).floor() as i32;
let _z = x + F::from(n).expect("Failed to convert to float");
let pi = F::from(std::f64::consts::PI).expect("Failed to convert to float");
let oneminus_x = F::one() - x;
let log_gamma_pos = stirling_approximation_ln(oneminus_x);
let log_sin_pi_x = enhanced_log_sin_pi_x(x);
let log_pi = pi.ln();
let sign: F = enhanced_reflection_sign(x_f64);
let log_result = log_pi - log_sin_pi_x - log_gamma_pos;
if log_result < F::from(std::f64::MAX.ln() * 0.9).expect("Operation failed") {
sign * log_result.exp()
} else if sign > F::zero() {
F::infinity()
} else {
F::neg_infinity()
}
}
#[allow(dead_code)]
pub(super) fn stable_gamma_near_large_negative_integer<
F: Float + FromPrimitive + std::ops::AddAssign,
>(
x: F,
n: i32,
) -> F {
let epsilon = x + F::from(n).expect("Failed to convert to float");
let n_f = F::from(n as f64).expect("Failed to convert to float");
let log_n_factorial = stirling_approximation_ln(n_f + F::one());
let sign = if n % 2 == 0 { F::one() } else { -F::one() };
let log_epsilon = epsilon.abs().ln();
let log_result = -log_n_factorial - log_epsilon;
if log_result < F::from(std::f64::MAX.ln() * 0.9).expect("Operation failed") {
sign / epsilon * log_result.exp()
} else if epsilon > F::zero() {
if sign > F::zero() {
F::infinity()
} else {
F::neg_infinity()
}
} else if sign > F::zero() {
F::neg_infinity()
} else {
F::infinity()
}
}
#[allow(dead_code)]
pub(super) fn enhanced_log_sin_pi_x<F: Float + FromPrimitive>(x: F) -> F {
let pi = F::from(std::f64::consts::PI).expect("Failed to convert to float");
let x_f64 = x.to_f64().expect("Operation failed");
let x_reduced = x_f64 - x_f64.floor();
let x_red = F::from(x_reduced).expect("Failed to convert to float");
if x_reduced < 0.5 {
(pi * x_red).sin().abs().ln()
} else {
let complement = F::one() - x_red;
(pi * complement).sin().abs().ln()
}
}
#[allow(dead_code)]
pub(super) fn enhanced_reflection_sign<F: Float + FromPrimitive>(xf64: f64) -> F {
let x_floor = xf64.floor();
let _n = x_floor as i32;
let fractional_part = xf64 - x_floor;
if fractional_part == 0.0 {
return F::nan();
}
let sin_sign = if fractional_part > 0.0 && fractional_part < 1.0 {
F::one()
} else {
-F::one()
};
if sin_sign > F::zero() {
F::one()
} else {
-F::one()
}
}
#[allow(dead_code)]
pub(super) fn validate_gamma_computation<
F: Float
+ FromPrimitive
+ Debug
+ std::ops::AddAssign
+ std::ops::SubAssign
+ std::ops::MulAssign
+ std::ops::DivAssign,
>(
x: F,
result: F,
) -> SpecialResult<F> {
let x_f64 = x.to_f64().expect("Operation failed");
if x.is_nan() {
return Err(SpecialError::DomainError("Input x is NaN".to_string()));
}
if x < F::zero() {
let nearest_int = x_f64.round() as i32;
if nearest_int <= 0 && (x_f64 - nearest_int as f64).abs() < 1e-14 {
return Err(SpecialError::DomainError(format!(
"Gamma function has a pole at x = {x_f64}"
)));
}
}
if result.is_nan() && !x.is_nan() {
return Err(SpecialError::ComputationError(format!(
"Gamma computation failed for x = {x_f64}, result is NaN"
)));
}
if result.is_infinite() {
if x_f64 > 171.5 {
return Ok(result);
} else if x_f64 < 0.0 && (x_f64 - x_f64.round()).abs() < 1e-12 {
return Ok(result);
} else {
return Err(SpecialError::ComputationError(format!(
"Unexpected overflow in gamma computation for x = {x_f64}"
)));
}
}
if result.is_zero() && x_f64 > 0.0 && x_f64 < 171.0 {
return Err(SpecialError::ComputationError(format!(
"Unexpected underflow in gamma computation for x = {x_f64}"
)));
}
let condition_estimate = estimate_gamma_condition_number(x);
if condition_estimate > 1e12 {
#[cfg(feature = "gpu")]
log::warn!(
"High condition number ({:.2e}) for gamma({}), result may be inaccurate",
condition_estimate,
x_f64
);
}
Ok(result)
}
#[allow(dead_code)]
pub(super) fn estimate_gamma_condition_number<
F: Float
+ FromPrimitive
+ std::fmt::Debug
+ std::ops::AddAssign
+ std::ops::SubAssign
+ std::ops::MulAssign
+ std::ops::DivAssign,
>(
x: F,
) -> f64 {
let x_f64 = x.to_f64().expect("Operation failed");
let h = 1e-8;
if x_f64 > 0.0 && x_f64 < 100.0 {
let psi_x = digamma(x).to_f64().unwrap_or(0.0);
(x_f64 * psi_x).abs()
} else {
if x_f64 > 100.0 {
x_f64.ln() } else {
1.0 / x_f64.abs() }
}
}
#[allow(dead_code)]
pub fn polygamma<
F: Float
+ FromPrimitive
+ std::fmt::Debug
+ std::ops::AddAssign
+ std::ops::SubAssign
+ std::ops::MulAssign
+ std::ops::DivAssign,
>(
n: u32,
x: F,
) -> F {
if x <= F::zero() {
return F::nan();
}
if n == 0 {
return digamma(x);
}
if (x - F::one()).abs() < F::from(1e-10).expect("Failed to convert constant to float") {
let zeta_value = match n {
1 => std::f64::consts::PI.powi(2) / 6.0, 2 => {
1.2020569031595942
}
3 => std::f64::consts::PI.powi(4) / 90.0, _ => {
0.0
}
};
if zeta_value != 0.0 {
let sign = if n.is_multiple_of(2) {
-F::one()
} else {
F::one()
};
let n_factorial = factorial_f(n);
return sign * F::from(n_factorial * zeta_value).expect("Failed to convert to float");
}
}
if x > F::from(20.0).expect("Failed to convert constant to float") {
let sign = if n.is_multiple_of(2) {
-F::one() } else {
F::one() };
let n_factorial = factorial_f(n);
let x_power = x.powi(n as i32 + 1);
let leading_term =
sign * F::from(n_factorial).expect("Failed to convert to float") / x_power;
let correction = F::from(n + 1).expect("Failed to convert to float")
/ (F::from(2.0).expect("Failed to convert constant to float") * x);
return leading_term * (F::one() + correction);
}
let sign = if n.is_multiple_of(2) {
-F::one() } else {
F::one() };
let n_factorial = factorial_f(n);
let mut sum = F::zero();
let n_plus_1 = n + 1;
for k in 0..10000 {
let term = (x + F::from(k).expect("Failed to convert to float")).powi(-(n_plus_1 as i32));
sum += term;
if k > 10
&& term.abs() < F::from(1e-16).expect("Failed to convert constant to float") * sum.abs()
{
break;
}
}
sign * F::from(n_factorial).expect("Failed to convert to float") * sum
}
#[allow(dead_code)]
fn factorial_f(n: u32) -> f64 {
match n {
0 | 1 => 1.0,
2 => 2.0,
3 => 6.0,
4 => 24.0,
5 => 120.0,
6 => 720.0,
7 => 5040.0,
8 => 40320.0,
9 => 362880.0,
10 => 3628800.0,
_ => {
let mut result = 1.0f64;
for i in 1..=n {
result *= i as f64;
}
result
}
}
}