1use rand::Rng;
18use rand::seq::SliceRandom;
19
20pub trait Sampler: Send + Sync {
28 fn len(&self) -> usize;
30
31 fn is_empty(&self) -> bool {
33 self.len() == 0
34 }
35
36 fn iter(&self) -> Box<dyn Iterator<Item = usize> + '_>;
38}
39
40pub struct SequentialSampler {
46 len: usize,
47}
48
49impl SequentialSampler {
50 #[must_use]
52 pub fn new(len: usize) -> Self {
53 Self { len }
54 }
55}
56
57impl Sampler for SequentialSampler {
58 fn len(&self) -> usize {
59 self.len
60 }
61
62 fn iter(&self) -> Box<dyn Iterator<Item = usize> + '_> {
63 Box::new(0..self.len)
64 }
65}
66
67pub struct RandomSampler {
73 len: usize,
74 replacement: bool,
75 num_samples: Option<usize>,
76}
77
78impl RandomSampler {
79 #[must_use]
81 pub fn new(len: usize) -> Self {
82 Self {
83 len,
84 replacement: false,
85 num_samples: None,
86 }
87 }
88
89 #[must_use]
91 pub fn with_replacement(len: usize, num_samples: usize) -> Self {
92 Self {
93 len,
94 replacement: true,
95 num_samples: Some(num_samples),
96 }
97 }
98}
99
100impl Sampler for RandomSampler {
101 fn len(&self) -> usize {
102 self.num_samples.unwrap_or(self.len)
103 }
104
105 fn iter(&self) -> Box<dyn Iterator<Item = usize> + '_> {
106 if self.replacement {
107 let len = self.len;
109 let num = self.num_samples.unwrap_or(len);
110 Box::new(RandomWithReplacementIter::new(len, num))
111 } else {
112 let mut indices: Vec<usize> = (0..self.len).collect();
114 indices.shuffle(&mut rand::thread_rng());
115 Box::new(indices.into_iter())
116 }
117 }
118}
119
120struct RandomWithReplacementIter {
122 len: usize,
123 remaining: usize,
124}
125
126impl RandomWithReplacementIter {
127 fn new(len: usize, num_samples: usize) -> Self {
128 Self {
129 len,
130 remaining: num_samples,
131 }
132 }
133}
134
135impl Iterator for RandomWithReplacementIter {
136 type Item = usize;
137
138 fn next(&mut self) -> Option<Self::Item> {
139 if self.remaining == 0 {
140 return None;
141 }
142 self.remaining -= 1;
143 Some(rand::thread_rng().gen_range(0..self.len))
144 }
145}
146
147pub struct SubsetRandomSampler {
153 indices: Vec<usize>,
154}
155
156impl SubsetRandomSampler {
157 #[must_use]
159 pub fn new(indices: Vec<usize>) -> Self {
160 Self { indices }
161 }
162}
163
164impl Sampler for SubsetRandomSampler {
165 fn len(&self) -> usize {
166 self.indices.len()
167 }
168
169 fn iter(&self) -> Box<dyn Iterator<Item = usize> + '_> {
170 let mut shuffled = self.indices.clone();
171 shuffled.shuffle(&mut rand::thread_rng());
172 Box::new(shuffled.into_iter())
173 }
174}
175
176pub struct WeightedRandomSampler {
185 weights: Vec<f64>,
186 cumulative: Vec<f64>,
188 num_samples: usize,
189 replacement: bool,
190}
191
192impl WeightedRandomSampler {
193 #[must_use]
195 pub fn new(weights: Vec<f64>, num_samples: usize, replacement: bool) -> Self {
196 let cumulative = Self::build_cumulative(&weights);
197 Self {
198 weights,
199 cumulative,
200 num_samples,
201 replacement,
202 }
203 }
204
205 fn build_cumulative(weights: &[f64]) -> Vec<f64> {
207 let mut cum = Vec::with_capacity(weights.len());
208 let mut total = 0.0;
209 for &w in weights {
210 total += w;
211 cum.push(total);
212 }
213 cum
214 }
215
216 fn sample_index(&self) -> usize {
218 let total = *self.cumulative.last().unwrap_or(&1.0);
219 let threshold: f64 = rand::thread_rng().r#gen::<f64>() * total;
220
221 match self
223 .cumulative
224 .binary_search_by(|c| c.partial_cmp(&threshold).unwrap())
225 {
226 Ok(i) => i,
227 Err(i) => i.min(self.cumulative.len() - 1),
228 }
229 }
230}
231
232impl Sampler for WeightedRandomSampler {
233 fn len(&self) -> usize {
234 self.num_samples
235 }
236
237 fn iter(&self) -> Box<dyn Iterator<Item = usize> + '_> {
238 if self.replacement {
239 Box::new(WeightedIter::new(self))
240 } else {
241 let mut indices = Vec::with_capacity(self.num_samples);
243 let mut available: Vec<usize> = (0..self.weights.len()).collect();
244 let mut weights = self.weights.clone();
245 let mut cumulative = self.cumulative.clone();
246
247 while indices.len() < self.num_samples && !available.is_empty() {
248 let total = *cumulative.last().unwrap_or(&0.0);
249 if total <= 0.0 {
250 break;
251 }
252
253 let threshold: f64 = rand::thread_rng().r#gen::<f64>() * total;
254 let selected =
255 match cumulative.binary_search_by(|c| c.partial_cmp(&threshold).unwrap()) {
256 Ok(i) => i,
257 Err(i) => i.min(cumulative.len() - 1),
258 };
259
260 indices.push(available[selected]);
261
262 available.swap_remove(selected);
264 weights.swap_remove(selected);
265
266 cumulative = Self::build_cumulative(&weights);
268 }
269
270 Box::new(indices.into_iter())
271 }
272 }
273}
274
275struct WeightedIter<'a> {
277 sampler: &'a WeightedRandomSampler,
278 remaining: usize,
279}
280
281impl<'a> WeightedIter<'a> {
282 fn new(sampler: &'a WeightedRandomSampler) -> Self {
283 Self {
284 sampler,
285 remaining: sampler.num_samples,
286 }
287 }
288}
289
290impl Iterator for WeightedIter<'_> {
291 type Item = usize;
292
293 fn next(&mut self) -> Option<Self::Item> {
294 if self.remaining == 0 {
295 return None;
296 }
297 self.remaining -= 1;
298 Some(self.sampler.sample_index())
299 }
300}
301
302pub struct BatchSampler<S: Sampler> {
308 sampler: S,
309 batch_size: usize,
310 drop_last: bool,
311}
312
313impl<S: Sampler> BatchSampler<S> {
314 pub fn new(sampler: S, batch_size: usize, drop_last: bool) -> Self {
316 Self {
317 sampler,
318 batch_size,
319 drop_last,
320 }
321 }
322
323 pub fn iter(&self) -> BatchIter {
325 let indices: Vec<usize> = self.sampler.iter().collect();
326 BatchIter {
327 indices,
328 batch_size: self.batch_size,
329 drop_last: self.drop_last,
330 position: 0,
331 }
332 }
333
334 pub fn len(&self) -> usize {
336 let total = self.sampler.len();
337 if self.drop_last {
338 total / self.batch_size
339 } else {
340 total.div_ceil(self.batch_size)
341 }
342 }
343
344 pub fn is_empty(&self) -> bool {
346 self.len() == 0
347 }
348}
349
350pub struct BatchIter {
352 indices: Vec<usize>,
353 batch_size: usize,
354 drop_last: bool,
355 position: usize,
356}
357
358impl Iterator for BatchIter {
359 type Item = Vec<usize>;
360
361 fn next(&mut self) -> Option<Self::Item> {
362 if self.position >= self.indices.len() {
363 return None;
364 }
365
366 let end = (self.position + self.batch_size).min(self.indices.len());
367 let batch: Vec<usize> = self.indices[self.position..end].to_vec();
368
369 if batch.len() < self.batch_size && self.drop_last {
370 return None;
371 }
372
373 self.position = end;
374 Some(batch)
375 }
376}
377
378#[cfg(test)]
383mod tests {
384 use super::*;
385
386 #[test]
387 fn test_sequential_sampler() {
388 let sampler = SequentialSampler::new(5);
389 let indices: Vec<usize> = sampler.iter().collect();
390 assert_eq!(indices, vec![0, 1, 2, 3, 4]);
391 }
392
393 #[test]
394 fn test_random_sampler() {
395 let sampler = RandomSampler::new(10);
396 let indices: Vec<usize> = sampler.iter().collect();
397
398 assert_eq!(indices.len(), 10);
399 let mut sorted = indices.clone();
401 sorted.sort_unstable();
402 sorted.dedup();
403 assert_eq!(sorted.len(), 10);
404 }
405
406 #[test]
407 fn test_random_sampler_with_replacement() {
408 let sampler = RandomSampler::with_replacement(5, 20);
409 let indices: Vec<usize> = sampler.iter().collect();
410
411 assert_eq!(indices.len(), 20);
412 assert!(indices.iter().all(|&i| i < 5));
414 }
415
416 #[test]
417 fn test_subset_random_sampler() {
418 let sampler = SubsetRandomSampler::new(vec![0, 5, 10, 15]);
419 let indices: Vec<usize> = sampler.iter().collect();
420
421 assert_eq!(indices.len(), 4);
422 let mut sorted = indices.clone();
424 sorted.sort_unstable();
425 assert_eq!(sorted, vec![0, 5, 10, 15]);
426 }
427
428 #[test]
429 fn test_weighted_random_sampler() {
430 let sampler = WeightedRandomSampler::new(vec![100.0, 1.0, 1.0, 1.0], 100, true);
432 let indices: Vec<usize> = sampler.iter().collect();
433
434 assert_eq!(indices.len(), 100);
435 let zeros = indices.iter().filter(|&&i| i == 0).count();
437 assert!(zeros > 50, "Expected mostly zeros, got {zeros}");
438 }
439
440 #[test]
441 fn test_batch_sampler() {
442 let base = SequentialSampler::new(10);
443 let sampler = BatchSampler::new(base, 3, false);
444
445 let batches: Vec<Vec<usize>> = sampler.iter().collect();
446 assert_eq!(batches.len(), 4); assert_eq!(batches[0], vec![0, 1, 2]);
449 assert_eq!(batches[1], vec![3, 4, 5]);
450 assert_eq!(batches[2], vec![6, 7, 8]);
451 assert_eq!(batches[3], vec![9]); }
453
454 #[test]
455 fn test_batch_sampler_drop_last() {
456 let base = SequentialSampler::new(10);
457 let sampler = BatchSampler::new(base, 3, true);
458
459 let batches: Vec<Vec<usize>> = sampler.iter().collect();
460 assert_eq!(batches.len(), 3); assert_eq!(batches[0], vec![0, 1, 2]);
463 assert_eq!(batches[1], vec![3, 4, 5]);
464 assert_eq!(batches[2], vec![6, 7, 8]);
465 }
466}