Skip to main content

ferray_random/
permutations.rs

1// ferray-random: Permutations and sampling — shuffle, permutation, permuted, choice
2
3use ferray_core::{Array, FerrayError, Ix1, IxDyn};
4
5use crate::bitgen::BitGenerator;
6use crate::generator::Generator;
7
8impl<B: BitGenerator> Generator<B> {
9    /// Shuffle a 1-D array in-place using Fisher-Yates.
10    ///
11    /// # Errors
12    /// Returns `FerrayError::InvalidValue` if the array is not contiguous.
13    pub fn shuffle<T>(&mut self, arr: &mut Array<T, Ix1>) -> Result<(), FerrayError>
14    where
15        T: ferray_core::Element,
16    {
17        let n = arr.shape()[0];
18        if n <= 1 {
19            return Ok(());
20        }
21        let slice = arr
22            .as_slice_mut()
23            .ok_or_else(|| FerrayError::invalid_value("array must be contiguous for shuffle"))?;
24        // Fisher-Yates
25        for i in (1..n).rev() {
26            let j = self.bg.next_u64_bounded((i + 1) as u64) as usize;
27            slice.swap(i, j);
28        }
29        Ok(())
30    }
31
32    /// Return a new array with elements randomly permuted.
33    ///
34    /// If the input is 1-D, returns a shuffled copy. If an integer `n` is
35    /// given (via `permutation_range`), returns a permutation of `0..n`.
36    ///
37    /// # Errors
38    /// Returns `FerrayError::InvalidValue` if the array is empty.
39    pub fn permutation<T>(&mut self, arr: &Array<T, Ix1>) -> Result<Array<T, Ix1>, FerrayError>
40    where
41        T: ferray_core::Element,
42    {
43        let mut copy = arr.clone();
44        self.shuffle(&mut copy)?;
45        Ok(copy)
46    }
47
48    /// Return a permutation of `0..n` as an `Array1<i64>`.
49    ///
50    /// # Errors
51    /// Returns `FerrayError::InvalidValue` if `n` is zero.
52    pub fn permutation_range(&mut self, n: usize) -> Result<Array<i64, Ix1>, FerrayError> {
53        if n == 0 {
54            return Err(FerrayError::invalid_value("n must be > 0"));
55        }
56        let mut data: Vec<i64> = (0..n as i64).collect();
57        // Fisher-Yates
58        for i in (1..n).rev() {
59            let j = self.bg.next_u64_bounded((i + 1) as u64) as usize;
60            data.swap(i, j);
61        }
62        Array::<i64, Ix1>::from_vec(Ix1::new([n]), data)
63    }
64
65    /// Return an array with elements independently permuted along the given axis.
66    ///
67    /// For 1-D arrays, this is the same as `permutation`. This simplified
68    /// implementation operates on 1-D arrays along axis 0.
69    ///
70    /// # Errors
71    /// Returns `FerrayError::InvalidValue` if the array is empty.
72    pub fn permuted<T>(
73        &mut self,
74        arr: &Array<T, Ix1>,
75        _axis: usize,
76    ) -> Result<Array<T, Ix1>, FerrayError>
77    where
78        T: ferray_core::Element,
79    {
80        self.permutation(arr)
81    }
82
83    /// Shuffle an N-D array in place along `axis`, swapping whole
84    /// hyperslices (rows when `axis == 0` for a 2-D array).
85    ///
86    /// Equivalent to `numpy.random.Generator.shuffle(x, axis=axis)`.
87    /// Each pair `(i, j)` selected by Fisher-Yates exchanges *all*
88    /// elements with axis-coordinate `i` and `j` simultaneously, so
89    /// rows / columns / N-D slices keep their internal structure (#447).
90    ///
91    /// # Errors
92    /// - `FerrayError::AxisOutOfBounds` if `axis >= arr.ndim()`.
93    /// - `FerrayError::InvalidValue` if `arr` is non-contiguous.
94    pub fn shuffle_dyn<T>(
95        &mut self,
96        arr: &mut Array<T, IxDyn>,
97        axis: usize,
98    ) -> Result<(), FerrayError>
99    where
100        T: ferray_core::Element,
101    {
102        let shape = arr.shape().to_vec();
103        let ndim = shape.len();
104        if axis >= ndim {
105            return Err(FerrayError::axis_out_of_bounds(axis, ndim));
106        }
107        let n = shape[axis];
108        if n <= 1 {
109            return Ok(());
110        }
111        let inner_stride: usize = shape[axis + 1..].iter().product();
112        let block = n * inner_stride;
113        let outer_size: usize = shape[..axis].iter().product();
114        let slice = arr
115            .as_slice_mut()
116            .ok_or_else(|| FerrayError::invalid_value("array must be contiguous for shuffle"))?;
117        for i in (1..n).rev() {
118            let j = self.bg.next_u64_bounded((i + 1) as u64) as usize;
119            if i == j {
120                continue;
121            }
122            for o in 0..outer_size {
123                let base = o * block;
124                for k in 0..inner_stride {
125                    slice.swap(base + i * inner_stride + k, base + j * inner_stride + k);
126                }
127            }
128        }
129        Ok(())
130    }
131
132    /// Sample N-D hyperslices from `arr` along `axis` (#448).
133    ///
134    /// For each of `size` draws, picks an index along `axis`
135    /// (uniformly or weighted by `p`, with or without replacement)
136    /// and copies the corresponding (N-1)-D slice into the output.
137    /// The output has the same shape as `arr` with the `axis`-th
138    /// dimension replaced by `size`.
139    ///
140    /// Equivalent to `numpy.random.Generator.choice(arr, size, replace, p, axis)`
141    /// for N-D `arr`. The `shuffle` parameter (numpy 1.24+) controls
142    /// whether the indices are returned in selection order
143    /// (`shuffle = true`, default) or sorted (`shuffle = false`,
144    /// only meaningful when `replace = false`).
145    ///
146    /// # Errors
147    /// - `FerrayError::AxisOutOfBounds` if `axis >= arr.ndim()`.
148    /// - `FerrayError::InvalidValue` if the axis dimension is empty,
149    ///   `size > axis_len` with `replace = false`, `arr` is non-contiguous,
150    ///   or `p` is malformed.
151    pub fn choice_dyn<T>(
152        &mut self,
153        arr: &Array<T, IxDyn>,
154        size: usize,
155        replace: bool,
156        p: Option<&[f64]>,
157        axis: usize,
158        shuffle: bool,
159    ) -> Result<Array<T, IxDyn>, FerrayError>
160    where
161        T: ferray_core::Element,
162    {
163        let shape = arr.shape().to_vec();
164        let ndim = shape.len();
165        if axis >= ndim {
166            return Err(FerrayError::axis_out_of_bounds(axis, ndim));
167        }
168        let axis_len = shape[axis];
169        if size == 0 {
170            // numpy: empty sample → output shape with axis dimension = 0
171            let mut out_shape = shape;
172            out_shape[axis] = 0;
173            return Array::<T, IxDyn>::from_vec(IxDyn::new(&out_shape), Vec::new());
174        }
175        if axis_len == 0 {
176            return Err(FerrayError::invalid_value(
177                "choice_dyn: source array has zero length along axis",
178            ));
179        }
180        if !replace && size > axis_len {
181            return Err(FerrayError::invalid_value(format!(
182                "cannot choose {size} elements without replacement from axis of size {axis_len}"
183            )));
184        }
185        if let Some(probs) = p {
186            if probs.len() != axis_len {
187                return Err(FerrayError::invalid_value(format!(
188                    "p must have length {axis_len} (size of axis {axis}), got {}",
189                    probs.len()
190                )));
191            }
192            let psum: f64 = probs.iter().sum();
193            if (psum - 1.0).abs() > 1e-6 {
194                return Err(FerrayError::invalid_value(format!(
195                    "p must sum to 1.0, got {psum}"
196                )));
197            }
198            for (i, &pi) in probs.iter().enumerate() {
199                if pi < 0.0 {
200                    return Err(FerrayError::invalid_value(format!(
201                        "p[{i}] = {pi} is negative"
202                    )));
203                }
204            }
205        }
206
207        let src = arr
208            .as_slice()
209            .ok_or_else(|| FerrayError::invalid_value("array must be contiguous for choice_dyn"))?;
210
211        let mut indices = if let Some(probs) = p {
212            if replace {
213                weighted_sample_with_replacement(&mut self.bg, probs, size)
214            } else {
215                weighted_sample_without_replacement(&mut self.bg, probs, size)?
216            }
217        } else if replace {
218            (0..size)
219                .map(|_| self.bg.next_u64_bounded(axis_len as u64) as usize)
220                .collect()
221        } else {
222            sample_without_replacement(&mut self.bg, axis_len, size)
223        };
224        if !shuffle && !replace {
225            indices.sort_unstable();
226        }
227
228        let inner_stride: usize = shape[axis + 1..].iter().product();
229        let outer_size: usize = shape[..axis].iter().product();
230        let src_block = axis_len * inner_stride;
231        let out_block = size * inner_stride;
232        let total_out = outer_size * out_block;
233
234        let mut out_data: Vec<T> = Vec::with_capacity(total_out);
235        // Pre-fill with clones from index 0 so we can address slots by
236        // index. SAFETY: this avoids unsafe; the trait bound `Element`
237        // requires `Clone`. Cost is one clone per element which is what
238        // numpy does too.
239        let filler = src[0].clone();
240        out_data.resize(total_out, filler);
241        for o in 0..outer_size {
242            let src_base = o * src_block;
243            let out_base = o * out_block;
244            for (i, &idx) in indices.iter().enumerate() {
245                let src_off = src_base + idx * inner_stride;
246                let out_off = out_base + i * inner_stride;
247                out_data[out_off..out_off + inner_stride]
248                    .clone_from_slice(&src[src_off..src_off + inner_stride]);
249            }
250        }
251
252        let mut out_shape = shape;
253        out_shape[axis] = size;
254        Array::<T, IxDyn>::from_vec(IxDyn::new(&out_shape), out_data)
255    }
256
257    /// Independently permute the entries along `axis` of `arr`.
258    ///
259    /// Returns a new array. For each combination of "other" indices
260    /// (everything except `axis`) the values along `axis` are
261    /// shuffled with their own Fisher-Yates pass — so columns of a
262    /// 2-D array get independent permutations when `axis = 0`.
263    /// Equivalent to `numpy.random.Generator.permuted(x, axis=axis)`.
264    ///
265    /// # Errors
266    /// - `FerrayError::AxisOutOfBounds` if `axis >= arr.ndim()`.
267    /// - `FerrayError::InvalidValue` if `arr` is non-contiguous.
268    pub fn permuted_dyn<T>(
269        &mut self,
270        arr: &Array<T, IxDyn>,
271        axis: usize,
272    ) -> Result<Array<T, IxDyn>, FerrayError>
273    where
274        T: ferray_core::Element,
275    {
276        let shape = arr.shape().to_vec();
277        let ndim = shape.len();
278        if axis >= ndim {
279            return Err(FerrayError::axis_out_of_bounds(axis, ndim));
280        }
281        let mut out = arr.clone();
282        let n = shape[axis];
283        if n <= 1 {
284            return Ok(out);
285        }
286        let inner_stride: usize = shape[axis + 1..].iter().product();
287        let block = n * inner_stride;
288        let outer_size: usize = shape[..axis].iter().product();
289        let slice = out
290            .as_slice_mut()
291            .ok_or_else(|| FerrayError::invalid_value("array must be contiguous for permuted"))?;
292        for o in 0..outer_size {
293            let base = o * block;
294            for k in 0..inner_stride {
295                // Independent Fisher-Yates over the n axis positions
296                // at this (outer, inner) coordinate.
297                for i in (1..n).rev() {
298                    let j = self.bg.next_u64_bounded((i + 1) as u64) as usize;
299                    slice.swap(base + i * inner_stride + k, base + j * inner_stride + k);
300                }
301            }
302        }
303        Ok(out)
304    }
305
306    /// Randomly select elements from an array, with or without replacement.
307    ///
308    /// # Arguments
309    /// * `arr` - Source array to sample from.
310    /// * `size` - Number of elements to select.
311    /// * `replace` - If `true`, sample with replacement; if `false`, without.
312    /// * `p` - Optional probability weights (must sum to 1.0 and have same length as `arr`).
313    ///
314    /// # Errors
315    /// Returns `FerrayError::InvalidValue` if parameters are invalid (e.g.,
316    /// `size > arr.len()` when `replace=false`, or invalid probability weights).
317    pub fn choice<T>(
318        &mut self,
319        arr: &Array<T, Ix1>,
320        size: usize,
321        replace: bool,
322        p: Option<&[f64]>,
323    ) -> Result<Array<T, Ix1>, FerrayError>
324    where
325        T: ferray_core::Element,
326    {
327        let n = arr.shape()[0];
328        // size == 0 is valid: NumPy returns an empty array. Only the
329        // source-array-empty case (and only when we actually need a
330        // sample) is still an error (#264, #455).
331        if size == 0 {
332            return Array::from_vec(Ix1::new([0]), Vec::new());
333        }
334        if n == 0 {
335            return Err(FerrayError::invalid_value("source array must be non-empty"));
336        }
337        if !replace && size > n {
338            return Err(FerrayError::invalid_value(format!(
339                "cannot choose {size} elements without replacement from array of size {n}"
340            )));
341        }
342
343        if let Some(probs) = p {
344            if probs.len() != n {
345                return Err(FerrayError::invalid_value(format!(
346                    "p must have same length as array ({n}), got {}",
347                    probs.len()
348                )));
349            }
350            let psum: f64 = probs.iter().sum();
351            if (psum - 1.0).abs() > 1e-6 {
352                return Err(FerrayError::invalid_value(format!(
353                    "p must sum to 1.0, got {psum}"
354                )));
355            }
356            for (i, &pi) in probs.iter().enumerate() {
357                if pi < 0.0 {
358                    return Err(FerrayError::invalid_value(format!(
359                        "p[{i}] = {pi} is negative"
360                    )));
361                }
362            }
363        }
364
365        let src = arr
366            .as_slice()
367            .ok_or_else(|| FerrayError::invalid_value("array must be contiguous"))?;
368
369        let indices = if let Some(probs) = p {
370            // Weighted sampling
371            if replace {
372                weighted_sample_with_replacement(&mut self.bg, probs, size)
373            } else {
374                weighted_sample_without_replacement(&mut self.bg, probs, size)?
375            }
376        } else if replace {
377            // Uniform with replacement
378            (0..size)
379                .map(|_| self.bg.next_u64_bounded(n as u64) as usize)
380                .collect()
381        } else {
382            // Uniform without replacement: partial Fisher-Yates
383            sample_without_replacement(&mut self.bg, n, size)
384        };
385
386        let data: Vec<T> = indices.iter().map(|&i| src[i].clone()).collect();
387        Array::<T, Ix1>::from_vec(Ix1::new([size]), data)
388    }
389}
390
391/// Sample `size` indices from `[0, n)` without replacement using partial Fisher-Yates.
392fn sample_without_replacement<B: BitGenerator>(bg: &mut B, n: usize, size: usize) -> Vec<usize> {
393    let mut pool: Vec<usize> = (0..n).collect();
394    for i in 0..size {
395        let j = i + bg.next_u64_bounded((n - i) as u64) as usize;
396        pool.swap(i, j);
397    }
398    pool[..size].to_vec()
399}
400
401/// Weighted sampling with replacement using Vose's alias method (#265).
402///
403/// Setup is O(n); each sample is O(1) — strictly faster than the
404/// O(log n) binary-search-on-CDF path we used to use, especially at
405/// large `size`. The alias table holds, for each bin `i`, a
406/// "secondary" choice `alias[i]` and a probability `prob[i]` of
407/// sticking with `i`. Sampling: pick `i` uniformly, draw `u ∈ [0, 1)`;
408/// if `u < prob[i]` return `i`, else return `alias[i]`.
409///
410/// Reference: M. D. Vose, "A linear algorithm for generating random
411/// numbers with a given distribution", IEEE TSE 17(9), 1991.
412fn weighted_sample_with_replacement<B: BitGenerator>(
413    bg: &mut B,
414    probs: &[f64],
415    size: usize,
416) -> Vec<usize> {
417    let n = probs.len();
418
419    // Normalize so the sum is exactly n. The alias method works on
420    // probabilities scaled by n: each bin "should" hold mass 1, and we
421    // shuffle excess from heavy bins into light bins.
422    let total: f64 = probs.iter().sum();
423    let mut scaled: Vec<f64> = probs.iter().map(|&p| p * n as f64 / total).collect();
424
425    let mut prob = vec![0.0_f64; n];
426    let mut alias = vec![0_usize; n];
427
428    // Two stacks: indices with mass < 1 vs. mass >= 1.
429    let mut small: Vec<usize> = Vec::with_capacity(n);
430    let mut large: Vec<usize> = Vec::with_capacity(n);
431    for (i, &m) in scaled.iter().enumerate() {
432        if m < 1.0 {
433            small.push(i);
434        } else {
435            large.push(i);
436        }
437    }
438
439    while !small.is_empty() && !large.is_empty() {
440        let s = small.pop().unwrap();
441        let l = large.pop().unwrap();
442        prob[s] = scaled[s];
443        alias[s] = l;
444        // Donate (1 - scaled[s]) of mass from l to fill s.
445        scaled[l] = (scaled[l] + scaled[s]) - 1.0;
446        if scaled[l] < 1.0 {
447            small.push(l);
448        } else {
449            large.push(l);
450        }
451    }
452    // Drain leftovers — these slots have mass exactly 1.0 (modulo
453    // floating-point drift); pin prob[i] = 1.0 so sampling always
454    // returns i for these.
455    for &i in large.iter().chain(small.iter()) {
456        prob[i] = 1.0;
457    }
458
459    (0..size)
460        .map(|_| {
461            let i = bg.next_u64_bounded(n as u64) as usize;
462            let u = bg.next_f64();
463            if u < prob[i] { i } else { alias[i] }
464        })
465        .collect()
466}
467
468/// Weighted sampling without replacement using a sequential elimination method.
469fn weighted_sample_without_replacement<B: BitGenerator>(
470    bg: &mut B,
471    probs: &[f64],
472    size: usize,
473) -> Result<Vec<usize>, FerrayError> {
474    let n = probs.len();
475    let mut weights: Vec<f64> = probs.to_vec();
476    let mut selected = Vec::with_capacity(size);
477
478    for _ in 0..size {
479        let total: f64 = weights.iter().sum();
480        if total <= 0.0 {
481            return Err(FerrayError::invalid_value(
482                "insufficient probability mass for sampling without replacement",
483            ));
484        }
485        let u = bg.next_f64() * total;
486        let mut cumsum = 0.0;
487        let mut chosen = n - 1;
488        for (i, &w) in weights.iter().enumerate() {
489            cumsum += w;
490            if cumsum > u {
491                chosen = i;
492                break;
493            }
494        }
495        selected.push(chosen);
496        weights[chosen] = 0.0;
497    }
498
499    Ok(selected)
500}
501
502#[cfg(test)]
503mod tests {
504    use crate::default_rng_seeded;
505    use ferray_core::{Array, Ix1};
506
507    #[test]
508    fn shuffle_preserves_elements() {
509        let mut rng = default_rng_seeded(42);
510        let mut arr = Array::<i64, Ix1>::from_vec(Ix1::new([5]), vec![1, 2, 3, 4, 5]).unwrap();
511        rng.shuffle(&mut arr).unwrap();
512        let mut sorted: Vec<i64> = arr.as_slice().unwrap().to_vec();
513        sorted.sort_unstable();
514        assert_eq!(sorted, vec![1, 2, 3, 4, 5]);
515    }
516
517    #[test]
518    fn permutation_preserves_elements() {
519        let mut rng = default_rng_seeded(42);
520        let arr = Array::<i64, Ix1>::from_vec(Ix1::new([5]), vec![10, 20, 30, 40, 50]).unwrap();
521        let perm = rng.permutation(&arr).unwrap();
522        let mut sorted: Vec<i64> = perm.as_slice().unwrap().to_vec();
523        sorted.sort_unstable();
524        assert_eq!(sorted, vec![10, 20, 30, 40, 50]);
525    }
526
527    #[test]
528    fn permutation_range_covers_all() {
529        let mut rng = default_rng_seeded(42);
530        let perm = rng.permutation_range(10).unwrap();
531        let mut sorted: Vec<i64> = perm.as_slice().unwrap().to_vec();
532        sorted.sort_unstable();
533        let expected: Vec<i64> = (0..10).collect();
534        assert_eq!(sorted, expected);
535    }
536
537    #[test]
538    fn shuffle_modifies_in_place() {
539        let mut rng = default_rng_seeded(42);
540        let original = vec![1i64, 2, 3, 4, 5, 6, 7, 8, 9, 10];
541        let mut arr = Array::<i64, Ix1>::from_vec(Ix1::new([10]), original.clone()).unwrap();
542        rng.shuffle(&mut arr).unwrap();
543        // Very unlikely (10! - 1 chance) that shuffle produces identity
544        let shuffled = arr.as_slice().unwrap().to_vec();
545        // Just verify it's a valid permutation
546        let mut sorted = shuffled;
547        sorted.sort_unstable();
548        assert_eq!(sorted, original);
549    }
550
551    #[test]
552    fn choice_with_replacement() {
553        let mut rng = default_rng_seeded(42);
554        let arr = Array::<i64, Ix1>::from_vec(Ix1::new([5]), vec![10, 20, 30, 40, 50]).unwrap();
555        let chosen = rng.choice(&arr, 10, true, None).unwrap();
556        assert_eq!(chosen.shape(), &[10]);
557        // All values should be from the original array
558        let src: Vec<i64> = vec![10, 20, 30, 40, 50];
559        for &v in chosen.as_slice().unwrap() {
560            assert!(src.contains(&v), "choice returned unexpected value {v}");
561        }
562    }
563
564    #[test]
565    fn choice_without_replacement_no_duplicates() {
566        let mut rng = default_rng_seeded(42);
567        let arr = Array::<i64, Ix1>::from_vec(Ix1::new([10]), (0..10).collect()).unwrap();
568        let chosen = rng.choice(&arr, 5, false, None).unwrap();
569        let slice = chosen.as_slice().unwrap();
570        // No duplicates
571        let mut seen = std::collections::HashSet::new();
572        for &v in slice {
573            assert!(
574                seen.insert(v),
575                "duplicate value {v} in choice without replacement"
576            );
577        }
578    }
579
580    #[test]
581    fn choice_without_replacement_too_many() {
582        let mut rng = default_rng_seeded(42);
583        let arr = Array::<i64, Ix1>::from_vec(Ix1::new([5]), vec![1, 2, 3, 4, 5]).unwrap();
584        assert!(rng.choice(&arr, 10, false, None).is_err());
585    }
586
587    #[test]
588    fn choice_with_weights() {
589        let mut rng = default_rng_seeded(42);
590        let arr = Array::<i64, Ix1>::from_vec(Ix1::new([3]), vec![10, 20, 30]).unwrap();
591        let p = [0.0, 0.0, 1.0]; // Always pick the last element
592        let chosen = rng.choice(&arr, 10, true, Some(&p)).unwrap();
593        for &v in chosen.as_slice().unwrap() {
594            assert_eq!(v, 30);
595        }
596    }
597
598    #[test]
599    fn choice_without_replacement_with_weights() {
600        let mut rng = default_rng_seeded(42);
601        let arr = Array::<i64, Ix1>::from_vec(Ix1::new([5]), vec![1, 2, 3, 4, 5]).unwrap();
602        let p = [0.1, 0.2, 0.3, 0.2, 0.2];
603        let chosen = rng.choice(&arr, 3, false, Some(&p)).unwrap();
604        let slice = chosen.as_slice().unwrap();
605        // No duplicates
606        let mut seen = std::collections::HashSet::new();
607        for &v in slice {
608            assert!(seen.insert(v), "duplicate value {v}");
609        }
610    }
611
612    #[test]
613    fn choice_bad_weights() {
614        let mut rng = default_rng_seeded(42);
615        let arr = Array::<i64, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
616        // Wrong length
617        assert!(rng.choice(&arr, 1, true, Some(&[0.5, 0.5])).is_err());
618        // Doesn't sum to 1
619        assert!(rng.choice(&arr, 1, true, Some(&[0.5, 0.5, 0.5])).is_err());
620        // Negative
621        assert!(rng.choice(&arr, 1, true, Some(&[-0.1, 0.6, 0.5])).is_err());
622    }
623
624    #[test]
625    fn permuted_1d() {
626        let mut rng = default_rng_seeded(42);
627        let arr = Array::<i64, Ix1>::from_vec(Ix1::new([5]), vec![1, 2, 3, 4, 5]).unwrap();
628        let result = rng.permuted(&arr, 0).unwrap();
629        let mut sorted: Vec<i64> = result.as_slice().unwrap().to_vec();
630        sorted.sort_unstable();
631        assert_eq!(sorted, vec![1, 2, 3, 4, 5]);
632    }
633
634    #[test]
635    fn weighted_with_replacement_alias_distribution_recovers_probs() {
636        // #265: Vose's alias method must produce empirical bin
637        // frequencies that match the input probability vector across a
638        // large sample. Use a deliberately uneven distribution that
639        // exercises the small/large stack rebalancing.
640        let mut rng = default_rng_seeded(42);
641        let arr = Array::<i64, Ix1>::from_vec(Ix1::new([5]), vec![0, 1, 2, 3, 4]).unwrap();
642        let p = [0.05, 0.15, 0.30, 0.40, 0.10];
643        let n = 100_000;
644        let chosen = rng.choice(&arr, n, true, Some(&p)).unwrap();
645        let mut counts = [0_usize; 5];
646        for &v in chosen.as_slice().unwrap() {
647            counts[v as usize] += 1;
648        }
649        // Each empirical frequency must be within 1.5% absolute of
650        // its target — comfortably above the Monte Carlo noise of
651        // sqrt(p(1-p)/n) ~ 0.15% for the largest bin.
652        for (i, &c) in counts.iter().enumerate() {
653            let observed = c as f64 / n as f64;
654            assert!(
655                (observed - p[i]).abs() < 0.015,
656                "bin {i}: observed {observed}, expected {}",
657                p[i]
658            );
659        }
660    }
661
662    // ---- choice_dyn (#448) ---------------------------------------------
663
664    #[test]
665    fn choice_dyn_axis0_picks_whole_rows() {
666        use ferray_core::IxDyn;
667        let mut rng = default_rng_seeded(42);
668        let data: Vec<i64> = (0..5)
669            .flat_map(|i| (0..3).map(move |j| i * 100 + j))
670            .collect();
671        let arr = Array::<i64, IxDyn>::from_vec(IxDyn::new(&[5, 3]), data).unwrap();
672        let chosen = rng.choice_dyn(&arr, 4, true, None, 0, true).unwrap();
673        assert_eq!(chosen.shape(), &[4, 3]);
674        let slice = chosen.as_slice().unwrap();
675        for row in 0..4 {
676            let v0 = slice[row * 3];
677            let id = v0 / 100;
678            assert!((0..5).contains(&id));
679            assert_eq!(slice[row * 3 + 1], id * 100 + 1);
680            assert_eq!(slice[row * 3 + 2], id * 100 + 2);
681        }
682    }
683
684    #[test]
685    fn choice_dyn_axis1_picks_whole_columns() {
686        use ferray_core::IxDyn;
687        let mut rng = default_rng_seeded(7);
688        let data: Vec<i64> = (0..3)
689            .flat_map(|i| (0..6).map(move |j| i * 10 + j))
690            .collect();
691        let arr = Array::<i64, IxDyn>::from_vec(IxDyn::new(&[3, 6]), data).unwrap();
692        let chosen = rng.choice_dyn(&arr, 2, false, None, 1, true).unwrap();
693        assert_eq!(chosen.shape(), &[3, 2]);
694        let slice = chosen.as_slice().unwrap();
695        // Each column in the output must equal one of the original columns
696        // (which all have the form [j, 10+j, 20+j]).
697        for col in 0..2 {
698            let v0 = slice[col];
699            let v1 = slice[2 + col];
700            let v2 = slice[4 + col];
701            assert!((0..6).contains(&v0));
702            assert_eq!(v1, v0 + 10);
703            assert_eq!(v2, v0 + 20);
704        }
705    }
706
707    #[test]
708    fn choice_dyn_without_replacement_no_duplicate_rows() {
709        use ferray_core::IxDyn;
710        let mut rng = default_rng_seeded(1);
711        let data: Vec<i64> = (0..10)
712            .flat_map(|i| (0..2).map(move |j| i * 100 + j))
713            .collect();
714        let arr = Array::<i64, IxDyn>::from_vec(IxDyn::new(&[10, 2]), data).unwrap();
715        let chosen = rng.choice_dyn(&arr, 5, false, None, 0, true).unwrap();
716        let slice = chosen.as_slice().unwrap();
717        let mut ids = std::collections::HashSet::new();
718        for row in 0..5 {
719            let id = slice[row * 2] / 100;
720            assert!(ids.insert(id), "row id {id} repeated under replace=false");
721        }
722    }
723
724    #[test]
725    fn choice_dyn_shuffle_false_returns_sorted_indices() {
726        use ferray_core::IxDyn;
727        let mut rng = default_rng_seeded(3);
728        // Tag each row with its original axis index in column 0; with
729        // shuffle=false + replace=false, the chosen rows must appear
730        // in ascending index order.
731        let data: Vec<i64> = (0..12)
732            .flat_map(|i| (0..2).map(move |j| if j == 0 { i as i64 } else { i as i64 * 10 }))
733            .collect();
734        let arr = Array::<i64, IxDyn>::from_vec(IxDyn::new(&[12, 2]), data).unwrap();
735        let chosen = rng.choice_dyn(&arr, 6, false, None, 0, false).unwrap();
736        let slice = chosen.as_slice().unwrap();
737        let mut last = -1i64;
738        for row in 0..6 {
739            let id = slice[row * 2];
740            assert!(
741                id > last,
742                "shuffle=false output not ascending: {id} after {last}"
743            );
744            last = id;
745        }
746    }
747
748    #[test]
749    fn choice_dyn_weighted_concentrates_on_high_p() {
750        use ferray_core::IxDyn;
751        let mut rng = default_rng_seeded(0);
752        let data: Vec<i64> = (0..4)
753            .flat_map(|i| (0..2).map(move |j| i * 100 + j))
754            .collect();
755        let arr = Array::<i64, IxDyn>::from_vec(IxDyn::new(&[4, 2]), data).unwrap();
756        // All probability on row 2.
757        let p = [0.0, 0.0, 1.0, 0.0];
758        let chosen = rng.choice_dyn(&arr, 20, true, Some(&p), 0, true).unwrap();
759        let slice = chosen.as_slice().unwrap();
760        for row in 0..20 {
761            assert_eq!(slice[row * 2], 200, "weighted choice strayed from row 2");
762        }
763    }
764
765    #[test]
766    fn choice_dyn_size_zero_returns_empty_axis() {
767        use ferray_core::IxDyn;
768        let mut rng = default_rng_seeded(11);
769        let arr = Array::<i64, IxDyn>::from_vec(IxDyn::new(&[3, 4]), (0..12).collect()).unwrap();
770        let chosen = rng.choice_dyn(&arr, 0, true, None, 0, true).unwrap();
771        assert_eq!(chosen.shape(), &[0, 4]);
772    }
773
774    #[test]
775    fn choice_dyn_bad_axis() {
776        use ferray_core::IxDyn;
777        let mut rng = default_rng_seeded(0);
778        let arr = Array::<i64, IxDyn>::from_vec(IxDyn::new(&[2, 3]), (0..6).collect()).unwrap();
779        assert!(rng.choice_dyn(&arr, 1, true, None, 5, true).is_err());
780    }
781
782    #[test]
783    fn choice_dyn_too_many_no_replace_errors() {
784        use ferray_core::IxDyn;
785        let mut rng = default_rng_seeded(0);
786        let arr = Array::<i64, IxDyn>::from_vec(IxDyn::new(&[3, 2]), (0..6).collect()).unwrap();
787        assert!(rng.choice_dyn(&arr, 5, false, None, 0, true).is_err());
788    }
789
790    // ---- shuffle_dyn / permuted_dyn (#447) -----------------------------
791
792    #[test]
793    fn shuffle_dyn_axis0_swaps_whole_rows() {
794        use ferray_core::IxDyn;
795        let mut rng = default_rng_seeded(42);
796        // 4×3: rows are [0,1,2], [10,11,12], [20,21,22], [30,31,32]
797        let data: Vec<i64> = (0..4)
798            .flat_map(|i| (0..3).map(move |j| i * 10 + j))
799            .collect();
800        let mut arr = Array::<i64, IxDyn>::from_vec(IxDyn::new(&[4, 3]), data).unwrap();
801        rng.shuffle_dyn(&mut arr, 0).unwrap();
802        let slice = arr.as_slice().unwrap();
803        // Each row must still be one of the originals — internal layout preserved.
804        let mut seen = std::collections::HashSet::new();
805        for row in 0..4 {
806            let row_first = slice[row * 3];
807            let id = row_first / 10;
808            assert!(
809                (0..4).contains(&id),
810                "row {row} starts with unexpected value {row_first}"
811            );
812            assert_eq!(slice[row * 3 + 1], id * 10 + 1);
813            assert_eq!(slice[row * 3 + 2], id * 10 + 2);
814            assert!(
815                seen.insert(id),
816                "row id {id} duplicated — shuffle broke a row"
817            );
818        }
819    }
820
821    #[test]
822    fn shuffle_dyn_axis1_swaps_whole_columns() {
823        use ferray_core::IxDyn;
824        let mut rng = default_rng_seeded(7);
825        // 3×4: column j is [j, 10+j, 20+j].
826        let data: Vec<i64> = (0..3)
827            .flat_map(|i| (0..4).map(move |j| i * 10 + j))
828            .collect();
829        let mut arr = Array::<i64, IxDyn>::from_vec(IxDyn::new(&[3, 4]), data).unwrap();
830        rng.shuffle_dyn(&mut arr, 1).unwrap();
831        let slice = arr.as_slice().unwrap();
832        // Each column must still equal one of the original column patterns.
833        let mut col_ids = Vec::new();
834        for col in 0..4 {
835            let v0 = slice[col];
836            let v1 = slice[4 + col];
837            let v2 = slice[8 + col];
838            assert!((0..4).contains(&v0));
839            assert_eq!(v1, v0 + 10);
840            assert_eq!(v2, v0 + 20);
841            col_ids.push(v0);
842        }
843        col_ids.sort_unstable();
844        assert_eq!(col_ids, vec![0, 1, 2, 3]);
845    }
846
847    #[test]
848    fn shuffle_dyn_axis_out_of_bounds() {
849        use ferray_core::IxDyn;
850        let mut rng = default_rng_seeded(0);
851        let mut arr = Array::<i64, IxDyn>::from_vec(IxDyn::new(&[2, 3]), vec![0; 6]).unwrap();
852        assert!(rng.shuffle_dyn(&mut arr, 2).is_err());
853    }
854
855    #[test]
856    fn permuted_dyn_axis0_each_column_independent() {
857        use ferray_core::IxDyn;
858        let mut rng = default_rng_seeded(99);
859        // 5×4 array; permuted along axis=0 → each column is independently
860        // shuffled, so a row is a *re-mix* of column-wise positions, not a
861        // whole-row swap.
862        let n_rows = 5;
863        let n_cols = 4;
864        let data: Vec<i64> = (0..n_rows * n_cols).map(|x| x as i64).collect();
865        let arr =
866            Array::<i64, IxDyn>::from_vec(IxDyn::new(&[n_rows, n_cols]), data.clone()).unwrap();
867        let result = rng.permuted_dyn(&arr, 0).unwrap();
868        let slice = result.as_slice().unwrap();
869        // Each column must be a permutation of the original column values.
870        for col in 0..n_cols {
871            let original_col: Vec<i64> = (0..n_rows).map(|r| (r * n_cols + col) as i64).collect();
872            let mut got_col: Vec<i64> = (0..n_rows).map(|r| slice[r * n_cols + col]).collect();
873            got_col.sort_unstable();
874            let mut want = original_col.clone();
875            want.sort_unstable();
876            assert_eq!(got_col, want, "col {col} lost values during permute");
877        }
878    }
879
880    #[test]
881    fn permuted_dyn_columns_can_diverge() {
882        use ferray_core::IxDyn;
883        // Permuted should produce different per-column orderings — across
884        // many trials the probability that all columns still match each
885        // other for a 5-row 4-column array is (1/120)^3 ≈ 1e-6.
886        let mut rng = default_rng_seeded(1234);
887        let n_rows = 5;
888        let n_cols = 4;
889        let data: Vec<i64> = (0..n_rows * n_cols)
890            .map(|x| x as i64 % n_rows as i64)
891            .collect();
892        let arr =
893            Array::<i64, IxDyn>::from_vec(IxDyn::new(&[n_rows, n_cols]), data.clone()).unwrap();
894        let result = rng.permuted_dyn(&arr, 0).unwrap();
895        let slice = result.as_slice().unwrap();
896        // Reference column 0 against each other column. At least one must differ.
897        let col0: Vec<i64> = (0..n_rows).map(|r| slice[r * n_cols]).collect();
898        let mut any_diff = false;
899        for col in 1..n_cols {
900            let coln: Vec<i64> = (0..n_rows).map(|r| slice[r * n_cols + col]).collect();
901            if col0 != coln {
902                any_diff = true;
903                break;
904            }
905        }
906        assert!(
907            any_diff,
908            "all columns matched — permuted didn't independently shuffle"
909        );
910    }
911
912    #[test]
913    fn permuted_dyn_seed_reproducible() {
914        use ferray_core::IxDyn;
915        let mut a = default_rng_seeded(31);
916        let mut b = default_rng_seeded(31);
917        let arr = Array::<i64, IxDyn>::from_vec(IxDyn::new(&[3, 3]), (0..9).collect()).unwrap();
918        let xa = a.permuted_dyn(&arr, 1).unwrap();
919        let xb = b.permuted_dyn(&arr, 1).unwrap();
920        assert_eq!(xa.as_slice().unwrap(), xb.as_slice().unwrap());
921    }
922
923    #[test]
924    fn weighted_with_replacement_unnormalized_probs() {
925        // The alias setup normalizes probs internally; a vector that
926        // sums to !=1 must produce the same empirical distribution as
927        // its normalized counterpart. (We bypass `choice`'s strict
928        // sum-to-1 validation by hitting the inner function path —
929        // here we test the user-facing path with an exact input.)
930        let mut rng = default_rng_seeded(42);
931        let arr = Array::<i64, Ix1>::from_vec(Ix1::new([3]), vec![0, 1, 2]).unwrap();
932        // Already-normalized comparison input.
933        let p = [0.2, 0.5, 0.3];
934        let n = 50_000;
935        let chosen = rng.choice(&arr, n, true, Some(&p)).unwrap();
936        let mut counts = [0_usize; 3];
937        for &v in chosen.as_slice().unwrap() {
938            counts[v as usize] += 1;
939        }
940        for (i, &c) in counts.iter().enumerate() {
941            let observed = c as f64 / n as f64;
942            assert!(
943                (observed - p[i]).abs() < 0.02,
944                "bin {i}: observed {observed}, expected {}",
945                p[i]
946            );
947        }
948    }
949}