use crate::DType;
use super::special;
use crate::stats::distribution::{ContinuousDistribution, Distribution};
use crate::stats::error::{StatsError, StatsResult};
use numr::algorithm::special::SpecialFunctions;
use numr::error::Result;
use numr::ops::{ScalarOps, TensorOps};
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::Tensor;
use std::f64::consts::PI;
#[derive(Debug, Clone, Copy)]
pub struct StudentT {
nu: f64,
log_norm: f64,
}
impl StudentT {
pub fn new(nu: f64) -> StatsResult<Self> {
if nu <= 0.0 {
return Err(StatsError::InvalidParameter {
name: "nu".to_string(),
value: nu,
reason: "degrees of freedom must be positive".to_string(),
});
}
if !nu.is_finite() {
return Err(StatsError::InvalidParameter {
name: "nu".to_string(),
value: nu,
reason: "must be finite".to_string(),
});
}
let log_norm =
special::lgamma((nu + 1.0) / 2.0) - 0.5 * (nu * PI).ln() - special::lgamma(nu / 2.0);
Ok(Self { nu, log_norm })
}
pub fn df(&self) -> f64 {
self.nu
}
}
impl Distribution for StudentT {
fn mean(&self) -> f64 {
if self.nu > 1.0 { 0.0 } else { f64::NAN }
}
fn var(&self) -> f64 {
if self.nu > 2.0 {
self.nu / (self.nu - 2.0)
} else if self.nu > 1.0 {
f64::INFINITY
} else {
f64::NAN
}
}
fn entropy(&self) -> f64 {
let half_nu = self.nu / 2.0;
let half_nu_plus_1 = (self.nu + 1.0) / 2.0;
half_nu_plus_1 * (special::digamma(half_nu_plus_1) - special::digamma(half_nu))
+ 0.5 * (self.nu * PI).ln()
+ special::lbeta(half_nu, 0.5)
}
fn median(&self) -> f64 {
0.0
}
fn mode(&self) -> f64 {
0.0
}
fn skewness(&self) -> f64 {
if self.nu > 3.0 { 0.0 } else { f64::NAN }
}
fn kurtosis(&self) -> f64 {
if self.nu > 4.0 {
6.0 / (self.nu - 4.0)
} else if self.nu > 2.0 {
f64::INFINITY
} else {
f64::NAN
}
}
}
impl ContinuousDistribution for StudentT {
fn pdf(&self, x: f64) -> f64 {
self.log_pdf(x).exp()
}
fn log_pdf(&self, x: f64) -> f64 {
self.log_norm - ((self.nu + 1.0) / 2.0) * (1.0 + x * x / self.nu).ln()
}
fn cdf(&self, x: f64) -> f64 {
if x == 0.0 {
return 0.5;
}
let x2 = x * x;
let t = self.nu / (self.nu + x2);
let beta_inc = special::betainc(self.nu / 2.0, 0.5, t);
if x > 0.0 {
1.0 - 0.5 * beta_inc
} else {
0.5 * beta_inc
}
}
fn sf(&self, x: f64) -> f64 {
self.cdf(-x)
}
fn ppf(&self, p: f64) -> StatsResult<f64> {
if !(0.0..=1.0).contains(&p) {
return Err(StatsError::InvalidProbability { value: p });
}
if p == 0.0 {
return Ok(f64::NEG_INFINITY);
}
if p == 1.0 {
return Ok(f64::INFINITY);
}
if p == 0.5 {
return Ok(0.0);
}
let (q, sign) = if p > 0.5 {
(2.0 * (1.0 - p), 1.0)
} else {
(2.0 * p, -1.0)
};
let t = special::betaincinv(self.nu / 2.0, 0.5, q);
let x = sign * (self.nu * (1.0 / t - 1.0)).sqrt();
Ok(x)
}
fn pdf_tensor<R: Runtime<DType = DType>, C>(
&self,
x: &Tensor<R>,
client: &C,
) -> Result<Tensor<R>>
where
C: TensorOps<R> + ScalarOps<R> + RuntimeClient<R>,
{
self.log_pdf_tensor(x, client)
.and_then(|log_pdf| client.exp(&log_pdf))
}
fn log_pdf_tensor<R: Runtime<DType = DType>, C>(
&self,
x: &Tensor<R>,
client: &C,
) -> Result<Tensor<R>>
where
C: TensorOps<R> + ScalarOps<R> + RuntimeClient<R>,
{
let x_sq = client.square(x)?;
let one_plus_t = client.add_scalar(&client.mul_scalar(&x_sq, 1.0 / self.nu)?, 1.0)?;
let ln_term = client.log(&one_plus_t)?;
let scaled = client.mul_scalar(&ln_term, -(self.nu + 1.0) / 2.0)?;
client.add_scalar(&scaled, self.log_norm)
}
fn cdf_tensor<R: Runtime<DType = DType>, C>(
&self,
x: &Tensor<R>,
client: &C,
) -> Result<Tensor<R>>
where
C: TensorOps<R> + ScalarOps<R> + SpecialFunctions<R> + RuntimeClient<R>,
{
let x_sq = client.square(x)?;
let nu_plus_x_sq = client.add_scalar(&x_sq, self.nu)?;
let nu_tensor = Tensor::<R>::full_scalar(x.shape(), x.dtype(), self.nu, client.device());
let t = client.div(&nu_tensor, &nu_plus_x_sq)?;
let a = Tensor::<R>::full_scalar(x.shape(), x.dtype(), self.nu / 2.0, client.device());
let b = Tensor::<R>::full_scalar(x.shape(), x.dtype(), 0.5, client.device());
let betainc_val = client.betainc(&a, &b, &t)?;
let half_betainc = client.mul_scalar(&betainc_val, 0.5)?;
let sign_x = client.sign(x)?;
let half_minus_half_beta = client.mul_scalar(&half_betainc, -1.0)?;
let half_minus_half_beta = client.add_scalar(&half_minus_half_beta, 0.5)?;
let sign_term = client.mul(&sign_x, &half_minus_half_beta)?;
client.add_scalar(&sign_term, 0.5)
}
fn sf_tensor<R: Runtime<DType = DType>, C>(
&self,
x: &Tensor<R>,
client: &C,
) -> Result<Tensor<R>>
where
C: TensorOps<R> + ScalarOps<R> + SpecialFunctions<R> + RuntimeClient<R>,
{
let neg_x = client.mul_scalar(x, -1.0)?;
self.cdf_tensor(&neg_x, client)
}
fn log_cdf_tensor<R: Runtime<DType = DType>, C>(
&self,
x: &Tensor<R>,
client: &C,
) -> Result<Tensor<R>>
where
C: TensorOps<R> + ScalarOps<R> + SpecialFunctions<R> + RuntimeClient<R>,
{
let cdf = self.cdf_tensor(x, client)?;
client.log(&cdf)
}
fn ppf_tensor<R: Runtime<DType = DType>, C>(
&self,
p: &Tensor<R>,
client: &C,
) -> Result<Tensor<R>>
where
C: TensorOps<R> + ScalarOps<R> + SpecialFunctions<R> + RuntimeClient<R>,
{
let a = Tensor::<R>::full_scalar(p.shape(), p.dtype(), self.nu / 2.0, client.device());
let b = Tensor::<R>::full_scalar(p.shape(), p.dtype(), 0.5, client.device());
let p_minus_half = client.sub_scalar(p, 0.5)?;
let abs_p_minus_half = client.abs(&p_minus_half)?;
let q = client.mul_scalar(&abs_p_minus_half, 2.0)?;
let one_minus_q = client.mul_scalar(&q, -1.0)?;
let q_adjusted = client.add_scalar(&one_minus_q, 1.0)?;
let t = client.betaincinv(&a, &b, &q_adjusted)?;
let one_tensor = Tensor::<R>::full_scalar(p.shape(), p.dtype(), 1.0, client.device());
let inv_t = client.div(&one_tensor, &t)?;
let inv_t_minus_1 = client.sub_scalar(&inv_t, 1.0)?;
let nu_term = client.mul_scalar(&inv_t_minus_1, self.nu)?;
let sqrt_term = client.sqrt(&nu_term)?;
let sign_p = client.sign(&p_minus_half)?;
client.mul(&sign_p, &sqrt_term)
}
fn isf_tensor<R: Runtime<DType = DType>, C>(
&self,
p: &Tensor<R>,
client: &C,
) -> Result<Tensor<R>>
where
C: TensorOps<R> + ScalarOps<R> + SpecialFunctions<R> + RuntimeClient<R>,
{
let neg_p = client.mul_scalar(p, -1.0)?;
let one_minus_p = client.add_scalar(&neg_p, 1.0)?;
self.ppf_tensor(&one_minus_p, client)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_student_t_creation() {
let t = StudentT::new(10.0).unwrap();
assert!((t.df() - 10.0).abs() < 1e-10);
assert!(StudentT::new(0.0).is_err());
assert!(StudentT::new(-1.0).is_err());
}
#[test]
fn test_student_t_moments() {
let t = StudentT::new(10.0).unwrap();
assert!((t.mean() - 0.0).abs() < 1e-10);
assert!((t.var() - 10.0 / 8.0).abs() < 1e-10);
assert!((t.median() - 0.0).abs() < 1e-10);
assert!((t.mode() - 0.0).abs() < 1e-10);
assert!((t.skewness() - 0.0).abs() < 1e-10);
assert!((t.kurtosis() - 1.0).abs() < 1e-10); }
#[test]
fn test_student_t_pdf_symmetry() {
let t = StudentT::new(5.0).unwrap();
for x in [0.5, 1.0, 2.0, 3.0] {
assert!((t.pdf(x) - t.pdf(-x)).abs() < 1e-10);
}
assert!(t.pdf(0.0) > t.pdf(1.0));
}
#[test]
fn test_student_t_cdf() {
let t = StudentT::new(10.0).unwrap();
assert!((t.cdf(0.0) - 0.5).abs() < 1e-10);
for x in [0.5, 1.0, 2.0] {
assert!((t.cdf(-x) + t.cdf(x) - 1.0).abs() < 1e-10);
}
}
#[test]
fn test_student_t_ppf() {
let t = StudentT::new(10.0).unwrap();
assert!((t.ppf(0.5).unwrap() - 0.0).abs() < 1e-10);
for p in [0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99] {
let x = t.ppf(p).unwrap();
assert!(
(t.cdf(x) - p).abs() < 1e-4,
"Roundtrip failed for p={}: cdf(ppf(p)) = {}",
p,
t.cdf(x)
);
}
assert!((t.ppf(0.975).unwrap() - 2.228).abs() < 0.01);
}
#[test]
fn test_student_t_convergence_to_normal() {
let t = StudentT::new(1000.0).unwrap();
let normal_cdf_1 = 0.8413447460685429;
assert!((t.cdf(1.0) - normal_cdf_1).abs() < 0.01);
}
#[test]
fn test_student_t_low_df() {
let t = StudentT::new(1.0).unwrap();
assert!(t.mean().is_nan());
assert!(t.var().is_nan());
let t = StudentT::new(2.0).unwrap();
assert!((t.mean() - 0.0).abs() < 1e-10);
assert!(t.var().is_infinite());
}
}