Skip to main content

burn_core/data/dataloader/
multithread.rs

1use burn_dataset::Dataset;
2use burn_dataset::transform::PartialDataset;
3use burn_tensor::backend::Backend;
4use rand::distr::{Distribution, StandardUniform};
5use rand::rngs::StdRng;
6use rand::{Rng, SeedableRng};
7
8use super::batcher::Batcher;
9use super::{BatchDataLoader, BatchStrategy, DataLoader, DataLoaderIterator, Progress};
10use std::sync::{Arc, OnceLock, mpsc};
11use std::thread;
12
13const MAX_QUEUED_ITEMS: usize = 100;
14
15type RngSeed = <StdRng as SeedableRng>::Seed;
16
17/// A multi-threaded data loader that can be used to iterate over a dataset.
18pub struct MultiThreadDataLoader<B: Backend, I, O> {
19    // Configuration parameters needed for initialization
20    strategy: Box<dyn BatchStrategy<I>>,
21    dataset: Arc<dyn Dataset<I>>,
22    batcher: Arc<dyn Batcher<B, I, O>>,
23    device: B::Device,
24    seed: Option<RngSeed>,
25    num_threads: usize,
26
27    // The lazily initialized data loaders
28    dataloaders: OnceLock<Vec<BatchDataLoader<B, I, O>>>,
29}
30
31/// A message that can be sent between threads.
32#[derive(Debug)]
33pub enum Message<O> {
34    /// A batch of items.
35    Batch(usize, O, Progress),
36
37    /// The thread is done.
38    Done,
39}
40
41struct MultiThreadsDataloaderIterator<O> {
42    num_done: usize,
43    workers: Vec<thread::JoinHandle<()>>,
44    receiver: mpsc::Receiver<Message<O>>,
45    progresses: Vec<Progress>,
46}
47
48impl<B: Backend, I, O> MultiThreadDataLoader<B, I, O>
49where
50    I: Send + Sync + Clone + 'static,
51    O: Send + 'static,
52{
53    /// Creates a new multi-threaded batch data loader.
54    ///
55    /// # Arguments
56    ///
57    /// * `strategy` - The batch strategy.
58    /// * `dataset` - The dataset.
59    /// * `batcher` - The batcher.
60    /// * `num_threads` - The number of threads.
61    /// * `device`  - The device to use when loading a batch.
62    /// * `rng`     - The rng determining if the dataset is shuffled each time a dataloader
63    ///   iterator is created.
64    ///
65    /// # Returns
66    ///
67    /// The multi-threaded batch data loader.
68    pub fn new(
69        strategy: Box<dyn BatchStrategy<I>>,
70        dataset: Arc<dyn Dataset<I>>,
71        batcher: Arc<dyn Batcher<B, I, O>>,
72        num_threads: usize,
73        device: B::Device,
74        rng: Option<rand::rngs::StdRng>,
75    ) -> Self {
76        let mut seed = None;
77        if let Some(mut rng) = rng {
78            // RNG stream splitting (not state cloning): derive a new seed from the RNG's output.
79            // This is exactly what `rng.fork()` does.
80            let mut s = RngSeed::default();
81            rng.fill_bytes(&mut s);
82
83            seed = Some(s);
84        }
85        Self::from_seed(strategy, dataset, batcher, num_threads, device, seed)
86    }
87
88    fn from_seed(
89        strategy: Box<dyn BatchStrategy<I>>,
90        dataset: Arc<dyn Dataset<I>>,
91        batcher: Arc<dyn Batcher<B, I, O>>,
92        num_threads: usize,
93        device: B::Device,
94        seed: Option<RngSeed>,
95    ) -> Self {
96        Self {
97            strategy,
98            dataset,
99            batcher,
100            num_threads,
101            device,
102            seed,
103            dataloaders: OnceLock::new(),
104        }
105    }
106
107    /// Force initialization if needed.
108    fn initialize(&self) -> &[BatchDataLoader<B, I, O>] {
109        self.dataloaders
110            .get_or_init(|| {
111                let mut dataset = self.dataset.clone();
112                if let Some(seed) = self.seed.as_ref() {
113                    // Pre-shuffle the dataset before split if shuffle is enabled.
114                    // This ensures that each thread gets a uniform random sample of the dataset.
115                    let mut rng = StdRng::from_seed(*seed);
116                    dataset = Arc::new(burn_dataset::transform::ShuffledDataset::new(
117                        dataset, &mut rng,
118                    ));
119                }
120
121                let datasets = match self.strategy.batch_size() {
122                    Some(batch_size) => {
123                        PartialDataset::split_chunks(dataset, self.num_threads, batch_size)
124                    }
125                    None => PartialDataset::split(dataset, self.num_threads),
126                };
127
128                // Create more rngs from the first one, one for each new dataloader.
129                let mut rng = self.seed.map(StdRng::from_seed);
130                let rngs = (0..self.num_threads).map(|_| {
131                    rng.as_mut().map(|rng| {
132                        StdRng::seed_from_u64(Distribution::sample(&StandardUniform, rng))
133                    })
134                });
135
136                datasets
137                    .into_iter()
138                    .zip(rngs)
139                    .map(|(dataset, rng)| {
140                        let strategy = self.strategy.clone_dyn();
141                        BatchDataLoader::new(
142                            strategy,
143                            Arc::new(dataset),
144                            self.batcher.clone(),
145                            self.device.clone(),
146                            rng,
147                        )
148                    })
149                    .collect()
150            })
151            .as_ref()
152    }
153}
154
155impl<B: Backend, I, O> DataLoader<B, O> for MultiThreadDataLoader<B, I, O>
156where
157    I: Send + Sync + Clone + 'static,
158    O: Send + 'static + std::fmt::Debug,
159{
160    fn iter<'a>(&'a self) -> Box<dyn DataLoaderIterator<O> + 'a> {
161        // This will initialize the loader if it hasn't been initialized yet
162        let dataloaders = self.initialize();
163
164        let (sender, receiver) = mpsc::sync_channel::<Message<O>>(MAX_QUEUED_ITEMS);
165
166        let mut progresses = Vec::with_capacity(dataloaders.len());
167
168        let handlers: Vec<_> = dataloaders
169            .iter()
170            .enumerate()
171            .map(|(index, dataloader)| {
172                let dataloader_cloned = dataloader.clone();
173                let sender_cloned = sender.clone();
174                progresses.push(Progress::new(0, dataloader_cloned.num_items()));
175
176                std::thread::Builder::new()
177                    .name(std::format!("dataloader-{index}"))
178                    .spawn(move || {
179                        let mut iterator = dataloader_cloned.iter();
180                        while let Some(item) = iterator.next() {
181                            let progress = iterator.progress();
182
183                            match sender_cloned.send(Message::Batch(index, item, progress)) {
184                                Ok(_) => {}
185                                // The receiver is probably gone, no need to panic, just need to stop
186                                // iterating.
187                                Err(_) => return,
188                            };
189                        }
190                        // Same thing.
191                        sender_cloned.send(Message::Done).ok();
192                    })
193                    .unwrap()
194            })
195            .collect();
196
197        Box::new(MultiThreadsDataloaderIterator::new(
198            receiver, handlers, progresses,
199        ))
200    }
201
202    fn num_items(&self) -> usize {
203        // For num_items, we can directly use the dataset size without
204        // necessarily initializing the full loader
205        self.dataset.len()
206    }
207
208    fn to_device(&self, device: &B::Device) -> Arc<dyn DataLoader<B, O>> {
209        Arc::new(Self::from_seed(
210            self.strategy.clone_dyn(),
211            self.dataset.clone(),
212            self.batcher.clone(),
213            self.num_threads,
214            device.clone(),
215            self.seed,
216        ))
217    }
218
219    fn slice(&self, start: usize, end: usize) -> Arc<dyn DataLoader<B, O>> {
220        let dataloader = Self::from_seed(
221            self.strategy.clone_dyn(),
222            Arc::new(PartialDataset::new(self.dataset.clone(), start, end)),
223            self.batcher.clone(),
224            self.num_threads,
225            self.device.clone(),
226            self.seed,
227        );
228        Arc::new(dataloader)
229    }
230}
231
232impl<O> MultiThreadsDataloaderIterator<O> {
233    pub fn new(
234        receiver: mpsc::Receiver<Message<O>>,
235        workers: Vec<thread::JoinHandle<()>>,
236        progresses: Vec<Progress>,
237    ) -> Self {
238        MultiThreadsDataloaderIterator {
239            num_done: 0,
240            workers,
241            receiver,
242            progresses,
243        }
244    }
245}
246impl<O: std::fmt::Debug> DataLoaderIterator<O> for MultiThreadsDataloaderIterator<O> {
247    fn progress(&self) -> Progress {
248        let mut items_total = 0;
249        let mut items_processed = 0;
250
251        for progress in self.progresses.iter() {
252            items_total += progress.items_total;
253            items_processed += progress.items_processed;
254        }
255
256        Progress::new(items_processed, items_total)
257    }
258}
259
260impl<O: std::fmt::Debug> Iterator for MultiThreadsDataloaderIterator<O> {
261    type Item = O;
262
263    fn next(&mut self) -> Option<O> {
264        if self.workers.is_empty() {
265            return None;
266        }
267
268        loop {
269            let item = self.receiver.recv();
270            let item = item.unwrap();
271
272            match item {
273                Message::Batch(index, item, progress) => {
274                    if let Some(current) = self.progresses.get_mut(index) {
275                        *current = progress;
276                    }
277                    return Some(item);
278                }
279                Message::Done => {
280                    self.num_done += 1;
281                }
282            };
283
284            if self.num_done == self.workers.len() {
285                while let Some(worker) = self.workers.pop() {
286                    worker.join().unwrap();
287                }
288                return None;
289            }
290        }
291    }
292}
293
294#[cfg(test)]
295mod tests {
296    use super::*;
297    use crate::data::dataloader::FixBatchStrategy;
298    use crate::data::dataloader::batcher::TestBatcher;
299    use crate::data::dataset::FakeDataset;
300    use burn_dataset::InMemDataset;
301    use std::collections::HashSet;
302
303    #[test]
304    fn test_multi_thread_batch_dataloader() {
305        let batcher = Arc::new(TestBatcher::new());
306        let dataset = Arc::new(FakeDataset::<String>::new(27));
307        let dataloader_single_thread = BatchDataLoader::new(
308            Box::new(FixBatchStrategy::new(5)),
309            dataset.clone(),
310            batcher.clone(),
311            Default::default(),
312            None,
313        );
314        let dataloader_multi_thread = MultiThreadDataLoader::new(
315            Box::new(FixBatchStrategy::new(5)),
316            dataset,
317            batcher,
318            4,
319            Default::default(),
320            None,
321        );
322
323        let mut items_single_thread = HashSet::new();
324        let mut items_multi_thread = HashSet::new();
325
326        for items in dataloader_single_thread.iter() {
327            for item in items {
328                items_single_thread.insert(item);
329            }
330        }
331
332        for items in dataloader_multi_thread.iter() {
333            for item in items {
334                items_multi_thread.insert(item);
335            }
336        }
337
338        assert_eq!(items_single_thread, items_multi_thread);
339    }
340
341    #[test]
342    fn test_multi_thread_batch_dataloader_shuffle() {
343        let num_classes = 2;
344        let class_size = 100;
345        let batch_size = 10;
346
347        // Items is a deliberately ordered dataset.
348        let mut items = Vec::new();
349        for class in 0..num_classes {
350            items.extend(vec![class; class_size]);
351        }
352
353        {
354            // Unshuffled multithreaded loader
355            let dataset = Arc::new(InMemDataset::new(items.clone()));
356            let batcher = Arc::new(TestBatcher::new());
357
358            let loader = MultiThreadDataLoader::new(
359                Box::new(FixBatchStrategy::new(batch_size)),
360                dataset,
361                batcher,
362                num_classes,
363                Default::default(),
364                // No rng means no shuffling.
365                None,
366            );
367
368            for batch in loader.iter() {
369                let mut batch_items = HashSet::new();
370                for item in batch {
371                    batch_items.insert(item);
372                }
373
374                // Since the dataset is not shuffled, we expect each batch to contain the same item.
375                assert_eq!(batch_items.len(), 1);
376            }
377        }
378
379        {
380            // Shuffled multithreaded loader
381            let dataset = Arc::new(InMemDataset::new(items.clone()));
382            let batcher = Arc::new(TestBatcher::new());
383
384            let loader = MultiThreadDataLoader::new(
385                Box::new(FixBatchStrategy::new(batch_size)),
386                dataset.clone(),
387                batcher.clone(),
388                num_classes,
389                Default::default(),
390                // The rng enables shuffling.
391                Some(StdRng::seed_from_u64(42)),
392            );
393
394            for batch in loader.iter() {
395                let mut batch_items = HashSet::new();
396                for item in batch {
397                    batch_items.insert(item);
398                }
399
400                // Since the dataset is shuffled, we expect to see all items.
401                assert_eq!(batch_items.len(), num_classes);
402            }
403        }
404    }
405
406    #[test]
407    fn test_multi_thread_batch_dataloader_incomplete_batches() {
408        let batcher = Arc::new(TestBatcher::new());
409        let dataset = Arc::new(FakeDataset::<String>::new(27));
410        let dataloader_single_thread = BatchDataLoader::new(
411            Box::new(FixBatchStrategy::new(5)),
412            dataset.clone(),
413            batcher.clone(),
414            Default::default(),
415            None,
416        );
417        let dataloader_multi_thread = MultiThreadDataLoader::new(
418            Box::new(FixBatchStrategy::new(5)),
419            dataset,
420            batcher,
421            4,
422            Default::default(),
423            None,
424        );
425
426        let mut items_single_thread = HashSet::new();
427        let mut items_multi_thread = HashSet::new();
428
429        let mut single_thread_cnt = 0;
430        let mut multi_thread_cnt = 0;
431        for items in dataloader_single_thread.iter() {
432            items_single_thread.insert(items);
433            single_thread_cnt += 1;
434        }
435
436        for items in dataloader_multi_thread.iter() {
437            items_multi_thread.insert(items);
438            multi_thread_cnt += 1;
439        }
440
441        assert_eq!(single_thread_cnt, multi_thread_cnt);
442        assert_eq!(items_single_thread, items_multi_thread);
443    }
444}