ferray_random/distributions/
uniform.rs1use ferray_core::{Array, FerrayError, IxDyn};
4
5use crate::bitgen::BitGenerator;
6use crate::generator::{
7 Generator, generate_vec, generate_vec_f32, generate_vec_i64, shape_size, vec_to_array_f32,
8 vec_to_array_f64, vec_to_array_i64,
9};
10use crate::shape::IntoShape;
11
12impl<B: BitGenerator> Generator<B> {
13 pub fn random(&mut self, size: impl IntoShape) -> Result<Array<f64, IxDyn>, FerrayError> {
31 let shape = size.into_shape()?;
32 let n = shape_size(&shape);
33 let data = generate_vec(self, n, super::super::bitgen::BitGenerator::next_f64);
34 vec_to_array_f64(data, &shape)
35 }
36
37 pub fn uniform(
44 &mut self,
45 low: f64,
46 high: f64,
47 size: impl IntoShape,
48 ) -> Result<Array<f64, IxDyn>, FerrayError> {
49 if low >= high {
50 return Err(FerrayError::invalid_value(format!(
51 "low ({low}) must be less than high ({high})"
52 )));
53 }
54 let shape = size.into_shape()?;
55 let n = shape_size(&shape);
56 let range = high - low;
57 let data = generate_vec(self, n, |bg| bg.next_f64().mul_add(range, low));
58 vec_to_array_f64(data, &shape)
59 }
60
61 pub fn random_f32(&mut self, size: impl IntoShape) -> Result<Array<f32, IxDyn>, FerrayError> {
79 let shape = size.into_shape()?;
80 let n = shape_size(&shape);
81 let data = generate_vec_f32(self, n, super::super::bitgen::BitGenerator::next_f32);
82 vec_to_array_f32(data, &shape)
83 }
84
85 pub fn uniform_f32(
92 &mut self,
93 low: f32,
94 high: f32,
95 size: impl IntoShape,
96 ) -> Result<Array<f32, IxDyn>, FerrayError> {
97 if low >= high {
98 return Err(FerrayError::invalid_value(format!(
99 "low ({low}) must be less than high ({high})"
100 )));
101 }
102 let shape = size.into_shape()?;
103 let n = shape_size(&shape);
104 let range = high - low;
105 let data = generate_vec_f32(self, n, |bg| bg.next_f32().mul_add(range, low));
106 vec_to_array_f32(data, &shape)
107 }
108
109 pub fn integers(
116 &mut self,
117 low: i64,
118 high: i64,
119 size: impl IntoShape,
120 ) -> Result<Array<i64, IxDyn>, FerrayError> {
121 if low >= high {
122 return Err(FerrayError::invalid_value(format!(
123 "low ({low}) must be less than high ({high})"
124 )));
125 }
126 let shape = size.into_shape()?;
127 let n = shape_size(&shape);
128 let range = (high - low) as u64;
129 let data = generate_vec_i64(self, n, |bg| low + bg.next_u64_bounded(range) as i64);
130 vec_to_array_i64(data, &shape)
131 }
132}
133
134#[cfg(test)]
135mod tests {
136 use crate::default_rng_seeded;
137
138 #[test]
139 fn random_in_range() {
140 let mut rng = default_rng_seeded(42);
141 let arr = rng.random(10_000).unwrap();
142 let slice = arr.as_slice().unwrap();
143 for &v in slice {
144 assert!((0.0..1.0).contains(&v));
145 }
146 }
147
148 #[test]
149 fn random_deterministic() {
150 let mut rng1 = default_rng_seeded(42);
151 let mut rng2 = default_rng_seeded(42);
152 let a = rng1.random(100).unwrap();
153 let b = rng2.random(100).unwrap();
154 assert_eq!(a.as_slice().unwrap(), b.as_slice().unwrap());
155 }
156
157 #[test]
158 fn uniform_in_range() {
159 let mut rng = default_rng_seeded(42);
160 let arr = rng.uniform(5.0, 10.0, 10_000).unwrap();
161 let slice = arr.as_slice().unwrap();
162 for &v in slice {
163 assert!((5.0..10.0).contains(&v), "value {v} out of range");
164 }
165 }
166
167 #[test]
168 fn uniform_bad_range() {
169 let mut rng = default_rng_seeded(42);
170 assert!(rng.uniform(10.0, 5.0, 100).is_err());
171 assert!(rng.uniform(5.0, 5.0, 100).is_err());
172 }
173
174 #[test]
175 fn integers_in_range() {
176 let mut rng = default_rng_seeded(42);
177 let arr = rng.integers(0, 10, 10_000).unwrap();
178 let slice = arr.as_slice().unwrap();
179 for &v in slice {
180 assert!((0..10).contains(&v), "value {v} out of range");
181 }
182 }
183
184 #[test]
185 fn integers_negative_range() {
186 let mut rng = default_rng_seeded(42);
187 let arr = rng.integers(-5, 5, 1000).unwrap();
188 let slice = arr.as_slice().unwrap();
189 for &v in slice {
190 assert!((-5..5).contains(&v), "value {v} out of range");
191 }
192 }
193
194 #[test]
195 fn integers_bad_range() {
196 let mut rng = default_rng_seeded(42);
197 assert!(rng.integers(10, 5, 100).is_err());
198 }
199
200 #[test]
201 fn uniform_mean_variance() {
202 let mut rng = default_rng_seeded(42);
203 let n = 100_000;
204 let arr = rng.uniform(2.0, 8.0, n).unwrap();
205 let slice = arr.as_slice().unwrap();
206 let mean: f64 = slice.iter().sum::<f64>() / n as f64;
207 let var: f64 = slice.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n as f64;
208 let expected_mean = 5.0;
210 let expected_var = 3.0;
211 let se_mean = (expected_var / n as f64).sqrt();
212 assert!(
213 (mean - expected_mean).abs() < 3.0 * se_mean,
214 "mean {mean} too far from {expected_mean}"
215 );
216 assert!(
218 (var - expected_var).abs() < 0.1,
219 "variance {var} too far from {expected_var}"
220 );
221 }
222
223 #[test]
224 fn reproducibility_golden_values() {
225 let mut rng = default_rng_seeded(42);
228 let arr = rng.random(5).unwrap();
229 let vals = arr.as_slice().unwrap();
230
231 let golden = [vals[0], vals[1], vals[2], vals[3], vals[4]];
233
234 let mut rng2 = default_rng_seeded(42);
236 let arr2 = rng2.random(5).unwrap();
237 let vals2 = arr2.as_slice().unwrap();
238 for i in 0..5 {
239 assert_eq!(
240 vals2[i].to_bits(),
241 golden[i].to_bits(),
242 "golden value mismatch at index {i}"
243 );
244 }
245 }
246
247 #[test]
248 fn different_seeds_different_values() {
249 let mut rng1 = default_rng_seeded(42);
250 let mut rng2 = default_rng_seeded(123);
251 let a = rng1.random(100).unwrap();
252 let b = rng2.random(100).unwrap();
253 let diffs = a
255 .as_slice()
256 .unwrap()
257 .iter()
258 .zip(b.as_slice().unwrap().iter())
259 .filter(|(x, y)| x != y)
260 .count();
261 assert!(diffs > 50, "seeds 42 and 123 produced too-similar output");
262 }
263
264 #[test]
269 fn random_nd_shape_from_array() {
270 let mut rng = default_rng_seeded(42);
271 let arr = rng.random([3, 4]).unwrap();
272 assert_eq!(arr.shape(), &[3, 4]);
273 assert_eq!(arr.size(), 12);
274 }
275
276 #[test]
277 fn random_nd_shape_from_slice() {
278 let mut rng = default_rng_seeded(42);
279 let shape: &[usize] = &[2, 3, 4];
280 let arr = rng.random(shape).unwrap();
281 assert_eq!(arr.shape(), &[2, 3, 4]);
282 assert_eq!(arr.size(), 24);
283 }
284
285 #[test]
286 fn random_nd_shape_from_vec() {
287 let mut rng = default_rng_seeded(42);
288 let shape = vec![5, 5];
289 let arr = rng.random(shape).unwrap();
290 assert_eq!(arr.shape(), &[5, 5]);
291 }
292
293 #[test]
294 fn random_nd_zero_axis_returns_empty() {
295 let mut rng = default_rng_seeded(42);
298 let a = rng.random([3, 0]).unwrap();
299 assert_eq!(a.shape(), &[3, 0]);
300 assert_eq!(a.size(), 0);
301 let b = rng.random(0usize).unwrap();
302 assert_eq!(b.shape(), &[0]);
303 assert_eq!(b.size(), 0);
304 }
305
306 #[test]
307 fn random_nd_equivalent_to_reshape() {
308 let mut rng1 = default_rng_seeded(42);
311 let mut rng2 = default_rng_seeded(42);
312 let a = rng1.random(12).unwrap();
313 let b = rng2.random([3, 4]).unwrap();
314 assert_eq!(a.size(), b.size());
315 let a_data: Vec<f64> = a.iter().copied().collect();
316 let b_data: Vec<f64> = b.iter().copied().collect();
317 assert_eq!(a_data, b_data);
318 }
319
320 #[test]
321 fn uniform_nd_shape() {
322 let mut rng = default_rng_seeded(42);
323 let arr = rng.uniform(0.0, 10.0, [2, 5]).unwrap();
324 assert_eq!(arr.shape(), &[2, 5]);
325 for &v in arr.iter() {
326 assert!((0.0..10.0).contains(&v));
327 }
328 }
329
330 #[test]
331 fn integers_nd_shape() {
332 let mut rng = default_rng_seeded(42);
333 let arr = rng.integers(0, 100, [4, 3]).unwrap();
334 assert_eq!(arr.shape(), &[4, 3]);
335 for &v in arr.iter() {
336 assert!((0..100).contains(&v));
337 }
338 }
339
340 #[test]
345 fn random_f32_in_range() {
346 let mut rng = default_rng_seeded(42);
347 let arr = rng.random_f32(10_000).unwrap();
348 for &v in arr.as_slice().unwrap() {
349 assert!((0.0..1.0).contains(&v), "f32 value out of range: {v}");
350 }
351 }
352
353 #[test]
354 fn random_f32_deterministic() {
355 let mut rng1 = default_rng_seeded(42);
356 let mut rng2 = default_rng_seeded(42);
357 let a = rng1.random_f32(100).unwrap();
358 let b = rng2.random_f32(100).unwrap();
359 assert_eq!(a.as_slice().unwrap(), b.as_slice().unwrap());
360 }
361
362 #[test]
363 fn random_f32_nd_shape() {
364 let mut rng = default_rng_seeded(42);
365 let arr = rng.random_f32([3, 4]).unwrap();
366 assert_eq!(arr.shape(), &[3, 4]);
367 }
368
369 #[test]
370 fn random_f32_mean() {
371 let mut rng = default_rng_seeded(42);
373 let n = 100_000usize;
374 let arr = rng.random_f32(n).unwrap();
375 let sum: f64 = arr.as_slice().unwrap().iter().map(|&v| v as f64).sum();
376 let mean = sum / n as f64;
377 assert!(
378 (mean - 0.5).abs() < 0.01,
379 "f32 random mean {mean} too far from 0.5"
380 );
381 }
382
383 #[test]
384 fn uniform_f32_in_range() {
385 let mut rng = default_rng_seeded(42);
386 let arr = rng.uniform_f32(5.0, 10.0, 10_000).unwrap();
387 for &v in arr.as_slice().unwrap() {
388 assert!(
389 (5.0..10.0).contains(&v),
390 "f32 uniform value out of range: {v}"
391 );
392 }
393 }
394
395 #[test]
396 fn uniform_f32_bad_range() {
397 let mut rng = default_rng_seeded(42);
398 assert!(rng.uniform_f32(10.0, 5.0, 100).is_err());
399 assert!(rng.uniform_f32(5.0, 5.0, 100).is_err());
400 }
401
402 #[test]
403 fn uniform_f32_nd_shape() {
404 let mut rng = default_rng_seeded(42);
405 let arr = rng.uniform_f32(-1.0, 1.0, [2, 5]).unwrap();
406 assert_eq!(arr.shape(), &[2, 5]);
407 for &v in arr.iter() {
408 assert!((-1.0..1.0).contains(&v));
409 }
410 }
411
412 #[test]
413 fn random_f32_zero_axis_returns_empty() {
414 let mut rng = default_rng_seeded(42);
415 let a = rng.random_f32([3, 0]).unwrap();
416 assert_eq!(a.shape(), &[3, 0]);
417 assert_eq!(a.size(), 0);
418 }
419}