use crate::function::gamma;
use crate::prec;
use std::f64;
#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)]
#[non_exhaustive]
pub enum BetaFuncError {
ANotGreaterThanZero,
BNotGreaterThanZero,
XOutOfRange,
}
impl std::fmt::Display for BetaFuncError {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
BetaFuncError::ANotGreaterThanZero => write!(f, "a is zero or less than zero"),
BetaFuncError::BNotGreaterThanZero => write!(f, "b is zero or less than zero"),
BetaFuncError::XOutOfRange => write!(f, "x is not in [0, 1]"),
}
}
}
impl std::error::Error for BetaFuncError {}
pub fn ln_beta(a: f64, b: f64) -> f64 {
checked_ln_beta(a, b).unwrap()
}
pub fn checked_ln_beta(a: f64, b: f64) -> Result<f64, BetaFuncError> {
if a <= 0.0 {
Err(BetaFuncError::ANotGreaterThanZero)
} else if b <= 0.0 {
Err(BetaFuncError::BNotGreaterThanZero)
} else {
Ok(gamma::ln_gamma(a) + gamma::ln_gamma(b) - gamma::ln_gamma(a + b))
}
}
pub fn beta(a: f64, b: f64) -> f64 {
checked_beta(a, b).unwrap()
}
pub fn checked_beta(a: f64, b: f64) -> Result<f64, BetaFuncError> {
checked_ln_beta(a, b).map(|x| x.exp())
}
pub fn beta_inc(a: f64, b: f64, x: f64) -> f64 {
checked_beta_inc(a, b, x).unwrap()
}
pub fn checked_beta_inc(a: f64, b: f64, x: f64) -> Result<f64, BetaFuncError> {
checked_beta_reg(a, b, x).and_then(|x| checked_beta(a, b).map(|y| x * y))
}
pub fn beta_reg(a: f64, b: f64, x: f64) -> f64 {
checked_beta_reg(a, b, x).unwrap()
}
pub fn checked_beta_reg(a: f64, b: f64, x: f64) -> Result<f64, BetaFuncError> {
if a <= 0.0 {
return Err(BetaFuncError::ANotGreaterThanZero);
}
if b <= 0.0 {
return Err(BetaFuncError::BNotGreaterThanZero);
}
if !(0.0..=1.0).contains(&x) {
return Err(BetaFuncError::XOutOfRange);
}
let bt = if x == 0.0 || ulps_eq!(x, 1.0) {
0.0
} else {
(gamma::ln_gamma(a + b) - gamma::ln_gamma(a) - gamma::ln_gamma(b)
+ a * x.ln()
+ b * (1.0 - x).ln())
.exp()
};
let symm_transform = x >= (a + 1.0) / (a + b + 2.0);
let eps = prec::F64_PREC;
let fpmin = f64::MIN_POSITIVE / eps;
let mut a = a;
let mut b = b;
let mut x = x;
if symm_transform {
let swap = a;
x = 1.0 - x;
a = b;
b = swap;
}
let qab = a + b;
let qap = a + 1.0;
let qam = a - 1.0;
let mut c = 1.0;
let mut d = 1.0 - qab * x / qap;
if d.abs() < fpmin {
d = fpmin;
}
d = 1.0 / d;
let mut h = d;
for m in 1..141 {
let m = f64::from(m);
let m2 = m * 2.0;
let mut aa = m * (b - m) * x / ((qam + m2) * (a + m2));
d = 1.0 + aa * d;
if d.abs() < fpmin {
d = fpmin;
}
c = 1.0 + aa / c;
if c.abs() < fpmin {
c = fpmin;
}
d = 1.0 / d;
h = h * d * c;
aa = -(a + m) * (qab + m) * x / ((a + m2) * (qap + m2));
d = 1.0 + aa * d;
if d.abs() < fpmin {
d = fpmin;
}
c = 1.0 + aa / c;
if c.abs() < fpmin {
c = fpmin;
}
d = 1.0 / d;
let del = d * c;
h *= del;
if (del - 1.0).abs() <= eps {
return if symm_transform {
Ok(1.0 - bt * h / a)
} else {
Ok(bt * h / a)
};
}
}
if symm_transform {
Ok(1.0 - bt * h / a)
} else {
Ok(bt * h / a)
}
}
pub fn inv_beta_reg(mut a: f64, mut b: f64, mut x: f64) -> f64 {
let ln_beta = ln_beta(a, b);
const SAE: i32 = -30;
const FPU: f64 = 1e-30;
debug_assert!((0.0..=1.0).contains(&x) && a > 0.0 && b > 0.0);
if x == 0.0 {
return 0.0;
}
if x == 1.0 {
return 1.0;
}
let mut p;
let mut q;
let flip = 0.5 < x;
if flip {
p = a;
a = b;
b = p;
x = 1.0 - x;
}
p = (-(x * x).ln()).sqrt();
q = p - (2.30753 + 0.27061 * p) / (1.0 + (0.99229 + 0.04481 * p) * p);
if 1.0 < a && 1.0 < b {
let r = (q * q - 3.0) / 6.0;
let s = 1.0 / (2.0 * a - 1.0);
let t = 1.0 / (2.0 * b - 1.0);
let h = 2.0 / (s + t);
let w = q * (h + r).sqrt() / h - (t - s) * (r + 5.0 / 6.0 - 2.0 / (3.0 * h));
p = a / (a + b * (2.0 * w).exp());
} else {
let mut t = 1.0 / (9.0 * b);
t = 2.0 * b * (1.0 - t + q * t.sqrt()).powf(3.0);
if t <= 0.0 {
p = 1.0 - ((((1.0 - x) * b).ln() + ln_beta) / b).exp();
} else {
t = 2.0 * (2.0 * a + b - 1.0) / t;
if t <= 1.0 {
p = (((x * a).ln() + ln_beta) / a).exp();
} else {
p = 1.0 - 2.0 / (t + 1.0);
}
}
}
p = p.clamp(0.0001, 0.9999);
let e = (-5.0 / a / a - 1.0 / x.powf(0.2) - 13.0) as i32;
let acu = if e > SAE { f64::powi(10.0, e) } else { FPU };
let mut pnext;
let mut qprev = 0.0;
let mut sq = 1.0;
let mut prev = 1.0;
'outer: loop {
q = beta_reg(a, b, p);
q = (q - x) * (ln_beta + (1.0 - a) * p.ln() + (1.0 - b) * (1.0 - p).ln()).exp();
if q * qprev <= 0.0 {
prev = if sq > FPU { sq } else { FPU };
}
let mut g = 1.0;
loop {
loop {
let adj = g * q;
sq = adj * adj;
if sq < prev {
pnext = p - adj;
if (0.0..=1.0).contains(&pnext) {
break;
}
}
g /= 3.0;
}
if prev <= acu || q * q <= acu {
p = pnext;
break 'outer;
}
if pnext != 0.0 && pnext != 1.0 {
break;
}
g /= 3.0;
}
if pnext == p {
break;
}
p = pnext;
qprev = q;
}
if flip {
1.0 - p
} else {
p
}
}
#[rustfmt::skip]
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ln_beta() {
assert_almost_eq!(super::ln_beta(0.5, 0.5), 1.144729885849400174144, 1e-15);
assert_almost_eq!(super::ln_beta(1.0, 0.5), 0.6931471805599453094172, 1e-14);
assert_almost_eq!(super::ln_beta(2.5, 0.5), 0.163900632837673937284, 1e-15);
assert_almost_eq!(super::ln_beta(0.5, 1.0), 0.6931471805599453094172, 1e-14);
assert_almost_eq!(super::ln_beta(1.0, 1.0), 0.0, 1e-15);
assert_almost_eq!(super::ln_beta(2.5, 1.0), -0.9162907318741550651835, 1e-14);
assert_almost_eq!(super::ln_beta(0.5, 2.5), 0.163900632837673937284, 1e-15);
assert_almost_eq!(super::ln_beta(1.0, 2.5), -0.9162907318741550651835, 1e-14);
assert_almost_eq!(super::ln_beta(2.5, 2.5), -2.608688089402107300388, 1e-14);
}
#[test]
#[should_panic]
fn test_ln_beta_a_lte_0() {
super::ln_beta(0.0, 0.5);
}
#[test]
#[should_panic]
fn test_ln_beta_b_lte_0() {
super::ln_beta(0.5, 0.0);
}
#[test]
fn test_checked_ln_beta_a_lte_0() {
assert!(super::checked_ln_beta(0.0, 0.5).is_err());
}
#[test]
fn test_checked_ln_beta_b_lte_0() {
assert!(super::checked_ln_beta(0.5, 0.0).is_err());
}
#[test]
#[should_panic]
fn test_beta_a_lte_0() {
super::beta(0.0, 0.5);
}
#[test]
#[should_panic]
fn test_beta_b_lte_0() {
super::beta(0.5, 0.0);
}
#[test]
fn test_checked_beta_a_lte_0() {
assert!(super::checked_beta(0.0, 0.5).is_err());
}
#[test]
fn test_checked_beta_b_lte_0() {
assert!(super::checked_beta(0.5, 0.0).is_err());
}
#[test]
fn test_beta() {
assert_almost_eq!(super::beta(0.5, 0.5), 3.141592653589793238463, 1e-15);
assert_almost_eq!(super::beta(1.0, 0.5), 2.0, 1e-14);
assert_almost_eq!(super::beta(2.5, 0.5), 1.17809724509617246442, 1e-15);
assert_almost_eq!(super::beta(0.5, 1.0), 2.0, 1e-14);
assert_almost_eq!(super::beta(1.0, 1.0), 1.0, 1e-15);
assert_almost_eq!(super::beta(2.5, 1.0), 0.4, 1e-14);
assert_almost_eq!(super::beta(0.5, 2.5), 1.17809724509617246442, 1e-15);
assert_almost_eq!(super::beta(1.0, 2.5), 0.4, 1e-14);
assert_almost_eq!(super::beta(2.5, 2.5), 0.073631077818510779026, 1e-15);
}
#[test]
fn test_beta_inc() {
assert_almost_eq!(super::beta_inc(0.5, 0.5, 0.5), 1.570796326794896619231, 1e-14);
assert_almost_eq!(super::beta_inc(0.5, 0.5, 1.0), 3.141592653589793238463, 1e-15);
assert_almost_eq!(super::beta_inc(1.0, 0.5, 0.5), 0.5857864376269049511983, 1e-15);
assert_almost_eq!(super::beta_inc(1.0, 0.5, 1.0), 2.0, 1e-14);
assert_almost_eq!(super::beta_inc(2.5, 0.5, 0.5), 0.0890486225480862322117, 1e-16);
assert_almost_eq!(super::beta_inc(2.5, 0.5, 1.0), 1.17809724509617246442, 1e-15);
assert_almost_eq!(super::beta_inc(0.5, 1.0, 0.5), 1.414213562373095048802, 1e-14);
assert_almost_eq!(super::beta_inc(0.5, 1.0, 1.0), 2.0, 1e-14);
assert_almost_eq!(super::beta_inc(1.0, 1.0, 0.5), 0.5, 1e-15);
assert_almost_eq!(super::beta_inc(1.0, 1.0, 1.0), 1.0, 1e-15);
assert_eq!(super::beta_inc(2.5, 1.0, 0.5), 0.0707106781186547524401);
assert_almost_eq!(super::beta_inc(2.5, 1.0, 1.0), 0.4, 1e-14);
assert_almost_eq!(super::beta_inc(0.5, 2.5, 0.5), 1.08904862254808623221, 1e-15);
assert_almost_eq!(super::beta_inc(0.5, 2.5, 1.0), 1.17809724509617246442, 1e-15);
assert_almost_eq!(super::beta_inc(1.0, 2.5, 0.5), 0.32928932188134524756, 1e-14);
assert_almost_eq!(super::beta_inc(1.0, 2.5, 1.0), 0.4, 1e-14);
assert_almost_eq!(super::beta_inc(2.5, 2.5, 0.5), 0.03681553890925538951323, 1e-15);
assert_almost_eq!(super::beta_inc(2.5, 2.5, 1.0), 0.073631077818510779026, 1e-15);
}
#[test]
#[should_panic]
fn test_beta_inc_a_lte_0() {
super::beta_inc(0.0, 1.0, 1.0);
}
#[test]
#[should_panic]
fn test_beta_inc_b_lte_0() {
super::beta_inc(1.0, 0.0, 1.0);
}
#[test]
#[should_panic]
fn test_beta_inc_x_lt_0() {
super::beta_inc(1.0, 1.0, -1.0);
}
#[test]
#[should_panic]
fn test_beta_inc_x_gt_1() {
super::beta_inc(1.0, 1.0, 2.0);
}
#[test]
fn test_checked_beta_inc_a_lte_0() {
assert!(super::checked_beta_inc(0.0, 1.0, 1.0).is_err());
}
#[test]
fn test_checked_beta_inc_b_lte_0() {
assert!(super::checked_beta_inc(1.0, 0.0, 1.0).is_err());
}
#[test]
fn test_checked_beta_inc_x_lt_0() {
assert!(super::checked_beta_inc(1.0, 1.0, -1.0).is_err());
}
#[test]
fn test_checked_beta_inc_x_gt_1() {
assert!(super::checked_beta_inc(1.0, 1.0, 2.0).is_err());
}
#[test]
fn test_beta_reg() {
assert_almost_eq!(super::beta_reg(0.5, 0.5, 0.5), 0.5, 1e-15);
assert_eq!(super::beta_reg(0.5, 0.5, 1.0), 1.0);
assert_almost_eq!(super::beta_reg(1.0, 0.5, 0.5), 0.292893218813452475599, 1e-15);
assert_eq!(super::beta_reg(1.0, 0.5, 1.0), 1.0);
assert_almost_eq!(super::beta_reg(2.5, 0.5, 0.5), 0.07558681842161243795, 1e-16);
assert_eq!(super::beta_reg(2.5, 0.5, 1.0), 1.0);
assert_almost_eq!(super::beta_reg(0.5, 1.0, 0.5), 0.7071067811865475244, 1e-15);
assert_eq!(super::beta_reg(0.5, 1.0, 1.0), 1.0);
assert_almost_eq!(super::beta_reg(1.0, 1.0, 0.5), 0.5, 1e-15);
assert_eq!(super::beta_reg(1.0, 1.0, 1.0), 1.0);
assert_almost_eq!(super::beta_reg(2.5, 1.0, 0.5), 0.1767766952966368811, 1e-15);
assert_eq!(super::beta_reg(2.5, 1.0, 1.0), 1.0);
assert_eq!(super::beta_reg(0.5, 2.5, 0.5), 0.92441318157838756205);
assert_eq!(super::beta_reg(0.5, 2.5, 1.0), 1.0);
assert_almost_eq!(super::beta_reg(1.0, 2.5, 0.5), 0.8232233047033631189, 1e-15);
assert_eq!(super::beta_reg(1.0, 2.5, 1.0), 1.0);
assert_almost_eq!(super::beta_reg(2.5, 2.5, 0.5), 0.5, 1e-15);
assert_eq!(super::beta_reg(2.5, 2.5, 1.0), 1.0);
}
#[test]
#[should_panic]
fn test_beta_reg_a_lte_0() {
super::beta_reg(0.0, 1.0, 1.0);
}
#[test]
#[should_panic]
fn test_beta_reg_b_lte_0() {
super::beta_reg(1.0, 0.0, 1.0);
}
#[test]
#[should_panic]
fn test_beta_reg_x_lt_0() {
super::beta_reg(1.0, 1.0, -1.0);
}
#[test]
#[should_panic]
fn test_beta_reg_x_gt_1() {
super::beta_reg(1.0, 1.0, 2.0);
}
#[test]
fn test_checked_beta_reg_a_lte_0() {
assert!(super::checked_beta_reg(0.0, 1.0, 1.0).is_err());
}
#[test]
fn test_checked_beta_reg_b_lte_0() {
assert!(super::checked_beta_reg(1.0, 0.0, 1.0).is_err());
}
#[test]
fn test_checked_beta_reg_x_lt_0() {
assert!(super::checked_beta_reg(1.0, 1.0, -1.0).is_err());
}
#[test]
fn test_checked_beta_reg_x_gt_1() {
assert!(super::checked_beta_reg(1.0, 1.0, 2.0).is_err());
}
#[test]
fn test_error_is_sync_send() {
fn assert_sync_send<T: Sync + Send>() {}
assert_sync_send::<BetaFuncError>();
}
}