burn_core/data/dataloader/
batch.rs

1use super::{BatchStrategy, DataLoader, DataLoaderIterator, Progress, batcher::Batcher};
2use burn_dataset::{
3    Dataset,
4    transform::{PartialDataset, ShuffledDataset},
5};
6use burn_tensor::backend::Backend;
7use rand::{Rng, distr::StandardUniform};
8use std::sync::Arc;
9
10/// A data loader that can be used to iterate over a dataset in batches.
11pub struct BatchDataLoader<B: Backend, I, O> {
12    strategy: Box<dyn BatchStrategy<I>>,
13    dataset: Arc<dyn Dataset<I>>,
14    batcher: Arc<dyn Batcher<B, I, O>>,
15    device: B::Device,
16    rng: Option<Arc<spin::Mutex<rand::rngs::StdRng>>>,
17}
18
19impl<B: Backend, I, O> Clone for BatchDataLoader<B, I, O> {
20    fn clone(&self) -> Self {
21        Self {
22            strategy: self.strategy.clone_dyn(),
23            dataset: self.dataset.clone(),
24            batcher: self.batcher.clone(),
25            device: self.device.clone(),
26            rng: self.rng.clone(),
27        }
28    }
29}
30
31impl<B: Backend, I, O> BatchDataLoader<B, I, O> {
32    /// Creates a new batch data loader.
33    ///
34    /// # Arguments
35    ///
36    /// * `strategy` - The batch strategy.
37    /// * `dataset` - The dataset.
38    /// * `batcher` - The batcher.
39    /// * `device`  - The device to use when loading a batch.
40    /// * `rng`     - The rng determining if the dataset is shuffled each time a dataloader
41    ///   iterator is created.
42    ///
43    /// # Returns
44    ///
45    /// The batch data loader.
46    pub fn new(
47        strategy: Box<dyn BatchStrategy<I>>,
48        dataset: Arc<dyn Dataset<I>>,
49        batcher: Arc<dyn Batcher<B, I, O>>,
50        device: B::Device,
51        rng: Option<rand::rngs::StdRng>,
52    ) -> Self {
53        Self {
54            strategy,
55            dataset,
56            batcher,
57            device,
58            rng: rng.map(|rng| Arc::new(spin::Mutex::new(rng))),
59        }
60    }
61}
62
63/// A data loader iterator that can be used to iterate over a data loader.
64struct BatchDataloaderIterator<B: Backend, I, O> {
65    current_index: usize,
66    strategy: Box<dyn BatchStrategy<I>>,
67    dataset: Arc<dyn Dataset<I>>,
68    batcher: Arc<dyn Batcher<B, I, O>>,
69    device: B::Device,
70}
71
72impl<B, I, O> DataLoader<B, O> for BatchDataLoader<B, I, O>
73where
74    B: Backend,
75    I: Send + Sync + Clone + 'static,
76    O: Send + 'static,
77{
78    fn iter<'a>(&'a self) -> Box<dyn DataLoaderIterator<O> + 'a> {
79        // When starting a new iteration, we first check if the dataloader was created with an rng,
80        // implying that we should shuffle the dataset beforehand, while advancing the current
81        // rng to ensure that each new iteration shuffles the dataset differently.
82        let dataset = match &self.rng {
83            Some(rng) => {
84                let mut rng = rng.lock();
85
86                Arc::new(ShuffledDataset::with_seed(
87                    self.dataset.clone(),
88                    rng.sample(StandardUniform),
89                ))
90            }
91            None => self.dataset.clone(),
92        };
93        Box::new(BatchDataloaderIterator::new(
94            self.strategy.clone_dyn(),
95            dataset,
96            self.batcher.clone(),
97            self.device.clone(),
98        ))
99    }
100
101    fn num_items(&self) -> usize {
102        self.dataset.len()
103    }
104
105    fn to_device(&self, device: &B::Device) -> Arc<dyn DataLoader<B, O>> {
106        let rng = self.rng.as_ref().map(|rng| {
107            let rng = rng.lock();
108            rng.clone()
109        });
110        Arc::new(Self::new(
111            self.strategy.clone_dyn(),
112            self.dataset.clone(),
113            self.batcher.clone(),
114            device.clone(),
115            rng,
116        ))
117    }
118
119    fn slice(&self, start: usize, end: usize) -> Arc<dyn DataLoader<B, O>> {
120        let rng = self.rng.as_ref().map(|rng| {
121            let rng = rng.lock();
122            rng.clone()
123        });
124        let dataloader = Self::new(
125            self.strategy.clone_dyn(),
126            Arc::new(PartialDataset::new(self.dataset.clone(), start, end)),
127            self.batcher.clone(),
128            self.device.clone(),
129            rng,
130        );
131        Arc::new(dataloader)
132    }
133}
134
135impl<B: Backend, I, O> BatchDataloaderIterator<B, I, O> {
136    /// Creates a new batch data loader iterator.
137    ///
138    /// # Arguments
139    ///
140    /// * `strategy` - The batch strategy.
141    /// * `dataset` - The dataset.
142    /// * `batcher` - The batcher.
143    /// * `device`  - The device to use when loading a batch.
144    ///
145    /// # Returns
146    ///
147    /// The batch data loader iterator.
148    pub fn new(
149        strategy: Box<dyn BatchStrategy<I>>,
150        dataset: Arc<dyn Dataset<I>>,
151        batcher: Arc<dyn Batcher<B, I, O>>,
152        device: B::Device,
153    ) -> Self {
154        BatchDataloaderIterator {
155            current_index: 0,
156            strategy,
157            dataset,
158            batcher,
159            device,
160        }
161    }
162}
163
164impl<B: Backend, I, O> Iterator for BatchDataloaderIterator<B, I, O> {
165    type Item = O;
166
167    fn next(&mut self) -> Option<O> {
168        while let Some(item) = self.dataset.get(self.current_index) {
169            self.current_index += 1;
170            self.strategy.add(item);
171
172            if let Some(items) = self.strategy.batch(false) {
173                return Some(self.batcher.batch(items, &self.device));
174            }
175        }
176
177        if let Some(items) = self.strategy.batch(true) {
178            return Some(self.batcher.batch(items, &self.device));
179        }
180
181        None
182    }
183}
184
185impl<B: Backend, I, O> DataLoaderIterator<O> for BatchDataloaderIterator<B, I, O> {
186    fn progress(&self) -> Progress {
187        Progress::new(self.current_index, self.dataset.len())
188    }
189}
190
191#[cfg(test)]
192mod tests {
193    use std::collections::HashSet;
194
195    use super::*;
196    use crate::data::dataloader::FixBatchStrategy;
197    use crate::data::dataloader::batcher::TestBatcher;
198    use crate::data::dataset::FakeDataset;
199
200    #[test]
201    fn test_batch_dataloader() {
202        let batcher = Arc::new(TestBatcher::new());
203        let dataset = Arc::new(FakeDataset::<String>::new(27));
204        let dataloader = BatchDataLoader::new(
205            Box::new(FixBatchStrategy::new(5)),
206            dataset.clone(),
207            batcher,
208            Default::default(),
209            None,
210        );
211
212        let mut items_dataset = HashSet::new();
213        let mut items_dataloader = HashSet::new();
214
215        for item in dataset.iter() {
216            items_dataset.insert(item);
217        }
218
219        for items in dataloader.iter() {
220            for item in items {
221                items_dataloader.insert(item);
222            }
223        }
224
225        assert_eq!(items_dataset, items_dataloader);
226    }
227
228    #[test]
229    fn test_batch_dataloader_slice() {
230        let batcher = Arc::new(TestBatcher::new());
231        let dataset = Arc::new(FakeDataset::<String>::new(27));
232        let dataloader = BatchDataLoader::new(
233            Box::new(FixBatchStrategy::new(5)),
234            dataset.clone(),
235            batcher,
236            Default::default(),
237            None,
238        );
239        let dataloader_slice = dataloader.slice(5, 15);
240
241        let mut items_dataloader = HashSet::new();
242        let mut items_dataloader_slice = HashSet::new();
243
244        let mut idx = 0;
245        for items in dataloader.iter() {
246            for item in items {
247                if (5..15).contains(&idx) {
248                    items_dataloader.insert(item);
249                }
250                idx += 1;
251            }
252        }
253
254        for items in dataloader_slice.iter() {
255            for item in items {
256                items_dataloader_slice.insert(item);
257            }
258        }
259
260        assert_eq!(items_dataloader, items_dataloader_slice);
261    }
262}