numra_stats/distributions/
gamma_dist.rs1use numra_core::Scalar;
8use numra_special::{gammainc, lgamma};
9use rand::RngCore;
10
11use super::normal::random_uniform_01;
12use super::ContinuousDistribution;
13
14#[derive(Clone, Debug)]
18pub struct GammaDist<S: Scalar> {
19 pub shape: S,
20 pub rate: S,
21}
22
23impl<S: Scalar> GammaDist<S> {
24 pub fn new(shape: S, rate: S) -> Self {
25 Self { shape, rate }
26 }
27}
28
29impl<S: Scalar> ContinuousDistribution<S> for GammaDist<S> {
30 fn pdf(&self, x: S) -> S {
31 if x < S::ZERO {
32 return S::ZERO;
33 }
34 if x == S::ZERO {
35 return if self.shape == S::ONE {
36 self.rate
37 } else if self.shape > S::ONE {
38 S::ZERO
39 } else {
40 S::INFINITY
41 };
42 }
43 let log_pdf = self.shape * self.rate.ln() + (self.shape - S::ONE) * x.ln()
44 - self.rate * x
45 - lgamma(self.shape);
46 log_pdf.exp()
47 }
48
49 fn cdf(&self, x: S) -> S {
50 if x <= S::ZERO {
51 return S::ZERO;
52 }
53 gammainc(self.shape, self.rate * x)
55 }
56
57 fn quantile(&self, p: S) -> S {
58 if p <= S::ZERO {
59 return S::ZERO;
60 }
61 if p >= S::ONE {
62 return S::INFINITY;
63 }
64 let mu = self.shape / self.rate;
67 let sig = (self.shape / (self.rate * self.rate)).sqrt();
68 let mut x = mu + sig * normal_quantile_approx(p);
69 if x <= S::ZERO {
70 x = mu * S::from_f64(0.01);
71 }
72 for _ in 0..50 {
73 let f_val = self.cdf(x) - p;
74 let f_prime = self.pdf(x);
75 if f_prime.to_f64().abs() < 1e-300 {
76 break;
77 }
78 let step = f_val / f_prime;
79 x -= step;
80 if x <= S::ZERO {
81 x = S::from_f64(1e-10);
82 }
83 if step.to_f64().abs() < 1e-12 * x.to_f64().abs() {
84 break;
85 }
86 }
87 x
88 }
89
90 fn mean(&self) -> S {
91 self.shape / self.rate
92 }
93
94 fn variance(&self) -> S {
95 self.shape / (self.rate * self.rate)
96 }
97
98 fn sample(&self, rng: &mut dyn RngCore) -> S {
99 let one = S::ONE;
101 let shape = if self.shape < one {
102 self.shape + one
103 } else {
104 self.shape
105 };
106
107 let d = shape - S::from_f64(1.0 / 3.0);
108 let c = S::ONE / (S::from_f64(9.0) * d).sqrt();
109
110 loop {
111 let x = sample_standard_normal::<S>(rng);
112 let v = S::ONE + c * x;
113 if v <= S::ZERO {
114 continue;
115 }
116 let v = v * v * v;
117 let u = random_uniform_01::<S>(rng);
118 let x2 = x * x;
119 if u < S::ONE - S::from_f64(0.0331) * x2 * x2 {
120 let result = d * v / self.rate;
121 if self.shape < one {
122 let u2 = random_uniform_01::<S>(rng);
123 return result * u2.ln().exp() / self.shape.ln().exp()
124 * (S::ONE / self.shape).ln().exp();
125 }
126 return result;
127 }
128 if u.ln() < S::HALF * x2 + d * (S::ONE - v + v.ln()) {
129 let result = d * v / self.rate;
130 if self.shape < one {
131 let u2 = random_uniform_01::<S>(rng);
132 return result * u2.powf(S::ONE / self.shape);
133 }
134 return result;
135 }
136 }
137 }
138}
139
140pub(crate) fn normal_quantile_approx<S: Scalar>(p: S) -> S {
142 let p_f64 = p.to_f64();
143 let t = if p_f64 < 0.5 {
145 (-2.0 * p_f64.ln()).sqrt()
146 } else {
147 (-2.0 * (1.0 - p_f64).ln()).sqrt()
148 };
149 let c0 = 2.515517;
150 let c1 = 0.802853;
151 let c2 = 0.010328;
152 let d1 = 1.432788;
153 let d2 = 0.189269;
154 let d3 = 0.001308;
155 let val = t - (c0 + c1 * t + c2 * t * t) / (1.0 + d1 * t + d2 * t * t + d3 * t * t * t);
156 if p_f64 < 0.5 {
157 S::from_f64(-val)
158 } else {
159 S::from_f64(val)
160 }
161}
162
163pub(crate) fn sample_standard_normal<S: Scalar>(rng: &mut dyn RngCore) -> S {
165 let u1 = random_uniform_01::<S>(rng);
166 let u2 = random_uniform_01::<S>(rng);
167 let two = S::TWO;
168 let pi2 = S::from_f64(core::f64::consts::TAU);
169 (S::ZERO - two * u1.ln()).sqrt() * (pi2 * u2).cos()
170}
171
172#[cfg(test)]
173mod tests {
174 use super::*;
175
176 #[test]
177 fn test_gamma_pdf_exponential_case() {
178 let g = GammaDist::new(1.0_f64, 2.0);
180 assert!((g.pdf(0.0) - 2.0).abs() < 1e-12);
181 assert!((g.pdf(1.0) - 2.0 * (-2.0_f64).exp()).abs() < 1e-12);
182 }
183
184 #[test]
185 fn test_gamma_cdf() {
186 let g = GammaDist::new(1.0_f64, 1.0);
187 assert!((g.cdf(1.0) - (1.0 - (-1.0_f64).exp())).abs() < 1e-8);
189 }
190
191 #[test]
192 fn test_gamma_quantile_roundtrip() {
193 let g = GammaDist::new(3.0_f64, 2.0);
194 for &p in &[0.1, 0.5, 0.9] {
195 let x = g.quantile(p);
196 let p2 = g.cdf(x);
197 assert!((p - p2).abs() < 1e-6, "p={}, p2={}", p, p2);
198 }
199 }
200
201 #[test]
202 fn test_gamma_mean_variance() {
203 let g = GammaDist::new(5.0_f64, 2.0);
204 assert!((g.mean() - 2.5).abs() < 1e-14);
205 assert!((g.variance() - 1.25).abs() < 1e-14);
206 }
207
208 #[test]
209 fn test_gamma_sample_mean() {
210 use rand::SeedableRng;
211 let g = GammaDist::new(3.0_f64, 1.0);
212 let mut rng = rand::rngs::StdRng::seed_from_u64(45);
213 let samples = g.sample_n(&mut rng, 10000);
214 let mean: f64 = samples.iter().sum::<f64>() / samples.len() as f64;
215 assert!((mean - 3.0).abs() < 0.2, "sample mean = {}", mean);
216 }
217}