Skip to main content

ferray_random/
permutations.rs

1// ferray-random: Permutations and sampling — shuffle, permutation, permuted, choice
2
3use ferray_core::{Array, FerrayError, Ix1};
4
5use crate::bitgen::BitGenerator;
6use crate::generator::Generator;
7
8impl<B: BitGenerator> Generator<B> {
9    /// Shuffle a 1-D array in-place using Fisher-Yates.
10    ///
11    /// # Errors
12    /// Returns `FerrayError::InvalidValue` if the array is not contiguous.
13    pub fn shuffle<T>(&mut self, arr: &mut Array<T, Ix1>) -> Result<(), FerrayError>
14    where
15        T: ferray_core::Element,
16    {
17        let n = arr.shape()[0];
18        if n <= 1 {
19            return Ok(());
20        }
21        let slice = arr
22            .as_slice_mut()
23            .ok_or_else(|| FerrayError::invalid_value("array must be contiguous for shuffle"))?;
24        // Fisher-Yates
25        for i in (1..n).rev() {
26            let j = self.bg.next_u64_bounded((i + 1) as u64) as usize;
27            slice.swap(i, j);
28        }
29        Ok(())
30    }
31
32    /// Return a new array with elements randomly permuted.
33    ///
34    /// If the input is 1-D, returns a shuffled copy. If an integer `n` is
35    /// given (via `permutation_range`), returns a permutation of `0..n`.
36    ///
37    /// # Errors
38    /// Returns `FerrayError::InvalidValue` if the array is empty.
39    pub fn permutation<T>(&mut self, arr: &Array<T, Ix1>) -> Result<Array<T, Ix1>, FerrayError>
40    where
41        T: ferray_core::Element,
42    {
43        let mut copy = arr.clone();
44        self.shuffle(&mut copy)?;
45        Ok(copy)
46    }
47
48    /// Return a permutation of `0..n` as an `Array1<i64>`.
49    ///
50    /// # Errors
51    /// Returns `FerrayError::InvalidValue` if `n` is zero.
52    pub fn permutation_range(&mut self, n: usize) -> Result<Array<i64, Ix1>, FerrayError> {
53        if n == 0 {
54            return Err(FerrayError::invalid_value("n must be > 0"));
55        }
56        let mut data: Vec<i64> = (0..n as i64).collect();
57        // Fisher-Yates
58        for i in (1..n).rev() {
59            let j = self.bg.next_u64_bounded((i + 1) as u64) as usize;
60            data.swap(i, j);
61        }
62        Array::<i64, Ix1>::from_vec(Ix1::new([n]), data)
63    }
64
65    /// Return an array with elements independently permuted along the given axis.
66    ///
67    /// For 1-D arrays, this is the same as `permutation`. This simplified
68    /// implementation operates on 1-D arrays along axis 0.
69    ///
70    /// # Errors
71    /// Returns `FerrayError::InvalidValue` if the array is empty.
72    pub fn permuted<T>(
73        &mut self,
74        arr: &Array<T, Ix1>,
75        _axis: usize,
76    ) -> Result<Array<T, Ix1>, FerrayError>
77    where
78        T: ferray_core::Element,
79    {
80        self.permutation(arr)
81    }
82
83    /// Randomly select elements from an array, with or without replacement.
84    ///
85    /// # Arguments
86    /// * `arr` - Source array to sample from.
87    /// * `size` - Number of elements to select.
88    /// * `replace` - If `true`, sample with replacement; if `false`, without.
89    /// * `p` - Optional probability weights (must sum to 1.0 and have same length as `arr`).
90    ///
91    /// # Errors
92    /// Returns `FerrayError::InvalidValue` if parameters are invalid (e.g.,
93    /// `size > arr.len()` when `replace=false`, or invalid probability weights).
94    pub fn choice<T>(
95        &mut self,
96        arr: &Array<T, Ix1>,
97        size: usize,
98        replace: bool,
99        p: Option<&[f64]>,
100    ) -> Result<Array<T, Ix1>, FerrayError>
101    where
102        T: ferray_core::Element,
103    {
104        let n = arr.shape()[0];
105        if size == 0 {
106            return Err(FerrayError::invalid_value("size must be > 0"));
107        }
108        if n == 0 {
109            return Err(FerrayError::invalid_value("source array must be non-empty"));
110        }
111        if !replace && size > n {
112            return Err(FerrayError::invalid_value(format!(
113                "cannot choose {size} elements without replacement from array of size {n}"
114            )));
115        }
116
117        if let Some(probs) = p {
118            if probs.len() != n {
119                return Err(FerrayError::invalid_value(format!(
120                    "p must have same length as array ({n}), got {}",
121                    probs.len()
122                )));
123            }
124            let psum: f64 = probs.iter().sum();
125            if (psum - 1.0).abs() > 1e-6 {
126                return Err(FerrayError::invalid_value(format!(
127                    "p must sum to 1.0, got {psum}"
128                )));
129            }
130            for (i, &pi) in probs.iter().enumerate() {
131                if pi < 0.0 {
132                    return Err(FerrayError::invalid_value(format!(
133                        "p[{i}] = {pi} is negative"
134                    )));
135                }
136            }
137        }
138
139        let src = arr
140            .as_slice()
141            .ok_or_else(|| FerrayError::invalid_value("array must be contiguous"))?;
142
143        let indices = if let Some(probs) = p {
144            // Weighted sampling
145            if replace {
146                weighted_sample_with_replacement(&mut self.bg, probs, size)
147            } else {
148                weighted_sample_without_replacement(&mut self.bg, probs, size)?
149            }
150        } else if replace {
151            // Uniform with replacement
152            (0..size)
153                .map(|_| self.bg.next_u64_bounded(n as u64) as usize)
154                .collect()
155        } else {
156            // Uniform without replacement: partial Fisher-Yates
157            sample_without_replacement(&mut self.bg, n, size)
158        };
159
160        let data: Vec<T> = indices.iter().map(|&i| src[i].clone()).collect();
161        Array::<T, Ix1>::from_vec(Ix1::new([size]), data)
162    }
163}
164
165/// Sample `size` indices from `[0, n)` without replacement using partial Fisher-Yates.
166fn sample_without_replacement<B: BitGenerator>(bg: &mut B, n: usize, size: usize) -> Vec<usize> {
167    let mut pool: Vec<usize> = (0..n).collect();
168    for i in 0..size {
169        let j = i + bg.next_u64_bounded((n - i) as u64) as usize;
170        pool.swap(i, j);
171    }
172    pool[..size].to_vec()
173}
174
175/// Weighted sampling with replacement using the inverse CDF method.
176fn weighted_sample_with_replacement<B: BitGenerator>(
177    bg: &mut B,
178    probs: &[f64],
179    size: usize,
180) -> Vec<usize> {
181    // Build cumulative distribution
182    let mut cdf = Vec::with_capacity(probs.len());
183    let mut cumsum = 0.0;
184    for &p in probs {
185        cumsum += p;
186        cdf.push(cumsum);
187    }
188
189    (0..size)
190        .map(|_| {
191            let u = bg.next_f64();
192            // Binary search in CDF
193            match cdf.binary_search_by(|c| c.partial_cmp(&u).unwrap_or(std::cmp::Ordering::Equal)) {
194                Ok(i) => i,
195                Err(i) => i.min(probs.len() - 1),
196            }
197        })
198        .collect()
199}
200
201/// Weighted sampling without replacement using a sequential elimination method.
202fn weighted_sample_without_replacement<B: BitGenerator>(
203    bg: &mut B,
204    probs: &[f64],
205    size: usize,
206) -> Result<Vec<usize>, FerrayError> {
207    let n = probs.len();
208    let mut weights: Vec<f64> = probs.to_vec();
209    let mut selected = Vec::with_capacity(size);
210
211    for _ in 0..size {
212        let total: f64 = weights.iter().sum();
213        if total <= 0.0 {
214            return Err(FerrayError::invalid_value(
215                "insufficient probability mass for sampling without replacement",
216            ));
217        }
218        let u = bg.next_f64() * total;
219        let mut cumsum = 0.0;
220        let mut chosen = n - 1;
221        for (i, &w) in weights.iter().enumerate() {
222            cumsum += w;
223            if cumsum > u {
224                chosen = i;
225                break;
226            }
227        }
228        selected.push(chosen);
229        weights[chosen] = 0.0;
230    }
231
232    Ok(selected)
233}
234
235#[cfg(test)]
236mod tests {
237    use crate::default_rng_seeded;
238    use ferray_core::{Array, Ix1};
239
240    #[test]
241    fn shuffle_preserves_elements() {
242        let mut rng = default_rng_seeded(42);
243        let mut arr = Array::<i64, Ix1>::from_vec(Ix1::new([5]), vec![1, 2, 3, 4, 5]).unwrap();
244        rng.shuffle(&mut arr).unwrap();
245        let mut sorted: Vec<i64> = arr.as_slice().unwrap().to_vec();
246        sorted.sort();
247        assert_eq!(sorted, vec![1, 2, 3, 4, 5]);
248    }
249
250    #[test]
251    fn permutation_preserves_elements() {
252        let mut rng = default_rng_seeded(42);
253        let arr = Array::<i64, Ix1>::from_vec(Ix1::new([5]), vec![10, 20, 30, 40, 50]).unwrap();
254        let perm = rng.permutation(&arr).unwrap();
255        let mut sorted: Vec<i64> = perm.as_slice().unwrap().to_vec();
256        sorted.sort();
257        assert_eq!(sorted, vec![10, 20, 30, 40, 50]);
258    }
259
260    #[test]
261    fn permutation_range_covers_all() {
262        let mut rng = default_rng_seeded(42);
263        let perm = rng.permutation_range(10).unwrap();
264        let mut sorted: Vec<i64> = perm.as_slice().unwrap().to_vec();
265        sorted.sort();
266        let expected: Vec<i64> = (0..10).collect();
267        assert_eq!(sorted, expected);
268    }
269
270    #[test]
271    fn shuffle_modifies_in_place() {
272        let mut rng = default_rng_seeded(42);
273        let original = vec![1i64, 2, 3, 4, 5, 6, 7, 8, 9, 10];
274        let mut arr = Array::<i64, Ix1>::from_vec(Ix1::new([10]), original.clone()).unwrap();
275        rng.shuffle(&mut arr).unwrap();
276        // Very unlikely (10! - 1 chance) that shuffle produces identity
277        let shuffled = arr.as_slice().unwrap().to_vec();
278        // Just verify it's a valid permutation
279        let mut sorted = shuffled.clone();
280        sorted.sort();
281        assert_eq!(sorted, original);
282    }
283
284    #[test]
285    fn choice_with_replacement() {
286        let mut rng = default_rng_seeded(42);
287        let arr = Array::<i64, Ix1>::from_vec(Ix1::new([5]), vec![10, 20, 30, 40, 50]).unwrap();
288        let chosen = rng.choice(&arr, 10, true, None).unwrap();
289        assert_eq!(chosen.shape(), &[10]);
290        // All values should be from the original array
291        let src: Vec<i64> = vec![10, 20, 30, 40, 50];
292        for &v in chosen.as_slice().unwrap() {
293            assert!(src.contains(&v), "choice returned unexpected value {v}");
294        }
295    }
296
297    #[test]
298    fn choice_without_replacement_no_duplicates() {
299        let mut rng = default_rng_seeded(42);
300        let arr = Array::<i64, Ix1>::from_vec(Ix1::new([10]), (0..10).collect()).unwrap();
301        let chosen = rng.choice(&arr, 5, false, None).unwrap();
302        let slice = chosen.as_slice().unwrap();
303        // No duplicates
304        let mut seen = std::collections::HashSet::new();
305        for &v in slice {
306            assert!(
307                seen.insert(v),
308                "duplicate value {v} in choice without replacement"
309            );
310        }
311    }
312
313    #[test]
314    fn choice_without_replacement_too_many() {
315        let mut rng = default_rng_seeded(42);
316        let arr = Array::<i64, Ix1>::from_vec(Ix1::new([5]), vec![1, 2, 3, 4, 5]).unwrap();
317        assert!(rng.choice(&arr, 10, false, None).is_err());
318    }
319
320    #[test]
321    fn choice_with_weights() {
322        let mut rng = default_rng_seeded(42);
323        let arr = Array::<i64, Ix1>::from_vec(Ix1::new([3]), vec![10, 20, 30]).unwrap();
324        let p = [0.0, 0.0, 1.0]; // Always pick the last element
325        let chosen = rng.choice(&arr, 10, true, Some(&p)).unwrap();
326        for &v in chosen.as_slice().unwrap() {
327            assert_eq!(v, 30);
328        }
329    }
330
331    #[test]
332    fn choice_without_replacement_with_weights() {
333        let mut rng = default_rng_seeded(42);
334        let arr = Array::<i64, Ix1>::from_vec(Ix1::new([5]), vec![1, 2, 3, 4, 5]).unwrap();
335        let p = [0.1, 0.2, 0.3, 0.2, 0.2];
336        let chosen = rng.choice(&arr, 3, false, Some(&p)).unwrap();
337        let slice = chosen.as_slice().unwrap();
338        // No duplicates
339        let mut seen = std::collections::HashSet::new();
340        for &v in slice {
341            assert!(seen.insert(v), "duplicate value {v}");
342        }
343    }
344
345    #[test]
346    fn choice_bad_weights() {
347        let mut rng = default_rng_seeded(42);
348        let arr = Array::<i64, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
349        // Wrong length
350        assert!(rng.choice(&arr, 1, true, Some(&[0.5, 0.5])).is_err());
351        // Doesn't sum to 1
352        assert!(rng.choice(&arr, 1, true, Some(&[0.5, 0.5, 0.5])).is_err());
353        // Negative
354        assert!(rng.choice(&arr, 1, true, Some(&[-0.1, 0.6, 0.5])).is_err());
355    }
356
357    #[test]
358    fn permuted_1d() {
359        let mut rng = default_rng_seeded(42);
360        let arr = Array::<i64, Ix1>::from_vec(Ix1::new([5]), vec![1, 2, 3, 4, 5]).unwrap();
361        let result = rng.permuted(&arr, 0).unwrap();
362        let mut sorted: Vec<i64> = result.as_slice().unwrap().to_vec();
363        sorted.sort();
364        assert_eq!(sorted, vec![1, 2, 3, 4, 5]);
365    }
366}