1use rand::Rng;
19use rand::seq::SliceRandom;
20
21pub trait Sampler: Send + Sync {
29 fn len(&self) -> usize;
31
32 fn is_empty(&self) -> bool {
34 self.len() == 0
35 }
36
37 fn iter(&self) -> Box<dyn Iterator<Item = usize> + '_>;
39}
40
41pub struct SequentialSampler {
47 len: usize,
48}
49
50impl SequentialSampler {
51 #[must_use]
53 pub fn new(len: usize) -> Self {
54 Self { len }
55 }
56}
57
58impl Sampler for SequentialSampler {
59 fn len(&self) -> usize {
60 self.len
61 }
62
63 fn iter(&self) -> Box<dyn Iterator<Item = usize> + '_> {
64 Box::new(0..self.len)
65 }
66}
67
68pub struct RandomSampler {
74 len: usize,
75 replacement: bool,
76 num_samples: Option<usize>,
77}
78
79impl RandomSampler {
80 #[must_use]
82 pub fn new(len: usize) -> Self {
83 Self {
84 len,
85 replacement: false,
86 num_samples: None,
87 }
88 }
89
90 #[must_use]
92 pub fn with_replacement(len: usize, num_samples: usize) -> Self {
93 Self {
94 len,
95 replacement: true,
96 num_samples: Some(num_samples),
97 }
98 }
99}
100
101impl Sampler for RandomSampler {
102 fn len(&self) -> usize {
103 self.num_samples.unwrap_or(self.len)
104 }
105
106 fn iter(&self) -> Box<dyn Iterator<Item = usize> + '_> {
107 if self.replacement {
108 let len = self.len;
110 let num = self.num_samples.unwrap_or(len);
111 Box::new(RandomWithReplacementIter::new(len, num))
112 } else {
113 let mut indices: Vec<usize> = (0..self.len).collect();
115 indices.shuffle(&mut rand::thread_rng());
116 Box::new(indices.into_iter())
117 }
118 }
119}
120
121struct RandomWithReplacementIter {
123 len: usize,
124 remaining: usize,
125}
126
127impl RandomWithReplacementIter {
128 fn new(len: usize, num_samples: usize) -> Self {
129 Self {
130 len,
131 remaining: num_samples,
132 }
133 }
134}
135
136impl Iterator for RandomWithReplacementIter {
137 type Item = usize;
138
139 fn next(&mut self) -> Option<Self::Item> {
140 if self.remaining == 0 {
141 return None;
142 }
143 self.remaining -= 1;
144 Some(rand::thread_rng().gen_range(0..self.len))
145 }
146}
147
148pub struct SubsetRandomSampler {
154 indices: Vec<usize>,
155}
156
157impl SubsetRandomSampler {
158 #[must_use]
160 pub fn new(indices: Vec<usize>) -> Self {
161 Self { indices }
162 }
163}
164
165impl Sampler for SubsetRandomSampler {
166 fn len(&self) -> usize {
167 self.indices.len()
168 }
169
170 fn iter(&self) -> Box<dyn Iterator<Item = usize> + '_> {
171 let mut shuffled = self.indices.clone();
172 shuffled.shuffle(&mut rand::thread_rng());
173 Box::new(shuffled.into_iter())
174 }
175}
176
177pub struct WeightedRandomSampler {
186 weights: Vec<f64>,
187 cumulative: Vec<f64>,
189 num_samples: usize,
190 replacement: bool,
191}
192
193impl WeightedRandomSampler {
194 #[must_use]
196 pub fn new(weights: Vec<f64>, num_samples: usize, replacement: bool) -> Self {
197 let cumulative = Self::build_cumulative(&weights);
198 Self {
199 weights,
200 cumulative,
201 num_samples,
202 replacement,
203 }
204 }
205
206 fn build_cumulative(weights: &[f64]) -> Vec<f64> {
208 let mut cum = Vec::with_capacity(weights.len());
209 let mut total = 0.0;
210 for &w in weights {
211 total += w;
212 cum.push(total);
213 }
214 cum
215 }
216
217 fn sample_index(&self) -> usize {
219 let total = *self.cumulative.last().unwrap_or(&1.0);
220 let threshold: f64 = rand::thread_rng().r#gen::<f64>() * total;
221
222 match self
224 .cumulative
225 .binary_search_by(|c| c.partial_cmp(&threshold).unwrap())
226 {
227 Ok(i) => i,
228 Err(i) => i.min(self.cumulative.len() - 1),
229 }
230 }
231}
232
233impl Sampler for WeightedRandomSampler {
234 fn len(&self) -> usize {
235 self.num_samples
236 }
237
238 fn iter(&self) -> Box<dyn Iterator<Item = usize> + '_> {
239 if self.replacement {
240 Box::new(WeightedIter::new(self))
241 } else {
242 let mut indices = Vec::with_capacity(self.num_samples);
244 let mut available: Vec<usize> = (0..self.weights.len()).collect();
245 let mut weights = self.weights.clone();
246 let mut cumulative = self.cumulative.clone();
247
248 while indices.len() < self.num_samples && !available.is_empty() {
249 let total = *cumulative.last().unwrap_or(&0.0);
250 if total <= 0.0 {
251 break;
252 }
253
254 let threshold: f64 = rand::thread_rng().r#gen::<f64>() * total;
255 let selected =
256 match cumulative.binary_search_by(|c| c.partial_cmp(&threshold).unwrap()) {
257 Ok(i) => i,
258 Err(i) => i.min(cumulative.len() - 1),
259 };
260
261 indices.push(available[selected]);
262
263 available.swap_remove(selected);
265 weights.swap_remove(selected);
266
267 cumulative = Self::build_cumulative(&weights);
269 }
270
271 Box::new(indices.into_iter())
272 }
273 }
274}
275
276struct WeightedIter<'a> {
278 sampler: &'a WeightedRandomSampler,
279 remaining: usize,
280}
281
282impl<'a> WeightedIter<'a> {
283 fn new(sampler: &'a WeightedRandomSampler) -> Self {
284 Self {
285 sampler,
286 remaining: sampler.num_samples,
287 }
288 }
289}
290
291impl Iterator for WeightedIter<'_> {
292 type Item = usize;
293
294 fn next(&mut self) -> Option<Self::Item> {
295 if self.remaining == 0 {
296 return None;
297 }
298 self.remaining -= 1;
299 Some(self.sampler.sample_index())
300 }
301}
302
303pub struct BatchSampler<S: Sampler> {
309 sampler: S,
310 batch_size: usize,
311 drop_last: bool,
312}
313
314impl<S: Sampler> BatchSampler<S> {
315 pub fn new(sampler: S, batch_size: usize, drop_last: bool) -> Self {
317 Self {
318 sampler,
319 batch_size,
320 drop_last,
321 }
322 }
323
324 pub fn iter(&self) -> BatchIter {
326 let indices: Vec<usize> = self.sampler.iter().collect();
327 BatchIter {
328 indices,
329 batch_size: self.batch_size,
330 drop_last: self.drop_last,
331 position: 0,
332 }
333 }
334
335 pub fn len(&self) -> usize {
337 let total = self.sampler.len();
338 if self.drop_last {
339 total / self.batch_size
340 } else {
341 total.div_ceil(self.batch_size)
342 }
343 }
344
345 pub fn is_empty(&self) -> bool {
347 self.len() == 0
348 }
349}
350
351pub struct BatchIter {
353 indices: Vec<usize>,
354 batch_size: usize,
355 drop_last: bool,
356 position: usize,
357}
358
359impl Iterator for BatchIter {
360 type Item = Vec<usize>;
361
362 fn next(&mut self) -> Option<Self::Item> {
363 if self.position >= self.indices.len() {
364 return None;
365 }
366
367 let end = (self.position + self.batch_size).min(self.indices.len());
368 let batch: Vec<usize> = self.indices[self.position..end].to_vec();
369
370 if batch.len() < self.batch_size && self.drop_last {
371 return None;
372 }
373
374 self.position = end;
375 Some(batch)
376 }
377}
378
379#[cfg(test)]
384mod tests {
385 use super::*;
386
387 #[test]
388 fn test_sequential_sampler() {
389 let sampler = SequentialSampler::new(5);
390 let indices: Vec<usize> = sampler.iter().collect();
391 assert_eq!(indices, vec![0, 1, 2, 3, 4]);
392 }
393
394 #[test]
395 fn test_random_sampler() {
396 let sampler = RandomSampler::new(10);
397 let indices: Vec<usize> = sampler.iter().collect();
398
399 assert_eq!(indices.len(), 10);
400 let mut sorted = indices.clone();
402 sorted.sort_unstable();
403 sorted.dedup();
404 assert_eq!(sorted.len(), 10);
405 }
406
407 #[test]
408 fn test_random_sampler_with_replacement() {
409 let sampler = RandomSampler::with_replacement(5, 20);
410 let indices: Vec<usize> = sampler.iter().collect();
411
412 assert_eq!(indices.len(), 20);
413 assert!(indices.iter().all(|&i| i < 5));
415 }
416
417 #[test]
418 fn test_subset_random_sampler() {
419 let sampler = SubsetRandomSampler::new(vec![0, 5, 10, 15]);
420 let indices: Vec<usize> = sampler.iter().collect();
421
422 assert_eq!(indices.len(), 4);
423 let mut sorted = indices.clone();
425 sorted.sort_unstable();
426 assert_eq!(sorted, vec![0, 5, 10, 15]);
427 }
428
429 #[test]
430 fn test_weighted_random_sampler() {
431 let sampler = WeightedRandomSampler::new(vec![100.0, 1.0, 1.0, 1.0], 100, true);
433 let indices: Vec<usize> = sampler.iter().collect();
434
435 assert_eq!(indices.len(), 100);
436 let zeros = indices.iter().filter(|&&i| i == 0).count();
438 assert!(zeros > 50, "Expected mostly zeros, got {zeros}");
439 }
440
441 #[test]
442 fn test_batch_sampler() {
443 let base = SequentialSampler::new(10);
444 let sampler = BatchSampler::new(base, 3, false);
445
446 let batches: Vec<Vec<usize>> = sampler.iter().collect();
447 assert_eq!(batches.len(), 4); assert_eq!(batches[0], vec![0, 1, 2]);
450 assert_eq!(batches[1], vec![3, 4, 5]);
451 assert_eq!(batches[2], vec![6, 7, 8]);
452 assert_eq!(batches[3], vec![9]); }
454
455 #[test]
456 fn test_batch_sampler_drop_last() {
457 let base = SequentialSampler::new(10);
458 let sampler = BatchSampler::new(base, 3, true);
459
460 let batches: Vec<Vec<usize>> = sampler.iter().collect();
461 assert_eq!(batches.len(), 3); assert_eq!(batches[0], vec![0, 1, 2]);
464 assert_eq!(batches[1], vec![3, 4, 5]);
465 assert_eq!(batches[2], vec![6, 7, 8]);
466 }
467}