Skip to main content

burn_dataset/transform/
sampler.rs

1use crate::Dataset;
2use crate::transform::{RngSource, SizeConfig};
3use rand::prelude::SliceRandom;
4use rand::{Rng, distr::Uniform, rngs::StdRng, seq::IteratorRandom};
5use std::{marker::PhantomData, ops::DerefMut, sync::Mutex};
6
7/// Options to configure a [SamplerDataset].
8#[derive(Debug, Clone, PartialEq)]
9pub struct SamplerDatasetOptions {
10    /// The sampling mode.
11    pub replace_samples: bool,
12
13    /// The size source of the wrapper relative to the dataset.
14    pub size_config: SizeConfig,
15
16    /// The source of the random number generator.
17    pub rng_source: RngSource,
18}
19
20impl Default for SamplerDatasetOptions {
21    fn default() -> Self {
22        Self {
23            replace_samples: true,
24            size_config: SizeConfig::Default,
25            rng_source: RngSource::Default,
26        }
27    }
28}
29
30impl<T> From<Option<T>> for SamplerDatasetOptions
31where
32    T: Into<SamplerDatasetOptions>,
33{
34    fn from(option: Option<T>) -> Self {
35        match option {
36            Some(option) => option.into(),
37            None => Self::default(),
38        }
39    }
40}
41
42impl From<usize> for SamplerDatasetOptions {
43    fn from(size: usize) -> Self {
44        Self::default().with_replacement().with_fixed_size(size)
45    }
46}
47
48impl SamplerDatasetOptions {
49    /// Set the replacement mode.
50    pub fn with_replace_samples(self, replace_samples: bool) -> Self {
51        Self {
52            replace_samples,
53            ..self
54        }
55    }
56
57    /// Set the replacement mode to WithReplacement.
58    pub fn with_replacement(self) -> Self {
59        self.with_replace_samples(true)
60    }
61
62    /// Set the replacement mode to WithoutReplacement.
63    pub fn without_replacement(self) -> Self {
64        self.with_replace_samples(false)
65    }
66
67    /// Set the size source.
68    pub fn with_size<S>(self, source: S) -> Self
69    where
70        S: Into<SizeConfig>,
71    {
72        Self {
73            size_config: source.into(),
74            ..self
75        }
76    }
77
78    /// Set the size to the size of the source.
79    pub fn with_source_size(self) -> Self {
80        self.with_size(SizeConfig::Default)
81    }
82
83    /// Set the size to a fixed size.
84    pub fn with_fixed_size(self, size: usize) -> Self {
85        self.with_size(size)
86    }
87
88    /// Set the size to be a multiple of the ration and the source size.
89    pub fn with_size_ratio(self, size_ratio: f64) -> Self {
90        self.with_size(size_ratio)
91    }
92
93    /// Set the `RngSource`.
94    pub fn with_rng<R>(self, rng: R) -> Self
95    where
96        R: Into<RngSource>,
97    {
98        Self {
99            rng_source: rng.into(),
100            ..self
101        }
102    }
103
104    /// Use the system rng.
105    pub fn with_system_rng(self) -> Self {
106        self.with_rng(RngSource::Default)
107    }
108
109    /// Use a rng, built from a seed.
110    pub fn with_seed(self, seed: u64) -> Self {
111        self.with_rng(seed)
112    }
113}
114
115/// Sample items from a dataset.
116///
117/// This is a convenient way of modeling a dataset as a probability distribution of a fixed size.
118/// You have multiple options to instantiate the dataset sampler.
119///
120/// * With replacement (Default): This is the most efficient way of using the sampler because no state is
121///   required to keep indices that have been selected.
122///
123/// * Without replacement: This has a similar effect to using a
124///   [shuffled dataset](crate::transform::ShuffledDataset), but with more flexibility since you can
125///   set the dataset to an arbitrary size. Once every item has been used, a new cycle is
126///   created with a new random suffle.
127pub struct SamplerDataset<D, I> {
128    dataset: D,
129    size: usize,
130    state: Mutex<SamplerState>,
131    input: PhantomData<I>,
132}
133enum SamplerState {
134    WithReplacement(StdRng),
135    WithoutReplacement(StdRng, Vec<usize>),
136}
137
138impl<D, I> SamplerDataset<D, I>
139where
140    D: Dataset<I>,
141    I: Send + Sync,
142{
143    /// Creates a new sampler dataset with replacement.
144    ///
145    /// When the sample size is less than or equal to the source dataset size,
146    /// data will be sampled without replacement from the source dataset in
147    /// a uniformly shuffled order.
148    ///
149    /// When the sample size is greater than the source dataset size,
150    /// the entire source dataset will be sampled once for every multiple
151    /// of the size ratios; with the remaining samples taken without replacement
152    /// uniformly from the source. All samples will be returned uniformly shuffled.
153    ///
154    /// ## Arguments
155    ///
156    /// * `dataset`: the dataset to wrap.
157    /// * `options`: the options to configure the sampler dataset.
158    ///
159    /// ## Examples
160    /// ```rust,ignore
161    /// use burn_dataset::transform::{
162    ///   SamplerDataset,
163    ///   SamplerDatasetOptions,
164    /// };
165    ///
166    /// // Examples below assuming `dataset.len()` = `10`.
167    ///
168    /// // sample size: 5
169    /// // WithReplacement
170    /// // rng: StdRng::from_os_rng()
171    /// SamplerDataset::new(dataset, 5);
172    ///
173    /// // sample size: 10 (source)
174    /// // WithReplacement
175    /// // rng: StdRng::from_os_rng()
176    /// SamplerDataset::new(dataset, SamplerDatasetOptions::default());
177    ///
178    /// // sample size: 15
179    /// // WithoutReplacement
180    /// // rng: StdRng::seed_from_u64(42)
181    /// SamplerDataset::new(
182    ///   dataset,
183    ///   SamplerDatasetOptions::default()
184    ///     .with_size(1.5)
185    ///     .without_replacement()
186    ///     .with_rng(42),
187    /// );
188    /// ```
189    pub fn new<O>(dataset: D, options: O) -> Self
190    where
191        O: Into<SamplerDatasetOptions>,
192    {
193        let options = options.into();
194        let size = options.size_config.resolve(dataset.len());
195        let rng = options.rng_source.into();
196        Self {
197            dataset,
198            size,
199            state: Mutex::new(match options.replace_samples {
200                true => SamplerState::WithReplacement(rng),
201                false => SamplerState::WithoutReplacement(rng, Vec::with_capacity(size)),
202            }),
203            input: PhantomData,
204        }
205    }
206
207    /// Creates a new sampler dataset with replacement.
208    ///
209    /// # Arguments
210    ///
211    /// - `dataset`: the dataset to wrap.
212    /// - `size`: the effective size of the sampled dataset.
213    pub fn with_replacement(dataset: D, size: usize) -> Self {
214        Self::new(
215            dataset,
216            SamplerDatasetOptions::default()
217                .with_replacement()
218                .with_fixed_size(size),
219        )
220    }
221
222    /// Creates a new sampler dataset without replacement.
223    ///
224    /// When the sample size is less than or equal to the source dataset size,
225    /// data will be sampled without replacement from the source dataset in
226    /// a uniformly shuffled order.
227    ///
228    /// When the sample size is greater than the source dataset size,
229    /// the entire source dataset will be sampled once for every multiple
230    /// of the size ratios; with the remaining samples taken without replacement
231    /// uniformly from the source. All samples will be returned uniformly shuffled.
232    ///
233    /// # Arguments
234    /// - `dataset`: the dataset to wrap.
235    /// - `size`: the effective size of the sampled dataset.
236    pub fn without_replacement(dataset: D, size: usize) -> Self {
237        Self::new(
238            dataset,
239            SamplerDatasetOptions::default()
240                .without_replacement()
241                .with_fixed_size(size),
242        )
243    }
244
245    /// Determines if the sampler is using the "with replacement" strategy.
246    ///
247    /// # Returns
248    /// - `true`: If the sampler is configured to sample with replacement.
249    /// - `false`: If the sampler is configured to sample without replacement.
250    pub fn is_with_replacement(&self) -> bool {
251        match self.state.lock().unwrap().deref_mut() {
252            SamplerState::WithReplacement(_) => true,
253            SamplerState::WithoutReplacement(_, _) => false,
254        }
255    }
256
257    fn index(&self) -> usize {
258        match self.state.lock().unwrap().deref_mut() {
259            SamplerState::WithReplacement(rng) => {
260                rng.sample(Uniform::new(0, self.dataset.len()).unwrap())
261            }
262            SamplerState::WithoutReplacement(rng, indices) => {
263                if indices.is_empty() {
264                    // Refill the state.
265                    let idx_range = 0..self.dataset.len();
266                    for _ in 0..(self.size / self.dataset.len()) {
267                        // No need to `.choose_multiple` here because we're using
268                        // the entire source range; and `.choose_multiple` will
269                        // not return a random sample anyway.
270                        indices.extend(idx_range.clone())
271                    }
272
273                    // From `choose_multiple` documentation:
274                    // > Although the elements are selected randomly, the order of elements in
275                    // > the buffer is neither stable nor fully random. If random ordering is
276                    // > desired, shuffle the result.
277                    indices.extend(idx_range.choose_multiple(rng, self.size - indices.len()));
278
279                    // The real shuffling is done here.
280                    indices.shuffle(rng);
281                }
282
283                indices.pop().expect("Indices are refilled when empty.")
284            }
285        }
286    }
287}
288
289impl<D, I> Dataset<I> for SamplerDataset<D, I>
290where
291    D: Dataset<I>,
292    I: Send + Sync,
293{
294    fn get(&self, index: usize) -> Option<I> {
295        if index >= self.size {
296            return None;
297        }
298
299        self.dataset.get(self.index())
300    }
301
302    fn len(&self) -> usize {
303        self.size
304    }
305}
306
307#[cfg(test)]
308mod tests {
309    #![allow(clippy::bool_assert_comparison)]
310
311    use super::*;
312    use crate::FakeDataset;
313    use rand::SeedableRng;
314    use std::collections::HashMap;
315
316    #[test]
317    fn test_samplerdataset_options() {
318        let options = SamplerDatasetOptions::default();
319        assert_eq!(options.replace_samples, true);
320        assert_eq!(options.size_config, SizeConfig::Default);
321        assert_eq!(options.rng_source, RngSource::Default);
322
323        // ReplacementMode
324        let options = options.with_replace_samples(false);
325        assert_eq!(options.replace_samples, false);
326        let options = options.with_replacement();
327        assert_eq!(options.replace_samples, true);
328        let options = options.without_replacement();
329        assert_eq!(options.replace_samples, false);
330
331        // SourceSize
332        let options = options.with_size(SizeConfig::Default);
333        assert_eq!(options.size_config, SizeConfig::Default);
334        let options = options.with_source_size();
335        assert_eq!(options.size_config, SizeConfig::Default);
336        let options = options.with_fixed_size(10);
337        assert_eq!(options.size_config, SizeConfig::Fixed(10));
338        let options = options.with_size_ratio(1.5);
339        assert_eq!(options.size_config, SizeConfig::Ratio(1.5));
340
341        // RngSource
342        let options = options.with_system_rng();
343        assert_eq!(options.rng_source, RngSource::Default);
344        let options = options.with_seed(42);
345        assert_eq!(options.rng_source, RngSource::Seed(42));
346        let rng = StdRng::seed_from_u64(9);
347        let options = options.with_rng(rng.clone());
348        assert_eq!(options.rng_source, RngSource::Rng(rng.clone()));
349    }
350
351    #[test]
352    fn sampler_dataset_constructors_test() {
353        let ds = SamplerDataset::new(FakeDataset::<u32>::new(10), 15);
354        assert_eq!(ds.len(), 15);
355        assert_eq!(ds.dataset.len(), 10);
356        assert!(ds.is_with_replacement());
357
358        let ds = SamplerDataset::with_replacement(FakeDataset::<u32>::new(10), 15);
359        assert_eq!(ds.len(), 15);
360        assert_eq!(ds.dataset.len(), 10);
361        assert!(ds.is_with_replacement());
362
363        let ds = SamplerDataset::without_replacement(FakeDataset::<u32>::new(10), 15);
364        assert_eq!(ds.len(), 15);
365        assert_eq!(ds.dataset.len(), 10);
366        assert!(!ds.is_with_replacement());
367    }
368
369    #[test]
370    fn sampler_dataset_with_replacement_iter() {
371        let factor = 3;
372        let len_original = 10;
373        let dataset_sampler = SamplerDataset::with_replacement(
374            FakeDataset::<String>::new(len_original),
375            len_original * factor,
376        );
377        let mut total = 0;
378
379        for _item in dataset_sampler.iter() {
380            total += 1;
381        }
382
383        assert_eq!(total, factor * len_original);
384    }
385
386    #[test]
387    fn sampler_dataset_without_replacement_bucket_test() {
388        let factor = 3;
389        let len_original = 10;
390
391        let dataset_sampler = SamplerDataset::new(
392            FakeDataset::<String>::new(len_original),
393            SamplerDatasetOptions::default()
394                .without_replacement()
395                .with_size_ratio(factor as f64),
396        );
397
398        let mut buckets = HashMap::new();
399
400        for item in dataset_sampler.iter() {
401            let count = match buckets.get(&item) {
402                Some(count) => count + 1,
403                None => 1,
404            };
405
406            buckets.insert(item, count);
407        }
408
409        let mut total = 0;
410        for count in buckets.into_values() {
411            assert_eq!(count, factor);
412            total += count;
413        }
414        assert_eq!(total, factor * len_original);
415    }
416
417    #[test]
418    fn sampler_dataset_without_replacement_uniform_order_test() {
419        // This is a reversion test on the indices.shuffle(rng) call in SamplerDataset::index().
420        let size = 1000;
421        let dataset_sampler =
422            SamplerDataset::without_replacement(FakeDataset::<i32>::new(size), size);
423
424        let indices: Vec<_> = (0..size).map(|_| dataset_sampler.index()).collect();
425        let mean_delta = indices
426            .windows(2)
427            .map(|pair| pair[1].abs_diff(pair[0]))
428            .sum::<usize>() as f64
429            / (size - 1) as f64;
430
431        let expected = (size + 2) as f64 / 3.0;
432
433        assert!(
434            (mean_delta - expected).abs() <= 0.25 * expected,
435            "Sampled indices are not uniformly distributed: mean_delta: {mean_delta}, expected: {expected}"
436        );
437    }
438}