burn_dataset/transform/
sampler.rs

1use crate::Dataset;
2use rand::{Rng, SeedableRng, distr::Uniform, rngs::StdRng, seq::IteratorRandom};
3use std::{marker::PhantomData, ops::DerefMut, sync::Mutex};
4
5/// Sample items from a dataset.
6///
7/// This is an convenient way of modeling a dataset as a probability distribution of a fixed size.
8/// You have multiple options to instantiate the dataset sampler.
9///
10/// * With replacement (Default): This is the most efficient way of using the sampler because no state is
11///   required to keep indices that have been selected.
12///
13/// * Without replacement: This has a similar effect to using a
14///   [shuffled dataset](crate::transform::ShuffledDataset), but with more flexibility since you can
15///   set the dataset to an arbitrary size. Once every item has been used, a new cycle is
16///   created with a new random suffle.
17pub struct SamplerDataset<D, I> {
18    dataset: D,
19    size: usize,
20    state: Mutex<SamplerState>,
21    input: PhantomData<I>,
22}
23
24enum SamplerState {
25    WithReplacement(StdRng),
26    WithoutReplacement(StdRng, Vec<usize>),
27}
28
29impl<D, I> SamplerDataset<D, I>
30where
31    D: Dataset<I>,
32    I: Send + Sync,
33{
34    /// Creates a new sampler dataset with replacement.
35    pub fn new(dataset: D, size: usize) -> Self {
36        Self {
37            dataset,
38            size,
39            state: Mutex::new(SamplerState::WithReplacement(StdRng::from_os_rng())),
40            input: PhantomData,
41        }
42    }
43
44    /// Creates a new sampler dataset with replacement.
45    pub fn with_replacement(dataset: D, size: usize) -> Self {
46        Self::new(dataset, size)
47    }
48
49    /// Creates a new sampler dataset without replacement.
50    pub fn without_replacement(dataset: D, size: usize) -> Self {
51        Self {
52            dataset,
53            size,
54            state: Mutex::new(SamplerState::WithoutReplacement(
55                StdRng::from_os_rng(),
56                Vec::new(),
57            )),
58            input: PhantomData,
59        }
60    }
61
62    fn index(&self) -> usize {
63        let mut state = self.state.lock().unwrap();
64
65        match state.deref_mut() {
66            SamplerState::WithReplacement(rng) => {
67                rng.sample(Uniform::new(0, self.dataset.len()).unwrap())
68            }
69            SamplerState::WithoutReplacement(rng, indices) => {
70                if indices.is_empty() {
71                    // Refill the state.
72                    *indices = (0..self.dataset.len()).choose_multiple(rng, self.dataset.len());
73                }
74
75                indices.pop().expect("Indices are refilled when empty.")
76            }
77        }
78    }
79}
80
81impl<D, I> Dataset<I> for SamplerDataset<D, I>
82where
83    D: Dataset<I>,
84    I: Send + Sync,
85{
86    fn get(&self, index: usize) -> Option<I> {
87        if index >= self.size {
88            return None;
89        }
90
91        self.dataset.get(self.index())
92    }
93
94    fn len(&self) -> usize {
95        self.size
96    }
97}
98
99#[cfg(test)]
100mod tests {
101    use super::*;
102    use crate::FakeDataset;
103    use std::collections::HashMap;
104
105    #[test]
106    fn sampler_dataset_with_replacement_iter() {
107        let factor = 3;
108        let len_original = 10;
109        let dataset_sampler = SamplerDataset::with_replacement(
110            FakeDataset::<String>::new(len_original),
111            len_original * factor,
112        );
113        let mut total = 0;
114
115        for _item in dataset_sampler.iter() {
116            total += 1;
117        }
118
119        assert_eq!(total, factor * len_original);
120    }
121
122    #[test]
123    fn sampler_dataset_without_replacement_bucket_test() {
124        let factor = 3;
125        let len_original = 10;
126        let dataset_sampler = SamplerDataset::without_replacement(
127            FakeDataset::<String>::new(len_original),
128            len_original * factor,
129        );
130        let mut buckets = HashMap::new();
131
132        for item in dataset_sampler.iter() {
133            let count = match buckets.get(&item) {
134                Some(count) => count + 1,
135                None => 1,
136            };
137
138            buckets.insert(item, count);
139        }
140
141        let mut total = 0;
142        for count in buckets.into_values() {
143            assert_eq!(count, factor);
144            total += count;
145        }
146        assert_eq!(total, factor * len_original);
147    }
148}