1use ferray_core::dimension::broadcast::broadcast_shapes;
4use ferray_core::{Array, FerrayError, IxDyn};
5
6use crate::bitgen::BitGenerator;
7use crate::generator::{
8 Generator, generate_vec, generate_vec_f32, generate_vec_i64, shape_size, vec_to_array_f32,
9 vec_to_array_f64, vec_to_array_i64,
10};
11use crate::shape::IntoShape;
12
13impl<B: BitGenerator> Generator<B> {
14 pub fn random(&mut self, size: impl IntoShape) -> Result<Array<f64, IxDyn>, FerrayError> {
32 let shape = size.into_shape()?;
33 let n = shape_size(&shape);
34 let data = generate_vec(self, n, super::super::bitgen::BitGenerator::next_f64);
35 vec_to_array_f64(data, &shape)
36 }
37
38 pub fn random_into(&mut self, out: &mut Array<f64, IxDyn>) -> Result<(), FerrayError> {
48 let slice = out.as_slice_mut().ok_or_else(|| {
49 FerrayError::invalid_value("random_into requires a contiguous out buffer")
50 })?;
51 for v in slice.iter_mut() {
52 *v = self.bg.next_f64();
53 }
54 Ok(())
55 }
56
57 pub fn uniform(
64 &mut self,
65 low: f64,
66 high: f64,
67 size: impl IntoShape,
68 ) -> Result<Array<f64, IxDyn>, FerrayError> {
69 if low >= high {
70 return Err(FerrayError::invalid_value(format!(
71 "low ({low}) must be less than high ({high})"
72 )));
73 }
74 let shape = size.into_shape()?;
75 let n = shape_size(&shape);
76 let range = high - low;
77 let data = generate_vec(self, n, |bg| bg.next_f64().mul_add(range, low));
78 vec_to_array_f64(data, &shape)
79 }
80
81 pub fn uniform_array(
93 &mut self,
94 low: &Array<f64, IxDyn>,
95 high: &Array<f64, IxDyn>,
96 ) -> Result<Array<f64, IxDyn>, FerrayError> {
97 let target = broadcast_shapes(low.shape(), high.shape())?;
98 let lo_v = low.broadcast_to(&target)?;
99 let hi_v = high.broadcast_to(&target)?;
100 let total: usize = target.iter().product();
101 let mut out: Vec<f64> = Vec::with_capacity(total);
102 for (&l, &h) in lo_v.iter().zip(hi_v.iter()) {
103 if l >= h {
104 return Err(FerrayError::invalid_value(format!(
105 "low ({l}) must be less than high ({h})"
106 )));
107 }
108 out.push(self.bg.next_f64().mul_add(h - l, l));
109 }
110 Array::<f64, IxDyn>::from_vec(IxDyn::new(&target), out)
111 }
112
113 pub fn random_f32(&mut self, size: impl IntoShape) -> Result<Array<f32, IxDyn>, FerrayError> {
131 let shape = size.into_shape()?;
132 let n = shape_size(&shape);
133 let data = generate_vec_f32(self, n, super::super::bitgen::BitGenerator::next_f32);
134 vec_to_array_f32(data, &shape)
135 }
136
137 pub fn uniform_f32(
144 &mut self,
145 low: f32,
146 high: f32,
147 size: impl IntoShape,
148 ) -> Result<Array<f32, IxDyn>, FerrayError> {
149 if low >= high {
150 return Err(FerrayError::invalid_value(format!(
151 "low ({low}) must be less than high ({high})"
152 )));
153 }
154 let shape = size.into_shape()?;
155 let n = shape_size(&shape);
156 let range = high - low;
157 let data = generate_vec_f32(self, n, |bg| bg.next_f32().mul_add(range, low));
158 vec_to_array_f32(data, &shape)
159 }
160
161 pub fn integers(
168 &mut self,
169 low: i64,
170 high: i64,
171 size: impl IntoShape,
172 ) -> Result<Array<i64, IxDyn>, FerrayError> {
173 if low >= high {
174 return Err(FerrayError::invalid_value(format!(
175 "low ({low}) must be less than high ({high})"
176 )));
177 }
178 let shape = size.into_shape()?;
179 let n = shape_size(&shape);
180 let range = (high - low) as u64;
181 let data = generate_vec_i64(self, n, |bg| low + bg.next_u64_bounded(range) as i64);
182 vec_to_array_i64(data, &shape)
183 }
184}
185
186macro_rules! typed_integers {
194 (
195 $name:ident, $ty:ty, $doc:literal
196 ) => {
197 impl<B: BitGenerator> Generator<B> {
198 #[doc = $doc]
199 pub fn $name(
203 &mut self,
204 low: $ty,
205 high: $ty,
206 size: impl IntoShape,
207 ) -> Result<Array<$ty, IxDyn>, FerrayError> {
208 if low >= high {
209 return Err(FerrayError::invalid_value(format!(
210 "low ({low}) must be less than high ({high})"
211 )));
212 }
213 let shape = size.into_shape()?;
214 let n = shape_size(&shape);
215 let range = (i128::from(high) - i128::from(low)) as u64;
218 let mut data = Vec::with_capacity(n);
219 for _ in 0..n {
220 let raw = self.bg.next_u64_bounded(range);
221 let v = (i128::from(low) + raw as i128) as $ty;
222 data.push(v);
223 }
224 Array::<$ty, IxDyn>::from_vec(IxDyn::new(&shape), data)
225 }
226 }
227 };
228}
229
230typed_integers!(
231 integers_u8,
232 u8,
233 "Generate u8 integers in [low, high), matching `numpy.random.Generator.integers(..., dtype=np.uint8)`."
234);
235typed_integers!(
236 integers_i8,
237 i8,
238 "Generate i8 integers in [low, high), matching `numpy.random.Generator.integers(..., dtype=np.int8)`."
239);
240typed_integers!(
241 integers_u16,
242 u16,
243 "Generate u16 integers in [low, high), matching `numpy.random.Generator.integers(..., dtype=np.uint16)`."
244);
245typed_integers!(
246 integers_i16,
247 i16,
248 "Generate i16 integers in [low, high), matching `numpy.random.Generator.integers(..., dtype=np.int16)`."
249);
250typed_integers!(
251 integers_u32,
252 u32,
253 "Generate u32 integers in [low, high), matching `numpy.random.Generator.integers(..., dtype=np.uint32)`."
254);
255typed_integers!(
256 integers_i32,
257 i32,
258 "Generate i32 integers in [low, high), matching `numpy.random.Generator.integers(..., dtype=np.int32)`."
259);
260typed_integers!(
261 integers_u64,
262 u64,
263 "Generate u64 integers in [low, high), matching `numpy.random.Generator.integers(..., dtype=np.uint64)`."
264);
265
266#[cfg(test)]
267mod tests {
268 use crate::default_rng_seeded;
269
270 #[test]
273 fn integers_u8_in_range() {
274 let mut rng = default_rng_seeded(42);
275 let arr = rng.integers_u8(0, 200, 10_000).unwrap();
276 for &v in arr.as_slice().unwrap() {
277 assert!(v < 200);
278 }
279 }
280
281 #[test]
282 fn integers_i8_in_range_with_negatives() {
283 let mut rng = default_rng_seeded(42);
284 let arr = rng.integers_i8(-50, 50, 10_000).unwrap();
285 for &v in arr.as_slice().unwrap() {
286 assert!((-50..50).contains(&v));
287 }
288 }
289
290 #[test]
291 fn integers_u16_in_range() {
292 let mut rng = default_rng_seeded(42);
293 let arr = rng.integers_u16(1000, 5000, 5_000).unwrap();
294 for &v in arr.as_slice().unwrap() {
295 assert!((1000..5000).contains(&v));
296 }
297 }
298
299 #[test]
300 fn integers_i32_in_range_full_span() {
301 let mut rng = default_rng_seeded(42);
302 let arr = rng.integers_i32(i32::MIN, i32::MAX, 1_000).unwrap();
303 for &v in arr.as_slice().unwrap() {
304 assert!(v < i32::MAX);
305 }
306 }
307
308 #[test]
309 fn integers_u64_full_range() {
310 let mut rng = default_rng_seeded(42);
311 let arr = rng.integers_u64(0, u64::MAX, 100).unwrap();
312 assert_eq!(arr.shape(), &[100]);
314 }
315
316 #[test]
317 fn integers_typed_low_ge_high_errors() {
318 let mut rng = default_rng_seeded(0);
319 assert!(rng.integers_u8(10, 5, 5).is_err());
320 assert!(rng.integers_i16(0, 0, 5).is_err());
321 assert!(rng.integers_u32(7, 7, 5).is_err());
322 }
323
324 #[test]
327 fn random_into_fills_buffer_in_place() {
328 use ferray_core::{Array, IxDyn};
329 let mut rng = default_rng_seeded(42);
330 let mut buf = Array::<f64, IxDyn>::from_vec(IxDyn::new(&[8]), vec![-1.0; 8]).unwrap();
331 rng.random_into(&mut buf).unwrap();
332 let s = buf.as_slice().unwrap();
333 for &v in s {
334 assert!((0.0..1.0).contains(&v));
335 }
336 }
337
338 #[test]
339 fn random_into_matches_random_for_same_seed() {
340 use ferray_core::{Array, IxDyn};
341 let mut a = default_rng_seeded(7);
342 let mut b = default_rng_seeded(7);
343 let allocated = a.random([3, 4]).unwrap();
344 let mut buf = Array::<f64, IxDyn>::from_vec(IxDyn::new(&[3, 4]), vec![0.0; 12]).unwrap();
345 b.random_into(&mut buf).unwrap();
346 assert_eq!(allocated.as_slice().unwrap(), buf.as_slice().unwrap());
347 }
348
349 #[test]
352 fn uniform_array_per_element_bounds() {
353 use ferray_core::{Array, IxDyn};
354 let mut rng = default_rng_seeded(42);
355 let low = Array::<f64, IxDyn>::from_vec(IxDyn::new(&[3]), vec![0.0, 100.0, -10.0]).unwrap();
356 let high = Array::<f64, IxDyn>::from_vec(IxDyn::new(&[3]), vec![1.0, 200.0, 0.0]).unwrap();
357 let out = rng.uniform_array(&low, &high).unwrap();
358 let s = out.as_slice().unwrap();
359 assert!((0.0..1.0).contains(&s[0]));
360 assert!((100.0..200.0).contains(&s[1]));
361 assert!((-10.0..0.0).contains(&s[2]));
362 }
363
364 #[test]
365 fn uniform_array_broadcast() {
366 use ferray_core::{Array, IxDyn};
367 let mut rng = default_rng_seeded(42);
368 let low = Array::<f64, IxDyn>::from_vec(IxDyn::new(&[1]), vec![0.0]).unwrap();
369 let high =
370 Array::<f64, IxDyn>::from_vec(IxDyn::new(&[2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
371 .unwrap();
372 let out = rng.uniform_array(&low, &high).unwrap();
373 assert_eq!(out.shape(), &[2, 3]);
374 let s = out.as_slice().unwrap();
376 let highs = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
377 for (i, &v) in s.iter().enumerate() {
378 assert!(v >= 0.0 && v < highs[i]);
379 }
380 }
381
382 #[test]
383 fn uniform_array_low_ge_high_errors() {
384 use ferray_core::{Array, IxDyn};
385 let mut rng = default_rng_seeded(0);
386 let low = Array::<f64, IxDyn>::from_vec(IxDyn::new(&[2]), vec![0.0, 5.0]).unwrap();
387 let high = Array::<f64, IxDyn>::from_vec(
388 IxDyn::new(&[2]),
389 vec![1.0, 5.0], )
391 .unwrap();
392 assert!(rng.uniform_array(&low, &high).is_err());
393 }
394
395 #[test]
396 fn random_in_range() {
397 let mut rng = default_rng_seeded(42);
398 let arr = rng.random(10_000).unwrap();
399 let slice = arr.as_slice().unwrap();
400 for &v in slice {
401 assert!((0.0..1.0).contains(&v));
402 }
403 }
404
405 #[test]
406 fn random_deterministic() {
407 let mut rng1 = default_rng_seeded(42);
408 let mut rng2 = default_rng_seeded(42);
409 let a = rng1.random(100).unwrap();
410 let b = rng2.random(100).unwrap();
411 assert_eq!(a.as_slice().unwrap(), b.as_slice().unwrap());
412 }
413
414 #[test]
415 fn uniform_in_range() {
416 let mut rng = default_rng_seeded(42);
417 let arr = rng.uniform(5.0, 10.0, 10_000).unwrap();
418 let slice = arr.as_slice().unwrap();
419 for &v in slice {
420 assert!((5.0..10.0).contains(&v), "value {v} out of range");
421 }
422 }
423
424 #[test]
425 fn uniform_bad_range() {
426 let mut rng = default_rng_seeded(42);
427 assert!(rng.uniform(10.0, 5.0, 100).is_err());
428 assert!(rng.uniform(5.0, 5.0, 100).is_err());
429 }
430
431 #[test]
432 fn integers_in_range() {
433 let mut rng = default_rng_seeded(42);
434 let arr = rng.integers(0, 10, 10_000).unwrap();
435 let slice = arr.as_slice().unwrap();
436 for &v in slice {
437 assert!((0..10).contains(&v), "value {v} out of range");
438 }
439 }
440
441 #[test]
442 fn integers_negative_range() {
443 let mut rng = default_rng_seeded(42);
444 let arr = rng.integers(-5, 5, 1000).unwrap();
445 let slice = arr.as_slice().unwrap();
446 for &v in slice {
447 assert!((-5..5).contains(&v), "value {v} out of range");
448 }
449 }
450
451 #[test]
452 fn integers_bad_range() {
453 let mut rng = default_rng_seeded(42);
454 assert!(rng.integers(10, 5, 100).is_err());
455 }
456
457 #[test]
458 fn uniform_mean_variance() {
459 let mut rng = default_rng_seeded(42);
460 let n = 100_000;
461 let arr = rng.uniform(2.0, 8.0, n).unwrap();
462 let slice = arr.as_slice().unwrap();
463 let mean: f64 = slice.iter().sum::<f64>() / n as f64;
464 let var: f64 = slice.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n as f64;
465 let expected_mean = 5.0;
467 let expected_var = 3.0;
468 let se_mean = (expected_var / n as f64).sqrt();
469 assert!(
470 (mean - expected_mean).abs() < 3.0 * se_mean,
471 "mean {mean} too far from {expected_mean}"
472 );
473 assert!(
475 (var - expected_var).abs() < 0.1,
476 "variance {var} too far from {expected_var}"
477 );
478 }
479
480 #[test]
481 fn reproducibility_golden_values() {
482 let mut rng = default_rng_seeded(42);
485 let arr = rng.random(5).unwrap();
486 let vals = arr.as_slice().unwrap();
487
488 let golden = [vals[0], vals[1], vals[2], vals[3], vals[4]];
490
491 let mut rng2 = default_rng_seeded(42);
493 let arr2 = rng2.random(5).unwrap();
494 let vals2 = arr2.as_slice().unwrap();
495 for i in 0..5 {
496 assert_eq!(
497 vals2[i].to_bits(),
498 golden[i].to_bits(),
499 "golden value mismatch at index {i}"
500 );
501 }
502 }
503
504 #[test]
505 fn different_seeds_different_values() {
506 let mut rng1 = default_rng_seeded(42);
507 let mut rng2 = default_rng_seeded(123);
508 let a = rng1.random(100).unwrap();
509 let b = rng2.random(100).unwrap();
510 let diffs = a
512 .as_slice()
513 .unwrap()
514 .iter()
515 .zip(b.as_slice().unwrap().iter())
516 .filter(|(x, y)| x != y)
517 .count();
518 assert!(diffs > 50, "seeds 42 and 123 produced too-similar output");
519 }
520
521 #[test]
526 fn random_nd_shape_from_array() {
527 let mut rng = default_rng_seeded(42);
528 let arr = rng.random([3, 4]).unwrap();
529 assert_eq!(arr.shape(), &[3, 4]);
530 assert_eq!(arr.size(), 12);
531 }
532
533 #[test]
534 fn random_nd_shape_from_slice() {
535 let mut rng = default_rng_seeded(42);
536 let shape: &[usize] = &[2, 3, 4];
537 let arr = rng.random(shape).unwrap();
538 assert_eq!(arr.shape(), &[2, 3, 4]);
539 assert_eq!(arr.size(), 24);
540 }
541
542 #[test]
543 fn random_nd_shape_from_vec() {
544 let mut rng = default_rng_seeded(42);
545 let shape = vec![5, 5];
546 let arr = rng.random(shape).unwrap();
547 assert_eq!(arr.shape(), &[5, 5]);
548 }
549
550 #[test]
551 fn random_nd_zero_axis_returns_empty() {
552 let mut rng = default_rng_seeded(42);
555 let a = rng.random([3, 0]).unwrap();
556 assert_eq!(a.shape(), &[3, 0]);
557 assert_eq!(a.size(), 0);
558 let b = rng.random(0usize).unwrap();
559 assert_eq!(b.shape(), &[0]);
560 assert_eq!(b.size(), 0);
561 }
562
563 #[test]
564 fn random_nd_equivalent_to_reshape() {
565 let mut rng1 = default_rng_seeded(42);
568 let mut rng2 = default_rng_seeded(42);
569 let a = rng1.random(12).unwrap();
570 let b = rng2.random([3, 4]).unwrap();
571 assert_eq!(a.size(), b.size());
572 let a_data: Vec<f64> = a.iter().copied().collect();
573 let b_data: Vec<f64> = b.iter().copied().collect();
574 assert_eq!(a_data, b_data);
575 }
576
577 #[test]
578 fn uniform_nd_shape() {
579 let mut rng = default_rng_seeded(42);
580 let arr = rng.uniform(0.0, 10.0, [2, 5]).unwrap();
581 assert_eq!(arr.shape(), &[2, 5]);
582 for &v in arr.iter() {
583 assert!((0.0..10.0).contains(&v));
584 }
585 }
586
587 #[test]
588 fn integers_nd_shape() {
589 let mut rng = default_rng_seeded(42);
590 let arr = rng.integers(0, 100, [4, 3]).unwrap();
591 assert_eq!(arr.shape(), &[4, 3]);
592 for &v in arr.iter() {
593 assert!((0..100).contains(&v));
594 }
595 }
596
597 #[test]
602 fn random_f32_in_range() {
603 let mut rng = default_rng_seeded(42);
604 let arr = rng.random_f32(10_000).unwrap();
605 for &v in arr.as_slice().unwrap() {
606 assert!((0.0..1.0).contains(&v), "f32 value out of range: {v}");
607 }
608 }
609
610 #[test]
611 fn random_f32_deterministic() {
612 let mut rng1 = default_rng_seeded(42);
613 let mut rng2 = default_rng_seeded(42);
614 let a = rng1.random_f32(100).unwrap();
615 let b = rng2.random_f32(100).unwrap();
616 assert_eq!(a.as_slice().unwrap(), b.as_slice().unwrap());
617 }
618
619 #[test]
620 fn random_f32_nd_shape() {
621 let mut rng = default_rng_seeded(42);
622 let arr = rng.random_f32([3, 4]).unwrap();
623 assert_eq!(arr.shape(), &[3, 4]);
624 }
625
626 #[test]
627 fn random_f32_mean() {
628 let mut rng = default_rng_seeded(42);
630 let n = 100_000usize;
631 let arr = rng.random_f32(n).unwrap();
632 let sum: f64 = arr.as_slice().unwrap().iter().map(|&v| v as f64).sum();
633 let mean = sum / n as f64;
634 assert!(
635 (mean - 0.5).abs() < 0.01,
636 "f32 random mean {mean} too far from 0.5"
637 );
638 }
639
640 #[test]
641 fn uniform_f32_in_range() {
642 let mut rng = default_rng_seeded(42);
643 let arr = rng.uniform_f32(5.0, 10.0, 10_000).unwrap();
644 for &v in arr.as_slice().unwrap() {
645 assert!(
646 (5.0..10.0).contains(&v),
647 "f32 uniform value out of range: {v}"
648 );
649 }
650 }
651
652 #[test]
653 fn uniform_f32_bad_range() {
654 let mut rng = default_rng_seeded(42);
655 assert!(rng.uniform_f32(10.0, 5.0, 100).is_err());
656 assert!(rng.uniform_f32(5.0, 5.0, 100).is_err());
657 }
658
659 #[test]
660 fn uniform_f32_nd_shape() {
661 let mut rng = default_rng_seeded(42);
662 let arr = rng.uniform_f32(-1.0, 1.0, [2, 5]).unwrap();
663 assert_eq!(arr.shape(), &[2, 5]);
664 for &v in arr.iter() {
665 assert!((-1.0..1.0).contains(&v));
666 }
667 }
668
669 #[test]
670 fn random_f32_zero_axis_returns_empty() {
671 let mut rng = default_rng_seeded(42);
672 let a = rng.random_f32([3, 0]).unwrap();
673 assert_eq!(a.shape(), &[3, 0]);
674 assert_eq!(a.size(), 0);
675 }
676}