ferray_random/
parallel.rs1use ferray_core::{Array, FerrayError, IxDyn};
7
8use crate::bitgen::BitGenerator;
9use crate::distributions::normal::standard_normal_single;
10use crate::generator::{Generator, shape_size, vec_to_array_f64};
11use crate::shape::IntoShape;
12
13impl<B: BitGenerator + Clone> Generator<B> {
14 pub fn standard_normal_parallel(
28 &mut self,
29 size: impl IntoShape,
30 ) -> Result<Array<f64, IxDyn>, FerrayError> {
31 let shape_vec = size.into_shape()?;
32 let n = shape_size(&shape_vec);
33
34 let num_threads = rayon::current_num_threads().max(1);
36 if n < 10_000 || num_threads <= 1 {
37 let mut data = Vec::with_capacity(n);
39 for _ in 0..n {
40 data.push(standard_normal_single(&mut self.bg));
41 }
42 return vec_to_array_f64(data, &shape_vec);
43 }
44
45 let mut children = self.spawn(num_threads)?;
47 let chunk_size = n.div_ceil(num_threads);
48
49 use rayon::prelude::*;
51 let chunks: Vec<Vec<f64>> = children
52 .par_iter_mut()
53 .enumerate()
54 .map(|(i, child)| {
55 let start = i * chunk_size;
56 let end = (start + chunk_size).min(n);
57 let count = end - start;
58 let mut chunk = Vec::with_capacity(count);
59 for _ in 0..count {
60 chunk.push(standard_normal_single(&mut child.bg));
61 }
62 chunk
63 })
64 .collect();
65
66 let mut data = Vec::with_capacity(n);
68 for chunk in chunks {
69 data.extend_from_slice(&chunk);
70 }
71 data.truncate(n);
72
73 vec_to_array_f64(data, &shape_vec)
74 }
75
76 pub fn spawn(&mut self, n: usize) -> Result<Vec<Generator<B>>, FerrayError> {
86 crate::generator::spawn_generators(self, n)
87 }
88}
89
90#[cfg(test)]
91mod tests {
92 use crate::default_rng_seeded;
93
94 #[test]
95 fn parallel_correct_length_and_stats() {
96 let mut rng = default_rng_seeded(42);
98 let par = rng.standard_normal_parallel(10_000).unwrap();
99 assert_eq!(par.shape(), &[10_000]);
100 let slice = par.as_slice().unwrap();
101 let mean: f64 = slice.iter().sum::<f64>() / slice.len() as f64;
102 let var: f64 =
103 slice.iter().map(|x| (x - mean) * (x - mean)).sum::<f64>() / slice.len() as f64;
104 assert!(mean.abs() < 0.05, "mean = {mean}");
106 assert!((var - 1.0).abs() < 0.1, "var = {var}");
107 }
108
109 #[test]
110 fn parallel_deterministic() {
111 let mut rng1 = default_rng_seeded(42);
112 let mut rng2 = default_rng_seeded(42);
113
114 let a = rng1.standard_normal_parallel(50_000).unwrap();
115 let b = rng2.standard_normal_parallel(50_000).unwrap();
116
117 assert_eq!(a.as_slice().unwrap(), b.as_slice().unwrap());
118 }
119
120 #[test]
121 fn parallel_large() {
122 let mut rng = default_rng_seeded(42);
123 let arr = rng.standard_normal_parallel(1_000_000).unwrap();
124 assert_eq!(arr.shape(), &[1_000_000]);
125 let slice = arr.as_slice().unwrap();
127 let mean: f64 = slice.iter().sum::<f64>() / slice.len() as f64;
128 assert!(mean.abs() < 0.01, "parallel mean {mean} too far from 0");
129 }
130
131 #[test]
132 fn spawn_creates_independent_generators() {
133 let mut rng = default_rng_seeded(42);
134 let mut children = rng.spawn(4).unwrap();
135 assert_eq!(children.len(), 4);
136
137 let outputs: Vec<u64> = children.iter_mut().map(|c| c.next_u64()).collect();
139 for i in 0..outputs.len() {
140 for j in (i + 1)..outputs.len() {
141 assert_ne!(
142 outputs[i], outputs[j],
143 "children {i} and {j} produced same first value"
144 );
145 }
146 }
147 }
148
149 #[test]
150 fn spawn_deterministic() {
151 let mut rng1 = default_rng_seeded(42);
152 let mut rng2 = default_rng_seeded(42);
153
154 let mut children1 = rng1.spawn(4).unwrap();
155 let mut children2 = rng2.spawn(4).unwrap();
156
157 for (c1, c2) in children1.iter_mut().zip(children2.iter_mut()) {
158 for _ in 0..100 {
159 assert_eq!(c1.next_u64(), c2.next_u64());
160 }
161 }
162 }
163
164 #[test]
165 fn parallel_zero_size_returns_empty() {
166 let mut rng = default_rng_seeded(42);
169 let arr = rng.standard_normal_parallel(0).unwrap();
170 assert_eq!(arr.shape(), &[0]);
171 assert_eq!(arr.size(), 0);
172 }
173}