use super::gamma_functions::{
EPSILON_F64, MAX_ITER, TINY, betainc_scalar, gammainc_scalar, lgamma_scalar,
};
pub fn gammaincinv_scalar(a: f64, p: f64) -> f64 {
if p <= 0.0 {
return 0.0;
}
if p >= 1.0 {
return f64::INFINITY;
}
if a <= 0.0 {
return f64::NAN;
}
let gln = lgamma_scalar(a);
let mut x = if a > 1.0 {
let pp = if p < 0.5 { p } else { 1.0 - p };
let t = (-2.0 * pp.ln()).sqrt();
let x0 = t - (2.30753 + t * 0.27061) / (1.0 + t * (0.99229 + t * 0.04481));
let x0 = if p < 0.5 { -x0 } else { x0 };
(a + x0 * a.sqrt()).max(0.001)
} else {
let t = 1.0 - a * (0.253 + a * 0.12);
let pp = p;
if pp < t {
(pp / t).powf(1.0 / a) * a
} else {
1.0 - (1.0 - pp).ln() + (1.0 - a) * (1.0 - pp).ln().abs().ln()
}
};
for _ in 0..MAX_ITER {
if x <= 0.0 {
return 0.0;
}
let err = gammainc_scalar(a, x) - p;
if err.abs() < EPSILON_F64 {
break;
}
let t = (a - 1.0) * x.ln() - x - gln;
if t < -700.0 {
break;
}
let dfdx = t.exp();
if dfdx.abs() < TINY {
break;
}
let d2ratio = (a - 1.0) / x - 1.0;
let dx = err / (dfdx * (1.0 - 0.5 * err * d2ratio / dfdx));
x -= dx;
if x <= 0.0 {
x = 0.5 * (x + dx); }
if dx.abs() < x * EPSILON_F64 {
break;
}
}
x.max(0.0)
}
pub fn betaincinv_scalar(a: f64, b: f64, p: f64) -> f64 {
if p <= 0.0 {
return 0.0;
}
if p >= 1.0 {
return 1.0;
}
if a <= 0.0 || b <= 0.0 {
return f64::NAN;
}
if p > 0.5 {
return 1.0 - betaincinv_scalar(b, a, 1.0 - p);
}
let lnbeta = lgamma_scalar(a) + lgamma_scalar(b) - lgamma_scalar(a + b);
let mean = a / (a + b);
let mut x = if p < 0.5 {
mean * (2.0 * p).powf(1.0 / a.max(1.0))
} else {
mean
};
if a >= 1.0 && b >= 1.0 {
let pp = if p < 0.5 { p } else { 1.0 - p };
let t = (-2.0 * pp.ln()).sqrt();
let s = t - (2.30753 + t * 0.27061) / (1.0 + t * (0.99229 + t * 0.04481));
let s = if p < 0.5 { -s } else { s };
let lam = (s * s - 3.0) / 6.0;
let h = 2.0 / (1.0 / (2.0 * a - 1.0) + 1.0 / (2.0 * b - 1.0));
let w = s * (h + lam).sqrt() / h
- (1.0 / (2.0 * b - 1.0) - 1.0 / (2.0 * a - 1.0)) * (lam + 5.0 / 6.0 - 2.0 / (3.0 * h));
x = a / (a + b * (2.0 * w).exp());
}
x = x.clamp(1e-10, 1.0 - 1e-10);
let init_err = (betainc_scalar(a, b, x) - p).abs();
if init_err > 0.3 {
let mut lo = 0.0;
let mut hi = 1.0;
for _ in 0..20 {
let mid = (lo + hi) / 2.0;
if betainc_scalar(a, b, mid) < p {
lo = mid;
} else {
hi = mid;
}
}
x = (lo + hi) / 2.0;
}
let afac = -lnbeta;
for _ in 0..MAX_ITER {
if x <= 0.0 || x >= 1.0 {
break;
}
let err = betainc_scalar(a, b, x) - p;
if err.abs() < EPSILON_F64 {
break;
}
let t = (a - 1.0) * x.ln() + (b - 1.0) * (1.0 - x).ln() + afac;
if t < -700.0 {
break;
}
let dfdx = t.exp();
if dfdx.abs() < TINY {
break;
}
let mut dx = err / dfdx;
if dx > x / 2.0 {
dx = x / 2.0;
}
if dx < -(1.0 - x) / 2.0 {
dx = -(1.0 - x) / 2.0;
}
x -= dx;
x = x.clamp(1e-15, 1.0 - 1e-15);
if dx.abs() < x * EPSILON_F64 {
break;
}
}
x.clamp(0.0, 1.0)
}
#[cfg(test)]
mod tests {
use super::*;
const TOL: f64 = 1e-6;
fn assert_close(a: f64, b: f64, tol: f64) {
assert!(
(a - b).abs() < tol,
"expected {} to be close to {}, diff = {}",
a,
b,
(a - b).abs()
);
}
#[test]
fn test_gammaincinv_roundtrip() {
let test_cases = [(1.0, 0.5), (2.0, 0.3), (5.0, 0.7), (10.0, 0.9), (0.5, 0.4)];
for (a, p) in test_cases {
let x = gammaincinv_scalar(a, p);
let back = gammainc_scalar(a, x);
assert_close(back, p, TOL);
}
}
#[test]
fn test_gammaincinv_bounds() {
assert_close(gammaincinv_scalar(2.0, 0.0), 0.0, 1e-10);
assert!(gammaincinv_scalar(2.0, 1.0).is_infinite());
}
#[test]
fn test_betaincinv_roundtrip() {
let test_cases = [
(1.0, 1.0, 0.5),
(2.0, 2.0, 0.3),
(2.0, 5.0, 0.4),
(5.0, 2.0, 0.6),
(5.0, 0.5, 0.2),
(0.5, 5.0, 0.8),
(10.0, 10.0, 0.5),
];
for (a, b, p) in test_cases {
let x = betaincinv_scalar(a, b, p);
let back = betainc_scalar(a, b, x);
assert_close(back, p, TOL);
}
}
#[test]
fn test_betaincinv_bounds() {
assert_close(betaincinv_scalar(2.0, 3.0, 0.0), 0.0, 1e-10);
assert_close(betaincinv_scalar(2.0, 3.0, 1.0), 1.0, 1e-10);
}
}