use crate::special::{f_lsf, ln_gamma_lr, ln_gamma_ur, ln_norm_cdf, norm_isf_log};
use statrs::distribution::{Continuous, ContinuousCDF, StudentsT};
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum ZscoreTMethod {
Exact,
Bailey,
Hill,
Wallace,
}
fn rsign(x: f64) -> f64 {
if x > 0.0 {
1.0
} else if x < 0.0 {
-1.0
} else {
0.0
}
}
pub fn zscore_t(x: f64, df: f64, method: ZscoreTMethod) -> f64 {
match method {
ZscoreTMethod::Exact => zscore_t_quantile(x, df),
ZscoreTMethod::Bailey => zscore_t_bailey(x, df),
ZscoreTMethod::Hill => zscore_t_hill(x, df),
ZscoreTMethod::Wallace => zscore_t_wallace(x, df),
}
}
fn zscore_t_wallace(x: f64, df: f64) -> f64 {
((df + 0.125) / (df + 0.375)) * (df * (x / df * x).ln_1p()).sqrt() * rsign(x)
}
fn zscore_t_bailey(x: f64, df: f64) -> f64 {
((df + 0.125) / (df + 1.125))
* ((df + 19.0 / 12.0) * (x / (df + 1.0 / 12.0) * x).ln_1p()).sqrt()
* rsign(x)
}
fn zscore_t_hill(x: f64, df: f64) -> f64 {
let a = df - 0.5;
let b = 48.0 * a * a;
let mut z = a * (x / df * x).ln_1p();
z = (((((-0.4 * z - 3.3) * z - 24.0) * z - 85.5) / (0.8 * z * z + 100.0 + b) + z + 3.0) / b
+ 1.0)
* z.sqrt();
z * rsign(x)
}
fn zscore_t_quantile(x: f64, df: f64) -> f64 {
let lp = ln_t_sf_pos(x.abs(), df);
norm_isf_log(lp) * rsign(x)
}
fn ln_t_sf_pos(t: f64, df: f64) -> f64 {
-std::f64::consts::LN_2 + f_lsf(t * t, 1.0, df)
}
pub fn t_zscore(z: f64, df: f64) -> f64 {
let lp = ln_norm_cdf(-z.abs()); t_isf_log(lp, df) * rsign(z)
}
fn t_isf_log(lp: f64, df: f64) -> f64 {
if lp >= -std::f64::consts::LN_2 {
return 0.0;
}
let dist = StudentsT::new(0.0, 1.0, df).expect("valid df");
let mut t = if lp > -700.0 {
let g = dist.inverse_cdf(1.0 - lp.exp());
if g.is_finite() && g > 0.0 {
g
} else {
1.0
}
} else {
let mut tt = 1.0;
while ln_t_sf_pos(tt, df) > lp {
tt *= 2.0;
}
tt
};
for _ in 0..100 {
let lsf = ln_t_sf_pos(t, df);
let lpdf = dist.pdf(t).ln();
let h = lsf - lp;
let dh = -(lpdf - lsf).exp(); let step = h / dh;
let mut tn = t - step;
if tn <= 0.0 {
tn = 0.5 * t;
}
let converged = (tn - t).abs() < 1e-13 * (1.0 + tn);
t = tn;
if converged {
break;
}
}
t
}
pub fn zscore_from_log_tails(log_lower: f64, log_upper: f64) -> f64 {
if log_upper < log_lower {
norm_isf_log(log_upper)
} else {
-norm_isf_log(log_lower)
}
}
pub fn zscore_gamma(q: f64, shape: f64, scale: f64) -> f64 {
if q > shape * scale {
norm_isf_log(ln_gamma_ur(shape, q / scale))
} else {
-norm_isf_log(ln_gamma_lr(shape, q / scale))
}
}
#[cfg(test)]
mod tests {
use super::*;
const X: [f64; 8] = [0.5, 1.2, -2.0, 3.5, 8.0, -12.0, 0.0, 25.0];
const DF: [f64; 8] = [5.0, 10.0, 3.0, 20.0, 50.0, 8.0, 7.0, 4.0];
fn close(a: f64, b: f64, tol: f64) -> bool {
(a - b).abs() <= tol + tol * b.abs()
}
fn check(method: ZscoreTMethod, want: &[f64], tol: f64) {
for i in 0..X.len() {
let got = zscore_t(X[i], DF[i], method);
assert!(
close(got, want[i], tol),
"{:?} i={i}: got {got}, want {}",
method,
want[i]
);
}
}
#[test]
fn zscore_t_exact_matches_r() {
check(
ZscoreTMethod::Exact,
&[
0.470078587598904,
1.1316150720248,
-1.47830571549245,
3.05439849233135,
6.3895913409388,
-4.73936721150427,
0.0,
4.32580582047361,
],
1e-7,
);
}
#[test]
fn zscore_t_approx_methods_match_r() {
check(
ZscoreTMethod::Bailey,
&[
0.47040618216154,
1.13171290668369,
-1.47913851858807,
3.05419638749544,
6.38909334210128,
-4.72198983061799,
0.0,
4.26852597651621,
],
1e-9,
);
check(
ZscoreTMethod::Hill,
&[
0.470078899765056,
1.13161507280532,
-1.47838714521046,
3.0543984920903,
6.38959133924561,
-4.73937919940718,
0.0,
4.32687312869554,
],
1e-9,
);
check(
ZscoreTMethod::Wallace,
&[
0.470941044862427,
1.13192574792552,
-1.4762330589056,
3.05330279152777,
6.38754780925417,
-4.70852441428392,
0.0,
4.24090262987019,
],
1e-9,
);
}
#[test]
fn t_zscore_matches_r() {
let z = [0.3, 1.0, -1.5, 2.5, 4.0, -6.0, 0.0, 3.1];
let dfz = [5.0, 10.0, 3.0, 20.0, 50.0, 8.0, 7.0, 4.0];
let want = [
0.316825549865956,
1.05256241525734,
-2.04335742531369,
2.74732044128031,
4.36716311562917,
-29.3392145892267,
0.0,
7.23627391940524,
];
for i in 0..z.len() {
let got = t_zscore(z[i], dfz[i]);
assert!(
close(got, want[i], 1e-7),
"i={i}: got {got}, want {}",
want[i]
);
}
}
#[test]
fn zscore_t_round_trips_through_t_zscore() {
for &(x, df) in &[(1.3_f64, 9.0_f64), (-2.7, 14.0), (4.5, 6.0)] {
let z = zscore_t(x, df, ZscoreTMethod::Exact);
let back = t_zscore(z, df);
assert!(close(back, x, 1e-7), "x={x}: round-trip {back}");
}
}
#[test]
fn zscore_gamma_matches_r() {
let q = [0.5, 9.0, 10.0, 2.0];
let shape = [2.0, 5.0, 3.0, 1.0];
let scale = [1.0, 1.0, 2.0, 4.0];
let want = [
-1.33949979584716,
1.59852006747439,
1.15204145464159,
-0.270288020738736,
];
for i in 0..q.len() {
let got = zscore_gamma(q[i], shape[i], scale[i]);
assert!(
close(got, want[i], 1e-9),
"i={i}: got {got}, want {}",
want[i]
);
}
}
#[test]
fn zscore_generic_recovers_normal() {
for &qi in &[-2.5_f64, -0.3, 0.8, 3.1] {
let lo = crate::special::ln_norm_cdf(qi);
let hi = crate::special::ln_norm_cdf(-qi);
let got = zscore_from_log_tails(lo, hi);
assert!(close(got, qi, 1e-9), "q={qi}: got {got}");
}
}
}