Skip to main content

burn_core/data/dataloader/
batch.rs

1use super::{BatchStrategy, DataLoader, DataLoaderIterator, Progress, batcher::Batcher};
2use burn_dataset::{
3    Dataset,
4    transform::{PartialDataset, ShuffledDataset},
5};
6use burn_tensor::backend::Backend;
7use std::ops::DerefMut;
8use std::sync::Arc;
9
10/// A data loader that can be used to iterate over a dataset in batches.
11pub struct BatchDataLoader<B: Backend, I, O> {
12    strategy: Box<dyn BatchStrategy<I>>,
13    dataset: Arc<dyn Dataset<I>>,
14    batcher: Arc<dyn Batcher<B, I, O>>,
15    device: B::Device,
16    rng: Option<Arc<spin::Mutex<rand::rngs::StdRng>>>,
17}
18
19impl<B: Backend, I, O> Clone for BatchDataLoader<B, I, O> {
20    fn clone(&self) -> Self {
21        Self {
22            strategy: self.strategy.clone_dyn(),
23            dataset: self.dataset.clone(),
24            batcher: self.batcher.clone(),
25            device: self.device.clone(),
26            rng: self.rng.clone(),
27        }
28    }
29}
30
31impl<B: Backend, I, O> BatchDataLoader<B, I, O> {
32    /// Creates a new batch data loader.
33    ///
34    /// # Arguments
35    ///
36    /// * `strategy` - The batch strategy.
37    /// * `dataset` - The dataset.
38    /// * `batcher` - The batcher.
39    /// * `device`  - The device to use when loading a batch.
40    /// * `rng`     - The rng determining if the dataset is shuffled each time a dataloader
41    ///   iterator is created.
42    ///
43    /// # Returns
44    ///
45    /// The batch data loader.
46    pub fn new(
47        strategy: Box<dyn BatchStrategy<I>>,
48        dataset: Arc<dyn Dataset<I>>,
49        batcher: Arc<dyn Batcher<B, I, O>>,
50        device: B::Device,
51        rng: Option<rand::rngs::StdRng>,
52    ) -> Self {
53        Self {
54            strategy,
55            dataset,
56            batcher,
57            device,
58            rng: rng.map(|rng| Arc::new(spin::Mutex::new(rng))),
59        }
60    }
61}
62
63/// A data loader iterator that can be used to iterate over a data loader.
64struct BatchDataloaderIterator<B: Backend, I, O> {
65    current_index: usize,
66    strategy: Box<dyn BatchStrategy<I>>,
67    dataset: Arc<dyn Dataset<I>>,
68    batcher: Arc<dyn Batcher<B, I, O>>,
69    device: B::Device,
70}
71
72impl<B, I, O> DataLoader<B, O> for BatchDataLoader<B, I, O>
73where
74    B: Backend,
75    I: Send + Sync + Clone + 'static,
76    O: Send + 'static,
77{
78    fn iter<'a>(&'a self) -> Box<dyn DataLoaderIterator<O> + 'a> {
79        // When starting a new iteration, we first check if the dataloader was created with an rng,
80        // implying that we should shuffle the dataset beforehand, while advancing the current
81        // rng to ensure that each new iteration shuffles the dataset differently.
82        let dataset = match &self.rng {
83            Some(rng) => Arc::new(ShuffledDataset::new(
84                self.dataset.clone(),
85                rng.lock().deref_mut(),
86            )),
87            None => self.dataset.clone(),
88        };
89        Box::new(BatchDataloaderIterator::new(
90            self.strategy.clone_dyn(),
91            dataset,
92            self.batcher.clone(),
93            self.device.clone(),
94        ))
95    }
96
97    fn num_items(&self) -> usize {
98        self.dataset.len()
99    }
100
101    fn to_device(&self, device: &B::Device) -> Arc<dyn DataLoader<B, O>> {
102        let rng = self.rng.as_ref().map(|rng| {
103            let rng = rng.lock();
104            rng.clone()
105        });
106        Arc::new(Self::new(
107            self.strategy.clone_dyn(),
108            self.dataset.clone(),
109            self.batcher.clone(),
110            device.clone(),
111            rng,
112        ))
113    }
114
115    fn slice(&self, start: usize, end: usize) -> Arc<dyn DataLoader<B, O>> {
116        let rng = self.rng.as_ref().map(|rng| {
117            let rng = rng.lock();
118            rng.clone()
119        });
120        let dataloader = Self::new(
121            self.strategy.clone_dyn(),
122            Arc::new(PartialDataset::new(self.dataset.clone(), start, end)),
123            self.batcher.clone(),
124            self.device.clone(),
125            rng,
126        );
127        Arc::new(dataloader)
128    }
129}
130
131impl<B: Backend, I, O> BatchDataloaderIterator<B, I, O> {
132    /// Creates a new batch data loader iterator.
133    ///
134    /// # Arguments
135    ///
136    /// * `strategy` - The batch strategy.
137    /// * `dataset` - The dataset.
138    /// * `batcher` - The batcher.
139    /// * `device`  - The device to use when loading a batch.
140    ///
141    /// # Returns
142    ///
143    /// The batch data loader iterator.
144    pub fn new(
145        strategy: Box<dyn BatchStrategy<I>>,
146        dataset: Arc<dyn Dataset<I>>,
147        batcher: Arc<dyn Batcher<B, I, O>>,
148        device: B::Device,
149    ) -> Self {
150        BatchDataloaderIterator {
151            current_index: 0,
152            strategy,
153            dataset,
154            batcher,
155            device,
156        }
157    }
158}
159
160impl<B: Backend, I, O> Iterator for BatchDataloaderIterator<B, I, O> {
161    type Item = O;
162
163    fn next(&mut self) -> Option<O> {
164        while let Some(item) = self.dataset.get(self.current_index) {
165            self.current_index += 1;
166            self.strategy.add(item);
167
168            if let Some(items) = self.strategy.batch(false) {
169                return Some(self.batcher.batch(items, &self.device));
170            }
171        }
172
173        if let Some(items) = self.strategy.batch(true) {
174            return Some(self.batcher.batch(items, &self.device));
175        }
176
177        None
178    }
179}
180
181impl<B: Backend, I, O> DataLoaderIterator<O> for BatchDataloaderIterator<B, I, O> {
182    fn progress(&self) -> Progress {
183        Progress::new(self.current_index, self.dataset.len())
184    }
185}
186
187#[cfg(test)]
188mod tests {
189    use std::collections::HashSet;
190
191    use super::*;
192    use crate::data::dataloader::FixBatchStrategy;
193    use crate::data::dataloader::batcher::TestBatcher;
194    use crate::data::dataset::FakeDataset;
195
196    #[test]
197    fn test_batch_dataloader() {
198        let batcher = Arc::new(TestBatcher::new());
199        let dataset = Arc::new(FakeDataset::<String>::new(27));
200        let dataloader = BatchDataLoader::new(
201            Box::new(FixBatchStrategy::new(5)),
202            dataset.clone(),
203            batcher,
204            Default::default(),
205            None,
206        );
207
208        let mut items_dataset = HashSet::new();
209        let mut items_dataloader = HashSet::new();
210
211        for item in dataset.iter() {
212            items_dataset.insert(item);
213        }
214
215        for items in dataloader.iter() {
216            for item in items {
217                items_dataloader.insert(item);
218            }
219        }
220
221        assert_eq!(items_dataset, items_dataloader);
222    }
223
224    #[test]
225    fn test_batch_dataloader_slice() {
226        let batcher = Arc::new(TestBatcher::new());
227        let dataset = Arc::new(FakeDataset::<String>::new(27));
228        let dataloader = BatchDataLoader::new(
229            Box::new(FixBatchStrategy::new(5)),
230            dataset.clone(),
231            batcher,
232            Default::default(),
233            None,
234        );
235        let dataloader_slice = dataloader.slice(5, 15);
236
237        let mut items_dataloader = HashSet::new();
238        let mut items_dataloader_slice = HashSet::new();
239
240        let mut idx = 0;
241        for items in dataloader.iter() {
242            for item in items {
243                if (5..15).contains(&idx) {
244                    items_dataloader.insert(item);
245                }
246                idx += 1;
247            }
248        }
249
250        for items in dataloader_slice.iter() {
251            for item in items {
252                items_dataloader_slice.insert(item);
253            }
254        }
255
256        assert_eq!(items_dataloader, items_dataloader_slice);
257    }
258}