1use ferray_core::{Array, FerrayError, IxDyn};
7
8use crate::bitgen::BitGenerator;
9use crate::distributions::normal::standard_normal_single;
10use crate::generator::{Generator, generate_vec, shape_size, vec_to_array_f64};
11use crate::shape::IntoShape;
12
13pub(crate) fn standard_gamma_single<B: BitGenerator>(bg: &mut B, alpha: f64) -> f64 {
18 if alpha < 1.0 {
19 if alpha <= 0.0 {
21 return 0.0;
22 }
23 loop {
24 let u = bg.next_f64();
25 if u > f64::EPSILON {
26 let x = standard_gamma_ge1(bg, alpha + 1.0);
27 return x * u.powf(1.0 / alpha);
28 }
29 }
30 } else {
31 standard_gamma_ge1(bg, alpha)
32 }
33}
34
35fn standard_gamma_ge1<B: BitGenerator>(bg: &mut B, alpha: f64) -> f64 {
37 let d = alpha - 1.0 / 3.0;
38 let c = 1.0 / (9.0 * d).sqrt();
39
40 loop {
41 let x = standard_normal_single(bg);
42 let v_base = 1.0 + c * x;
43 if v_base <= 0.0 {
44 continue;
45 }
46 let v = v_base * v_base * v_base;
47 let u = bg.next_f64();
48 if u < (0.0331 * (x * x)).mul_add(-(x * x), 1.0) {
50 return d * v;
51 }
52 if u.ln() < (0.5 * x).mul_add(x, d * (1.0 - v + v.ln())) {
53 return d * v;
54 }
55 }
56}
57
58impl<B: BitGenerator> Generator<B> {
59 pub fn standard_gamma(
64 &mut self,
65 alpha: f64,
66 size: impl IntoShape,
67 ) -> Result<Array<f64, IxDyn>, FerrayError> {
68 if alpha <= 0.0 {
69 return Err(FerrayError::invalid_value(format!(
70 "alpha must be positive, got {alpha}"
71 )));
72 }
73 let shape_vec = size.into_shape()?;
74 let n = shape_size(&shape_vec);
75 let data = generate_vec(self, n, |bg| standard_gamma_single(bg, alpha));
76 vec_to_array_f64(data, &shape_vec)
77 }
78
79 pub fn gamma(
87 &mut self,
88 alpha: f64,
89 scale: f64,
90 size: impl IntoShape,
91 ) -> Result<Array<f64, IxDyn>, FerrayError> {
92 if alpha <= 0.0 {
93 return Err(FerrayError::invalid_value(format!(
94 "alpha must be positive, got {alpha}"
95 )));
96 }
97 if scale <= 0.0 {
98 return Err(FerrayError::invalid_value(format!(
99 "scale must be positive, got {scale}"
100 )));
101 }
102 let shape_vec = size.into_shape()?;
103 let n = shape_size(&shape_vec);
104 let data = generate_vec(self, n, |bg| scale * standard_gamma_single(bg, alpha));
105 vec_to_array_f64(data, &shape_vec)
106 }
107
108 pub fn beta(
115 &mut self,
116 a: f64,
117 b: f64,
118 size: impl IntoShape,
119 ) -> Result<Array<f64, IxDyn>, FerrayError> {
120 if a <= 0.0 {
121 return Err(FerrayError::invalid_value(format!(
122 "a must be positive, got {a}"
123 )));
124 }
125 if b <= 0.0 {
126 return Err(FerrayError::invalid_value(format!(
127 "b must be positive, got {b}"
128 )));
129 }
130 let shape_vec = size.into_shape()?;
131 let n = shape_size(&shape_vec);
132 let data = generate_vec(self, n, |bg| {
133 let x = standard_gamma_single(bg, a);
134 let y = standard_gamma_single(bg, b);
135 if x + y == 0.0 {
136 0.5 } else {
138 x / (x + y)
139 }
140 });
141 vec_to_array_f64(data, &shape_vec)
142 }
143
144 pub fn chisquare(
151 &mut self,
152 df: f64,
153 size: impl IntoShape,
154 ) -> Result<Array<f64, IxDyn>, FerrayError> {
155 if df <= 0.0 {
156 return Err(FerrayError::invalid_value(format!(
157 "df must be positive, got {df}"
158 )));
159 }
160 let shape_vec = size.into_shape()?;
161 let n = shape_size(&shape_vec);
162 let data = generate_vec(self, n, |bg| 2.0 * standard_gamma_single(bg, df / 2.0));
163 vec_to_array_f64(data, &shape_vec)
164 }
165
166 pub fn f(
173 &mut self,
174 dfnum: f64,
175 dfden: f64,
176 size: impl IntoShape,
177 ) -> Result<Array<f64, IxDyn>, FerrayError> {
178 if dfnum <= 0.0 {
179 return Err(FerrayError::invalid_value(format!(
180 "dfnum must be positive, got {dfnum}"
181 )));
182 }
183 if dfden <= 0.0 {
184 return Err(FerrayError::invalid_value(format!(
185 "dfden must be positive, got {dfden}"
186 )));
187 }
188 let shape_vec = size.into_shape()?;
189 let n = shape_size(&shape_vec);
190 let data = generate_vec(self, n, |bg| {
191 let x1 = standard_gamma_single(bg, dfnum / 2.0);
192 let x2 = standard_gamma_single(bg, dfden / 2.0);
193 if x2 == 0.0 {
194 f64::INFINITY
195 } else {
196 (x1 / dfnum) / (x2 / dfden)
197 }
198 });
199 vec_to_array_f64(data, &shape_vec)
200 }
201
202 pub fn student_t(
209 &mut self,
210 df: f64,
211 size: impl IntoShape,
212 ) -> Result<Array<f64, IxDyn>, FerrayError> {
213 if df <= 0.0 {
214 return Err(FerrayError::invalid_value(format!(
215 "df must be positive, got {df}"
216 )));
217 }
218 let shape_vec = size.into_shape()?;
219 let n = shape_size(&shape_vec);
220 let data = generate_vec(self, n, |bg| {
221 let z = standard_normal_single(bg);
222 let chi2 = 2.0 * standard_gamma_single(bg, df / 2.0);
223 z / (chi2 / df).sqrt()
224 });
225 vec_to_array_f64(data, &shape_vec)
226 }
227
228 pub fn standard_t(
233 &mut self,
234 df: f64,
235 size: impl IntoShape,
236 ) -> Result<Array<f64, IxDyn>, FerrayError> {
237 self.student_t(df, size)
238 }
239
240 pub fn noncentral_chisquare(
252 &mut self,
253 df: f64,
254 nonc: f64,
255 size: impl IntoShape,
256 ) -> Result<Array<f64, IxDyn>, FerrayError> {
257 if df <= 0.0 {
258 return Err(FerrayError::invalid_value(format!(
259 "df must be positive, got {df}"
260 )));
261 }
262 if nonc < 0.0 {
263 return Err(FerrayError::invalid_value(format!(
264 "nonc must be non-negative, got {nonc}"
265 )));
266 }
267 let shape_vec = size.into_shape()?;
268 let n = shape_size(&shape_vec);
269 let data = generate_vec(self, n, |bg| {
270 let lam = nonc / 2.0;
274 let n_pois: u64 = if lam == 0.0 {
275 0
276 } else {
277 let l = (-lam).exp();
278 let mut k: u64 = 0;
279 let mut p = 1.0;
280 loop {
281 k += 1;
282 p *= bg.next_f64();
283 if p <= l {
284 break k - 1;
285 }
286 }
287 };
288 let total_df = df + 2.0 * (n_pois as f64);
289 2.0 * standard_gamma_single(bg, total_df / 2.0)
290 });
291 vec_to_array_f64(data, &shape_vec)
292 }
293
294 pub fn noncentral_f(
305 &mut self,
306 dfnum: f64,
307 dfden: f64,
308 nonc: f64,
309 size: impl IntoShape,
310 ) -> Result<Array<f64, IxDyn>, FerrayError> {
311 if dfnum <= 0.0 {
312 return Err(FerrayError::invalid_value(format!(
313 "dfnum must be positive, got {dfnum}"
314 )));
315 }
316 if dfden <= 0.0 {
317 return Err(FerrayError::invalid_value(format!(
318 "dfden must be positive, got {dfden}"
319 )));
320 }
321 if nonc < 0.0 {
322 return Err(FerrayError::invalid_value(format!(
323 "nonc must be non-negative, got {nonc}"
324 )));
325 }
326 let shape_vec = size.into_shape()?;
327 let n = shape_size(&shape_vec);
328 let data = generate_vec(self, n, |bg| {
329 let lam = nonc / 2.0;
331 let n_pois: u64 = if lam == 0.0 {
332 0
333 } else {
334 let l = (-lam).exp();
335 let mut k: u64 = 0;
336 let mut p = 1.0;
337 loop {
338 k += 1;
339 p *= bg.next_f64();
340 if p <= l {
341 break k - 1;
342 }
343 }
344 };
345 let total_dfnum = dfnum + 2.0 * (n_pois as f64);
346 let chi2_num = 2.0 * standard_gamma_single(bg, total_dfnum / 2.0);
347 let chi2_den = 2.0 * standard_gamma_single(bg, dfden / 2.0);
348 if chi2_den == 0.0 {
349 f64::INFINITY
350 } else {
351 (chi2_num / dfnum) / (chi2_den / dfden)
352 }
353 });
354 vec_to_array_f64(data, &shape_vec)
355 }
356}
357
358#[cfg(test)]
359mod tests {
360 use crate::default_rng_seeded;
361
362 #[test]
363 fn gamma_positive() {
364 let mut rng = default_rng_seeded(42);
365 let arr = rng.gamma(2.0, 1.0, 10_000).unwrap();
366 let slice = arr.as_slice().unwrap();
367 for &v in slice {
368 assert!(v > 0.0);
369 }
370 }
371
372 #[test]
373 fn gamma_mean_variance() {
374 let mut rng = default_rng_seeded(42);
375 let n = 100_000;
376 let shape = 3.0;
377 let scale = 2.0;
378 let arr = rng.gamma(shape, scale, n).unwrap();
379 let slice = arr.as_slice().unwrap();
380 let mean: f64 = slice.iter().sum::<f64>() / n as f64;
381 let var: f64 = slice.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n as f64;
382 let expected_mean = shape * scale;
384 let expected_var = shape * scale * scale;
385 let se = (expected_var / n as f64).sqrt();
386 assert!(
387 (mean - expected_mean).abs() < 3.0 * se,
388 "gamma mean {mean} too far from {expected_mean}"
389 );
390 assert!(
391 (var - expected_var).abs() / expected_var < 0.05,
392 "gamma variance {var} too far from {expected_var}"
393 );
394 }
395
396 #[test]
397 fn gamma_small_shape() {
398 let mut rng = default_rng_seeded(42);
399 let arr = rng.gamma(0.5, 1.0, 10_000).unwrap();
400 let slice = arr.as_slice().unwrap();
401 for &v in slice {
402 assert!(v > 0.0);
403 }
404 }
405
406 #[test]
407 fn beta_in_range() {
408 let mut rng = default_rng_seeded(42);
409 let arr = rng.beta(2.0, 5.0, 10_000).unwrap();
410 let slice = arr.as_slice().unwrap();
411 for &v in slice {
412 assert!(v > 0.0 && v < 1.0, "beta value {v} out of (0,1)");
413 }
414 }
415
416 #[test]
417 fn beta_mean() {
418 let mut rng = default_rng_seeded(42);
419 let n = 100_000;
420 let a = 2.0;
421 let b = 5.0;
422 let arr = rng.beta(a, b, n).unwrap();
423 let slice = arr.as_slice().unwrap();
424 let mean: f64 = slice.iter().sum::<f64>() / n as f64;
425 let expected_mean = a / (a + b);
427 let expected_var = (a * b) / ((a + b).powi(2) * (a + b + 1.0));
428 let se = (expected_var / n as f64).sqrt();
429 assert!(
430 (mean - expected_mean).abs() < 3.0 * se,
431 "beta mean {mean} too far from {expected_mean}"
432 );
433 }
434
435 #[test]
436 fn chisquare_positive() {
437 let mut rng = default_rng_seeded(42);
438 let arr = rng.chisquare(5.0, 10_000).unwrap();
439 let slice = arr.as_slice().unwrap();
440 for &v in slice {
441 assert!(v > 0.0);
442 }
443 }
444
445 #[test]
446 fn chisquare_mean() {
447 let mut rng = default_rng_seeded(42);
448 let n = 100_000;
449 let df = 10.0;
450 let arr = rng.chisquare(df, n).unwrap();
451 let slice = arr.as_slice().unwrap();
452 let mean: f64 = slice.iter().sum::<f64>() / n as f64;
453 let expected_var = 2.0 * df;
455 let se = (expected_var / n as f64).sqrt();
456 assert!(
457 (mean - df).abs() < 3.0 * se,
458 "chisquare mean {mean} too far from {df}"
459 );
460 }
461
462 #[test]
463 fn f_positive() {
464 let mut rng = default_rng_seeded(42);
465 let arr = rng.f(5.0, 10.0, 10_000).unwrap();
466 let slice = arr.as_slice().unwrap();
467 for &v in slice {
468 assert!(v > 0.0);
469 }
470 }
471
472 #[test]
473 fn student_t_symmetric() {
474 let mut rng = default_rng_seeded(42);
475 let n = 100_000;
476 let df = 10.0;
477 let arr = rng.student_t(df, n).unwrap();
478 let slice = arr.as_slice().unwrap();
479 let mean: f64 = slice.iter().sum::<f64>() / n as f64;
480 assert!(mean.abs() < 0.05, "student_t mean {mean} too far from 0");
482 }
483
484 #[test]
485 fn standard_gamma_mean() {
486 let mut rng = default_rng_seeded(42);
487 let n = 100_000;
488 let shape = 5.0;
489 let arr = rng.standard_gamma(shape, n).unwrap();
490 let slice = arr.as_slice().unwrap();
491 let mean: f64 = slice.iter().sum::<f64>() / n as f64;
492 let se = (shape / n as f64).sqrt();
493 assert!(
494 (mean - shape).abs() < 3.0 * se,
495 "standard_gamma mean {mean} too far from {shape}"
496 );
497 }
498
499 #[test]
500 fn gamma_bad_params() {
501 let mut rng = default_rng_seeded(42);
502 assert!(rng.gamma(0.0, 1.0, 100).is_err());
503 assert!(rng.gamma(1.0, 0.0, 100).is_err());
504 assert!(rng.gamma(-1.0, 1.0, 100).is_err());
505 }
506
507 #[test]
508 fn standard_t_alias_matches_student_t() {
509 let mut rng_a = default_rng_seeded(7);
511 let mut rng_b = default_rng_seeded(7);
512 let a = rng_a.student_t(5.0, 100).unwrap();
513 let b = rng_b.standard_t(5.0, 100).unwrap();
514 assert_eq!(a.as_slice().unwrap(), b.as_slice().unwrap());
515 }
516
517 #[test]
518 fn noncentral_chisquare_mean_approx() {
519 let mut rng = default_rng_seeded(42);
521 let n = 50_000;
522 let arr = rng.noncentral_chisquare(5.0, 3.0, n).unwrap();
523 let s = arr.as_slice().unwrap();
524 let mean: f64 = s.iter().sum::<f64>() / n as f64;
525 assert!((mean - 8.0).abs() < 0.5, "noncentral_chisquare mean {mean}");
527 }
528
529 #[test]
530 fn noncentral_chisquare_zero_lambda_matches_chisquare() {
531 let mut rng_a = default_rng_seeded(11);
532 let mut rng_b = default_rng_seeded(11);
533 let a = rng_a.noncentral_chisquare(4.0, 0.0, 1000).unwrap();
534 let b = rng_b.chisquare(4.0, 1000).unwrap();
535 for (x, y) in a.as_slice().unwrap().iter().zip(b.as_slice().unwrap()) {
537 assert!((x - y).abs() < 1e-12);
538 }
539 }
540
541 #[test]
542 fn noncentral_chisquare_bad_params() {
543 let mut rng = default_rng_seeded(0);
544 assert!(rng.noncentral_chisquare(0.0, 1.0, 10).is_err());
545 assert!(rng.noncentral_chisquare(1.0, -1.0, 10).is_err());
546 }
547
548 #[test]
549 fn noncentral_f_positive() {
550 let mut rng = default_rng_seeded(100);
551 let arr = rng.noncentral_f(5.0, 7.0, 2.0, 1000).unwrap();
552 for &v in arr.as_slice().unwrap() {
553 assert!(v >= 0.0);
554 }
555 }
556
557 #[test]
558 fn noncentral_f_bad_params() {
559 let mut rng = default_rng_seeded(0);
560 assert!(rng.noncentral_f(0.0, 1.0, 1.0, 10).is_err());
561 assert!(rng.noncentral_f(1.0, 0.0, 1.0, 10).is_err());
562 assert!(rng.noncentral_f(1.0, 1.0, -1.0, 10).is_err());
563 }
564}