Skip to main content

ferray_random/distributions/
multivariate.rs

1// ferray-random: Multivariate distributions — multinomial, multivariate_normal, dirichlet
2
3use ferray_core::{Array, FerrayError, Ix2};
4
5use crate::bitgen::BitGenerator;
6use crate::distributions::gamma::standard_gamma_single;
7use crate::distributions::normal::standard_normal_single;
8use crate::generator::Generator;
9
10impl<B: BitGenerator> Generator<B> {
11    /// Generate multinomial samples.
12    ///
13    /// Each row of the output is one draw of `n` items distributed across
14    /// `k` categories with probabilities `pvals`.
15    ///
16    /// # Arguments
17    /// * `n` - Number of trials per sample.
18    /// * `pvals` - Category probabilities (must sum to ~1.0, length k).
19    /// * `size` - Number of multinomial draws (rows in output).
20    ///
21    /// # Returns
22    /// An `Array<i64, Ix2>` with shape `[size, k]`.
23    ///
24    /// # Errors
25    /// Returns `FerrayError::InvalidValue` for invalid parameters.
26    pub fn multinomial(
27        &mut self,
28        n: u64,
29        pvals: &[f64],
30        size: usize,
31    ) -> Result<Array<i64, Ix2>, FerrayError> {
32        if size == 0 {
33            return Err(FerrayError::invalid_value("size must be > 0"));
34        }
35        if pvals.is_empty() {
36            return Err(FerrayError::invalid_value(
37                "pvals must have at least one element",
38            ));
39        }
40        let psum: f64 = pvals.iter().sum();
41        if (psum - 1.0).abs() > 1e-6 {
42            return Err(FerrayError::invalid_value(format!(
43                "pvals must sum to 1.0, got {psum}"
44            )));
45        }
46        for (i, &p) in pvals.iter().enumerate() {
47            if p < 0.0 {
48                return Err(FerrayError::invalid_value(format!(
49                    "pvals[{i}] = {p} is negative"
50                )));
51            }
52        }
53
54        let k = pvals.len();
55        let mut data = Vec::with_capacity(size * k);
56
57        for _ in 0..size {
58            let mut remaining = n;
59            let mut psum_remaining = 1.0;
60            for (j, &pj) in pvals.iter().enumerate() {
61                if j == k - 1 {
62                    // Last category gets all remaining
63                    data.push(remaining as i64);
64                } else if psum_remaining <= 0.0 || remaining == 0 {
65                    data.push(0);
66                } else {
67                    let p_cond = (pj / psum_remaining).clamp(0.0, 1.0);
68                    let count = binomial_for_multinomial(&mut self.bg, remaining, p_cond);
69                    data.push(count as i64);
70                    remaining -= count;
71                    psum_remaining -= pj;
72                }
73            }
74        }
75
76        Array::<i64, Ix2>::from_vec(Ix2::new([size, k]), data)
77    }
78
79    /// Generate multivariate normal samples.
80    ///
81    /// Uses the Cholesky decomposition of the covariance matrix.
82    ///
83    /// # Arguments
84    /// * `mean` - Mean vector of length `d`.
85    /// * `cov` - Covariance matrix, flattened in row-major order, shape `[d, d]`.
86    /// * `size` - Number of samples (rows in output).
87    ///
88    /// # Returns
89    /// An `Array<f64, Ix2>` with shape `[size, d]`.
90    ///
91    /// # Errors
92    /// Returns `FerrayError::InvalidValue` for invalid parameters or if
93    /// the covariance matrix is not positive semi-definite.
94    pub fn multivariate_normal(
95        &mut self,
96        mean: &[f64],
97        cov: &[f64],
98        size: usize,
99    ) -> Result<Array<f64, Ix2>, FerrayError> {
100        if size == 0 {
101            return Err(FerrayError::invalid_value("size must be > 0"));
102        }
103        let d = mean.len();
104        if d == 0 {
105            return Err(FerrayError::invalid_value("mean must be non-empty"));
106        }
107        if cov.len() != d * d {
108            return Err(FerrayError::invalid_value(format!(
109                "cov must have {} elements for mean of length {d}, got {}",
110                d * d,
111                cov.len()
112            )));
113        }
114
115        // Compute Cholesky decomposition L such that cov = L * L^T
116        let l = cholesky_decompose(cov, d)?;
117
118        let mut data = Vec::with_capacity(size * d);
119        for _ in 0..size {
120            // Generate d independent standard normals
121            let mut z = Vec::with_capacity(d);
122            for _ in 0..d {
123                z.push(standard_normal_single(&mut self.bg));
124            }
125
126            // x = mean + L * z
127            for i in 0..d {
128                let mut val = mean[i];
129                for j in 0..=i {
130                    val += l[i * d + j] * z[j];
131                }
132                data.push(val);
133            }
134        }
135
136        Array::<f64, Ix2>::from_vec(Ix2::new([size, d]), data)
137    }
138
139    /// Generate Dirichlet-distributed samples.
140    ///
141    /// Each row is a sample from the Dirichlet distribution parameterized
142    /// by `alpha`, producing vectors that sum to 1.
143    ///
144    /// # Arguments
145    /// * `alpha` - Concentration parameters (all must be positive).
146    /// * `size` - Number of samples (rows in output).
147    ///
148    /// # Returns
149    /// An `Array<f64, Ix2>` with shape `[size, k]` where k = alpha.len().
150    ///
151    /// # Errors
152    /// Returns `FerrayError::InvalidValue` for invalid parameters.
153    pub fn dirichlet(
154        &mut self,
155        alpha: &[f64],
156        size: usize,
157    ) -> Result<Array<f64, Ix2>, FerrayError> {
158        if size == 0 {
159            return Err(FerrayError::invalid_value("size must be > 0"));
160        }
161        if alpha.is_empty() {
162            return Err(FerrayError::invalid_value(
163                "alpha must have at least one element",
164            ));
165        }
166        for (i, &a) in alpha.iter().enumerate() {
167            if a <= 0.0 {
168                return Err(FerrayError::invalid_value(format!(
169                    "alpha[{i}] = {a} must be positive"
170                )));
171            }
172        }
173
174        let k = alpha.len();
175        let mut data = Vec::with_capacity(size * k);
176
177        for _ in 0..size {
178            let mut gammas = Vec::with_capacity(k);
179            let mut sum = 0.0;
180            for &a in alpha {
181                let g = standard_gamma_single(&mut self.bg, a);
182                gammas.push(g);
183                sum += g;
184            }
185            // Normalize
186            if sum > 0.0 {
187                for g in &gammas {
188                    data.push(g / sum);
189                }
190            } else {
191                // Degenerate: uniform
192                let val = 1.0 / k as f64;
193                for _ in 0..k {
194                    data.push(val);
195                }
196            }
197        }
198
199        Array::<f64, Ix2>::from_vec(Ix2::new([size, k]), data)
200    }
201}
202
203/// Simple binomial sampling for multinomial (avoids circular dependency).
204fn binomial_for_multinomial<B: BitGenerator>(bg: &mut B, n: u64, p: f64) -> u64 {
205    if n == 0 || p <= 0.0 {
206        return 0;
207    }
208    if p >= 1.0 {
209        return n;
210    }
211
212    let (pp, flipped) = if p > 0.5 { (1.0 - p, true) } else { (p, false) };
213
214    let result = if (n as f64) * pp < 30.0 {
215        // Inverse transform
216        let q = 1.0 - pp;
217        let s = pp / q;
218        let a = (n as f64 + 1.0) * s;
219        let mut r = q.powi(n as i32);
220        let mut u = bg.next_f64();
221        let mut x: u64 = 0;
222        while u > r {
223            u -= r;
224            x += 1;
225            if x > n {
226                x = n;
227                break;
228            }
229            r *= a / (x as f64) - s;
230            if r < 0.0 {
231                break;
232            }
233        }
234        x.min(n)
235    } else {
236        // Normal approximation for large n*p
237        loop {
238            let z = standard_normal_single(bg);
239            let sigma = ((n as f64) * pp * (1.0 - pp)).sqrt();
240            let x = ((n as f64) * pp + sigma * z + 0.5).floor() as i64;
241            if x >= 0 && x <= n as i64 {
242                break x as u64;
243            }
244        }
245    };
246
247    if flipped { n - result } else { result }
248}
249
250/// Cholesky decomposition of a symmetric positive-definite matrix.
251/// Input: flat row-major matrix `a` of size `n x n`.
252/// Output: lower-triangular matrix L such that A = L * L^T.
253fn cholesky_decompose(a: &[f64], n: usize) -> Result<Vec<f64>, FerrayError> {
254    let mut l = vec![0.0; n * n];
255
256    for i in 0..n {
257        for j in 0..=i {
258            let mut sum = 0.0;
259            for k in 0..j {
260                sum += l[i * n + k] * l[j * n + k];
261            }
262            if i == j {
263                let diag = a[i * n + i] - sum;
264                if diag < -1e-10 {
265                    return Err(FerrayError::invalid_value(
266                        "covariance matrix is not positive semi-definite",
267                    ));
268                }
269                l[i * n + j] = diag.max(0.0).sqrt();
270            } else {
271                let denom = l[j * n + j];
272                if denom.abs() < 1e-15 {
273                    l[i * n + j] = 0.0;
274                } else {
275                    l[i * n + j] = (a[i * n + j] - sum) / denom;
276                }
277            }
278        }
279    }
280
281    Ok(l)
282}
283
284#[cfg(test)]
285mod tests {
286    use crate::default_rng_seeded;
287
288    #[test]
289    fn multinomial_shape() {
290        let mut rng = default_rng_seeded(42);
291        let pvals = [0.2, 0.3, 0.5];
292        let arr = rng.multinomial(100, &pvals, 10).unwrap();
293        assert_eq!(arr.shape(), &[10, 3]);
294    }
295
296    #[test]
297    fn multinomial_row_sums() {
298        let mut rng = default_rng_seeded(42);
299        let pvals = [0.2, 0.3, 0.5];
300        let n = 100u64;
301        let arr = rng.multinomial(n, &pvals, 50).unwrap();
302        let slice = arr.as_slice().unwrap();
303        let k = pvals.len();
304        for row in 0..50 {
305            let row_sum: i64 = (0..k).map(|j| slice[row * k + j]).sum();
306            assert_eq!(
307                row_sum, n as i64,
308                "row {row} sum is {row_sum}, expected {n}"
309            );
310        }
311    }
312
313    #[test]
314    fn multinomial_nonnegative() {
315        let mut rng = default_rng_seeded(42);
316        let pvals = [0.1, 0.2, 0.3, 0.4];
317        let arr = rng.multinomial(50, &pvals, 100).unwrap();
318        for &v in arr.as_slice().unwrap() {
319            assert!(v >= 0, "multinomial produced negative count: {v}");
320        }
321    }
322
323    #[test]
324    fn multinomial_bad_pvals() {
325        let mut rng = default_rng_seeded(42);
326        assert!(rng.multinomial(10, &[0.5, 0.6], 10).is_err()); // sum > 1
327        assert!(rng.multinomial(10, &[-0.1, 1.1], 10).is_err()); // negative
328        assert!(rng.multinomial(10, &[], 10).is_err()); // empty
329    }
330
331    #[test]
332    fn multivariate_normal_shape() {
333        let mut rng = default_rng_seeded(42);
334        let mean = [1.0, 2.0, 3.0];
335        let cov = [1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0];
336        let arr = rng.multivariate_normal(&mean, &cov, 100).unwrap();
337        assert_eq!(arr.shape(), &[100, 3]);
338    }
339
340    #[test]
341    fn multivariate_normal_mean() {
342        let mut rng = default_rng_seeded(42);
343        let mean = [5.0, -3.0];
344        let cov = [1.0, 0.0, 0.0, 1.0];
345        let n = 100_000;
346        let arr = rng.multivariate_normal(&mean, &cov, n).unwrap();
347        let slice = arr.as_slice().unwrap();
348        let d = mean.len();
349
350        for j in 0..d {
351            let col_mean: f64 = (0..n).map(|i| slice[i * d + j]).sum::<f64>() / n as f64;
352            let se = (1.0 / n as f64).sqrt();
353            assert!(
354                (col_mean - mean[j]).abs() < 3.0 * se,
355                "multivariate_normal mean[{j}] = {col_mean}, expected {}",
356                mean[j]
357            );
358        }
359    }
360
361    #[test]
362    fn multivariate_normal_bad_cov() {
363        let mut rng = default_rng_seeded(42);
364        let mean = [0.0, 0.0];
365        // Wrong size cov
366        assert!(
367            rng.multivariate_normal(&mean, &[1.0, 0.0, 0.0], 10)
368                .is_err()
369        );
370    }
371
372    #[test]
373    fn dirichlet_shape() {
374        let mut rng = default_rng_seeded(42);
375        let alpha = [1.0, 2.0, 3.0];
376        let arr = rng.dirichlet(&alpha, 10).unwrap();
377        assert_eq!(arr.shape(), &[10, 3]);
378    }
379
380    #[test]
381    fn dirichlet_sums_to_one() {
382        let mut rng = default_rng_seeded(42);
383        let alpha = [0.5, 1.0, 2.0, 0.5];
384        let arr = rng.dirichlet(&alpha, 100).unwrap();
385        let slice = arr.as_slice().unwrap();
386        let k = alpha.len();
387        for row in 0..100 {
388            let row_sum: f64 = (0..k).map(|j| slice[row * k + j]).sum();
389            assert!(
390                (row_sum - 1.0).abs() < 1e-10,
391                "dirichlet row {row} sums to {row_sum}, expected 1.0"
392            );
393        }
394    }
395
396    #[test]
397    fn dirichlet_nonnegative() {
398        let mut rng = default_rng_seeded(42);
399        let alpha = [0.5, 1.0, 2.0];
400        let arr = rng.dirichlet(&alpha, 100).unwrap();
401        for &v in arr.as_slice().unwrap() {
402            assert!(v >= 0.0, "dirichlet produced negative value: {v}");
403        }
404    }
405
406    #[test]
407    fn dirichlet_bad_alpha() {
408        let mut rng = default_rng_seeded(42);
409        assert!(rng.dirichlet(&[], 10).is_err());
410        assert!(rng.dirichlet(&[1.0, 0.0], 10).is_err());
411        assert!(rng.dirichlet(&[1.0, -1.0], 10).is_err());
412    }
413}