burn_core/data/dataloader/
batch.rs1use super::{BatchStrategy, DataLoader, DataLoaderIterator, Progress, batcher::Batcher};
2use burn_dataset::{
3 Dataset,
4 transform::{PartialDataset, ShuffledDataset},
5};
6use burn_tensor::backend::Backend;
7use rand::{Rng, distr::StandardUniform};
8use std::sync::Arc;
9
10pub struct BatchDataLoader<B: Backend, I, O> {
12 strategy: Box<dyn BatchStrategy<I>>,
13 dataset: Arc<dyn Dataset<I>>,
14 batcher: Arc<dyn Batcher<B, I, O>>,
15 device: B::Device,
16 rng: Option<Arc<spin::Mutex<rand::rngs::StdRng>>>,
17}
18
19impl<B: Backend, I, O> Clone for BatchDataLoader<B, I, O> {
20 fn clone(&self) -> Self {
21 Self {
22 strategy: self.strategy.clone_dyn(),
23 dataset: self.dataset.clone(),
24 batcher: self.batcher.clone(),
25 device: self.device.clone(),
26 rng: self.rng.clone(),
27 }
28 }
29}
30
31impl<B: Backend, I, O> BatchDataLoader<B, I, O> {
32 pub fn new(
47 strategy: Box<dyn BatchStrategy<I>>,
48 dataset: Arc<dyn Dataset<I>>,
49 batcher: Arc<dyn Batcher<B, I, O>>,
50 device: B::Device,
51 rng: Option<rand::rngs::StdRng>,
52 ) -> Self {
53 Self {
54 strategy,
55 dataset,
56 batcher,
57 device,
58 rng: rng.map(|rng| Arc::new(spin::Mutex::new(rng))),
59 }
60 }
61}
62
63struct BatchDataloaderIterator<B: Backend, I, O> {
65 current_index: usize,
66 strategy: Box<dyn BatchStrategy<I>>,
67 dataset: Arc<dyn Dataset<I>>,
68 batcher: Arc<dyn Batcher<B, I, O>>,
69 device: B::Device,
70}
71
72impl<B, I, O> DataLoader<B, O> for BatchDataLoader<B, I, O>
73where
74 B: Backend,
75 I: Send + Sync + Clone + 'static,
76 O: Send + 'static,
77{
78 fn iter<'a>(&'a self) -> Box<dyn DataLoaderIterator<O> + 'a> {
79 let dataset = match &self.rng {
83 Some(rng) => {
84 let mut rng = rng.lock();
85
86 Arc::new(ShuffledDataset::with_seed(
87 self.dataset.clone(),
88 rng.sample(StandardUniform),
89 ))
90 }
91 None => self.dataset.clone(),
92 };
93 Box::new(BatchDataloaderIterator::new(
94 self.strategy.clone_dyn(),
95 dataset,
96 self.batcher.clone(),
97 self.device.clone(),
98 ))
99 }
100
101 fn num_items(&self) -> usize {
102 self.dataset.len()
103 }
104
105 fn to_device(&self, device: &B::Device) -> Arc<dyn DataLoader<B, O>> {
106 let rng = self.rng.as_ref().map(|rng| {
107 let rng = rng.lock();
108 rng.clone()
109 });
110 Arc::new(Self::new(
111 self.strategy.clone_dyn(),
112 self.dataset.clone(),
113 self.batcher.clone(),
114 device.clone(),
115 rng,
116 ))
117 }
118
119 fn slice(&self, start: usize, end: usize) -> Arc<dyn DataLoader<B, O>> {
120 let rng = self.rng.as_ref().map(|rng| {
121 let rng = rng.lock();
122 rng.clone()
123 });
124 let dataloader = Self::new(
125 self.strategy.clone_dyn(),
126 Arc::new(PartialDataset::new(self.dataset.clone(), start, end)),
127 self.batcher.clone(),
128 self.device.clone(),
129 rng,
130 );
131 Arc::new(dataloader)
132 }
133}
134
135impl<B: Backend, I, O> BatchDataloaderIterator<B, I, O> {
136 pub fn new(
149 strategy: Box<dyn BatchStrategy<I>>,
150 dataset: Arc<dyn Dataset<I>>,
151 batcher: Arc<dyn Batcher<B, I, O>>,
152 device: B::Device,
153 ) -> Self {
154 BatchDataloaderIterator {
155 current_index: 0,
156 strategy,
157 dataset,
158 batcher,
159 device,
160 }
161 }
162}
163
164impl<B: Backend, I, O> Iterator for BatchDataloaderIterator<B, I, O> {
165 type Item = O;
166
167 fn next(&mut self) -> Option<O> {
168 while let Some(item) = self.dataset.get(self.current_index) {
169 self.current_index += 1;
170 self.strategy.add(item);
171
172 if let Some(items) = self.strategy.batch(false) {
173 return Some(self.batcher.batch(items, &self.device));
174 }
175 }
176
177 if let Some(items) = self.strategy.batch(true) {
178 return Some(self.batcher.batch(items, &self.device));
179 }
180
181 None
182 }
183}
184
185impl<B: Backend, I, O> DataLoaderIterator<O> for BatchDataloaderIterator<B, I, O> {
186 fn progress(&self) -> Progress {
187 Progress::new(self.current_index, self.dataset.len())
188 }
189}
190
191#[cfg(test)]
192mod tests {
193 use std::collections::HashSet;
194
195 use super::*;
196 use crate::data::dataloader::FixBatchStrategy;
197 use crate::data::dataloader::batcher::TestBatcher;
198 use crate::data::dataset::FakeDataset;
199
200 #[test]
201 fn test_batch_dataloader() {
202 let batcher = Arc::new(TestBatcher::new());
203 let dataset = Arc::new(FakeDataset::<String>::new(27));
204 let dataloader = BatchDataLoader::new(
205 Box::new(FixBatchStrategy::new(5)),
206 dataset.clone(),
207 batcher,
208 Default::default(),
209 None,
210 );
211
212 let mut items_dataset = HashSet::new();
213 let mut items_dataloader = HashSet::new();
214
215 for item in dataset.iter() {
216 items_dataset.insert(item);
217 }
218
219 for items in dataloader.iter() {
220 for item in items {
221 items_dataloader.insert(item);
222 }
223 }
224
225 assert_eq!(items_dataset, items_dataloader);
226 }
227
228 #[test]
229 fn test_batch_dataloader_slice() {
230 let batcher = Arc::new(TestBatcher::new());
231 let dataset = Arc::new(FakeDataset::<String>::new(27));
232 let dataloader = BatchDataLoader::new(
233 Box::new(FixBatchStrategy::new(5)),
234 dataset.clone(),
235 batcher,
236 Default::default(),
237 None,
238 );
239 let dataloader_slice = dataloader.slice(5, 15);
240
241 let mut items_dataloader = HashSet::new();
242 let mut items_dataloader_slice = HashSet::new();
243
244 let mut idx = 0;
245 for items in dataloader.iter() {
246 for item in items {
247 if (5..15).contains(&idx) {
248 items_dataloader.insert(item);
249 }
250 idx += 1;
251 }
252 }
253
254 for items in dataloader_slice.iter() {
255 for item in items {
256 items_dataloader_slice.insert(item);
257 }
258 }
259
260 assert_eq!(items_dataloader, items_dataloader_slice);
261 }
262}