ferray_random/distributions/
gamma.rs1use ferray_core::{Array, FerrayError, Ix1};
7
8use crate::bitgen::BitGenerator;
9use crate::distributions::normal::standard_normal_single;
10use crate::generator::{Generator, generate_vec, vec_to_array1};
11
12pub(crate) fn standard_gamma_single<B: BitGenerator>(bg: &mut B, alpha: f64) -> f64 {
17 if alpha < 1.0 {
18 if alpha <= 0.0 {
20 return 0.0;
21 }
22 loop {
23 let u = bg.next_f64();
24 if u > f64::EPSILON {
25 let x = standard_gamma_ge1(bg, alpha + 1.0);
26 return x * u.powf(1.0 / alpha);
27 }
28 }
29 } else {
30 standard_gamma_ge1(bg, alpha)
31 }
32}
33
34fn standard_gamma_ge1<B: BitGenerator>(bg: &mut B, alpha: f64) -> f64 {
36 let d = alpha - 1.0 / 3.0;
37 let c = 1.0 / (9.0 * d).sqrt();
38
39 loop {
40 let x = standard_normal_single(bg);
41 let v_base = 1.0 + c * x;
42 if v_base <= 0.0 {
43 continue;
44 }
45 let v = v_base * v_base * v_base;
46 let u = bg.next_f64();
47 if u < 1.0 - 0.0331 * (x * x) * (x * x) {
49 return d * v;
50 }
51 if u.ln() < 0.5 * x * x + d * (1.0 - v + v.ln()) {
52 return d * v;
53 }
54 }
55}
56
57impl<B: BitGenerator> Generator<B> {
58 pub fn standard_gamma(
67 &mut self,
68 shape: f64,
69 size: usize,
70 ) -> Result<Array<f64, Ix1>, FerrayError> {
71 if size == 0 {
72 return Err(FerrayError::invalid_value("size must be > 0"));
73 }
74 if shape <= 0.0 {
75 return Err(FerrayError::invalid_value(format!(
76 "shape must be positive, got {shape}"
77 )));
78 }
79 let data = generate_vec(self, size, |bg| standard_gamma_single(bg, shape));
80 vec_to_array1(data)
81 }
82
83 pub fn gamma(
96 &mut self,
97 shape: f64,
98 scale: f64,
99 size: usize,
100 ) -> Result<Array<f64, Ix1>, FerrayError> {
101 if size == 0 {
102 return Err(FerrayError::invalid_value("size must be > 0"));
103 }
104 if shape <= 0.0 {
105 return Err(FerrayError::invalid_value(format!(
106 "shape must be positive, got {shape}"
107 )));
108 }
109 if scale <= 0.0 {
110 return Err(FerrayError::invalid_value(format!(
111 "scale must be positive, got {scale}"
112 )));
113 }
114 let data = generate_vec(self, size, |bg| scale * standard_gamma_single(bg, shape));
115 vec_to_array1(data)
116 }
117
118 pub fn beta(&mut self, a: f64, b: f64, size: usize) -> Result<Array<f64, Ix1>, FerrayError> {
130 if size == 0 {
131 return Err(FerrayError::invalid_value("size must be > 0"));
132 }
133 if a <= 0.0 {
134 return Err(FerrayError::invalid_value(format!(
135 "a must be positive, got {a}"
136 )));
137 }
138 if b <= 0.0 {
139 return Err(FerrayError::invalid_value(format!(
140 "b must be positive, got {b}"
141 )));
142 }
143 let data = generate_vec(self, size, |bg| {
144 let x = standard_gamma_single(bg, a);
145 let y = standard_gamma_single(bg, b);
146 if x + y == 0.0 {
147 0.5 } else {
149 x / (x + y)
150 }
151 });
152 vec_to_array1(data)
153 }
154
155 pub fn chisquare(&mut self, df: f64, size: usize) -> Result<Array<f64, Ix1>, FerrayError> {
166 if size == 0 {
167 return Err(FerrayError::invalid_value("size must be > 0"));
168 }
169 if df <= 0.0 {
170 return Err(FerrayError::invalid_value(format!(
171 "df must be positive, got {df}"
172 )));
173 }
174 let data = generate_vec(self, size, |bg| 2.0 * standard_gamma_single(bg, df / 2.0));
175 vec_to_array1(data)
176 }
177
178 pub fn f(
190 &mut self,
191 dfnum: f64,
192 dfden: f64,
193 size: usize,
194 ) -> Result<Array<f64, Ix1>, FerrayError> {
195 if size == 0 {
196 return Err(FerrayError::invalid_value("size must be > 0"));
197 }
198 if dfnum <= 0.0 {
199 return Err(FerrayError::invalid_value(format!(
200 "dfnum must be positive, got {dfnum}"
201 )));
202 }
203 if dfden <= 0.0 {
204 return Err(FerrayError::invalid_value(format!(
205 "dfden must be positive, got {dfden}"
206 )));
207 }
208 let data = generate_vec(self, size, |bg| {
209 let x1 = standard_gamma_single(bg, dfnum / 2.0);
210 let x2 = standard_gamma_single(bg, dfden / 2.0);
211 if x2 == 0.0 {
212 f64::INFINITY
213 } else {
214 (x1 / dfnum) / (x2 / dfden)
215 }
216 });
217 vec_to_array1(data)
218 }
219
220 pub fn student_t(&mut self, df: f64, size: usize) -> Result<Array<f64, Ix1>, FerrayError> {
231 if size == 0 {
232 return Err(FerrayError::invalid_value("size must be > 0"));
233 }
234 if df <= 0.0 {
235 return Err(FerrayError::invalid_value(format!(
236 "df must be positive, got {df}"
237 )));
238 }
239 let data = generate_vec(self, size, |bg| {
240 let z = standard_normal_single(bg);
241 let chi2 = 2.0 * standard_gamma_single(bg, df / 2.0);
242 z / (chi2 / df).sqrt()
243 });
244 vec_to_array1(data)
245 }
246}
247
248#[cfg(test)]
249mod tests {
250 use crate::default_rng_seeded;
251
252 #[test]
253 fn gamma_positive() {
254 let mut rng = default_rng_seeded(42);
255 let arr = rng.gamma(2.0, 1.0, 10_000).unwrap();
256 let slice = arr.as_slice().unwrap();
257 for &v in slice {
258 assert!(v > 0.0);
259 }
260 }
261
262 #[test]
263 fn gamma_mean_variance() {
264 let mut rng = default_rng_seeded(42);
265 let n = 100_000;
266 let shape = 3.0;
267 let scale = 2.0;
268 let arr = rng.gamma(shape, scale, n).unwrap();
269 let slice = arr.as_slice().unwrap();
270 let mean: f64 = slice.iter().sum::<f64>() / n as f64;
271 let var: f64 = slice.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n as f64;
272 let expected_mean = shape * scale;
274 let expected_var = shape * scale * scale;
275 let se = (expected_var / n as f64).sqrt();
276 assert!(
277 (mean - expected_mean).abs() < 3.0 * se,
278 "gamma mean {mean} too far from {expected_mean}"
279 );
280 assert!(
281 (var - expected_var).abs() / expected_var < 0.05,
282 "gamma variance {var} too far from {expected_var}"
283 );
284 }
285
286 #[test]
287 fn gamma_small_shape() {
288 let mut rng = default_rng_seeded(42);
289 let arr = rng.gamma(0.5, 1.0, 10_000).unwrap();
290 let slice = arr.as_slice().unwrap();
291 for &v in slice {
292 assert!(v > 0.0);
293 }
294 }
295
296 #[test]
297 fn beta_in_range() {
298 let mut rng = default_rng_seeded(42);
299 let arr = rng.beta(2.0, 5.0, 10_000).unwrap();
300 let slice = arr.as_slice().unwrap();
301 for &v in slice {
302 assert!(v > 0.0 && v < 1.0, "beta value {v} out of (0,1)");
303 }
304 }
305
306 #[test]
307 fn beta_mean() {
308 let mut rng = default_rng_seeded(42);
309 let n = 100_000;
310 let a = 2.0;
311 let b = 5.0;
312 let arr = rng.beta(a, b, n).unwrap();
313 let slice = arr.as_slice().unwrap();
314 let mean: f64 = slice.iter().sum::<f64>() / n as f64;
315 let expected_mean = a / (a + b);
317 let expected_var = (a * b) / ((a + b).powi(2) * (a + b + 1.0));
318 let se = (expected_var / n as f64).sqrt();
319 assert!(
320 (mean - expected_mean).abs() < 3.0 * se,
321 "beta mean {mean} too far from {expected_mean}"
322 );
323 }
324
325 #[test]
326 fn chisquare_positive() {
327 let mut rng = default_rng_seeded(42);
328 let arr = rng.chisquare(5.0, 10_000).unwrap();
329 let slice = arr.as_slice().unwrap();
330 for &v in slice {
331 assert!(v > 0.0);
332 }
333 }
334
335 #[test]
336 fn chisquare_mean() {
337 let mut rng = default_rng_seeded(42);
338 let n = 100_000;
339 let df = 10.0;
340 let arr = rng.chisquare(df, n).unwrap();
341 let slice = arr.as_slice().unwrap();
342 let mean: f64 = slice.iter().sum::<f64>() / n as f64;
343 let expected_var = 2.0 * df;
345 let se = (expected_var / n as f64).sqrt();
346 assert!(
347 (mean - df).abs() < 3.0 * se,
348 "chisquare mean {mean} too far from {df}"
349 );
350 }
351
352 #[test]
353 fn f_positive() {
354 let mut rng = default_rng_seeded(42);
355 let arr = rng.f(5.0, 10.0, 10_000).unwrap();
356 let slice = arr.as_slice().unwrap();
357 for &v in slice {
358 assert!(v > 0.0);
359 }
360 }
361
362 #[test]
363 fn student_t_symmetric() {
364 let mut rng = default_rng_seeded(42);
365 let n = 100_000;
366 let df = 10.0;
367 let arr = rng.student_t(df, n).unwrap();
368 let slice = arr.as_slice().unwrap();
369 let mean: f64 = slice.iter().sum::<f64>() / n as f64;
370 assert!(mean.abs() < 0.05, "student_t mean {mean} too far from 0");
372 }
373
374 #[test]
375 fn standard_gamma_mean() {
376 let mut rng = default_rng_seeded(42);
377 let n = 100_000;
378 let shape = 5.0;
379 let arr = rng.standard_gamma(shape, n).unwrap();
380 let slice = arr.as_slice().unwrap();
381 let mean: f64 = slice.iter().sum::<f64>() / n as f64;
382 let se = (shape / n as f64).sqrt();
383 assert!(
384 (mean - shape).abs() < 3.0 * se,
385 "standard_gamma mean {mean} too far from {shape}"
386 );
387 }
388
389 #[test]
390 fn gamma_bad_params() {
391 let mut rng = default_rng_seeded(42);
392 assert!(rng.gamma(0.0, 1.0, 100).is_err());
393 assert!(rng.gamma(1.0, 0.0, 100).is_err());
394 assert!(rng.gamma(-1.0, 1.0, 100).is_err());
395 }
396}