1use ferray_core::{Array, FerrayError, Ix1};
4
5use crate::bitgen::BitGenerator;
6use crate::generator::Generator;
7
8impl<B: BitGenerator> Generator<B> {
9 pub fn shuffle<T>(&mut self, arr: &mut Array<T, Ix1>) -> Result<(), FerrayError>
14 where
15 T: ferray_core::Element,
16 {
17 let n = arr.shape()[0];
18 if n <= 1 {
19 return Ok(());
20 }
21 let slice = arr
22 .as_slice_mut()
23 .ok_or_else(|| FerrayError::invalid_value("array must be contiguous for shuffle"))?;
24 for i in (1..n).rev() {
26 let j = self.bg.next_u64_bounded((i + 1) as u64) as usize;
27 slice.swap(i, j);
28 }
29 Ok(())
30 }
31
32 pub fn permutation<T>(&mut self, arr: &Array<T, Ix1>) -> Result<Array<T, Ix1>, FerrayError>
40 where
41 T: ferray_core::Element,
42 {
43 let mut copy = arr.clone();
44 self.shuffle(&mut copy)?;
45 Ok(copy)
46 }
47
48 pub fn permutation_range(&mut self, n: usize) -> Result<Array<i64, Ix1>, FerrayError> {
53 if n == 0 {
54 return Err(FerrayError::invalid_value("n must be > 0"));
55 }
56 let mut data: Vec<i64> = (0..n as i64).collect();
57 for i in (1..n).rev() {
59 let j = self.bg.next_u64_bounded((i + 1) as u64) as usize;
60 data.swap(i, j);
61 }
62 Array::<i64, Ix1>::from_vec(Ix1::new([n]), data)
63 }
64
65 pub fn permuted<T>(
73 &mut self,
74 arr: &Array<T, Ix1>,
75 _axis: usize,
76 ) -> Result<Array<T, Ix1>, FerrayError>
77 where
78 T: ferray_core::Element,
79 {
80 self.permutation(arr)
81 }
82
83 pub fn choice<T>(
95 &mut self,
96 arr: &Array<T, Ix1>,
97 size: usize,
98 replace: bool,
99 p: Option<&[f64]>,
100 ) -> Result<Array<T, Ix1>, FerrayError>
101 where
102 T: ferray_core::Element,
103 {
104 let n = arr.shape()[0];
105 if size == 0 {
106 return Err(FerrayError::invalid_value("size must be > 0"));
107 }
108 if n == 0 {
109 return Err(FerrayError::invalid_value("source array must be non-empty"));
110 }
111 if !replace && size > n {
112 return Err(FerrayError::invalid_value(format!(
113 "cannot choose {size} elements without replacement from array of size {n}"
114 )));
115 }
116
117 if let Some(probs) = p {
118 if probs.len() != n {
119 return Err(FerrayError::invalid_value(format!(
120 "p must have same length as array ({n}), got {}",
121 probs.len()
122 )));
123 }
124 let psum: f64 = probs.iter().sum();
125 if (psum - 1.0).abs() > 1e-6 {
126 return Err(FerrayError::invalid_value(format!(
127 "p must sum to 1.0, got {psum}"
128 )));
129 }
130 for (i, &pi) in probs.iter().enumerate() {
131 if pi < 0.0 {
132 return Err(FerrayError::invalid_value(format!(
133 "p[{i}] = {pi} is negative"
134 )));
135 }
136 }
137 }
138
139 let src = arr
140 .as_slice()
141 .ok_or_else(|| FerrayError::invalid_value("array must be contiguous"))?;
142
143 let indices = if let Some(probs) = p {
144 if replace {
146 weighted_sample_with_replacement(&mut self.bg, probs, size)
147 } else {
148 weighted_sample_without_replacement(&mut self.bg, probs, size)?
149 }
150 } else if replace {
151 (0..size)
153 .map(|_| self.bg.next_u64_bounded(n as u64) as usize)
154 .collect()
155 } else {
156 sample_without_replacement(&mut self.bg, n, size)
158 };
159
160 let data: Vec<T> = indices.iter().map(|&i| src[i].clone()).collect();
161 Array::<T, Ix1>::from_vec(Ix1::new([size]), data)
162 }
163}
164
165fn sample_without_replacement<B: BitGenerator>(bg: &mut B, n: usize, size: usize) -> Vec<usize> {
167 let mut pool: Vec<usize> = (0..n).collect();
168 for i in 0..size {
169 let j = i + bg.next_u64_bounded((n - i) as u64) as usize;
170 pool.swap(i, j);
171 }
172 pool[..size].to_vec()
173}
174
175fn weighted_sample_with_replacement<B: BitGenerator>(
177 bg: &mut B,
178 probs: &[f64],
179 size: usize,
180) -> Vec<usize> {
181 let mut cdf = Vec::with_capacity(probs.len());
183 let mut cumsum = 0.0;
184 for &p in probs {
185 cumsum += p;
186 cdf.push(cumsum);
187 }
188
189 (0..size)
190 .map(|_| {
191 let u = bg.next_f64();
192 match cdf.binary_search_by(|c| c.partial_cmp(&u).unwrap_or(std::cmp::Ordering::Equal)) {
194 Ok(i) => i,
195 Err(i) => i.min(probs.len() - 1),
196 }
197 })
198 .collect()
199}
200
201fn weighted_sample_without_replacement<B: BitGenerator>(
203 bg: &mut B,
204 probs: &[f64],
205 size: usize,
206) -> Result<Vec<usize>, FerrayError> {
207 let n = probs.len();
208 let mut weights: Vec<f64> = probs.to_vec();
209 let mut selected = Vec::with_capacity(size);
210
211 for _ in 0..size {
212 let total: f64 = weights.iter().sum();
213 if total <= 0.0 {
214 return Err(FerrayError::invalid_value(
215 "insufficient probability mass for sampling without replacement",
216 ));
217 }
218 let u = bg.next_f64() * total;
219 let mut cumsum = 0.0;
220 let mut chosen = n - 1;
221 for (i, &w) in weights.iter().enumerate() {
222 cumsum += w;
223 if cumsum > u {
224 chosen = i;
225 break;
226 }
227 }
228 selected.push(chosen);
229 weights[chosen] = 0.0;
230 }
231
232 Ok(selected)
233}
234
235#[cfg(test)]
236mod tests {
237 use crate::default_rng_seeded;
238 use ferray_core::{Array, Ix1};
239
240 #[test]
241 fn shuffle_preserves_elements() {
242 let mut rng = default_rng_seeded(42);
243 let mut arr = Array::<i64, Ix1>::from_vec(Ix1::new([5]), vec![1, 2, 3, 4, 5]).unwrap();
244 rng.shuffle(&mut arr).unwrap();
245 let mut sorted: Vec<i64> = arr.as_slice().unwrap().to_vec();
246 sorted.sort();
247 assert_eq!(sorted, vec![1, 2, 3, 4, 5]);
248 }
249
250 #[test]
251 fn permutation_preserves_elements() {
252 let mut rng = default_rng_seeded(42);
253 let arr = Array::<i64, Ix1>::from_vec(Ix1::new([5]), vec![10, 20, 30, 40, 50]).unwrap();
254 let perm = rng.permutation(&arr).unwrap();
255 let mut sorted: Vec<i64> = perm.as_slice().unwrap().to_vec();
256 sorted.sort();
257 assert_eq!(sorted, vec![10, 20, 30, 40, 50]);
258 }
259
260 #[test]
261 fn permutation_range_covers_all() {
262 let mut rng = default_rng_seeded(42);
263 let perm = rng.permutation_range(10).unwrap();
264 let mut sorted: Vec<i64> = perm.as_slice().unwrap().to_vec();
265 sorted.sort();
266 let expected: Vec<i64> = (0..10).collect();
267 assert_eq!(sorted, expected);
268 }
269
270 #[test]
271 fn shuffle_modifies_in_place() {
272 let mut rng = default_rng_seeded(42);
273 let original = vec![1i64, 2, 3, 4, 5, 6, 7, 8, 9, 10];
274 let mut arr = Array::<i64, Ix1>::from_vec(Ix1::new([10]), original.clone()).unwrap();
275 rng.shuffle(&mut arr).unwrap();
276 let shuffled = arr.as_slice().unwrap().to_vec();
278 let mut sorted = shuffled.clone();
280 sorted.sort();
281 assert_eq!(sorted, original);
282 }
283
284 #[test]
285 fn choice_with_replacement() {
286 let mut rng = default_rng_seeded(42);
287 let arr = Array::<i64, Ix1>::from_vec(Ix1::new([5]), vec![10, 20, 30, 40, 50]).unwrap();
288 let chosen = rng.choice(&arr, 10, true, None).unwrap();
289 assert_eq!(chosen.shape(), &[10]);
290 let src: Vec<i64> = vec![10, 20, 30, 40, 50];
292 for &v in chosen.as_slice().unwrap() {
293 assert!(src.contains(&v), "choice returned unexpected value {v}");
294 }
295 }
296
297 #[test]
298 fn choice_without_replacement_no_duplicates() {
299 let mut rng = default_rng_seeded(42);
300 let arr = Array::<i64, Ix1>::from_vec(Ix1::new([10]), (0..10).collect()).unwrap();
301 let chosen = rng.choice(&arr, 5, false, None).unwrap();
302 let slice = chosen.as_slice().unwrap();
303 let mut seen = std::collections::HashSet::new();
305 for &v in slice {
306 assert!(
307 seen.insert(v),
308 "duplicate value {v} in choice without replacement"
309 );
310 }
311 }
312
313 #[test]
314 fn choice_without_replacement_too_many() {
315 let mut rng = default_rng_seeded(42);
316 let arr = Array::<i64, Ix1>::from_vec(Ix1::new([5]), vec![1, 2, 3, 4, 5]).unwrap();
317 assert!(rng.choice(&arr, 10, false, None).is_err());
318 }
319
320 #[test]
321 fn choice_with_weights() {
322 let mut rng = default_rng_seeded(42);
323 let arr = Array::<i64, Ix1>::from_vec(Ix1::new([3]), vec![10, 20, 30]).unwrap();
324 let p = [0.0, 0.0, 1.0]; let chosen = rng.choice(&arr, 10, true, Some(&p)).unwrap();
326 for &v in chosen.as_slice().unwrap() {
327 assert_eq!(v, 30);
328 }
329 }
330
331 #[test]
332 fn choice_without_replacement_with_weights() {
333 let mut rng = default_rng_seeded(42);
334 let arr = Array::<i64, Ix1>::from_vec(Ix1::new([5]), vec![1, 2, 3, 4, 5]).unwrap();
335 let p = [0.1, 0.2, 0.3, 0.2, 0.2];
336 let chosen = rng.choice(&arr, 3, false, Some(&p)).unwrap();
337 let slice = chosen.as_slice().unwrap();
338 let mut seen = std::collections::HashSet::new();
340 for &v in slice {
341 assert!(seen.insert(v), "duplicate value {v}");
342 }
343 }
344
345 #[test]
346 fn choice_bad_weights() {
347 let mut rng = default_rng_seeded(42);
348 let arr = Array::<i64, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
349 assert!(rng.choice(&arr, 1, true, Some(&[0.5, 0.5])).is_err());
351 assert!(rng.choice(&arr, 1, true, Some(&[0.5, 0.5, 0.5])).is_err());
353 assert!(rng.choice(&arr, 1, true, Some(&[-0.1, 0.6, 0.5])).is_err());
355 }
356
357 #[test]
358 fn permuted_1d() {
359 let mut rng = default_rng_seeded(42);
360 let arr = Array::<i64, Ix1>::from_vec(Ix1::new([5]), vec![1, 2, 3, 4, 5]).unwrap();
361 let result = rng.permuted(&arr, 0).unwrap();
362 let mut sorted: Vec<i64> = result.as_slice().unwrap().to_vec();
363 sorted.sort();
364 assert_eq!(sorted, vec![1, 2, 3, 4, 5]);
365 }
366}