1use scirs2_autograd::ndarray::{Array1, Array2};
8
9#[cfg(feature = "no-std")]
10use core::f32::consts::{SQRT_2, TAU};
11#[cfg(not(feature = "no-std"))]
12use std::f32::consts::{SQRT_2, TAU};
13
14#[cfg(feature = "no-std")]
15use alloc::vec;
16
17pub struct SimdRng {
19 state: u64,
20 multiplier: u64,
21 increment: u64,
22}
23
24impl SimdRng {
25 pub fn new(seed: u64) -> Self {
27 Self {
28 state: seed,
29 multiplier: 1103515245,
30 increment: 12345,
31 }
32 }
33
34 pub fn next_u32(&mut self) -> u32 {
36 self.state = self
37 .state
38 .wrapping_mul(self.multiplier)
39 .wrapping_add(self.increment);
40 (self.state >> 16) as u32
41 }
42
43 pub fn fill_u32(&mut self, output: &mut [u32]) {
45 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
46 {
47 if crate::simd_feature_detected!("avx2") {
48 unsafe { self.fill_u32_avx2(output) };
49 return;
50 } else if crate::simd_feature_detected!("sse2") {
51 unsafe { self.fill_u32_sse2(output) };
52 return;
53 }
54 }
55
56 for val in output.iter_mut() {
58 *val = self.next_u32();
59 }
60 }
61
62 pub fn uniform_f32(&mut self, output: &mut [f32]) {
64 let mut u32_buffer = vec![0u32; output.len()];
65 self.fill_u32(&mut u32_buffer);
66
67 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
68 {
69 if crate::simd_feature_detected!("avx2") {
70 unsafe { convert_u32_to_f32_avx2(&u32_buffer, output) };
71 return;
72 } else if crate::simd_feature_detected!("sse2") {
73 unsafe { convert_u32_to_f32_sse2(&u32_buffer, output) };
74 return;
75 }
76 }
77
78 for (i, &val) in u32_buffer.iter().enumerate() {
80 output[i] = (val as f32) / (u32::MAX as f32);
81 }
82 }
83}
84
85pub struct Normal {
87 mean: f32,
88 std_dev: f32,
89}
90
91impl Normal {
92 pub fn new(mean: f32, std_dev: f32) -> Self {
94 assert!(std_dev > 0.0, "Standard deviation must be positive");
95 Self { mean, std_dev }
96 }
97
98 pub fn sample(&self, rng: &mut SimdRng, output: &mut [f32]) {
100 let mut uniform_samples = vec![0.0f32; output.len() * 2];
101 rng.uniform_f32(&mut uniform_samples);
102
103 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
104 {
105 if crate::simd_feature_detected!("avx2") {
106 unsafe { self.box_muller_avx2(&uniform_samples, output) };
107 return;
108 } else if crate::simd_feature_detected!("sse2") {
109 unsafe { self.box_muller_sse2(&uniform_samples, output) };
110 return;
111 }
112 }
113
114 self.box_muller_scalar(&uniform_samples, output);
116 }
117
118 pub fn pdf(&self, values: &[f32], output: &mut [f32]) {
120 assert_eq!(values.len(), output.len());
121
122 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
123 {
124 if crate::simd_feature_detected!("avx2") {
125 unsafe { self.pdf_avx2(values, output) };
126 return;
127 } else if crate::simd_feature_detected!("sse2") {
128 unsafe { self.pdf_sse2(values, output) };
129 return;
130 }
131 }
132
133 self.pdf_scalar(values, output);
135 }
136
137 pub fn cdf(&self, values: &[f32], output: &mut [f32]) {
139 assert_eq!(values.len(), output.len());
140
141 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
142 {
143 if crate::simd_feature_detected!("avx2") {
144 unsafe { self.cdf_avx2(values, output) };
145 return;
146 } else if crate::simd_feature_detected!("sse2") {
147 unsafe { self.cdf_sse2(values, output) };
148 return;
149 }
150 }
151
152 self.cdf_scalar(values, output);
154 }
155
156 fn box_muller_scalar(&self, uniform: &[f32], output: &mut [f32]) {
157 let mut i = 0;
158 let mut out_idx = 0;
159
160 while out_idx < output.len() && i + 1 < uniform.len() {
161 let u1 = uniform[i].max(1e-10); let u2 = uniform[i + 1];
163
164 let magnitude = (-2.0 * u1.ln()).sqrt() * self.std_dev;
165 let angle = TAU * u2;
166
167 let z0 = magnitude * angle.cos() + self.mean;
168 let z1 = magnitude * angle.sin() + self.mean;
169
170 output[out_idx] = z0;
171 if out_idx + 1 < output.len() {
172 output[out_idx + 1] = z1;
173 }
174
175 i += 2;
176 out_idx += 2;
177 }
178 }
179
180 fn pdf_scalar(&self, values: &[f32], output: &mut [f32]) {
181 let inv_sqrt_2pi = 1.0 / (TAU).sqrt();
182 let inv_std = 1.0 / self.std_dev;
183 let inv_var_2 = 1.0 / (2.0 * self.std_dev * self.std_dev);
184
185 for (i, &x) in values.iter().enumerate() {
186 let z = (x - self.mean) * inv_std;
187 output[i] = inv_sqrt_2pi * inv_std * (-z * z * inv_var_2).exp();
188 }
189 }
190
191 fn cdf_scalar(&self, values: &[f32], output: &mut [f32]) {
192 for (i, &x) in values.iter().enumerate() {
193 let z = (x - self.mean) / (self.std_dev * SQRT_2);
194 output[i] = 0.5 * (1.0 + erf_approximation(z));
195 }
196 }
197}
198
199pub struct Exponential {
201 rate: f32,
202}
203
204impl Exponential {
205 pub fn new(rate: f32) -> Self {
207 assert!(rate > 0.0, "Rate parameter must be positive");
208 Self { rate }
209 }
210
211 pub fn sample(&self, rng: &mut SimdRng, output: &mut [f32]) {
213 let mut uniform_samples = vec![0.0f32; output.len()];
214 rng.uniform_f32(&mut uniform_samples);
215
216 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
217 {
218 if crate::simd_feature_detected!("avx2") {
219 unsafe { self.inverse_transform_avx2(&uniform_samples, output) };
220 return;
221 } else if crate::simd_feature_detected!("sse2") {
222 unsafe { self.inverse_transform_sse2(&uniform_samples, output) };
223 return;
224 }
225 }
226
227 for (i, &u) in uniform_samples.iter().enumerate() {
229 output[i] = -(1.0 - u).ln() / self.rate;
230 }
231 }
232
233 pub fn pdf(&self, values: &[f32], output: &mut [f32]) {
235 for (i, &x) in values.iter().enumerate() {
236 if x >= 0.0 {
237 output[i] = self.rate * (-self.rate * x).exp();
238 } else {
239 output[i] = 0.0;
240 }
241 }
242 }
243}
244
245pub struct Beta {
247 alpha: f32,
248 beta: f32,
249}
250
251impl Beta {
252 pub fn new(alpha: f32, beta: f32) -> Self {
254 assert!(alpha > 0.0 && beta > 0.0, "Alpha and beta must be positive");
255 Self { alpha, beta }
256 }
257
258 pub fn sample(&self, rng: &mut SimdRng, output: &mut [f32]) {
260 let mut uniform_samples = vec![0.0f32; output.len() * 2];
263 rng.uniform_f32(&mut uniform_samples);
264
265 for i in 0..output.len() {
266 let u1 = uniform_samples[i * 2];
267 let u2 = uniform_samples[i * 2 + 1];
268
269 let x = u1.powf(1.0 / self.alpha);
271 let y = u2.powf(1.0 / self.beta);
272
273 output[i] = x / (x + y);
274 }
275 }
276}
277
278fn erf_approximation(x: f32) -> f32 {
280 let a1 = 0.254_829_6;
282 let a2 = -0.284_496_72;
283 let a3 = 1.421_413_8;
284 let a4 = -1.453_152_1;
285 let a5 = 1.061_405_4;
286 let p = 0.3275911;
287
288 let sign = if x < 0.0 { -1.0 } else { 1.0 };
289 let x_abs = x.abs();
290
291 let t = 1.0 / (1.0 + p * x_abs);
292 let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x_abs * x_abs).exp();
293
294 sign * y
295}
296
297#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
300impl SimdRng {
301 #[target_feature(enable = "sse2")]
302 unsafe fn fill_u32_sse2(&mut self, output: &mut [u32]) {
303 for val in output.iter_mut() {
306 *val = self.next_u32();
307 }
308 }
309
310 #[target_feature(enable = "avx2")]
311 unsafe fn fill_u32_avx2(&mut self, output: &mut [u32]) {
312 for val in output.iter_mut() {
318 *val = self.next_u32();
319 }
320 }
321}
322
323#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
324#[target_feature(enable = "sse2")]
325unsafe fn convert_u32_to_f32_sse2(input: &[u32], output: &mut [f32]) {
326 for (i, &val) in input.iter().enumerate() {
330 output[i] = (val as f32) / (u32::MAX as f32);
331 }
332}
333
334#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
335#[target_feature(enable = "avx2")]
336unsafe fn convert_u32_to_f32_avx2(input: &[u32], output: &mut [f32]) {
337 for (i, &val) in input.iter().enumerate() {
341 output[i] = (val as f32) / (u32::MAX as f32);
342 }
343}
344
345#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
346impl Normal {
347 #[target_feature(enable = "sse2")]
348 unsafe fn box_muller_sse2(&self, uniform: &[f32], output: &mut [f32]) {
349 #[cfg(feature = "no-std")]
350 use core::arch::x86_64::*;
351 #[cfg(not(feature = "no-std"))]
352 use core::arch::x86_64::*;
353
354 let mut i = 0;
355 let mut out_idx = 0;
356
357 while out_idx + 4 <= output.len() && i + 8 <= uniform.len() {
358 let u1 = _mm_loadu_ps(&uniform[i]);
359 let u2 = _mm_loadu_ps(&uniform[i + 4]);
360
361 let mut u1_vals = [0.0f32; 4];
364 let mut u2_vals = [0.0f32; 4];
365 _mm_storeu_ps(u1_vals.as_mut_ptr(), u1);
366 _mm_storeu_ps(u2_vals.as_mut_ptr(), u2);
367
368 let mut z0_vals = [0.0f32; 4];
369 for k in 0..4 {
370 let magnitude = (-2.0 * u1_vals[k].ln()).sqrt() * self.std_dev;
371 let angle = TAU * u2_vals[k];
372 z0_vals[k] = magnitude * angle.cos() + self.mean;
373 }
374
375 let z0 = _mm_loadu_ps(z0_vals.as_ptr());
376
377 _mm_storeu_ps(&mut output[out_idx], z0);
378
379 i += 8;
380 out_idx += 4;
381 }
382
383 while out_idx < output.len() && i + 1 < uniform.len() {
385 let u1 = uniform[i].max(1e-10);
386 let u2 = uniform[i + 1];
387
388 let magnitude = (-2.0 * u1.ln()).sqrt() * self.std_dev;
389 let angle = TAU * u2;
390
391 output[out_idx] = magnitude * angle.cos() + self.mean;
392
393 i += 2;
394 out_idx += 1;
395 }
396 }
397
398 #[target_feature(enable = "avx2")]
399 unsafe fn box_muller_avx2(&self, uniform: &[f32], output: &mut [f32]) {
400 #[cfg(feature = "no-std")]
401 use core::arch::x86_64::*;
402 #[cfg(not(feature = "no-std"))]
403 use core::arch::x86_64::*;
404
405 let mut i = 0;
406 let mut out_idx = 0;
407
408 while out_idx + 8 <= output.len() && i + 16 <= uniform.len() {
409 let u1 = _mm256_loadu_ps(&uniform[i]);
410 let u2 = _mm256_loadu_ps(&uniform[i + 8]);
411
412 let mut u1_vals = [0.0f32; 8];
415 let mut u2_vals = [0.0f32; 8];
416 _mm256_storeu_ps(u1_vals.as_mut_ptr(), u1);
417 _mm256_storeu_ps(u2_vals.as_mut_ptr(), u2);
418
419 let mut z0_vals = [0.0f32; 8];
420 for k in 0..8 {
421 let magnitude = (-2.0 * u1_vals[k].ln()).sqrt() * self.std_dev;
422 let angle = TAU * u2_vals[k];
423 z0_vals[k] = magnitude * angle.cos() + self.mean;
424 }
425
426 let z0 = _mm256_loadu_ps(z0_vals.as_ptr());
427
428 _mm256_storeu_ps(&mut output[out_idx], z0);
429
430 i += 16;
431 out_idx += 8;
432 }
433
434 while out_idx < output.len() && i + 1 < uniform.len() {
436 let u1 = uniform[i].max(1e-10);
437 let u2 = uniform[i + 1];
438
439 let magnitude = (-2.0 * u1.ln()).sqrt() * self.std_dev;
440 let angle = TAU * u2;
441
442 output[out_idx] = magnitude * angle.cos() + self.mean;
443
444 i += 2;
445 out_idx += 1;
446 }
447 }
448
449 #[target_feature(enable = "sse2")]
450 unsafe fn pdf_sse2(&self, values: &[f32], output: &mut [f32]) {
451 #[cfg(feature = "no-std")]
452 use core::arch::x86_64::*;
453 #[cfg(not(feature = "no-std"))]
454 use core::arch::x86_64::*;
455
456 let inv_sqrt_2pi = _mm_set1_ps(1.0 / (TAU).sqrt());
457 let mean_vec = _mm_set1_ps(self.mean);
458 let inv_std = _mm_set1_ps(1.0 / self.std_dev);
459 let inv_var_2 = _mm_set1_ps(1.0 / (2.0 * self.std_dev * self.std_dev));
460
461 let mut i = 0;
462 while i + 4 <= values.len() {
463 let x = _mm_loadu_ps(&values[i]);
464 let z = _mm_mul_ps(_mm_sub_ps(x, mean_vec), inv_std);
465 let exp_arg = _mm_mul_ps(_mm_mul_ps(z, z), inv_var_2);
466 let mut exp_arg_vals = [0.0f32; 4];
467 _mm_storeu_ps(exp_arg_vals.as_mut_ptr(), exp_arg);
468 let mut exp_vals = [0.0f32; 4];
469 for k in 0..4 {
470 exp_vals[k] = (-exp_arg_vals[k]).exp();
471 }
472 let exp_result = _mm_loadu_ps(exp_vals.as_ptr());
473 let result = _mm_mul_ps(_mm_mul_ps(inv_sqrt_2pi, inv_std), exp_result);
474 _mm_storeu_ps(&mut output[i], result);
475 i += 4;
476 }
477
478 while i < values.len() {
480 let z = (values[i] - self.mean) / self.std_dev;
481 output[i] = (1.0 / (TAU).sqrt()) / self.std_dev * (-z * z / 2.0).exp();
482 i += 1;
483 }
484 }
485
486 #[target_feature(enable = "avx2")]
487 unsafe fn pdf_avx2(&self, values: &[f32], output: &mut [f32]) {
488 #[cfg(feature = "no-std")]
489 use core::arch::x86_64::*;
490 #[cfg(not(feature = "no-std"))]
491 use core::arch::x86_64::*;
492
493 let inv_sqrt_2pi = _mm256_set1_ps(1.0 / (TAU).sqrt());
494 let mean_vec = _mm256_set1_ps(self.mean);
495 let inv_std = _mm256_set1_ps(1.0 / self.std_dev);
496 let inv_var_2 = _mm256_set1_ps(1.0 / (2.0 * self.std_dev * self.std_dev));
497
498 let mut i = 0;
499 while i + 8 <= values.len() {
500 let x = _mm256_loadu_ps(&values[i]);
501 let z = _mm256_mul_ps(_mm256_sub_ps(x, mean_vec), inv_std);
502 let exp_arg = _mm256_mul_ps(_mm256_mul_ps(z, z), inv_var_2);
503 let mut exp_arg_vals = [0.0f32; 8];
504 _mm256_storeu_ps(exp_arg_vals.as_mut_ptr(), exp_arg);
505 let mut exp_vals = [0.0f32; 8];
506 for k in 0..8 {
507 exp_vals[k] = (-exp_arg_vals[k]).exp();
508 }
509 let exp_result = _mm256_loadu_ps(exp_vals.as_ptr());
510 let result = _mm256_mul_ps(_mm256_mul_ps(inv_sqrt_2pi, inv_std), exp_result);
511 _mm256_storeu_ps(&mut output[i], result);
512 i += 8;
513 }
514
515 while i < values.len() {
517 let z = (values[i] - self.mean) / self.std_dev;
518 output[i] = (1.0 / (TAU).sqrt()) / self.std_dev * (-z * z / 2.0).exp();
519 i += 1;
520 }
521 }
522
523 #[target_feature(enable = "sse2")]
524 unsafe fn cdf_sse2(&self, values: &[f32], output: &mut [f32]) {
525 self.cdf_scalar(values, output);
527 }
528
529 #[target_feature(enable = "avx2")]
530 unsafe fn cdf_avx2(&self, values: &[f32], output: &mut [f32]) {
531 self.cdf_scalar(values, output);
533 }
534}
535
536#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
537impl Exponential {
538 #[target_feature(enable = "sse2")]
539 unsafe fn inverse_transform_sse2(&self, uniform: &[f32], output: &mut [f32]) {
540 #[cfg(feature = "no-std")]
541 use core::arch::x86_64::*;
542 #[cfg(not(feature = "no-std"))]
543 use core::arch::x86_64::*;
544
545 let one = _mm_set1_ps(1.0);
546 let rate_vec = _mm_set1_ps(self.rate);
547
548 let mut i = 0;
549 while i + 4 <= uniform.len() {
550 let u = _mm_loadu_ps(&uniform[i]);
551 let one_minus_u = _mm_sub_ps(one, u);
552 let mut one_minus_u_vals = [0.0f32; 4];
553 _mm_storeu_ps(one_minus_u_vals.as_mut_ptr(), one_minus_u);
554 let mut ln_vals = [0.0f32; 4];
555 for k in 0..4 {
556 ln_vals[k] = one_minus_u_vals[k].ln();
557 }
558 let ln_result = _mm_loadu_ps(ln_vals.as_ptr());
559 let neg_ln = _mm_sub_ps(_mm_setzero_ps(), ln_result);
560 let result = _mm_div_ps(neg_ln, rate_vec);
561 _mm_storeu_ps(&mut output[i], result);
562 i += 4;
563 }
564
565 while i < uniform.len() {
567 output[i] = -(1.0 - uniform[i]).ln() / self.rate;
568 i += 1;
569 }
570 }
571
572 #[target_feature(enable = "avx2")]
573 unsafe fn inverse_transform_avx2(&self, uniform: &[f32], output: &mut [f32]) {
574 #[cfg(feature = "no-std")]
575 use core::arch::x86_64::*;
576 #[cfg(not(feature = "no-std"))]
577 use core::arch::x86_64::*;
578
579 let one = _mm256_set1_ps(1.0);
580 let rate_vec = _mm256_set1_ps(self.rate);
581
582 let mut i = 0;
583 while i + 8 <= uniform.len() {
584 let u = _mm256_loadu_ps(&uniform[i]);
585 let one_minus_u = _mm256_sub_ps(one, u);
586 let mut one_minus_u_vals = [0.0f32; 8];
587 _mm256_storeu_ps(one_minus_u_vals.as_mut_ptr(), one_minus_u);
588 let mut ln_vals = [0.0f32; 8];
589 for k in 0..8 {
590 ln_vals[k] = one_minus_u_vals[k].ln();
591 }
592 let ln_result = _mm256_loadu_ps(ln_vals.as_ptr());
593 let neg_ln = _mm256_sub_ps(_mm256_setzero_ps(), ln_result);
594 let result = _mm256_div_ps(neg_ln, rate_vec);
595 _mm256_storeu_ps(&mut output[i], result);
596 i += 8;
597 }
598
599 while i < uniform.len() {
601 output[i] = -(1.0 - uniform[i]).ln() / self.rate;
602 i += 1;
603 }
604 }
605}
606
607pub fn multivariate_normal_sample(
609 mean: &Array1<f32>,
610 covariance: &Array2<f32>,
611 rng: &mut SimdRng,
612 num_samples: usize,
613) -> Array2<f32> {
614 let dim = mean.len();
615 assert_eq!(covariance.shape(), &[dim, dim]);
616
617 let chol = cholesky_decomposition(covariance);
619
620 let mut samples = Array2::zeros((num_samples, dim));
621 let normal = Normal::new(0.0, 1.0);
622
623 for i in 0..num_samples {
624 let mut standard_normal = vec![0.0f32; dim];
625 normal.sample(rng, &mut standard_normal);
626
627 let z = Array1::from_vec(standard_normal);
629 let transformed = crate::matrix::matrix_vector_multiply_f32(&chol, &z);
630
631 for j in 0..dim {
632 samples[[i, j]] = transformed[j] + mean[j];
633 }
634 }
635
636 samples
637}
638
639fn cholesky_decomposition(matrix: &Array2<f32>) -> Array2<f32> {
641 let n = matrix.nrows();
642 let mut chol = Array2::zeros((n, n));
643
644 for i in 0..n {
645 for j in 0..=i {
646 if i == j {
647 let mut sum = 0.0;
648 for k in 0..j {
649 sum += chol[[j, k]] * chol[[j, k]];
650 }
651 chol[[j, j]] = (matrix[[j, j]] - sum).sqrt();
652 } else {
653 let mut sum = 0.0;
654 for k in 0..j {
655 sum += chol[[i, k]] * chol[[j, k]];
656 }
657 chol[[i, j]] = (matrix[[i, j]] - sum) / chol[[j, j]];
658 }
659 }
660 }
661
662 chol
663}
664
665#[allow(non_snake_case)]
666#[cfg(all(test, not(feature = "no-std")))]
667mod tests {
668 use super::*;
669 use approx::assert_relative_eq;
670
671 #[cfg(feature = "no-std")]
672 use alloc::{vec, vec::Vec};
673
674 #[test]
675 fn test_simd_rng() {
676 let mut rng = SimdRng::new(12345);
677 let mut output = vec![0u32; 16];
678 rng.fill_u32(&mut output);
679
680 assert!(output.iter().any(|&x| x != output[0]));
682 }
683
684 #[test]
685 fn test_uniform_f32() {
686 let mut rng = SimdRng::new(12345);
687 let mut output = vec![0.0f32; 100];
688 rng.uniform_f32(&mut output);
689
690 for &val in &output {
692 assert!((0.0..1.0).contains(&val));
693 }
694
695 let mean: f32 = output.iter().sum::<f32>() / output.len() as f32;
697 assert!(mean > 0.4 && mean < 0.6); }
699
700 #[test]
701 fn test_normal_distribution() {
702 let mut rng = SimdRng::new(42);
703 let normal = Normal::new(5.0, 2.0);
704 let mut samples = vec![0.0f32; 1000];
705 normal.sample(&mut rng, &mut samples);
706
707 let mean: f32 = samples.iter().sum::<f32>() / samples.len() as f32;
708 assert_relative_eq!(mean, 5.0, epsilon = 0.2);
709 }
710
711 #[test]
712 fn test_normal_pdf() {
713 let normal = Normal::new(0.0, 1.0);
714 let values = vec![0.0, 1.0, -1.0];
715 let mut output = vec![0.0f32; 3];
716 normal.pdf(&values, &mut output);
717
718 assert_relative_eq!(output[0], 0.3989, epsilon = 0.01);
720
721 assert_relative_eq!(output[1], output[2], epsilon = 1e-6);
723 }
724
725 #[test]
726 fn test_exponential_distribution() {
727 let mut rng = SimdRng::new(123);
728 let exp_dist = Exponential::new(2.0);
729 let mut samples = vec![0.0f32; 1000];
730 exp_dist.sample(&mut rng, &mut samples);
731
732 for &sample in &samples {
734 assert!(sample >= 0.0);
735 }
736
737 let mean: f32 = samples.iter().sum::<f32>() / samples.len() as f32;
739 assert_relative_eq!(mean, 0.5, epsilon = 0.1);
740 }
741
742 #[test]
743 fn test_beta_distribution() {
744 let mut rng = SimdRng::new(456);
745 let beta = Beta::new(2.0, 3.0);
746 let mut samples = vec![0.0f32; 100];
747 beta.sample(&mut rng, &mut samples);
748
749 for &sample in &samples {
751 assert!((0.0..=1.0).contains(&sample));
752 }
753 }
754
755 #[test]
756 fn test_erf_approximation() {
757 assert_relative_eq!(erf_approximation(0.0), 0.0, epsilon = 1e-4);
758 assert_relative_eq!(erf_approximation(1.0), 0.8427, epsilon = 1e-3);
759 assert_relative_eq!(erf_approximation(-1.0), -0.8427, epsilon = 1e-3);
760 }
761
762 #[test]
763 fn test_rng_uniform() {
764 let mut rng = SimdRng::new(123);
765 let mut samples = vec![0.0f32; 10];
766 rng.uniform_f32(&mut samples);
767
768 eprintln!("Uniform samples: {:?}", samples);
769 let sum: f32 = samples.iter().sum();
770 eprintln!("Sum: {}, Mean: {}", sum, sum / samples.len() as f32);
771
772 assert!(sum > 0.1);
774 }
775
776 #[test]
777 fn test_multivariate_normal() {
778 let mut rng = SimdRng::new(789);
779 let mean = Array1::from_vec(vec![1.0, 2.0]);
780 let cov = Array2::from_shape_vec((2, 2), vec![1.0, 0.5, 0.5, 1.0])
781 .expect("shape and data length should match");
782
783 let samples = multivariate_normal_sample(&mean, &cov, &mut rng, 10);
784 assert_eq!(samples.shape(), &[10, 2]);
785 }
786}