burn_core/data/dataloader/
multithread.rs

1use burn_dataset::Dataset;
2use burn_dataset::transform::PartialDataset;
3use burn_tensor::backend::Backend;
4use rand::SeedableRng;
5use rand::distr::{Distribution, StandardUniform};
6use rand::rngs::StdRng;
7
8use super::batcher::Batcher;
9use super::{BatchDataLoader, BatchStrategy, DataLoader, DataLoaderIterator, Progress};
10use core::cell::OnceCell;
11use std::sync::{Arc, mpsc};
12use std::thread;
13
14const MAX_QUEUED_ITEMS: usize = 100;
15
16/// A multi-threaded data loader that can be used to iterate over a dataset.
17pub struct MultiThreadDataLoader<B: Backend, I, O> {
18    // Configuration parameters needed for initialization
19    strategy: Box<dyn BatchStrategy<I>>,
20    dataset: Arc<dyn Dataset<I>>,
21    batcher: Arc<dyn Batcher<B, I, O>>,
22    device: B::Device,
23    rng: Option<rand::rngs::StdRng>,
24    num_threads: usize,
25
26    // The lazily initialized data loaders
27    dataloaders: OnceCell<Vec<BatchDataLoader<B, I, O>>>,
28}
29
30/// A message that can be sent between threads.
31#[derive(Debug)]
32pub enum Message<O> {
33    /// A batch of items.
34    Batch(usize, O, Progress),
35
36    /// The thread is done.
37    Done,
38}
39
40struct MultiThreadsDataloaderIterator<O> {
41    num_done: usize,
42    workers: Vec<thread::JoinHandle<()>>,
43    receiver: mpsc::Receiver<Message<O>>,
44    progresses: Vec<Progress>,
45}
46
47impl<B: Backend, I, O> MultiThreadDataLoader<B, I, O>
48where
49    I: Send + Sync + Clone + 'static,
50    O: Send + 'static,
51{
52    /// Creates a new multi-threaded batch data loader.
53    ///
54    /// # Arguments
55    ///
56    /// * `strategy` - The batch strategy.
57    /// * `dataset` - The dataset.
58    /// * `batcher` - The batcher.
59    /// * `num_threads` - The number of threads.
60    /// * `device`  - The device to use when loading a batch.
61    /// * `rng`     - The rng determining if the dataset is shuffled each time a dataloader
62    ///   iterator is created.
63    ///
64    /// # Returns
65    ///
66    /// The multi-threaded batch data loader.
67    pub fn new(
68        strategy: Box<dyn BatchStrategy<I>>,
69        dataset: Arc<dyn Dataset<I>>,
70        batcher: Arc<dyn Batcher<B, I, O>>,
71        num_threads: usize,
72        device: B::Device,
73        rng: Option<rand::rngs::StdRng>,
74    ) -> Self {
75        Self {
76            strategy,
77            dataset,
78            batcher,
79            num_threads,
80            device,
81            rng,
82            dataloaders: OnceCell::new(),
83        }
84    }
85
86    /// Force initialization if needed.
87    fn initialize(&self) -> &[BatchDataLoader<B, I, O>] {
88        self.dataloaders
89            .get_or_init(|| {
90                let datasets = PartialDataset::split(self.dataset.clone(), self.num_threads);
91
92                // Create more rngs from the first one, one for each new dataloader.
93                let mut rng = self.rng.clone();
94                let rngs = (0..self.num_threads).map(|_| {
95                    rng.as_mut().map(|rng| {
96                        StdRng::seed_from_u64(Distribution::sample(&StandardUniform, rng))
97                    })
98                });
99
100                datasets
101                    .into_iter()
102                    .zip(rngs)
103                    .map(|(dataset, rng)| {
104                        let strategy = self.strategy.clone_dyn();
105                        BatchDataLoader::new(
106                            strategy,
107                            Arc::new(dataset),
108                            self.batcher.clone(),
109                            self.device.clone(),
110                            rng,
111                        )
112                    })
113                    .collect()
114            })
115            .as_ref()
116    }
117}
118
119impl<B: Backend, I, O> DataLoader<B, O> for MultiThreadDataLoader<B, I, O>
120where
121    I: Send + Sync + Clone + 'static,
122    O: Send + 'static + std::fmt::Debug,
123{
124    fn iter<'a>(&'a self) -> Box<dyn DataLoaderIterator<O> + 'a> {
125        // This will initialize the loader if it hasn't been initialized yet
126        let dataloaders = self.initialize();
127
128        let (sender, receiver) = mpsc::sync_channel::<Message<O>>(MAX_QUEUED_ITEMS);
129
130        let mut progresses = Vec::with_capacity(dataloaders.len());
131
132        let handlers: Vec<_> = dataloaders
133            .iter()
134            .enumerate()
135            .map(|(index, dataloader)| {
136                let dataloader_cloned = dataloader.clone();
137                let sender_cloned = sender.clone();
138                progresses.push(Progress::new(0, dataloader_cloned.num_items()));
139
140                thread::spawn(move || {
141                    let mut iterator = dataloader_cloned.iter();
142                    while let Some(item) = iterator.next() {
143                        let progress = iterator.progress();
144
145                        match sender_cloned.send(Message::Batch(index, item, progress)) {
146                            Ok(_) => {}
147                            // The receiver is probably gone, no need to panic, just need to stop
148                            // iterating.
149                            Err(_) => return,
150                        };
151                    }
152                    // Same thing.
153                    sender_cloned.send(Message::Done).ok();
154                })
155            })
156            .collect();
157
158        Box::new(MultiThreadsDataloaderIterator::new(
159            receiver, handlers, progresses,
160        ))
161    }
162
163    fn num_items(&self) -> usize {
164        // For num_items, we can directly use the dataset size without
165        // necessarily initializing the full loader
166        self.dataset.len()
167    }
168
169    fn to_device(&self, device: &B::Device) -> Arc<dyn DataLoader<B, O>> {
170        Arc::new(Self::new(
171            self.strategy.clone_dyn(),
172            self.dataset.clone(),
173            self.batcher.clone(),
174            self.num_threads,
175            device.clone(),
176            self.rng.clone(),
177        ))
178    }
179
180    fn slice(&self, start: usize, end: usize) -> Arc<dyn DataLoader<B, O>> {
181        let dataloader = Self::new(
182            self.strategy.clone_dyn(),
183            Arc::new(PartialDataset::new(self.dataset.clone(), start, end)),
184            self.batcher.clone(),
185            self.num_threads,
186            self.device.clone(),
187            self.rng.clone(),
188        );
189        Arc::new(dataloader)
190    }
191}
192
193impl<O> MultiThreadsDataloaderIterator<O> {
194    pub fn new(
195        receiver: mpsc::Receiver<Message<O>>,
196        workers: Vec<thread::JoinHandle<()>>,
197        progresses: Vec<Progress>,
198    ) -> Self {
199        MultiThreadsDataloaderIterator {
200            num_done: 0,
201            workers,
202            receiver,
203            progresses,
204        }
205    }
206}
207impl<O: std::fmt::Debug> DataLoaderIterator<O> for MultiThreadsDataloaderIterator<O> {
208    fn progress(&self) -> Progress {
209        let mut items_total = 0;
210        let mut items_processed = 0;
211
212        for progress in self.progresses.iter() {
213            items_total += progress.items_total;
214            items_processed += progress.items_processed;
215        }
216
217        Progress::new(items_processed, items_total)
218    }
219}
220
221impl<O: std::fmt::Debug> Iterator for MultiThreadsDataloaderIterator<O> {
222    type Item = O;
223
224    fn next(&mut self) -> Option<O> {
225        if self.workers.is_empty() {
226            return None;
227        }
228
229        loop {
230            let item = self.receiver.recv();
231            let item = item.unwrap();
232
233            match item {
234                Message::Batch(index, item, progress) => {
235                    if let Some(current) = self.progresses.get_mut(index) {
236                        *current = progress;
237                    }
238                    return Some(item);
239                }
240                Message::Done => {
241                    self.num_done += 1;
242                }
243            };
244
245            if self.num_done == self.workers.len() {
246                while let Some(worker) = self.workers.pop() {
247                    worker.join().unwrap();
248                }
249                return None;
250            }
251        }
252    }
253}
254
255#[cfg(test)]
256mod tests {
257    use std::collections::HashSet;
258
259    use super::*;
260    use crate::data::dataloader::FixBatchStrategy;
261    use crate::data::dataloader::batcher::TestBatcher;
262    use crate::data::dataset::FakeDataset;
263
264    #[test]
265    fn test_multi_thread_batch_dataloader() {
266        let batcher = Arc::new(TestBatcher::new());
267        let dataset = Arc::new(FakeDataset::<String>::new(27));
268        let dataloader_single_thread = BatchDataLoader::new(
269            Box::new(FixBatchStrategy::new(5)),
270            dataset.clone(),
271            batcher.clone(),
272            Default::default(),
273            None,
274        );
275        let dataloader_multi_thread = MultiThreadDataLoader::new(
276            Box::new(FixBatchStrategy::new(5)),
277            dataset,
278            batcher,
279            4,
280            Default::default(),
281            None,
282        );
283
284        let mut items_single_thread = HashSet::new();
285        let mut items_multi_thread = HashSet::new();
286
287        for items in dataloader_single_thread.iter() {
288            for item in items {
289                items_single_thread.insert(item);
290            }
291        }
292
293        for items in dataloader_multi_thread.iter() {
294            for item in items {
295                items_multi_thread.insert(item);
296            }
297        }
298
299        assert_eq!(items_single_thread, items_multi_thread);
300    }
301}