use crate::{digamma, gammaln};
use crate::consts::BERNOULLI_EVEN;
pub fn psi(k: usize, x: f64) -> f64 {
polygamma(k, x)
}
fn polygamma(n: usize, x: f64) -> f64 {
if n == 0 { return digamma(x); } if x <= 0.0 && x == x.floor() { return f64::NAN; }
if x.is_infinite() { return 0.0; }
let limit = 0.4 * 15.0 + 4.0 * (n as f64);
if x > limit {
polygamma_at_infinity(n, x)
} else {
polygamma_at_transition(n, x)
}
}
fn polygamma_at_transition(n: usize, x: f64) -> f64 {
let mut z = x;
let mut sum = 0.0;
let n_f64 = n as f64;
let target = (0.4 * 15.0) + (4.0 * n_f64);
let iterations = (target - x).floor() as i32;
for _ in 0..iterations {
let log_term = gammaln(n_f64 + 1.0) - (n_f64 + 1.0) * z.ln();
let term = log_term.exp();
if n % 2 == 0 {
sum -= term;
} else {
sum += term;
}
z += 1.0;
}
sum + polygamma_at_infinity(n, z)
}
fn polygamma_at_infinity(n: usize, x: f64) -> f64 {
let n_f64 = n as f64;
let x_sq = x * x;
let log_part_term = gammaln(n_f64) - (n_f64 + 1.0) * x.ln();
let mut part_term = log_part_term.exp();
let mut sum = part_term * (n_f64 + 2.0 * x) / 2.0;
part_term *= (n_f64 * (n_f64 + 1.0)) / (2.0 * x);
for k in 1..BERNOULLI_EVEN.iter().len() {
let term = part_term * BERNOULLI_EVEN[k];
sum += term;
if (term / sum).abs() < f64::EPSILON {
break;
}
let k_f64 = k as f64;
part_term *= (n_f64 + 2.0 * k_f64) * (n_f64 + 2.0 * k_f64 + 1.0);
part_term /= (2.0 * k_f64 + 1.0) * (2.0 * k_f64 + 2.0) * x_sq;
}
if (n - 1) % 2 != 0 { -sum } else { sum }
}
#[cfg(test)]
mod tests {
use super::*;
fn assert_approx_eq(actual: f64, expected: f64, eps: f64) {
let d = (actual - expected).abs();
assert!(
d < eps,
"actual={} expected={} diff={} eps={}",
actual,
expected,
d,
eps
);
}
#[test]
fn test_special_cases() {
assert!(psi(3, f64::NAN).is_nan());
assert_eq!(psi(0, f64::INFINITY), f64::INFINITY);
assert_eq!(psi(3, f64::INFINITY), 0.0);
assert!(psi(4, 0.0).is_nan());
assert!(psi(4, -2.0).is_nan());
}
#[test]
fn test_known_high_order_values() {
assert_approx_eq(psi(3, 1.0), std::f64::consts::PI.powi(4) / 15.0, 1e-12);
assert_approx_eq(psi(4, 1.0), -24.88626612344088, 1e-11);
}
#[test]
fn test_recurrence_high_order() {
let x = 2.75;
let lhs = psi(3, x + 1.0);
let rhs = psi(3, x) - 6.0 / x.powi(4);
assert_approx_eq(lhs, rhs, 1e-12);
}
}