burn_core/data/dataloader/
multithread.rs1use burn_dataset::Dataset;
2use burn_dataset::transform::PartialDataset;
3use burn_tensor::backend::Backend;
4use rand::SeedableRng;
5use rand::distr::{Distribution, StandardUniform};
6use rand::rngs::StdRng;
7
8use super::batcher::Batcher;
9use super::{BatchDataLoader, BatchStrategy, DataLoader, DataLoaderIterator, Progress};
10use core::cell::OnceCell;
11use std::sync::{Arc, mpsc};
12use std::thread;
13
14const MAX_QUEUED_ITEMS: usize = 100;
15
16pub struct MultiThreadDataLoader<B: Backend, I, O> {
18 strategy: Box<dyn BatchStrategy<I>>,
20 dataset: Arc<dyn Dataset<I>>,
21 batcher: Arc<dyn Batcher<B, I, O>>,
22 device: B::Device,
23 rng: Option<rand::rngs::StdRng>,
24 num_threads: usize,
25
26 dataloaders: OnceCell<Vec<BatchDataLoader<B, I, O>>>,
28}
29
30#[derive(Debug)]
32pub enum Message<O> {
33 Batch(usize, O, Progress),
35
36 Done,
38}
39
40struct MultiThreadsDataloaderIterator<O> {
41 num_done: usize,
42 workers: Vec<thread::JoinHandle<()>>,
43 receiver: mpsc::Receiver<Message<O>>,
44 progresses: Vec<Progress>,
45}
46
47impl<B: Backend, I, O> MultiThreadDataLoader<B, I, O>
48where
49 I: Send + Sync + Clone + 'static,
50 O: Send + 'static,
51{
52 pub fn new(
68 strategy: Box<dyn BatchStrategy<I>>,
69 dataset: Arc<dyn Dataset<I>>,
70 batcher: Arc<dyn Batcher<B, I, O>>,
71 num_threads: usize,
72 device: B::Device,
73 rng: Option<rand::rngs::StdRng>,
74 ) -> Self {
75 Self {
76 strategy,
77 dataset,
78 batcher,
79 num_threads,
80 device,
81 rng,
82 dataloaders: OnceCell::new(),
83 }
84 }
85
86 fn initialize(&self) -> &[BatchDataLoader<B, I, O>] {
88 self.dataloaders
89 .get_or_init(|| {
90 let datasets = PartialDataset::split(self.dataset.clone(), self.num_threads);
91
92 let mut rng = self.rng.clone();
94 let rngs = (0..self.num_threads).map(|_| {
95 rng.as_mut().map(|rng| {
96 StdRng::seed_from_u64(Distribution::sample(&StandardUniform, rng))
97 })
98 });
99
100 datasets
101 .into_iter()
102 .zip(rngs)
103 .map(|(dataset, rng)| {
104 let strategy = self.strategy.clone_dyn();
105 BatchDataLoader::new(
106 strategy,
107 Arc::new(dataset),
108 self.batcher.clone(),
109 self.device.clone(),
110 rng,
111 )
112 })
113 .collect()
114 })
115 .as_ref()
116 }
117}
118
119impl<B: Backend, I, O> DataLoader<B, O> for MultiThreadDataLoader<B, I, O>
120where
121 I: Send + Sync + Clone + 'static,
122 O: Send + 'static + std::fmt::Debug,
123{
124 fn iter<'a>(&'a self) -> Box<dyn DataLoaderIterator<O> + 'a> {
125 let dataloaders = self.initialize();
127
128 let (sender, receiver) = mpsc::sync_channel::<Message<O>>(MAX_QUEUED_ITEMS);
129
130 let mut progresses = Vec::with_capacity(dataloaders.len());
131
132 let handlers: Vec<_> = dataloaders
133 .iter()
134 .enumerate()
135 .map(|(index, dataloader)| {
136 let dataloader_cloned = dataloader.clone();
137 let sender_cloned = sender.clone();
138 progresses.push(Progress::new(0, dataloader_cloned.num_items()));
139
140 thread::spawn(move || {
141 let mut iterator = dataloader_cloned.iter();
142 while let Some(item) = iterator.next() {
143 let progress = iterator.progress();
144
145 match sender_cloned.send(Message::Batch(index, item, progress)) {
146 Ok(_) => {}
147 Err(_) => return,
150 };
151 }
152 sender_cloned.send(Message::Done).ok();
154 })
155 })
156 .collect();
157
158 Box::new(MultiThreadsDataloaderIterator::new(
159 receiver, handlers, progresses,
160 ))
161 }
162
163 fn num_items(&self) -> usize {
164 self.dataset.len()
167 }
168
169 fn to_device(&self, device: &B::Device) -> Arc<dyn DataLoader<B, O>> {
170 Arc::new(Self::new(
171 self.strategy.clone_dyn(),
172 self.dataset.clone(),
173 self.batcher.clone(),
174 self.num_threads,
175 device.clone(),
176 self.rng.clone(),
177 ))
178 }
179
180 fn slice(&self, start: usize, end: usize) -> Arc<dyn DataLoader<B, O>> {
181 let dataloader = Self::new(
182 self.strategy.clone_dyn(),
183 Arc::new(PartialDataset::new(self.dataset.clone(), start, end)),
184 self.batcher.clone(),
185 self.num_threads,
186 self.device.clone(),
187 self.rng.clone(),
188 );
189 Arc::new(dataloader)
190 }
191}
192
193impl<O> MultiThreadsDataloaderIterator<O> {
194 pub fn new(
195 receiver: mpsc::Receiver<Message<O>>,
196 workers: Vec<thread::JoinHandle<()>>,
197 progresses: Vec<Progress>,
198 ) -> Self {
199 MultiThreadsDataloaderIterator {
200 num_done: 0,
201 workers,
202 receiver,
203 progresses,
204 }
205 }
206}
207impl<O: std::fmt::Debug> DataLoaderIterator<O> for MultiThreadsDataloaderIterator<O> {
208 fn progress(&self) -> Progress {
209 let mut items_total = 0;
210 let mut items_processed = 0;
211
212 for progress in self.progresses.iter() {
213 items_total += progress.items_total;
214 items_processed += progress.items_processed;
215 }
216
217 Progress::new(items_processed, items_total)
218 }
219}
220
221impl<O: std::fmt::Debug> Iterator for MultiThreadsDataloaderIterator<O> {
222 type Item = O;
223
224 fn next(&mut self) -> Option<O> {
225 if self.workers.is_empty() {
226 return None;
227 }
228
229 loop {
230 let item = self.receiver.recv();
231 let item = item.unwrap();
232
233 match item {
234 Message::Batch(index, item, progress) => {
235 if let Some(current) = self.progresses.get_mut(index) {
236 *current = progress;
237 }
238 return Some(item);
239 }
240 Message::Done => {
241 self.num_done += 1;
242 }
243 };
244
245 if self.num_done == self.workers.len() {
246 while let Some(worker) = self.workers.pop() {
247 worker.join().unwrap();
248 }
249 return None;
250 }
251 }
252 }
253}
254
255#[cfg(test)]
256mod tests {
257 use std::collections::HashSet;
258
259 use super::*;
260 use crate::data::dataloader::FixBatchStrategy;
261 use crate::data::dataloader::batcher::TestBatcher;
262 use crate::data::dataset::FakeDataset;
263
264 #[test]
265 fn test_multi_thread_batch_dataloader() {
266 let batcher = Arc::new(TestBatcher::new());
267 let dataset = Arc::new(FakeDataset::<String>::new(27));
268 let dataloader_single_thread = BatchDataLoader::new(
269 Box::new(FixBatchStrategy::new(5)),
270 dataset.clone(),
271 batcher.clone(),
272 Default::default(),
273 None,
274 );
275 let dataloader_multi_thread = MultiThreadDataLoader::new(
276 Box::new(FixBatchStrategy::new(5)),
277 dataset,
278 batcher,
279 4,
280 Default::default(),
281 None,
282 );
283
284 let mut items_single_thread = HashSet::new();
285 let mut items_multi_thread = HashSet::new();
286
287 for items in dataloader_single_thread.iter() {
288 for item in items {
289 items_single_thread.insert(item);
290 }
291 }
292
293 for items in dataloader_multi_thread.iter() {
294 for item in items {
295 items_multi_thread.insert(item);
296 }
297 }
298
299 assert_eq!(items_single_thread, items_multi_thread);
300 }
301}