Skip to main content

burn_core/data/dataloader/
batch.rs

1use super::{BatchStrategy, DataLoader, DataLoaderIterator, Progress, batcher::Batcher};
2use burn_dataset::{
3    Dataset,
4    transform::{PartialDataset, ShuffledDataset},
5};
6use burn_tensor::backend::Backend;
7use rand::SeedableRng;
8use std::ops::DerefMut;
9use std::sync::Arc;
10
11/// A data loader that can be used to iterate over a dataset in batches.
12pub struct BatchDataLoader<B: Backend, I, O> {
13    strategy: Box<dyn BatchStrategy<I>>,
14    dataset: Arc<dyn Dataset<I>>,
15    batcher: Arc<dyn Batcher<B, I, O>>,
16    device: B::Device,
17    rng: Option<Arc<spin::Mutex<rand::rngs::StdRng>>>,
18}
19
20impl<B: Backend, I, O> Clone for BatchDataLoader<B, I, O> {
21    fn clone(&self) -> Self {
22        Self {
23            strategy: self.strategy.clone_dyn(),
24            dataset: self.dataset.clone(),
25            batcher: self.batcher.clone(),
26            device: self.device.clone(),
27            rng: self.rng.clone(),
28        }
29    }
30}
31
32impl<B: Backend, I, O> BatchDataLoader<B, I, O> {
33    /// Creates a new batch data loader.
34    ///
35    /// # Arguments
36    ///
37    /// * `strategy` - The batch strategy.
38    /// * `dataset` - The dataset.
39    /// * `batcher` - The batcher.
40    /// * `device`  - The device to use when loading a batch.
41    /// * `rng`     - The rng determining if the dataset is shuffled each time a dataloader
42    ///   iterator is created.
43    ///
44    /// # Returns
45    ///
46    /// The batch data loader.
47    pub fn new(
48        strategy: Box<dyn BatchStrategy<I>>,
49        dataset: Arc<dyn Dataset<I>>,
50        batcher: Arc<dyn Batcher<B, I, O>>,
51        device: B::Device,
52        rng: Option<rand::rngs::StdRng>,
53    ) -> Self {
54        Self {
55            strategy,
56            dataset,
57            batcher,
58            device,
59            rng: rng.map(|rng| Arc::new(spin::Mutex::new(rng))),
60        }
61    }
62}
63
64/// A data loader iterator that can be used to iterate over a data loader.
65struct BatchDataloaderIterator<B: Backend, I, O> {
66    current_index: usize,
67    strategy: Box<dyn BatchStrategy<I>>,
68    dataset: Arc<dyn Dataset<I>>,
69    batcher: Arc<dyn Batcher<B, I, O>>,
70    device: B::Device,
71}
72
73impl<B, I, O> DataLoader<B, O> for BatchDataLoader<B, I, O>
74where
75    B: Backend,
76    I: Send + Sync + Clone + 'static,
77    O: Send + 'static,
78{
79    fn iter<'a>(&'a self) -> Box<dyn DataLoaderIterator<O> + 'a> {
80        // When starting a new iteration, we first check if the dataloader was created with an rng,
81        // implying that we should shuffle the dataset beforehand, while advancing the current
82        // rng to ensure that each new iteration shuffles the dataset differently.
83        let dataset = match &self.rng {
84            Some(rng) => Arc::new(ShuffledDataset::new(
85                self.dataset.clone(),
86                rng.lock().deref_mut(),
87            )),
88            None => self.dataset.clone(),
89        };
90        Box::new(BatchDataloaderIterator::new(
91            self.strategy.clone_dyn(),
92            dataset,
93            self.batcher.clone(),
94            self.device.clone(),
95        ))
96    }
97
98    fn num_items(&self) -> usize {
99        self.dataset.len()
100    }
101
102    fn to_device(&self, device: &B::Device) -> Arc<dyn DataLoader<B, O>> {
103        let rng = self.rng.as_ref().map(|rng| {
104            let mut rng = rng.lock();
105            rng.fork()
106        });
107        Arc::new(Self::new(
108            self.strategy.clone_dyn(),
109            self.dataset.clone(),
110            self.batcher.clone(),
111            device.clone(),
112            rng,
113        ))
114    }
115
116    fn slice(&self, start: usize, end: usize) -> Arc<dyn DataLoader<B, O>> {
117        let rng = self.rng.as_ref().map(|rng| {
118            let mut rng = rng.lock();
119            rng.fork()
120        });
121        let dataloader = Self::new(
122            self.strategy.clone_dyn(),
123            Arc::new(PartialDataset::new(self.dataset.clone(), start, end)),
124            self.batcher.clone(),
125            self.device.clone(),
126            rng,
127        );
128        Arc::new(dataloader)
129    }
130}
131
132impl<B: Backend, I, O> BatchDataloaderIterator<B, I, O> {
133    /// Creates a new batch data loader iterator.
134    ///
135    /// # Arguments
136    ///
137    /// * `strategy` - The batch strategy.
138    /// * `dataset` - The dataset.
139    /// * `batcher` - The batcher.
140    /// * `device`  - The device to use when loading a batch.
141    ///
142    /// # Returns
143    ///
144    /// The batch data loader iterator.
145    pub fn new(
146        strategy: Box<dyn BatchStrategy<I>>,
147        dataset: Arc<dyn Dataset<I>>,
148        batcher: Arc<dyn Batcher<B, I, O>>,
149        device: B::Device,
150    ) -> Self {
151        BatchDataloaderIterator {
152            current_index: 0,
153            strategy,
154            dataset,
155            batcher,
156            device,
157        }
158    }
159}
160
161impl<B: Backend, I, O> Iterator for BatchDataloaderIterator<B, I, O> {
162    type Item = O;
163
164    fn next(&mut self) -> Option<O> {
165        while let Some(item) = self.dataset.get(self.current_index) {
166            self.current_index += 1;
167            self.strategy.add(item);
168
169            if let Some(items) = self.strategy.batch(false) {
170                return Some(self.batcher.batch(items, &self.device));
171            }
172        }
173
174        if let Some(items) = self.strategy.batch(true) {
175            return Some(self.batcher.batch(items, &self.device));
176        }
177
178        None
179    }
180}
181
182impl<B: Backend, I, O> DataLoaderIterator<O> for BatchDataloaderIterator<B, I, O> {
183    fn progress(&self) -> Progress {
184        Progress::new(self.current_index, self.dataset.len())
185    }
186}
187
188#[cfg(test)]
189mod tests {
190    use std::collections::HashSet;
191
192    use super::*;
193    use crate::data::dataloader::FixBatchStrategy;
194    use crate::data::dataloader::batcher::TestBatcher;
195    use crate::data::dataset::FakeDataset;
196
197    #[test]
198    fn test_batch_dataloader() {
199        let batcher = Arc::new(TestBatcher::new());
200        let dataset = Arc::new(FakeDataset::<String>::new(27));
201        let dataloader = BatchDataLoader::new(
202            Box::new(FixBatchStrategy::new(5)),
203            dataset.clone(),
204            batcher,
205            Default::default(),
206            None,
207        );
208
209        let mut items_dataset = HashSet::new();
210        let mut items_dataloader = HashSet::new();
211
212        for item in dataset.iter() {
213            items_dataset.insert(item);
214        }
215
216        for items in dataloader.iter() {
217            for item in items {
218                items_dataloader.insert(item);
219            }
220        }
221
222        assert_eq!(items_dataset, items_dataloader);
223    }
224
225    #[test]
226    fn test_batch_dataloader_slice() {
227        let batcher = Arc::new(TestBatcher::new());
228        let dataset = Arc::new(FakeDataset::<String>::new(27));
229        let dataloader = BatchDataLoader::new(
230            Box::new(FixBatchStrategy::new(5)),
231            dataset.clone(),
232            batcher,
233            Default::default(),
234            None,
235        );
236        let dataloader_slice = dataloader.slice(5, 15);
237
238        let mut items_dataloader = HashSet::new();
239        let mut items_dataloader_slice = HashSet::new();
240
241        let mut idx = 0;
242        for items in dataloader.iter() {
243            for item in items {
244                if (5..15).contains(&idx) {
245                    items_dataloader.insert(item);
246                }
247                idx += 1;
248            }
249        }
250
251        for items in dataloader_slice.iter() {
252            for item in items {
253                items_dataloader_slice.insert(item);
254            }
255        }
256
257        assert_eq!(items_dataloader, items_dataloader_slice);
258    }
259}