use statrs::distribution::{
ChiSquared, Continuous, ContinuousCDF, FisherSnedecor, Normal, StudentsT,
};
use statrs::function::gamma::gamma_ur;
pub use statrs::function::gamma::{digamma, ln_gamma};
pub(crate) const LN_2PI: f64 = 1.837_877_066_409_345_3;
pub fn trigamma(x: f64) -> f64 {
if x.is_nan() {
return f64::NAN;
}
if x <= 0.0 && x == x.floor() {
return f64::INFINITY;
}
let mut x = x;
let mut result = 0.0;
while x < 20.0 {
result += 1.0 / (x * x);
x += 1.0;
}
let inv = 1.0 / x;
let inv2 = inv * inv;
result
+ inv
* (1.0
+ inv
* (0.5
+ inv
* (1.0 / 6.0
+ inv2
* (-1.0 / 30.0 + inv2 * (1.0 / 42.0 + inv2 * (-1.0 / 30.0))))))
}
pub fn tetragamma(x: f64) -> f64 {
if x.is_nan() {
return f64::NAN;
}
let mut x = x;
let mut result = 0.0;
while x < 20.0 {
result -= 2.0 / (x * x * x);
x += 1.0;
}
let inv = 1.0 / x;
let inv2 = inv * inv;
result - inv2 * (1.0 + inv * (1.0 + inv * (0.5 + inv2 * (-1.0 / 6.0 + inv2 * (1.0 / 6.0)))))
}
pub fn trigamma_inverse(x: f64) -> f64 {
if x.is_nan() {
return f64::NAN;
}
if x < 0.0 {
return f64::NAN;
}
if x > 1e7 {
return 1.0 / x.sqrt();
}
if x < 1e-6 {
return 1.0 / x;
}
let mut y = 0.5 + 1.0 / x;
let mut it = 0;
loop {
it += 1;
let tri = trigamma(y);
let dif = tri * (1.0 - tri / x) / tetragamma(y);
y += dif;
if (-dif / y) < 1e-8 {
break;
}
if it > 50 {
break;
}
}
y
}
pub fn logmdigamma(x: f64) -> f64 {
if x.is_nan() || x <= 0.0 {
return f64::NAN;
}
if x < 5.0 {
return (x / (x + 5.0)).ln()
+ logmdigamma(x + 5.0)
+ 1.0 / x
+ 1.0 / (x + 1.0)
+ 1.0 / (x + 2.0)
+ 1.0 / (x + 3.0)
+ 1.0 / (x + 4.0);
}
let t = 1.0 / (x * x);
let tail = t
* (-1.0 / 12.0
+ t * (1.0 / 120.0
+ t * (-1.0 / 252.0
+ t * (1.0 / 240.0
+ t * (-1.0 / 132.0
+ t * (691.0 / 32760.0
+ t * (-1.0 / 12.0 + (3617.0 * t) / 8160.0)))))));
1.0 / (2.0 * x) - tail
}
pub fn t_two_sided_pvalue(t: f64, df: f64) -> f64 {
if t.is_nan() || df <= 0.0 {
return f64::NAN;
}
let dist = StudentsT::new(0.0, 1.0, df).expect("valid df");
2.0 * dist.cdf(-t.abs())
}
pub fn t_sf(x: f64, df: f64) -> f64 {
let dist = StudentsT::new(0.0, 1.0, df).expect("valid df");
dist.cdf(-x)
}
pub fn t_isf(p: f64, df: f64) -> f64 {
let dist = StudentsT::new(0.0, 1.0, df).expect("valid df");
dist.inverse_cdf(1.0 - p)
}
pub fn t_ppf(p: f64, df: f64) -> f64 {
let dist = StudentsT::new(0.0, 1.0, df).expect("valid df");
dist.inverse_cdf(p)
}
pub(crate) fn ln_norm_cdf(z: f64) -> f64 {
if z > -10.0 {
Normal::new(0.0, 1.0).unwrap().cdf(z).ln()
} else {
let z2 = z * z;
let log_phi = -0.5 * LN_2PI - 0.5 * z2;
let series = 1.0 - 1.0 / z2 + 3.0 / (z2 * z2) - 15.0 / (z2 * z2 * z2);
log_phi - (-z).ln() + series.ln()
}
}
pub(crate) fn norm_isf_log(lp: f64) -> f64 {
if lp >= 0.0 {
return f64::NEG_INFINITY;
}
let mut z = if lp > -700.0 {
-Normal::new(0.0, 1.0).unwrap().inverse_cdf(lp.exp())
} else {
(-2.0 * lp).sqrt()
};
for _ in 0..80 {
let lc = ln_norm_cdf(-z);
let lpdf = -0.5 * LN_2PI - 0.5 * z * z;
let h = lc - lp;
let dh = -(lpdf - lc).exp();
let step = h / dh;
z -= step;
if step.abs() < 1e-13 * (1.0 + z.abs()) {
break;
}
}
z
}
pub fn ln_beta(a: f64, b: f64) -> f64 {
ln_gamma(a) + ln_gamma(b) - ln_gamma(a + b)
}
fn betacf(x: f64, a: f64, b: f64) -> f64 {
const TINY: f64 = 1e-30;
const EPS: f64 = 3e-16;
const MAXIT: usize = 400;
let qab = a + b;
let qap = a + 1.0;
let qam = a - 1.0;
let mut c = 1.0;
let mut d = 1.0 - qab * x / qap;
if d.abs() < TINY {
d = TINY;
}
d = 1.0 / d;
let mut h = d;
for m in 1..=MAXIT {
let m = m as f64;
let m2 = 2.0 * m;
let aa = m * (b - m) * x / ((qam + m2) * (a + m2));
d = 1.0 + aa * d;
if d.abs() < TINY {
d = TINY;
}
c = 1.0 + aa / c;
if c.abs() < TINY {
c = TINY;
}
d = 1.0 / d;
h *= d * c;
let aa = -(a + m) * (qab + m) * x / ((a + m2) * (qap + m2));
d = 1.0 + aa * d;
if d.abs() < TINY {
d = TINY;
}
c = 1.0 + aa / c;
if c.abs() < TINY {
c = TINY;
}
d = 1.0 / d;
let del = d * c;
h *= del;
if (del - 1.0).abs() <= EPS {
break;
}
}
h
}
fn ln_1m_exp(z: f64) -> f64 {
if z > -std::f64::consts::LN_2 {
(-z.exp_m1()).ln()
} else {
(-z.exp()).ln_1p()
}
}
pub fn betai(x: f64, a: f64, b: f64) -> f64 {
if x <= 0.0 {
return 0.0;
}
if x >= 1.0 {
return 1.0;
}
let lfront = a * x.ln() + b * (1.0 - x).ln() - ln_beta(a, b);
let front = lfront.exp();
if x < (a + 1.0) / (a + b + 2.0) {
front * betacf(x, a, b) / a
} else {
1.0 - front * betacf(1.0 - x, b, a) / b
}
}
pub fn ln_betai(x: f64, a: f64, b: f64) -> f64 {
if x <= 0.0 {
return f64::NEG_INFINITY;
}
if x >= 1.0 {
return 0.0;
}
let lfront = a * x.ln() + b * (1.0 - x).ln() - ln_beta(a, b);
if x < (a + 1.0) / (a + b + 2.0) {
lfront + (betacf(x, a, b) / a).ln()
} else {
let other = lfront + (betacf(1.0 - x, b, a) / b).ln();
ln_1m_exp(other)
}
}
pub fn f_sf(x: f64, dfn: f64, dfd: f64) -> f64 {
if x <= 0.0 {
return 1.0;
}
let y = dfd / (dfd + dfn * x);
betai(y, dfd / 2.0, dfn / 2.0)
}
pub fn f_lsf(x: f64, dfn: f64, dfd: f64) -> f64 {
if x <= 0.0 {
return 0.0;
}
let y = dfd / (dfd + dfn * x);
ln_betai(y, dfd / 2.0, dfn / 2.0)
}
pub fn f_pdf(x: f64, dfn: f64, dfd: f64) -> f64 {
if x <= 0.0 {
return 0.0;
}
FisherSnedecor::new(dfn, dfd).expect("valid df").pdf(x)
}
pub fn f_qf(p: f64, dfn: f64, dfd: f64) -> f64 {
FisherSnedecor::new(dfn, dfd)
.expect("valid df")
.inverse_cdf(p)
}
pub fn chi2_sf(x: f64, df: f64) -> f64 {
if x <= 0.0 {
return 1.0;
}
gamma_ur(df / 2.0, x / 2.0)
}
pub fn chisq_pdf(x: f64, df: f64) -> f64 {
if x < 0.0 {
return 0.0;
}
ChiSquared::new(df).expect("valid df").pdf(x)
}
pub fn chisq_qf(p: f64, df: f64) -> f64 {
ChiSquared::new(df).expect("valid df").inverse_cdf(p)
}
fn gamma_series_sum(a: f64, x: f64) -> f64 {
const EPS: f64 = 3e-16;
let mut ap = a;
let mut del = 1.0 / a;
let mut sum = del;
for _ in 0..1000 {
ap += 1.0;
del *= x / ap;
sum += del;
if del.abs() < sum.abs() * EPS {
break;
}
}
sum
}
fn gamma_cf(a: f64, x: f64) -> f64 {
const EPS: f64 = 3e-16;
const FPMIN: f64 = 1e-300;
let mut b = x + 1.0 - a;
let mut c = 1.0 / FPMIN;
let mut d = 1.0 / b;
let mut h = d;
for i in 1..1000 {
let an = -(i as f64) * (i as f64 - a);
b += 2.0;
d = an * d + b;
if d.abs() < FPMIN {
d = FPMIN;
}
c = b + an / c;
if c.abs() < FPMIN {
c = FPMIN;
}
d = 1.0 / d;
let del = d * c;
h *= del;
if (del - 1.0).abs() <= EPS {
break;
}
}
h
}
pub fn ln_gamma_lr(a: f64, x: f64) -> f64 {
if x <= 0.0 {
return f64::NEG_INFINITY;
}
let lpref = a * x.ln() - x - ln_gamma(a);
if x < a + 1.0 {
lpref + gamma_series_sum(a, x).ln()
} else {
ln_1m_exp(lpref + gamma_cf(a, x).ln())
}
}
pub fn ln_gamma_ur(a: f64, x: f64) -> f64 {
if x <= 0.0 {
return 0.0;
}
let lpref = a * x.ln() - x - ln_gamma(a);
if x < a + 1.0 {
ln_1m_exp(lpref + gamma_series_sum(a, x).ln())
} else {
lpref + gamma_cf(a, x).ln()
}
}
pub fn gauss_legendre_01(n: usize) -> (Vec<f64>, Vec<f64>) {
let mut nodes = vec![0.0_f64; n];
let mut weights = vec![0.0_f64; n];
let nn = n as f64;
let half = n.div_ceil(2);
for i in 0..half {
let mut z = (std::f64::consts::PI * (i as f64 + 0.75) / (nn + 0.5)).cos();
let mut pp;
loop {
let mut p1 = 1.0;
let mut p2 = 0.0;
for j in 0..n {
let p3 = p2;
p2 = p1;
let jj = j as f64;
p1 = ((2.0 * jj + 1.0) * z * p2 - jj * p3) / (jj + 1.0);
}
pp = nn * (z * p1 - p2) / (z * z - 1.0);
let z1 = z;
z = z1 - p1 / pp;
if (z - z1).abs() <= 1e-15 {
break;
}
}
let w = 1.0 / ((1.0 - z * z) * pp * pp);
nodes[i] = (1.0 - z) / 2.0;
nodes[n - 1 - i] = (1.0 + z) / 2.0;
weights[i] = w;
weights[n - 1 - i] = w;
}
(nodes, weights)
}
#[cfg(test)]
mod tests {
use super::*;
fn close(a: f64, b: f64, tol: f64) -> bool {
(a - b).abs() <= tol * (1.0 + b.abs())
}
#[test]
fn trigamma_matches_reference() {
assert!(close(trigamma(1.0), 1.6449340668482264, 1e-12));
assert!(close(trigamma(2.5), 0.49035775610023737, 1e-12));
assert!(close(trigamma(10.0), 0.10516633568168574, 1e-12));
assert!(close(trigamma(0.5), 4.934802200544679, 1e-12));
}
#[test]
fn tetragamma_matches_reference() {
assert!(close(tetragamma(1.0), -2.404113806319188, 1e-10));
assert!(close(tetragamma(2.5), -0.236_204_051_641_727_4, 1e-10));
assert!(close(tetragamma(10.0), -0.011049834970802069, 1e-10));
}
#[test]
fn trigamma_inverse_roundtrip() {
for &y in &[0.3_f64, 1.0, 2.7, 5.0, 20.0] {
let x = trigamma(y);
let yy = trigamma_inverse(x);
assert!(close(yy, y, 1e-7), "y={} got={}", y, yy);
}
}
#[test]
fn logmdigamma_matches_statmod() {
let cases = [
(0.25_f64, 2.841_159_172_256_119),
(0.5, 1.270_362_845_461_365),
(1.0, 0.577_215_664_901_508_4),
(2.0, 0.270_362_845_461_476_4),
(3.0, 0.17582795356964234),
(4.0, 0.1301766926880901),
(4.9999, 0.10332237634182319),
(5.0, 0.10332024400169718),
(7.5, 0.068_145_536_296_177_49),
(50.0, 0.010033332000253862),
(500.0, 0.0010003333332000003),
(4999.0, 0.0001000233386678536),
];
for (x, want) in cases {
assert!(
close(logmdigamma(x), want, 1e-13),
"x={x} got={}",
logmdigamma(x)
);
}
assert!(close(logmdigamma(3.0), 3.0_f64.ln() - digamma(3.0), 1e-12));
assert!(logmdigamma(0.0).is_nan());
assert!(logmdigamma(-1.0).is_nan());
assert!(logmdigamma(f64::NAN).is_nan());
}
#[test]
fn betai_matches_r_pbeta() {
assert!(close(betai(0.3, 2.5, 4.0), 0.352_197_585_906_767_14, 1e-13));
assert!(close(betai(0.85, 3.0, 2.0), 0.890_481_249_999_999_9, 1e-13));
let s = betai(0.62, 3.3, 7.1) + betai(0.38, 7.1, 3.3);
assert!(close(s, 1.0, 1e-13), "reflection sum={}", s);
}
#[test]
fn f_sf_matches_r() {
assert!(close(
f_sf(5.0, 3.0, 10.0),
0.022_613_922_751_096_315,
1e-12
));
assert!(close(
f_sf(50.0, 3.0, 10.0),
2.513_470_418_384_048_3e-6,
1e-12
));
}
#[test]
fn f_lsf_matches_r_in_tail() {
assert!(close(
f_lsf(50.0, 3.0, 10.0),
-12.893_846_122_976_313,
1e-11
));
assert!(close(
f_lsf(200.0, 4.0, 8.0),
-16.858_996_483_121_434,
1e-11
));
assert!(close(f_lsf(1e4, 2.0, 3.0), -13.207_537_878_928_715, 1e-11));
assert!(close(
f_lsf(5.0, 3.0, 10.0),
f_sf(5.0, 3.0, 10.0).ln(),
1e-12
));
}
#[test]
fn gauss_legendre_quadrature() {
let (nodes, weights) = gauss_legendre_01(128);
let sw: f64 = weights.iter().sum();
let swx: f64 = nodes.iter().zip(&weights).map(|(x, w)| w * x).sum();
let swx2: f64 = nodes.iter().zip(&weights).map(|(x, w)| w * x * x).sum();
let swx5: f64 = nodes.iter().zip(&weights).map(|(x, w)| w * x.powi(5)).sum();
assert!(close(sw, 1.0, 1e-13), "sum w={}", sw);
assert!(close(swx, 0.5, 1e-13), "sum wx={}", swx);
assert!(close(swx2, 1.0 / 3.0, 1e-13), "sum wx2={}", swx2);
assert!(close(swx5, 1.0 / 6.0, 1e-13), "sum wx5={}", swx5);
assert!(close(nodes[0], 8.755_602_643_395_477e-5, 1e-9));
assert!(close(weights[0], 0.000_224_690_480_146_078_16, 1e-9));
assert!(close(nodes[127], 0.999_912_443_973_565_8, 1e-12));
}
#[test]
fn ln_gamma_incomplete_matches_r_pgamma() {
let x = [0.5, 2.0, 8.0, 0.01, 50.0, 100.0, 3.0];
let shape = [2.0, 2.0, 5.0, 5.0, 2.0, 3.0, 3.0];
let lower = [
-2.40568139136037,
-0.520885807664344,
-0.104952155145049,
-27.821675013441,
-9.83662422461599e-21,
-1.89761075536825e-40,
-0.550242496777221,
];
let upper = [
-0.0945348918918356,
-0.90138771133189,
-2.3062678611973,
-8.26418564180996e-13,
-46.0681743672757,
-91.4628081220771,
-0.859933836503729,
];
let rel = |a: f64, b: f64| ((a - b) / b).abs();
for i in 0..x.len() {
assert!(
rel(ln_gamma_lr(shape[i], x[i]), lower[i]) < 1e-9,
"lower i={i}: got {}",
ln_gamma_lr(shape[i], x[i])
);
assert!(
rel(ln_gamma_ur(shape[i], x[i]), upper[i]) < 1e-9,
"upper i={i}: got {}",
ln_gamma_ur(shape[i], x[i])
);
}
}
#[test]
fn chi2_sf_matches_r_into_tail() {
let rel = |a: f64, b: f64| ((a - b) / b).abs();
assert!(rel(chi2_sf(20.0, 13.0), 0.095_210_256_078_091_53) < 1e-12);
assert!(rel(chi2_sf(100.0, 13.0), 1.659_026_080_708_588_6e-15) < 1e-12);
assert!(rel(chi2_sf(500.0, 13.0), 1.463_698_528_480_679_2e-98) < 1e-11);
assert!(rel(chi2_sf(1200.0, 13.0), 1.769_766_357_320_996_9e-248) < 1e-10);
}
}