ferray_random/distributions/
discrete.rs1use ferray_core::{Array, FerrayError, IxDyn};
6
7use crate::bitgen::BitGenerator;
8use crate::distributions::gamma::standard_gamma_single;
9use crate::generator::{Generator, generate_vec_i64, shape_size, vec_to_array_i64};
10use crate::shape::IntoShape;
11
12fn poisson_single<B: BitGenerator>(bg: &mut B, lam: f64) -> i64 {
15 if lam < 30.0 {
16 let l = (-lam).exp();
18 let mut k: i64 = 0;
19 let mut p = 1.0;
20 loop {
21 k += 1;
22 p *= bg.next_f64();
23 if p <= l {
24 return k - 1;
25 }
26 }
27 } else {
28 let slam = lam.sqrt();
30 let loglam = lam.ln();
31 let b = 0.931 + 2.53 * slam;
32 let a = -0.059 + 0.02483 * b;
33 let inv_alpha = 1.1239 + 1.1328 / (b - 3.4);
34 let vr = 0.9277 - 3.6224 / (b - 2.0);
35
36 loop {
37 let u = bg.next_f64() - 0.5;
38 let v = bg.next_f64();
39 let us = 0.5 - u.abs();
40 let k = ((2.0 * a / us + b) * u + lam + 0.43).floor() as i64;
41 if k < 0 {
42 continue;
43 }
44 if us >= 0.07 && v <= vr {
45 return k;
46 }
47 if k > 0
48 && us >= 0.013
49 && v <= (k as f64)
50 .ln()
51 .mul_add(-0.5, (k as f64) * loglam - lam - ln_factorial(k as u64))
52 .exp()
53 * inv_alpha
54 {
55 return k;
56 }
57 if us < 0.013 && v > us {
58 continue;
59 }
60 let kf = k as f64;
62 let log_accept = -lam + kf * loglam - ln_factorial(k as u64);
63 if v.ln() + inv_alpha.ln() - (a / (us * us) + b).ln() <= log_accept {
64 return k;
65 }
66 }
67 }
68}
69
70fn ln_factorial(n: u64) -> f64 {
72 if n <= 20 {
73 let mut result = 0.0_f64;
75 for i in 2..=n {
76 result += (i as f64).ln();
77 }
78 result
79 } else {
80 let nf = n as f64;
82 0.5 * (std::f64::consts::TAU).ln() + (nf + 0.5) * nf.ln() - nf + 1.0 / (12.0 * nf)
83 - 1.0 / (360.0 * nf * nf * nf)
84 }
85}
86
87fn binomial_single<B: BitGenerator>(bg: &mut B, n: u64, p: f64) -> i64 {
90 if n == 0 || p == 0.0 {
91 return 0;
92 }
93 if p == 1.0 {
94 return n as i64;
95 }
96
97 let (pp, flipped) = if p > 0.5 { (1.0 - p, true) } else { (p, false) };
99
100 let np = n as f64 * pp;
101
102 let result = if np < 30.0 {
103 let q = 1.0 - pp;
105 let s = pp / q;
106 let a = (n as f64 + 1.0) * s;
107 let mut r = q.powf(n as f64);
108 let mut u = bg.next_f64();
109 let mut x: i64 = 0;
110 while u > r {
111 u -= r;
112 x += 1;
113 r *= a / (x as f64) - s;
114 if r < 0.0 {
115 break;
116 }
117 }
118 x.min(n as i64)
119 } else {
120 let q = 1.0 - pp;
124 let fm = np + pp;
125 let m = fm.floor() as i64;
126 let mf = m as f64;
127 let p1 = (2.195 * (np * q).sqrt() - 4.6 * q).floor() + 0.5;
128 let xm = mf + 0.5;
129 let xl = xm - p1;
130 let xr = xm + p1;
131 let c = 0.134 + 20.5 / (15.3 + mf);
132 let a = (fm - xl) / (fm - xl * pp);
133 let lambda_l = a * (1.0 + 0.5 * a);
134 let a2 = (xr - fm) / (xr * q);
135 let lambda_r = a2 * (1.0 + 0.5 * a2);
136 let p2 = p1 * (1.0 + 2.0 * c);
137 let p3 = p2 + c / lambda_l;
138 let p4 = p3 + c / lambda_r;
139
140 loop {
141 let u = bg.next_f64() * p4;
142 let v = bg.next_f64();
143 let y: i64;
144
145 if u <= p1 {
146 y = (xm - p1 * v + u).floor() as i64;
148 } else if u <= p2 {
149 let x = xl + (u - p1) / c;
151 let w = v + (x - xm) * (x - xm) / (p1 * p1);
152 if w > 1.0 {
153 continue;
154 }
155 y = x.floor() as i64;
156 } else if u <= p3 {
157 y = (xl + v.ln() / lambda_l).floor() as i64;
159 if y < 0 {
160 continue;
161 }
162 } else {
163 y = (xr - v.ln() / lambda_r).floor() as i64;
165 if y > n as i64 {
166 continue;
167 }
168 }
169
170 let k = (y - m).abs();
172 if k <= 20 || k as f64 >= 0.5 * np * q - 1.0 {
173 let kf = k as f64;
175 let yf = y as f64;
176 let rho =
177 (kf / (np * q)) * ((kf * (kf / 3.0 + 0.625) + 1.0 / 6.0) / (np * q) + 0.5);
178 let t = -kf * kf / (2.0 * np * q);
179 let log_a = t - rho;
180 if v.ln() <= log_a {
181 break y;
182 }
183 let log_v = v.ln();
185 let log_accept =
186 ln_factorial(m as u64) - ln_factorial(y as u64) - ln_factorial(n - y as u64)
187 + ln_factorial(n - m as u64)
188 + (yf - mf) * (pp / q).ln();
189 if log_v <= log_accept {
190 break y;
191 }
192 } else {
193 break y;
194 }
195 }
196 };
197
198 if flipped { n as i64 - result } else { result }
199}
200
201impl<B: BitGenerator> Generator<B> {
202 pub fn binomial(
215 &mut self,
216 n: u64,
217 p: f64,
218 size: impl IntoShape,
219 ) -> Result<Array<i64, IxDyn>, FerrayError> {
220 if !(0.0..=1.0).contains(&p) {
221 return Err(FerrayError::invalid_value(format!(
222 "p must be in [0, 1], got {p}"
223 )));
224 }
225 let shape_vec = size.into_shape()?;
226 let total = shape_size(&shape_vec);
227 let data = generate_vec_i64(self, total, |bg| binomial_single(bg, n, p));
228 vec_to_array_i64(data, &shape_vec)
229 }
230
231 pub fn negative_binomial(
244 &mut self,
245 n: f64,
246 p: f64,
247 size: impl IntoShape,
248 ) -> Result<Array<i64, IxDyn>, FerrayError> {
249 if n <= 0.0 {
250 return Err(FerrayError::invalid_value(format!(
251 "n must be positive, got {n}"
252 )));
253 }
254 if p <= 0.0 || p > 1.0 {
255 return Err(FerrayError::invalid_value(format!(
256 "p must be in (0, 1], got {p}"
257 )));
258 }
259 let shape_vec = size.into_shape()?;
260 let total = shape_size(&shape_vec);
261 let data = generate_vec_i64(self, total, |bg| {
262 let y = standard_gamma_single(bg, n) * (1.0 - p) / p;
265 poisson_single(bg, y)
266 });
267 vec_to_array_i64(data, &shape_vec)
268 }
269
270 pub fn poisson(
279 &mut self,
280 lam: f64,
281 size: impl IntoShape,
282 ) -> Result<Array<i64, IxDyn>, FerrayError> {
283 if lam < 0.0 {
284 return Err(FerrayError::invalid_value(format!(
285 "lam must be non-negative, got {lam}"
286 )));
287 }
288 let shape_vec = size.into_shape()?;
289 let total = shape_size(&shape_vec);
290 if lam == 0.0 {
291 let data = vec![0i64; total];
292 return vec_to_array_i64(data, &shape_vec);
293 }
294 let data = generate_vec_i64(self, total, |bg| poisson_single(bg, lam));
295 vec_to_array_i64(data, &shape_vec)
296 }
297
298 pub fn geometric(
309 &mut self,
310 p: f64,
311 size: impl IntoShape,
312 ) -> Result<Array<i64, IxDyn>, FerrayError> {
313 if p <= 0.0 || p > 1.0 {
314 return Err(FerrayError::invalid_value(format!(
315 "p must be in (0, 1], got {p}"
316 )));
317 }
318 let shape_vec = size.into_shape()?;
319 let total = shape_size(&shape_vec);
320 if (p - 1.0).abs() < f64::EPSILON {
321 let data = vec![1i64; total];
322 return vec_to_array_i64(data, &shape_vec);
323 }
324 let log_q = (1.0 - p).ln();
325 let data = generate_vec_i64(self, total, |bg| {
326 loop {
327 let u = bg.next_f64();
328 if u > f64::EPSILON {
329 return (u.ln() / log_q).floor() as i64 + 1;
330 }
331 }
332 });
333 vec_to_array_i64(data, &shape_vec)
334 }
335
336 pub fn hypergeometric(
350 &mut self,
351 ngood: u64,
352 nbad: u64,
353 nsample: u64,
354 size: impl IntoShape,
355 ) -> Result<Array<i64, IxDyn>, FerrayError> {
356 let total = ngood + nbad;
357 if nsample > total {
358 return Err(FerrayError::invalid_value(format!(
359 "nsample ({nsample}) > ngood + nbad ({total})"
360 )));
361 }
362 let shape_vec = size.into_shape()?;
363 let total_n = shape_size(&shape_vec);
364 let data = generate_vec_i64(self, total_n, |bg| {
365 hypergeometric_single(bg, ngood, nbad, nsample)
366 });
367 vec_to_array_i64(data, &shape_vec)
368 }
369
370 pub fn logseries(
379 &mut self,
380 p: f64,
381 size: impl IntoShape,
382 ) -> Result<Array<i64, IxDyn>, FerrayError> {
383 if p <= 0.0 || p >= 1.0 {
384 return Err(FerrayError::invalid_value(format!(
385 "p must be in (0, 1), got {p}"
386 )));
387 }
388 let r = (-(-p).ln_1p()).recip();
389 let shape_vec = size.into_shape()?;
390 let total = shape_size(&shape_vec);
391 let data = generate_vec_i64(self, total, |bg| {
392 loop {
395 let u = bg.next_f64();
396 if u <= f64::EPSILON || u >= 1.0 - f64::EPSILON {
397 continue;
398 }
399 let v = bg.next_f64();
400 let q = 1.0 - (-r.recip() * u.ln()).exp();
401 if q <= 0.0 {
402 return 1;
403 }
404 if v < q * q {
405 let k = (1.0 + v.ln() / q.ln()).floor() as i64;
406 return k.max(1);
407 }
408 if v < q {
409 return 2;
410 }
411 return 1;
412 }
413 });
414 vec_to_array_i64(data, &shape_vec)
415 }
416}
417
418fn hypergeometric_single<B: BitGenerator>(bg: &mut B, ngood: u64, nbad: u64, nsample: u64) -> i64 {
420 let mut good_remaining = ngood;
422 let mut total_remaining = ngood + nbad;
423 let mut successes: i64 = 0;
424
425 for _ in 0..nsample {
426 if total_remaining == 0 {
427 break;
428 }
429 let u = bg.next_f64();
430 if u < (good_remaining as f64) / (total_remaining as f64) {
431 successes += 1;
432 good_remaining -= 1;
433 }
434 total_remaining -= 1;
435 }
436 successes
437}
438
439#[cfg(test)]
440mod tests {
441 use crate::default_rng_seeded;
442
443 #[test]
444 fn poisson_mean() {
445 let mut rng = default_rng_seeded(42);
446 let n = 100_000;
447 let lam = 5.0;
448 let arr = rng.poisson(lam, n).unwrap();
449 let slice = arr.as_slice().unwrap();
450 let mean: f64 = slice.iter().map(|&x| x as f64).sum::<f64>() / n as f64;
451 let se = (lam / n as f64).sqrt();
453 assert!(
454 (mean - lam).abs() < 3.0 * se,
455 "poisson mean {mean} too far from {lam}"
456 );
457 }
458
459 #[test]
460 fn poisson_large_lambda() {
461 let mut rng = default_rng_seeded(42);
462 let n = 50_000;
463 let lam = 100.0;
464 let arr = rng.poisson(lam, n).unwrap();
465 let slice = arr.as_slice().unwrap();
466 let mean: f64 = slice.iter().map(|&x| x as f64).sum::<f64>() / n as f64;
467 let se = (lam / n as f64).sqrt();
468 assert!(
469 (mean - lam).abs() < 3.0 * se,
470 "poisson mean {mean} too far from {lam}"
471 );
472 }
473
474 #[test]
475 fn poisson_zero() {
476 let mut rng = default_rng_seeded(42);
477 let arr = rng.poisson(0.0, 100).unwrap();
478 for &v in arr.as_slice().unwrap() {
479 assert_eq!(v, 0);
480 }
481 }
482
483 #[test]
484 fn binomial_mean() {
485 let mut rng = default_rng_seeded(42);
486 let size = 100_000;
487 let n = 20u64;
488 let p = 0.3;
489 let arr = rng.binomial(n, p, size).unwrap();
490 let slice = arr.as_slice().unwrap();
491 let mean: f64 = slice.iter().map(|&x| x as f64).sum::<f64>() / size as f64;
492 let expected_mean = n as f64 * p;
494 let expected_var = n as f64 * p * (1.0 - p);
495 let se = (expected_var / size as f64).sqrt();
496 assert!(
497 (mean - expected_mean).abs() < 3.0 * se,
498 "binomial mean {mean} too far from {expected_mean}"
499 );
500 for &v in slice {
502 assert!(
503 v >= 0 && v <= n as i64,
504 "binomial value {v} out of [0, {n}]"
505 );
506 }
507 }
508
509 #[test]
510 fn binomial_edge_cases() {
511 let mut rng = default_rng_seeded(42);
512 let arr = rng.binomial(10, 0.0, 100).unwrap();
514 for &v in arr.as_slice().unwrap() {
515 assert_eq!(v, 0);
516 }
517 let arr = rng.binomial(10, 1.0, 100).unwrap();
519 for &v in arr.as_slice().unwrap() {
520 assert_eq!(v, 10);
521 }
522 }
523
524 #[test]
525 fn negative_binomial_positive() {
526 let mut rng = default_rng_seeded(42);
527 let arr = rng.negative_binomial(5.0, 0.5, 10_000).unwrap();
528 for &v in arr.as_slice().unwrap() {
529 assert!(v >= 0, "negative_binomial value {v} must be >= 0");
530 }
531 }
532
533 #[test]
534 fn geometric_mean() {
535 let mut rng = default_rng_seeded(42);
536 let n = 100_000;
537 let p = 0.3;
538 let arr = rng.geometric(p, n).unwrap();
539 let slice = arr.as_slice().unwrap();
540 let mean: f64 = slice.iter().map(|&x| x as f64).sum::<f64>() / n as f64;
541 let expected_mean = 1.0 / p;
543 let expected_var = (1.0 - p) / (p * p);
544 let se = (expected_var / n as f64).sqrt();
545 assert!(
546 (mean - expected_mean).abs() < 3.0 * se,
547 "geometric mean {mean} too far from {expected_mean}"
548 );
549 for &v in slice {
550 assert!(v >= 1, "geometric value {v} must be >= 1");
551 }
552 }
553
554 #[test]
555 fn hypergeometric_range() {
556 let mut rng = default_rng_seeded(42);
557 let ngood = 20u64;
558 let nbad = 30u64;
559 let nsample = 10u64;
560 let arr = rng.hypergeometric(ngood, nbad, nsample, 10_000).unwrap();
561 let slice = arr.as_slice().unwrap();
562 for &v in slice {
563 assert!(
564 v >= 0 && v <= nsample.min(ngood) as i64,
565 "hypergeometric value {v} out of range"
566 );
567 }
568 }
569
570 #[test]
571 fn hypergeometric_mean() {
572 let mut rng = default_rng_seeded(42);
573 let n = 100_000;
574 let ngood = 20u64;
575 let nbad = 30u64;
576 let nsample = 10u64;
577 let arr = rng.hypergeometric(ngood, nbad, nsample, n).unwrap();
578 let slice = arr.as_slice().unwrap();
579 let mean: f64 = slice.iter().map(|&x| x as f64).sum::<f64>() / n as f64;
580 let total = (ngood + nbad) as f64;
582 let expected_mean = nsample as f64 * ngood as f64 / total;
583 let expected_var = nsample as f64
584 * (ngood as f64 / total)
585 * (nbad as f64 / total)
586 * (total - nsample as f64)
587 / (total - 1.0);
588 let se = (expected_var / n as f64).sqrt();
589 assert!(
590 (mean - expected_mean).abs() < 3.0 * se,
591 "hypergeometric mean {mean} too far from {expected_mean}"
592 );
593 }
594
595 #[test]
596 fn logseries_positive() {
597 let mut rng = default_rng_seeded(42);
598 let arr = rng.logseries(0.5, 10_000).unwrap();
599 for &v in arr.as_slice().unwrap() {
600 assert!(v >= 1, "logseries value {v} must be >= 1");
601 }
602 }
603
604 #[test]
605 fn bad_params() {
606 let mut rng = default_rng_seeded(42);
607 assert!(rng.binomial(10, -0.1, 10).is_err());
608 assert!(rng.binomial(10, 1.5, 10).is_err());
609 assert!(rng.poisson(-1.0, 10).is_err());
610 assert!(rng.geometric(0.0, 10).is_err());
611 assert!(rng.geometric(1.5, 10).is_err());
612 assert!(rng.hypergeometric(5, 5, 20, 10).is_err());
613 assert!(rng.logseries(0.0, 10).is_err());
614 assert!(rng.logseries(1.0, 10).is_err());
615 assert!(rng.negative_binomial(0.0, 0.5, 10).is_err());
616 assert!(rng.negative_binomial(5.0, 0.0, 10).is_err());
617 }
618}