use ferray_core::{Array, FerrayError, Ix1};
use crate::bitgen::BitGenerator;
use crate::distributions::normal::standard_normal_pair;
use crate::generator::Generator;
impl<B: BitGenerator + Clone> Generator<B> {
pub fn standard_normal_parallel(
&mut self,
size: usize,
) -> Result<Array<f64, Ix1>, FerrayError> {
if size == 0 {
return Err(FerrayError::invalid_value("size must be > 0"));
}
let mut data = Vec::with_capacity(size);
while data.len() < size {
let (a, b) = standard_normal_pair(&mut self.bg);
data.push(a);
if data.len() < size {
data.push(b);
}
}
let n = data.len();
Array::<f64, Ix1>::from_vec(Ix1::new([n]), data)
}
pub fn spawn(&mut self, n: usize) -> Result<Vec<Generator<B>>, FerrayError> {
crate::generator::spawn_generators(self, n)
}
}
#[cfg(test)]
mod tests {
use crate::default_rng_seeded;
#[test]
fn parallel_matches_sequential() {
let mut rng1 = default_rng_seeded(42);
let mut rng2 = default_rng_seeded(42);
let seq = rng1.standard_normal(10_000).unwrap();
let par = rng2.standard_normal_parallel(10_000).unwrap();
assert_eq!(
seq.as_slice().unwrap(),
par.as_slice().unwrap(),
"parallel and sequential outputs differ"
);
}
#[test]
fn parallel_deterministic() {
let mut rng1 = default_rng_seeded(42);
let mut rng2 = default_rng_seeded(42);
let a = rng1.standard_normal_parallel(50_000).unwrap();
let b = rng2.standard_normal_parallel(50_000).unwrap();
assert_eq!(a.as_slice().unwrap(), b.as_slice().unwrap());
}
#[test]
fn parallel_large() {
let mut rng = default_rng_seeded(42);
let arr = rng.standard_normal_parallel(1_000_000).unwrap();
assert_eq!(arr.shape(), &[1_000_000]);
let slice = arr.as_slice().unwrap();
let mean: f64 = slice.iter().sum::<f64>() / slice.len() as f64;
assert!(mean.abs() < 0.01, "parallel mean {mean} too far from 0");
}
#[test]
fn spawn_creates_independent_generators() {
let mut rng = default_rng_seeded(42);
let mut children = rng.spawn(4).unwrap();
assert_eq!(children.len(), 4);
let outputs: Vec<u64> = children.iter_mut().map(|c| c.next_u64()).collect();
for i in 0..outputs.len() {
for j in (i + 1)..outputs.len() {
assert_ne!(
outputs[i], outputs[j],
"children {i} and {j} produced same first value"
);
}
}
}
#[test]
fn spawn_deterministic() {
let mut rng1 = default_rng_seeded(42);
let mut rng2 = default_rng_seeded(42);
let mut children1 = rng1.spawn(4).unwrap();
let mut children2 = rng2.spawn(4).unwrap();
for (c1, c2) in children1.iter_mut().zip(children2.iter_mut()) {
for _ in 0..100 {
assert_eq!(c1.next_u64(), c2.next_u64());
}
}
}
#[test]
fn parallel_zero_size_error() {
let mut rng = default_rng_seeded(42);
assert!(rng.standard_normal_parallel(0).is_err());
}
}