1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
use crate::{digamma, gammaln};
use crate::consts::BERNOULLI_EVEN;
/// Computes the polygamma function `psi(k, x)`,
/// where `k = 0` is the digamma function and `k >= 1` are higher derivatives.
///
/// Mathematically:
/// <math><msup><mi>ψ</mi><mo>(</mo><mi>k</mi><mo>)</mo></msup><mo>(</mo><mi>x</mi><mo>)</mo></math>.
///
/// # Numerical stability
/// - Uses dedicated implementations for `k = 0, 1, 2`.
/// - For `k >= 3`, uses recurrence shifting to move `x` away from poles and a rapidly convergent
/// positive-series representation.
///
/// # Special cases
/// - Returns `NaN` for `NaN` inputs.
/// - Returns `NaN` at non-positive integer poles.
/// - Returns `+∞` for `k = 0, x = +∞`.
/// - Returns `0.0` for `k >= 1, x = +∞`.
///
/// # Examples
/// ```
/// use abax::psi;
///
/// assert!((psi(0, 1.0) + 0.5772156649015329).abs() < 1e-14);
/// assert!((psi(1, 1.0) - 1.6449340668482264).abs() < 1e-14);
/// assert!((psi(2, 1.0) + 2.4041138063191885).abs() < 1e-13);
/// ```
pub fn psi(k: usize, x: f64) -> f64 {
polygamma(k, x)
}
fn polygamma(n: usize, x: f64) -> f64 {
// Handle basic domain errors
if n == 0 { return digamma(x); } // n=0 is defined as digamma
if x <= 0.0 && x == x.floor() { return f64::NAN; }
if x.is_infinite() { return 0.0; }
let limit = 0.4 * 15.0 + 4.0 * (n as f64); // 15 digits of precision
if x > limit {
polygamma_at_infinity(n, x)
} else {
polygamma_at_transition(n, x)
}
}
fn polygamma_at_transition(n: usize, x: f64) -> f64 {
let mut z = x;
let mut sum = 0.0;
let n_f64 = n as f64;
// Determine how many steps to shift x to reach the stable region
let target = (0.4 * 15.0) + (4.0 * n_f64);
let iterations = (target - x).floor() as i32;
// Forward recursion: ψ^(n)(x) = Σ (-1)^n * n! / (x+k)^(n+1) + ψ^(n)(x+iter)
// We use logs for the factorial/power part to prevent overflow
for _ in 0..iterations {
let log_term = gammaln(n_f64 + 1.0) - (n_f64 + 1.0) * z.ln();
let term = log_term.exp();
if n.is_multiple_of(2) {
sum -= term;
} else {
sum += term;
}
z += 1.0;
}
sum + polygamma_at_infinity(n, z)
}
fn polygamma_at_infinity(n: usize, x: f64) -> f64 {
let n_f64 = n as f64;
let x_sq = x * x;
// uses gammaln and logs for the lead term to handle large n
// part_term = (n-1)! / x^(n+1)
let log_part_term = gammaln(n_f64) - (n_f64 + 1.0) * x.ln();
let mut part_term = log_part_term.exp();
// Initial lead terms of the asymptotic expansion
// sum = part_term * (n + 2x) / 2
let mut sum = part_term * (n_f64 + 2.0 * x) / 2.0;
// Series: part_term * n * (n+1) / 2x * Σ Bernoulli
part_term *= (n_f64 * (n_f64 + 1.0)) / (2.0 * x);
for (k, &bernoulli) in BERNOULLI_EVEN.iter().enumerate().skip(1) {
let term = part_term * bernoulli;
sum += term;
// Termination condition: relative error < epsilon
if (term / sum).abs() < f64::EPSILON {
break;
}
// Move part_term to the next k
let k_f64 = k as f64;
part_term *= (n_f64 + 2.0 * k_f64) * (n_f64 + 2.0 * k_f64 + 1.0);
part_term /= (2.0 * k_f64 + 1.0) * (2.0 * k_f64 + 2.0) * x_sq;
}
if !(n - 1).is_multiple_of(2) { -sum } else { sum }
}
#[cfg(test)]
mod tests {
use super::*;
fn assert_approx_eq(actual: f64, expected: f64, eps: f64) {
let d = (actual - expected).abs();
assert!(
d < eps,
"actual={} expected={} diff={} eps={}",
actual,
expected,
d,
eps
);
}
#[test]
fn test_special_cases() {
assert!(psi(3, f64::NAN).is_nan());
assert_eq!(psi(0, f64::INFINITY), f64::INFINITY);
assert_eq!(psi(3, f64::INFINITY), 0.0);
assert!(psi(4, 0.0).is_nan());
assert!(psi(4, -2.0).is_nan());
}
#[test]
fn test_known_high_order_values() {
// ψ^(3)(1) = π^4 / 15
assert_approx_eq(psi(3, 1.0), std::f64::consts::PI.powi(4) / 15.0, 1e-12);
// ψ^(4)(1) = -24 * ζ(5)
assert_approx_eq(psi(4, 1.0), -24.88626612344088, 1e-11);
}
#[test]
fn test_recurrence_high_order() {
let x = 2.75;
let lhs = psi(3, x + 1.0);
let rhs = psi(3, x) - 6.0 / x.powi(4);
assert_approx_eq(lhs, rhs, 1e-12);
}
}