1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
use super::{DataLoader, DataLoaderIterator, Progress};
use std::collections::HashMap;
use std::sync::{mpsc, Arc};
use std::thread;

static MAX_QUEUED_ITEMS: usize = 100;

pub struct MultiThreadDataLoader<O> {
    dataloaders: Vec<Arc<dyn DataLoader<O> + Send + Sync>>,
}

#[derive(Debug)]
pub enum Message<O> {
    Batch(usize, O, Progress),
    Done,
}

struct MultiThreadsDataloaderIterator<O> {
    num_done: usize,
    workers: Vec<thread::JoinHandle<()>>,
    receiver: mpsc::Receiver<Message<O>>,
    progresses: HashMap<usize, Progress>,
}

impl<O> MultiThreadDataLoader<O> {
    pub fn new(dataloaders: Vec<Arc<dyn DataLoader<O> + Send + Sync>>) -> Self {
        Self { dataloaders }
    }
}

impl<O> DataLoader<O> for MultiThreadDataLoader<O>
where
    O: Send + 'static + std::fmt::Debug,
{
    fn iter<'a>(&'a self) -> Box<dyn DataLoaderIterator<O> + 'a> {
        let (sender, receiver) = mpsc::sync_channel::<Message<O>>(MAX_QUEUED_ITEMS);

        let handlers: Vec<_> = self
            .dataloaders
            .clone()
            .into_iter()
            .enumerate()
            .map(|(index, dataloader)| {
                let dataloader_cloned = dataloader;
                let sender_cloned = sender.clone();

                thread::spawn(move || {
                    let mut iterator = dataloader_cloned.iter();
                    while let Some(item) = iterator.next() {
                        let progress = iterator.progress();
                        sender_cloned
                            .send(Message::Batch(index, item, progress))
                            .unwrap();
                    }
                    sender_cloned.send(Message::Done).unwrap();
                })
            })
            .collect();

        Box::new(MultiThreadsDataloaderIterator::new(receiver, handlers))
    }
}

impl<O> MultiThreadsDataloaderIterator<O> {
    pub fn new(receiver: mpsc::Receiver<Message<O>>, workers: Vec<thread::JoinHandle<()>>) -> Self {
        MultiThreadsDataloaderIterator {
            num_done: 0,
            workers,
            receiver,
            progresses: HashMap::new(),
        }
    }
}
impl<O: std::fmt::Debug> DataLoaderIterator<O> for MultiThreadsDataloaderIterator<O> {
    fn progress(&self) -> Progress {
        let mut items_total = 0;
        let mut items_processed = 0;

        for progress in self.progresses.values() {
            items_total += progress.items_total;
            items_processed += progress.items_processed;
        }

        Progress {
            items_processed,
            items_total,
        }
    }
}

impl<O: std::fmt::Debug> Iterator for MultiThreadsDataloaderIterator<O> {
    type Item = O;

    fn next(&mut self) -> Option<O> {
        if self.workers.is_empty() {
            return None;
        }

        loop {
            let item = self.receiver.recv();
            let item = item.unwrap();

            match item {
                Message::Batch(index, item, progress) => {
                    self.progresses.insert(index, progress);
                    return Some(item);
                }
                Message::Done => {
                    self.num_done += 1;
                }
            };

            if self.num_done == self.workers.len() {
                while let Some(worker) = self.workers.pop() {
                    worker.join().unwrap();
                }
                return None;
            }
        }
    }
}