ferray_random/distributions/
discrete.rs1use ferray_core::{Array, FerrayError, Ix1};
6
7use crate::bitgen::BitGenerator;
8use crate::distributions::gamma::standard_gamma_single;
9use crate::distributions::normal::standard_normal_single;
10use crate::generator::{Generator, generate_vec_i64, vec_to_array1_i64};
11
12fn poisson_single<B: BitGenerator>(bg: &mut B, lam: f64) -> i64 {
15 if lam < 30.0 {
16 let l = (-lam).exp();
18 let mut k: i64 = 0;
19 let mut p = 1.0;
20 loop {
21 k += 1;
22 p *= bg.next_f64();
23 if p <= l {
24 return k - 1;
25 }
26 }
27 } else {
28 let slam = lam.sqrt();
30 let loglam = lam.ln();
31 let b = 0.931 + 2.53 * slam;
32 let a = -0.059 + 0.02483 * b;
33 let inv_alpha = 1.1239 + 1.1328 / (b - 3.4);
34 let vr = 0.9277 - 3.6224 / (b - 2.0);
35
36 loop {
37 let u = bg.next_f64() - 0.5;
38 let v = bg.next_f64();
39 let us = 0.5 - u.abs();
40 let k = ((2.0 * a / us + b) * u + lam + 0.43).floor() as i64;
41 if k < 0 {
42 continue;
43 }
44 if us >= 0.07 && v <= vr {
45 return k;
46 }
47 if k > 0
48 && us >= 0.013
49 && v <= (k as f64)
50 .ln()
51 .mul_add(-0.5, (k as f64) * loglam - lam - ln_factorial(k as u64))
52 .exp()
53 * inv_alpha
54 {
55 }
57 if us < 0.013 && v > us {
58 continue;
59 }
60 let kf = k as f64;
62 let log_accept = -lam + kf * loglam - ln_factorial(k as u64);
63 if v.ln() + inv_alpha.ln() - (a / (us * us) + b).ln() <= log_accept {
64 return k;
65 }
66 }
67 }
68}
69
70fn ln_factorial(n: u64) -> f64 {
72 if n <= 20 {
73 let mut result = 0.0_f64;
75 for i in 2..=n {
76 result += (i as f64).ln();
77 }
78 result
79 } else {
80 let nf = n as f64;
82 0.5 * (std::f64::consts::TAU).ln() + (nf + 0.5) * nf.ln() - nf + 1.0 / (12.0 * nf)
83 - 1.0 / (360.0 * nf * nf * nf)
84 }
85}
86
87fn binomial_single<B: BitGenerator>(bg: &mut B, n: u64, p: f64) -> i64 {
90 if n == 0 || p == 0.0 {
91 return 0;
92 }
93 if p == 1.0 {
94 return n as i64;
95 }
96
97 let (pp, flipped) = if p > 0.5 { (1.0 - p, true) } else { (p, false) };
99
100 let np = n as f64 * pp;
101
102 let result = if np < 30.0 {
103 let q = 1.0 - pp;
105 let s = pp / q;
106 let a = (n as f64 + 1.0) * s;
107 let mut r = q.powi(n as i32);
108 let mut u = bg.next_f64();
109 let mut x: i64 = 0;
110 while u > r {
111 u -= r;
112 x += 1;
113 r *= a / (x as f64) - s;
114 if r < 0.0 {
115 break;
116 }
117 }
118 x.min(n as i64)
119 } else {
120 loop {
123 let z = standard_normal_single(bg);
124 let sigma = (np * (1.0 - pp)).sqrt();
125 let x = (np + sigma * z + 0.5).floor() as i64;
126 if x >= 0 && x <= n as i64 {
127 break x;
128 }
129 }
130 };
131
132 if flipped { n as i64 - result } else { result }
133}
134
135impl<B: BitGenerator> Generator<B> {
136 pub fn binomial(
149 &mut self,
150 n: u64,
151 p: f64,
152 size: usize,
153 ) -> Result<Array<i64, Ix1>, FerrayError> {
154 if size == 0 {
155 return Err(FerrayError::invalid_value("size must be > 0"));
156 }
157 if !(0.0..=1.0).contains(&p) {
158 return Err(FerrayError::invalid_value(format!(
159 "p must be in [0, 1], got {p}"
160 )));
161 }
162 let data = generate_vec_i64(self, size, |bg| binomial_single(bg, n, p));
163 vec_to_array1_i64(data)
164 }
165
166 pub fn negative_binomial(
179 &mut self,
180 n: f64,
181 p: f64,
182 size: usize,
183 ) -> Result<Array<i64, Ix1>, FerrayError> {
184 if size == 0 {
185 return Err(FerrayError::invalid_value("size must be > 0"));
186 }
187 if n <= 0.0 {
188 return Err(FerrayError::invalid_value(format!(
189 "n must be positive, got {n}"
190 )));
191 }
192 if p <= 0.0 || p > 1.0 {
193 return Err(FerrayError::invalid_value(format!(
194 "p must be in (0, 1], got {p}"
195 )));
196 }
197 let data = generate_vec_i64(self, size, |bg| {
198 let y = standard_gamma_single(bg, n) * (1.0 - p) / p;
201 poisson_single(bg, y)
202 });
203 vec_to_array1_i64(data)
204 }
205
206 pub fn poisson(&mut self, lam: f64, size: usize) -> Result<Array<i64, Ix1>, FerrayError> {
215 if size == 0 {
216 return Err(FerrayError::invalid_value("size must be > 0"));
217 }
218 if lam < 0.0 {
219 return Err(FerrayError::invalid_value(format!(
220 "lam must be non-negative, got {lam}"
221 )));
222 }
223 if lam == 0.0 {
224 let data = vec![0i64; size];
225 return vec_to_array1_i64(data);
226 }
227 let data = generate_vec_i64(self, size, |bg| poisson_single(bg, lam));
228 vec_to_array1_i64(data)
229 }
230
231 pub fn geometric(&mut self, p: f64, size: usize) -> Result<Array<i64, Ix1>, FerrayError> {
242 if size == 0 {
243 return Err(FerrayError::invalid_value("size must be > 0"));
244 }
245 if p <= 0.0 || p > 1.0 {
246 return Err(FerrayError::invalid_value(format!(
247 "p must be in (0, 1], got {p}"
248 )));
249 }
250 if (p - 1.0).abs() < f64::EPSILON {
251 let data = vec![1i64; size];
252 return vec_to_array1_i64(data);
253 }
254 let log_q = (1.0 - p).ln();
255 let data = generate_vec_i64(self, size, |bg| {
256 loop {
257 let u = bg.next_f64();
258 if u > f64::EPSILON {
259 return (u.ln() / log_q).floor() as i64 + 1;
260 }
261 }
262 });
263 vec_to_array1_i64(data)
264 }
265
266 pub fn hypergeometric(
280 &mut self,
281 ngood: u64,
282 nbad: u64,
283 nsample: u64,
284 size: usize,
285 ) -> Result<Array<i64, Ix1>, FerrayError> {
286 if size == 0 {
287 return Err(FerrayError::invalid_value("size must be > 0"));
288 }
289 let total = ngood + nbad;
290 if nsample > total {
291 return Err(FerrayError::invalid_value(format!(
292 "nsample ({nsample}) > ngood + nbad ({total})"
293 )));
294 }
295 let data = generate_vec_i64(self, size, |bg| {
296 hypergeometric_single(bg, ngood, nbad, nsample)
297 });
298 vec_to_array1_i64(data)
299 }
300
301 pub fn logseries(&mut self, p: f64, size: usize) -> Result<Array<i64, Ix1>, FerrayError> {
310 if size == 0 {
311 return Err(FerrayError::invalid_value("size must be > 0"));
312 }
313 if p <= 0.0 || p >= 1.0 {
314 return Err(FerrayError::invalid_value(format!(
315 "p must be in (0, 1), got {p}"
316 )));
317 }
318 let r = (-(-p).ln_1p()).recip();
319 let data = generate_vec_i64(self, size, |bg| {
320 loop {
323 let u = bg.next_f64();
324 if u <= f64::EPSILON || u >= 1.0 - f64::EPSILON {
325 continue;
326 }
327 let v = bg.next_f64();
328 let q = 1.0 - (-r.recip() * u.ln()).exp();
329 if q <= 0.0 {
330 return 1;
331 }
332 if v < q * q {
333 let k = (1.0 + v.ln() / q.ln()).floor() as i64;
334 return k.max(1);
335 }
336 if v < q {
337 return 2;
338 }
339 return 1;
340 }
341 });
342 vec_to_array1_i64(data)
343 }
344}
345
346fn hypergeometric_single<B: BitGenerator>(bg: &mut B, ngood: u64, nbad: u64, nsample: u64) -> i64 {
348 let mut good_remaining = ngood;
350 let mut total_remaining = ngood + nbad;
351 let mut successes: i64 = 0;
352
353 for _ in 0..nsample {
354 if total_remaining == 0 {
355 break;
356 }
357 let u = bg.next_f64();
358 if u < (good_remaining as f64) / (total_remaining as f64) {
359 successes += 1;
360 good_remaining -= 1;
361 }
362 total_remaining -= 1;
363 }
364 successes
365}
366
367#[cfg(test)]
368mod tests {
369 use crate::default_rng_seeded;
370
371 #[test]
372 fn poisson_mean() {
373 let mut rng = default_rng_seeded(42);
374 let n = 100_000;
375 let lam = 5.0;
376 let arr = rng.poisson(lam, n).unwrap();
377 let slice = arr.as_slice().unwrap();
378 let mean: f64 = slice.iter().map(|&x| x as f64).sum::<f64>() / n as f64;
379 let se = (lam / n as f64).sqrt();
381 assert!(
382 (mean - lam).abs() < 3.0 * se,
383 "poisson mean {mean} too far from {lam}"
384 );
385 }
386
387 #[test]
388 fn poisson_large_lambda() {
389 let mut rng = default_rng_seeded(42);
390 let n = 50_000;
391 let lam = 100.0;
392 let arr = rng.poisson(lam, n).unwrap();
393 let slice = arr.as_slice().unwrap();
394 let mean: f64 = slice.iter().map(|&x| x as f64).sum::<f64>() / n as f64;
395 let se = (lam / n as f64).sqrt();
396 assert!(
397 (mean - lam).abs() < 3.0 * se,
398 "poisson mean {mean} too far from {lam}"
399 );
400 }
401
402 #[test]
403 fn poisson_zero() {
404 let mut rng = default_rng_seeded(42);
405 let arr = rng.poisson(0.0, 100).unwrap();
406 for &v in arr.as_slice().unwrap() {
407 assert_eq!(v, 0);
408 }
409 }
410
411 #[test]
412 fn binomial_mean() {
413 let mut rng = default_rng_seeded(42);
414 let size = 100_000;
415 let n = 20u64;
416 let p = 0.3;
417 let arr = rng.binomial(n, p, size).unwrap();
418 let slice = arr.as_slice().unwrap();
419 let mean: f64 = slice.iter().map(|&x| x as f64).sum::<f64>() / size as f64;
420 let expected_mean = n as f64 * p;
422 let expected_var = n as f64 * p * (1.0 - p);
423 let se = (expected_var / size as f64).sqrt();
424 assert!(
425 (mean - expected_mean).abs() < 3.0 * se,
426 "binomial mean {mean} too far from {expected_mean}"
427 );
428 for &v in slice {
430 assert!(
431 v >= 0 && v <= n as i64,
432 "binomial value {v} out of [0, {n}]"
433 );
434 }
435 }
436
437 #[test]
438 fn binomial_edge_cases() {
439 let mut rng = default_rng_seeded(42);
440 let arr = rng.binomial(10, 0.0, 100).unwrap();
442 for &v in arr.as_slice().unwrap() {
443 assert_eq!(v, 0);
444 }
445 let arr = rng.binomial(10, 1.0, 100).unwrap();
447 for &v in arr.as_slice().unwrap() {
448 assert_eq!(v, 10);
449 }
450 }
451
452 #[test]
453 fn negative_binomial_positive() {
454 let mut rng = default_rng_seeded(42);
455 let arr = rng.negative_binomial(5.0, 0.5, 10_000).unwrap();
456 for &v in arr.as_slice().unwrap() {
457 assert!(v >= 0, "negative_binomial value {v} must be >= 0");
458 }
459 }
460
461 #[test]
462 fn geometric_mean() {
463 let mut rng = default_rng_seeded(42);
464 let n = 100_000;
465 let p = 0.3;
466 let arr = rng.geometric(p, n).unwrap();
467 let slice = arr.as_slice().unwrap();
468 let mean: f64 = slice.iter().map(|&x| x as f64).sum::<f64>() / n as f64;
469 let expected_mean = 1.0 / p;
471 let expected_var = (1.0 - p) / (p * p);
472 let se = (expected_var / n as f64).sqrt();
473 assert!(
474 (mean - expected_mean).abs() < 3.0 * se,
475 "geometric mean {mean} too far from {expected_mean}"
476 );
477 for &v in slice {
478 assert!(v >= 1, "geometric value {v} must be >= 1");
479 }
480 }
481
482 #[test]
483 fn hypergeometric_range() {
484 let mut rng = default_rng_seeded(42);
485 let ngood = 20u64;
486 let nbad = 30u64;
487 let nsample = 10u64;
488 let arr = rng.hypergeometric(ngood, nbad, nsample, 10_000).unwrap();
489 let slice = arr.as_slice().unwrap();
490 for &v in slice {
491 assert!(
492 v >= 0 && v <= nsample.min(ngood) as i64,
493 "hypergeometric value {v} out of range"
494 );
495 }
496 }
497
498 #[test]
499 fn hypergeometric_mean() {
500 let mut rng = default_rng_seeded(42);
501 let n = 100_000;
502 let ngood = 20u64;
503 let nbad = 30u64;
504 let nsample = 10u64;
505 let arr = rng.hypergeometric(ngood, nbad, nsample, n).unwrap();
506 let slice = arr.as_slice().unwrap();
507 let mean: f64 = slice.iter().map(|&x| x as f64).sum::<f64>() / n as f64;
508 let total = (ngood + nbad) as f64;
510 let expected_mean = nsample as f64 * ngood as f64 / total;
511 let expected_var = nsample as f64
512 * (ngood as f64 / total)
513 * (nbad as f64 / total)
514 * (total - nsample as f64)
515 / (total - 1.0);
516 let se = (expected_var / n as f64).sqrt();
517 assert!(
518 (mean - expected_mean).abs() < 3.0 * se,
519 "hypergeometric mean {mean} too far from {expected_mean}"
520 );
521 }
522
523 #[test]
524 fn logseries_positive() {
525 let mut rng = default_rng_seeded(42);
526 let arr = rng.logseries(0.5, 10_000).unwrap();
527 for &v in arr.as_slice().unwrap() {
528 assert!(v >= 1, "logseries value {v} must be >= 1");
529 }
530 }
531
532 #[test]
533 fn bad_params() {
534 let mut rng = default_rng_seeded(42);
535 assert!(rng.binomial(10, -0.1, 10).is_err());
536 assert!(rng.binomial(10, 1.5, 10).is_err());
537 assert!(rng.poisson(-1.0, 10).is_err());
538 assert!(rng.geometric(0.0, 10).is_err());
539 assert!(rng.geometric(1.5, 10).is_err());
540 assert!(rng.hypergeometric(5, 5, 20, 10).is_err());
541 assert!(rng.logseries(0.0, 10).is_err());
542 assert!(rng.logseries(1.0, 10).is_err());
543 assert!(rng.negative_binomial(0.0, 0.5, 10).is_err());
544 assert!(rng.negative_binomial(5.0, 0.0, 10).is_err());
545 }
546}