1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
use super::{DataLoader, DataLoaderIterator, Progress};
use std::collections::HashMap;
use std::sync::{mpsc, Arc};
use std::thread;
static MAX_QUEUED_ITEMS: usize = 100;
pub struct MultiThreadDataLoader<O> {
dataloaders: Vec<Arc<dyn DataLoader<O> + Send + Sync>>,
}
#[derive(Debug)]
pub enum Message<O> {
Batch(usize, O, Progress),
Done,
}
struct MultiThreadsDataloaderIterator<O> {
num_done: usize,
workers: Vec<thread::JoinHandle<()>>,
receiver: mpsc::Receiver<Message<O>>,
progresses: HashMap<usize, Progress>,
}
impl<O> MultiThreadDataLoader<O> {
pub fn new(dataloaders: Vec<Arc<dyn DataLoader<O> + Send + Sync>>) -> Self {
Self { dataloaders }
}
}
impl<O> DataLoader<O> for MultiThreadDataLoader<O>
where
O: Send + 'static + std::fmt::Debug,
{
fn iter<'a>(&'a self) -> Box<dyn DataLoaderIterator<O> + 'a> {
let (sender, receiver) = mpsc::sync_channel::<Message<O>>(MAX_QUEUED_ITEMS);
let handlers: Vec<_> = self
.dataloaders
.clone()
.into_iter()
.enumerate()
.map(|(index, dataloader)| {
let dataloader_cloned = dataloader;
let sender_cloned = sender.clone();
thread::spawn(move || {
let mut iterator = dataloader_cloned.iter();
while let Some(item) = iterator.next() {
let progress = iterator.progress();
sender_cloned
.send(Message::Batch(index, item, progress))
.unwrap();
}
sender_cloned.send(Message::Done).unwrap();
})
})
.collect();
Box::new(MultiThreadsDataloaderIterator::new(receiver, handlers))
}
}
impl<O> MultiThreadsDataloaderIterator<O> {
pub fn new(receiver: mpsc::Receiver<Message<O>>, workers: Vec<thread::JoinHandle<()>>) -> Self {
MultiThreadsDataloaderIterator {
num_done: 0,
workers,
receiver,
progresses: HashMap::new(),
}
}
}
impl<O: std::fmt::Debug> DataLoaderIterator<O> for MultiThreadsDataloaderIterator<O> {
fn progress(&self) -> Progress {
let mut items_total = 0;
let mut items_processed = 0;
for progress in self.progresses.values() {
items_total += progress.items_total;
items_processed += progress.items_processed;
}
Progress {
items_processed,
items_total,
}
}
}
impl<O: std::fmt::Debug> Iterator for MultiThreadsDataloaderIterator<O> {
type Item = O;
fn next(&mut self) -> Option<O> {
if self.workers.is_empty() {
return None;
}
loop {
let item = self.receiver.recv();
let item = item.unwrap();
match item {
Message::Batch(index, item, progress) => {
self.progresses.insert(index, progress);
return Some(item);
}
Message::Done => {
self.num_done += 1;
}
};
if self.num_done == self.workers.len() {
while let Some(worker) = self.workers.pop() {
worker.join().unwrap();
}
return None;
}
}
}
}