1use rand::seq::SliceRandom;
9use rand::Rng;
10
11pub trait Sampler: Send + Sync {
19 fn len(&self) -> usize;
21
22 fn is_empty(&self) -> bool {
24 self.len() == 0
25 }
26
27 fn iter(&self) -> Box<dyn Iterator<Item = usize> + '_>;
29}
30
31pub struct SequentialSampler {
37 len: usize,
38}
39
40impl SequentialSampler {
41 #[must_use] pub fn new(len: usize) -> Self {
43 Self { len }
44 }
45}
46
47impl Sampler for SequentialSampler {
48 fn len(&self) -> usize {
49 self.len
50 }
51
52 fn iter(&self) -> Box<dyn Iterator<Item = usize> + '_> {
53 Box::new(0..self.len)
54 }
55}
56
57pub struct RandomSampler {
63 len: usize,
64 replacement: bool,
65 num_samples: Option<usize>,
66}
67
68impl RandomSampler {
69 #[must_use] pub fn new(len: usize) -> Self {
71 Self {
72 len,
73 replacement: false,
74 num_samples: None,
75 }
76 }
77
78 #[must_use] pub fn with_replacement(len: usize, num_samples: usize) -> Self {
80 Self {
81 len,
82 replacement: true,
83 num_samples: Some(num_samples),
84 }
85 }
86}
87
88impl Sampler for RandomSampler {
89 fn len(&self) -> usize {
90 self.num_samples.unwrap_or(self.len)
91 }
92
93 fn iter(&self) -> Box<dyn Iterator<Item = usize> + '_> {
94 if self.replacement {
95 let len = self.len;
97 let num = self.num_samples.unwrap_or(len);
98 Box::new(RandomWithReplacementIter::new(len, num))
99 } else {
100 let mut indices: Vec<usize> = (0..self.len).collect();
102 indices.shuffle(&mut rand::thread_rng());
103 Box::new(indices.into_iter())
104 }
105 }
106}
107
108struct RandomWithReplacementIter {
110 len: usize,
111 remaining: usize,
112}
113
114impl RandomWithReplacementIter {
115 fn new(len: usize, num_samples: usize) -> Self {
116 Self {
117 len,
118 remaining: num_samples,
119 }
120 }
121}
122
123impl Iterator for RandomWithReplacementIter {
124 type Item = usize;
125
126 fn next(&mut self) -> Option<Self::Item> {
127 if self.remaining == 0 {
128 return None;
129 }
130 self.remaining -= 1;
131 Some(rand::thread_rng().gen_range(0..self.len))
132 }
133}
134
135pub struct SubsetRandomSampler {
141 indices: Vec<usize>,
142}
143
144impl SubsetRandomSampler {
145 #[must_use] pub fn new(indices: Vec<usize>) -> Self {
147 Self { indices }
148 }
149}
150
151impl Sampler for SubsetRandomSampler {
152 fn len(&self) -> usize {
153 self.indices.len()
154 }
155
156 fn iter(&self) -> Box<dyn Iterator<Item = usize> + '_> {
157 let mut shuffled = self.indices.clone();
158 shuffled.shuffle(&mut rand::thread_rng());
159 Box::new(shuffled.into_iter())
160 }
161}
162
163pub struct WeightedRandomSampler {
169 weights: Vec<f64>,
170 num_samples: usize,
171 replacement: bool,
172}
173
174impl WeightedRandomSampler {
175 #[must_use] pub fn new(weights: Vec<f64>, num_samples: usize, replacement: bool) -> Self {
177 Self {
178 weights,
179 num_samples,
180 replacement,
181 }
182 }
183
184 fn sample_index(&self) -> usize {
186 let total: f64 = self.weights.iter().sum();
187 let mut cumulative = 0.0;
188 let threshold: f64 = rand::thread_rng().gen::<f64>() * total;
189
190 for (i, &weight) in self.weights.iter().enumerate() {
191 cumulative += weight;
192 if cumulative > threshold {
193 return i;
194 }
195 }
196 self.weights.len() - 1
197 }
198}
199
200impl Sampler for WeightedRandomSampler {
201 fn len(&self) -> usize {
202 self.num_samples
203 }
204
205 fn iter(&self) -> Box<dyn Iterator<Item = usize> + '_> {
206 if self.replacement {
207 Box::new(WeightedIter::new(self))
208 } else {
209 let mut indices = Vec::with_capacity(self.num_samples);
211 let mut available: Vec<usize> = (0..self.weights.len()).collect();
212 let mut weights = self.weights.clone();
213
214 while indices.len() < self.num_samples && !available.is_empty() {
215 let total: f64 = weights.iter().sum();
216 if total <= 0.0 {
217 break;
218 }
219
220 let threshold: f64 = rand::thread_rng().gen::<f64>() * total;
221 let mut cumulative = 0.0;
222 let mut selected = 0;
223
224 for (i, &weight) in weights.iter().enumerate() {
225 cumulative += weight;
226 if cumulative > threshold {
227 selected = i;
228 break;
229 }
230 }
231
232 indices.push(available[selected]);
233 available.remove(selected);
234 weights.remove(selected);
235 }
236
237 Box::new(indices.into_iter())
238 }
239 }
240}
241
242struct WeightedIter<'a> {
244 sampler: &'a WeightedRandomSampler,
245 remaining: usize,
246}
247
248impl<'a> WeightedIter<'a> {
249 fn new(sampler: &'a WeightedRandomSampler) -> Self {
250 Self {
251 sampler,
252 remaining: sampler.num_samples,
253 }
254 }
255}
256
257impl Iterator for WeightedIter<'_> {
258 type Item = usize;
259
260 fn next(&mut self) -> Option<Self::Item> {
261 if self.remaining == 0 {
262 return None;
263 }
264 self.remaining -= 1;
265 Some(self.sampler.sample_index())
266 }
267}
268
269pub struct BatchSampler<S: Sampler> {
275 sampler: S,
276 batch_size: usize,
277 drop_last: bool,
278}
279
280impl<S: Sampler> BatchSampler<S> {
281 pub fn new(sampler: S, batch_size: usize, drop_last: bool) -> Self {
283 Self {
284 sampler,
285 batch_size,
286 drop_last,
287 }
288 }
289
290 pub fn iter(&self) -> BatchIter {
292 let indices: Vec<usize> = self.sampler.iter().collect();
293 BatchIter {
294 indices,
295 batch_size: self.batch_size,
296 drop_last: self.drop_last,
297 position: 0,
298 }
299 }
300
301 pub fn len(&self) -> usize {
303 let total = self.sampler.len();
304 if self.drop_last {
305 total / self.batch_size
306 } else {
307 total.div_ceil(self.batch_size)
308 }
309 }
310
311 pub fn is_empty(&self) -> bool {
313 self.len() == 0
314 }
315}
316
317pub struct BatchIter {
319 indices: Vec<usize>,
320 batch_size: usize,
321 drop_last: bool,
322 position: usize,
323}
324
325impl Iterator for BatchIter {
326 type Item = Vec<usize>;
327
328 fn next(&mut self) -> Option<Self::Item> {
329 if self.position >= self.indices.len() {
330 return None;
331 }
332
333 let end = (self.position + self.batch_size).min(self.indices.len());
334 let batch: Vec<usize> = self.indices[self.position..end].to_vec();
335
336 if batch.len() < self.batch_size && self.drop_last {
337 return None;
338 }
339
340 self.position = end;
341 Some(batch)
342 }
343}
344
345#[cfg(test)]
350mod tests {
351 use super::*;
352
353 #[test]
354 fn test_sequential_sampler() {
355 let sampler = SequentialSampler::new(5);
356 let indices: Vec<usize> = sampler.iter().collect();
357 assert_eq!(indices, vec![0, 1, 2, 3, 4]);
358 }
359
360 #[test]
361 fn test_random_sampler() {
362 let sampler = RandomSampler::new(10);
363 let indices: Vec<usize> = sampler.iter().collect();
364
365 assert_eq!(indices.len(), 10);
366 let mut sorted = indices.clone();
368 sorted.sort_unstable();
369 sorted.dedup();
370 assert_eq!(sorted.len(), 10);
371 }
372
373 #[test]
374 fn test_random_sampler_with_replacement() {
375 let sampler = RandomSampler::with_replacement(5, 20);
376 let indices: Vec<usize> = sampler.iter().collect();
377
378 assert_eq!(indices.len(), 20);
379 assert!(indices.iter().all(|&i| i < 5));
381 }
382
383 #[test]
384 fn test_subset_random_sampler() {
385 let sampler = SubsetRandomSampler::new(vec![0, 5, 10, 15]);
386 let indices: Vec<usize> = sampler.iter().collect();
387
388 assert_eq!(indices.len(), 4);
389 let mut sorted = indices.clone();
391 sorted.sort_unstable();
392 assert_eq!(sorted, vec![0, 5, 10, 15]);
393 }
394
395 #[test]
396 fn test_weighted_random_sampler() {
397 let sampler = WeightedRandomSampler::new(vec![100.0, 1.0, 1.0, 1.0], 100, true);
399 let indices: Vec<usize> = sampler.iter().collect();
400
401 assert_eq!(indices.len(), 100);
402 let zeros = indices.iter().filter(|&&i| i == 0).count();
404 assert!(zeros > 50, "Expected mostly zeros, got {zeros}");
405 }
406
407 #[test]
408 fn test_batch_sampler() {
409 let base = SequentialSampler::new(10);
410 let sampler = BatchSampler::new(base, 3, false);
411
412 let batches: Vec<Vec<usize>> = sampler.iter().collect();
413 assert_eq!(batches.len(), 4); assert_eq!(batches[0], vec![0, 1, 2]);
416 assert_eq!(batches[1], vec![3, 4, 5]);
417 assert_eq!(batches[2], vec![6, 7, 8]);
418 assert_eq!(batches[3], vec![9]); }
420
421 #[test]
422 fn test_batch_sampler_drop_last() {
423 let base = SequentialSampler::new(10);
424 let sampler = BatchSampler::new(base, 3, true);
425
426 let batches: Vec<Vec<usize>> = sampler.iter().collect();
427 assert_eq!(batches.len(), 3); assert_eq!(batches[0], vec![0, 1, 2]);
430 assert_eq!(batches[1], vec![3, 4, 5]);
431 assert_eq!(batches[2], vec![6, 7, 8]);
432 }
433}