1use crate::dist::{Distribution, DistributionMethods, RegressionDistn};
2use crate::scores::{CRPScore, LogScore, Scorable};
3use ndarray::{array, Array1, Array2, Array3};
4use rand::prelude::*;
5use statrs::distribution::{Continuous, ContinuousCDF, Gamma as GammaDist};
6use statrs::function::gamma::digamma;
7
8#[derive(Debug, Clone)]
10pub struct Gamma {
11 pub shape: Array1<f64>, pub rate: Array1<f64>, _params: Array2<f64>,
14}
15
16impl Distribution for Gamma {
17 fn from_params(params: &Array2<f64>) -> Self {
18 let shape = params.column(0).mapv(f64::exp);
19 let rate = params.column(1).mapv(f64::exp);
20 Gamma {
21 shape,
22 rate,
23 _params: params.clone(),
24 }
25 }
26
27 fn fit(y: &Array1<f64>) -> Array1<f64> {
28 let mean = y.mean().unwrap_or(1.0);
31 let var = y.var(0.0);
32 let shape = mean * mean / var.max(1e-9);
33 let scale = var / mean.max(1e-9);
34 let rate: f64 = 1.0 / scale;
35 array![shape.ln(), rate.ln()]
36 }
37
38 fn n_params(&self) -> usize {
39 2
40 }
41
42 fn predict(&self) -> Array1<f64> {
43 &self.shape / &self.rate
45 }
46
47 fn params(&self) -> &Array2<f64> {
48 &self._params
49 }
50}
51
52impl RegressionDistn for Gamma {}
53
54impl DistributionMethods for Gamma {
55 fn mean(&self) -> Array1<f64> {
56 &self.shape / &self.rate
58 }
59
60 fn variance(&self) -> Array1<f64> {
61 &self.shape / (&self.rate * &self.rate)
63 }
64
65 fn std(&self) -> Array1<f64> {
66 self.variance().mapv(f64::sqrt)
67 }
68
69 fn pdf(&self, y: &Array1<f64>) -> Array1<f64> {
70 let mut result = Array1::zeros(y.len());
71 for i in 0..y.len() {
72 if let Ok(d) = GammaDist::new(self.shape[i], self.rate[i]) {
73 result[i] = d.pdf(y[i]);
74 }
75 }
76 result
77 }
78
79 fn logpdf(&self, y: &Array1<f64>) -> Array1<f64> {
80 let mut result = Array1::zeros(y.len());
81 for i in 0..y.len() {
82 if let Ok(d) = GammaDist::new(self.shape[i], self.rate[i]) {
83 result[i] = d.ln_pdf(y[i]);
84 }
85 }
86 result
87 }
88
89 fn cdf(&self, y: &Array1<f64>) -> Array1<f64> {
90 let mut result = Array1::zeros(y.len());
91 for i in 0..y.len() {
92 if let Ok(d) = GammaDist::new(self.shape[i], self.rate[i]) {
93 result[i] = d.cdf(y[i]);
94 }
95 }
96 result
97 }
98
99 fn ppf(&self, q: &Array1<f64>) -> Array1<f64> {
100 let mut result = Array1::zeros(q.len());
101 for i in 0..q.len() {
102 if let Ok(d) = GammaDist::new(self.shape[i], self.rate[i]) {
103 let q_clamped = q[i].clamp(1e-15, 1.0 - 1e-15);
104 result[i] = d.inverse_cdf(q_clamped);
105 }
106 }
107 result
108 }
109
110 fn sample(&self, n_samples: usize) -> Array2<f64> {
111 let n_obs = self.shape.len();
112 let mut samples = Array2::zeros((n_samples, n_obs));
113 let mut rng = rand::rng();
114
115 for i in 0..n_obs {
116 if let Ok(d) = GammaDist::new(self.shape[i], self.rate[i]) {
117 for s in 0..n_samples {
118 let u: f64 = rng.random();
119 samples[[s, i]] = d.inverse_cdf(u);
120 }
121 }
122 }
123 samples
124 }
125
126 fn median(&self) -> Array1<f64> {
127 let q = Array1::from_elem(self.shape.len(), 0.5);
129 self.ppf(&q)
130 }
131
132 fn mode(&self) -> Array1<f64> {
133 let mut result = Array1::zeros(self.shape.len());
135 for i in 0..self.shape.len() {
136 if self.shape[i] >= 1.0 {
137 result[i] = (self.shape[i] - 1.0) / self.rate[i];
138 }
139 }
140 result
141 }
142}
143
144impl Scorable<LogScore> for Gamma {
145 fn score(&self, y: &Array1<f64>) -> Array1<f64> {
146 let mut scores = Array1::zeros(y.len());
147 for (i, &y_i) in y.iter().enumerate() {
148 let d = GammaDist::new(self.shape[i], self.rate[i]).unwrap();
149 scores[i] = -d.ln_pdf(y_i);
150 }
151 scores
152 }
153
154 fn d_score(&self, y: &Array1<f64>) -> Array2<f64> {
155 let n_obs = y.len();
156 let mut d_params = Array2::zeros((n_obs, 2));
157
158 for i in 0..n_obs {
159 let shape_i = self.shape[i];
160 let rate_i = self.rate[i];
161
162 let d_log_shape = shape_i * (digamma(shape_i) - (y[i] * rate_i).max(1e-9).ln());
164 d_params[[i, 0]] = d_log_shape;
165
166 let d_log_rate = y[i] * rate_i - shape_i;
168 d_params[[i, 1]] = d_log_rate;
169 }
170
171 d_params
172 }
173
174 fn metric(&self) -> Array3<f64> {
175 let n_obs = self.shape.len();
176 let mut fi = Array3::zeros((n_obs, 2, 2));
177
178 for i in 0..n_obs {
179 let shape_i = self.shape[i];
180
181 fi[[i, 0, 0]] = shape_i * shape_i * trigamma(shape_i);
183 fi[[i, 1, 1]] = shape_i;
184 fi[[i, 0, 1]] = -shape_i;
185 fi[[i, 1, 0]] = -shape_i;
186 }
187
188 fi
189 }
190}
191
192impl Scorable<CRPScore> for Gamma {
193 fn score(&self, y: &Array1<f64>) -> Array1<f64> {
194 let mut scores = Array1::zeros(y.len());
200
201 for i in 0..y.len() {
202 let shape = self.shape[i];
203 let rate = self.rate[i];
204 let y_i = y[i];
205
206 let f_y = if let Ok(d) = GammaDist::new(shape, rate) {
208 d.cdf(y_i)
209 } else {
210 0.5
211 };
212
213 let f_alpha1_y = if let Ok(d) = GammaDist::new(shape + 1.0, rate) {
215 d.cdf(y_i)
216 } else {
217 0.5
218 };
219
220 let beta_term = beta(0.5, shape);
224
225 let mean = shape / rate;
227 scores[i] = y_i * (2.0 * f_y - 1.0) - mean * (2.0 * f_alpha1_y - 1.0)
228 + mean / (std::f64::consts::PI.sqrt() * beta_term);
229 }
230 scores
231 }
232
233 fn d_score(&self, y: &Array1<f64>) -> Array2<f64> {
234 let n_obs = y.len();
236 let mut d_params = Array2::zeros((n_obs, 2));
237 let eps = 1e-6;
238
239 for i in 0..n_obs {
240 let shape_i = self.shape[i];
241 let rate_i = self.rate[i];
242 let y_i = y[i];
243
244 let score_center = self.crps_single(y_i, shape_i, rate_i);
246
247 let shape_plus = shape_i * (1.0 + eps);
249 let score_shape_plus = self.crps_single(y_i, shape_plus, rate_i);
250 d_params[[i, 0]] = (score_shape_plus - score_center) / (shape_i * eps);
251
252 let rate_plus = rate_i * (1.0 + eps);
254 let score_rate_plus = self.crps_single(y_i, shape_i, rate_plus);
255 d_params[[i, 1]] = (score_rate_plus - score_center) / (rate_i * eps);
256 }
257
258 d_params
259 }
260
261 fn metric(&self) -> Array3<f64> {
262 let n_obs = self.shape.len();
264 let mut fi = Array3::zeros((n_obs, 2, 2));
265
266 for i in 0..n_obs {
267 let mean = self.shape[i] / self.rate[i];
269 fi[[i, 0, 0]] = mean;
270 fi[[i, 1, 1]] = mean;
271 }
272
273 fi
274 }
275}
276
277impl Gamma {
278 fn crps_single(&self, y: f64, shape: f64, rate: f64) -> f64 {
280 let f_y = if let Ok(d) = GammaDist::new(shape, rate) {
282 d.cdf(y)
283 } else {
284 0.5
285 };
286
287 let f_alpha1_y = if let Ok(d) = GammaDist::new(shape + 1.0, rate) {
289 d.cdf(y)
290 } else {
291 0.5
292 };
293
294 let beta_term = beta(0.5, shape);
295 let mean = shape / rate;
296
297 y * (2.0 * f_y - 1.0) - mean * (2.0 * f_alpha1_y - 1.0)
298 + mean / (std::f64::consts::PI.sqrt() * beta_term)
299 }
300}
301
302fn trigamma(x: f64) -> f64 {
304 let mut x = x;
305 let mut result = 0.0;
306
307 while x < 10.0 {
310 result += 1.0 / (x * x);
311 x += 1.0;
312 }
313
314 let x2 = x * x;
316 let x3 = x2 * x;
317 let x5 = x2 * x3;
318 let x7 = x2 * x5;
319
320 result += 1.0 / x + 0.5 / x2 + 1.0 / (6.0 * x3) - 1.0 / (30.0 * x5) + 1.0 / (42.0 * x7);
321
322 result
323}
324
325fn beta(a: f64, b: f64) -> f64 {
327 use statrs::function::gamma::ln_gamma;
328 (ln_gamma(a) + ln_gamma(b) - ln_gamma(a + b)).exp()
329}
330
331#[cfg(test)]
332mod tests {
333 use super::*;
334 use approx::assert_relative_eq;
335
336 #[test]
337 fn test_gamma_distribution_methods() {
338 let params = Array2::from_shape_vec((1, 2), vec![2.0_f64.ln(), 0.0]).unwrap();
340 let dist = Gamma::from_params(¶ms);
341
342 let mean = dist.mean();
344 assert_relative_eq!(mean[0], 2.0, epsilon = 1e-10);
345
346 let var = dist.variance();
348 assert_relative_eq!(var[0], 2.0, epsilon = 1e-10);
349
350 let mode = dist.mode();
352 assert_relative_eq!(mode[0], 1.0, epsilon = 1e-10);
353 }
354
355 #[test]
356 fn test_gamma_cdf_ppf() {
357 let params = Array2::from_shape_vec((1, 2), vec![1.0_f64.ln(), 0.0]).unwrap();
358 let dist = Gamma::from_params(¶ms);
359
360 let y = Array1::from_vec(vec![1.0]);
363 let cdf = dist.cdf(&y);
364 assert_relative_eq!(cdf[0], 1.0 - (-1.0_f64).exp(), epsilon = 1e-6);
365
366 let q = Array1::from_vec(vec![0.5]);
368 let ppf = dist.ppf(&q);
369 let cdf_of_ppf = dist.cdf(&ppf);
370 assert_relative_eq!(cdf_of_ppf[0], 0.5, epsilon = 1e-6);
371 }
372
373 #[test]
374 fn test_gamma_sample() {
375 let params = Array2::from_shape_vec((1, 2), vec![2.0_f64.ln(), 0.5_f64.ln()]).unwrap();
376 let dist = Gamma::from_params(¶ms);
377
378 let samples = dist.sample(1000);
379 assert_eq!(samples.shape(), &[1000, 1]);
380
381 assert!(samples.iter().all(|&x| x >= 0.0));
383
384 let sample_mean: f64 = samples.column(0).mean().unwrap();
386 assert!((sample_mean - 4.0).abs() < 0.5);
387 }
388
389 #[test]
390 fn test_gamma_fit() {
391 let y = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
392 let params = Gamma::fit(&y);
393 assert_eq!(params.len(), 2);
394 }
396
397 #[test]
398 fn test_gamma_logscore() {
399 let params = Array2::from_shape_vec((1, 2), vec![2.0_f64.ln(), 0.0]).unwrap();
400 let dist = Gamma::from_params(¶ms);
401
402 let y = Array1::from_vec(vec![2.0]);
403 let score = Scorable::<LogScore>::score(&dist, &y);
404
405 assert!(score[0].is_finite());
407 assert!(score[0] > 0.0);
408 }
409
410 #[test]
411 fn test_gamma_crps() {
412 let params = Array2::from_shape_vec((1, 2), vec![2.0_f64.ln(), 0.0]).unwrap();
413 let dist = Gamma::from_params(¶ms);
414
415 let y = Array1::from_vec(vec![2.0]);
416 let score = Scorable::<CRPScore>::score(&dist, &y);
417
418 assert!(score[0].is_finite());
420 assert!(score[0] >= 0.0);
421 }
422
423 #[test]
424 fn test_gamma_crps_d_score() {
425 let params = Array2::from_shape_vec((1, 2), vec![2.0_f64.ln(), 0.0]).unwrap();
426 let dist = Gamma::from_params(¶ms);
427
428 let y = Array1::from_vec(vec![2.0]);
429 let d_score = Scorable::<CRPScore>::d_score(&dist, &y);
430
431 assert!(d_score[[0, 0]].is_finite());
433 assert!(d_score[[0, 1]].is_finite());
434 }
435
436 #[test]
437 fn test_trigamma() {
438 assert_relative_eq!(
440 trigamma(1.0),
441 std::f64::consts::PI.powi(2) / 6.0,
442 epsilon = 1e-6
443 );
444
445 assert_relative_eq!(
447 trigamma(2.0),
448 std::f64::consts::PI.powi(2) / 6.0 - 1.0,
449 epsilon = 1e-6
450 );
451 }
452
453 #[test]
454 fn test_beta_function() {
455 assert_relative_eq!(beta(1.0, 1.0), 1.0, epsilon = 1e-10);
457
458 assert_relative_eq!(beta(0.5, 0.5), std::f64::consts::PI, epsilon = 1e-10);
460 }
461}