1use ferray_core::{Array, FerrayError, IxDyn};
6
7use crate::bitgen::BitGenerator;
8use crate::distributions::gamma::standard_gamma_single;
9use crate::generator::{Generator, generate_vec_i64, shape_size, vec_to_array_i64};
10use crate::shape::IntoShape;
11
12fn poisson_single<B: BitGenerator>(bg: &mut B, lam: f64) -> i64 {
15 if lam < 30.0 {
16 let l = (-lam).exp();
18 let mut k: i64 = 0;
19 let mut p = 1.0;
20 loop {
21 k += 1;
22 p *= bg.next_f64();
23 if p <= l {
24 return k - 1;
25 }
26 }
27 } else {
28 let slam = lam.sqrt();
30 let loglam = lam.ln();
31 let b = 2.53f64.mul_add(slam, 0.931);
32 let a = 0.02483f64.mul_add(b, -0.059);
33 let inv_alpha = 1.1239 + 1.1328 / (b - 3.4);
34 let vr = 0.9277 - 3.6224 / (b - 2.0);
35
36 loop {
37 let u = bg.next_f64() - 0.5;
38 let v = bg.next_f64();
39 let us = 0.5 - u.abs();
40 let k = ((2.0 * a / us + b).mul_add(u, lam) + 0.43).floor() as i64;
41 if k < 0 {
42 continue;
43 }
44 if us >= 0.07 && v <= vr {
45 return k;
46 }
47 if k > 0
48 && us >= 0.013
49 && v <= (k as f64)
50 .ln()
51 .mul_add(
52 -0.5,
53 (k as f64).mul_add(loglam, -lam) - ln_factorial(k as u64),
54 )
55 .exp()
56 * inv_alpha
57 {
58 return k;
59 }
60 if us < 0.013 && v > us {
61 continue;
62 }
63 let kf = k as f64;
65 let log_accept = -lam + kf * loglam - ln_factorial(k as u64);
66 if v.ln() + inv_alpha.ln() - (a / (us * us) + b).ln() <= log_accept {
67 return k;
68 }
69 }
70 }
71}
72
73fn ln_factorial(n: u64) -> f64 {
75 if n <= 20 {
76 let mut result = 0.0_f64;
78 for i in 2..=n {
79 result += (i as f64).ln();
80 }
81 result
82 } else {
83 let nf = n as f64;
85 0.5f64.mul_add((std::f64::consts::TAU).ln(), (nf + 0.5) * nf.ln()) - nf + 1.0 / (12.0 * nf)
86 - 1.0 / (360.0 * nf * nf * nf)
87 }
88}
89
90fn binomial_single<B: BitGenerator>(bg: &mut B, n: u64, p: f64) -> i64 {
93 if n == 0 || p == 0.0 {
94 return 0;
95 }
96 if p == 1.0 {
97 return n as i64;
98 }
99
100 let (pp, flipped) = if p > 0.5 { (1.0 - p, true) } else { (p, false) };
102
103 let np = n as f64 * pp;
104 let q = 1.0 - pp;
105
106 let result = if np < 30.0 {
107 let s = pp / q;
109 let a = (n as f64 + 1.0) * s;
110 let mut r = q.powf(n as f64);
111 let mut u = bg.next_f64();
112 let mut x: i64 = 0;
113 while u > r {
114 u -= r;
115 x += 1;
116 r *= a / (x as f64) - s;
117 if r < 0.0 {
118 break;
119 }
120 }
121 x.min(n as i64)
122 } else {
123 let fm = np + pp;
127 let m = fm.floor() as i64;
128 let mf = m as f64;
129 let p1 = 2.195f64.mul_add((np * q).sqrt(), -(4.6 * q)).floor() + 0.5;
130 let xm = mf + 0.5;
131 let xl = xm - p1;
132 let xr = xm + p1;
133 let c = 0.134 + 20.5 / (15.3 + mf);
134 let a = (fm - xl) / (fm - xl * pp);
135 let lambda_l = a * 0.5f64.mul_add(a, 1.0);
136 let a2 = (xr - fm) / (xr * q);
137 let lambda_r = a2 * 0.5f64.mul_add(a2, 1.0);
138 let p2 = p1 * 2.0f64.mul_add(c, 1.0);
139 let p3 = p2 + c / lambda_l;
140 let p4 = p3 + c / lambda_r;
141
142 loop {
143 let u = bg.next_f64() * p4;
144 let v = bg.next_f64();
145 let y: i64;
146
147 if u <= p1 {
148 y = (xm - p1 * v + u).floor() as i64;
150 } else if u <= p2 {
151 let x = xl + (u - p1) / c;
153 #[allow(clippy::suspicious_operation_groupings)]
158 let w = v + (x - xm) * (x - xm) / (p1 * p1);
159 if w > 1.0 {
160 continue;
161 }
162 y = x.floor() as i64;
163 } else if u <= p3 {
164 y = (xl + v.ln() / lambda_l).floor() as i64;
166 if y < 0 {
167 continue;
168 }
169 } else {
170 y = (xr - v.ln() / lambda_r).floor() as i64;
172 if y > n as i64 {
173 continue;
174 }
175 }
176
177 let k = (y - m).abs();
179 if k <= 20 || k as f64 >= (0.5 * np).mul_add(q, -1.0) {
180 let kf = k as f64;
182 let yf = y as f64;
183 let rho =
184 (kf / (np * q)) * (kf.mul_add(kf / 3.0 + 0.625, 1.0 / 6.0) / (np * q) + 0.5);
185 let t = -kf * kf / (2.0 * np * q);
186 let log_a = t - rho;
187 if v.ln() <= log_a {
188 break y;
189 }
190 let log_v = v.ln();
192 let log_accept = (yf - mf).mul_add(
193 (pp / q).ln(),
194 ln_factorial(m as u64) - ln_factorial(y as u64) - ln_factorial(n - y as u64)
195 + ln_factorial(n - m as u64),
196 );
197 if log_v <= log_accept {
198 break y;
199 }
200 } else {
201 break y;
202 }
203 }
204 };
205
206 if flipped { n as i64 - result } else { result }
207}
208
209impl<B: BitGenerator> Generator<B> {
210 pub fn binomial(
223 &mut self,
224 n: u64,
225 p: f64,
226 size: impl IntoShape,
227 ) -> Result<Array<i64, IxDyn>, FerrayError> {
228 if !(0.0..=1.0).contains(&p) {
229 return Err(FerrayError::invalid_value(format!(
230 "p must be in [0, 1], got {p}"
231 )));
232 }
233 let shape_vec = size.into_shape()?;
234 let total = shape_size(&shape_vec);
235 let data = generate_vec_i64(self, total, |bg| binomial_single(bg, n, p));
236 vec_to_array_i64(data, &shape_vec)
237 }
238
239 pub fn negative_binomial(
252 &mut self,
253 n: f64,
254 p: f64,
255 size: impl IntoShape,
256 ) -> Result<Array<i64, IxDyn>, FerrayError> {
257 if n <= 0.0 {
258 return Err(FerrayError::invalid_value(format!(
259 "n must be positive, got {n}"
260 )));
261 }
262 if p <= 0.0 || p > 1.0 {
263 return Err(FerrayError::invalid_value(format!(
264 "p must be in (0, 1], got {p}"
265 )));
266 }
267 let shape_vec = size.into_shape()?;
268 let total = shape_size(&shape_vec);
269 let data = generate_vec_i64(self, total, |bg| {
270 let y = standard_gamma_single(bg, n) * (1.0 - p) / p;
273 poisson_single(bg, y)
274 });
275 vec_to_array_i64(data, &shape_vec)
276 }
277
278 pub fn poisson(
287 &mut self,
288 lam: f64,
289 size: impl IntoShape,
290 ) -> Result<Array<i64, IxDyn>, FerrayError> {
291 if lam < 0.0 {
292 return Err(FerrayError::invalid_value(format!(
293 "lam must be non-negative, got {lam}"
294 )));
295 }
296 let shape_vec = size.into_shape()?;
297 let total = shape_size(&shape_vec);
298 if lam == 0.0 {
299 let data = vec![0i64; total];
300 return vec_to_array_i64(data, &shape_vec);
301 }
302 let data = generate_vec_i64(self, total, |bg| poisson_single(bg, lam));
303 vec_to_array_i64(data, &shape_vec)
304 }
305
306 pub fn poisson_array(
316 &mut self,
317 lam: &Array<f64, IxDyn>,
318 ) -> Result<Array<i64, IxDyn>, FerrayError> {
319 let shape = lam.shape().to_vec();
320 let total: usize = shape.iter().product();
321 let mut out: Vec<i64> = Vec::with_capacity(total);
322 for &l in lam.iter() {
323 if l < 0.0 {
324 return Err(FerrayError::invalid_value(format!(
325 "lam must be non-negative, got {l}"
326 )));
327 }
328 if l == 0.0 {
329 out.push(0);
330 } else {
331 out.push(poisson_single(&mut self.bg, l));
332 }
333 }
334 Array::<i64, IxDyn>::from_vec(IxDyn::new(&shape), out)
335 }
336
337 pub fn geometric(
348 &mut self,
349 p: f64,
350 size: impl IntoShape,
351 ) -> Result<Array<i64, IxDyn>, FerrayError> {
352 if p <= 0.0 || p > 1.0 {
353 return Err(FerrayError::invalid_value(format!(
354 "p must be in (0, 1], got {p}"
355 )));
356 }
357 let shape_vec = size.into_shape()?;
358 let total = shape_size(&shape_vec);
359 if (p - 1.0).abs() < f64::EPSILON {
360 let data = vec![1i64; total];
361 return vec_to_array_i64(data, &shape_vec);
362 }
363 let log_q = (1.0 - p).ln();
364 let data = generate_vec_i64(self, total, |bg| {
365 loop {
366 let u = bg.next_f64();
367 if u > f64::EPSILON {
368 return (u.ln() / log_q).floor() as i64 + 1;
369 }
370 }
371 });
372 vec_to_array_i64(data, &shape_vec)
373 }
374
375 pub fn hypergeometric(
389 &mut self,
390 ngood: u64,
391 nbad: u64,
392 nsample: u64,
393 size: impl IntoShape,
394 ) -> Result<Array<i64, IxDyn>, FerrayError> {
395 let total = ngood + nbad;
396 if nsample > total {
397 return Err(FerrayError::invalid_value(format!(
398 "nsample ({nsample}) > ngood + nbad ({total})"
399 )));
400 }
401 let shape_vec = size.into_shape()?;
402 let total_n = shape_size(&shape_vec);
403 let data = generate_vec_i64(self, total_n, |bg| {
404 hypergeometric_single(bg, ngood, nbad, nsample)
405 });
406 vec_to_array_i64(data, &shape_vec)
407 }
408
409 pub fn logseries(
418 &mut self,
419 p: f64,
420 size: impl IntoShape,
421 ) -> Result<Array<i64, IxDyn>, FerrayError> {
422 if p <= 0.0 || p >= 1.0 {
423 return Err(FerrayError::invalid_value(format!(
424 "p must be in (0, 1), got {p}"
425 )));
426 }
427 let r = (-(-p).ln_1p()).recip();
428 let shape_vec = size.into_shape()?;
429 let total = shape_size(&shape_vec);
430 let data = generate_vec_i64(self, total, |bg| {
431 loop {
434 let u = bg.next_f64();
435 if u <= f64::EPSILON || u >= 1.0 - f64::EPSILON {
436 continue;
437 }
438 let v = bg.next_f64();
439 let q = 1.0 - (-r.recip() * u.ln()).exp();
440 if q <= 0.0 {
441 return 1;
442 }
443 if v < q * q {
444 let k = (1.0 + v.log(q)).floor() as i64;
445 return k.max(1);
446 }
447 if v < q {
448 return 2;
449 }
450 return 1;
451 }
452 });
453 vec_to_array_i64(data, &shape_vec)
454 }
455
456 pub fn zipf(&mut self, a: f64, size: impl IntoShape) -> Result<Array<i64, IxDyn>, FerrayError> {
468 if a <= 1.0 {
469 return Err(FerrayError::invalid_value(format!(
470 "a must be > 1 for Zipf, got {a}"
471 )));
472 }
473 let am1 = a - 1.0;
474 let b = 2.0_f64.powf(am1);
475 let shape_vec = size.into_shape()?;
476 let total = shape_size(&shape_vec);
477 let data = generate_vec_i64(self, total, |bg| {
478 loop {
479 let u = 1.0 - bg.next_f64();
480 let v = bg.next_f64();
481 let x = u.powf(-1.0 / am1).floor();
482 if !x.is_finite() || x < 1.0 {
484 continue;
485 }
486 let t = (1.0 + 1.0 / x).powf(am1);
487 if v * x * (t - 1.0) / (b - 1.0) <= t / b {
489 if x > i64::MAX as f64 {
490 continue;
491 }
492 return x as i64;
493 }
494 }
495 });
496 vec_to_array_i64(data, &shape_vec)
497 }
498}
499
500fn hypergeometric_single<B: BitGenerator>(bg: &mut B, ngood: u64, nbad: u64, nsample: u64) -> i64 {
502 let mut good_remaining = ngood;
504 let mut total_remaining = ngood + nbad;
505 let mut successes: i64 = 0;
506
507 for _ in 0..nsample {
508 if total_remaining == 0 {
509 break;
510 }
511 let u = bg.next_f64();
512 if u < (good_remaining as f64) / (total_remaining as f64) {
513 successes += 1;
514 good_remaining -= 1;
515 }
516 total_remaining -= 1;
517 }
518 successes
519}
520
521#[cfg(test)]
522mod tests {
523 use crate::default_rng_seeded;
524
525 #[test]
528 fn poisson_array_shape_matches_lam() {
529 use crate::default_rng_seeded;
530 use ferray_core::{Array, IxDyn};
531 let mut rng = default_rng_seeded(42);
532 let lam =
533 Array::<f64, IxDyn>::from_vec(IxDyn::new(&[2, 2]), vec![1.0, 5.0, 50.0, 0.0]).unwrap();
534 let out = rng.poisson_array(&lam).unwrap();
535 assert_eq!(out.shape(), &[2, 2]);
536 let s = out.as_slice().unwrap();
538 assert_eq!(s[3], 0);
539 for &v in s {
540 assert!(v >= 0);
541 }
542 }
543
544 #[test]
545 fn poisson_array_per_element_mean() {
546 use crate::default_rng_seeded;
547 use ferray_core::{Array, IxDyn};
548 let mut rng = default_rng_seeded(11);
549 let lams = [3.0_f64, 50.0];
550 let lam = Array::<f64, IxDyn>::from_vec(IxDyn::new(&[2]), lams.to_vec()).unwrap();
551 let n_trials = 5_000;
552 let mut sums = [0.0_f64; 2];
553 for _ in 0..n_trials {
554 let out = rng.poisson_array(&lam).unwrap();
555 let s = out.as_slice().unwrap();
556 for j in 0..2 {
557 sums[j] += s[j] as f64;
558 }
559 }
560 for j in 0..2 {
561 let mean = sums[j] / n_trials as f64;
562 let se = (lams[j] / n_trials as f64).sqrt();
564 assert!(
565 (mean - lams[j]).abs() < 4.0 * se,
566 "elt {j}: mean {mean} too far from {}",
567 lams[j]
568 );
569 }
570 }
571
572 #[test]
573 fn poisson_array_negative_lam_errors() {
574 use crate::default_rng_seeded;
575 use ferray_core::{Array, IxDyn};
576 let mut rng = default_rng_seeded(0);
577 let lam = Array::<f64, IxDyn>::from_vec(IxDyn::new(&[2]), vec![1.0, -2.0]).unwrap();
578 assert!(rng.poisson_array(&lam).is_err());
579 }
580
581 #[test]
582 fn poisson_mean() {
583 let mut rng = default_rng_seeded(42);
584 let n = 100_000;
585 let lam = 5.0;
586 let arr = rng.poisson(lam, n).unwrap();
587 let slice = arr.as_slice().unwrap();
588 let mean: f64 = slice.iter().map(|&x| x as f64).sum::<f64>() / n as f64;
589 let se = (lam / n as f64).sqrt();
591 assert!(
592 (mean - lam).abs() < 3.0 * se,
593 "poisson mean {mean} too far from {lam}"
594 );
595 }
596
597 #[test]
598 fn poisson_large_lambda() {
599 let mut rng = default_rng_seeded(42);
600 let n = 50_000;
601 let lam = 100.0;
602 let arr = rng.poisson(lam, n).unwrap();
603 let slice = arr.as_slice().unwrap();
604 let mean: f64 = slice.iter().map(|&x| x as f64).sum::<f64>() / n as f64;
605 let se = (lam / n as f64).sqrt();
606 assert!(
607 (mean - lam).abs() < 3.0 * se,
608 "poisson mean {mean} too far from {lam}"
609 );
610 }
611
612 #[test]
613 fn poisson_zero() {
614 let mut rng = default_rng_seeded(42);
615 let arr = rng.poisson(0.0, 100).unwrap();
616 for &v in arr.as_slice().unwrap() {
617 assert_eq!(v, 0);
618 }
619 }
620
621 #[test]
622 fn binomial_mean() {
623 let mut rng = default_rng_seeded(42);
624 let size = 100_000;
625 let n = 20u64;
626 let p = 0.3;
627 let arr = rng.binomial(n, p, size).unwrap();
628 let slice = arr.as_slice().unwrap();
629 let mean: f64 = slice.iter().map(|&x| x as f64).sum::<f64>() / size as f64;
630 let expected_mean = n as f64 * p;
632 let expected_var = n as f64 * p * (1.0 - p);
633 let se = (expected_var / size as f64).sqrt();
634 assert!(
635 (mean - expected_mean).abs() < 3.0 * se,
636 "binomial mean {mean} too far from {expected_mean}"
637 );
638 for &v in slice {
640 assert!(
641 v >= 0 && v <= n as i64,
642 "binomial value {v} out of [0, {n}]"
643 );
644 }
645 }
646
647 #[test]
648 fn binomial_edge_cases() {
649 let mut rng = default_rng_seeded(42);
650 let arr = rng.binomial(10, 0.0, 100).unwrap();
652 for &v in arr.as_slice().unwrap() {
653 assert_eq!(v, 0);
654 }
655 let arr = rng.binomial(10, 1.0, 100).unwrap();
657 for &v in arr.as_slice().unwrap() {
658 assert_eq!(v, 10);
659 }
660 }
661
662 #[test]
663 fn negative_binomial_positive() {
664 let mut rng = default_rng_seeded(42);
665 let arr = rng.negative_binomial(5.0, 0.5, 10_000).unwrap();
666 for &v in arr.as_slice().unwrap() {
667 assert!(v >= 0, "negative_binomial value {v} must be >= 0");
668 }
669 }
670
671 #[test]
672 fn geometric_mean() {
673 let mut rng = default_rng_seeded(42);
674 let n = 100_000;
675 let p = 0.3;
676 let arr = rng.geometric(p, n).unwrap();
677 let slice = arr.as_slice().unwrap();
678 let mean: f64 = slice.iter().map(|&x| x as f64).sum::<f64>() / n as f64;
679 let expected_mean = 1.0 / p;
681 let expected_var = (1.0 - p) / (p * p);
682 let se = (expected_var / n as f64).sqrt();
683 assert!(
684 (mean - expected_mean).abs() < 3.0 * se,
685 "geometric mean {mean} too far from {expected_mean}"
686 );
687 for &v in slice {
688 assert!(v >= 1, "geometric value {v} must be >= 1");
689 }
690 }
691
692 #[test]
693 fn hypergeometric_range() {
694 let mut rng = default_rng_seeded(42);
695 let ngood = 20u64;
696 let nbad = 30u64;
697 let nsample = 10u64;
698 let arr = rng.hypergeometric(ngood, nbad, nsample, 10_000).unwrap();
699 let slice = arr.as_slice().unwrap();
700 for &v in slice {
701 assert!(
702 v >= 0 && v <= nsample.min(ngood) as i64,
703 "hypergeometric value {v} out of range"
704 );
705 }
706 }
707
708 #[test]
709 fn hypergeometric_mean() {
710 let mut rng = default_rng_seeded(42);
711 let n = 100_000;
712 let ngood = 20u64;
713 let nbad = 30u64;
714 let nsample = 10u64;
715 let arr = rng.hypergeometric(ngood, nbad, nsample, n).unwrap();
716 let slice = arr.as_slice().unwrap();
717 let mean: f64 = slice.iter().map(|&x| x as f64).sum::<f64>() / n as f64;
718 let total = (ngood + nbad) as f64;
720 let expected_mean = nsample as f64 * ngood as f64 / total;
721 let expected_var = nsample as f64
722 * (ngood as f64 / total)
723 * (nbad as f64 / total)
724 * (total - nsample as f64)
725 / (total - 1.0);
726 let se = (expected_var / n as f64).sqrt();
727 assert!(
728 (mean - expected_mean).abs() < 3.0 * se,
729 "hypergeometric mean {mean} too far from {expected_mean}"
730 );
731 }
732
733 #[test]
734 fn logseries_positive() {
735 let mut rng = default_rng_seeded(42);
736 let arr = rng.logseries(0.5, 10_000).unwrap();
737 for &v in arr.as_slice().unwrap() {
738 assert!(v >= 1, "logseries value {v} must be >= 1");
739 }
740 }
741
742 #[test]
743 fn bad_params() {
744 let mut rng = default_rng_seeded(42);
745 assert!(rng.binomial(10, -0.1, 10).is_err());
746 assert!(rng.binomial(10, 1.5, 10).is_err());
747 assert!(rng.poisson(-1.0, 10).is_err());
748 assert!(rng.geometric(0.0, 10).is_err());
749 assert!(rng.geometric(1.5, 10).is_err());
750 assert!(rng.hypergeometric(5, 5, 20, 10).is_err());
751 assert!(rng.logseries(0.0, 10).is_err());
752 assert!(rng.logseries(1.0, 10).is_err());
753 assert!(rng.negative_binomial(0.0, 0.5, 10).is_err());
754 assert!(rng.negative_binomial(5.0, 0.0, 10).is_err());
755 }
756
757 #[test]
758 fn zipf_positive_integers() {
759 use crate::default_rng_seeded;
760 let mut rng = default_rng_seeded(42);
761 let arr = rng.zipf(2.5, 1000).unwrap();
762 for &v in arr.as_slice().unwrap() {
763 assert!(v >= 1, "zipf output must be >= 1, got {v}");
764 }
765 }
766
767 #[test]
768 fn zipf_seed_reproducible() {
769 use crate::default_rng_seeded;
770 let mut a = default_rng_seeded(7);
771 let mut b = default_rng_seeded(7);
772 let xs = a.zipf(3.0, 200).unwrap();
773 let ys = b.zipf(3.0, 200).unwrap();
774 assert_eq!(xs.as_slice().unwrap(), ys.as_slice().unwrap());
775 }
776
777 #[test]
778 fn zipf_bad_a_errs() {
779 use crate::default_rng_seeded;
780 let mut rng = default_rng_seeded(0);
781 assert!(rng.zipf(1.0, 10).is_err());
782 assert!(rng.zipf(0.5, 10).is_err());
783 assert!(rng.zipf(-2.0, 10).is_err());
784 }
785}