1use ferray_core::{Array, FerrayError, Ix1};
4
5use crate::bitgen::BitGenerator;
6use crate::generator::Generator;
7
8impl<B: BitGenerator> Generator<B> {
9 pub fn shuffle<T>(&mut self, arr: &mut Array<T, Ix1>) -> Result<(), FerrayError>
14 where
15 T: ferray_core::Element,
16 {
17 let n = arr.shape()[0];
18 if n <= 1 {
19 return Ok(());
20 }
21 let slice = arr
22 .as_slice_mut()
23 .ok_or_else(|| FerrayError::invalid_value("array must be contiguous for shuffle"))?;
24 for i in (1..n).rev() {
26 let j = self.bg.next_u64_bounded((i + 1) as u64) as usize;
27 slice.swap(i, j);
28 }
29 Ok(())
30 }
31
32 pub fn permutation<T>(&mut self, arr: &Array<T, Ix1>) -> Result<Array<T, Ix1>, FerrayError>
40 where
41 T: ferray_core::Element,
42 {
43 let mut copy = arr.clone();
44 self.shuffle(&mut copy)?;
45 Ok(copy)
46 }
47
48 pub fn permutation_range(&mut self, n: usize) -> Result<Array<i64, Ix1>, FerrayError> {
53 if n == 0 {
54 return Err(FerrayError::invalid_value("n must be > 0"));
55 }
56 let mut data: Vec<i64> = (0..n as i64).collect();
57 for i in (1..n).rev() {
59 let j = self.bg.next_u64_bounded((i + 1) as u64) as usize;
60 data.swap(i, j);
61 }
62 Array::<i64, Ix1>::from_vec(Ix1::new([n]), data)
63 }
64
65 pub fn permuted<T>(
73 &mut self,
74 arr: &Array<T, Ix1>,
75 _axis: usize,
76 ) -> Result<Array<T, Ix1>, FerrayError>
77 where
78 T: ferray_core::Element,
79 {
80 self.permutation(arr)
81 }
82
83 pub fn choice<T>(
95 &mut self,
96 arr: &Array<T, Ix1>,
97 size: usize,
98 replace: bool,
99 p: Option<&[f64]>,
100 ) -> Result<Array<T, Ix1>, FerrayError>
101 where
102 T: ferray_core::Element,
103 {
104 let n = arr.shape()[0];
105 if size == 0 {
109 return Array::from_vec(Ix1::new([0]), Vec::new());
110 }
111 if n == 0 {
112 return Err(FerrayError::invalid_value("source array must be non-empty"));
113 }
114 if !replace && size > n {
115 return Err(FerrayError::invalid_value(format!(
116 "cannot choose {size} elements without replacement from array of size {n}"
117 )));
118 }
119
120 if let Some(probs) = p {
121 if probs.len() != n {
122 return Err(FerrayError::invalid_value(format!(
123 "p must have same length as array ({n}), got {}",
124 probs.len()
125 )));
126 }
127 let psum: f64 = probs.iter().sum();
128 if (psum - 1.0).abs() > 1e-6 {
129 return Err(FerrayError::invalid_value(format!(
130 "p must sum to 1.0, got {psum}"
131 )));
132 }
133 for (i, &pi) in probs.iter().enumerate() {
134 if pi < 0.0 {
135 return Err(FerrayError::invalid_value(format!(
136 "p[{i}] = {pi} is negative"
137 )));
138 }
139 }
140 }
141
142 let src = arr
143 .as_slice()
144 .ok_or_else(|| FerrayError::invalid_value("array must be contiguous"))?;
145
146 let indices = if let Some(probs) = p {
147 if replace {
149 weighted_sample_with_replacement(&mut self.bg, probs, size)
150 } else {
151 weighted_sample_without_replacement(&mut self.bg, probs, size)?
152 }
153 } else if replace {
154 (0..size)
156 .map(|_| self.bg.next_u64_bounded(n as u64) as usize)
157 .collect()
158 } else {
159 sample_without_replacement(&mut self.bg, n, size)
161 };
162
163 let data: Vec<T> = indices.iter().map(|&i| src[i].clone()).collect();
164 Array::<T, Ix1>::from_vec(Ix1::new([size]), data)
165 }
166}
167
168fn sample_without_replacement<B: BitGenerator>(bg: &mut B, n: usize, size: usize) -> Vec<usize> {
170 let mut pool: Vec<usize> = (0..n).collect();
171 for i in 0..size {
172 let j = i + bg.next_u64_bounded((n - i) as u64) as usize;
173 pool.swap(i, j);
174 }
175 pool[..size].to_vec()
176}
177
178fn weighted_sample_with_replacement<B: BitGenerator>(
180 bg: &mut B,
181 probs: &[f64],
182 size: usize,
183) -> Vec<usize> {
184 let mut cdf = Vec::with_capacity(probs.len());
186 let mut cumsum = 0.0;
187 for &p in probs {
188 cumsum += p;
189 cdf.push(cumsum);
190 }
191
192 (0..size)
193 .map(|_| {
194 let u = bg.next_f64();
195 match cdf.binary_search_by(|c| c.partial_cmp(&u).unwrap_or(std::cmp::Ordering::Equal)) {
197 Ok(i) => i,
198 Err(i) => i.min(probs.len() - 1),
199 }
200 })
201 .collect()
202}
203
204fn weighted_sample_without_replacement<B: BitGenerator>(
206 bg: &mut B,
207 probs: &[f64],
208 size: usize,
209) -> Result<Vec<usize>, FerrayError> {
210 let n = probs.len();
211 let mut weights: Vec<f64> = probs.to_vec();
212 let mut selected = Vec::with_capacity(size);
213
214 for _ in 0..size {
215 let total: f64 = weights.iter().sum();
216 if total <= 0.0 {
217 return Err(FerrayError::invalid_value(
218 "insufficient probability mass for sampling without replacement",
219 ));
220 }
221 let u = bg.next_f64() * total;
222 let mut cumsum = 0.0;
223 let mut chosen = n - 1;
224 for (i, &w) in weights.iter().enumerate() {
225 cumsum += w;
226 if cumsum > u {
227 chosen = i;
228 break;
229 }
230 }
231 selected.push(chosen);
232 weights[chosen] = 0.0;
233 }
234
235 Ok(selected)
236}
237
238#[cfg(test)]
239mod tests {
240 use crate::default_rng_seeded;
241 use ferray_core::{Array, Ix1};
242
243 #[test]
244 fn shuffle_preserves_elements() {
245 let mut rng = default_rng_seeded(42);
246 let mut arr = Array::<i64, Ix1>::from_vec(Ix1::new([5]), vec![1, 2, 3, 4, 5]).unwrap();
247 rng.shuffle(&mut arr).unwrap();
248 let mut sorted: Vec<i64> = arr.as_slice().unwrap().to_vec();
249 sorted.sort_unstable();
250 assert_eq!(sorted, vec![1, 2, 3, 4, 5]);
251 }
252
253 #[test]
254 fn permutation_preserves_elements() {
255 let mut rng = default_rng_seeded(42);
256 let arr = Array::<i64, Ix1>::from_vec(Ix1::new([5]), vec![10, 20, 30, 40, 50]).unwrap();
257 let perm = rng.permutation(&arr).unwrap();
258 let mut sorted: Vec<i64> = perm.as_slice().unwrap().to_vec();
259 sorted.sort_unstable();
260 assert_eq!(sorted, vec![10, 20, 30, 40, 50]);
261 }
262
263 #[test]
264 fn permutation_range_covers_all() {
265 let mut rng = default_rng_seeded(42);
266 let perm = rng.permutation_range(10).unwrap();
267 let mut sorted: Vec<i64> = perm.as_slice().unwrap().to_vec();
268 sorted.sort_unstable();
269 let expected: Vec<i64> = (0..10).collect();
270 assert_eq!(sorted, expected);
271 }
272
273 #[test]
274 fn shuffle_modifies_in_place() {
275 let mut rng = default_rng_seeded(42);
276 let original = vec![1i64, 2, 3, 4, 5, 6, 7, 8, 9, 10];
277 let mut arr = Array::<i64, Ix1>::from_vec(Ix1::new([10]), original.clone()).unwrap();
278 rng.shuffle(&mut arr).unwrap();
279 let shuffled = arr.as_slice().unwrap().to_vec();
281 let mut sorted = shuffled;
283 sorted.sort_unstable();
284 assert_eq!(sorted, original);
285 }
286
287 #[test]
288 fn choice_with_replacement() {
289 let mut rng = default_rng_seeded(42);
290 let arr = Array::<i64, Ix1>::from_vec(Ix1::new([5]), vec![10, 20, 30, 40, 50]).unwrap();
291 let chosen = rng.choice(&arr, 10, true, None).unwrap();
292 assert_eq!(chosen.shape(), &[10]);
293 let src: Vec<i64> = vec![10, 20, 30, 40, 50];
295 for &v in chosen.as_slice().unwrap() {
296 assert!(src.contains(&v), "choice returned unexpected value {v}");
297 }
298 }
299
300 #[test]
301 fn choice_without_replacement_no_duplicates() {
302 let mut rng = default_rng_seeded(42);
303 let arr = Array::<i64, Ix1>::from_vec(Ix1::new([10]), (0..10).collect()).unwrap();
304 let chosen = rng.choice(&arr, 5, false, None).unwrap();
305 let slice = chosen.as_slice().unwrap();
306 let mut seen = std::collections::HashSet::new();
308 for &v in slice {
309 assert!(
310 seen.insert(v),
311 "duplicate value {v} in choice without replacement"
312 );
313 }
314 }
315
316 #[test]
317 fn choice_without_replacement_too_many() {
318 let mut rng = default_rng_seeded(42);
319 let arr = Array::<i64, Ix1>::from_vec(Ix1::new([5]), vec![1, 2, 3, 4, 5]).unwrap();
320 assert!(rng.choice(&arr, 10, false, None).is_err());
321 }
322
323 #[test]
324 fn choice_with_weights() {
325 let mut rng = default_rng_seeded(42);
326 let arr = Array::<i64, Ix1>::from_vec(Ix1::new([3]), vec![10, 20, 30]).unwrap();
327 let p = [0.0, 0.0, 1.0]; let chosen = rng.choice(&arr, 10, true, Some(&p)).unwrap();
329 for &v in chosen.as_slice().unwrap() {
330 assert_eq!(v, 30);
331 }
332 }
333
334 #[test]
335 fn choice_without_replacement_with_weights() {
336 let mut rng = default_rng_seeded(42);
337 let arr = Array::<i64, Ix1>::from_vec(Ix1::new([5]), vec![1, 2, 3, 4, 5]).unwrap();
338 let p = [0.1, 0.2, 0.3, 0.2, 0.2];
339 let chosen = rng.choice(&arr, 3, false, Some(&p)).unwrap();
340 let slice = chosen.as_slice().unwrap();
341 let mut seen = std::collections::HashSet::new();
343 for &v in slice {
344 assert!(seen.insert(v), "duplicate value {v}");
345 }
346 }
347
348 #[test]
349 fn choice_bad_weights() {
350 let mut rng = default_rng_seeded(42);
351 let arr = Array::<i64, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
352 assert!(rng.choice(&arr, 1, true, Some(&[0.5, 0.5])).is_err());
354 assert!(rng.choice(&arr, 1, true, Some(&[0.5, 0.5, 0.5])).is_err());
356 assert!(rng.choice(&arr, 1, true, Some(&[-0.1, 0.6, 0.5])).is_err());
358 }
359
360 #[test]
361 fn permuted_1d() {
362 let mut rng = default_rng_seeded(42);
363 let arr = Array::<i64, Ix1>::from_vec(Ix1::new([5]), vec![1, 2, 3, 4, 5]).unwrap();
364 let result = rng.permuted(&arr, 0).unwrap();
365 let mut sorted: Vec<i64> = result.as_slice().unwrap().to_vec();
366 sorted.sort_unstable();
367 assert_eq!(sorted, vec![1, 2, 3, 4, 5]);
368 }
369}