ferray_random/distributions/
multivariate.rs1use ferray_core::{Array, FerrayError, Ix2};
4
5use crate::bitgen::BitGenerator;
6use crate::distributions::gamma::standard_gamma_single;
7use crate::distributions::normal::standard_normal_single;
8use crate::generator::Generator;
9
10impl<B: BitGenerator> Generator<B> {
11 pub fn multinomial(
27 &mut self,
28 n: u64,
29 pvals: &[f64],
30 size: usize,
31 ) -> Result<Array<i64, Ix2>, FerrayError> {
32 if size == 0 {
33 return Err(FerrayError::invalid_value("size must be > 0"));
34 }
35 if pvals.is_empty() {
36 return Err(FerrayError::invalid_value(
37 "pvals must have at least one element",
38 ));
39 }
40 let psum: f64 = pvals.iter().sum();
41 if (psum - 1.0).abs() > 1e-6 {
42 return Err(FerrayError::invalid_value(format!(
43 "pvals must sum to 1.0, got {psum}"
44 )));
45 }
46 for (i, &p) in pvals.iter().enumerate() {
47 if p < 0.0 {
48 return Err(FerrayError::invalid_value(format!(
49 "pvals[{i}] = {p} is negative"
50 )));
51 }
52 }
53
54 let k = pvals.len();
55 let mut data = Vec::with_capacity(size * k);
56
57 for _ in 0..size {
58 let mut remaining = n;
59 let mut psum_remaining = 1.0;
60 for (j, &pj) in pvals.iter().enumerate() {
61 if j == k - 1 {
62 data.push(remaining as i64);
64 } else if psum_remaining <= 0.0 || remaining == 0 {
65 data.push(0);
66 } else {
67 let p_cond = (pj / psum_remaining).clamp(0.0, 1.0);
68 let count = binomial_for_multinomial(&mut self.bg, remaining, p_cond);
69 data.push(count as i64);
70 remaining -= count;
71 psum_remaining -= pj;
72 }
73 }
74 }
75
76 Array::<i64, Ix2>::from_vec(Ix2::new([size, k]), data)
77 }
78
79 pub fn multivariate_normal(
95 &mut self,
96 mean: &[f64],
97 cov: &[f64],
98 size: usize,
99 ) -> Result<Array<f64, Ix2>, FerrayError> {
100 if size == 0 {
101 return Err(FerrayError::invalid_value("size must be > 0"));
102 }
103 let d = mean.len();
104 if d == 0 {
105 return Err(FerrayError::invalid_value("mean must be non-empty"));
106 }
107 if cov.len() != d * d {
108 return Err(FerrayError::invalid_value(format!(
109 "cov must have {} elements for mean of length {d}, got {}",
110 d * d,
111 cov.len()
112 )));
113 }
114
115 let l = cholesky_decompose(cov, d)?;
117
118 let mut data = Vec::with_capacity(size * d);
119 for _ in 0..size {
120 let mut z = Vec::with_capacity(d);
122 for _ in 0..d {
123 z.push(standard_normal_single(&mut self.bg));
124 }
125
126 for i in 0..d {
128 let mut val = mean[i];
129 for j in 0..=i {
130 val += l[i * d + j] * z[j];
131 }
132 data.push(val);
133 }
134 }
135
136 Array::<f64, Ix2>::from_vec(Ix2::new([size, d]), data)
137 }
138
139 pub fn dirichlet(
154 &mut self,
155 alpha: &[f64],
156 size: usize,
157 ) -> Result<Array<f64, Ix2>, FerrayError> {
158 if size == 0 {
159 return Err(FerrayError::invalid_value("size must be > 0"));
160 }
161 if alpha.is_empty() {
162 return Err(FerrayError::invalid_value(
163 "alpha must have at least one element",
164 ));
165 }
166 for (i, &a) in alpha.iter().enumerate() {
167 if a <= 0.0 {
168 return Err(FerrayError::invalid_value(format!(
169 "alpha[{i}] = {a} must be positive"
170 )));
171 }
172 }
173
174 let k = alpha.len();
175 let mut data = Vec::with_capacity(size * k);
176
177 for _ in 0..size {
178 let mut gammas = Vec::with_capacity(k);
179 let mut sum = 0.0;
180 for &a in alpha {
181 let g = standard_gamma_single(&mut self.bg, a);
182 gammas.push(g);
183 sum += g;
184 }
185 if sum > 0.0 {
187 for g in &gammas {
188 data.push(g / sum);
189 }
190 } else {
191 let val = 1.0 / k as f64;
193 for _ in 0..k {
194 data.push(val);
195 }
196 }
197 }
198
199 Array::<f64, Ix2>::from_vec(Ix2::new([size, k]), data)
200 }
201}
202
203fn binomial_for_multinomial<B: BitGenerator>(bg: &mut B, n: u64, p: f64) -> u64 {
205 if n == 0 || p <= 0.0 {
206 return 0;
207 }
208 if p >= 1.0 {
209 return n;
210 }
211
212 let (pp, flipped) = if p > 0.5 { (1.0 - p, true) } else { (p, false) };
213
214 let result = if (n as f64) * pp < 30.0 {
215 let q = 1.0 - pp;
217 let s = pp / q;
218 let a = (n as f64 + 1.0) * s;
219 let mut r = q.powi(n as i32);
220 let mut u = bg.next_f64();
221 let mut x: u64 = 0;
222 while u > r {
223 u -= r;
224 x += 1;
225 if x > n {
226 x = n;
227 break;
228 }
229 r *= a / (x as f64) - s;
230 if r < 0.0 {
231 break;
232 }
233 }
234 x.min(n)
235 } else {
236 loop {
238 let z = standard_normal_single(bg);
239 let sigma = ((n as f64) * pp * (1.0 - pp)).sqrt();
240 let x = ((n as f64) * pp + sigma * z + 0.5).floor() as i64;
241 if x >= 0 && x <= n as i64 {
242 break x as u64;
243 }
244 }
245 };
246
247 if flipped { n - result } else { result }
248}
249
250fn cholesky_decompose(a: &[f64], n: usize) -> Result<Vec<f64>, FerrayError> {
254 let mut l = vec![0.0; n * n];
255
256 for i in 0..n {
257 for j in 0..=i {
258 let mut sum = 0.0;
259 for k in 0..j {
260 sum += l[i * n + k] * l[j * n + k];
261 }
262 if i == j {
263 let diag = a[i * n + i] - sum;
264 if diag < -1e-10 {
265 return Err(FerrayError::invalid_value(
266 "covariance matrix is not positive semi-definite",
267 ));
268 }
269 l[i * n + j] = diag.max(0.0).sqrt();
270 } else {
271 let denom = l[j * n + j];
272 if denom.abs() < 1e-15 {
273 l[i * n + j] = 0.0;
274 } else {
275 l[i * n + j] = (a[i * n + j] - sum) / denom;
276 }
277 }
278 }
279 }
280
281 Ok(l)
282}
283
284#[cfg(test)]
285mod tests {
286 use crate::default_rng_seeded;
287
288 #[test]
289 fn multinomial_shape() {
290 let mut rng = default_rng_seeded(42);
291 let pvals = [0.2, 0.3, 0.5];
292 let arr = rng.multinomial(100, &pvals, 10).unwrap();
293 assert_eq!(arr.shape(), &[10, 3]);
294 }
295
296 #[test]
297 fn multinomial_row_sums() {
298 let mut rng = default_rng_seeded(42);
299 let pvals = [0.2, 0.3, 0.5];
300 let n = 100u64;
301 let arr = rng.multinomial(n, &pvals, 50).unwrap();
302 let slice = arr.as_slice().unwrap();
303 let k = pvals.len();
304 for row in 0..50 {
305 let row_sum: i64 = (0..k).map(|j| slice[row * k + j]).sum();
306 assert_eq!(
307 row_sum, n as i64,
308 "row {row} sum is {row_sum}, expected {n}"
309 );
310 }
311 }
312
313 #[test]
314 fn multinomial_nonnegative() {
315 let mut rng = default_rng_seeded(42);
316 let pvals = [0.1, 0.2, 0.3, 0.4];
317 let arr = rng.multinomial(50, &pvals, 100).unwrap();
318 for &v in arr.as_slice().unwrap() {
319 assert!(v >= 0, "multinomial produced negative count: {v}");
320 }
321 }
322
323 #[test]
324 fn multinomial_bad_pvals() {
325 let mut rng = default_rng_seeded(42);
326 assert!(rng.multinomial(10, &[0.5, 0.6], 10).is_err()); assert!(rng.multinomial(10, &[-0.1, 1.1], 10).is_err()); assert!(rng.multinomial(10, &[], 10).is_err()); }
330
331 #[test]
332 fn multivariate_normal_shape() {
333 let mut rng = default_rng_seeded(42);
334 let mean = [1.0, 2.0, 3.0];
335 let cov = [1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0];
336 let arr = rng.multivariate_normal(&mean, &cov, 100).unwrap();
337 assert_eq!(arr.shape(), &[100, 3]);
338 }
339
340 #[test]
341 fn multivariate_normal_mean() {
342 let mut rng = default_rng_seeded(42);
343 let mean = [5.0, -3.0];
344 let cov = [1.0, 0.0, 0.0, 1.0];
345 let n = 100_000;
346 let arr = rng.multivariate_normal(&mean, &cov, n).unwrap();
347 let slice = arr.as_slice().unwrap();
348 let d = mean.len();
349
350 for j in 0..d {
351 let col_mean: f64 = (0..n).map(|i| slice[i * d + j]).sum::<f64>() / n as f64;
352 let se = (1.0 / n as f64).sqrt();
353 assert!(
354 (col_mean - mean[j]).abs() < 3.0 * se,
355 "multivariate_normal mean[{j}] = {col_mean}, expected {}",
356 mean[j]
357 );
358 }
359 }
360
361 #[test]
362 fn multivariate_normal_bad_cov() {
363 let mut rng = default_rng_seeded(42);
364 let mean = [0.0, 0.0];
365 assert!(
367 rng.multivariate_normal(&mean, &[1.0, 0.0, 0.0], 10)
368 .is_err()
369 );
370 }
371
372 #[test]
373 fn dirichlet_shape() {
374 let mut rng = default_rng_seeded(42);
375 let alpha = [1.0, 2.0, 3.0];
376 let arr = rng.dirichlet(&alpha, 10).unwrap();
377 assert_eq!(arr.shape(), &[10, 3]);
378 }
379
380 #[test]
381 fn dirichlet_sums_to_one() {
382 let mut rng = default_rng_seeded(42);
383 let alpha = [0.5, 1.0, 2.0, 0.5];
384 let arr = rng.dirichlet(&alpha, 100).unwrap();
385 let slice = arr.as_slice().unwrap();
386 let k = alpha.len();
387 for row in 0..100 {
388 let row_sum: f64 = (0..k).map(|j| slice[row * k + j]).sum();
389 assert!(
390 (row_sum - 1.0).abs() < 1e-10,
391 "dirichlet row {row} sums to {row_sum}, expected 1.0"
392 );
393 }
394 }
395
396 #[test]
397 fn dirichlet_nonnegative() {
398 let mut rng = default_rng_seeded(42);
399 let alpha = [0.5, 1.0, 2.0];
400 let arr = rng.dirichlet(&alpha, 100).unwrap();
401 for &v in arr.as_slice().unwrap() {
402 assert!(v >= 0.0, "dirichlet produced negative value: {v}");
403 }
404 }
405
406 #[test]
407 fn dirichlet_bad_alpha() {
408 let mut rng = default_rng_seeded(42);
409 assert!(rng.dirichlet(&[], 10).is_err());
410 assert!(rng.dirichlet(&[1.0, 0.0], 10).is_err());
411 assert!(rng.dirichlet(&[1.0, -1.0], 10).is_err());
412 }
413}