Skip to main content

ferray_random/
permutations.rs

1// ferray-random: Permutations and sampling — shuffle, permutation, permuted, choice
2
3use ferray_core::{Array, FerrayError, Ix1};
4
5use crate::bitgen::BitGenerator;
6use crate::generator::Generator;
7
8impl<B: BitGenerator> Generator<B> {
9    /// Shuffle a 1-D array in-place using Fisher-Yates.
10    ///
11    /// # Errors
12    /// Returns `FerrayError::InvalidValue` if the array is not contiguous.
13    pub fn shuffle<T>(&mut self, arr: &mut Array<T, Ix1>) -> Result<(), FerrayError>
14    where
15        T: ferray_core::Element,
16    {
17        let n = arr.shape()[0];
18        if n <= 1 {
19            return Ok(());
20        }
21        let slice = arr
22            .as_slice_mut()
23            .ok_or_else(|| FerrayError::invalid_value("array must be contiguous for shuffle"))?;
24        // Fisher-Yates
25        for i in (1..n).rev() {
26            let j = self.bg.next_u64_bounded((i + 1) as u64) as usize;
27            slice.swap(i, j);
28        }
29        Ok(())
30    }
31
32    /// Return a new array with elements randomly permuted.
33    ///
34    /// If the input is 1-D, returns a shuffled copy. If an integer `n` is
35    /// given (via `permutation_range`), returns a permutation of `0..n`.
36    ///
37    /// # Errors
38    /// Returns `FerrayError::InvalidValue` if the array is empty.
39    pub fn permutation<T>(&mut self, arr: &Array<T, Ix1>) -> Result<Array<T, Ix1>, FerrayError>
40    where
41        T: ferray_core::Element,
42    {
43        let mut copy = arr.clone();
44        self.shuffle(&mut copy)?;
45        Ok(copy)
46    }
47
48    /// Return a permutation of `0..n` as an `Array1<i64>`.
49    ///
50    /// # Errors
51    /// Returns `FerrayError::InvalidValue` if `n` is zero.
52    pub fn permutation_range(&mut self, n: usize) -> Result<Array<i64, Ix1>, FerrayError> {
53        if n == 0 {
54            return Err(FerrayError::invalid_value("n must be > 0"));
55        }
56        let mut data: Vec<i64> = (0..n as i64).collect();
57        // Fisher-Yates
58        for i in (1..n).rev() {
59            let j = self.bg.next_u64_bounded((i + 1) as u64) as usize;
60            data.swap(i, j);
61        }
62        Array::<i64, Ix1>::from_vec(Ix1::new([n]), data)
63    }
64
65    /// Return an array with elements independently permuted along the given axis.
66    ///
67    /// For 1-D arrays, this is the same as `permutation`. This simplified
68    /// implementation operates on 1-D arrays along axis 0.
69    ///
70    /// # Errors
71    /// Returns `FerrayError::InvalidValue` if the array is empty.
72    pub fn permuted<T>(
73        &mut self,
74        arr: &Array<T, Ix1>,
75        _axis: usize,
76    ) -> Result<Array<T, Ix1>, FerrayError>
77    where
78        T: ferray_core::Element,
79    {
80        self.permutation(arr)
81    }
82
83    /// Randomly select elements from an array, with or without replacement.
84    ///
85    /// # Arguments
86    /// * `arr` - Source array to sample from.
87    /// * `size` - Number of elements to select.
88    /// * `replace` - If `true`, sample with replacement; if `false`, without.
89    /// * `p` - Optional probability weights (must sum to 1.0 and have same length as `arr`).
90    ///
91    /// # Errors
92    /// Returns `FerrayError::InvalidValue` if parameters are invalid (e.g.,
93    /// `size > arr.len()` when `replace=false`, or invalid probability weights).
94    pub fn choice<T>(
95        &mut self,
96        arr: &Array<T, Ix1>,
97        size: usize,
98        replace: bool,
99        p: Option<&[f64]>,
100    ) -> Result<Array<T, Ix1>, FerrayError>
101    where
102        T: ferray_core::Element,
103    {
104        let n = arr.shape()[0];
105        // size == 0 is valid: NumPy returns an empty array. Only the
106        // source-array-empty case (and only when we actually need a
107        // sample) is still an error (#264, #455).
108        if size == 0 {
109            return Array::from_vec(Ix1::new([0]), Vec::new());
110        }
111        if n == 0 {
112            return Err(FerrayError::invalid_value("source array must be non-empty"));
113        }
114        if !replace && size > n {
115            return Err(FerrayError::invalid_value(format!(
116                "cannot choose {size} elements without replacement from array of size {n}"
117            )));
118        }
119
120        if let Some(probs) = p {
121            if probs.len() != n {
122                return Err(FerrayError::invalid_value(format!(
123                    "p must have same length as array ({n}), got {}",
124                    probs.len()
125                )));
126            }
127            let psum: f64 = probs.iter().sum();
128            if (psum - 1.0).abs() > 1e-6 {
129                return Err(FerrayError::invalid_value(format!(
130                    "p must sum to 1.0, got {psum}"
131                )));
132            }
133            for (i, &pi) in probs.iter().enumerate() {
134                if pi < 0.0 {
135                    return Err(FerrayError::invalid_value(format!(
136                        "p[{i}] = {pi} is negative"
137                    )));
138                }
139            }
140        }
141
142        let src = arr
143            .as_slice()
144            .ok_or_else(|| FerrayError::invalid_value("array must be contiguous"))?;
145
146        let indices = if let Some(probs) = p {
147            // Weighted sampling
148            if replace {
149                weighted_sample_with_replacement(&mut self.bg, probs, size)
150            } else {
151                weighted_sample_without_replacement(&mut self.bg, probs, size)?
152            }
153        } else if replace {
154            // Uniform with replacement
155            (0..size)
156                .map(|_| self.bg.next_u64_bounded(n as u64) as usize)
157                .collect()
158        } else {
159            // Uniform without replacement: partial Fisher-Yates
160            sample_without_replacement(&mut self.bg, n, size)
161        };
162
163        let data: Vec<T> = indices.iter().map(|&i| src[i].clone()).collect();
164        Array::<T, Ix1>::from_vec(Ix1::new([size]), data)
165    }
166}
167
168/// Sample `size` indices from `[0, n)` without replacement using partial Fisher-Yates.
169fn sample_without_replacement<B: BitGenerator>(bg: &mut B, n: usize, size: usize) -> Vec<usize> {
170    let mut pool: Vec<usize> = (0..n).collect();
171    for i in 0..size {
172        let j = i + bg.next_u64_bounded((n - i) as u64) as usize;
173        pool.swap(i, j);
174    }
175    pool[..size].to_vec()
176}
177
178/// Weighted sampling with replacement using the inverse CDF method.
179fn weighted_sample_with_replacement<B: BitGenerator>(
180    bg: &mut B,
181    probs: &[f64],
182    size: usize,
183) -> Vec<usize> {
184    // Build cumulative distribution
185    let mut cdf = Vec::with_capacity(probs.len());
186    let mut cumsum = 0.0;
187    for &p in probs {
188        cumsum += p;
189        cdf.push(cumsum);
190    }
191
192    (0..size)
193        .map(|_| {
194            let u = bg.next_f64();
195            // Binary search in CDF
196            match cdf.binary_search_by(|c| c.partial_cmp(&u).unwrap_or(std::cmp::Ordering::Equal)) {
197                Ok(i) => i,
198                Err(i) => i.min(probs.len() - 1),
199            }
200        })
201        .collect()
202}
203
204/// Weighted sampling without replacement using a sequential elimination method.
205fn weighted_sample_without_replacement<B: BitGenerator>(
206    bg: &mut B,
207    probs: &[f64],
208    size: usize,
209) -> Result<Vec<usize>, FerrayError> {
210    let n = probs.len();
211    let mut weights: Vec<f64> = probs.to_vec();
212    let mut selected = Vec::with_capacity(size);
213
214    for _ in 0..size {
215        let total: f64 = weights.iter().sum();
216        if total <= 0.0 {
217            return Err(FerrayError::invalid_value(
218                "insufficient probability mass for sampling without replacement",
219            ));
220        }
221        let u = bg.next_f64() * total;
222        let mut cumsum = 0.0;
223        let mut chosen = n - 1;
224        for (i, &w) in weights.iter().enumerate() {
225            cumsum += w;
226            if cumsum > u {
227                chosen = i;
228                break;
229            }
230        }
231        selected.push(chosen);
232        weights[chosen] = 0.0;
233    }
234
235    Ok(selected)
236}
237
238#[cfg(test)]
239mod tests {
240    use crate::default_rng_seeded;
241    use ferray_core::{Array, Ix1};
242
243    #[test]
244    fn shuffle_preserves_elements() {
245        let mut rng = default_rng_seeded(42);
246        let mut arr = Array::<i64, Ix1>::from_vec(Ix1::new([5]), vec![1, 2, 3, 4, 5]).unwrap();
247        rng.shuffle(&mut arr).unwrap();
248        let mut sorted: Vec<i64> = arr.as_slice().unwrap().to_vec();
249        sorted.sort();
250        assert_eq!(sorted, vec![1, 2, 3, 4, 5]);
251    }
252
253    #[test]
254    fn permutation_preserves_elements() {
255        let mut rng = default_rng_seeded(42);
256        let arr = Array::<i64, Ix1>::from_vec(Ix1::new([5]), vec![10, 20, 30, 40, 50]).unwrap();
257        let perm = rng.permutation(&arr).unwrap();
258        let mut sorted: Vec<i64> = perm.as_slice().unwrap().to_vec();
259        sorted.sort();
260        assert_eq!(sorted, vec![10, 20, 30, 40, 50]);
261    }
262
263    #[test]
264    fn permutation_range_covers_all() {
265        let mut rng = default_rng_seeded(42);
266        let perm = rng.permutation_range(10).unwrap();
267        let mut sorted: Vec<i64> = perm.as_slice().unwrap().to_vec();
268        sorted.sort();
269        let expected: Vec<i64> = (0..10).collect();
270        assert_eq!(sorted, expected);
271    }
272
273    #[test]
274    fn shuffle_modifies_in_place() {
275        let mut rng = default_rng_seeded(42);
276        let original = vec![1i64, 2, 3, 4, 5, 6, 7, 8, 9, 10];
277        let mut arr = Array::<i64, Ix1>::from_vec(Ix1::new([10]), original.clone()).unwrap();
278        rng.shuffle(&mut arr).unwrap();
279        // Very unlikely (10! - 1 chance) that shuffle produces identity
280        let shuffled = arr.as_slice().unwrap().to_vec();
281        // Just verify it's a valid permutation
282        let mut sorted = shuffled.clone();
283        sorted.sort();
284        assert_eq!(sorted, original);
285    }
286
287    #[test]
288    fn choice_with_replacement() {
289        let mut rng = default_rng_seeded(42);
290        let arr = Array::<i64, Ix1>::from_vec(Ix1::new([5]), vec![10, 20, 30, 40, 50]).unwrap();
291        let chosen = rng.choice(&arr, 10, true, None).unwrap();
292        assert_eq!(chosen.shape(), &[10]);
293        // All values should be from the original array
294        let src: Vec<i64> = vec![10, 20, 30, 40, 50];
295        for &v in chosen.as_slice().unwrap() {
296            assert!(src.contains(&v), "choice returned unexpected value {v}");
297        }
298    }
299
300    #[test]
301    fn choice_without_replacement_no_duplicates() {
302        let mut rng = default_rng_seeded(42);
303        let arr = Array::<i64, Ix1>::from_vec(Ix1::new([10]), (0..10).collect()).unwrap();
304        let chosen = rng.choice(&arr, 5, false, None).unwrap();
305        let slice = chosen.as_slice().unwrap();
306        // No duplicates
307        let mut seen = std::collections::HashSet::new();
308        for &v in slice {
309            assert!(
310                seen.insert(v),
311                "duplicate value {v} in choice without replacement"
312            );
313        }
314    }
315
316    #[test]
317    fn choice_without_replacement_too_many() {
318        let mut rng = default_rng_seeded(42);
319        let arr = Array::<i64, Ix1>::from_vec(Ix1::new([5]), vec![1, 2, 3, 4, 5]).unwrap();
320        assert!(rng.choice(&arr, 10, false, None).is_err());
321    }
322
323    #[test]
324    fn choice_with_weights() {
325        let mut rng = default_rng_seeded(42);
326        let arr = Array::<i64, Ix1>::from_vec(Ix1::new([3]), vec![10, 20, 30]).unwrap();
327        let p = [0.0, 0.0, 1.0]; // Always pick the last element
328        let chosen = rng.choice(&arr, 10, true, Some(&p)).unwrap();
329        for &v in chosen.as_slice().unwrap() {
330            assert_eq!(v, 30);
331        }
332    }
333
334    #[test]
335    fn choice_without_replacement_with_weights() {
336        let mut rng = default_rng_seeded(42);
337        let arr = Array::<i64, Ix1>::from_vec(Ix1::new([5]), vec![1, 2, 3, 4, 5]).unwrap();
338        let p = [0.1, 0.2, 0.3, 0.2, 0.2];
339        let chosen = rng.choice(&arr, 3, false, Some(&p)).unwrap();
340        let slice = chosen.as_slice().unwrap();
341        // No duplicates
342        let mut seen = std::collections::HashSet::new();
343        for &v in slice {
344            assert!(seen.insert(v), "duplicate value {v}");
345        }
346    }
347
348    #[test]
349    fn choice_bad_weights() {
350        let mut rng = default_rng_seeded(42);
351        let arr = Array::<i64, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
352        // Wrong length
353        assert!(rng.choice(&arr, 1, true, Some(&[0.5, 0.5])).is_err());
354        // Doesn't sum to 1
355        assert!(rng.choice(&arr, 1, true, Some(&[0.5, 0.5, 0.5])).is_err());
356        // Negative
357        assert!(rng.choice(&arr, 1, true, Some(&[-0.1, 0.6, 0.5])).is_err());
358    }
359
360    #[test]
361    fn permuted_1d() {
362        let mut rng = default_rng_seeded(42);
363        let arr = Array::<i64, Ix1>::from_vec(Ix1::new([5]), vec![1, 2, 3, 4, 5]).unwrap();
364        let result = rng.permuted(&arr, 0).unwrap();
365        let mut sorted: Vec<i64> = result.as_slice().unwrap().to_vec();
366        sorted.sort();
367        assert_eq!(sorted, vec![1, 2, 3, 4, 5]);
368    }
369}