1use ferray_core::{Array, FerrayError, Ix1, Ix2};
4
5use crate::bitgen::BitGenerator;
6use crate::distributions::gamma::standard_gamma_single;
7use crate::distributions::normal::standard_normal_single;
8use crate::generator::Generator;
9
10impl<B: BitGenerator> Generator<B> {
11 pub fn multinomial(
27 &mut self,
28 n: u64,
29 pvals: &[f64],
30 size: usize,
31 ) -> Result<Array<i64, Ix2>, FerrayError> {
32 if size == 0 {
33 return Err(FerrayError::invalid_value("size must be > 0"));
34 }
35 if pvals.is_empty() {
36 return Err(FerrayError::invalid_value(
37 "pvals must have at least one element",
38 ));
39 }
40 let psum: f64 = pvals.iter().sum();
41 if (psum - 1.0).abs() > 1e-6 {
42 return Err(FerrayError::invalid_value(format!(
43 "pvals must sum to 1.0, got {psum}"
44 )));
45 }
46 for (i, &p) in pvals.iter().enumerate() {
47 if p < 0.0 {
48 return Err(FerrayError::invalid_value(format!(
49 "pvals[{i}] = {p} is negative"
50 )));
51 }
52 }
53
54 let k = pvals.len();
55 let mut data = Vec::with_capacity(size * k);
56
57 for _ in 0..size {
58 let mut remaining = n;
59 let mut psum_remaining = 1.0;
60 for (j, &pj) in pvals.iter().enumerate() {
61 if j == k - 1 {
62 data.push(remaining as i64);
64 } else if psum_remaining <= 0.0 || remaining == 0 {
65 data.push(0);
66 } else {
67 let p_cond = (pj / psum_remaining).clamp(0.0, 1.0);
68 let count = binomial_for_multinomial(&mut self.bg, remaining, p_cond);
69 data.push(count as i64);
70 remaining -= count;
71 psum_remaining -= pj;
72 }
73 }
74 }
75
76 Array::<i64, Ix2>::from_vec(Ix2::new([size, k]), data)
77 }
78
79 pub fn multivariate_normal(
95 &mut self,
96 mean: &[f64],
97 cov: &[f64],
98 size: usize,
99 ) -> Result<Array<f64, Ix2>, FerrayError> {
100 if size == 0 {
101 return Err(FerrayError::invalid_value("size must be > 0"));
102 }
103 let d = mean.len();
104 if d == 0 {
105 return Err(FerrayError::invalid_value("mean must be non-empty"));
106 }
107 if cov.len() != d * d {
108 return Err(FerrayError::invalid_value(format!(
109 "cov must have {} elements for mean of length {d}, got {}",
110 d * d,
111 cov.len()
112 )));
113 }
114
115 let l = cholesky_decompose(cov, d)?;
117
118 let mut data = Vec::with_capacity(size * d);
119 for _ in 0..size {
120 let mut z = Vec::with_capacity(d);
122 for _ in 0..d {
123 z.push(standard_normal_single(&mut self.bg));
124 }
125
126 for i in 0..d {
128 let mut val = mean[i];
129 for j in 0..=i {
130 val += l[i * d + j] * z[j];
131 }
132 data.push(val);
133 }
134 }
135
136 Array::<f64, Ix2>::from_vec(Ix2::new([size, d]), data)
137 }
138
139 pub fn multivariate_hypergeometric(
163 &mut self,
164 colors: &[u64],
165 nsample: u64,
166 size: usize,
167 ) -> Result<Array<i64, Ix2>, FerrayError> {
168 if size == 0 {
169 return Err(FerrayError::invalid_value("size must be > 0"));
170 }
171 if colors.is_empty() {
172 return Err(FerrayError::invalid_value(
173 "colors must have at least one element",
174 ));
175 }
176 let total: u64 = colors.iter().try_fold(0_u64, |acc, &c| {
177 acc.checked_add(c).ok_or_else(|| {
178 FerrayError::invalid_value("multivariate_hypergeometric: colors sum overflows u64")
179 })
180 })?;
181 if nsample > total {
182 return Err(FerrayError::invalid_value(format!(
183 "nsample ({nsample}) > sum of colors ({total})"
184 )));
185 }
186
187 let k = colors.len();
188 let mut data = Vec::with_capacity(size * k);
189
190 for _ in 0..size {
191 let mut remaining_pop: u64 = total;
196 let mut remaining_sample: u64 = nsample;
197
198 for &ngood in &colors[..k - 1] {
199 let nbad = remaining_pop - ngood;
200 let draw = if remaining_sample == 0 || ngood == 0 {
201 0
202 } else if remaining_sample >= remaining_pop {
203 ngood as i64
205 } else {
206 hypergeometric_for_multivariate(&mut self.bg, ngood, nbad, remaining_sample)
207 };
208 data.push(draw);
209 remaining_pop -= ngood;
210 remaining_sample -= draw as u64;
211 }
212 data.push(remaining_sample as i64);
214 }
215
216 Array::<i64, Ix2>::from_vec(Ix2::new([size, k]), data)
217 }
218
219 pub fn multivariate_normal_array(
237 &mut self,
238 mean: &Array<f64, Ix1>,
239 cov: &Array<f64, Ix2>,
240 size: usize,
241 ) -> Result<Array<f64, Ix2>, FerrayError> {
242 if size == 0 {
243 return Err(FerrayError::invalid_value("size must be > 0"));
244 }
245 let d = mean.shape()[0];
246 if d == 0 {
247 return Err(FerrayError::invalid_value("mean must be non-empty"));
248 }
249 let cov_shape = cov.shape();
250 if cov_shape[0] != d || cov_shape[1] != d {
251 return Err(FerrayError::shape_mismatch(format!(
252 "cov shape {cov_shape:?} does not match mean of length {d}"
253 )));
254 }
255
256 let l_arr = ferray_linalg::cholesky(cov)?;
257 let l_slice = l_arr
258 .as_slice()
259 .ok_or_else(|| FerrayError::invalid_value("cholesky returned non-contiguous L"))?;
260 let mean_slice = mean
261 .as_slice()
262 .ok_or_else(|| FerrayError::invalid_value("mean must be contiguous"))?;
263
264 let mut data = Vec::with_capacity(size * d);
265 let mut z = vec![0.0_f64; d];
266 for _ in 0..size {
267 for v in z.iter_mut() {
268 *v = standard_normal_single(&mut self.bg);
269 }
270 for i in 0..d {
271 let mut val = mean_slice[i];
272 for j in 0..=i {
273 val += l_slice[i * d + j] * z[j];
274 }
275 data.push(val);
276 }
277 }
278 Array::<f64, Ix2>::from_vec(Ix2::new([size, d]), data)
279 }
280
281 pub fn dirichlet(
296 &mut self,
297 alpha: &[f64],
298 size: usize,
299 ) -> Result<Array<f64, Ix2>, FerrayError> {
300 if size == 0 {
301 return Err(FerrayError::invalid_value("size must be > 0"));
302 }
303 if alpha.is_empty() {
304 return Err(FerrayError::invalid_value(
305 "alpha must have at least one element",
306 ));
307 }
308 for (i, &a) in alpha.iter().enumerate() {
309 if a <= 0.0 {
310 return Err(FerrayError::invalid_value(format!(
311 "alpha[{i}] = {a} must be positive"
312 )));
313 }
314 }
315
316 let k = alpha.len();
317 let mut data = Vec::with_capacity(size * k);
318
319 for _ in 0..size {
320 let mut gammas = Vec::with_capacity(k);
321 let mut sum = 0.0;
322 for &a in alpha {
323 let g = standard_gamma_single(&mut self.bg, a);
324 gammas.push(g);
325 sum += g;
326 }
327 if sum > 0.0 {
329 for g in &gammas {
330 data.push(g / sum);
331 }
332 } else {
333 let val = 1.0 / k as f64;
335 for _ in 0..k {
336 data.push(val);
337 }
338 }
339 }
340
341 Array::<f64, Ix2>::from_vec(Ix2::new([size, k]), data)
342 }
343}
344
345fn hypergeometric_for_multivariate<B: BitGenerator>(
352 bg: &mut B,
353 ngood: u64,
354 nbad: u64,
355 nsample: u64,
356) -> i64 {
357 let mut good_remaining = ngood;
358 let mut total_remaining = ngood + nbad;
359 let mut successes: i64 = 0;
360 for _ in 0..nsample {
361 if total_remaining == 0 {
362 break;
363 }
364 let u = bg.next_f64();
365 if u < (good_remaining as f64) / (total_remaining as f64) {
366 successes += 1;
367 good_remaining -= 1;
368 }
369 total_remaining -= 1;
370 }
371 successes
372}
373
374fn binomial_for_multinomial<B: BitGenerator>(bg: &mut B, n: u64, p: f64) -> u64 {
376 if n == 0 || p <= 0.0 {
377 return 0;
378 }
379 if p >= 1.0 {
380 return n;
381 }
382
383 let (pp, flipped) = if p > 0.5 { (1.0 - p, true) } else { (p, false) };
384
385 let result = if (n as f64) * pp < 30.0 {
386 let q = 1.0 - pp;
388 let s = pp / q;
389 let a = (n as f64 + 1.0) * s;
390 let mut r = q.powf(n as f64);
391 let mut u = bg.next_f64();
392 let mut x: u64 = 0;
393 while u > r {
394 u -= r;
395 x += 1;
396 if x > n {
397 x = n;
398 break;
399 }
400 r *= a / (x as f64) - s;
401 if r < 0.0 {
402 break;
403 }
404 }
405 x.min(n)
406 } else {
407 loop {
409 let z = standard_normal_single(bg);
410 let sigma = ((n as f64) * pp * (1.0 - pp)).sqrt();
411 let x = ((n as f64).mul_add(pp, sigma * z) + 0.5).floor() as i64;
412 if x >= 0 && x <= n as i64 {
413 break x as u64;
414 }
415 }
416 };
417
418 if flipped { n - result } else { result }
419}
420
421fn cholesky_decompose(a: &[f64], n: usize) -> Result<Vec<f64>, FerrayError> {
425 let mut l = vec![0.0; n * n];
426
427 for i in 0..n {
428 for j in 0..=i {
429 let mut sum = 0.0;
430 for k in 0..j {
431 sum += l[i * n + k] * l[j * n + k];
432 }
433 if i == j {
434 let diag = a[i * n + i] - sum;
435 if diag < -1e-10 {
436 return Err(FerrayError::invalid_value(
437 "covariance matrix is not positive semi-definite",
438 ));
439 }
440 l[i * n + j] = diag.max(0.0).sqrt();
441 } else {
442 let denom = l[j * n + j];
443 if denom.abs() < 1e-15 {
444 l[i * n + j] = 0.0;
445 } else {
446 l[i * n + j] = (a[i * n + j] - sum) / denom;
447 }
448 }
449 }
450 }
451
452 Ok(l)
453}
454
455#[cfg(test)]
456mod tests {
457 use crate::default_rng_seeded;
458
459 #[test]
460 fn multinomial_shape() {
461 let mut rng = default_rng_seeded(42);
462 let pvals = [0.2, 0.3, 0.5];
463 let arr = rng.multinomial(100, &pvals, 10).unwrap();
464 assert_eq!(arr.shape(), &[10, 3]);
465 }
466
467 #[test]
468 fn multinomial_row_sums() {
469 let mut rng = default_rng_seeded(42);
470 let pvals = [0.2, 0.3, 0.5];
471 let n = 100u64;
472 let arr = rng.multinomial(n, &pvals, 50).unwrap();
473 let slice = arr.as_slice().unwrap();
474 let k = pvals.len();
475 for row in 0..50 {
476 let row_sum: i64 = (0..k).map(|j| slice[row * k + j]).sum();
477 assert_eq!(
478 row_sum, n as i64,
479 "row {row} sum is {row_sum}, expected {n}"
480 );
481 }
482 }
483
484 #[test]
485 fn multinomial_nonnegative() {
486 let mut rng = default_rng_seeded(42);
487 let pvals = [0.1, 0.2, 0.3, 0.4];
488 let arr = rng.multinomial(50, &pvals, 100).unwrap();
489 for &v in arr.as_slice().unwrap() {
490 assert!(v >= 0, "multinomial produced negative count: {v}");
491 }
492 }
493
494 #[test]
495 fn multinomial_bad_pvals() {
496 let mut rng = default_rng_seeded(42);
497 assert!(rng.multinomial(10, &[0.5, 0.6], 10).is_err()); assert!(rng.multinomial(10, &[-0.1, 1.1], 10).is_err()); assert!(rng.multinomial(10, &[], 10).is_err()); }
501
502 #[test]
503 fn multivariate_normal_shape() {
504 let mut rng = default_rng_seeded(42);
505 let mean = [1.0, 2.0, 3.0];
506 let cov = [1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0];
507 let arr = rng.multivariate_normal(&mean, &cov, 100).unwrap();
508 assert_eq!(arr.shape(), &[100, 3]);
509 }
510
511 #[test]
512 fn multivariate_normal_mean() {
513 let mut rng = default_rng_seeded(42);
514 let mean = [5.0, -3.0];
515 let cov = [1.0, 0.0, 0.0, 1.0];
516 let n = 100_000;
517 let arr = rng.multivariate_normal(&mean, &cov, n).unwrap();
518 let slice = arr.as_slice().unwrap();
519 let d = mean.len();
520
521 for j in 0..d {
522 let col_mean: f64 = (0..n).map(|i| slice[i * d + j]).sum::<f64>() / n as f64;
523 let se = (1.0 / n as f64).sqrt();
524 assert!(
525 (col_mean - mean[j]).abs() < 3.0 * se,
526 "multivariate_normal mean[{j}] = {col_mean}, expected {}",
527 mean[j]
528 );
529 }
530 }
531
532 #[test]
533 fn multivariate_normal_bad_cov() {
534 let mut rng = default_rng_seeded(42);
535 let mean = [0.0, 0.0];
536 assert!(
538 rng.multivariate_normal(&mean, &[1.0, 0.0, 0.0], 10)
539 .is_err()
540 );
541 }
542
543 #[test]
544 fn dirichlet_shape() {
545 let mut rng = default_rng_seeded(42);
546 let alpha = [1.0, 2.0, 3.0];
547 let arr = rng.dirichlet(&alpha, 10).unwrap();
548 assert_eq!(arr.shape(), &[10, 3]);
549 }
550
551 #[test]
552 fn dirichlet_sums_to_one() {
553 let mut rng = default_rng_seeded(42);
554 let alpha = [0.5, 1.0, 2.0, 0.5];
555 let arr = rng.dirichlet(&alpha, 100).unwrap();
556 let slice = arr.as_slice().unwrap();
557 let k = alpha.len();
558 for row in 0..100 {
559 let row_sum: f64 = (0..k).map(|j| slice[row * k + j]).sum();
560 assert!(
561 (row_sum - 1.0).abs() < 1e-10,
562 "dirichlet row {row} sums to {row_sum}, expected 1.0"
563 );
564 }
565 }
566
567 #[test]
568 fn dirichlet_nonnegative() {
569 let mut rng = default_rng_seeded(42);
570 let alpha = [0.5, 1.0, 2.0];
571 let arr = rng.dirichlet(&alpha, 100).unwrap();
572 for &v in arr.as_slice().unwrap() {
573 assert!(v >= 0.0, "dirichlet produced negative value: {v}");
574 }
575 }
576
577 #[test]
578 fn dirichlet_bad_alpha() {
579 let mut rng = default_rng_seeded(42);
580 assert!(rng.dirichlet(&[], 10).is_err());
581 assert!(rng.dirichlet(&[1.0, 0.0], 10).is_err());
582 assert!(rng.dirichlet(&[1.0, -1.0], 10).is_err());
583 }
584
585 #[test]
588 fn mvn_array_shape() {
589 use ferray_core::{Array, Ix1, Ix2};
590 let mut rng = default_rng_seeded(42);
591 let mean = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
592 let cov = Array::<f64, Ix2>::from_vec(
593 Ix2::new([3, 3]),
594 vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0],
595 )
596 .unwrap();
597 let arr = rng.multivariate_normal_array(&mean, &cov, 100).unwrap();
598 assert_eq!(arr.shape(), &[100, 3]);
599 }
600
601 #[test]
602 fn mvn_array_means_match() {
603 use ferray_core::{Array, Ix1, Ix2};
604 let mut rng = default_rng_seeded(42);
605 let mean = Array::<f64, Ix1>::from_vec(Ix1::new([2]), vec![5.0, -3.0]).unwrap();
606 let cov = Array::<f64, Ix2>::from_vec(Ix2::new([2, 2]), vec![1.0, 0.0, 0.0, 1.0]).unwrap();
607 let n = 100_000;
608 let arr = rng.multivariate_normal_array(&mean, &cov, n).unwrap();
609 let slice = arr.as_slice().unwrap();
610 for j in 0..2 {
611 let m: f64 = (0..n).map(|i| slice[i * 2 + j]).sum::<f64>() / n as f64;
612 let se = (1.0 / n as f64).sqrt();
613 let want = mean.as_slice().unwrap()[j];
614 assert!((m - want).abs() < 4.0 * se, "col {j} mean {m} ≠ {want}");
615 }
616 }
617
618 #[test]
619 fn mvn_array_rejects_non_square_cov() {
620 use ferray_core::{Array, Ix1, Ix2};
621 let mut rng = default_rng_seeded(0);
622 let mean = Array::<f64, Ix1>::from_vec(Ix1::new([2]), vec![0.0, 0.0]).unwrap();
623 let cov = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0])
625 .unwrap();
626 assert!(rng.multivariate_normal_array(&mean, &cov, 5).is_err());
627 }
628
629 #[test]
630 fn mvn_array_rejects_non_pd_cov() {
631 use ferray_core::{Array, Ix1, Ix2};
632 let mut rng = default_rng_seeded(0);
633 let mean = Array::<f64, Ix1>::from_vec(Ix1::new([2]), vec![0.0, 0.0]).unwrap();
634 let cov = Array::<f64, Ix2>::from_vec(Ix2::new([2, 2]), vec![0.0, 1.0, 1.0, 0.0]).unwrap();
636 let err = rng.multivariate_normal_array(&mean, &cov, 5).unwrap_err();
637 assert!(matches!(
638 err,
639 ferray_core::FerrayError::SingularMatrix { .. }
640 ));
641 }
642
643 #[test]
644 fn mvn_array_correlated_recovers_cov() {
645 use ferray_core::{Array, Ix1, Ix2};
646 let mut rng = default_rng_seeded(11);
649 let mean = Array::<f64, Ix1>::from_vec(Ix1::new([2]), vec![0.0, 0.0]).unwrap();
650 let cov = Array::<f64, Ix2>::from_vec(Ix2::new([2, 2]), vec![1.0, 0.7, 0.7, 1.0]).unwrap();
651 let n = 50_000;
652 let arr = rng.multivariate_normal_array(&mean, &cov, n).unwrap();
653 let s = arr.as_slice().unwrap();
654 let mean0: f64 = (0..n).map(|i| s[i * 2]).sum::<f64>() / n as f64;
655 let mean1: f64 = (0..n).map(|i| s[i * 2 + 1]).sum::<f64>() / n as f64;
656 let cov01: f64 = (0..n)
657 .map(|i| (s[i * 2] - mean0) * (s[i * 2 + 1] - mean1))
658 .sum::<f64>()
659 / n as f64;
660 assert!((cov01 - 0.7).abs() < 0.05, "sample cov01 {cov01} ≠ 0.7");
661 }
662
663 #[test]
666 fn mvhg_shape_and_row_sum() {
667 let mut rng = default_rng_seeded(42);
668 let colors = [10u64, 20, 30];
669 let nsample = 15u64;
670 let arr = rng
671 .multivariate_hypergeometric(&colors, nsample, 50)
672 .unwrap();
673 assert_eq!(arr.shape(), &[50, 3]);
674 let slice = arr.as_slice().unwrap();
675 for row in 0..50 {
676 let row_sum: i64 = (0..3).map(|j| slice[row * 3 + j]).sum();
677 assert_eq!(row_sum, nsample as i64);
678 }
679 }
680
681 #[test]
682 fn mvhg_per_color_within_population() {
683 let mut rng = default_rng_seeded(123);
684 let colors = [5u64, 5, 5];
685 let arr = rng.multivariate_hypergeometric(&colors, 10, 200).unwrap();
686 let slice = arr.as_slice().unwrap();
687 for row in 0..200 {
688 for j in 0..3 {
689 let v = slice[row * 3 + j];
690 assert!(
691 v >= 0 && v <= colors[j] as i64,
692 "row {row} col {j}: count {v} out of [0, {}]",
693 colors[j]
694 );
695 }
696 }
697 }
698
699 #[test]
700 fn mvhg_marginal_means_match_theory() {
701 let mut rng = default_rng_seeded(7);
703 let colors = [10u64, 20, 30, 40];
704 let total: f64 = colors.iter().sum::<u64>() as f64;
705 let nsample = 25u64;
706 let n_draws = 10_000;
707 let arr = rng
708 .multivariate_hypergeometric(&colors, nsample, n_draws)
709 .unwrap();
710 let slice = arr.as_slice().unwrap();
711 let k = colors.len();
712 for j in 0..k {
713 let observed: f64 =
714 (0..n_draws).map(|i| slice[i * k + j] as f64).sum::<f64>() / n_draws as f64;
715 let expected = nsample as f64 * colors[j] as f64 / total;
716 let kj = colors[j] as f64;
718 let var = nsample as f64
719 * (kj / total)
720 * ((total - kj) / total)
721 * ((total - nsample as f64) / (total - 1.0));
722 let se = (var / n_draws as f64).sqrt();
723 assert!(
724 (observed - expected).abs() < 4.0 * se,
725 "color {j}: observed mean {observed}, expected {expected} ± {se}"
726 );
727 }
728 }
729
730 #[test]
731 fn mvhg_take_all() {
732 let mut rng = default_rng_seeded(0);
734 let colors = [3u64, 7, 0, 5];
735 let total: u64 = colors.iter().sum();
736 let arr = rng.multivariate_hypergeometric(&colors, total, 5).unwrap();
737 let slice = arr.as_slice().unwrap();
738 for row in 0..5 {
739 for j in 0..colors.len() {
740 assert_eq!(slice[row * colors.len() + j], colors[j] as i64);
741 }
742 }
743 }
744
745 #[test]
746 fn mvhg_seed_reproducible() {
747 let mut a = default_rng_seeded(99);
748 let mut b = default_rng_seeded(99);
749 let xa = a.multivariate_hypergeometric(&[5, 10, 15], 8, 30).unwrap();
750 let xb = b.multivariate_hypergeometric(&[5, 10, 15], 8, 30).unwrap();
751 assert_eq!(xa.as_slice().unwrap(), xb.as_slice().unwrap());
752 }
753
754 #[test]
755 fn mvhg_bad_params() {
756 let mut rng = default_rng_seeded(0);
757 assert!(rng.multivariate_hypergeometric(&[1, 2], 1, 0).is_err());
759 assert!(rng.multivariate_hypergeometric(&[], 0, 5).is_err());
761 assert!(rng.multivariate_hypergeometric(&[3, 4], 10, 5).is_err());
763 }
764}