1use ferray_core::{Array, FerrayError, Ix1, IxDyn};
4
5use crate::bitgen::BitGenerator;
6use crate::generator::Generator;
7
8impl<B: BitGenerator> Generator<B> {
9 pub fn shuffle<T>(&mut self, arr: &mut Array<T, Ix1>) -> Result<(), FerrayError>
14 where
15 T: ferray_core::Element,
16 {
17 let n = arr.shape()[0];
18 if n <= 1 {
19 return Ok(());
20 }
21 let slice = arr
22 .as_slice_mut()
23 .ok_or_else(|| FerrayError::invalid_value("array must be contiguous for shuffle"))?;
24 for i in (1..n).rev() {
26 let j = self.bg.next_u64_bounded((i + 1) as u64) as usize;
27 slice.swap(i, j);
28 }
29 Ok(())
30 }
31
32 pub fn permutation<T>(&mut self, arr: &Array<T, Ix1>) -> Result<Array<T, Ix1>, FerrayError>
40 where
41 T: ferray_core::Element,
42 {
43 let mut copy = arr.clone();
44 self.shuffle(&mut copy)?;
45 Ok(copy)
46 }
47
48 pub fn permutation_range(&mut self, n: usize) -> Result<Array<i64, Ix1>, FerrayError> {
53 if n == 0 {
54 return Err(FerrayError::invalid_value("n must be > 0"));
55 }
56 let mut data: Vec<i64> = (0..n as i64).collect();
57 for i in (1..n).rev() {
59 let j = self.bg.next_u64_bounded((i + 1) as u64) as usize;
60 data.swap(i, j);
61 }
62 Array::<i64, Ix1>::from_vec(Ix1::new([n]), data)
63 }
64
65 pub fn permuted<T>(
73 &mut self,
74 arr: &Array<T, Ix1>,
75 _axis: usize,
76 ) -> Result<Array<T, Ix1>, FerrayError>
77 where
78 T: ferray_core::Element,
79 {
80 self.permutation(arr)
81 }
82
83 pub fn shuffle_dyn<T>(
95 &mut self,
96 arr: &mut Array<T, IxDyn>,
97 axis: usize,
98 ) -> Result<(), FerrayError>
99 where
100 T: ferray_core::Element,
101 {
102 let shape = arr.shape().to_vec();
103 let ndim = shape.len();
104 if axis >= ndim {
105 return Err(FerrayError::axis_out_of_bounds(axis, ndim));
106 }
107 let n = shape[axis];
108 if n <= 1 {
109 return Ok(());
110 }
111 let inner_stride: usize = shape[axis + 1..].iter().product();
112 let block = n * inner_stride;
113 let outer_size: usize = shape[..axis].iter().product();
114 let slice = arr
115 .as_slice_mut()
116 .ok_or_else(|| FerrayError::invalid_value("array must be contiguous for shuffle"))?;
117 for i in (1..n).rev() {
118 let j = self.bg.next_u64_bounded((i + 1) as u64) as usize;
119 if i == j {
120 continue;
121 }
122 for o in 0..outer_size {
123 let base = o * block;
124 for k in 0..inner_stride {
125 slice.swap(base + i * inner_stride + k, base + j * inner_stride + k);
126 }
127 }
128 }
129 Ok(())
130 }
131
132 pub fn choice_dyn<T>(
152 &mut self,
153 arr: &Array<T, IxDyn>,
154 size: usize,
155 replace: bool,
156 p: Option<&[f64]>,
157 axis: usize,
158 shuffle: bool,
159 ) -> Result<Array<T, IxDyn>, FerrayError>
160 where
161 T: ferray_core::Element,
162 {
163 let shape = arr.shape().to_vec();
164 let ndim = shape.len();
165 if axis >= ndim {
166 return Err(FerrayError::axis_out_of_bounds(axis, ndim));
167 }
168 let axis_len = shape[axis];
169 if size == 0 {
170 let mut out_shape = shape;
172 out_shape[axis] = 0;
173 return Array::<T, IxDyn>::from_vec(IxDyn::new(&out_shape), Vec::new());
174 }
175 if axis_len == 0 {
176 return Err(FerrayError::invalid_value(
177 "choice_dyn: source array has zero length along axis",
178 ));
179 }
180 if !replace && size > axis_len {
181 return Err(FerrayError::invalid_value(format!(
182 "cannot choose {size} elements without replacement from axis of size {axis_len}"
183 )));
184 }
185 if let Some(probs) = p {
186 if probs.len() != axis_len {
187 return Err(FerrayError::invalid_value(format!(
188 "p must have length {axis_len} (size of axis {axis}), got {}",
189 probs.len()
190 )));
191 }
192 let psum: f64 = probs.iter().sum();
193 if (psum - 1.0).abs() > 1e-6 {
194 return Err(FerrayError::invalid_value(format!(
195 "p must sum to 1.0, got {psum}"
196 )));
197 }
198 for (i, &pi) in probs.iter().enumerate() {
199 if pi < 0.0 {
200 return Err(FerrayError::invalid_value(format!(
201 "p[{i}] = {pi} is negative"
202 )));
203 }
204 }
205 }
206
207 let src = arr
208 .as_slice()
209 .ok_or_else(|| FerrayError::invalid_value("array must be contiguous for choice_dyn"))?;
210
211 let mut indices = if let Some(probs) = p {
212 if replace {
213 weighted_sample_with_replacement(&mut self.bg, probs, size)
214 } else {
215 weighted_sample_without_replacement(&mut self.bg, probs, size)?
216 }
217 } else if replace {
218 (0..size)
219 .map(|_| self.bg.next_u64_bounded(axis_len as u64) as usize)
220 .collect()
221 } else {
222 sample_without_replacement(&mut self.bg, axis_len, size)
223 };
224 if !shuffle && !replace {
225 indices.sort_unstable();
226 }
227
228 let inner_stride: usize = shape[axis + 1..].iter().product();
229 let outer_size: usize = shape[..axis].iter().product();
230 let src_block = axis_len * inner_stride;
231 let out_block = size * inner_stride;
232 let total_out = outer_size * out_block;
233
234 let mut out_data: Vec<T> = Vec::with_capacity(total_out);
235 let filler = src[0].clone();
240 out_data.resize(total_out, filler);
241 for o in 0..outer_size {
242 let src_base = o * src_block;
243 let out_base = o * out_block;
244 for (i, &idx) in indices.iter().enumerate() {
245 let src_off = src_base + idx * inner_stride;
246 let out_off = out_base + i * inner_stride;
247 out_data[out_off..out_off + inner_stride]
248 .clone_from_slice(&src[src_off..src_off + inner_stride]);
249 }
250 }
251
252 let mut out_shape = shape;
253 out_shape[axis] = size;
254 Array::<T, IxDyn>::from_vec(IxDyn::new(&out_shape), out_data)
255 }
256
257 pub fn permuted_dyn<T>(
269 &mut self,
270 arr: &Array<T, IxDyn>,
271 axis: usize,
272 ) -> Result<Array<T, IxDyn>, FerrayError>
273 where
274 T: ferray_core::Element,
275 {
276 let shape = arr.shape().to_vec();
277 let ndim = shape.len();
278 if axis >= ndim {
279 return Err(FerrayError::axis_out_of_bounds(axis, ndim));
280 }
281 let mut out = arr.clone();
282 let n = shape[axis];
283 if n <= 1 {
284 return Ok(out);
285 }
286 let inner_stride: usize = shape[axis + 1..].iter().product();
287 let block = n * inner_stride;
288 let outer_size: usize = shape[..axis].iter().product();
289 let slice = out
290 .as_slice_mut()
291 .ok_or_else(|| FerrayError::invalid_value("array must be contiguous for permuted"))?;
292 for o in 0..outer_size {
293 let base = o * block;
294 for k in 0..inner_stride {
295 for i in (1..n).rev() {
298 let j = self.bg.next_u64_bounded((i + 1) as u64) as usize;
299 slice.swap(base + i * inner_stride + k, base + j * inner_stride + k);
300 }
301 }
302 }
303 Ok(out)
304 }
305
306 pub fn choice<T>(
318 &mut self,
319 arr: &Array<T, Ix1>,
320 size: usize,
321 replace: bool,
322 p: Option<&[f64]>,
323 ) -> Result<Array<T, Ix1>, FerrayError>
324 where
325 T: ferray_core::Element,
326 {
327 let n = arr.shape()[0];
328 if size == 0 {
332 return Array::from_vec(Ix1::new([0]), Vec::new());
333 }
334 if n == 0 {
335 return Err(FerrayError::invalid_value("source array must be non-empty"));
336 }
337 if !replace && size > n {
338 return Err(FerrayError::invalid_value(format!(
339 "cannot choose {size} elements without replacement from array of size {n}"
340 )));
341 }
342
343 if let Some(probs) = p {
344 if probs.len() != n {
345 return Err(FerrayError::invalid_value(format!(
346 "p must have same length as array ({n}), got {}",
347 probs.len()
348 )));
349 }
350 let psum: f64 = probs.iter().sum();
351 if (psum - 1.0).abs() > 1e-6 {
352 return Err(FerrayError::invalid_value(format!(
353 "p must sum to 1.0, got {psum}"
354 )));
355 }
356 for (i, &pi) in probs.iter().enumerate() {
357 if pi < 0.0 {
358 return Err(FerrayError::invalid_value(format!(
359 "p[{i}] = {pi} is negative"
360 )));
361 }
362 }
363 }
364
365 let src = arr
366 .as_slice()
367 .ok_or_else(|| FerrayError::invalid_value("array must be contiguous"))?;
368
369 let indices = if let Some(probs) = p {
370 if replace {
372 weighted_sample_with_replacement(&mut self.bg, probs, size)
373 } else {
374 weighted_sample_without_replacement(&mut self.bg, probs, size)?
375 }
376 } else if replace {
377 (0..size)
379 .map(|_| self.bg.next_u64_bounded(n as u64) as usize)
380 .collect()
381 } else {
382 sample_without_replacement(&mut self.bg, n, size)
384 };
385
386 let data: Vec<T> = indices.iter().map(|&i| src[i].clone()).collect();
387 Array::<T, Ix1>::from_vec(Ix1::new([size]), data)
388 }
389}
390
391fn sample_without_replacement<B: BitGenerator>(bg: &mut B, n: usize, size: usize) -> Vec<usize> {
393 let mut pool: Vec<usize> = (0..n).collect();
394 for i in 0..size {
395 let j = i + bg.next_u64_bounded((n - i) as u64) as usize;
396 pool.swap(i, j);
397 }
398 pool[..size].to_vec()
399}
400
401fn weighted_sample_with_replacement<B: BitGenerator>(
413 bg: &mut B,
414 probs: &[f64],
415 size: usize,
416) -> Vec<usize> {
417 let n = probs.len();
418
419 let total: f64 = probs.iter().sum();
423 let mut scaled: Vec<f64> = probs.iter().map(|&p| p * n as f64 / total).collect();
424
425 let mut prob = vec![0.0_f64; n];
426 let mut alias = vec![0_usize; n];
427
428 let mut small: Vec<usize> = Vec::with_capacity(n);
430 let mut large: Vec<usize> = Vec::with_capacity(n);
431 for (i, &m) in scaled.iter().enumerate() {
432 if m < 1.0 {
433 small.push(i);
434 } else {
435 large.push(i);
436 }
437 }
438
439 while !small.is_empty() && !large.is_empty() {
440 let s = small.pop().unwrap();
441 let l = large.pop().unwrap();
442 prob[s] = scaled[s];
443 alias[s] = l;
444 scaled[l] = (scaled[l] + scaled[s]) - 1.0;
446 if scaled[l] < 1.0 {
447 small.push(l);
448 } else {
449 large.push(l);
450 }
451 }
452 for &i in large.iter().chain(small.iter()) {
456 prob[i] = 1.0;
457 }
458
459 (0..size)
460 .map(|_| {
461 let i = bg.next_u64_bounded(n as u64) as usize;
462 let u = bg.next_f64();
463 if u < prob[i] { i } else { alias[i] }
464 })
465 .collect()
466}
467
468fn weighted_sample_without_replacement<B: BitGenerator>(
470 bg: &mut B,
471 probs: &[f64],
472 size: usize,
473) -> Result<Vec<usize>, FerrayError> {
474 let n = probs.len();
475 let mut weights: Vec<f64> = probs.to_vec();
476 let mut selected = Vec::with_capacity(size);
477
478 for _ in 0..size {
479 let total: f64 = weights.iter().sum();
480 if total <= 0.0 {
481 return Err(FerrayError::invalid_value(
482 "insufficient probability mass for sampling without replacement",
483 ));
484 }
485 let u = bg.next_f64() * total;
486 let mut cumsum = 0.0;
487 let mut chosen = n - 1;
488 for (i, &w) in weights.iter().enumerate() {
489 cumsum += w;
490 if cumsum > u {
491 chosen = i;
492 break;
493 }
494 }
495 selected.push(chosen);
496 weights[chosen] = 0.0;
497 }
498
499 Ok(selected)
500}
501
502#[cfg(test)]
503mod tests {
504 use crate::default_rng_seeded;
505 use ferray_core::{Array, Ix1};
506
507 #[test]
508 fn shuffle_preserves_elements() {
509 let mut rng = default_rng_seeded(42);
510 let mut arr = Array::<i64, Ix1>::from_vec(Ix1::new([5]), vec![1, 2, 3, 4, 5]).unwrap();
511 rng.shuffle(&mut arr).unwrap();
512 let mut sorted: Vec<i64> = arr.as_slice().unwrap().to_vec();
513 sorted.sort_unstable();
514 assert_eq!(sorted, vec![1, 2, 3, 4, 5]);
515 }
516
517 #[test]
518 fn permutation_preserves_elements() {
519 let mut rng = default_rng_seeded(42);
520 let arr = Array::<i64, Ix1>::from_vec(Ix1::new([5]), vec![10, 20, 30, 40, 50]).unwrap();
521 let perm = rng.permutation(&arr).unwrap();
522 let mut sorted: Vec<i64> = perm.as_slice().unwrap().to_vec();
523 sorted.sort_unstable();
524 assert_eq!(sorted, vec![10, 20, 30, 40, 50]);
525 }
526
527 #[test]
528 fn permutation_range_covers_all() {
529 let mut rng = default_rng_seeded(42);
530 let perm = rng.permutation_range(10).unwrap();
531 let mut sorted: Vec<i64> = perm.as_slice().unwrap().to_vec();
532 sorted.sort_unstable();
533 let expected: Vec<i64> = (0..10).collect();
534 assert_eq!(sorted, expected);
535 }
536
537 #[test]
538 fn shuffle_modifies_in_place() {
539 let mut rng = default_rng_seeded(42);
540 let original = vec![1i64, 2, 3, 4, 5, 6, 7, 8, 9, 10];
541 let mut arr = Array::<i64, Ix1>::from_vec(Ix1::new([10]), original.clone()).unwrap();
542 rng.shuffle(&mut arr).unwrap();
543 let shuffled = arr.as_slice().unwrap().to_vec();
545 let mut sorted = shuffled;
547 sorted.sort_unstable();
548 assert_eq!(sorted, original);
549 }
550
551 #[test]
552 fn choice_with_replacement() {
553 let mut rng = default_rng_seeded(42);
554 let arr = Array::<i64, Ix1>::from_vec(Ix1::new([5]), vec![10, 20, 30, 40, 50]).unwrap();
555 let chosen = rng.choice(&arr, 10, true, None).unwrap();
556 assert_eq!(chosen.shape(), &[10]);
557 let src: Vec<i64> = vec![10, 20, 30, 40, 50];
559 for &v in chosen.as_slice().unwrap() {
560 assert!(src.contains(&v), "choice returned unexpected value {v}");
561 }
562 }
563
564 #[test]
565 fn choice_without_replacement_no_duplicates() {
566 let mut rng = default_rng_seeded(42);
567 let arr = Array::<i64, Ix1>::from_vec(Ix1::new([10]), (0..10).collect()).unwrap();
568 let chosen = rng.choice(&arr, 5, false, None).unwrap();
569 let slice = chosen.as_slice().unwrap();
570 let mut seen = std::collections::HashSet::new();
572 for &v in slice {
573 assert!(
574 seen.insert(v),
575 "duplicate value {v} in choice without replacement"
576 );
577 }
578 }
579
580 #[test]
581 fn choice_without_replacement_too_many() {
582 let mut rng = default_rng_seeded(42);
583 let arr = Array::<i64, Ix1>::from_vec(Ix1::new([5]), vec![1, 2, 3, 4, 5]).unwrap();
584 assert!(rng.choice(&arr, 10, false, None).is_err());
585 }
586
587 #[test]
588 fn choice_with_weights() {
589 let mut rng = default_rng_seeded(42);
590 let arr = Array::<i64, Ix1>::from_vec(Ix1::new([3]), vec![10, 20, 30]).unwrap();
591 let p = [0.0, 0.0, 1.0]; let chosen = rng.choice(&arr, 10, true, Some(&p)).unwrap();
593 for &v in chosen.as_slice().unwrap() {
594 assert_eq!(v, 30);
595 }
596 }
597
598 #[test]
599 fn choice_without_replacement_with_weights() {
600 let mut rng = default_rng_seeded(42);
601 let arr = Array::<i64, Ix1>::from_vec(Ix1::new([5]), vec![1, 2, 3, 4, 5]).unwrap();
602 let p = [0.1, 0.2, 0.3, 0.2, 0.2];
603 let chosen = rng.choice(&arr, 3, false, Some(&p)).unwrap();
604 let slice = chosen.as_slice().unwrap();
605 let mut seen = std::collections::HashSet::new();
607 for &v in slice {
608 assert!(seen.insert(v), "duplicate value {v}");
609 }
610 }
611
612 #[test]
613 fn choice_bad_weights() {
614 let mut rng = default_rng_seeded(42);
615 let arr = Array::<i64, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
616 assert!(rng.choice(&arr, 1, true, Some(&[0.5, 0.5])).is_err());
618 assert!(rng.choice(&arr, 1, true, Some(&[0.5, 0.5, 0.5])).is_err());
620 assert!(rng.choice(&arr, 1, true, Some(&[-0.1, 0.6, 0.5])).is_err());
622 }
623
624 #[test]
625 fn permuted_1d() {
626 let mut rng = default_rng_seeded(42);
627 let arr = Array::<i64, Ix1>::from_vec(Ix1::new([5]), vec![1, 2, 3, 4, 5]).unwrap();
628 let result = rng.permuted(&arr, 0).unwrap();
629 let mut sorted: Vec<i64> = result.as_slice().unwrap().to_vec();
630 sorted.sort_unstable();
631 assert_eq!(sorted, vec![1, 2, 3, 4, 5]);
632 }
633
634 #[test]
635 fn weighted_with_replacement_alias_distribution_recovers_probs() {
636 let mut rng = default_rng_seeded(42);
641 let arr = Array::<i64, Ix1>::from_vec(Ix1::new([5]), vec![0, 1, 2, 3, 4]).unwrap();
642 let p = [0.05, 0.15, 0.30, 0.40, 0.10];
643 let n = 100_000;
644 let chosen = rng.choice(&arr, n, true, Some(&p)).unwrap();
645 let mut counts = [0_usize; 5];
646 for &v in chosen.as_slice().unwrap() {
647 counts[v as usize] += 1;
648 }
649 for (i, &c) in counts.iter().enumerate() {
653 let observed = c as f64 / n as f64;
654 assert!(
655 (observed - p[i]).abs() < 0.015,
656 "bin {i}: observed {observed}, expected {}",
657 p[i]
658 );
659 }
660 }
661
662 #[test]
665 fn choice_dyn_axis0_picks_whole_rows() {
666 use ferray_core::IxDyn;
667 let mut rng = default_rng_seeded(42);
668 let data: Vec<i64> = (0..5)
669 .flat_map(|i| (0..3).map(move |j| i * 100 + j))
670 .collect();
671 let arr = Array::<i64, IxDyn>::from_vec(IxDyn::new(&[5, 3]), data).unwrap();
672 let chosen = rng.choice_dyn(&arr, 4, true, None, 0, true).unwrap();
673 assert_eq!(chosen.shape(), &[4, 3]);
674 let slice = chosen.as_slice().unwrap();
675 for row in 0..4 {
676 let v0 = slice[row * 3];
677 let id = v0 / 100;
678 assert!((0..5).contains(&id));
679 assert_eq!(slice[row * 3 + 1], id * 100 + 1);
680 assert_eq!(slice[row * 3 + 2], id * 100 + 2);
681 }
682 }
683
684 #[test]
685 fn choice_dyn_axis1_picks_whole_columns() {
686 use ferray_core::IxDyn;
687 let mut rng = default_rng_seeded(7);
688 let data: Vec<i64> = (0..3)
689 .flat_map(|i| (0..6).map(move |j| i * 10 + j))
690 .collect();
691 let arr = Array::<i64, IxDyn>::from_vec(IxDyn::new(&[3, 6]), data).unwrap();
692 let chosen = rng.choice_dyn(&arr, 2, false, None, 1, true).unwrap();
693 assert_eq!(chosen.shape(), &[3, 2]);
694 let slice = chosen.as_slice().unwrap();
695 for col in 0..2 {
698 let v0 = slice[col];
699 let v1 = slice[2 + col];
700 let v2 = slice[4 + col];
701 assert!((0..6).contains(&v0));
702 assert_eq!(v1, v0 + 10);
703 assert_eq!(v2, v0 + 20);
704 }
705 }
706
707 #[test]
708 fn choice_dyn_without_replacement_no_duplicate_rows() {
709 use ferray_core::IxDyn;
710 let mut rng = default_rng_seeded(1);
711 let data: Vec<i64> = (0..10)
712 .flat_map(|i| (0..2).map(move |j| i * 100 + j))
713 .collect();
714 let arr = Array::<i64, IxDyn>::from_vec(IxDyn::new(&[10, 2]), data).unwrap();
715 let chosen = rng.choice_dyn(&arr, 5, false, None, 0, true).unwrap();
716 let slice = chosen.as_slice().unwrap();
717 let mut ids = std::collections::HashSet::new();
718 for row in 0..5 {
719 let id = slice[row * 2] / 100;
720 assert!(ids.insert(id), "row id {id} repeated under replace=false");
721 }
722 }
723
724 #[test]
725 fn choice_dyn_shuffle_false_returns_sorted_indices() {
726 use ferray_core::IxDyn;
727 let mut rng = default_rng_seeded(3);
728 let data: Vec<i64> = (0..12)
732 .flat_map(|i| (0..2).map(move |j| if j == 0 { i as i64 } else { i as i64 * 10 }))
733 .collect();
734 let arr = Array::<i64, IxDyn>::from_vec(IxDyn::new(&[12, 2]), data).unwrap();
735 let chosen = rng.choice_dyn(&arr, 6, false, None, 0, false).unwrap();
736 let slice = chosen.as_slice().unwrap();
737 let mut last = -1i64;
738 for row in 0..6 {
739 let id = slice[row * 2];
740 assert!(
741 id > last,
742 "shuffle=false output not ascending: {id} after {last}"
743 );
744 last = id;
745 }
746 }
747
748 #[test]
749 fn choice_dyn_weighted_concentrates_on_high_p() {
750 use ferray_core::IxDyn;
751 let mut rng = default_rng_seeded(0);
752 let data: Vec<i64> = (0..4)
753 .flat_map(|i| (0..2).map(move |j| i * 100 + j))
754 .collect();
755 let arr = Array::<i64, IxDyn>::from_vec(IxDyn::new(&[4, 2]), data).unwrap();
756 let p = [0.0, 0.0, 1.0, 0.0];
758 let chosen = rng.choice_dyn(&arr, 20, true, Some(&p), 0, true).unwrap();
759 let slice = chosen.as_slice().unwrap();
760 for row in 0..20 {
761 assert_eq!(slice[row * 2], 200, "weighted choice strayed from row 2");
762 }
763 }
764
765 #[test]
766 fn choice_dyn_size_zero_returns_empty_axis() {
767 use ferray_core::IxDyn;
768 let mut rng = default_rng_seeded(11);
769 let arr = Array::<i64, IxDyn>::from_vec(IxDyn::new(&[3, 4]), (0..12).collect()).unwrap();
770 let chosen = rng.choice_dyn(&arr, 0, true, None, 0, true).unwrap();
771 assert_eq!(chosen.shape(), &[0, 4]);
772 }
773
774 #[test]
775 fn choice_dyn_bad_axis() {
776 use ferray_core::IxDyn;
777 let mut rng = default_rng_seeded(0);
778 let arr = Array::<i64, IxDyn>::from_vec(IxDyn::new(&[2, 3]), (0..6).collect()).unwrap();
779 assert!(rng.choice_dyn(&arr, 1, true, None, 5, true).is_err());
780 }
781
782 #[test]
783 fn choice_dyn_too_many_no_replace_errors() {
784 use ferray_core::IxDyn;
785 let mut rng = default_rng_seeded(0);
786 let arr = Array::<i64, IxDyn>::from_vec(IxDyn::new(&[3, 2]), (0..6).collect()).unwrap();
787 assert!(rng.choice_dyn(&arr, 5, false, None, 0, true).is_err());
788 }
789
790 #[test]
793 fn shuffle_dyn_axis0_swaps_whole_rows() {
794 use ferray_core::IxDyn;
795 let mut rng = default_rng_seeded(42);
796 let data: Vec<i64> = (0..4)
798 .flat_map(|i| (0..3).map(move |j| i * 10 + j))
799 .collect();
800 let mut arr = Array::<i64, IxDyn>::from_vec(IxDyn::new(&[4, 3]), data).unwrap();
801 rng.shuffle_dyn(&mut arr, 0).unwrap();
802 let slice = arr.as_slice().unwrap();
803 let mut seen = std::collections::HashSet::new();
805 for row in 0..4 {
806 let row_first = slice[row * 3];
807 let id = row_first / 10;
808 assert!(
809 (0..4).contains(&id),
810 "row {row} starts with unexpected value {row_first}"
811 );
812 assert_eq!(slice[row * 3 + 1], id * 10 + 1);
813 assert_eq!(slice[row * 3 + 2], id * 10 + 2);
814 assert!(
815 seen.insert(id),
816 "row id {id} duplicated — shuffle broke a row"
817 );
818 }
819 }
820
821 #[test]
822 fn shuffle_dyn_axis1_swaps_whole_columns() {
823 use ferray_core::IxDyn;
824 let mut rng = default_rng_seeded(7);
825 let data: Vec<i64> = (0..3)
827 .flat_map(|i| (0..4).map(move |j| i * 10 + j))
828 .collect();
829 let mut arr = Array::<i64, IxDyn>::from_vec(IxDyn::new(&[3, 4]), data).unwrap();
830 rng.shuffle_dyn(&mut arr, 1).unwrap();
831 let slice = arr.as_slice().unwrap();
832 let mut col_ids = Vec::new();
834 for col in 0..4 {
835 let v0 = slice[col];
836 let v1 = slice[4 + col];
837 let v2 = slice[8 + col];
838 assert!((0..4).contains(&v0));
839 assert_eq!(v1, v0 + 10);
840 assert_eq!(v2, v0 + 20);
841 col_ids.push(v0);
842 }
843 col_ids.sort_unstable();
844 assert_eq!(col_ids, vec![0, 1, 2, 3]);
845 }
846
847 #[test]
848 fn shuffle_dyn_axis_out_of_bounds() {
849 use ferray_core::IxDyn;
850 let mut rng = default_rng_seeded(0);
851 let mut arr = Array::<i64, IxDyn>::from_vec(IxDyn::new(&[2, 3]), vec![0; 6]).unwrap();
852 assert!(rng.shuffle_dyn(&mut arr, 2).is_err());
853 }
854
855 #[test]
856 fn permuted_dyn_axis0_each_column_independent() {
857 use ferray_core::IxDyn;
858 let mut rng = default_rng_seeded(99);
859 let n_rows = 5;
863 let n_cols = 4;
864 let data: Vec<i64> = (0..n_rows * n_cols).map(|x| x as i64).collect();
865 let arr =
866 Array::<i64, IxDyn>::from_vec(IxDyn::new(&[n_rows, n_cols]), data.clone()).unwrap();
867 let result = rng.permuted_dyn(&arr, 0).unwrap();
868 let slice = result.as_slice().unwrap();
869 for col in 0..n_cols {
871 let original_col: Vec<i64> = (0..n_rows).map(|r| (r * n_cols + col) as i64).collect();
872 let mut got_col: Vec<i64> = (0..n_rows).map(|r| slice[r * n_cols + col]).collect();
873 got_col.sort_unstable();
874 let mut want = original_col.clone();
875 want.sort_unstable();
876 assert_eq!(got_col, want, "col {col} lost values during permute");
877 }
878 }
879
880 #[test]
881 fn permuted_dyn_columns_can_diverge() {
882 use ferray_core::IxDyn;
883 let mut rng = default_rng_seeded(1234);
887 let n_rows = 5;
888 let n_cols = 4;
889 let data: Vec<i64> = (0..n_rows * n_cols)
890 .map(|x| x as i64 % n_rows as i64)
891 .collect();
892 let arr =
893 Array::<i64, IxDyn>::from_vec(IxDyn::new(&[n_rows, n_cols]), data.clone()).unwrap();
894 let result = rng.permuted_dyn(&arr, 0).unwrap();
895 let slice = result.as_slice().unwrap();
896 let col0: Vec<i64> = (0..n_rows).map(|r| slice[r * n_cols]).collect();
898 let mut any_diff = false;
899 for col in 1..n_cols {
900 let coln: Vec<i64> = (0..n_rows).map(|r| slice[r * n_cols + col]).collect();
901 if col0 != coln {
902 any_diff = true;
903 break;
904 }
905 }
906 assert!(
907 any_diff,
908 "all columns matched — permuted didn't independently shuffle"
909 );
910 }
911
912 #[test]
913 fn permuted_dyn_seed_reproducible() {
914 use ferray_core::IxDyn;
915 let mut a = default_rng_seeded(31);
916 let mut b = default_rng_seeded(31);
917 let arr = Array::<i64, IxDyn>::from_vec(IxDyn::new(&[3, 3]), (0..9).collect()).unwrap();
918 let xa = a.permuted_dyn(&arr, 1).unwrap();
919 let xb = b.permuted_dyn(&arr, 1).unwrap();
920 assert_eq!(xa.as_slice().unwrap(), xb.as_slice().unwrap());
921 }
922
923 #[test]
924 fn weighted_with_replacement_unnormalized_probs() {
925 let mut rng = default_rng_seeded(42);
931 let arr = Array::<i64, Ix1>::from_vec(Ix1::new([3]), vec![0, 1, 2]).unwrap();
932 let p = [0.2, 0.5, 0.3];
934 let n = 50_000;
935 let chosen = rng.choice(&arr, n, true, Some(&p)).unwrap();
936 let mut counts = [0_usize; 3];
937 for &v in chosen.as_slice().unwrap() {
938 counts[v as usize] += 1;
939 }
940 for (i, &c) in counts.iter().enumerate() {
941 let observed = c as f64 / n as f64;
942 assert!(
943 (observed - p[i]).abs() < 0.02,
944 "bin {i}: observed {observed}, expected {}",
945 p[i]
946 );
947 }
948 }
949}