Skip to main content

burn_core/data/dataloader/
multithread.rs

1use burn_dataset::Dataset;
2use burn_dataset::transform::PartialDataset;
3use burn_tensor::backend::Backend;
4use rand::SeedableRng;
5use rand::distr::{Distribution, StandardUniform};
6use rand::rngs::StdRng;
7
8use super::batcher::Batcher;
9use super::{BatchDataLoader, BatchStrategy, DataLoader, DataLoaderIterator, Progress};
10use std::sync::{Arc, OnceLock, mpsc};
11use std::thread;
12
13const MAX_QUEUED_ITEMS: usize = 100;
14
15/// A multi-threaded data loader that can be used to iterate over a dataset.
16pub struct MultiThreadDataLoader<B: Backend, I, O> {
17    // Configuration parameters needed for initialization
18    strategy: Box<dyn BatchStrategy<I>>,
19    dataset: Arc<dyn Dataset<I>>,
20    batcher: Arc<dyn Batcher<B, I, O>>,
21    device: B::Device,
22    rng: Option<rand::rngs::StdRng>,
23    num_threads: usize,
24
25    // The lazily initialized data loaders
26    dataloaders: OnceLock<Vec<BatchDataLoader<B, I, O>>>,
27}
28
29/// A message that can be sent between threads.
30#[derive(Debug)]
31pub enum Message<O> {
32    /// A batch of items.
33    Batch(usize, O, Progress),
34
35    /// The thread is done.
36    Done,
37}
38
39struct MultiThreadsDataloaderIterator<O> {
40    num_done: usize,
41    workers: Vec<thread::JoinHandle<()>>,
42    receiver: mpsc::Receiver<Message<O>>,
43    progresses: Vec<Progress>,
44}
45
46impl<B: Backend, I, O> MultiThreadDataLoader<B, I, O>
47where
48    I: Send + Sync + Clone + 'static,
49    O: Send + 'static,
50{
51    /// Creates a new multi-threaded batch data loader.
52    ///
53    /// # Arguments
54    ///
55    /// * `strategy` - The batch strategy.
56    /// * `dataset` - The dataset.
57    /// * `batcher` - The batcher.
58    /// * `num_threads` - The number of threads.
59    /// * `device`  - The device to use when loading a batch.
60    /// * `rng`     - The rng determining if the dataset is shuffled each time a dataloader
61    ///   iterator is created.
62    ///
63    /// # Returns
64    ///
65    /// The multi-threaded batch data loader.
66    pub fn new(
67        strategy: Box<dyn BatchStrategy<I>>,
68        dataset: Arc<dyn Dataset<I>>,
69        batcher: Arc<dyn Batcher<B, I, O>>,
70        num_threads: usize,
71        device: B::Device,
72        rng: Option<rand::rngs::StdRng>,
73    ) -> Self {
74        Self {
75            strategy,
76            dataset,
77            batcher,
78            num_threads,
79            device,
80            rng,
81            dataloaders: OnceLock::new(),
82        }
83    }
84
85    /// Force initialization if needed.
86    fn initialize(&self) -> &[BatchDataLoader<B, I, O>] {
87        self.dataloaders
88            .get_or_init(|| {
89                let mut dataset = self.dataset.clone();
90                if let Some(rng) = self.rng.as_ref() {
91                    // Pre-shuffle the dataset before split if shuffle is enabled.
92                    // This ensures that each thread gets a uniform random sample of the dataset.
93                    let mut rng = rng.clone();
94                    dataset = Arc::new(burn_dataset::transform::ShuffledDataset::new(
95                        dataset, &mut rng,
96                    ));
97                }
98
99                let datasets = match self.strategy.batch_size() {
100                    Some(batch_size) => {
101                        PartialDataset::split_chunks(dataset, self.num_threads, batch_size)
102                    }
103                    None => PartialDataset::split(dataset, self.num_threads),
104                };
105
106                // Create more rngs from the first one, one for each new dataloader.
107                let mut rng = self.rng.clone();
108                let rngs = (0..self.num_threads).map(|_| {
109                    rng.as_mut().map(|rng| {
110                        StdRng::seed_from_u64(Distribution::sample(&StandardUniform, rng))
111                    })
112                });
113
114                datasets
115                    .into_iter()
116                    .zip(rngs)
117                    .map(|(dataset, rng)| {
118                        let strategy = self.strategy.clone_dyn();
119                        BatchDataLoader::new(
120                            strategy,
121                            Arc::new(dataset),
122                            self.batcher.clone(),
123                            self.device.clone(),
124                            rng,
125                        )
126                    })
127                    .collect()
128            })
129            .as_ref()
130    }
131}
132
133impl<B: Backend, I, O> DataLoader<B, O> for MultiThreadDataLoader<B, I, O>
134where
135    I: Send + Sync + Clone + 'static,
136    O: Send + 'static + std::fmt::Debug,
137{
138    fn iter<'a>(&'a self) -> Box<dyn DataLoaderIterator<O> + 'a> {
139        // This will initialize the loader if it hasn't been initialized yet
140        let dataloaders = self.initialize();
141
142        let (sender, receiver) = mpsc::sync_channel::<Message<O>>(MAX_QUEUED_ITEMS);
143
144        let mut progresses = Vec::with_capacity(dataloaders.len());
145
146        let handlers: Vec<_> = dataloaders
147            .iter()
148            .enumerate()
149            .map(|(index, dataloader)| {
150                let dataloader_cloned = dataloader.clone();
151                let sender_cloned = sender.clone();
152                progresses.push(Progress::new(0, dataloader_cloned.num_items()));
153
154                thread::spawn(move || {
155                    let mut iterator = dataloader_cloned.iter();
156                    while let Some(item) = iterator.next() {
157                        let progress = iterator.progress();
158
159                        match sender_cloned.send(Message::Batch(index, item, progress)) {
160                            Ok(_) => {}
161                            // The receiver is probably gone, no need to panic, just need to stop
162                            // iterating.
163                            Err(_) => return,
164                        };
165                    }
166                    // Same thing.
167                    sender_cloned.send(Message::Done).ok();
168                })
169            })
170            .collect();
171
172        Box::new(MultiThreadsDataloaderIterator::new(
173            receiver, handlers, progresses,
174        ))
175    }
176
177    fn num_items(&self) -> usize {
178        // For num_items, we can directly use the dataset size without
179        // necessarily initializing the full loader
180        self.dataset.len()
181    }
182
183    fn to_device(&self, device: &B::Device) -> Arc<dyn DataLoader<B, O>> {
184        Arc::new(Self::new(
185            self.strategy.clone_dyn(),
186            self.dataset.clone(),
187            self.batcher.clone(),
188            self.num_threads,
189            device.clone(),
190            self.rng.clone(),
191        ))
192    }
193
194    fn slice(&self, start: usize, end: usize) -> Arc<dyn DataLoader<B, O>> {
195        let dataloader = Self::new(
196            self.strategy.clone_dyn(),
197            Arc::new(PartialDataset::new(self.dataset.clone(), start, end)),
198            self.batcher.clone(),
199            self.num_threads,
200            self.device.clone(),
201            self.rng.clone(),
202        );
203        Arc::new(dataloader)
204    }
205}
206
207impl<O> MultiThreadsDataloaderIterator<O> {
208    pub fn new(
209        receiver: mpsc::Receiver<Message<O>>,
210        workers: Vec<thread::JoinHandle<()>>,
211        progresses: Vec<Progress>,
212    ) -> Self {
213        MultiThreadsDataloaderIterator {
214            num_done: 0,
215            workers,
216            receiver,
217            progresses,
218        }
219    }
220}
221impl<O: std::fmt::Debug> DataLoaderIterator<O> for MultiThreadsDataloaderIterator<O> {
222    fn progress(&self) -> Progress {
223        let mut items_total = 0;
224        let mut items_processed = 0;
225
226        for progress in self.progresses.iter() {
227            items_total += progress.items_total;
228            items_processed += progress.items_processed;
229        }
230
231        Progress::new(items_processed, items_total)
232    }
233}
234
235impl<O: std::fmt::Debug> Iterator for MultiThreadsDataloaderIterator<O> {
236    type Item = O;
237
238    fn next(&mut self) -> Option<O> {
239        if self.workers.is_empty() {
240            return None;
241        }
242
243        loop {
244            let item = self.receiver.recv();
245            let item = item.unwrap();
246
247            match item {
248                Message::Batch(index, item, progress) => {
249                    if let Some(current) = self.progresses.get_mut(index) {
250                        *current = progress;
251                    }
252                    return Some(item);
253                }
254                Message::Done => {
255                    self.num_done += 1;
256                }
257            };
258
259            if self.num_done == self.workers.len() {
260                while let Some(worker) = self.workers.pop() {
261                    worker.join().unwrap();
262                }
263                return None;
264            }
265        }
266    }
267}
268
269#[cfg(test)]
270mod tests {
271    use super::*;
272    use crate::data::dataloader::FixBatchStrategy;
273    use crate::data::dataloader::batcher::TestBatcher;
274    use crate::data::dataset::FakeDataset;
275    use burn_dataset::InMemDataset;
276    use std::collections::HashSet;
277
278    #[test]
279    fn test_multi_thread_batch_dataloader() {
280        let batcher = Arc::new(TestBatcher::new());
281        let dataset = Arc::new(FakeDataset::<String>::new(27));
282        let dataloader_single_thread = BatchDataLoader::new(
283            Box::new(FixBatchStrategy::new(5)),
284            dataset.clone(),
285            batcher.clone(),
286            Default::default(),
287            None,
288        );
289        let dataloader_multi_thread = MultiThreadDataLoader::new(
290            Box::new(FixBatchStrategy::new(5)),
291            dataset,
292            batcher,
293            4,
294            Default::default(),
295            None,
296        );
297
298        let mut items_single_thread = HashSet::new();
299        let mut items_multi_thread = HashSet::new();
300
301        for items in dataloader_single_thread.iter() {
302            for item in items {
303                items_single_thread.insert(item);
304            }
305        }
306
307        for items in dataloader_multi_thread.iter() {
308            for item in items {
309                items_multi_thread.insert(item);
310            }
311        }
312
313        assert_eq!(items_single_thread, items_multi_thread);
314    }
315
316    #[test]
317    fn test_multi_thread_batch_dataloader_shuffle() {
318        let num_classes = 2;
319        let class_size = 100;
320        let batch_size = 10;
321
322        // Items is a deliberately ordered dataset.
323        let mut items = Vec::new();
324        for class in 0..num_classes {
325            items.extend(vec![class; class_size]);
326        }
327
328        {
329            // Unshuffled multithreaded loader
330            let dataset = Arc::new(InMemDataset::new(items.clone()));
331            let batcher = Arc::new(TestBatcher::new());
332
333            let loader = MultiThreadDataLoader::new(
334                Box::new(FixBatchStrategy::new(batch_size)),
335                dataset,
336                batcher,
337                num_classes,
338                Default::default(),
339                // No rng means no shuffling.
340                None,
341            );
342
343            for batch in loader.iter() {
344                let mut batch_items = HashSet::new();
345                for item in batch {
346                    batch_items.insert(item);
347                }
348
349                // Since the dataset is not shuffled, we expect each batch to contain the same item.
350                assert_eq!(batch_items.len(), 1);
351            }
352        }
353
354        {
355            // Shuffled multithreaded loader
356            let dataset = Arc::new(InMemDataset::new(items.clone()));
357            let batcher = Arc::new(TestBatcher::new());
358
359            let loader = MultiThreadDataLoader::new(
360                Box::new(FixBatchStrategy::new(batch_size)),
361                dataset.clone(),
362                batcher.clone(),
363                num_classes,
364                Default::default(),
365                // The rng enables shuffling.
366                Some(StdRng::seed_from_u64(42)),
367            );
368
369            for batch in loader.iter() {
370                let mut batch_items = HashSet::new();
371                for item in batch {
372                    batch_items.insert(item);
373                }
374
375                // Since the dataset is shuffled, we expect to see all items.
376                assert_eq!(batch_items.len(), num_classes);
377            }
378        }
379    }
380
381    #[test]
382    fn test_multi_thread_batch_dataloader_incomplete_batches() {
383        let batcher = Arc::new(TestBatcher::new());
384        let dataset = Arc::new(FakeDataset::<String>::new(27));
385        let dataloader_single_thread = BatchDataLoader::new(
386            Box::new(FixBatchStrategy::new(5)),
387            dataset.clone(),
388            batcher.clone(),
389            Default::default(),
390            None,
391        );
392        let dataloader_multi_thread = MultiThreadDataLoader::new(
393            Box::new(FixBatchStrategy::new(5)),
394            dataset,
395            batcher,
396            4,
397            Default::default(),
398            None,
399        );
400
401        let mut items_single_thread = HashSet::new();
402        let mut items_multi_thread = HashSet::new();
403
404        let mut single_thread_cnt = 0;
405        let mut multi_thread_cnt = 0;
406        for items in dataloader_single_thread.iter() {
407            items_single_thread.insert(items);
408            single_thread_cnt += 1;
409        }
410
411        for items in dataloader_multi_thread.iter() {
412            items_multi_thread.insert(items);
413            multi_thread_cnt += 1;
414        }
415
416        assert_eq!(single_thread_cnt, multi_thread_cnt);
417        assert_eq!(items_single_thread, items_multi_thread);
418    }
419}