1use rand::Rng;
18use rand::seq::SliceRandom;
19
20pub trait Sampler: Send + Sync {
28 fn len(&self) -> usize;
30
31 fn is_empty(&self) -> bool {
33 self.len() == 0
34 }
35
36 fn iter(&self) -> Box<dyn Iterator<Item = usize> + '_>;
38}
39
40pub struct SequentialSampler {
46 len: usize,
47}
48
49impl SequentialSampler {
50 #[must_use]
52 pub fn new(len: usize) -> Self {
53 Self { len }
54 }
55}
56
57impl Sampler for SequentialSampler {
58 fn len(&self) -> usize {
59 self.len
60 }
61
62 fn iter(&self) -> Box<dyn Iterator<Item = usize> + '_> {
63 Box::new(0..self.len)
64 }
65}
66
67pub struct RandomSampler {
73 len: usize,
74 replacement: bool,
75 num_samples: Option<usize>,
76}
77
78impl RandomSampler {
79 #[must_use]
81 pub fn new(len: usize) -> Self {
82 Self {
83 len,
84 replacement: false,
85 num_samples: None,
86 }
87 }
88
89 #[must_use]
91 pub fn with_replacement(len: usize, num_samples: usize) -> Self {
92 Self {
93 len,
94 replacement: true,
95 num_samples: Some(num_samples),
96 }
97 }
98}
99
100impl Sampler for RandomSampler {
101 fn len(&self) -> usize {
102 self.num_samples.unwrap_or(self.len)
103 }
104
105 fn iter(&self) -> Box<dyn Iterator<Item = usize> + '_> {
106 if self.replacement {
107 let len = self.len;
109 let num = self.num_samples.unwrap_or(len);
110 Box::new(RandomWithReplacementIter::new(len, num))
111 } else {
112 let mut indices: Vec<usize> = (0..self.len).collect();
114 indices.shuffle(&mut rand::thread_rng());
115 Box::new(indices.into_iter())
116 }
117 }
118}
119
120struct RandomWithReplacementIter {
122 len: usize,
123 remaining: usize,
124}
125
126impl RandomWithReplacementIter {
127 fn new(len: usize, num_samples: usize) -> Self {
128 Self {
129 len,
130 remaining: num_samples,
131 }
132 }
133}
134
135impl Iterator for RandomWithReplacementIter {
136 type Item = usize;
137
138 fn next(&mut self) -> Option<Self::Item> {
139 if self.remaining == 0 {
140 return None;
141 }
142 self.remaining -= 1;
143 Some(rand::thread_rng().gen_range(0..self.len))
144 }
145}
146
147pub struct SubsetRandomSampler {
153 indices: Vec<usize>,
154}
155
156impl SubsetRandomSampler {
157 #[must_use]
159 pub fn new(indices: Vec<usize>) -> Self {
160 Self { indices }
161 }
162}
163
164impl Sampler for SubsetRandomSampler {
165 fn len(&self) -> usize {
166 self.indices.len()
167 }
168
169 fn iter(&self) -> Box<dyn Iterator<Item = usize> + '_> {
170 let mut shuffled = self.indices.clone();
171 shuffled.shuffle(&mut rand::thread_rng());
172 Box::new(shuffled.into_iter())
173 }
174}
175
176pub struct WeightedRandomSampler {
185 weights: Vec<f64>,
186 cumulative: Vec<f64>,
188 num_samples: usize,
189 replacement: bool,
190}
191
192impl WeightedRandomSampler {
193 #[must_use]
195 pub fn new(weights: Vec<f64>, num_samples: usize, replacement: bool) -> Self {
196 let cumulative = Self::build_cumulative(&weights);
197 Self {
198 weights,
199 cumulative,
200 num_samples,
201 replacement,
202 }
203 }
204
205 fn build_cumulative(weights: &[f64]) -> Vec<f64> {
207 let mut cum = Vec::with_capacity(weights.len());
208 let mut total = 0.0;
209 for &w in weights {
210 total += w;
211 cum.push(total);
212 }
213 cum
214 }
215
216 fn sample_index(&self) -> usize {
218 let total = *self.cumulative.last().unwrap_or(&1.0);
219 let threshold: f64 = rand::thread_rng().r#gen::<f64>() * total;
220
221 match self.cumulative.binary_search_by(|c| c.partial_cmp(&threshold).unwrap()) {
223 Ok(i) => i,
224 Err(i) => i.min(self.cumulative.len() - 1),
225 }
226 }
227}
228
229impl Sampler for WeightedRandomSampler {
230 fn len(&self) -> usize {
231 self.num_samples
232 }
233
234 fn iter(&self) -> Box<dyn Iterator<Item = usize> + '_> {
235 if self.replacement {
236 Box::new(WeightedIter::new(self))
237 } else {
238 let mut indices = Vec::with_capacity(self.num_samples);
240 let mut available: Vec<usize> = (0..self.weights.len()).collect();
241 let mut weights = self.weights.clone();
242 let mut cumulative = self.cumulative.clone();
243
244 while indices.len() < self.num_samples && !available.is_empty() {
245 let total = *cumulative.last().unwrap_or(&0.0);
246 if total <= 0.0 {
247 break;
248 }
249
250 let threshold: f64 = rand::thread_rng().r#gen::<f64>() * total;
251 let selected = match cumulative.binary_search_by(|c| {
252 c.partial_cmp(&threshold).unwrap()
253 }) {
254 Ok(i) => i,
255 Err(i) => i.min(cumulative.len() - 1),
256 };
257
258 indices.push(available[selected]);
259
260 available.swap_remove(selected);
262 weights.swap_remove(selected);
263
264 cumulative = Self::build_cumulative(&weights);
266 }
267
268 Box::new(indices.into_iter())
269 }
270 }
271}
272
273struct WeightedIter<'a> {
275 sampler: &'a WeightedRandomSampler,
276 remaining: usize,
277}
278
279impl<'a> WeightedIter<'a> {
280 fn new(sampler: &'a WeightedRandomSampler) -> Self {
281 Self {
282 sampler,
283 remaining: sampler.num_samples,
284 }
285 }
286}
287
288impl Iterator for WeightedIter<'_> {
289 type Item = usize;
290
291 fn next(&mut self) -> Option<Self::Item> {
292 if self.remaining == 0 {
293 return None;
294 }
295 self.remaining -= 1;
296 Some(self.sampler.sample_index())
297 }
298}
299
300pub struct BatchSampler<S: Sampler> {
306 sampler: S,
307 batch_size: usize,
308 drop_last: bool,
309}
310
311impl<S: Sampler> BatchSampler<S> {
312 pub fn new(sampler: S, batch_size: usize, drop_last: bool) -> Self {
314 Self {
315 sampler,
316 batch_size,
317 drop_last,
318 }
319 }
320
321 pub fn iter(&self) -> BatchIter {
323 let indices: Vec<usize> = self.sampler.iter().collect();
324 BatchIter {
325 indices,
326 batch_size: self.batch_size,
327 drop_last: self.drop_last,
328 position: 0,
329 }
330 }
331
332 pub fn len(&self) -> usize {
334 let total = self.sampler.len();
335 if self.drop_last {
336 total / self.batch_size
337 } else {
338 total.div_ceil(self.batch_size)
339 }
340 }
341
342 pub fn is_empty(&self) -> bool {
344 self.len() == 0
345 }
346}
347
348pub struct BatchIter {
350 indices: Vec<usize>,
351 batch_size: usize,
352 drop_last: bool,
353 position: usize,
354}
355
356impl Iterator for BatchIter {
357 type Item = Vec<usize>;
358
359 fn next(&mut self) -> Option<Self::Item> {
360 if self.position >= self.indices.len() {
361 return None;
362 }
363
364 let end = (self.position + self.batch_size).min(self.indices.len());
365 let batch: Vec<usize> = self.indices[self.position..end].to_vec();
366
367 if batch.len() < self.batch_size && self.drop_last {
368 return None;
369 }
370
371 self.position = end;
372 Some(batch)
373 }
374}
375
376#[cfg(test)]
381mod tests {
382 use super::*;
383
384 #[test]
385 fn test_sequential_sampler() {
386 let sampler = SequentialSampler::new(5);
387 let indices: Vec<usize> = sampler.iter().collect();
388 assert_eq!(indices, vec![0, 1, 2, 3, 4]);
389 }
390
391 #[test]
392 fn test_random_sampler() {
393 let sampler = RandomSampler::new(10);
394 let indices: Vec<usize> = sampler.iter().collect();
395
396 assert_eq!(indices.len(), 10);
397 let mut sorted = indices.clone();
399 sorted.sort_unstable();
400 sorted.dedup();
401 assert_eq!(sorted.len(), 10);
402 }
403
404 #[test]
405 fn test_random_sampler_with_replacement() {
406 let sampler = RandomSampler::with_replacement(5, 20);
407 let indices: Vec<usize> = sampler.iter().collect();
408
409 assert_eq!(indices.len(), 20);
410 assert!(indices.iter().all(|&i| i < 5));
412 }
413
414 #[test]
415 fn test_subset_random_sampler() {
416 let sampler = SubsetRandomSampler::new(vec![0, 5, 10, 15]);
417 let indices: Vec<usize> = sampler.iter().collect();
418
419 assert_eq!(indices.len(), 4);
420 let mut sorted = indices.clone();
422 sorted.sort_unstable();
423 assert_eq!(sorted, vec![0, 5, 10, 15]);
424 }
425
426 #[test]
427 fn test_weighted_random_sampler() {
428 let sampler = WeightedRandomSampler::new(vec![100.0, 1.0, 1.0, 1.0], 100, true);
430 let indices: Vec<usize> = sampler.iter().collect();
431
432 assert_eq!(indices.len(), 100);
433 let zeros = indices.iter().filter(|&&i| i == 0).count();
435 assert!(zeros > 50, "Expected mostly zeros, got {zeros}");
436 }
437
438 #[test]
439 fn test_batch_sampler() {
440 let base = SequentialSampler::new(10);
441 let sampler = BatchSampler::new(base, 3, false);
442
443 let batches: Vec<Vec<usize>> = sampler.iter().collect();
444 assert_eq!(batches.len(), 4); assert_eq!(batches[0], vec![0, 1, 2]);
447 assert_eq!(batches[1], vec![3, 4, 5]);
448 assert_eq!(batches[2], vec![6, 7, 8]);
449 assert_eq!(batches[3], vec![9]); }
451
452 #[test]
453 fn test_batch_sampler_drop_last() {
454 let base = SequentialSampler::new(10);
455 let sampler = BatchSampler::new(base, 3, true);
456
457 let batches: Vec<Vec<usize>> = sampler.iter().collect();
458 assert_eq!(batches.len(), 3); assert_eq!(batches[0], vec![0, 1, 2]);
461 assert_eq!(batches[1], vec![3, 4, 5]);
462 assert_eq!(batches[2], vec![6, 7, 8]);
463 }
464}