1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
use crate::MAX_SIZE_FOR_THREAD;
use crossbeam::channel;
use crossbeam::channel::Receiver;
use num_cpus;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::thread;
use std::thread::JoinHandle;

///
/// This trait implement the async version of `IntoParallelIteratorSync`
///
pub trait IntoParallelIteratorAsync<R, T, TL, F>
where
    F: Send + Clone + 'static + Fn(T) -> Result<R, ()>,
    T: Send + 'static,
    TL: Send + IntoIterator<Item = T> + 'static,
    R: Send,
{
    ///
    /// An asynchronous equivalent of into_par_iter_sync
    ///
    fn into_par_iter_async(self, func: F) -> ParIterAsync<R>;
}

impl<R, T, TL, F> IntoParallelIteratorAsync<R, T, TL, F> for TL
where
    F: Send + Clone + 'static + Fn(T) -> Result<R, ()>,
    T: Send + 'static,
    TL: Send + IntoIterator<Item = T> + 'static,
    R: Send + 'static,
{
    fn into_par_iter_async(self, func: F) -> ParIterAsync<R> {
        ParIterAsync::new(self, func)
    }
}

/// iterate through blocks according to array index.
pub struct ParIterAsync<R> {
    /// this receiver receives results produced by workers
    output_receiver: Receiver<R>,

    /// handles to join worker threads
    worker_thread: Option<Vec<JoinHandle<()>>>,

    /// atomic flag to stop workers from fetching new tasks
    iterator_stopper: Arc<AtomicBool>,

    /// if this is `true`, it must guarantee that all worker threads have stopped
    is_killed: bool,

    /// number of worker threads
    worker_count: usize,
}

impl<R> ParIterAsync<R>
where
    R: Send + 'static,
{
    ///
    /// the worker threads are dispatched in this `new` constructor!
    ///
    pub fn new<T, TL, F>(tasks: TL, task_executor: F) -> Self
    where
        F: Send + Clone + 'static + Fn(T) -> Result<R, ()>,
        T: Send + 'static,
        TL: Send + IntoIterator<Item = T> + 'static,
    {
        let cpus = num_cpus::get();
        let iterator_stopper = Arc::new(AtomicBool::new(false));
        let stopper_clone = iterator_stopper.clone();

        // this thread dispatches tasks to worker threads
        let (dispatcher, task_receiver) = channel::bounded(MAX_SIZE_FOR_THREAD * cpus);
        let work_dispatcher = thread::spawn(move || {
            for t in tasks {
                if dispatcher.send(t).is_err() {
                    break;
                }
            }
        });

        // output senders for worker threads, and output receiver for user thread
        let (output_sender, output_receiver) = channel::bounded(MAX_SIZE_FOR_THREAD * cpus);

        // this is what each worker do
        let worker_task = move || {
            loop {
                // check stopper flag, stop if `true`
                if iterator_stopper.load(Ordering::SeqCst) {
                    break;
                }

                // fetch next task
                match get_task(&task_receiver) {
                    // break if no more task
                    None => break,
                    Some(task) => match task_executor(task) {
                        Ok(blk) => {
                            // send output
                            output_sender.send(blk).unwrap();
                        }
                        Err(_) => {
                            // stop other workers if error is returned
                            iterator_stopper.fetch_or(true, Ordering::SeqCst);
                            break;
                        }
                    },
                }
            }
        };

        // spawn worker threads
        let mut worker_handles = Vec::with_capacity(cpus + 1);
        for _ in 0..cpus {
            worker_handles.push(thread::spawn(worker_task.clone()));
        }
        worker_handles.push(work_dispatcher);

        ParIterAsync {
            output_receiver,
            worker_thread: Some(worker_handles),
            iterator_stopper: stopper_clone,
            is_killed: false,
            worker_count: cpus,
        }
    }
}

impl<R> ParIterAsync<R> {
    ///
    /// - stop workers from fetching new tasks
    /// - pull one result from each worker to prevent `send` blocking
    ///
    pub fn kill(&mut self) {
        if !self.is_killed {
            // stop threads from getting new tasks
            self.iterator_stopper.fetch_or(true, Ordering::SeqCst);
            // receive one for each channel to prevent blocking
            for _ in 0..self.worker_count {
                let _ = self.output_receiver.try_recv();
            }
            // all workers should reasonably stopped by now
            self.is_killed = true;
        }
    }
}

///
/// A helper function to receive task from task receiver.
///
/// It guarantees to return None if and only if there is no more new task.
///
#[inline(always)]
fn get_task<T>(tasks: &channel::Receiver<T>) -> Option<T>
where
    T: Send,
{
    // lock task list
    tasks.recv().ok()
}

impl<R> Iterator for ParIterAsync<R> {
    type Item = R;

    ///
    /// The output API, use next to fetch result from the iterator.
    ///
    fn next(&mut self) -> Option<Self::Item> {
        if self.is_killed {
            return None;
        }
        match self.output_receiver.recv() {
            Ok(block) => Some(block),
            // all workers have stopped
            Err(_) => {
                self.kill();
                None
            }
        }
    }
}

impl<R> ParIterAsync<R> {
    ///
    /// Join worker threads. This can be only called only once.
    /// Otherwise it will panic.
    /// This is automatically called in `drop()`
    ///
    fn join(&mut self) {
        for handle in self.worker_thread.take().unwrap() {
            handle.join().unwrap()
        }
    }
}

impl<R> Drop for ParIterAsync<R> {
    ///
    /// Stop worker threads, join the threads.
    ///
    fn drop(&mut self) {
        self.kill();
        self.join();
    }
}

#[cfg(test)]
mod test_par_iter_async {
    #[cfg(feature = "bench")]
    extern crate test;
    use crate::IntoParallelIteratorAsync;
    use std::collections::HashSet;
    #[cfg(feature = "bench")]
    use test::Bencher;

    #[test]
    fn par_iter_test_exception() {
        for _ in 0..100 {
            let resource_captured = vec![3, 1, 4, 1, 5, 9, 2, 6, 5, 3];

            // if Err(()) is returned, the iterator stops early
            let results: HashSet<i32> = (0..resource_captured.len())
                .into_par_iter_async(move |a| {
                    let n = resource_captured.get(a).unwrap().to_owned();
                    if n == 5 {
                        Err(())
                    } else {
                        Ok(n)
                    }
                })
                .collect();

            assert!(!results.contains(&5))
        }
    }

    ///
    /// The iterators can be chained.
    ///
    /// par_iter_0 -> owned by -> par_iter_1 -> owned by -> par_iter_2
    ///
    /// par_iter_1 exception at height 1000,
    ///
    /// the final output should contain 0..1000;
    ///
    #[test]
    fn par_iter_chained_exception() {
        for _ in 0..100 {
            let resource_captured: Vec<i32> = (0..10000).collect();
            let resource_captured_1 = resource_captured.clone();
            let resource_captured_2 = resource_captured.clone();

            let results: HashSet<i32> = (0..resource_captured.len())
                .into_par_iter_async(move |a| Ok(resource_captured.get(a).unwrap().to_owned()))
                .into_par_iter_async(move |a| {
                    let n = resource_captured_1.get(a as usize).unwrap().to_owned();
                    if n == 1000 {
                        Err(())
                    } else {
                        Ok(n)
                    }
                })
                .into_par_iter_async(move |a| {
                    Ok(resource_captured_2.get(a as usize).unwrap().to_owned())
                })
                .collect();

            assert!(!results.contains(&1000))
        }
    }

    #[test]
    /// test that the iterator won't deadlock during drop
    fn test_break() {
        for _ in 0..10000 {
            for i in (0..2000).into_par_iter_async(|a| Ok(a)) {
                if i == 1000 {
                    break;
                }
            }
        }
    }

    #[cfg(feature = "bench")]
    #[bench]
    fn bench_into_par_iter_async(b: &mut Bencher) {
        b.iter(|| {
            (0..1_000_000)
                .into_par_iter_async(|a| Ok(a))
                .for_each(|_| {})
        });
    }
}