Skip to main content

ferray_random/distributions/
gamma.rs

1// ferray-random: Gamma family distributions — gamma, beta, chisquare, f, student_t, standard_gamma
2//
3// Gamma sampling uses Marsaglia & Tsang's method (shape >= 1) with
4// Ahrens-Dieter transformation for shape < 1.
5
6use ferray_core::{Array, FerrayError, IxDyn};
7
8use crate::bitgen::BitGenerator;
9use crate::distributions::normal::standard_normal_single;
10use crate::generator::{Generator, generate_vec, shape_size, vec_to_array_f64};
11use crate::shape::IntoShape;
12
13/// Generate a single standard gamma variate with shape parameter `alpha`.
14///
15/// Uses Marsaglia & Tsang's method for alpha >= 1, and
16/// the Ahrens-Dieter boost for alpha < 1.
17pub(crate) fn standard_gamma_single<B: BitGenerator>(bg: &mut B, alpha: f64) -> f64 {
18    if alpha < 1.0 {
19        // Ahrens-Dieter: if X ~ Gamma(alpha+1), then X * U^(1/alpha) ~ Gamma(alpha)
20        if alpha <= 0.0 {
21            return 0.0;
22        }
23        loop {
24            let u = bg.next_f64();
25            if u > f64::EPSILON {
26                let x = standard_gamma_ge1(bg, alpha + 1.0);
27                return x * u.powf(1.0 / alpha);
28            }
29        }
30    } else {
31        standard_gamma_ge1(bg, alpha)
32    }
33}
34
35/// Marsaglia & Tsang's method for Gamma(alpha) with alpha >= 1.
36fn standard_gamma_ge1<B: BitGenerator>(bg: &mut B, alpha: f64) -> f64 {
37    let d = alpha - 1.0 / 3.0;
38    let c = 1.0 / (9.0 * d).sqrt();
39
40    loop {
41        let x = standard_normal_single(bg);
42        let v_base = 1.0 + c * x;
43        if v_base <= 0.0 {
44            continue;
45        }
46        let v = v_base * v_base * v_base;
47        let u = bg.next_f64();
48        // Squeeze test
49        if u < 1.0 - 0.0331 * (x * x) * (x * x) {
50            return d * v;
51        }
52        if u.ln() < 0.5 * x * x + d * (1.0 - v + v.ln()) {
53            return d * v;
54        }
55    }
56}
57
58impl<B: BitGenerator> Generator<B> {
59    /// Generate an array of standard gamma variates with shape parameter `alpha`.
60    ///
61    /// # Errors
62    /// Returns `FerrayError::InvalidValue` if `alpha <= 0` or `size` is invalid.
63    pub fn standard_gamma(
64        &mut self,
65        alpha: f64,
66        size: impl IntoShape,
67    ) -> Result<Array<f64, IxDyn>, FerrayError> {
68        if alpha <= 0.0 {
69            return Err(FerrayError::invalid_value(format!(
70                "alpha must be positive, got {alpha}"
71            )));
72        }
73        let shape_vec = size.into_shape()?;
74        let n = shape_size(&shape_vec);
75        let data = generate_vec(self, n, |bg| standard_gamma_single(bg, alpha));
76        vec_to_array_f64(data, &shape_vec)
77    }
78
79    /// Generate an array of gamma-distributed variates.
80    ///
81    /// The gamma distribution with shape `alpha` and scale `scale` has
82    /// PDF: f(x) = x^(alpha-1) * exp(-x/scale) / (scale^alpha * Gamma(alpha)).
83    ///
84    /// # Errors
85    /// Returns `FerrayError::InvalidValue` if `alpha <= 0`, `scale <= 0`, or `size` is invalid.
86    pub fn gamma(
87        &mut self,
88        alpha: f64,
89        scale: f64,
90        size: impl IntoShape,
91    ) -> Result<Array<f64, IxDyn>, FerrayError> {
92        if alpha <= 0.0 {
93            return Err(FerrayError::invalid_value(format!(
94                "alpha must be positive, got {alpha}"
95            )));
96        }
97        if scale <= 0.0 {
98            return Err(FerrayError::invalid_value(format!(
99                "scale must be positive, got {scale}"
100            )));
101        }
102        let shape_vec = size.into_shape()?;
103        let n = shape_size(&shape_vec);
104        let data = generate_vec(self, n, |bg| scale * standard_gamma_single(bg, alpha));
105        vec_to_array_f64(data, &shape_vec)
106    }
107
108    /// Generate an array of beta-distributed variates in (0, 1).
109    ///
110    /// Uses the relationship: if X ~ Gamma(a), Y ~ Gamma(b), then X/(X+Y) ~ Beta(a,b).
111    ///
112    /// # Errors
113    /// Returns `FerrayError::InvalidValue` if `a <= 0`, `b <= 0`, or `size` is invalid.
114    pub fn beta(
115        &mut self,
116        a: f64,
117        b: f64,
118        size: impl IntoShape,
119    ) -> Result<Array<f64, IxDyn>, FerrayError> {
120        if a <= 0.0 {
121            return Err(FerrayError::invalid_value(format!(
122                "a must be positive, got {a}"
123            )));
124        }
125        if b <= 0.0 {
126            return Err(FerrayError::invalid_value(format!(
127                "b must be positive, got {b}"
128            )));
129        }
130        let shape_vec = size.into_shape()?;
131        let n = shape_size(&shape_vec);
132        let data = generate_vec(self, n, |bg| {
133            let x = standard_gamma_single(bg, a);
134            let y = standard_gamma_single(bg, b);
135            if x + y == 0.0 {
136                0.5 // Degenerate case
137            } else {
138                x / (x + y)
139            }
140        });
141        vec_to_array_f64(data, &shape_vec)
142    }
143
144    /// Generate an array of chi-squared distributed variates.
145    ///
146    /// Chi-squared(df) = Gamma(df/2, 2).
147    ///
148    /// # Errors
149    /// Returns `FerrayError::InvalidValue` if `df <= 0` or `size` is invalid.
150    pub fn chisquare(
151        &mut self,
152        df: f64,
153        size: impl IntoShape,
154    ) -> Result<Array<f64, IxDyn>, FerrayError> {
155        if df <= 0.0 {
156            return Err(FerrayError::invalid_value(format!(
157                "df must be positive, got {df}"
158            )));
159        }
160        let shape_vec = size.into_shape()?;
161        let n = shape_size(&shape_vec);
162        let data = generate_vec(self, n, |bg| 2.0 * standard_gamma_single(bg, df / 2.0));
163        vec_to_array_f64(data, &shape_vec)
164    }
165
166    /// Generate an array of F-distributed variates.
167    ///
168    /// F(d1, d2) = (Chi2(d1)/d1) / (Chi2(d2)/d2).
169    ///
170    /// # Errors
171    /// Returns `FerrayError::InvalidValue` if either df is non-positive or `size` is invalid.
172    pub fn f(
173        &mut self,
174        dfnum: f64,
175        dfden: f64,
176        size: impl IntoShape,
177    ) -> Result<Array<f64, IxDyn>, FerrayError> {
178        if dfnum <= 0.0 {
179            return Err(FerrayError::invalid_value(format!(
180                "dfnum must be positive, got {dfnum}"
181            )));
182        }
183        if dfden <= 0.0 {
184            return Err(FerrayError::invalid_value(format!(
185                "dfden must be positive, got {dfden}"
186            )));
187        }
188        let shape_vec = size.into_shape()?;
189        let n = shape_size(&shape_vec);
190        let data = generate_vec(self, n, |bg| {
191            let x1 = standard_gamma_single(bg, dfnum / 2.0);
192            let x2 = standard_gamma_single(bg, dfden / 2.0);
193            if x2 == 0.0 {
194                f64::INFINITY
195            } else {
196                (x1 / dfnum) / (x2 / dfden)
197            }
198        });
199        vec_to_array_f64(data, &shape_vec)
200    }
201
202    /// Generate an array of Student's t-distributed variates.
203    ///
204    /// t(df) = Normal(0,1) / sqrt(Chi2(df)/df).
205    ///
206    /// # Errors
207    /// Returns `FerrayError::InvalidValue` if `df <= 0` or `size` is invalid.
208    pub fn student_t(
209        &mut self,
210        df: f64,
211        size: impl IntoShape,
212    ) -> Result<Array<f64, IxDyn>, FerrayError> {
213        if df <= 0.0 {
214            return Err(FerrayError::invalid_value(format!(
215                "df must be positive, got {df}"
216            )));
217        }
218        let shape_vec = size.into_shape()?;
219        let n = shape_size(&shape_vec);
220        let data = generate_vec(self, n, |bg| {
221            let z = standard_normal_single(bg);
222            let chi2 = 2.0 * standard_gamma_single(bg, df / 2.0);
223            z / (chi2 / df).sqrt()
224        });
225        vec_to_array_f64(data, &shape_vec)
226    }
227}
228
229#[cfg(test)]
230mod tests {
231    use crate::default_rng_seeded;
232
233    #[test]
234    fn gamma_positive() {
235        let mut rng = default_rng_seeded(42);
236        let arr = rng.gamma(2.0, 1.0, 10_000).unwrap();
237        let slice = arr.as_slice().unwrap();
238        for &v in slice {
239            assert!(v > 0.0);
240        }
241    }
242
243    #[test]
244    fn gamma_mean_variance() {
245        let mut rng = default_rng_seeded(42);
246        let n = 100_000;
247        let shape = 3.0;
248        let scale = 2.0;
249        let arr = rng.gamma(shape, scale, n).unwrap();
250        let slice = arr.as_slice().unwrap();
251        let mean: f64 = slice.iter().sum::<f64>() / n as f64;
252        let var: f64 = slice.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n as f64;
253        // Gamma(k, theta): mean = k*theta, var = k*theta^2
254        let expected_mean = shape * scale;
255        let expected_var = shape * scale * scale;
256        let se = (expected_var / n as f64).sqrt();
257        assert!(
258            (mean - expected_mean).abs() < 3.0 * se,
259            "gamma mean {mean} too far from {expected_mean}"
260        );
261        assert!(
262            (var - expected_var).abs() / expected_var < 0.05,
263            "gamma variance {var} too far from {expected_var}"
264        );
265    }
266
267    #[test]
268    fn gamma_small_shape() {
269        let mut rng = default_rng_seeded(42);
270        let arr = rng.gamma(0.5, 1.0, 10_000).unwrap();
271        let slice = arr.as_slice().unwrap();
272        for &v in slice {
273            assert!(v > 0.0);
274        }
275    }
276
277    #[test]
278    fn beta_in_range() {
279        let mut rng = default_rng_seeded(42);
280        let arr = rng.beta(2.0, 5.0, 10_000).unwrap();
281        let slice = arr.as_slice().unwrap();
282        for &v in slice {
283            assert!(v > 0.0 && v < 1.0, "beta value {v} out of (0,1)");
284        }
285    }
286
287    #[test]
288    fn beta_mean() {
289        let mut rng = default_rng_seeded(42);
290        let n = 100_000;
291        let a = 2.0;
292        let b = 5.0;
293        let arr = rng.beta(a, b, n).unwrap();
294        let slice = arr.as_slice().unwrap();
295        let mean: f64 = slice.iter().sum::<f64>() / n as f64;
296        // Beta(a,b): mean = a/(a+b)
297        let expected_mean = a / (a + b);
298        let expected_var = (a * b) / ((a + b).powi(2) * (a + b + 1.0));
299        let se = (expected_var / n as f64).sqrt();
300        assert!(
301            (mean - expected_mean).abs() < 3.0 * se,
302            "beta mean {mean} too far from {expected_mean}"
303        );
304    }
305
306    #[test]
307    fn chisquare_positive() {
308        let mut rng = default_rng_seeded(42);
309        let arr = rng.chisquare(5.0, 10_000).unwrap();
310        let slice = arr.as_slice().unwrap();
311        for &v in slice {
312            assert!(v > 0.0);
313        }
314    }
315
316    #[test]
317    fn chisquare_mean() {
318        let mut rng = default_rng_seeded(42);
319        let n = 100_000;
320        let df = 10.0;
321        let arr = rng.chisquare(df, n).unwrap();
322        let slice = arr.as_slice().unwrap();
323        let mean: f64 = slice.iter().sum::<f64>() / n as f64;
324        // Chi2(df): mean = df
325        let expected_var = 2.0 * df;
326        let se = (expected_var / n as f64).sqrt();
327        assert!(
328            (mean - df).abs() < 3.0 * se,
329            "chisquare mean {mean} too far from {df}"
330        );
331    }
332
333    #[test]
334    fn f_positive() {
335        let mut rng = default_rng_seeded(42);
336        let arr = rng.f(5.0, 10.0, 10_000).unwrap();
337        let slice = arr.as_slice().unwrap();
338        for &v in slice {
339            assert!(v > 0.0);
340        }
341    }
342
343    #[test]
344    fn student_t_symmetric() {
345        let mut rng = default_rng_seeded(42);
346        let n = 100_000;
347        let df = 10.0;
348        let arr = rng.student_t(df, n).unwrap();
349        let slice = arr.as_slice().unwrap();
350        let mean: f64 = slice.iter().sum::<f64>() / n as f64;
351        // t(df) with df > 1: mean = 0
352        assert!(mean.abs() < 0.05, "student_t mean {mean} too far from 0");
353    }
354
355    #[test]
356    fn standard_gamma_mean() {
357        let mut rng = default_rng_seeded(42);
358        let n = 100_000;
359        let shape = 5.0;
360        let arr = rng.standard_gamma(shape, n).unwrap();
361        let slice = arr.as_slice().unwrap();
362        let mean: f64 = slice.iter().sum::<f64>() / n as f64;
363        let se = (shape / n as f64).sqrt();
364        assert!(
365            (mean - shape).abs() < 3.0 * se,
366            "standard_gamma mean {mean} too far from {shape}"
367        );
368    }
369
370    #[test]
371    fn gamma_bad_params() {
372        let mut rng = default_rng_seeded(42);
373        assert!(rng.gamma(0.0, 1.0, 100).is_err());
374        assert!(rng.gamma(1.0, 0.0, 100).is_err());
375        assert!(rng.gamma(-1.0, 1.0, 100).is_err());
376    }
377}