burn_dataset/transform/
sampler.rs1use crate::Dataset;
2use rand::{Rng, SeedableRng, distr::Uniform, rngs::StdRng, seq::IteratorRandom};
3use std::{marker::PhantomData, ops::DerefMut, sync::Mutex};
4
5pub struct SamplerDataset<D, I> {
18 dataset: D,
19 size: usize,
20 state: Mutex<SamplerState>,
21 input: PhantomData<I>,
22}
23
24enum SamplerState {
25 WithReplacement(StdRng),
26 WithoutReplacement(StdRng, Vec<usize>),
27}
28
29impl<D, I> SamplerDataset<D, I>
30where
31 D: Dataset<I>,
32 I: Send + Sync,
33{
34 pub fn new(dataset: D, size: usize) -> Self {
36 Self {
37 dataset,
38 size,
39 state: Mutex::new(SamplerState::WithReplacement(StdRng::from_os_rng())),
40 input: PhantomData,
41 }
42 }
43
44 pub fn with_replacement(dataset: D, size: usize) -> Self {
46 Self::new(dataset, size)
47 }
48
49 pub fn without_replacement(dataset: D, size: usize) -> Self {
51 Self {
52 dataset,
53 size,
54 state: Mutex::new(SamplerState::WithoutReplacement(
55 StdRng::from_os_rng(),
56 Vec::new(),
57 )),
58 input: PhantomData,
59 }
60 }
61
62 fn index(&self) -> usize {
63 let mut state = self.state.lock().unwrap();
64
65 match state.deref_mut() {
66 SamplerState::WithReplacement(rng) => {
67 rng.sample(Uniform::new(0, self.dataset.len()).unwrap())
68 }
69 SamplerState::WithoutReplacement(rng, indices) => {
70 if indices.is_empty() {
71 *indices = (0..self.dataset.len()).choose_multiple(rng, self.dataset.len());
73 }
74
75 indices.pop().expect("Indices are refilled when empty.")
76 }
77 }
78 }
79}
80
81impl<D, I> Dataset<I> for SamplerDataset<D, I>
82where
83 D: Dataset<I>,
84 I: Send + Sync,
85{
86 fn get(&self, index: usize) -> Option<I> {
87 if index >= self.size {
88 return None;
89 }
90
91 self.dataset.get(self.index())
92 }
93
94 fn len(&self) -> usize {
95 self.size
96 }
97}
98
99#[cfg(test)]
100mod tests {
101 use super::*;
102 use crate::FakeDataset;
103 use std::collections::HashMap;
104
105 #[test]
106 fn sampler_dataset_with_replacement_iter() {
107 let factor = 3;
108 let len_original = 10;
109 let dataset_sampler = SamplerDataset::with_replacement(
110 FakeDataset::<String>::new(len_original),
111 len_original * factor,
112 );
113 let mut total = 0;
114
115 for _item in dataset_sampler.iter() {
116 total += 1;
117 }
118
119 assert_eq!(total, factor * len_original);
120 }
121
122 #[test]
123 fn sampler_dataset_without_replacement_bucket_test() {
124 let factor = 3;
125 let len_original = 10;
126 let dataset_sampler = SamplerDataset::without_replacement(
127 FakeDataset::<String>::new(len_original),
128 len_original * factor,
129 );
130 let mut buckets = HashMap::new();
131
132 for item in dataset_sampler.iter() {
133 let count = match buckets.get(&item) {
134 Some(count) => count + 1,
135 None => 1,
136 };
137
138 buckets.insert(item, count);
139 }
140
141 let mut total = 0;
142 for count in buckets.into_values() {
143 assert_eq!(count, factor);
144 total += count;
145 }
146 assert_eq!(total, factor * len_original);
147 }
148}