1use rand::Rng;
18use rand::seq::SliceRandom;
19
20pub trait Sampler: Send + Sync {
28 fn len(&self) -> usize;
30
31 fn is_empty(&self) -> bool {
33 self.len() == 0
34 }
35
36 fn iter(&self) -> Box<dyn Iterator<Item = usize> + '_>;
38}
39
40pub struct SequentialSampler {
46 len: usize,
47}
48
49impl SequentialSampler {
50 #[must_use]
52 pub fn new(len: usize) -> Self {
53 Self { len }
54 }
55}
56
57impl Sampler for SequentialSampler {
58 fn len(&self) -> usize {
59 self.len
60 }
61
62 fn iter(&self) -> Box<dyn Iterator<Item = usize> + '_> {
63 Box::new(0..self.len)
64 }
65}
66
67pub struct RandomSampler {
73 len: usize,
74 replacement: bool,
75 num_samples: Option<usize>,
76}
77
78impl RandomSampler {
79 #[must_use]
81 pub fn new(len: usize) -> Self {
82 Self {
83 len,
84 replacement: false,
85 num_samples: None,
86 }
87 }
88
89 #[must_use]
91 pub fn with_replacement(len: usize, num_samples: usize) -> Self {
92 Self {
93 len,
94 replacement: true,
95 num_samples: Some(num_samples),
96 }
97 }
98}
99
100impl Sampler for RandomSampler {
101 fn len(&self) -> usize {
102 self.num_samples.unwrap_or(self.len)
103 }
104
105 fn iter(&self) -> Box<dyn Iterator<Item = usize> + '_> {
106 if self.replacement {
107 let len = self.len;
109 let num = self.num_samples.unwrap_or(len);
110 Box::new(RandomWithReplacementIter::new(len, num))
111 } else {
112 let mut indices: Vec<usize> = (0..self.len).collect();
114 indices.shuffle(&mut rand::thread_rng());
115 Box::new(indices.into_iter())
116 }
117 }
118}
119
120struct RandomWithReplacementIter {
122 len: usize,
123 remaining: usize,
124}
125
126impl RandomWithReplacementIter {
127 fn new(len: usize, num_samples: usize) -> Self {
128 Self {
129 len,
130 remaining: num_samples,
131 }
132 }
133}
134
135impl Iterator for RandomWithReplacementIter {
136 type Item = usize;
137
138 fn next(&mut self) -> Option<Self::Item> {
139 if self.remaining == 0 {
140 return None;
141 }
142 self.remaining -= 1;
143 Some(rand::thread_rng().gen_range(0..self.len))
144 }
145}
146
147pub struct SubsetRandomSampler {
153 indices: Vec<usize>,
154}
155
156impl SubsetRandomSampler {
157 #[must_use]
159 pub fn new(indices: Vec<usize>) -> Self {
160 Self { indices }
161 }
162}
163
164impl Sampler for SubsetRandomSampler {
165 fn len(&self) -> usize {
166 self.indices.len()
167 }
168
169 fn iter(&self) -> Box<dyn Iterator<Item = usize> + '_> {
170 let mut shuffled = self.indices.clone();
171 shuffled.shuffle(&mut rand::thread_rng());
172 Box::new(shuffled.into_iter())
173 }
174}
175
176pub struct WeightedRandomSampler {
182 weights: Vec<f64>,
183 num_samples: usize,
184 replacement: bool,
185}
186
187impl WeightedRandomSampler {
188 #[must_use]
190 pub fn new(weights: Vec<f64>, num_samples: usize, replacement: bool) -> Self {
191 Self {
192 weights,
193 num_samples,
194 replacement,
195 }
196 }
197
198 fn sample_index(&self) -> usize {
200 let total: f64 = self.weights.iter().sum();
201 let mut cumulative = 0.0;
202 let threshold: f64 = rand::thread_rng().r#gen::<f64>() * total;
203
204 for (i, &weight) in self.weights.iter().enumerate() {
205 cumulative += weight;
206 if cumulative > threshold {
207 return i;
208 }
209 }
210 self.weights.len() - 1
211 }
212}
213
214impl Sampler for WeightedRandomSampler {
215 fn len(&self) -> usize {
216 self.num_samples
217 }
218
219 fn iter(&self) -> Box<dyn Iterator<Item = usize> + '_> {
220 if self.replacement {
221 Box::new(WeightedIter::new(self))
222 } else {
223 let mut indices = Vec::with_capacity(self.num_samples);
225 let mut available: Vec<usize> = (0..self.weights.len()).collect();
226 let mut weights = self.weights.clone();
227
228 while indices.len() < self.num_samples && !available.is_empty() {
229 let total: f64 = weights.iter().sum();
230 if total <= 0.0 {
231 break;
232 }
233
234 let threshold: f64 = rand::thread_rng().r#gen::<f64>() * total;
235 let mut cumulative = 0.0;
236 let mut selected = 0;
237
238 for (i, &weight) in weights.iter().enumerate() {
239 cumulative += weight;
240 if cumulative > threshold {
241 selected = i;
242 break;
243 }
244 }
245
246 indices.push(available[selected]);
247 available.remove(selected);
248 weights.remove(selected);
249 }
250
251 Box::new(indices.into_iter())
252 }
253 }
254}
255
256struct WeightedIter<'a> {
258 sampler: &'a WeightedRandomSampler,
259 remaining: usize,
260}
261
262impl<'a> WeightedIter<'a> {
263 fn new(sampler: &'a WeightedRandomSampler) -> Self {
264 Self {
265 sampler,
266 remaining: sampler.num_samples,
267 }
268 }
269}
270
271impl Iterator for WeightedIter<'_> {
272 type Item = usize;
273
274 fn next(&mut self) -> Option<Self::Item> {
275 if self.remaining == 0 {
276 return None;
277 }
278 self.remaining -= 1;
279 Some(self.sampler.sample_index())
280 }
281}
282
283pub struct BatchSampler<S: Sampler> {
289 sampler: S,
290 batch_size: usize,
291 drop_last: bool,
292}
293
294impl<S: Sampler> BatchSampler<S> {
295 pub fn new(sampler: S, batch_size: usize, drop_last: bool) -> Self {
297 Self {
298 sampler,
299 batch_size,
300 drop_last,
301 }
302 }
303
304 pub fn iter(&self) -> BatchIter {
306 let indices: Vec<usize> = self.sampler.iter().collect();
307 BatchIter {
308 indices,
309 batch_size: self.batch_size,
310 drop_last: self.drop_last,
311 position: 0,
312 }
313 }
314
315 pub fn len(&self) -> usize {
317 let total = self.sampler.len();
318 if self.drop_last {
319 total / self.batch_size
320 } else {
321 total.div_ceil(self.batch_size)
322 }
323 }
324
325 pub fn is_empty(&self) -> bool {
327 self.len() == 0
328 }
329}
330
331pub struct BatchIter {
333 indices: Vec<usize>,
334 batch_size: usize,
335 drop_last: bool,
336 position: usize,
337}
338
339impl Iterator for BatchIter {
340 type Item = Vec<usize>;
341
342 fn next(&mut self) -> Option<Self::Item> {
343 if self.position >= self.indices.len() {
344 return None;
345 }
346
347 let end = (self.position + self.batch_size).min(self.indices.len());
348 let batch: Vec<usize> = self.indices[self.position..end].to_vec();
349
350 if batch.len() < self.batch_size && self.drop_last {
351 return None;
352 }
353
354 self.position = end;
355 Some(batch)
356 }
357}
358
359#[cfg(test)]
364mod tests {
365 use super::*;
366
367 #[test]
368 fn test_sequential_sampler() {
369 let sampler = SequentialSampler::new(5);
370 let indices: Vec<usize> = sampler.iter().collect();
371 assert_eq!(indices, vec![0, 1, 2, 3, 4]);
372 }
373
374 #[test]
375 fn test_random_sampler() {
376 let sampler = RandomSampler::new(10);
377 let indices: Vec<usize> = sampler.iter().collect();
378
379 assert_eq!(indices.len(), 10);
380 let mut sorted = indices.clone();
382 sorted.sort_unstable();
383 sorted.dedup();
384 assert_eq!(sorted.len(), 10);
385 }
386
387 #[test]
388 fn test_random_sampler_with_replacement() {
389 let sampler = RandomSampler::with_replacement(5, 20);
390 let indices: Vec<usize> = sampler.iter().collect();
391
392 assert_eq!(indices.len(), 20);
393 assert!(indices.iter().all(|&i| i < 5));
395 }
396
397 #[test]
398 fn test_subset_random_sampler() {
399 let sampler = SubsetRandomSampler::new(vec![0, 5, 10, 15]);
400 let indices: Vec<usize> = sampler.iter().collect();
401
402 assert_eq!(indices.len(), 4);
403 let mut sorted = indices.clone();
405 sorted.sort_unstable();
406 assert_eq!(sorted, vec![0, 5, 10, 15]);
407 }
408
409 #[test]
410 fn test_weighted_random_sampler() {
411 let sampler = WeightedRandomSampler::new(vec![100.0, 1.0, 1.0, 1.0], 100, true);
413 let indices: Vec<usize> = sampler.iter().collect();
414
415 assert_eq!(indices.len(), 100);
416 let zeros = indices.iter().filter(|&&i| i == 0).count();
418 assert!(zeros > 50, "Expected mostly zeros, got {zeros}");
419 }
420
421 #[test]
422 fn test_batch_sampler() {
423 let base = SequentialSampler::new(10);
424 let sampler = BatchSampler::new(base, 3, false);
425
426 let batches: Vec<Vec<usize>> = sampler.iter().collect();
427 assert_eq!(batches.len(), 4); assert_eq!(batches[0], vec![0, 1, 2]);
430 assert_eq!(batches[1], vec![3, 4, 5]);
431 assert_eq!(batches[2], vec![6, 7, 8]);
432 assert_eq!(batches[3], vec![9]); }
434
435 #[test]
436 fn test_batch_sampler_drop_last() {
437 let base = SequentialSampler::new(10);
438 let sampler = BatchSampler::new(base, 3, true);
439
440 let batches: Vec<Vec<usize>> = sampler.iter().collect();
441 assert_eq!(batches.len(), 3); assert_eq!(batches[0], vec![0, 1, 2]);
444 assert_eq!(batches[1], vec![3, 4, 5]);
445 assert_eq!(batches[2], vec![6, 7, 8]);
446 }
447}