use crate::DType;
use super::Gamma;
use crate::stats::distribution::{ContinuousDistribution, Distribution};
use crate::stats::error::{StatsError, StatsResult};
use numr::algorithm::special::SpecialFunctions;
use numr::error::Result;
use numr::ops::{ScalarOps, TensorOps};
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::Tensor;
#[derive(Debug, Clone, Copy)]
pub struct ChiSquared {
k: f64,
gamma: Gamma,
}
impl ChiSquared {
pub fn new(k: u64) -> StatsResult<Self> {
if k == 0 {
return Err(StatsError::InvalidParameter {
name: "k".to_string(),
value: 0.0,
reason: "degrees of freedom must be positive".to_string(),
});
}
let k_f64 = k as f64;
let gamma = Gamma::new(k_f64 / 2.0, 0.5)?;
Ok(Self { k: k_f64, gamma })
}
pub fn new_f64(k: f64) -> StatsResult<Self> {
if k <= 0.0 {
return Err(StatsError::InvalidParameter {
name: "k".to_string(),
value: k,
reason: "degrees of freedom must be positive".to_string(),
});
}
let gamma = Gamma::new(k / 2.0, 0.5)?;
Ok(Self { k, gamma })
}
pub fn df(&self) -> f64 {
self.k
}
}
impl Distribution for ChiSquared {
fn mean(&self) -> f64 {
self.k
}
fn var(&self) -> f64 {
2.0 * self.k
}
fn entropy(&self) -> f64 {
self.gamma.entropy()
}
fn median(&self) -> f64 {
self.k * (1.0 - 2.0 / (9.0 * self.k)).powi(3)
}
fn mode(&self) -> f64 {
if self.k >= 2.0 { self.k - 2.0 } else { 0.0 }
}
fn skewness(&self) -> f64 {
(8.0 / self.k).sqrt()
}
fn kurtosis(&self) -> f64 {
12.0 / self.k }
}
impl ContinuousDistribution for ChiSquared {
fn pdf(&self, x: f64) -> f64 {
self.gamma.pdf(x)
}
fn log_pdf(&self, x: f64) -> f64 {
self.gamma.log_pdf(x)
}
fn cdf(&self, x: f64) -> f64 {
self.gamma.cdf(x)
}
fn sf(&self, x: f64) -> f64 {
self.gamma.sf(x)
}
fn ppf(&self, p: f64) -> StatsResult<f64> {
self.gamma.ppf(p)
}
fn pdf_tensor<R: Runtime<DType = DType>, C>(
&self,
x: &Tensor<R>,
client: &C,
) -> Result<Tensor<R>>
where
C: TensorOps<R> + ScalarOps<R> + RuntimeClient<R>,
{
self.gamma.pdf_tensor(x, client)
}
fn log_pdf_tensor<R: Runtime<DType = DType>, C>(
&self,
x: &Tensor<R>,
client: &C,
) -> Result<Tensor<R>>
where
C: TensorOps<R> + ScalarOps<R> + RuntimeClient<R>,
{
self.gamma.log_pdf_tensor(x, client)
}
fn cdf_tensor<R: Runtime<DType = DType>, C>(
&self,
x: &Tensor<R>,
client: &C,
) -> Result<Tensor<R>>
where
C: TensorOps<R> + ScalarOps<R> + SpecialFunctions<R> + RuntimeClient<R>,
{
self.gamma.cdf_tensor(x, client)
}
fn sf_tensor<R: Runtime<DType = DType>, C>(
&self,
x: &Tensor<R>,
client: &C,
) -> Result<Tensor<R>>
where
C: TensorOps<R> + ScalarOps<R> + SpecialFunctions<R> + RuntimeClient<R>,
{
self.gamma.sf_tensor(x, client)
}
fn log_cdf_tensor<R: Runtime<DType = DType>, C>(
&self,
x: &Tensor<R>,
client: &C,
) -> Result<Tensor<R>>
where
C: TensorOps<R> + ScalarOps<R> + SpecialFunctions<R> + RuntimeClient<R>,
{
self.gamma.log_cdf_tensor(x, client)
}
fn ppf_tensor<R: Runtime<DType = DType>, C>(
&self,
p: &Tensor<R>,
client: &C,
) -> Result<Tensor<R>>
where
C: TensorOps<R> + ScalarOps<R> + SpecialFunctions<R> + RuntimeClient<R>,
{
self.gamma.ppf_tensor(p, client)
}
fn isf_tensor<R: Runtime<DType = DType>, C>(
&self,
p: &Tensor<R>,
client: &C,
) -> Result<Tensor<R>>
where
C: TensorOps<R> + ScalarOps<R> + SpecialFunctions<R> + RuntimeClient<R>,
{
self.gamma.isf_tensor(p, client)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_chi_squared_creation() {
let chi2 = ChiSquared::new(5).unwrap();
assert!((chi2.df() - 5.0).abs() < 1e-10);
assert!(ChiSquared::new(0).is_err());
}
#[test]
fn test_chi_squared_moments() {
let chi2 = ChiSquared::new(10).unwrap();
assert!((chi2.mean() - 10.0).abs() < 1e-10);
assert!((chi2.var() - 20.0).abs() < 1e-10);
assert!((chi2.mode() - 8.0).abs() < 1e-10);
assert!((chi2.skewness() - (0.8_f64).sqrt()).abs() < 1e-10);
assert!((chi2.kurtosis() - 1.2).abs() < 1e-10);
}
#[test]
fn test_chi_squared_cdf() {
let chi2 = ChiSquared::new(1).unwrap();
assert!((chi2.cdf(1.0) - 0.6826894921370859).abs() < 1e-6);
let chi2 = ChiSquared::new(2).unwrap();
assert!((chi2.cdf(2.0) - (1.0 - (-1.0_f64).exp())).abs() < 1e-6);
}
#[test]
fn test_chi_squared_ppf() {
let chi2 = ChiSquared::new(5).unwrap();
for p in [0.1, 0.25, 0.5, 0.75, 0.9, 0.95] {
let x = chi2.ppf(p).unwrap();
assert!((chi2.cdf(x) - p).abs() < 1e-6, "Failed for p={}", p);
}
assert!((chi2.ppf(0.95).unwrap() - 11.0705).abs() < 0.01);
}
}