use thiserror::Error;
use crate::error::SearchError;
use crate::search::{search_monotone_with_atol, SEARCH_BOUND};
use crate::special::gamma_inc;
use crate::special::gamma_log;
use crate::traits::{ContinuousCdf, Mean, Variance};
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct ChiSquaredNoncentral {
df: f64,
ncp: f64,
}
#[derive(Debug, Clone, Copy, PartialEq, Error)]
pub enum ChiSquaredNoncentralError {
#[error("degrees of freedom must be positive, got {0}")]
DfNotPositive(f64),
#[error("degrees of freedom must be finite, got {0}")]
DfNotFinite(f64),
#[error("noncentrality parameter must be ≥ 0, got {0}")]
NcpNegative(f64),
#[error("noncentrality parameter must be finite, got {0}")]
NcpNotFinite(f64),
#[error("argument x must be positive, got {0}")]
XNotPositive(f64),
#[error("argument x must be finite, got {0}")]
XNotFinite(f64),
#[error("probability {0} outside [0..1]")]
PNotInRange(f64),
#[error("probability {0} outside [0..1]")]
QNotInRange(f64),
#[error(transparent)]
Search(#[from] SearchError),
}
impl ChiSquaredNoncentral {
#[inline]
pub fn new(df: f64, ncp: f64) -> Self {
Self::try_new(df, ncp).unwrap()
}
#[inline]
pub fn try_new(df: f64, ncp: f64) -> Result<Self, ChiSquaredNoncentralError> {
if !df.is_finite() {
return Err(ChiSquaredNoncentralError::DfNotFinite(df));
}
if df <= 0.0 {
return Err(ChiSquaredNoncentralError::DfNotPositive(df));
}
if !ncp.is_finite() {
return Err(ChiSquaredNoncentralError::NcpNotFinite(ncp));
}
if ncp < 0.0 {
return Err(ChiSquaredNoncentralError::NcpNegative(ncp));
}
Ok(Self { df, ncp })
}
#[inline]
pub const fn df(&self) -> f64 {
self.df
}
#[inline]
pub const fn ncp(&self) -> f64 {
self.ncp
}
#[inline]
pub fn search_df(p: f64, x: f64, ncp: f64) -> Result<f64, ChiSquaredNoncentralError> {
check_p(p)?;
if !x.is_finite() {
return Err(ChiSquaredNoncentralError::XNotFinite(x));
}
if x <= 0.0 {
return Err(ChiSquaredNoncentralError::XNotPositive(x));
}
if !ncp.is_finite() {
return Err(ChiSquaredNoncentralError::NcpNotFinite(ncp));
}
if ncp < 0.0 {
return Err(ChiSquaredNoncentralError::NcpNegative(ncp));
}
let f = |df: f64| cumchn(x, df, ncp).0 - p;
Ok(search_monotone_with_atol(
0.0,
SEARCH_BOUND,
5.0,
0.0,
SEARCH_BOUND,
1.0e-50,
f,
)?)
}
#[inline]
pub fn search_ncp(p: f64, x: f64, df: f64) -> Result<f64, ChiSquaredNoncentralError> {
check_p(p)?;
if !x.is_finite() {
return Err(ChiSquaredNoncentralError::XNotFinite(x));
}
if x <= 0.0 {
return Err(ChiSquaredNoncentralError::XNotPositive(x));
}
if !df.is_finite() {
return Err(ChiSquaredNoncentralError::DfNotFinite(df));
}
if df <= 0.0 {
return Err(ChiSquaredNoncentralError::DfNotPositive(df));
}
let f = |ncp: f64| cumchn(x, df, ncp).0 - p;
Ok(search_monotone_with_atol(
0.0, 1.0e4, 5.0, 0.0, 1.0e4, 1.0e-50, f,
)?)
}
}
#[inline]
fn check_p(p: f64) -> Result<(), ChiSquaredNoncentralError> {
if !(0.0..=1.0).contains(&p) || !p.is_finite() {
Err(ChiSquaredNoncentralError::PNotInRange(p))
} else {
Ok(())
}
}
fn cumchn(x: f64, df: f64, pnonc: f64) -> (f64, f64) {
if x.is_nan() || df.is_nan() || pnonc.is_nan() {
return (f64::NAN, f64::NAN);
}
if x <= 0.0 {
return (0.0, 1.0);
}
if pnonc <= 1e-10 {
let (p, q) = gamma_inc(df / 2.0, x / 2.0);
return (p, q);
}
let eps = 1e-5;
let ntired: i32 = 1000;
let xnonc = pnonc / 2.0;
let mut icent = xnonc as i32;
if icent == 0 {
icent = 1;
}
let chid2 = x / 2.0;
let lfact = gamma_log((icent + 1) as f64);
let lcntwt = -xnonc + (icent as f64) * xnonc.ln() - lfact;
let centwt = lcntwt.exp();
let dg = |i: i32| df + 2.0 * (i as f64);
let (pcent, _) = gamma_inc(dg(icent) / 2.0, chid2);
let dfd2 = dg(icent) / 2.0;
let lfact = gamma_log(1.0 + dfd2);
let lcntaj = dfd2 * chid2.ln() - chid2 - lfact;
let centaj = lcntaj.exp();
let mut sum = centwt * pcent;
let mut iterb: i32 = 0;
let mut sumadj = 0.0;
let mut adj = centaj;
let mut wt = centwt;
let mut i = icent;
loop {
let dfd2 = dg(i) / 2.0;
adj *= dfd2 / chid2;
sumadj += adj;
let pterm = pcent + sumadj;
wt *= i as f64 / xnonc;
let term = wt * pterm;
sum += term;
i -= 1;
iterb += 1;
let small = sum < 1e-20 || term < eps * sum;
if iterb > ntired || small || i == 0 {
break;
}
}
let mut iterf: i32 = 0;
let mut adj = centaj;
let mut sumadj = centaj;
let mut wt = centwt;
let mut i = icent;
loop {
wt *= xnonc / (i + 1) as f64;
let pterm = pcent - sumadj;
let term = wt * pterm;
sum += term;
i += 1;
let dfd2 = dg(i) / 2.0;
adj *= chid2 / dfd2;
sumadj += adj;
iterf += 1;
let small = sum < 1e-20 || term < eps * sum;
if iterf > ntired || small {
break;
}
}
let cum = sum;
(cum, 0.5 + (0.5 - cum))
}
impl ContinuousCdf for ChiSquaredNoncentral {
type Error = ChiSquaredNoncentralError;
#[inline]
fn cdf(&self, x: f64) -> f64 {
cumchn(x, self.df, self.ncp).0
}
#[inline]
fn ccdf(&self, x: f64) -> f64 {
cumchn(x, self.df, self.ncp).1
}
#[inline]
fn inverse_cdf(&self, p: f64) -> Result<f64, ChiSquaredNoncentralError> {
check_p(p)?;
if p == 0.0 {
return Ok(0.0);
}
if p == 1.0 {
return Ok(f64::INFINITY);
}
let df = self.df;
let ncp = self.ncp;
let f = |x: f64| cumchn(x, df, ncp).0 - p;
Ok(search_monotone_with_atol(
0.0,
SEARCH_BOUND,
5.0,
0.0,
SEARCH_BOUND,
1.0e-50,
f,
)?)
}
}
impl Mean for ChiSquaredNoncentral {
#[inline]
fn mean(&self) -> f64 {
self.df + self.ncp
}
}
impl Variance for ChiSquaredNoncentral {
#[inline]
fn variance(&self) -> f64 {
2.0 * (self.df + 2.0 * self.ncp)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn rejects_invalid_inputs() {
assert!(matches!(
ChiSquaredNoncentral::try_new(0.0, 1.0),
Err(ChiSquaredNoncentralError::DfNotPositive(0.0))
));
assert!(matches!(
ChiSquaredNoncentral::try_new(1.0, -1.0),
Err(ChiSquaredNoncentralError::NcpNegative(-1.0))
));
assert!(matches!(
ChiSquaredNoncentral::search_df(-0.1, 1.0, 2.0),
Err(ChiSquaredNoncentralError::PNotInRange(-0.1))
));
}
#[test]
fn inverse_and_moment_edges() {
let d = ChiSquaredNoncentral::new(5.0, 2.0);
assert_eq!(d.inverse_cdf(0.0).unwrap(), 0.0);
assert!(d.inverse_cdf(0.25).unwrap().is_finite());
assert!(d.mean().is_finite());
assert!(d.variance().is_finite());
}
#[test]
fn central_limit_path_is_consistent() {
let d = ChiSquaredNoncentral::new(4.0, 0.0);
let x = 3.0;
let cdf = d.cdf(x);
let ccdf = d.ccdf(x);
assert!((cdf + ccdf - 1.0).abs() < 1e-12);
}
}