ngboost_rs/dist/
gamma.rs

1use crate::dist::{Distribution, DistributionMethods, RegressionDistn};
2use crate::scores::{CRPScore, LogScore, Scorable};
3use ndarray::{array, Array1, Array2, Array3};
4use rand::prelude::*;
5use statrs::distribution::{Continuous, ContinuousCDF, Gamma as GammaDist};
6use statrs::function::gamma::digamma;
7
8/// The Gamma distribution.
9#[derive(Debug, Clone)]
10pub struct Gamma {
11    pub shape: Array1<f64>, // alpha
12    pub rate: Array1<f64>,  // beta
13    _params: Array2<f64>,
14}
15
16impl Distribution for Gamma {
17    fn from_params(params: &Array2<f64>) -> Self {
18        let shape = params.column(0).mapv(f64::exp);
19        let rate = params.column(1).mapv(f64::exp);
20        Gamma {
21            shape,
22            rate,
23            _params: params.clone(),
24        }
25    }
26
27    fn fit(y: &Array1<f64>) -> Array1<f64> {
28        // This is a simplification, MLE for Gamma is complex.
29        // Using method of moments.
30        let mean = y.mean().unwrap_or(1.0);
31        let var = y.var(0.0);
32        let shape = mean * mean / var.max(1e-9);
33        let scale = var / mean.max(1e-9);
34        let rate: f64 = 1.0 / scale;
35        array![shape.ln(), rate.ln()]
36    }
37
38    fn n_params(&self) -> usize {
39        2
40    }
41
42    fn predict(&self) -> Array1<f64> {
43        // Mean is shape / rate
44        &self.shape / &self.rate
45    }
46
47    fn params(&self) -> &Array2<f64> {
48        &self._params
49    }
50}
51
52impl RegressionDistn for Gamma {}
53
54impl DistributionMethods for Gamma {
55    fn mean(&self) -> Array1<f64> {
56        // Mean of Gamma is shape / rate
57        &self.shape / &self.rate
58    }
59
60    fn variance(&self) -> Array1<f64> {
61        // Variance of Gamma is shape / rate^2
62        &self.shape / (&self.rate * &self.rate)
63    }
64
65    fn std(&self) -> Array1<f64> {
66        self.variance().mapv(f64::sqrt)
67    }
68
69    fn pdf(&self, y: &Array1<f64>) -> Array1<f64> {
70        let mut result = Array1::zeros(y.len());
71        for i in 0..y.len() {
72            if let Ok(d) = GammaDist::new(self.shape[i], self.rate[i]) {
73                result[i] = d.pdf(y[i]);
74            }
75        }
76        result
77    }
78
79    fn logpdf(&self, y: &Array1<f64>) -> Array1<f64> {
80        let mut result = Array1::zeros(y.len());
81        for i in 0..y.len() {
82            if let Ok(d) = GammaDist::new(self.shape[i], self.rate[i]) {
83                result[i] = d.ln_pdf(y[i]);
84            }
85        }
86        result
87    }
88
89    fn cdf(&self, y: &Array1<f64>) -> Array1<f64> {
90        let mut result = Array1::zeros(y.len());
91        for i in 0..y.len() {
92            if let Ok(d) = GammaDist::new(self.shape[i], self.rate[i]) {
93                result[i] = d.cdf(y[i]);
94            }
95        }
96        result
97    }
98
99    fn ppf(&self, q: &Array1<f64>) -> Array1<f64> {
100        let mut result = Array1::zeros(q.len());
101        for i in 0..q.len() {
102            if let Ok(d) = GammaDist::new(self.shape[i], self.rate[i]) {
103                let q_clamped = q[i].clamp(1e-15, 1.0 - 1e-15);
104                result[i] = d.inverse_cdf(q_clamped);
105            }
106        }
107        result
108    }
109
110    fn sample(&self, n_samples: usize) -> Array2<f64> {
111        let n_obs = self.shape.len();
112        let mut samples = Array2::zeros((n_samples, n_obs));
113        let mut rng = rand::rng();
114
115        for i in 0..n_obs {
116            if let Ok(d) = GammaDist::new(self.shape[i], self.rate[i]) {
117                for s in 0..n_samples {
118                    let u: f64 = rng.random();
119                    samples[[s, i]] = d.inverse_cdf(u);
120                }
121            }
122        }
123        samples
124    }
125
126    fn median(&self) -> Array1<f64> {
127        // For Gamma, median has no closed form. Use ppf(0.5)
128        let q = Array1::from_elem(self.shape.len(), 0.5);
129        self.ppf(&q)
130    }
131
132    fn mode(&self) -> Array1<f64> {
133        // Mode of Gamma is (shape - 1) / rate for shape >= 1, else 0
134        let mut result = Array1::zeros(self.shape.len());
135        for i in 0..self.shape.len() {
136            if self.shape[i] >= 1.0 {
137                result[i] = (self.shape[i] - 1.0) / self.rate[i];
138            }
139        }
140        result
141    }
142}
143
144impl Scorable<LogScore> for Gamma {
145    fn score(&self, y: &Array1<f64>) -> Array1<f64> {
146        let mut scores = Array1::zeros(y.len());
147        for (i, &y_i) in y.iter().enumerate() {
148            let d = GammaDist::new(self.shape[i], self.rate[i]).unwrap();
149            scores[i] = -d.ln_pdf(y_i);
150        }
151        scores
152    }
153
154    fn d_score(&self, y: &Array1<f64>) -> Array2<f64> {
155        let n_obs = y.len();
156        let mut d_params = Array2::zeros((n_obs, 2));
157
158        for i in 0..n_obs {
159            let shape_i = self.shape[i];
160            let rate_i = self.rate[i];
161
162            // d/d(log(shape))
163            let d_log_shape = shape_i * (digamma(shape_i) - (y[i] * rate_i).max(1e-9).ln());
164            d_params[[i, 0]] = d_log_shape;
165
166            // d/d(log(rate))
167            let d_log_rate = y[i] * rate_i - shape_i;
168            d_params[[i, 1]] = d_log_rate;
169        }
170
171        d_params
172    }
173
174    fn metric(&self) -> Array3<f64> {
175        let n_obs = self.shape.len();
176        let mut fi = Array3::zeros((n_obs, 2, 2));
177
178        for i in 0..n_obs {
179            let shape_i = self.shape[i];
180
181            // We use our local helper function for trigamma
182            fi[[i, 0, 0]] = shape_i * shape_i * trigamma(shape_i);
183            fi[[i, 1, 1]] = shape_i;
184            fi[[i, 0, 1]] = -shape_i;
185            fi[[i, 1, 0]] = -shape_i;
186        }
187
188        fi
189    }
190}
191
192impl Scorable<CRPScore> for Gamma {
193    fn score(&self, y: &Array1<f64>) -> Array1<f64> {
194        // CRPS for Gamma distribution
195        // Based on: Gneiting, T. and Raftery, A.E. (2007)
196        // CRPS(F, y) = y * (2*F(y) - 1) - shape/rate * (2*F_alpha+1(y) - 1)
197        //              + shape / (rate * Beta(0.5, shape))
198        // where F is the CDF and F_alpha+1 is the CDF of Gamma(shape+1, rate)
199        let mut scores = Array1::zeros(y.len());
200
201        for i in 0..y.len() {
202            let shape = self.shape[i];
203            let rate = self.rate[i];
204            let y_i = y[i];
205
206            // CDF of Gamma(shape, rate) at y
207            let f_y = if let Ok(d) = GammaDist::new(shape, rate) {
208                d.cdf(y_i)
209            } else {
210                0.5
211            };
212
213            // CDF of Gamma(shape+1, rate) at y
214            let f_alpha1_y = if let Ok(d) = GammaDist::new(shape + 1.0, rate) {
215                d.cdf(y_i)
216            } else {
217                0.5
218            };
219
220            // Beta(0.5, shape) term using gamma functions
221            // Beta(0.5, a) = Gamma(0.5) * Gamma(a) / Gamma(a + 0.5)
222            // = sqrt(pi) * Gamma(a) / Gamma(a + 0.5)
223            let beta_term = beta(0.5, shape);
224
225            // CRPS formula
226            let mean = shape / rate;
227            scores[i] = y_i * (2.0 * f_y - 1.0) - mean * (2.0 * f_alpha1_y - 1.0)
228                + mean / (std::f64::consts::PI.sqrt() * beta_term);
229        }
230        scores
231    }
232
233    fn d_score(&self, y: &Array1<f64>) -> Array2<f64> {
234        // Numerical gradient for CRPS (analytical form is complex)
235        let n_obs = y.len();
236        let mut d_params = Array2::zeros((n_obs, 2));
237        let eps = 1e-6;
238
239        for i in 0..n_obs {
240            let shape_i = self.shape[i];
241            let rate_i = self.rate[i];
242            let y_i = y[i];
243
244            // Compute score at current params
245            let score_center = self.crps_single(y_i, shape_i, rate_i);
246
247            // Derivative w.r.t. log(shape) via finite difference
248            let shape_plus = shape_i * (1.0 + eps);
249            let score_shape_plus = self.crps_single(y_i, shape_plus, rate_i);
250            d_params[[i, 0]] = (score_shape_plus - score_center) / (shape_i * eps);
251
252            // Derivative w.r.t. log(rate) via finite difference
253            let rate_plus = rate_i * (1.0 + eps);
254            let score_rate_plus = self.crps_single(y_i, shape_i, rate_plus);
255            d_params[[i, 1]] = (score_rate_plus - score_center) / (rate_i * eps);
256        }
257
258        d_params
259    }
260
261    fn metric(&self) -> Array3<f64> {
262        // Use identity matrix scaled by estimated variance as a simple metric
263        let n_obs = self.shape.len();
264        let mut fi = Array3::zeros((n_obs, 2, 2));
265
266        for i in 0..n_obs {
267            // Use a simple diagonal metric
268            let mean = self.shape[i] / self.rate[i];
269            fi[[i, 0, 0]] = mean;
270            fi[[i, 1, 1]] = mean;
271        }
272
273        fi
274    }
275}
276
277impl Gamma {
278    /// Helper function to compute CRPS for a single observation.
279    fn crps_single(&self, y: f64, shape: f64, rate: f64) -> f64 {
280        // CDF of Gamma(shape, rate) at y
281        let f_y = if let Ok(d) = GammaDist::new(shape, rate) {
282            d.cdf(y)
283        } else {
284            0.5
285        };
286
287        // CDF of Gamma(shape+1, rate) at y
288        let f_alpha1_y = if let Ok(d) = GammaDist::new(shape + 1.0, rate) {
289            d.cdf(y)
290        } else {
291            0.5
292        };
293
294        let beta_term = beta(0.5, shape);
295        let mean = shape / rate;
296
297        y * (2.0 * f_y - 1.0) - mean * (2.0 * f_alpha1_y - 1.0)
298            + mean / (std::f64::consts::PI.sqrt() * beta_term)
299    }
300}
301
302/// Trigamma function (second derivative of log gamma).
303fn trigamma(x: f64) -> f64 {
304    let mut x = x;
305    let mut result = 0.0;
306
307    // Use recurrence relation trigamma(x) = trigamma(x+1) + 1/x^2
308    // to shift argument to > 10 for asymptotic expansion accuracy
309    while x < 10.0 {
310        result += 1.0 / (x * x);
311        x += 1.0;
312    }
313
314    // Asymptotic expansion: 1/x + 1/2x^2 + 1/6x^3 - 1/30x^5 + 1/42x^7
315    let x2 = x * x;
316    let x3 = x2 * x;
317    let x5 = x2 * x3;
318    let x7 = x2 * x5;
319
320    result += 1.0 / x + 0.5 / x2 + 1.0 / (6.0 * x3) - 1.0 / (30.0 * x5) + 1.0 / (42.0 * x7);
321
322    result
323}
324
325/// Beta function B(a, b) = Gamma(a) * Gamma(b) / Gamma(a + b)
326fn beta(a: f64, b: f64) -> f64 {
327    use statrs::function::gamma::ln_gamma;
328    (ln_gamma(a) + ln_gamma(b) - ln_gamma(a + b)).exp()
329}
330
331#[cfg(test)]
332mod tests {
333    use super::*;
334    use approx::assert_relative_eq;
335
336    #[test]
337    fn test_gamma_distribution_methods() {
338        // shape=2, rate=1 -> mean=2, var=2
339        let params = Array2::from_shape_vec((1, 2), vec![2.0_f64.ln(), 0.0]).unwrap();
340        let dist = Gamma::from_params(&params);
341
342        // Test mean: shape / rate = 2
343        let mean = dist.mean();
344        assert_relative_eq!(mean[0], 2.0, epsilon = 1e-10);
345
346        // Test variance: shape / rate^2 = 2
347        let var = dist.variance();
348        assert_relative_eq!(var[0], 2.0, epsilon = 1e-10);
349
350        // Test mode: (shape - 1) / rate = 1
351        let mode = dist.mode();
352        assert_relative_eq!(mode[0], 1.0, epsilon = 1e-10);
353    }
354
355    #[test]
356    fn test_gamma_cdf_ppf() {
357        let params = Array2::from_shape_vec((1, 2), vec![1.0_f64.ln(), 0.0]).unwrap();
358        let dist = Gamma::from_params(&params);
359
360        // For shape=1, rate=1, Gamma is Exponential(1)
361        // CDF at 1 should be 1 - exp(-1) ≈ 0.632
362        let y = Array1::from_vec(vec![1.0]);
363        let cdf = dist.cdf(&y);
364        assert_relative_eq!(cdf[0], 1.0 - (-1.0_f64).exp(), epsilon = 1e-6);
365
366        // PPF inverse test
367        let q = Array1::from_vec(vec![0.5]);
368        let ppf = dist.ppf(&q);
369        let cdf_of_ppf = dist.cdf(&ppf);
370        assert_relative_eq!(cdf_of_ppf[0], 0.5, epsilon = 1e-6);
371    }
372
373    #[test]
374    fn test_gamma_sample() {
375        let params = Array2::from_shape_vec((1, 2), vec![2.0_f64.ln(), 0.5_f64.ln()]).unwrap();
376        let dist = Gamma::from_params(&params);
377
378        let samples = dist.sample(1000);
379        assert_eq!(samples.shape(), &[1000, 1]);
380
381        // All samples should be non-negative
382        assert!(samples.iter().all(|&x| x >= 0.0));
383
384        // Check that sample mean is close to shape/rate = 2/0.5 = 4
385        let sample_mean: f64 = samples.column(0).mean().unwrap();
386        assert!((sample_mean - 4.0).abs() < 0.5);
387    }
388
389    #[test]
390    fn test_gamma_fit() {
391        let y = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
392        let params = Gamma::fit(&y);
393        assert_eq!(params.len(), 2);
394        // Should return log(shape) and log(rate)
395    }
396
397    #[test]
398    fn test_gamma_logscore() {
399        let params = Array2::from_shape_vec((1, 2), vec![2.0_f64.ln(), 0.0]).unwrap();
400        let dist = Gamma::from_params(&params);
401
402        let y = Array1::from_vec(vec![2.0]);
403        let score = Scorable::<LogScore>::score(&dist, &y);
404
405        // Score should be finite and positive
406        assert!(score[0].is_finite());
407        assert!(score[0] > 0.0);
408    }
409
410    #[test]
411    fn test_gamma_crps() {
412        let params = Array2::from_shape_vec((1, 2), vec![2.0_f64.ln(), 0.0]).unwrap();
413        let dist = Gamma::from_params(&params);
414
415        let y = Array1::from_vec(vec![2.0]);
416        let score = Scorable::<CRPScore>::score(&dist, &y);
417
418        // CRPS should be finite and non-negative
419        assert!(score[0].is_finite());
420        assert!(score[0] >= 0.0);
421    }
422
423    #[test]
424    fn test_gamma_crps_d_score() {
425        let params = Array2::from_shape_vec((1, 2), vec![2.0_f64.ln(), 0.0]).unwrap();
426        let dist = Gamma::from_params(&params);
427
428        let y = Array1::from_vec(vec![2.0]);
429        let d_score = Scorable::<CRPScore>::d_score(&dist, &y);
430
431        // Gradients should be finite
432        assert!(d_score[[0, 0]].is_finite());
433        assert!(d_score[[0, 1]].is_finite());
434    }
435
436    #[test]
437    fn test_trigamma() {
438        // trigamma(1) = pi^2 / 6 ≈ 1.6449
439        assert_relative_eq!(
440            trigamma(1.0),
441            std::f64::consts::PI.powi(2) / 6.0,
442            epsilon = 1e-6
443        );
444
445        // trigamma(2) = pi^2 / 6 - 1 ≈ 0.6449
446        assert_relative_eq!(
447            trigamma(2.0),
448            std::f64::consts::PI.powi(2) / 6.0 - 1.0,
449            epsilon = 1e-6
450        );
451    }
452
453    #[test]
454    fn test_beta_function() {
455        // Beta(1, 1) = 1
456        assert_relative_eq!(beta(1.0, 1.0), 1.0, epsilon = 1e-10);
457
458        // Beta(0.5, 0.5) = pi
459        assert_relative_eq!(beta(0.5, 0.5), std::f64::consts::PI, epsilon = 1e-10);
460    }
461}