numra_stats/distributions/
student_t.rs1use numra_core::Scalar;
8use numra_special::{betainc, lgamma};
9use rand::RngCore;
10
11use super::gamma_dist::{sample_standard_normal, GammaDist};
12use super::ContinuousDistribution;
13
14#[derive(Clone, Debug)]
16pub struct StudentT<S: Scalar> {
17 pub df: S,
18}
19
20impl<S: Scalar> StudentT<S> {
21 pub fn new(df: S) -> Self {
22 Self { df }
23 }
24}
25
26impl<S: Scalar> ContinuousDistribution<S> for StudentT<S> {
27 fn pdf(&self, x: S) -> S {
28 let half = S::HALF;
29 let nu = self.df;
30 let log_pdf = lgamma((nu + S::ONE) * half)
31 - lgamma(nu * half)
32 - half * (nu * S::from_f64(core::f64::consts::PI)).ln()
33 - (nu + S::ONE) * half * (S::ONE + x * x / nu).ln();
34 log_pdf.exp()
35 }
36
37 fn cdf(&self, x: S) -> S {
38 let half = S::HALF;
39 let nu = self.df;
40 let t2 = x * x;
41 let p = betainc(nu * half, half, nu / (nu + t2));
42 if x >= S::ZERO {
43 S::ONE - half * p
44 } else {
45 half * p
46 }
47 }
48
49 fn quantile(&self, p: S) -> S {
50 if p <= S::ZERO {
51 return S::NEG_INFINITY;
52 }
53 if p >= S::ONE {
54 return S::INFINITY;
55 }
56 let mut x = super::gamma_dist::normal_quantile_approx(p);
58 for _ in 0..100 {
59 let f_val = self.cdf(x) - p;
60 let f_prime = self.pdf(x);
61 if f_prime.to_f64().abs() < 1e-300 {
62 break;
63 }
64 let step = f_val / f_prime;
65 x -= step;
66 if step.to_f64().abs() < 1e-12 * (S::ONE + x.abs()).to_f64() {
67 break;
68 }
69 }
70 x
71 }
72
73 fn mean(&self) -> S {
74 S::ZERO
76 }
77
78 fn variance(&self) -> S {
79 let two = S::TWO;
80 if self.df > two {
81 self.df / (self.df - two)
82 } else {
83 S::INFINITY
84 }
85 }
86
87 fn sample(&self, rng: &mut dyn RngCore) -> S {
88 let z = sample_standard_normal::<S>(rng);
90 let half = S::HALF;
91 let chi2 = GammaDist::new(self.df * half, half);
92 let v = chi2.sample(rng);
93 z / (v / self.df).sqrt()
94 }
95}
96
97#[cfg(test)]
98mod tests {
99 use super::*;
100
101 #[test]
102 fn test_student_t_symmetry() {
103 let t = StudentT::new(5.0_f64);
104 assert!((t.pdf(1.0) - t.pdf(-1.0)).abs() < 1e-12);
105 assert!((t.cdf(0.0) - 0.5).abs() < 1e-10);
106 }
107
108 #[test]
109 fn test_student_t_cdf_tails() {
110 let t = StudentT::new(10.0_f64);
111 assert!(t.cdf(-10.0) < 0.001);
112 assert!(t.cdf(10.0) > 0.999);
113 }
114
115 #[test]
116 fn test_student_t_quantile_roundtrip() {
117 let t = StudentT::new(5.0_f64);
118 for &p in &[0.05, 0.25, 0.5, 0.75, 0.95] {
119 let x = t.quantile(p);
120 let p2 = t.cdf(x);
121 assert!((p - p2).abs() < 1e-6, "p={}, p2={}", p, p2);
122 }
123 }
124
125 #[test]
126 fn test_student_t_variance() {
127 let t = StudentT::new(5.0_f64);
128 assert!((t.variance() - 5.0 / 3.0).abs() < 1e-14);
129 }
130}