1use crate::Dataset;
2use crate::transform::{RngSource, SizeConfig};
3use rand::prelude::SliceRandom;
4use rand::{Rng, distr::Uniform, rngs::StdRng, seq::IteratorRandom};
5use std::{marker::PhantomData, ops::DerefMut, sync::Mutex};
6
7#[derive(Debug, Clone, PartialEq)]
9pub struct SamplerDatasetOptions {
10 pub replace_samples: bool,
12
13 pub size_config: SizeConfig,
15
16 pub rng_source: RngSource,
18}
19
20impl Default for SamplerDatasetOptions {
21 fn default() -> Self {
22 Self {
23 replace_samples: true,
24 size_config: SizeConfig::Default,
25 rng_source: RngSource::Default,
26 }
27 }
28}
29
30impl<T> From<Option<T>> for SamplerDatasetOptions
31where
32 T: Into<SamplerDatasetOptions>,
33{
34 fn from(option: Option<T>) -> Self {
35 match option {
36 Some(option) => option.into(),
37 None => Self::default(),
38 }
39 }
40}
41
42impl From<usize> for SamplerDatasetOptions {
43 fn from(size: usize) -> Self {
44 Self::default().with_replacement().with_fixed_size(size)
45 }
46}
47
48impl SamplerDatasetOptions {
49 pub fn with_replace_samples(self, replace_samples: bool) -> Self {
51 Self {
52 replace_samples,
53 ..self
54 }
55 }
56
57 pub fn with_replacement(self) -> Self {
59 self.with_replace_samples(true)
60 }
61
62 pub fn without_replacement(self) -> Self {
64 self.with_replace_samples(false)
65 }
66
67 pub fn with_size<S>(self, source: S) -> Self
69 where
70 S: Into<SizeConfig>,
71 {
72 Self {
73 size_config: source.into(),
74 ..self
75 }
76 }
77
78 pub fn with_source_size(self) -> Self {
80 self.with_size(SizeConfig::Default)
81 }
82
83 pub fn with_fixed_size(self, size: usize) -> Self {
85 self.with_size(size)
86 }
87
88 pub fn with_size_ratio(self, size_ratio: f64) -> Self {
90 self.with_size(size_ratio)
91 }
92
93 pub fn with_rng<R>(self, rng: R) -> Self
95 where
96 R: Into<RngSource>,
97 {
98 Self {
99 rng_source: rng.into(),
100 ..self
101 }
102 }
103
104 pub fn with_system_rng(self) -> Self {
106 self.with_rng(RngSource::Default)
107 }
108
109 pub fn with_seed(self, seed: u64) -> Self {
111 self.with_rng(seed)
112 }
113}
114
115pub struct SamplerDataset<D, I> {
128 dataset: D,
129 size: usize,
130 state: Mutex<SamplerState>,
131 input: PhantomData<I>,
132}
133enum SamplerState {
134 WithReplacement(StdRng),
135 WithoutReplacement(StdRng, Vec<usize>),
136}
137
138impl<D, I> SamplerDataset<D, I>
139where
140 D: Dataset<I>,
141 I: Send + Sync,
142{
143 pub fn new<O>(dataset: D, options: O) -> Self
190 where
191 O: Into<SamplerDatasetOptions>,
192 {
193 let options = options.into();
194 let size = options.size_config.resolve(dataset.len());
195 let rng = options.rng_source.into();
196 Self {
197 dataset,
198 size,
199 state: Mutex::new(match options.replace_samples {
200 true => SamplerState::WithReplacement(rng),
201 false => SamplerState::WithoutReplacement(rng, Vec::with_capacity(size)),
202 }),
203 input: PhantomData,
204 }
205 }
206
207 pub fn with_replacement(dataset: D, size: usize) -> Self {
214 Self::new(
215 dataset,
216 SamplerDatasetOptions::default()
217 .with_replacement()
218 .with_fixed_size(size),
219 )
220 }
221
222 pub fn without_replacement(dataset: D, size: usize) -> Self {
237 Self::new(
238 dataset,
239 SamplerDatasetOptions::default()
240 .without_replacement()
241 .with_fixed_size(size),
242 )
243 }
244
245 pub fn is_with_replacement(&self) -> bool {
251 match self.state.lock().unwrap().deref_mut() {
252 SamplerState::WithReplacement(_) => true,
253 SamplerState::WithoutReplacement(_, _) => false,
254 }
255 }
256
257 fn index(&self) -> usize {
258 match self.state.lock().unwrap().deref_mut() {
259 SamplerState::WithReplacement(rng) => {
260 rng.sample(Uniform::new(0, self.dataset.len()).unwrap())
261 }
262 SamplerState::WithoutReplacement(rng, indices) => {
263 if indices.is_empty() {
264 let idx_range = 0..self.dataset.len();
266 for _ in 0..(self.size / self.dataset.len()) {
267 indices.extend(idx_range.clone())
271 }
272
273 indices.extend(idx_range.choose_multiple(rng, self.size - indices.len()));
278
279 indices.shuffle(rng);
281 }
282
283 indices.pop().expect("Indices are refilled when empty.")
284 }
285 }
286 }
287}
288
289impl<D, I> Dataset<I> for SamplerDataset<D, I>
290where
291 D: Dataset<I>,
292 I: Send + Sync,
293{
294 fn get(&self, index: usize) -> Option<I> {
295 if index >= self.size {
296 return None;
297 }
298
299 self.dataset.get(self.index())
300 }
301
302 fn len(&self) -> usize {
303 self.size
304 }
305}
306
307#[cfg(test)]
308mod tests {
309 #![allow(clippy::bool_assert_comparison)]
310
311 use super::*;
312 use crate::FakeDataset;
313 use rand::SeedableRng;
314 use std::collections::HashMap;
315
316 #[test]
317 fn test_samplerdataset_options() {
318 let options = SamplerDatasetOptions::default();
319 assert_eq!(options.replace_samples, true);
320 assert_eq!(options.size_config, SizeConfig::Default);
321 assert_eq!(options.rng_source, RngSource::Default);
322
323 let options = options.with_replace_samples(false);
325 assert_eq!(options.replace_samples, false);
326 let options = options.with_replacement();
327 assert_eq!(options.replace_samples, true);
328 let options = options.without_replacement();
329 assert_eq!(options.replace_samples, false);
330
331 let options = options.with_size(SizeConfig::Default);
333 assert_eq!(options.size_config, SizeConfig::Default);
334 let options = options.with_source_size();
335 assert_eq!(options.size_config, SizeConfig::Default);
336 let options = options.with_fixed_size(10);
337 assert_eq!(options.size_config, SizeConfig::Fixed(10));
338 let options = options.with_size_ratio(1.5);
339 assert_eq!(options.size_config, SizeConfig::Ratio(1.5));
340
341 let options = options.with_system_rng();
343 assert_eq!(options.rng_source, RngSource::Default);
344 let options = options.with_seed(42);
345 assert_eq!(options.rng_source, RngSource::Seed(42));
346 let rng = StdRng::seed_from_u64(9);
347 let options = options.with_rng(rng.clone());
348 assert_eq!(options.rng_source, RngSource::Rng(rng.clone()));
349 }
350
351 #[test]
352 fn sampler_dataset_constructors_test() {
353 let ds = SamplerDataset::new(FakeDataset::<u32>::new(10), 15);
354 assert_eq!(ds.len(), 15);
355 assert_eq!(ds.dataset.len(), 10);
356 assert!(ds.is_with_replacement());
357
358 let ds = SamplerDataset::with_replacement(FakeDataset::<u32>::new(10), 15);
359 assert_eq!(ds.len(), 15);
360 assert_eq!(ds.dataset.len(), 10);
361 assert!(ds.is_with_replacement());
362
363 let ds = SamplerDataset::without_replacement(FakeDataset::<u32>::new(10), 15);
364 assert_eq!(ds.len(), 15);
365 assert_eq!(ds.dataset.len(), 10);
366 assert!(!ds.is_with_replacement());
367 }
368
369 #[test]
370 fn sampler_dataset_with_replacement_iter() {
371 let factor = 3;
372 let len_original = 10;
373 let dataset_sampler = SamplerDataset::with_replacement(
374 FakeDataset::<String>::new(len_original),
375 len_original * factor,
376 );
377 let mut total = 0;
378
379 for _item in dataset_sampler.iter() {
380 total += 1;
381 }
382
383 assert_eq!(total, factor * len_original);
384 }
385
386 #[test]
387 fn sampler_dataset_without_replacement_bucket_test() {
388 let factor = 3;
389 let len_original = 10;
390
391 let dataset_sampler = SamplerDataset::new(
392 FakeDataset::<String>::new(len_original),
393 SamplerDatasetOptions::default()
394 .without_replacement()
395 .with_size_ratio(factor as f64),
396 );
397
398 let mut buckets = HashMap::new();
399
400 for item in dataset_sampler.iter() {
401 let count = match buckets.get(&item) {
402 Some(count) => count + 1,
403 None => 1,
404 };
405
406 buckets.insert(item, count);
407 }
408
409 let mut total = 0;
410 for count in buckets.into_values() {
411 assert_eq!(count, factor);
412 total += count;
413 }
414 assert_eq!(total, factor * len_original);
415 }
416
417 #[test]
418 fn sampler_dataset_without_replacement_uniform_order_test() {
419 let size = 1000;
421 let dataset_sampler =
422 SamplerDataset::without_replacement(FakeDataset::<i32>::new(size), size);
423
424 let indices: Vec<_> = (0..size).map(|_| dataset_sampler.index()).collect();
425 let mean_delta = indices
426 .windows(2)
427 .map(|pair| pair[1].abs_diff(pair[0]))
428 .sum::<usize>() as f64
429 / (size - 1) as f64;
430
431 let expected = (size + 2) as f64 / 3.0;
432
433 assert!(
434 (mean_delta - expected).abs() <= 0.25 * expected,
435 "Sampled indices are not uniformly distributed: mean_delta: {mean_delta}, expected: {expected}"
436 );
437 }
438}