burn_core/data/dataloader/
multithread.rs1use burn_dataset::Dataset;
2use burn_dataset::transform::PartialDataset;
3use burn_tensor::backend::Backend;
4use rand::SeedableRng;
5use rand::distr::{Distribution, StandardUniform};
6use rand::rngs::StdRng;
7
8use super::batcher::Batcher;
9use super::{BatchDataLoader, BatchStrategy, DataLoader, DataLoaderIterator, Progress};
10use std::sync::{Arc, OnceLock, mpsc};
11use std::thread;
12
13const MAX_QUEUED_ITEMS: usize = 100;
14
15pub struct MultiThreadDataLoader<B: Backend, I, O> {
17 strategy: Box<dyn BatchStrategy<I>>,
19 dataset: Arc<dyn Dataset<I>>,
20 batcher: Arc<dyn Batcher<B, I, O>>,
21 device: B::Device,
22 rng: Option<rand::rngs::StdRng>,
23 num_threads: usize,
24
25 dataloaders: OnceLock<Vec<BatchDataLoader<B, I, O>>>,
27}
28
29#[derive(Debug)]
31pub enum Message<O> {
32 Batch(usize, O, Progress),
34
35 Done,
37}
38
39struct MultiThreadsDataloaderIterator<O> {
40 num_done: usize,
41 workers: Vec<thread::JoinHandle<()>>,
42 receiver: mpsc::Receiver<Message<O>>,
43 progresses: Vec<Progress>,
44}
45
46impl<B: Backend, I, O> MultiThreadDataLoader<B, I, O>
47where
48 I: Send + Sync + Clone + 'static,
49 O: Send + 'static,
50{
51 pub fn new(
67 strategy: Box<dyn BatchStrategy<I>>,
68 dataset: Arc<dyn Dataset<I>>,
69 batcher: Arc<dyn Batcher<B, I, O>>,
70 num_threads: usize,
71 device: B::Device,
72 rng: Option<rand::rngs::StdRng>,
73 ) -> Self {
74 Self {
75 strategy,
76 dataset,
77 batcher,
78 num_threads,
79 device,
80 rng,
81 dataloaders: OnceLock::new(),
82 }
83 }
84
85 fn initialize(&self) -> &[BatchDataLoader<B, I, O>] {
87 self.dataloaders
88 .get_or_init(|| {
89 let mut dataset = self.dataset.clone();
90 if let Some(rng) = self.rng.as_ref() {
91 let mut rng = rng.clone();
94 dataset = Arc::new(burn_dataset::transform::ShuffledDataset::new(
95 dataset, &mut rng,
96 ));
97 }
98
99 let datasets = match self.strategy.batch_size() {
100 Some(batch_size) => {
101 PartialDataset::split_chunks(dataset, self.num_threads, batch_size)
102 }
103 None => PartialDataset::split(dataset, self.num_threads),
104 };
105
106 let mut rng = self.rng.clone();
108 let rngs = (0..self.num_threads).map(|_| {
109 rng.as_mut().map(|rng| {
110 StdRng::seed_from_u64(Distribution::sample(&StandardUniform, rng))
111 })
112 });
113
114 datasets
115 .into_iter()
116 .zip(rngs)
117 .map(|(dataset, rng)| {
118 let strategy = self.strategy.clone_dyn();
119 BatchDataLoader::new(
120 strategy,
121 Arc::new(dataset),
122 self.batcher.clone(),
123 self.device.clone(),
124 rng,
125 )
126 })
127 .collect()
128 })
129 .as_ref()
130 }
131}
132
133impl<B: Backend, I, O> DataLoader<B, O> for MultiThreadDataLoader<B, I, O>
134where
135 I: Send + Sync + Clone + 'static,
136 O: Send + 'static + std::fmt::Debug,
137{
138 fn iter<'a>(&'a self) -> Box<dyn DataLoaderIterator<O> + 'a> {
139 let dataloaders = self.initialize();
141
142 let (sender, receiver) = mpsc::sync_channel::<Message<O>>(MAX_QUEUED_ITEMS);
143
144 let mut progresses = Vec::with_capacity(dataloaders.len());
145
146 let handlers: Vec<_> = dataloaders
147 .iter()
148 .enumerate()
149 .map(|(index, dataloader)| {
150 let dataloader_cloned = dataloader.clone();
151 let sender_cloned = sender.clone();
152 progresses.push(Progress::new(0, dataloader_cloned.num_items()));
153
154 thread::spawn(move || {
155 let mut iterator = dataloader_cloned.iter();
156 while let Some(item) = iterator.next() {
157 let progress = iterator.progress();
158
159 match sender_cloned.send(Message::Batch(index, item, progress)) {
160 Ok(_) => {}
161 Err(_) => return,
164 };
165 }
166 sender_cloned.send(Message::Done).ok();
168 })
169 })
170 .collect();
171
172 Box::new(MultiThreadsDataloaderIterator::new(
173 receiver, handlers, progresses,
174 ))
175 }
176
177 fn num_items(&self) -> usize {
178 self.dataset.len()
181 }
182
183 fn to_device(&self, device: &B::Device) -> Arc<dyn DataLoader<B, O>> {
184 Arc::new(Self::new(
185 self.strategy.clone_dyn(),
186 self.dataset.clone(),
187 self.batcher.clone(),
188 self.num_threads,
189 device.clone(),
190 self.rng.clone(),
191 ))
192 }
193
194 fn slice(&self, start: usize, end: usize) -> Arc<dyn DataLoader<B, O>> {
195 let dataloader = Self::new(
196 self.strategy.clone_dyn(),
197 Arc::new(PartialDataset::new(self.dataset.clone(), start, end)),
198 self.batcher.clone(),
199 self.num_threads,
200 self.device.clone(),
201 self.rng.clone(),
202 );
203 Arc::new(dataloader)
204 }
205}
206
207impl<O> MultiThreadsDataloaderIterator<O> {
208 pub fn new(
209 receiver: mpsc::Receiver<Message<O>>,
210 workers: Vec<thread::JoinHandle<()>>,
211 progresses: Vec<Progress>,
212 ) -> Self {
213 MultiThreadsDataloaderIterator {
214 num_done: 0,
215 workers,
216 receiver,
217 progresses,
218 }
219 }
220}
221impl<O: std::fmt::Debug> DataLoaderIterator<O> for MultiThreadsDataloaderIterator<O> {
222 fn progress(&self) -> Progress {
223 let mut items_total = 0;
224 let mut items_processed = 0;
225
226 for progress in self.progresses.iter() {
227 items_total += progress.items_total;
228 items_processed += progress.items_processed;
229 }
230
231 Progress::new(items_processed, items_total)
232 }
233}
234
235impl<O: std::fmt::Debug> Iterator for MultiThreadsDataloaderIterator<O> {
236 type Item = O;
237
238 fn next(&mut self) -> Option<O> {
239 if self.workers.is_empty() {
240 return None;
241 }
242
243 loop {
244 let item = self.receiver.recv();
245 let item = item.unwrap();
246
247 match item {
248 Message::Batch(index, item, progress) => {
249 if let Some(current) = self.progresses.get_mut(index) {
250 *current = progress;
251 }
252 return Some(item);
253 }
254 Message::Done => {
255 self.num_done += 1;
256 }
257 };
258
259 if self.num_done == self.workers.len() {
260 while let Some(worker) = self.workers.pop() {
261 worker.join().unwrap();
262 }
263 return None;
264 }
265 }
266 }
267}
268
269#[cfg(test)]
270mod tests {
271 use super::*;
272 use crate::data::dataloader::FixBatchStrategy;
273 use crate::data::dataloader::batcher::TestBatcher;
274 use crate::data::dataset::FakeDataset;
275 use burn_dataset::InMemDataset;
276 use std::collections::HashSet;
277
278 #[test]
279 fn test_multi_thread_batch_dataloader() {
280 let batcher = Arc::new(TestBatcher::new());
281 let dataset = Arc::new(FakeDataset::<String>::new(27));
282 let dataloader_single_thread = BatchDataLoader::new(
283 Box::new(FixBatchStrategy::new(5)),
284 dataset.clone(),
285 batcher.clone(),
286 Default::default(),
287 None,
288 );
289 let dataloader_multi_thread = MultiThreadDataLoader::new(
290 Box::new(FixBatchStrategy::new(5)),
291 dataset,
292 batcher,
293 4,
294 Default::default(),
295 None,
296 );
297
298 let mut items_single_thread = HashSet::new();
299 let mut items_multi_thread = HashSet::new();
300
301 for items in dataloader_single_thread.iter() {
302 for item in items {
303 items_single_thread.insert(item);
304 }
305 }
306
307 for items in dataloader_multi_thread.iter() {
308 for item in items {
309 items_multi_thread.insert(item);
310 }
311 }
312
313 assert_eq!(items_single_thread, items_multi_thread);
314 }
315
316 #[test]
317 fn test_multi_thread_batch_dataloader_shuffle() {
318 let num_classes = 2;
319 let class_size = 100;
320 let batch_size = 10;
321
322 let mut items = Vec::new();
324 for class in 0..num_classes {
325 items.extend(vec![class; class_size]);
326 }
327
328 {
329 let dataset = Arc::new(InMemDataset::new(items.clone()));
331 let batcher = Arc::new(TestBatcher::new());
332
333 let loader = MultiThreadDataLoader::new(
334 Box::new(FixBatchStrategy::new(batch_size)),
335 dataset,
336 batcher,
337 num_classes,
338 Default::default(),
339 None,
341 );
342
343 for batch in loader.iter() {
344 let mut batch_items = HashSet::new();
345 for item in batch {
346 batch_items.insert(item);
347 }
348
349 assert_eq!(batch_items.len(), 1);
351 }
352 }
353
354 {
355 let dataset = Arc::new(InMemDataset::new(items.clone()));
357 let batcher = Arc::new(TestBatcher::new());
358
359 let loader = MultiThreadDataLoader::new(
360 Box::new(FixBatchStrategy::new(batch_size)),
361 dataset.clone(),
362 batcher.clone(),
363 num_classes,
364 Default::default(),
365 Some(StdRng::seed_from_u64(42)),
367 );
368
369 for batch in loader.iter() {
370 let mut batch_items = HashSet::new();
371 for item in batch {
372 batch_items.insert(item);
373 }
374
375 assert_eq!(batch_items.len(), num_classes);
377 }
378 }
379 }
380
381 #[test]
382 fn test_multi_thread_batch_dataloader_incomplete_batches() {
383 let batcher = Arc::new(TestBatcher::new());
384 let dataset = Arc::new(FakeDataset::<String>::new(27));
385 let dataloader_single_thread = BatchDataLoader::new(
386 Box::new(FixBatchStrategy::new(5)),
387 dataset.clone(),
388 batcher.clone(),
389 Default::default(),
390 None,
391 );
392 let dataloader_multi_thread = MultiThreadDataLoader::new(
393 Box::new(FixBatchStrategy::new(5)),
394 dataset,
395 batcher,
396 4,
397 Default::default(),
398 None,
399 );
400
401 let mut items_single_thread = HashSet::new();
402 let mut items_multi_thread = HashSet::new();
403
404 let mut single_thread_cnt = 0;
405 let mut multi_thread_cnt = 0;
406 for items in dataloader_single_thread.iter() {
407 items_single_thread.insert(items);
408 single_thread_cnt += 1;
409 }
410
411 for items in dataloader_multi_thread.iter() {
412 items_multi_thread.insert(items);
413 multi_thread_cnt += 1;
414 }
415
416 assert_eq!(single_thread_cnt, multi_thread_cnt);
417 assert_eq!(items_single_thread, items_multi_thread);
418 }
419}