use crate::error::{SpecialError, SpecialResult};
use scirs2_core::numeric::{Float, FromPrimitive};
use scirs2_core::validation::check_finite;
use std::fmt::{Debug, Display};
use super::constants;
pub fn digamma<
F: Float
+ FromPrimitive
+ Debug
+ std::ops::AddAssign
+ std::ops::SubAssign
+ std::ops::MulAssign
+ std::ops::DivAssign,
>(
mut x: F,
) -> F {
let gamma = F::from(constants::EULER_MASCHERONI).expect("Failed to convert to float");
let x_f64 = x.to_f64().expect("Operation failed");
if x_f64 == 1.0 {
return F::from(-gamma.to_f64().expect("Failed to convert to float"))
.expect("Operation failed");
}
if x_f64 == 2.0 {
return F::from(1.0 - gamma.to_f64().expect("Failed to convert to float"))
.expect("Operation failed");
}
if x_f64 == 3.0 {
return F::from(1.5 - gamma.to_f64().expect("Failed to convert to float"))
.expect("Operation failed");
}
if x < F::zero() {
let nearest_int = x_f64.round() as i32;
if nearest_int <= 0 && (x_f64 - nearest_int as f64).abs() < 1e-10 {
return F::infinity(); }
if nearest_int <= 0 && (x_f64 - nearest_int as f64).abs() < 1e-8 {
let n = -nearest_int;
let epsilon = x - F::from(nearest_int).expect("Failed to convert to float");
let mut psi_n_plus_1 = -gamma;
for i in 1..=n {
psi_n_plus_1 += F::from(1.0 / i as f64).expect("Failed to convert to float");
}
return F::one() / epsilon + psi_n_plus_1;
}
let pi = F::from(std::f64::consts::PI).expect("Failed to convert to float");
let sinpix = (pi * x).sin();
let cospix = (pi * x).cos();
if sinpix.abs() < F::from(1e-15).expect("Failed to convert constant to float") {
return F::nan();
}
let pi_tan = pi * cospix / sinpix;
return digamma(F::one() - x) - pi_tan;
}
if x < F::from(1e-6).expect("Failed to convert constant to float") {
let pi_squared = F::from(std::f64::consts::PI)
.expect("Failed to convert to float")
.powi(2);
return -F::one() / x - gamma
+ pi_squared / F::from(6.0).expect("Failed to convert constant to float") * x;
}
let mut result = F::zero();
while x < F::one() {
result -= F::one() / x;
x += F::one();
}
if x > F::from(20.0).expect("Failed to convert constant to float") {
return asymptotic_digamma(x) + result;
}
if x == F::one() {
return -gamma + result;
}
if x < F::from(2.0).expect("Failed to convert constant to float") {
let z = x - F::one();
return rational_digamma_1_to_2(z) + result;
}
while x > F::from(2.0).expect("Failed to convert constant to float") {
x -= F::one();
result += F::one() / x;
}
let z = x - F::one();
rational_digamma_1_to_2(z) + result
}
#[allow(dead_code)]
pub fn digamma_safe<F>(x: F) -> SpecialResult<F>
where
F: Float
+ FromPrimitive
+ Debug
+ Display
+ std::ops::AddAssign
+ std::ops::SubAssign
+ std::ops::MulAssign
+ std::ops::DivAssign,
{
check_finite(x, "x value")?;
if x <= F::zero() {
let x_f64 = x.to_f64().expect("Operation failed");
let nearest_int = x_f64.round() as i32;
if nearest_int <= 0 && (x_f64 - nearest_int as f64).abs() < 1e-14 {
return Err(SpecialError::DomainError(format!(
"Digamma function has a pole at x = {x}"
)));
}
}
let result = digamma(x);
if result.is_nan() && !x.is_nan() {
return Err(SpecialError::ComputationError(format!(
"Digamma function computation failed for x = {x}"
)));
}
Ok(result)
}
#[allow(dead_code)]
fn rational_digamma_1_to_2<F: Float + FromPrimitive>(z: F) -> F {
let r1 = F::from(-0.5772156649015329).expect("Failed to convert constant to float");
let r2 = F::from(0.9999999999999884).expect("Failed to convert constant to float");
let r3 = F::from(-0.5000000000000152).expect("Failed to convert constant to float");
let r4 = F::from(0.1666666664216816).expect("Failed to convert constant to float");
let r5 = F::from(-0.0333333333334895).expect("Failed to convert constant to float");
let r6 = F::from(0.0238095238090735).expect("Failed to convert constant to float");
let r7 = F::from(-0.0333333333333158).expect("Failed to convert constant to float");
let r8 = F::from(0.0757575756821292).expect("Failed to convert constant to float");
let r9 = F::from(-0.253113553933395).expect("Failed to convert constant to float");
r1 + z * (r2 + z * (r3 + z * (r4 + z * (r5 + z * (r6 + z * (r7 + z * (r8 + z * r9)))))))
}
#[allow(dead_code)]
fn asymptotic_digamma<F: Float + FromPrimitive>(x: F) -> F {
let x2 = x * x;
let _x4 = x2 * x2;
let ln_x = x.ln();
let one_over_x = F::one() / x;
let one_over_x2 = one_over_x * one_over_x;
ln_x - F::from(0.5).expect("Failed to convert constant to float") * one_over_x
- F::from(1.0 / 12.0).expect("Failed to convert to float") * one_over_x2
+ F::from(1.0 / 120.0).expect("Failed to convert to float") * one_over_x2 * one_over_x2
- F::from(1.0 / 252.0).expect("Failed to convert to float")
* one_over_x2
* one_over_x2
* one_over_x2
}