Skip to main content

ferray_random/distributions/
gamma.rs

1// ferray-random: Gamma family distributions — gamma, beta, chisquare, f, student_t, standard_gamma
2//
3// Gamma sampling uses Marsaglia & Tsang's method (shape >= 1) with
4// Ahrens-Dieter transformation for shape < 1.
5
6use ferray_core::{Array, FerrayError, IxDyn};
7
8use crate::bitgen::BitGenerator;
9use crate::distributions::normal::standard_normal_single;
10use crate::generator::{Generator, generate_vec, shape_size, vec_to_array_f64};
11use crate::shape::IntoShape;
12
13/// Generate a single standard gamma variate with shape parameter `alpha`.
14///
15/// Uses Marsaglia & Tsang's method for alpha >= 1, and
16/// the Ahrens-Dieter boost for alpha < 1.
17pub(crate) fn standard_gamma_single<B: BitGenerator>(bg: &mut B, alpha: f64) -> f64 {
18    if alpha < 1.0 {
19        // Ahrens-Dieter: if X ~ Gamma(alpha+1), then X * U^(1/alpha) ~ Gamma(alpha)
20        if alpha <= 0.0 {
21            return 0.0;
22        }
23        loop {
24            let u = bg.next_f64();
25            if u > f64::EPSILON {
26                let x = standard_gamma_ge1(bg, alpha + 1.0);
27                return x * u.powf(1.0 / alpha);
28            }
29        }
30    } else {
31        standard_gamma_ge1(bg, alpha)
32    }
33}
34
35/// Marsaglia & Tsang's method for Gamma(alpha) with alpha >= 1.
36fn standard_gamma_ge1<B: BitGenerator>(bg: &mut B, alpha: f64) -> f64 {
37    let d = alpha - 1.0 / 3.0;
38    let c = 1.0 / (9.0 * d).sqrt();
39
40    loop {
41        let x = standard_normal_single(bg);
42        let v_base = 1.0 + c * x;
43        if v_base <= 0.0 {
44            continue;
45        }
46        let v = v_base * v_base * v_base;
47        let u = bg.next_f64();
48        // Squeeze test
49        if u < (0.0331 * (x * x)).mul_add(-(x * x), 1.0) {
50            return d * v;
51        }
52        if u.ln() < (0.5 * x).mul_add(x, d * (1.0 - v + v.ln())) {
53            return d * v;
54        }
55    }
56}
57
58impl<B: BitGenerator> Generator<B> {
59    /// Generate an array of standard gamma variates with shape parameter `alpha`.
60    ///
61    /// # Errors
62    /// Returns `FerrayError::InvalidValue` if `alpha <= 0` or `size` is invalid.
63    pub fn standard_gamma(
64        &mut self,
65        alpha: f64,
66        size: impl IntoShape,
67    ) -> Result<Array<f64, IxDyn>, FerrayError> {
68        if alpha <= 0.0 {
69            return Err(FerrayError::invalid_value(format!(
70                "alpha must be positive, got {alpha}"
71            )));
72        }
73        let shape_vec = size.into_shape()?;
74        let n = shape_size(&shape_vec);
75        let data = generate_vec(self, n, |bg| standard_gamma_single(bg, alpha));
76        vec_to_array_f64(data, &shape_vec)
77    }
78
79    /// Generate an array of gamma-distributed variates.
80    ///
81    /// The gamma distribution with shape `alpha` and scale `scale` has
82    /// PDF: f(x) = x^(alpha-1) * exp(-x/scale) / (scale^alpha * Gamma(alpha)).
83    ///
84    /// # Errors
85    /// Returns `FerrayError::InvalidValue` if `alpha <= 0`, `scale <= 0`, or `size` is invalid.
86    pub fn gamma(
87        &mut self,
88        alpha: f64,
89        scale: f64,
90        size: impl IntoShape,
91    ) -> Result<Array<f64, IxDyn>, FerrayError> {
92        if alpha <= 0.0 {
93            return Err(FerrayError::invalid_value(format!(
94                "alpha must be positive, got {alpha}"
95            )));
96        }
97        if scale <= 0.0 {
98            return Err(FerrayError::invalid_value(format!(
99                "scale must be positive, got {scale}"
100            )));
101        }
102        let shape_vec = size.into_shape()?;
103        let n = shape_size(&shape_vec);
104        let data = generate_vec(self, n, |bg| scale * standard_gamma_single(bg, alpha));
105        vec_to_array_f64(data, &shape_vec)
106    }
107
108    /// Generate an array of beta-distributed variates in (0, 1).
109    ///
110    /// Uses the relationship: if X ~ Gamma(a), Y ~ Gamma(b), then X/(X+Y) ~ Beta(a,b).
111    ///
112    /// # Errors
113    /// Returns `FerrayError::InvalidValue` if `a <= 0`, `b <= 0`, or `size` is invalid.
114    pub fn beta(
115        &mut self,
116        a: f64,
117        b: f64,
118        size: impl IntoShape,
119    ) -> Result<Array<f64, IxDyn>, FerrayError> {
120        if a <= 0.0 {
121            return Err(FerrayError::invalid_value(format!(
122                "a must be positive, got {a}"
123            )));
124        }
125        if b <= 0.0 {
126            return Err(FerrayError::invalid_value(format!(
127                "b must be positive, got {b}"
128            )));
129        }
130        let shape_vec = size.into_shape()?;
131        let n = shape_size(&shape_vec);
132        let data = generate_vec(self, n, |bg| {
133            let x = standard_gamma_single(bg, a);
134            let y = standard_gamma_single(bg, b);
135            if x + y == 0.0 {
136                0.5 // Degenerate case
137            } else {
138                x / (x + y)
139            }
140        });
141        vec_to_array_f64(data, &shape_vec)
142    }
143
144    /// Generate an array of chi-squared distributed variates.
145    ///
146    /// Chi-squared(df) = Gamma(df/2, 2).
147    ///
148    /// # Errors
149    /// Returns `FerrayError::InvalidValue` if `df <= 0` or `size` is invalid.
150    pub fn chisquare(
151        &mut self,
152        df: f64,
153        size: impl IntoShape,
154    ) -> Result<Array<f64, IxDyn>, FerrayError> {
155        if df <= 0.0 {
156            return Err(FerrayError::invalid_value(format!(
157                "df must be positive, got {df}"
158            )));
159        }
160        let shape_vec = size.into_shape()?;
161        let n = shape_size(&shape_vec);
162        let data = generate_vec(self, n, |bg| 2.0 * standard_gamma_single(bg, df / 2.0));
163        vec_to_array_f64(data, &shape_vec)
164    }
165
166    /// Generate an array of F-distributed variates.
167    ///
168    /// F(d1, d2) = (Chi2(d1)/d1) / (Chi2(d2)/d2).
169    ///
170    /// # Errors
171    /// Returns `FerrayError::InvalidValue` if either df is non-positive or `size` is invalid.
172    pub fn f(
173        &mut self,
174        dfnum: f64,
175        dfden: f64,
176        size: impl IntoShape,
177    ) -> Result<Array<f64, IxDyn>, FerrayError> {
178        if dfnum <= 0.0 {
179            return Err(FerrayError::invalid_value(format!(
180                "dfnum must be positive, got {dfnum}"
181            )));
182        }
183        if dfden <= 0.0 {
184            return Err(FerrayError::invalid_value(format!(
185                "dfden must be positive, got {dfden}"
186            )));
187        }
188        let shape_vec = size.into_shape()?;
189        let n = shape_size(&shape_vec);
190        let data = generate_vec(self, n, |bg| {
191            let x1 = standard_gamma_single(bg, dfnum / 2.0);
192            let x2 = standard_gamma_single(bg, dfden / 2.0);
193            if x2 == 0.0 {
194                f64::INFINITY
195            } else {
196                (x1 / dfnum) / (x2 / dfden)
197            }
198        });
199        vec_to_array_f64(data, &shape_vec)
200    }
201
202    /// Generate an array of Student's t-distributed variates.
203    ///
204    /// t(df) = Normal(0,1) / sqrt(Chi2(df)/df).
205    ///
206    /// # Errors
207    /// Returns `FerrayError::InvalidValue` if `df <= 0` or `size` is invalid.
208    pub fn student_t(
209        &mut self,
210        df: f64,
211        size: impl IntoShape,
212    ) -> Result<Array<f64, IxDyn>, FerrayError> {
213        if df <= 0.0 {
214            return Err(FerrayError::invalid_value(format!(
215                "df must be positive, got {df}"
216            )));
217        }
218        let shape_vec = size.into_shape()?;
219        let n = shape_size(&shape_vec);
220        let data = generate_vec(self, n, |bg| {
221            let z = standard_normal_single(bg);
222            let chi2 = 2.0 * standard_gamma_single(bg, df / 2.0);
223            z / (chi2 / df).sqrt()
224        });
225        vec_to_array_f64(data, &shape_vec)
226    }
227
228    /// Alias for [`student_t`](Self::student_t) using NumPy's `standard_t` spelling.
229    ///
230    /// # Errors
231    /// Returns `FerrayError::InvalidValue` if `df <= 0` or `size` is invalid.
232    pub fn standard_t(
233        &mut self,
234        df: f64,
235        size: impl IntoShape,
236    ) -> Result<Array<f64, IxDyn>, FerrayError> {
237        self.student_t(df, size)
238    }
239
240    /// Generate an array of noncentral chi-squared variates.
241    ///
242    /// `noncentral_chisquare(df, nonc)` is the distribution of
243    /// `sum((Z_i + mu_i)^2)` where `Z_i ~ N(0,1)` and the sum of the
244    /// `mu_i^2` equals `nonc`. Implemented via the Poisson-mixed central
245    /// chi-squared method: `N ~ Poisson(nonc/2)`, `X = 2 * Gamma((df + 2N)/2)`.
246    ///
247    /// Equivalent to `numpy.random.Generator.noncentral_chisquare`.
248    ///
249    /// # Errors
250    /// - `FerrayError::InvalidValue` if `df <= 0`, `nonc < 0`, or `size` is invalid.
251    pub fn noncentral_chisquare(
252        &mut self,
253        df: f64,
254        nonc: f64,
255        size: impl IntoShape,
256    ) -> Result<Array<f64, IxDyn>, FerrayError> {
257        if df <= 0.0 {
258            return Err(FerrayError::invalid_value(format!(
259                "df must be positive, got {df}"
260            )));
261        }
262        if nonc < 0.0 {
263            return Err(FerrayError::invalid_value(format!(
264                "nonc must be non-negative, got {nonc}"
265            )));
266        }
267        let shape_vec = size.into_shape()?;
268        let n = shape_size(&shape_vec);
269        let data = generate_vec(self, n, |bg| {
270            // Sample N ~ Poisson(nonc / 2) inline via Knuth (small lambda).
271            // For nonc/2 in moderate range this is fast; for large nonc the
272            // call still terminates quickly in practice.
273            let lam = nonc / 2.0;
274            let n_pois: u64 = if lam == 0.0 {
275                0
276            } else {
277                let l = (-lam).exp();
278                let mut k: u64 = 0;
279                let mut p = 1.0;
280                loop {
281                    k += 1;
282                    p *= bg.next_f64();
283                    if p <= l {
284                        break k - 1;
285                    }
286                }
287            };
288            let total_df = df + 2.0 * (n_pois as f64);
289            2.0 * standard_gamma_single(bg, total_df / 2.0)
290        });
291        vec_to_array_f64(data, &shape_vec)
292    }
293
294    /// Generate an array of noncentral F-distributed variates.
295    ///
296    /// `noncentral_f(dfnum, dfden, nonc) = (Chi2_nc(dfnum, nonc)/dfnum) /
297    /// (Chi2(dfden)/dfden)`.
298    ///
299    /// Equivalent to `numpy.random.Generator.noncentral_f`.
300    ///
301    /// # Errors
302    /// - `FerrayError::InvalidValue` if any df is non-positive, `nonc < 0`,
303    ///   or `size` is invalid.
304    pub fn noncentral_f(
305        &mut self,
306        dfnum: f64,
307        dfden: f64,
308        nonc: f64,
309        size: impl IntoShape,
310    ) -> Result<Array<f64, IxDyn>, FerrayError> {
311        if dfnum <= 0.0 {
312            return Err(FerrayError::invalid_value(format!(
313                "dfnum must be positive, got {dfnum}"
314            )));
315        }
316        if dfden <= 0.0 {
317            return Err(FerrayError::invalid_value(format!(
318                "dfden must be positive, got {dfden}"
319            )));
320        }
321        if nonc < 0.0 {
322            return Err(FerrayError::invalid_value(format!(
323                "nonc must be non-negative, got {nonc}"
324            )));
325        }
326        let shape_vec = size.into_shape()?;
327        let n = shape_size(&shape_vec);
328        let data = generate_vec(self, n, |bg| {
329            // Numerator: noncentral chi-squared sample (Poisson-mixed).
330            let lam = nonc / 2.0;
331            let n_pois: u64 = if lam == 0.0 {
332                0
333            } else {
334                let l = (-lam).exp();
335                let mut k: u64 = 0;
336                let mut p = 1.0;
337                loop {
338                    k += 1;
339                    p *= bg.next_f64();
340                    if p <= l {
341                        break k - 1;
342                    }
343                }
344            };
345            let total_dfnum = dfnum + 2.0 * (n_pois as f64);
346            let chi2_num = 2.0 * standard_gamma_single(bg, total_dfnum / 2.0);
347            let chi2_den = 2.0 * standard_gamma_single(bg, dfden / 2.0);
348            if chi2_den == 0.0 {
349                f64::INFINITY
350            } else {
351                (chi2_num / dfnum) / (chi2_den / dfden)
352            }
353        });
354        vec_to_array_f64(data, &shape_vec)
355    }
356}
357
358#[cfg(test)]
359mod tests {
360    use crate::default_rng_seeded;
361
362    #[test]
363    fn gamma_positive() {
364        let mut rng = default_rng_seeded(42);
365        let arr = rng.gamma(2.0, 1.0, 10_000).unwrap();
366        let slice = arr.as_slice().unwrap();
367        for &v in slice {
368            assert!(v > 0.0);
369        }
370    }
371
372    #[test]
373    fn gamma_mean_variance() {
374        let mut rng = default_rng_seeded(42);
375        let n = 100_000;
376        let shape = 3.0;
377        let scale = 2.0;
378        let arr = rng.gamma(shape, scale, n).unwrap();
379        let slice = arr.as_slice().unwrap();
380        let mean: f64 = slice.iter().sum::<f64>() / n as f64;
381        let var: f64 = slice.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n as f64;
382        // Gamma(k, theta): mean = k*theta, var = k*theta^2
383        let expected_mean = shape * scale;
384        let expected_var = shape * scale * scale;
385        let se = (expected_var / n as f64).sqrt();
386        assert!(
387            (mean - expected_mean).abs() < 3.0 * se,
388            "gamma mean {mean} too far from {expected_mean}"
389        );
390        assert!(
391            (var - expected_var).abs() / expected_var < 0.05,
392            "gamma variance {var} too far from {expected_var}"
393        );
394    }
395
396    #[test]
397    fn gamma_small_shape() {
398        let mut rng = default_rng_seeded(42);
399        let arr = rng.gamma(0.5, 1.0, 10_000).unwrap();
400        let slice = arr.as_slice().unwrap();
401        for &v in slice {
402            assert!(v > 0.0);
403        }
404    }
405
406    #[test]
407    fn beta_in_range() {
408        let mut rng = default_rng_seeded(42);
409        let arr = rng.beta(2.0, 5.0, 10_000).unwrap();
410        let slice = arr.as_slice().unwrap();
411        for &v in slice {
412            assert!(v > 0.0 && v < 1.0, "beta value {v} out of (0,1)");
413        }
414    }
415
416    #[test]
417    fn beta_mean() {
418        let mut rng = default_rng_seeded(42);
419        let n = 100_000;
420        let a = 2.0;
421        let b = 5.0;
422        let arr = rng.beta(a, b, n).unwrap();
423        let slice = arr.as_slice().unwrap();
424        let mean: f64 = slice.iter().sum::<f64>() / n as f64;
425        // Beta(a,b): mean = a/(a+b)
426        let expected_mean = a / (a + b);
427        let expected_var = (a * b) / ((a + b).powi(2) * (a + b + 1.0));
428        let se = (expected_var / n as f64).sqrt();
429        assert!(
430            (mean - expected_mean).abs() < 3.0 * se,
431            "beta mean {mean} too far from {expected_mean}"
432        );
433    }
434
435    #[test]
436    fn chisquare_positive() {
437        let mut rng = default_rng_seeded(42);
438        let arr = rng.chisquare(5.0, 10_000).unwrap();
439        let slice = arr.as_slice().unwrap();
440        for &v in slice {
441            assert!(v > 0.0);
442        }
443    }
444
445    #[test]
446    fn chisquare_mean() {
447        let mut rng = default_rng_seeded(42);
448        let n = 100_000;
449        let df = 10.0;
450        let arr = rng.chisquare(df, n).unwrap();
451        let slice = arr.as_slice().unwrap();
452        let mean: f64 = slice.iter().sum::<f64>() / n as f64;
453        // Chi2(df): mean = df
454        let expected_var = 2.0 * df;
455        let se = (expected_var / n as f64).sqrt();
456        assert!(
457            (mean - df).abs() < 3.0 * se,
458            "chisquare mean {mean} too far from {df}"
459        );
460    }
461
462    #[test]
463    fn f_positive() {
464        let mut rng = default_rng_seeded(42);
465        let arr = rng.f(5.0, 10.0, 10_000).unwrap();
466        let slice = arr.as_slice().unwrap();
467        for &v in slice {
468            assert!(v > 0.0);
469        }
470    }
471
472    #[test]
473    fn student_t_symmetric() {
474        let mut rng = default_rng_seeded(42);
475        let n = 100_000;
476        let df = 10.0;
477        let arr = rng.student_t(df, n).unwrap();
478        let slice = arr.as_slice().unwrap();
479        let mean: f64 = slice.iter().sum::<f64>() / n as f64;
480        // t(df) with df > 1: mean = 0
481        assert!(mean.abs() < 0.05, "student_t mean {mean} too far from 0");
482    }
483
484    #[test]
485    fn standard_gamma_mean() {
486        let mut rng = default_rng_seeded(42);
487        let n = 100_000;
488        let shape = 5.0;
489        let arr = rng.standard_gamma(shape, n).unwrap();
490        let slice = arr.as_slice().unwrap();
491        let mean: f64 = slice.iter().sum::<f64>() / n as f64;
492        let se = (shape / n as f64).sqrt();
493        assert!(
494            (mean - shape).abs() < 3.0 * se,
495            "standard_gamma mean {mean} too far from {shape}"
496        );
497    }
498
499    #[test]
500    fn gamma_bad_params() {
501        let mut rng = default_rng_seeded(42);
502        assert!(rng.gamma(0.0, 1.0, 100).is_err());
503        assert!(rng.gamma(1.0, 0.0, 100).is_err());
504        assert!(rng.gamma(-1.0, 1.0, 100).is_err());
505    }
506
507    #[test]
508    fn standard_t_alias_matches_student_t() {
509        // Same seed → same draws via either spelling.
510        let mut rng_a = default_rng_seeded(7);
511        let mut rng_b = default_rng_seeded(7);
512        let a = rng_a.student_t(5.0, 100).unwrap();
513        let b = rng_b.standard_t(5.0, 100).unwrap();
514        assert_eq!(a.as_slice().unwrap(), b.as_slice().unwrap());
515    }
516
517    #[test]
518    fn noncentral_chisquare_mean_approx() {
519        // E[noncentral chi^2(df, lam)] = df + lam.
520        let mut rng = default_rng_seeded(42);
521        let n = 50_000;
522        let arr = rng.noncentral_chisquare(5.0, 3.0, n).unwrap();
523        let s = arr.as_slice().unwrap();
524        let mean: f64 = s.iter().sum::<f64>() / n as f64;
525        // 3 standard errors of slack — very generous for a Monte Carlo check.
526        assert!((mean - 8.0).abs() < 0.5, "noncentral_chisquare mean {mean}");
527    }
528
529    #[test]
530    fn noncentral_chisquare_zero_lambda_matches_chisquare() {
531        let mut rng_a = default_rng_seeded(11);
532        let mut rng_b = default_rng_seeded(11);
533        let a = rng_a.noncentral_chisquare(4.0, 0.0, 1000).unwrap();
534        let b = rng_b.chisquare(4.0, 1000).unwrap();
535        // Both algorithms reduce to 2 * Gamma(df/2) under the same RNG sequence.
536        for (x, y) in a.as_slice().unwrap().iter().zip(b.as_slice().unwrap()) {
537            assert!((x - y).abs() < 1e-12);
538        }
539    }
540
541    #[test]
542    fn noncentral_chisquare_bad_params() {
543        let mut rng = default_rng_seeded(0);
544        assert!(rng.noncentral_chisquare(0.0, 1.0, 10).is_err());
545        assert!(rng.noncentral_chisquare(1.0, -1.0, 10).is_err());
546    }
547
548    #[test]
549    fn noncentral_f_positive() {
550        let mut rng = default_rng_seeded(100);
551        let arr = rng.noncentral_f(5.0, 7.0, 2.0, 1000).unwrap();
552        for &v in arr.as_slice().unwrap() {
553            assert!(v >= 0.0);
554        }
555    }
556
557    #[test]
558    fn noncentral_f_bad_params() {
559        let mut rng = default_rng_seeded(0);
560        assert!(rng.noncentral_f(0.0, 1.0, 1.0, 10).is_err());
561        assert!(rng.noncentral_f(1.0, 0.0, 1.0, 10).is_err());
562        assert!(rng.noncentral_f(1.0, 1.0, -1.0, 10).is_err());
563    }
564}