par_iter_sync/
iter_async.rs

1use crate::MAX_SIZE_FOR_THREAD;
2use crossbeam::channel;
3use crossbeam::channel::Receiver;
4use num_cpus;
5use std::sync::atomic::{AtomicBool, Ordering};
6use std::sync::Arc;
7use std::thread;
8use std::thread::JoinHandle;
9
10///
11/// This trait implement the async version of `IntoParallelIteratorSync`
12///
13pub trait IntoParallelIteratorAsync<R, T, TL, F>
14where
15    F: Send + Clone + 'static + Fn(T) -> Result<R, ()>,
16    T: Send + 'static,
17    TL: Send + IntoIterator<Item = T> + 'static,
18    R: Send,
19{
20    ///
21    /// An asynchronous equivalent of into_par_iter_sync
22    ///
23    fn into_par_iter_async(self, func: F) -> ParIterAsync<R>;
24}
25
26impl<R, T, TL, F> IntoParallelIteratorAsync<R, T, TL, F> for TL
27where
28    F: Send + Clone + 'static + Fn(T) -> Result<R, ()>,
29    T: Send + 'static,
30    TL: Send + IntoIterator<Item = T> + 'static,
31    R: Send + 'static,
32{
33    fn into_par_iter_async(self, func: F) -> ParIterAsync<R> {
34        ParIterAsync::new(self, func)
35    }
36}
37
38/// iterate through blocks according to array index.
39pub struct ParIterAsync<R> {
40    /// this receiver receives results produced by workers
41    output_receiver: Receiver<R>,
42
43    /// handles to join worker threads
44    worker_thread: Option<Vec<JoinHandle<()>>>,
45
46    /// atomic flag to stop workers from fetching new tasks
47    iterator_stopper: Arc<AtomicBool>,
48
49    /// if this is `true`, it must guarantee that all worker threads have stopped
50    is_killed: bool,
51
52    /// number of worker threads
53    worker_count: usize,
54}
55
56impl<R> ParIterAsync<R>
57where
58    R: Send + 'static,
59{
60    ///
61    /// the worker threads are dispatched in this `new` constructor!
62    ///
63    pub fn new<T, TL, F>(tasks: TL, task_executor: F) -> Self
64    where
65        F: Send + Clone + 'static + Fn(T) -> Result<R, ()>,
66        T: Send + 'static,
67        TL: Send + IntoIterator<Item = T> + 'static,
68    {
69        let cpus = num_cpus::get();
70        let iterator_stopper = Arc::new(AtomicBool::new(false));
71        let stopper_clone = iterator_stopper.clone();
72
73        // this thread dispatches tasks to worker threads
74        let (dispatcher, task_receiver) = channel::bounded(MAX_SIZE_FOR_THREAD * cpus);
75        let work_dispatcher = thread::spawn(move || {
76            for t in tasks {
77                if dispatcher.send(t).is_err() {
78                    break;
79                }
80            }
81        });
82
83        // output senders for worker threads, and output receiver for user thread
84        let (output_sender, output_receiver) = channel::bounded(MAX_SIZE_FOR_THREAD * cpus);
85
86        // this is what each worker do
87        let worker_task = move || {
88            loop {
89                // check stopper flag, stop if `true`
90                if iterator_stopper.load(Ordering::SeqCst) {
91                    break;
92                }
93
94                // fetch next task
95                match get_task(&task_receiver) {
96                    // break if no more task
97                    None => break,
98                    Some(task) => match task_executor(task) {
99                        Ok(blk) => {
100                            // send output
101                            output_sender.send(blk).unwrap();
102                        }
103                        Err(_) => {
104                            // stop other workers if error is returned
105                            iterator_stopper.fetch_or(true, Ordering::SeqCst);
106                            break;
107                        }
108                    },
109                }
110            }
111        };
112
113        // spawn worker threads
114        let mut worker_handles = Vec::with_capacity(cpus + 1);
115        for _ in 0..cpus {
116            worker_handles.push(thread::spawn(worker_task.clone()));
117        }
118        worker_handles.push(work_dispatcher);
119
120        ParIterAsync {
121            output_receiver,
122            worker_thread: Some(worker_handles),
123            iterator_stopper: stopper_clone,
124            is_killed: false,
125            worker_count: cpus,
126        }
127    }
128}
129
130impl<R> ParIterAsync<R> {
131    ///
132    /// - stop workers from fetching new tasks
133    /// - pull one result from each worker to prevent `send` blocking
134    ///
135    pub fn kill(&mut self) {
136        if !self.is_killed {
137            // stop threads from getting new tasks
138            self.iterator_stopper.fetch_or(true, Ordering::SeqCst);
139            // receive one for each channel to prevent blocking
140            for _ in 0..self.worker_count {
141                let _ = self.output_receiver.try_recv();
142            }
143            // all workers should reasonably stopped by now
144            self.is_killed = true;
145        }
146    }
147}
148
149///
150/// A helper function to receive task from task receiver.
151///
152/// It guarantees to return None if and only if there is no more new task.
153///
154#[inline(always)]
155fn get_task<T>(tasks: &channel::Receiver<T>) -> Option<T>
156where
157    T: Send,
158{
159    // lock task list
160    tasks.recv().ok()
161}
162
163impl<R> Iterator for ParIterAsync<R> {
164    type Item = R;
165
166    ///
167    /// The output API, use next to fetch result from the iterator.
168    ///
169    fn next(&mut self) -> Option<Self::Item> {
170        if self.is_killed {
171            return None;
172        }
173        match self.output_receiver.recv() {
174            Ok(block) => Some(block),
175            // all workers have stopped
176            Err(_) => {
177                self.kill();
178                None
179            }
180        }
181    }
182}
183
184impl<R> ParIterAsync<R> {
185    ///
186    /// Join worker threads. This can be only called only once.
187    /// Otherwise it will panic.
188    /// This is automatically called in `drop()`
189    ///
190    fn join(&mut self) {
191        for handle in self.worker_thread.take().unwrap() {
192            handle.join().unwrap()
193        }
194    }
195}
196
197impl<R> Drop for ParIterAsync<R> {
198    ///
199    /// Stop worker threads, join the threads.
200    ///
201    fn drop(&mut self) {
202        self.kill();
203        self.join();
204    }
205}
206
207#[cfg(test)]
208mod test_par_iter_async {
209    #[cfg(feature = "bench")]
210    extern crate test;
211    use crate::IntoParallelIteratorAsync;
212    use std::collections::HashSet;
213    #[cfg(feature = "bench")]
214    use test::Bencher;
215
216    #[test]
217    fn par_iter_test_exception() {
218        for _ in 0..100 {
219            let resource_captured = vec![3, 1, 4, 1, 5, 9, 2, 6, 5, 3];
220
221            // if Err(()) is returned, the iterator stops early
222            let results: HashSet<i32> = (0..resource_captured.len())
223                .into_par_iter_async(move |a| {
224                    let n = resource_captured.get(a).unwrap().to_owned();
225                    if n == 5 {
226                        Err(())
227                    } else {
228                        Ok(n)
229                    }
230                })
231                .collect();
232
233            assert!(!results.contains(&5))
234        }
235    }
236
237    ///
238    /// The iterators can be chained.
239    ///
240    /// par_iter_0 -> owned by -> par_iter_1 -> owned by -> par_iter_2
241    ///
242    /// par_iter_1 exception at height 1000,
243    ///
244    /// the final output should contain 0..1000;
245    ///
246    #[test]
247    fn par_iter_chained_exception() {
248        for _ in 0..100 {
249            let resource_captured: Vec<i32> = (0..10000).collect();
250            let resource_captured_1 = resource_captured.clone();
251            let resource_captured_2 = resource_captured.clone();
252
253            let results: HashSet<i32> = (0..resource_captured.len())
254                .into_par_iter_async(move |a| Ok(resource_captured.get(a).unwrap().to_owned()))
255                .into_par_iter_async(move |a| {
256                    let n = resource_captured_1.get(a as usize).unwrap().to_owned();
257                    if n == 1000 {
258                        Err(())
259                    } else {
260                        Ok(n)
261                    }
262                })
263                .into_par_iter_async(move |a| {
264                    Ok(resource_captured_2.get(a as usize).unwrap().to_owned())
265                })
266                .collect();
267
268            assert!(!results.contains(&1000))
269        }
270    }
271
272    #[test]
273    /// test that the iterator won't deadlock during drop
274    fn test_break() {
275        for _ in 0..10000 {
276            for i in (0..2000).into_par_iter_async(|a| Ok(a)) {
277                if i == 1000 {
278                    break;
279                }
280            }
281        }
282    }
283
284    #[cfg(feature = "bench")]
285    #[bench]
286    fn bench_into_par_iter_async(b: &mut Bencher) {
287        b.iter(|| {
288            (0..1_000_000)
289                .into_par_iter_async(|a| Ok(a))
290                .for_each(|_| {})
291        });
292    }
293}