burn_core/data/dataloader/
multithread.rs1use burn_dataset::Dataset;
2use burn_dataset::transform::PartialDataset;
3use burn_tensor::backend::Backend;
4use rand::distr::{Distribution, StandardUniform};
5use rand::rngs::StdRng;
6use rand::{Rng, SeedableRng};
7
8use super::batcher::Batcher;
9use super::{BatchDataLoader, BatchStrategy, DataLoader, DataLoaderIterator, Progress};
10use std::sync::{Arc, OnceLock, mpsc};
11use std::thread;
12
13const MAX_QUEUED_ITEMS: usize = 100;
14
15type RngSeed = <StdRng as SeedableRng>::Seed;
16
17pub struct MultiThreadDataLoader<B: Backend, I, O> {
19 strategy: Box<dyn BatchStrategy<I>>,
21 dataset: Arc<dyn Dataset<I>>,
22 batcher: Arc<dyn Batcher<B, I, O>>,
23 device: B::Device,
24 seed: Option<RngSeed>,
25 num_threads: usize,
26
27 dataloaders: OnceLock<Vec<BatchDataLoader<B, I, O>>>,
29}
30
31#[derive(Debug)]
33pub enum Message<O> {
34 Batch(usize, O, Progress),
36
37 Done,
39}
40
41struct MultiThreadsDataloaderIterator<O> {
42 num_done: usize,
43 workers: Vec<thread::JoinHandle<()>>,
44 receiver: mpsc::Receiver<Message<O>>,
45 progresses: Vec<Progress>,
46}
47
48impl<B: Backend, I, O> MultiThreadDataLoader<B, I, O>
49where
50 I: Send + Sync + Clone + 'static,
51 O: Send + 'static,
52{
53 pub fn new(
69 strategy: Box<dyn BatchStrategy<I>>,
70 dataset: Arc<dyn Dataset<I>>,
71 batcher: Arc<dyn Batcher<B, I, O>>,
72 num_threads: usize,
73 device: B::Device,
74 rng: Option<rand::rngs::StdRng>,
75 ) -> Self {
76 let mut seed = None;
77 if let Some(mut rng) = rng {
78 let mut s = RngSeed::default();
81 rng.fill_bytes(&mut s);
82
83 seed = Some(s);
84 }
85 Self::from_seed(strategy, dataset, batcher, num_threads, device, seed)
86 }
87
88 fn from_seed(
89 strategy: Box<dyn BatchStrategy<I>>,
90 dataset: Arc<dyn Dataset<I>>,
91 batcher: Arc<dyn Batcher<B, I, O>>,
92 num_threads: usize,
93 device: B::Device,
94 seed: Option<RngSeed>,
95 ) -> Self {
96 Self {
97 strategy,
98 dataset,
99 batcher,
100 num_threads,
101 device,
102 seed,
103 dataloaders: OnceLock::new(),
104 }
105 }
106
107 fn initialize(&self) -> &[BatchDataLoader<B, I, O>] {
109 self.dataloaders
110 .get_or_init(|| {
111 let mut dataset = self.dataset.clone();
112 if let Some(seed) = self.seed.as_ref() {
113 let mut rng = StdRng::from_seed(*seed);
116 dataset = Arc::new(burn_dataset::transform::ShuffledDataset::new(
117 dataset, &mut rng,
118 ));
119 }
120
121 let datasets = match self.strategy.batch_size() {
122 Some(batch_size) => {
123 PartialDataset::split_chunks(dataset, self.num_threads, batch_size)
124 }
125 None => PartialDataset::split(dataset, self.num_threads),
126 };
127
128 let mut rng = self.seed.map(StdRng::from_seed);
130 let rngs = (0..self.num_threads).map(|_| {
131 rng.as_mut().map(|rng| {
132 StdRng::seed_from_u64(Distribution::sample(&StandardUniform, rng))
133 })
134 });
135
136 datasets
137 .into_iter()
138 .zip(rngs)
139 .map(|(dataset, rng)| {
140 let strategy = self.strategy.clone_dyn();
141 BatchDataLoader::new(
142 strategy,
143 Arc::new(dataset),
144 self.batcher.clone(),
145 self.device.clone(),
146 rng,
147 )
148 })
149 .collect()
150 })
151 .as_ref()
152 }
153}
154
155impl<B: Backend, I, O> DataLoader<B, O> for MultiThreadDataLoader<B, I, O>
156where
157 I: Send + Sync + Clone + 'static,
158 O: Send + 'static + std::fmt::Debug,
159{
160 fn iter<'a>(&'a self) -> Box<dyn DataLoaderIterator<O> + 'a> {
161 let dataloaders = self.initialize();
163
164 let (sender, receiver) = mpsc::sync_channel::<Message<O>>(MAX_QUEUED_ITEMS);
165
166 let mut progresses = Vec::with_capacity(dataloaders.len());
167
168 let handlers: Vec<_> = dataloaders
169 .iter()
170 .enumerate()
171 .map(|(index, dataloader)| {
172 let dataloader_cloned = dataloader.clone();
173 let sender_cloned = sender.clone();
174 progresses.push(Progress::new(0, dataloader_cloned.num_items()));
175
176 std::thread::Builder::new()
177 .name(std::format!("dataloader-{index}"))
178 .spawn(move || {
179 let mut iterator = dataloader_cloned.iter();
180 while let Some(item) = iterator.next() {
181 let progress = iterator.progress();
182
183 match sender_cloned.send(Message::Batch(index, item, progress)) {
184 Ok(_) => {}
185 Err(_) => return,
188 };
189 }
190 sender_cloned.send(Message::Done).ok();
192 })
193 .unwrap()
194 })
195 .collect();
196
197 Box::new(MultiThreadsDataloaderIterator::new(
198 receiver, handlers, progresses,
199 ))
200 }
201
202 fn num_items(&self) -> usize {
203 self.dataset.len()
206 }
207
208 fn to_device(&self, device: &B::Device) -> Arc<dyn DataLoader<B, O>> {
209 Arc::new(Self::from_seed(
210 self.strategy.clone_dyn(),
211 self.dataset.clone(),
212 self.batcher.clone(),
213 self.num_threads,
214 device.clone(),
215 self.seed,
216 ))
217 }
218
219 fn slice(&self, start: usize, end: usize) -> Arc<dyn DataLoader<B, O>> {
220 let dataloader = Self::from_seed(
221 self.strategy.clone_dyn(),
222 Arc::new(PartialDataset::new(self.dataset.clone(), start, end)),
223 self.batcher.clone(),
224 self.num_threads,
225 self.device.clone(),
226 self.seed,
227 );
228 Arc::new(dataloader)
229 }
230}
231
232impl<O> MultiThreadsDataloaderIterator<O> {
233 pub fn new(
234 receiver: mpsc::Receiver<Message<O>>,
235 workers: Vec<thread::JoinHandle<()>>,
236 progresses: Vec<Progress>,
237 ) -> Self {
238 MultiThreadsDataloaderIterator {
239 num_done: 0,
240 workers,
241 receiver,
242 progresses,
243 }
244 }
245}
246impl<O: std::fmt::Debug> DataLoaderIterator<O> for MultiThreadsDataloaderIterator<O> {
247 fn progress(&self) -> Progress {
248 let mut items_total = 0;
249 let mut items_processed = 0;
250
251 for progress in self.progresses.iter() {
252 items_total += progress.items_total;
253 items_processed += progress.items_processed;
254 }
255
256 Progress::new(items_processed, items_total)
257 }
258}
259
260impl<O: std::fmt::Debug> Iterator for MultiThreadsDataloaderIterator<O> {
261 type Item = O;
262
263 fn next(&mut self) -> Option<O> {
264 if self.workers.is_empty() {
265 return None;
266 }
267
268 loop {
269 let item = self.receiver.recv();
270 let item = item.unwrap();
271
272 match item {
273 Message::Batch(index, item, progress) => {
274 if let Some(current) = self.progresses.get_mut(index) {
275 *current = progress;
276 }
277 return Some(item);
278 }
279 Message::Done => {
280 self.num_done += 1;
281 }
282 };
283
284 if self.num_done == self.workers.len() {
285 while let Some(worker) = self.workers.pop() {
286 worker.join().unwrap();
287 }
288 return None;
289 }
290 }
291 }
292}
293
294#[cfg(test)]
295mod tests {
296 use super::*;
297 use crate::data::dataloader::FixBatchStrategy;
298 use crate::data::dataloader::batcher::TestBatcher;
299 use crate::data::dataset::FakeDataset;
300 use burn_dataset::InMemDataset;
301 use std::collections::HashSet;
302
303 #[test]
304 fn test_multi_thread_batch_dataloader() {
305 let batcher = Arc::new(TestBatcher::new());
306 let dataset = Arc::new(FakeDataset::<String>::new(27));
307 let dataloader_single_thread = BatchDataLoader::new(
308 Box::new(FixBatchStrategy::new(5)),
309 dataset.clone(),
310 batcher.clone(),
311 Default::default(),
312 None,
313 );
314 let dataloader_multi_thread = MultiThreadDataLoader::new(
315 Box::new(FixBatchStrategy::new(5)),
316 dataset,
317 batcher,
318 4,
319 Default::default(),
320 None,
321 );
322
323 let mut items_single_thread = HashSet::new();
324 let mut items_multi_thread = HashSet::new();
325
326 for items in dataloader_single_thread.iter() {
327 for item in items {
328 items_single_thread.insert(item);
329 }
330 }
331
332 for items in dataloader_multi_thread.iter() {
333 for item in items {
334 items_multi_thread.insert(item);
335 }
336 }
337
338 assert_eq!(items_single_thread, items_multi_thread);
339 }
340
341 #[test]
342 fn test_multi_thread_batch_dataloader_shuffle() {
343 let num_classes = 2;
344 let class_size = 100;
345 let batch_size = 10;
346
347 let mut items = Vec::new();
349 for class in 0..num_classes {
350 items.extend(vec![class; class_size]);
351 }
352
353 {
354 let dataset = Arc::new(InMemDataset::new(items.clone()));
356 let batcher = Arc::new(TestBatcher::new());
357
358 let loader = MultiThreadDataLoader::new(
359 Box::new(FixBatchStrategy::new(batch_size)),
360 dataset,
361 batcher,
362 num_classes,
363 Default::default(),
364 None,
366 );
367
368 for batch in loader.iter() {
369 let mut batch_items = HashSet::new();
370 for item in batch {
371 batch_items.insert(item);
372 }
373
374 assert_eq!(batch_items.len(), 1);
376 }
377 }
378
379 {
380 let dataset = Arc::new(InMemDataset::new(items.clone()));
382 let batcher = Arc::new(TestBatcher::new());
383
384 let loader = MultiThreadDataLoader::new(
385 Box::new(FixBatchStrategy::new(batch_size)),
386 dataset.clone(),
387 batcher.clone(),
388 num_classes,
389 Default::default(),
390 Some(StdRng::seed_from_u64(42)),
392 );
393
394 for batch in loader.iter() {
395 let mut batch_items = HashSet::new();
396 for item in batch {
397 batch_items.insert(item);
398 }
399
400 assert_eq!(batch_items.len(), num_classes);
402 }
403 }
404 }
405
406 #[test]
407 fn test_multi_thread_batch_dataloader_incomplete_batches() {
408 let batcher = Arc::new(TestBatcher::new());
409 let dataset = Arc::new(FakeDataset::<String>::new(27));
410 let dataloader_single_thread = BatchDataLoader::new(
411 Box::new(FixBatchStrategy::new(5)),
412 dataset.clone(),
413 batcher.clone(),
414 Default::default(),
415 None,
416 );
417 let dataloader_multi_thread = MultiThreadDataLoader::new(
418 Box::new(FixBatchStrategy::new(5)),
419 dataset,
420 batcher,
421 4,
422 Default::default(),
423 None,
424 );
425
426 let mut items_single_thread = HashSet::new();
427 let mut items_multi_thread = HashSet::new();
428
429 let mut single_thread_cnt = 0;
430 let mut multi_thread_cnt = 0;
431 for items in dataloader_single_thread.iter() {
432 items_single_thread.insert(items);
433 single_thread_cnt += 1;
434 }
435
436 for items in dataloader_multi_thread.iter() {
437 items_multi_thread.insert(items);
438 multi_thread_cnt += 1;
439 }
440
441 assert_eq!(single_thread_cnt, multi_thread_cnt);
442 assert_eq!(items_single_thread, items_multi_thread);
443 }
444}