Skip to main content

ferray_random/distributions/
multivariate.rs

1// ferray-random: Multivariate distributions — multinomial, multivariate_normal, dirichlet
2
3use ferray_core::{Array, FerrayError, Ix1, Ix2};
4
5use crate::bitgen::BitGenerator;
6use crate::distributions::gamma::standard_gamma_single;
7use crate::distributions::normal::standard_normal_single;
8use crate::generator::Generator;
9
10impl<B: BitGenerator> Generator<B> {
11    /// Generate multinomial samples.
12    ///
13    /// Each row of the output is one draw of `n` items distributed across
14    /// `k` categories with probabilities `pvals`.
15    ///
16    /// # Arguments
17    /// * `n` - Number of trials per sample.
18    /// * `pvals` - Category probabilities (must sum to ~1.0, length k).
19    /// * `size` - Number of multinomial draws (rows in output).
20    ///
21    /// # Returns
22    /// An `Array<i64, Ix2>` with shape `[size, k]`.
23    ///
24    /// # Errors
25    /// Returns `FerrayError::InvalidValue` for invalid parameters.
26    pub fn multinomial(
27        &mut self,
28        n: u64,
29        pvals: &[f64],
30        size: usize,
31    ) -> Result<Array<i64, Ix2>, FerrayError> {
32        if size == 0 {
33            return Err(FerrayError::invalid_value("size must be > 0"));
34        }
35        if pvals.is_empty() {
36            return Err(FerrayError::invalid_value(
37                "pvals must have at least one element",
38            ));
39        }
40        let psum: f64 = pvals.iter().sum();
41        if (psum - 1.0).abs() > 1e-6 {
42            return Err(FerrayError::invalid_value(format!(
43                "pvals must sum to 1.0, got {psum}"
44            )));
45        }
46        for (i, &p) in pvals.iter().enumerate() {
47            if p < 0.0 {
48                return Err(FerrayError::invalid_value(format!(
49                    "pvals[{i}] = {p} is negative"
50                )));
51            }
52        }
53
54        let k = pvals.len();
55        let mut data = Vec::with_capacity(size * k);
56
57        for _ in 0..size {
58            let mut remaining = n;
59            let mut psum_remaining = 1.0;
60            for (j, &pj) in pvals.iter().enumerate() {
61                if j == k - 1 {
62                    // Last category gets all remaining
63                    data.push(remaining as i64);
64                } else if psum_remaining <= 0.0 || remaining == 0 {
65                    data.push(0);
66                } else {
67                    let p_cond = (pj / psum_remaining).clamp(0.0, 1.0);
68                    let count = binomial_for_multinomial(&mut self.bg, remaining, p_cond);
69                    data.push(count as i64);
70                    remaining -= count;
71                    psum_remaining -= pj;
72                }
73            }
74        }
75
76        Array::<i64, Ix2>::from_vec(Ix2::new([size, k]), data)
77    }
78
79    /// Generate multivariate normal samples.
80    ///
81    /// Uses the Cholesky decomposition of the covariance matrix.
82    ///
83    /// # Arguments
84    /// * `mean` - Mean vector of length `d`.
85    /// * `cov` - Covariance matrix, flattened in row-major order, shape `[d, d]`.
86    /// * `size` - Number of samples (rows in output).
87    ///
88    /// # Returns
89    /// An `Array<f64, Ix2>` with shape `[size, d]`.
90    ///
91    /// # Errors
92    /// Returns `FerrayError::InvalidValue` for invalid parameters or if
93    /// the covariance matrix is not positive semi-definite.
94    pub fn multivariate_normal(
95        &mut self,
96        mean: &[f64],
97        cov: &[f64],
98        size: usize,
99    ) -> Result<Array<f64, Ix2>, FerrayError> {
100        if size == 0 {
101            return Err(FerrayError::invalid_value("size must be > 0"));
102        }
103        let d = mean.len();
104        if d == 0 {
105            return Err(FerrayError::invalid_value("mean must be non-empty"));
106        }
107        if cov.len() != d * d {
108            return Err(FerrayError::invalid_value(format!(
109                "cov must have {} elements for mean of length {d}, got {}",
110                d * d,
111                cov.len()
112            )));
113        }
114
115        // Compute Cholesky decomposition L such that cov = L * L^T
116        let l = cholesky_decompose(cov, d)?;
117
118        let mut data = Vec::with_capacity(size * d);
119        for _ in 0..size {
120            // Generate d independent standard normals
121            let mut z = Vec::with_capacity(d);
122            for _ in 0..d {
123                z.push(standard_normal_single(&mut self.bg));
124            }
125
126            // x = mean + L * z
127            for i in 0..d {
128                let mut val = mean[i];
129                for j in 0..=i {
130                    val += l[i * d + j] * z[j];
131                }
132                data.push(val);
133            }
134        }
135
136        Array::<f64, Ix2>::from_vec(Ix2::new([size, d]), data)
137    }
138
139    /// Generate samples from the multivariate hypergeometric distribution.
140    ///
141    /// Models drawing `nsample` items without replacement from a population
142    /// partitioned into `K` colors with `colors[k]` items of color `k`.
143    /// Each row of the output is one such draw — a vector of `K` non-negative
144    /// counts summing to `nsample`.
145    ///
146    /// Uses the marginals algorithm: a sequence of `K-1` univariate
147    /// hypergeometric draws, each picking the count of one color from the
148    /// remainder of the population. The final color is what's left. This
149    /// matches `numpy.random.Generator.multivariate_hypergeometric` (#445).
150    ///
151    /// # Arguments
152    /// * `colors` - Number of items of each color (length K, all non-negative).
153    /// * `nsample` - Number of items drawn per sample (must be ≤ sum of `colors`).
154    /// * `size` - Number of multivariate draws (rows in output).
155    ///
156    /// # Returns
157    /// An `Array<i64, Ix2>` with shape `[size, K]`.
158    ///
159    /// # Errors
160    /// Returns `FerrayError::InvalidValue` if `colors` is empty, `size` is 0,
161    /// or `nsample` exceeds the population total.
162    pub fn multivariate_hypergeometric(
163        &mut self,
164        colors: &[u64],
165        nsample: u64,
166        size: usize,
167    ) -> Result<Array<i64, Ix2>, FerrayError> {
168        if size == 0 {
169            return Err(FerrayError::invalid_value("size must be > 0"));
170        }
171        if colors.is_empty() {
172            return Err(FerrayError::invalid_value(
173                "colors must have at least one element",
174            ));
175        }
176        let total: u64 = colors.iter().try_fold(0_u64, |acc, &c| {
177            acc.checked_add(c).ok_or_else(|| {
178                FerrayError::invalid_value("multivariate_hypergeometric: colors sum overflows u64")
179            })
180        })?;
181        if nsample > total {
182            return Err(FerrayError::invalid_value(format!(
183                "nsample ({nsample}) > sum of colors ({total})"
184            )));
185        }
186
187        let k = colors.len();
188        let mut data = Vec::with_capacity(size * k);
189
190        for _ in 0..size {
191            // For each color j in 0..k-1, draw a univariate hypergeometric
192            // from (ngood = colors[j], nbad = remaining_total - colors[j]).
193            // Subtract the draw and proceed; the last color gets whatever
194            // is left of `nsample`.
195            let mut remaining_pop: u64 = total;
196            let mut remaining_sample: u64 = nsample;
197
198            for &ngood in &colors[..k - 1] {
199                let nbad = remaining_pop - ngood;
200                let draw = if remaining_sample == 0 || ngood == 0 {
201                    0
202                } else if remaining_sample >= remaining_pop {
203                    // Take everything that's left of this color.
204                    ngood as i64
205                } else {
206                    hypergeometric_for_multivariate(&mut self.bg, ngood, nbad, remaining_sample)
207                };
208                data.push(draw);
209                remaining_pop -= ngood;
210                remaining_sample -= draw as u64;
211            }
212            // Final color absorbs the remainder of the sample.
213            data.push(remaining_sample as i64);
214        }
215
216        Array::<i64, Ix2>::from_vec(Ix2::new([size, k]), data)
217    }
218
219    /// Generate multivariate normal samples taking `Array` parameters.
220    ///
221    /// Ergonomic counterpart to [`multivariate_normal`] (#451): accepts
222    /// the mean as `Array<f64, Ix1>` and the covariance as
223    /// `Array<f64, Ix2>` directly, no manual flattening required.
224    ///
225    /// Cholesky decomposition is delegated to `ferray_linalg::cholesky`
226    /// (#452) which is faer-backed and surfaces non-positive-definite
227    /// inputs as `FerrayError::SingularMatrix` instead of the
228    /// home-grown `cholesky_decompose` helper.
229    ///
230    /// # Errors
231    /// - `FerrayError::ShapeMismatch` if `cov` is not square or its
232    ///   side does not match `mean.len()`.
233    /// - `FerrayError::SingularMatrix` if `cov` is not positive
234    ///   definite (propagated from `ferray-linalg`).
235    /// - `FerrayError::InvalidValue` for size = 0 or empty mean.
236    pub fn multivariate_normal_array(
237        &mut self,
238        mean: &Array<f64, Ix1>,
239        cov: &Array<f64, Ix2>,
240        size: usize,
241    ) -> Result<Array<f64, Ix2>, FerrayError> {
242        if size == 0 {
243            return Err(FerrayError::invalid_value("size must be > 0"));
244        }
245        let d = mean.shape()[0];
246        if d == 0 {
247            return Err(FerrayError::invalid_value("mean must be non-empty"));
248        }
249        let cov_shape = cov.shape();
250        if cov_shape[0] != d || cov_shape[1] != d {
251            return Err(FerrayError::shape_mismatch(format!(
252                "cov shape {cov_shape:?} does not match mean of length {d}"
253            )));
254        }
255
256        let l_arr = ferray_linalg::cholesky(cov)?;
257        let l_slice = l_arr
258            .as_slice()
259            .ok_or_else(|| FerrayError::invalid_value("cholesky returned non-contiguous L"))?;
260        let mean_slice = mean
261            .as_slice()
262            .ok_or_else(|| FerrayError::invalid_value("mean must be contiguous"))?;
263
264        let mut data = Vec::with_capacity(size * d);
265        let mut z = vec![0.0_f64; d];
266        for _ in 0..size {
267            for v in z.iter_mut() {
268                *v = standard_normal_single(&mut self.bg);
269            }
270            for i in 0..d {
271                let mut val = mean_slice[i];
272                for j in 0..=i {
273                    val += l_slice[i * d + j] * z[j];
274                }
275                data.push(val);
276            }
277        }
278        Array::<f64, Ix2>::from_vec(Ix2::new([size, d]), data)
279    }
280
281    /// Generate Dirichlet-distributed samples.
282    ///
283    /// Each row is a sample from the Dirichlet distribution parameterized
284    /// by `alpha`, producing vectors that sum to 1.
285    ///
286    /// # Arguments
287    /// * `alpha` - Concentration parameters (all must be positive).
288    /// * `size` - Number of samples (rows in output).
289    ///
290    /// # Returns
291    /// An `Array<f64, Ix2>` with shape `[size, k]` where k = `alpha.len()`.
292    ///
293    /// # Errors
294    /// Returns `FerrayError::InvalidValue` for invalid parameters.
295    pub fn dirichlet(
296        &mut self,
297        alpha: &[f64],
298        size: usize,
299    ) -> Result<Array<f64, Ix2>, FerrayError> {
300        if size == 0 {
301            return Err(FerrayError::invalid_value("size must be > 0"));
302        }
303        if alpha.is_empty() {
304            return Err(FerrayError::invalid_value(
305                "alpha must have at least one element",
306            ));
307        }
308        for (i, &a) in alpha.iter().enumerate() {
309            if a <= 0.0 {
310                return Err(FerrayError::invalid_value(format!(
311                    "alpha[{i}] = {a} must be positive"
312                )));
313            }
314        }
315
316        let k = alpha.len();
317        let mut data = Vec::with_capacity(size * k);
318
319        for _ in 0..size {
320            let mut gammas = Vec::with_capacity(k);
321            let mut sum = 0.0;
322            for &a in alpha {
323                let g = standard_gamma_single(&mut self.bg, a);
324                gammas.push(g);
325                sum += g;
326            }
327            // Normalize
328            if sum > 0.0 {
329                for g in &gammas {
330                    data.push(g / sum);
331                }
332            } else {
333                // Degenerate: uniform
334                let val = 1.0 / k as f64;
335                for _ in 0..k {
336                    data.push(val);
337                }
338            }
339        }
340
341        Array::<f64, Ix2>::from_vec(Ix2::new([size, k]), data)
342    }
343}
344
345/// Univariate hypergeometric draw used by `multivariate_hypergeometric`.
346/// Direct simulation: draw `nsample` items from a population of
347/// `ngood + nbad` and count successes. Equivalent to the helper in
348/// `discrete.rs::hypergeometric_single` — kept private here to avoid
349/// the cross-module visibility wiring that would otherwise pull in
350/// generator internals.
351fn hypergeometric_for_multivariate<B: BitGenerator>(
352    bg: &mut B,
353    ngood: u64,
354    nbad: u64,
355    nsample: u64,
356) -> i64 {
357    let mut good_remaining = ngood;
358    let mut total_remaining = ngood + nbad;
359    let mut successes: i64 = 0;
360    for _ in 0..nsample {
361        if total_remaining == 0 {
362            break;
363        }
364        let u = bg.next_f64();
365        if u < (good_remaining as f64) / (total_remaining as f64) {
366            successes += 1;
367            good_remaining -= 1;
368        }
369        total_remaining -= 1;
370    }
371    successes
372}
373
374/// Simple binomial sampling for multinomial (avoids circular dependency).
375fn binomial_for_multinomial<B: BitGenerator>(bg: &mut B, n: u64, p: f64) -> u64 {
376    if n == 0 || p <= 0.0 {
377        return 0;
378    }
379    if p >= 1.0 {
380        return n;
381    }
382
383    let (pp, flipped) = if p > 0.5 { (1.0 - p, true) } else { (p, false) };
384
385    let result = if (n as f64) * pp < 30.0 {
386        // Inverse transform
387        let q = 1.0 - pp;
388        let s = pp / q;
389        let a = (n as f64 + 1.0) * s;
390        let mut r = q.powf(n as f64);
391        let mut u = bg.next_f64();
392        let mut x: u64 = 0;
393        while u > r {
394            u -= r;
395            x += 1;
396            if x > n {
397                x = n;
398                break;
399            }
400            r *= a / (x as f64) - s;
401            if r < 0.0 {
402                break;
403            }
404        }
405        x.min(n)
406    } else {
407        // Normal approximation for large n*p
408        loop {
409            let z = standard_normal_single(bg);
410            let sigma = ((n as f64) * pp * (1.0 - pp)).sqrt();
411            let x = ((n as f64).mul_add(pp, sigma * z) + 0.5).floor() as i64;
412            if x >= 0 && x <= n as i64 {
413                break x as u64;
414            }
415        }
416    };
417
418    if flipped { n - result } else { result }
419}
420
421/// Cholesky decomposition of a symmetric positive-definite matrix.
422/// Input: flat row-major matrix `a` of size `n x n`.
423/// Output: lower-triangular matrix L such that A = L * L^T.
424fn cholesky_decompose(a: &[f64], n: usize) -> Result<Vec<f64>, FerrayError> {
425    let mut l = vec![0.0; n * n];
426
427    for i in 0..n {
428        for j in 0..=i {
429            let mut sum = 0.0;
430            for k in 0..j {
431                sum += l[i * n + k] * l[j * n + k];
432            }
433            if i == j {
434                let diag = a[i * n + i] - sum;
435                if diag < -1e-10 {
436                    return Err(FerrayError::invalid_value(
437                        "covariance matrix is not positive semi-definite",
438                    ));
439                }
440                l[i * n + j] = diag.max(0.0).sqrt();
441            } else {
442                let denom = l[j * n + j];
443                if denom.abs() < 1e-15 {
444                    l[i * n + j] = 0.0;
445                } else {
446                    l[i * n + j] = (a[i * n + j] - sum) / denom;
447                }
448            }
449        }
450    }
451
452    Ok(l)
453}
454
455#[cfg(test)]
456mod tests {
457    use crate::default_rng_seeded;
458
459    #[test]
460    fn multinomial_shape() {
461        let mut rng = default_rng_seeded(42);
462        let pvals = [0.2, 0.3, 0.5];
463        let arr = rng.multinomial(100, &pvals, 10).unwrap();
464        assert_eq!(arr.shape(), &[10, 3]);
465    }
466
467    #[test]
468    fn multinomial_row_sums() {
469        let mut rng = default_rng_seeded(42);
470        let pvals = [0.2, 0.3, 0.5];
471        let n = 100u64;
472        let arr = rng.multinomial(n, &pvals, 50).unwrap();
473        let slice = arr.as_slice().unwrap();
474        let k = pvals.len();
475        for row in 0..50 {
476            let row_sum: i64 = (0..k).map(|j| slice[row * k + j]).sum();
477            assert_eq!(
478                row_sum, n as i64,
479                "row {row} sum is {row_sum}, expected {n}"
480            );
481        }
482    }
483
484    #[test]
485    fn multinomial_nonnegative() {
486        let mut rng = default_rng_seeded(42);
487        let pvals = [0.1, 0.2, 0.3, 0.4];
488        let arr = rng.multinomial(50, &pvals, 100).unwrap();
489        for &v in arr.as_slice().unwrap() {
490            assert!(v >= 0, "multinomial produced negative count: {v}");
491        }
492    }
493
494    #[test]
495    fn multinomial_bad_pvals() {
496        let mut rng = default_rng_seeded(42);
497        assert!(rng.multinomial(10, &[0.5, 0.6], 10).is_err()); // sum > 1
498        assert!(rng.multinomial(10, &[-0.1, 1.1], 10).is_err()); // negative
499        assert!(rng.multinomial(10, &[], 10).is_err()); // empty
500    }
501
502    #[test]
503    fn multivariate_normal_shape() {
504        let mut rng = default_rng_seeded(42);
505        let mean = [1.0, 2.0, 3.0];
506        let cov = [1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0];
507        let arr = rng.multivariate_normal(&mean, &cov, 100).unwrap();
508        assert_eq!(arr.shape(), &[100, 3]);
509    }
510
511    #[test]
512    fn multivariate_normal_mean() {
513        let mut rng = default_rng_seeded(42);
514        let mean = [5.0, -3.0];
515        let cov = [1.0, 0.0, 0.0, 1.0];
516        let n = 100_000;
517        let arr = rng.multivariate_normal(&mean, &cov, n).unwrap();
518        let slice = arr.as_slice().unwrap();
519        let d = mean.len();
520
521        for j in 0..d {
522            let col_mean: f64 = (0..n).map(|i| slice[i * d + j]).sum::<f64>() / n as f64;
523            let se = (1.0 / n as f64).sqrt();
524            assert!(
525                (col_mean - mean[j]).abs() < 3.0 * se,
526                "multivariate_normal mean[{j}] = {col_mean}, expected {}",
527                mean[j]
528            );
529        }
530    }
531
532    #[test]
533    fn multivariate_normal_bad_cov() {
534        let mut rng = default_rng_seeded(42);
535        let mean = [0.0, 0.0];
536        // Wrong size cov
537        assert!(
538            rng.multivariate_normal(&mean, &[1.0, 0.0, 0.0], 10)
539                .is_err()
540        );
541    }
542
543    #[test]
544    fn dirichlet_shape() {
545        let mut rng = default_rng_seeded(42);
546        let alpha = [1.0, 2.0, 3.0];
547        let arr = rng.dirichlet(&alpha, 10).unwrap();
548        assert_eq!(arr.shape(), &[10, 3]);
549    }
550
551    #[test]
552    fn dirichlet_sums_to_one() {
553        let mut rng = default_rng_seeded(42);
554        let alpha = [0.5, 1.0, 2.0, 0.5];
555        let arr = rng.dirichlet(&alpha, 100).unwrap();
556        let slice = arr.as_slice().unwrap();
557        let k = alpha.len();
558        for row in 0..100 {
559            let row_sum: f64 = (0..k).map(|j| slice[row * k + j]).sum();
560            assert!(
561                (row_sum - 1.0).abs() < 1e-10,
562                "dirichlet row {row} sums to {row_sum}, expected 1.0"
563            );
564        }
565    }
566
567    #[test]
568    fn dirichlet_nonnegative() {
569        let mut rng = default_rng_seeded(42);
570        let alpha = [0.5, 1.0, 2.0];
571        let arr = rng.dirichlet(&alpha, 100).unwrap();
572        for &v in arr.as_slice().unwrap() {
573            assert!(v >= 0.0, "dirichlet produced negative value: {v}");
574        }
575    }
576
577    #[test]
578    fn dirichlet_bad_alpha() {
579        let mut rng = default_rng_seeded(42);
580        assert!(rng.dirichlet(&[], 10).is_err());
581        assert!(rng.dirichlet(&[1.0, 0.0], 10).is_err());
582        assert!(rng.dirichlet(&[1.0, -1.0], 10).is_err());
583    }
584
585    // ---- multivariate_normal_array (#451, #452) ------------------------
586
587    #[test]
588    fn mvn_array_shape() {
589        use ferray_core::{Array, Ix1, Ix2};
590        let mut rng = default_rng_seeded(42);
591        let mean = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
592        let cov = Array::<f64, Ix2>::from_vec(
593            Ix2::new([3, 3]),
594            vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0],
595        )
596        .unwrap();
597        let arr = rng.multivariate_normal_array(&mean, &cov, 100).unwrap();
598        assert_eq!(arr.shape(), &[100, 3]);
599    }
600
601    #[test]
602    fn mvn_array_means_match() {
603        use ferray_core::{Array, Ix1, Ix2};
604        let mut rng = default_rng_seeded(42);
605        let mean = Array::<f64, Ix1>::from_vec(Ix1::new([2]), vec![5.0, -3.0]).unwrap();
606        let cov = Array::<f64, Ix2>::from_vec(Ix2::new([2, 2]), vec![1.0, 0.0, 0.0, 1.0]).unwrap();
607        let n = 100_000;
608        let arr = rng.multivariate_normal_array(&mean, &cov, n).unwrap();
609        let slice = arr.as_slice().unwrap();
610        for j in 0..2 {
611            let m: f64 = (0..n).map(|i| slice[i * 2 + j]).sum::<f64>() / n as f64;
612            let se = (1.0 / n as f64).sqrt();
613            let want = mean.as_slice().unwrap()[j];
614            assert!((m - want).abs() < 4.0 * se, "col {j} mean {m} ≠ {want}");
615        }
616    }
617
618    #[test]
619    fn mvn_array_rejects_non_square_cov() {
620        use ferray_core::{Array, Ix1, Ix2};
621        let mut rng = default_rng_seeded(0);
622        let mean = Array::<f64, Ix1>::from_vec(Ix1::new([2]), vec![0.0, 0.0]).unwrap();
623        // 2×3 cov — neither square nor matching mean length.
624        let cov = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0])
625            .unwrap();
626        assert!(rng.multivariate_normal_array(&mean, &cov, 5).is_err());
627    }
628
629    #[test]
630    fn mvn_array_rejects_non_pd_cov() {
631        use ferray_core::{Array, Ix1, Ix2};
632        let mut rng = default_rng_seeded(0);
633        let mean = Array::<f64, Ix1>::from_vec(Ix1::new([2]), vec![0.0, 0.0]).unwrap();
634        // Indefinite (eigenvalues ±1).
635        let cov = Array::<f64, Ix2>::from_vec(Ix2::new([2, 2]), vec![0.0, 1.0, 1.0, 0.0]).unwrap();
636        let err = rng.multivariate_normal_array(&mean, &cov, 5).unwrap_err();
637        assert!(matches!(
638            err,
639            ferray_core::FerrayError::SingularMatrix { .. }
640        ));
641    }
642
643    #[test]
644    fn mvn_array_correlated_recovers_cov() {
645        use ferray_core::{Array, Ix1, Ix2};
646        // Strongly correlated covariance — sample covariance should
647        // approximate the input cov to within sampling error.
648        let mut rng = default_rng_seeded(11);
649        let mean = Array::<f64, Ix1>::from_vec(Ix1::new([2]), vec![0.0, 0.0]).unwrap();
650        let cov = Array::<f64, Ix2>::from_vec(Ix2::new([2, 2]), vec![1.0, 0.7, 0.7, 1.0]).unwrap();
651        let n = 50_000;
652        let arr = rng.multivariate_normal_array(&mean, &cov, n).unwrap();
653        let s = arr.as_slice().unwrap();
654        let mean0: f64 = (0..n).map(|i| s[i * 2]).sum::<f64>() / n as f64;
655        let mean1: f64 = (0..n).map(|i| s[i * 2 + 1]).sum::<f64>() / n as f64;
656        let cov01: f64 = (0..n)
657            .map(|i| (s[i * 2] - mean0) * (s[i * 2 + 1] - mean1))
658            .sum::<f64>()
659            / n as f64;
660        assert!((cov01 - 0.7).abs() < 0.05, "sample cov01 {cov01} ≠ 0.7");
661    }
662
663    // ---- multivariate_hypergeometric (#445) ----------------------------
664
665    #[test]
666    fn mvhg_shape_and_row_sum() {
667        let mut rng = default_rng_seeded(42);
668        let colors = [10u64, 20, 30];
669        let nsample = 15u64;
670        let arr = rng
671            .multivariate_hypergeometric(&colors, nsample, 50)
672            .unwrap();
673        assert_eq!(arr.shape(), &[50, 3]);
674        let slice = arr.as_slice().unwrap();
675        for row in 0..50 {
676            let row_sum: i64 = (0..3).map(|j| slice[row * 3 + j]).sum();
677            assert_eq!(row_sum, nsample as i64);
678        }
679    }
680
681    #[test]
682    fn mvhg_per_color_within_population() {
683        let mut rng = default_rng_seeded(123);
684        let colors = [5u64, 5, 5];
685        let arr = rng.multivariate_hypergeometric(&colors, 10, 200).unwrap();
686        let slice = arr.as_slice().unwrap();
687        for row in 0..200 {
688            for j in 0..3 {
689                let v = slice[row * 3 + j];
690                assert!(
691                    v >= 0 && v <= colors[j] as i64,
692                    "row {row} col {j}: count {v} out of [0, {}]",
693                    colors[j]
694                );
695            }
696        }
697    }
698
699    #[test]
700    fn mvhg_marginal_means_match_theory() {
701        // E[X_j] = nsample * colors[j] / sum(colors)
702        let mut rng = default_rng_seeded(7);
703        let colors = [10u64, 20, 30, 40];
704        let total: f64 = colors.iter().sum::<u64>() as f64;
705        let nsample = 25u64;
706        let n_draws = 10_000;
707        let arr = rng
708            .multivariate_hypergeometric(&colors, nsample, n_draws)
709            .unwrap();
710        let slice = arr.as_slice().unwrap();
711        let k = colors.len();
712        for j in 0..k {
713            let observed: f64 =
714                (0..n_draws).map(|i| slice[i * k + j] as f64).sum::<f64>() / n_draws as f64;
715            let expected = nsample as f64 * colors[j] as f64 / total;
716            // Marginal variance: nsample * (Kj/N) * (N-Kj)/N * (N-nsample)/(N-1)
717            let kj = colors[j] as f64;
718            let var = nsample as f64
719                * (kj / total)
720                * ((total - kj) / total)
721                * ((total - nsample as f64) / (total - 1.0));
722            let se = (var / n_draws as f64).sqrt();
723            assert!(
724                (observed - expected).abs() < 4.0 * se,
725                "color {j}: observed mean {observed}, expected {expected} ± {se}"
726            );
727        }
728    }
729
730    #[test]
731    fn mvhg_take_all() {
732        // nsample == total: result is exactly the colors vector.
733        let mut rng = default_rng_seeded(0);
734        let colors = [3u64, 7, 0, 5];
735        let total: u64 = colors.iter().sum();
736        let arr = rng.multivariate_hypergeometric(&colors, total, 5).unwrap();
737        let slice = arr.as_slice().unwrap();
738        for row in 0..5 {
739            for j in 0..colors.len() {
740                assert_eq!(slice[row * colors.len() + j], colors[j] as i64);
741            }
742        }
743    }
744
745    #[test]
746    fn mvhg_seed_reproducible() {
747        let mut a = default_rng_seeded(99);
748        let mut b = default_rng_seeded(99);
749        let xa = a.multivariate_hypergeometric(&[5, 10, 15], 8, 30).unwrap();
750        let xb = b.multivariate_hypergeometric(&[5, 10, 15], 8, 30).unwrap();
751        assert_eq!(xa.as_slice().unwrap(), xb.as_slice().unwrap());
752    }
753
754    #[test]
755    fn mvhg_bad_params() {
756        let mut rng = default_rng_seeded(0);
757        // size = 0
758        assert!(rng.multivariate_hypergeometric(&[1, 2], 1, 0).is_err());
759        // empty colors
760        assert!(rng.multivariate_hypergeometric(&[], 0, 5).is_err());
761        // nsample > total
762        assert!(rng.multivariate_hypergeometric(&[3, 4], 10, 5).is_err());
763    }
764}