1use rand::seq::SliceRandom;
9use rand::Rng;
10
11pub trait Sampler: Send + Sync {
19 fn len(&self) -> usize;
21
22 fn is_empty(&self) -> bool {
24 self.len() == 0
25 }
26
27 fn iter(&self) -> Box<dyn Iterator<Item = usize> + '_>;
29}
30
31pub struct SequentialSampler {
37 len: usize,
38}
39
40impl SequentialSampler {
41 #[must_use]
43 pub fn new(len: usize) -> Self {
44 Self { len }
45 }
46}
47
48impl Sampler for SequentialSampler {
49 fn len(&self) -> usize {
50 self.len
51 }
52
53 fn iter(&self) -> Box<dyn Iterator<Item = usize> + '_> {
54 Box::new(0..self.len)
55 }
56}
57
58pub struct RandomSampler {
64 len: usize,
65 replacement: bool,
66 num_samples: Option<usize>,
67}
68
69impl RandomSampler {
70 #[must_use]
72 pub fn new(len: usize) -> Self {
73 Self {
74 len,
75 replacement: false,
76 num_samples: None,
77 }
78 }
79
80 #[must_use]
82 pub fn with_replacement(len: usize, num_samples: usize) -> Self {
83 Self {
84 len,
85 replacement: true,
86 num_samples: Some(num_samples),
87 }
88 }
89}
90
91impl Sampler for RandomSampler {
92 fn len(&self) -> usize {
93 self.num_samples.unwrap_or(self.len)
94 }
95
96 fn iter(&self) -> Box<dyn Iterator<Item = usize> + '_> {
97 if self.replacement {
98 let len = self.len;
100 let num = self.num_samples.unwrap_or(len);
101 Box::new(RandomWithReplacementIter::new(len, num))
102 } else {
103 let mut indices: Vec<usize> = (0..self.len).collect();
105 indices.shuffle(&mut rand::thread_rng());
106 Box::new(indices.into_iter())
107 }
108 }
109}
110
111struct RandomWithReplacementIter {
113 len: usize,
114 remaining: usize,
115}
116
117impl RandomWithReplacementIter {
118 fn new(len: usize, num_samples: usize) -> Self {
119 Self {
120 len,
121 remaining: num_samples,
122 }
123 }
124}
125
126impl Iterator for RandomWithReplacementIter {
127 type Item = usize;
128
129 fn next(&mut self) -> Option<Self::Item> {
130 if self.remaining == 0 {
131 return None;
132 }
133 self.remaining -= 1;
134 Some(rand::thread_rng().gen_range(0..self.len))
135 }
136}
137
138pub struct SubsetRandomSampler {
144 indices: Vec<usize>,
145}
146
147impl SubsetRandomSampler {
148 #[must_use]
150 pub fn new(indices: Vec<usize>) -> Self {
151 Self { indices }
152 }
153}
154
155impl Sampler for SubsetRandomSampler {
156 fn len(&self) -> usize {
157 self.indices.len()
158 }
159
160 fn iter(&self) -> Box<dyn Iterator<Item = usize> + '_> {
161 let mut shuffled = self.indices.clone();
162 shuffled.shuffle(&mut rand::thread_rng());
163 Box::new(shuffled.into_iter())
164 }
165}
166
167pub struct WeightedRandomSampler {
173 weights: Vec<f64>,
174 num_samples: usize,
175 replacement: bool,
176}
177
178impl WeightedRandomSampler {
179 #[must_use]
181 pub fn new(weights: Vec<f64>, num_samples: usize, replacement: bool) -> Self {
182 Self {
183 weights,
184 num_samples,
185 replacement,
186 }
187 }
188
189 fn sample_index(&self) -> usize {
191 let total: f64 = self.weights.iter().sum();
192 let mut cumulative = 0.0;
193 let threshold: f64 = rand::thread_rng().gen::<f64>() * total;
194
195 for (i, &weight) in self.weights.iter().enumerate() {
196 cumulative += weight;
197 if cumulative > threshold {
198 return i;
199 }
200 }
201 self.weights.len() - 1
202 }
203}
204
205impl Sampler for WeightedRandomSampler {
206 fn len(&self) -> usize {
207 self.num_samples
208 }
209
210 fn iter(&self) -> Box<dyn Iterator<Item = usize> + '_> {
211 if self.replacement {
212 Box::new(WeightedIter::new(self))
213 } else {
214 let mut indices = Vec::with_capacity(self.num_samples);
216 let mut available: Vec<usize> = (0..self.weights.len()).collect();
217 let mut weights = self.weights.clone();
218
219 while indices.len() < self.num_samples && !available.is_empty() {
220 let total: f64 = weights.iter().sum();
221 if total <= 0.0 {
222 break;
223 }
224
225 let threshold: f64 = rand::thread_rng().gen::<f64>() * total;
226 let mut cumulative = 0.0;
227 let mut selected = 0;
228
229 for (i, &weight) in weights.iter().enumerate() {
230 cumulative += weight;
231 if cumulative > threshold {
232 selected = i;
233 break;
234 }
235 }
236
237 indices.push(available[selected]);
238 available.remove(selected);
239 weights.remove(selected);
240 }
241
242 Box::new(indices.into_iter())
243 }
244 }
245}
246
247struct WeightedIter<'a> {
249 sampler: &'a WeightedRandomSampler,
250 remaining: usize,
251}
252
253impl<'a> WeightedIter<'a> {
254 fn new(sampler: &'a WeightedRandomSampler) -> Self {
255 Self {
256 sampler,
257 remaining: sampler.num_samples,
258 }
259 }
260}
261
262impl Iterator for WeightedIter<'_> {
263 type Item = usize;
264
265 fn next(&mut self) -> Option<Self::Item> {
266 if self.remaining == 0 {
267 return None;
268 }
269 self.remaining -= 1;
270 Some(self.sampler.sample_index())
271 }
272}
273
274pub struct BatchSampler<S: Sampler> {
280 sampler: S,
281 batch_size: usize,
282 drop_last: bool,
283}
284
285impl<S: Sampler> BatchSampler<S> {
286 pub fn new(sampler: S, batch_size: usize, drop_last: bool) -> Self {
288 Self {
289 sampler,
290 batch_size,
291 drop_last,
292 }
293 }
294
295 pub fn iter(&self) -> BatchIter {
297 let indices: Vec<usize> = self.sampler.iter().collect();
298 BatchIter {
299 indices,
300 batch_size: self.batch_size,
301 drop_last: self.drop_last,
302 position: 0,
303 }
304 }
305
306 pub fn len(&self) -> usize {
308 let total = self.sampler.len();
309 if self.drop_last {
310 total / self.batch_size
311 } else {
312 total.div_ceil(self.batch_size)
313 }
314 }
315
316 pub fn is_empty(&self) -> bool {
318 self.len() == 0
319 }
320}
321
322pub struct BatchIter {
324 indices: Vec<usize>,
325 batch_size: usize,
326 drop_last: bool,
327 position: usize,
328}
329
330impl Iterator for BatchIter {
331 type Item = Vec<usize>;
332
333 fn next(&mut self) -> Option<Self::Item> {
334 if self.position >= self.indices.len() {
335 return None;
336 }
337
338 let end = (self.position + self.batch_size).min(self.indices.len());
339 let batch: Vec<usize> = self.indices[self.position..end].to_vec();
340
341 if batch.len() < self.batch_size && self.drop_last {
342 return None;
343 }
344
345 self.position = end;
346 Some(batch)
347 }
348}
349
350#[cfg(test)]
355mod tests {
356 use super::*;
357
358 #[test]
359 fn test_sequential_sampler() {
360 let sampler = SequentialSampler::new(5);
361 let indices: Vec<usize> = sampler.iter().collect();
362 assert_eq!(indices, vec![0, 1, 2, 3, 4]);
363 }
364
365 #[test]
366 fn test_random_sampler() {
367 let sampler = RandomSampler::new(10);
368 let indices: Vec<usize> = sampler.iter().collect();
369
370 assert_eq!(indices.len(), 10);
371 let mut sorted = indices.clone();
373 sorted.sort_unstable();
374 sorted.dedup();
375 assert_eq!(sorted.len(), 10);
376 }
377
378 #[test]
379 fn test_random_sampler_with_replacement() {
380 let sampler = RandomSampler::with_replacement(5, 20);
381 let indices: Vec<usize> = sampler.iter().collect();
382
383 assert_eq!(indices.len(), 20);
384 assert!(indices.iter().all(|&i| i < 5));
386 }
387
388 #[test]
389 fn test_subset_random_sampler() {
390 let sampler = SubsetRandomSampler::new(vec![0, 5, 10, 15]);
391 let indices: Vec<usize> = sampler.iter().collect();
392
393 assert_eq!(indices.len(), 4);
394 let mut sorted = indices.clone();
396 sorted.sort_unstable();
397 assert_eq!(sorted, vec![0, 5, 10, 15]);
398 }
399
400 #[test]
401 fn test_weighted_random_sampler() {
402 let sampler = WeightedRandomSampler::new(vec![100.0, 1.0, 1.0, 1.0], 100, true);
404 let indices: Vec<usize> = sampler.iter().collect();
405
406 assert_eq!(indices.len(), 100);
407 let zeros = indices.iter().filter(|&&i| i == 0).count();
409 assert!(zeros > 50, "Expected mostly zeros, got {zeros}");
410 }
411
412 #[test]
413 fn test_batch_sampler() {
414 let base = SequentialSampler::new(10);
415 let sampler = BatchSampler::new(base, 3, false);
416
417 let batches: Vec<Vec<usize>> = sampler.iter().collect();
418 assert_eq!(batches.len(), 4); assert_eq!(batches[0], vec![0, 1, 2]);
421 assert_eq!(batches[1], vec![3, 4, 5]);
422 assert_eq!(batches[2], vec![6, 7, 8]);
423 assert_eq!(batches[3], vec![9]); }
425
426 #[test]
427 fn test_batch_sampler_drop_last() {
428 let base = SequentialSampler::new(10);
429 let sampler = BatchSampler::new(base, 3, true);
430
431 let batches: Vec<Vec<usize>> = sampler.iter().collect();
432 assert_eq!(batches.len(), 3); assert_eq!(batches[0], vec![0, 1, 2]);
435 assert_eq!(batches[1], vec![3, 4, 5]);
436 assert_eq!(batches[2], vec![6, 7, 8]);
437 }
438}