par_iter_sync/
iter_async.rs1use crate::MAX_SIZE_FOR_THREAD;
2use crossbeam::channel;
3use crossbeam::channel::Receiver;
4use num_cpus;
5use std::sync::atomic::{AtomicBool, Ordering};
6use std::sync::Arc;
7use std::thread;
8use std::thread::JoinHandle;
9
10pub trait IntoParallelIteratorAsync<R, T, TL, F>
14where
15 F: Send + Clone + 'static + Fn(T) -> Result<R, ()>,
16 T: Send + 'static,
17 TL: Send + IntoIterator<Item = T> + 'static,
18 R: Send,
19{
20 fn into_par_iter_async(self, func: F) -> ParIterAsync<R>;
24}
25
26impl<R, T, TL, F> IntoParallelIteratorAsync<R, T, TL, F> for TL
27where
28 F: Send + Clone + 'static + Fn(T) -> Result<R, ()>,
29 T: Send + 'static,
30 TL: Send + IntoIterator<Item = T> + 'static,
31 R: Send + 'static,
32{
33 fn into_par_iter_async(self, func: F) -> ParIterAsync<R> {
34 ParIterAsync::new(self, func)
35 }
36}
37
38pub struct ParIterAsync<R> {
40 output_receiver: Receiver<R>,
42
43 worker_thread: Option<Vec<JoinHandle<()>>>,
45
46 iterator_stopper: Arc<AtomicBool>,
48
49 is_killed: bool,
51
52 worker_count: usize,
54}
55
56impl<R> ParIterAsync<R>
57where
58 R: Send + 'static,
59{
60 pub fn new<T, TL, F>(tasks: TL, task_executor: F) -> Self
64 where
65 F: Send + Clone + 'static + Fn(T) -> Result<R, ()>,
66 T: Send + 'static,
67 TL: Send + IntoIterator<Item = T> + 'static,
68 {
69 let cpus = num_cpus::get();
70 let iterator_stopper = Arc::new(AtomicBool::new(false));
71 let stopper_clone = iterator_stopper.clone();
72
73 let (dispatcher, task_receiver) = channel::bounded(MAX_SIZE_FOR_THREAD * cpus);
75 let work_dispatcher = thread::spawn(move || {
76 for t in tasks {
77 if dispatcher.send(t).is_err() {
78 break;
79 }
80 }
81 });
82
83 let (output_sender, output_receiver) = channel::bounded(MAX_SIZE_FOR_THREAD * cpus);
85
86 let worker_task = move || {
88 loop {
89 if iterator_stopper.load(Ordering::SeqCst) {
91 break;
92 }
93
94 match get_task(&task_receiver) {
96 None => break,
98 Some(task) => match task_executor(task) {
99 Ok(blk) => {
100 output_sender.send(blk).unwrap();
102 }
103 Err(_) => {
104 iterator_stopper.fetch_or(true, Ordering::SeqCst);
106 break;
107 }
108 },
109 }
110 }
111 };
112
113 let mut worker_handles = Vec::with_capacity(cpus + 1);
115 for _ in 0..cpus {
116 worker_handles.push(thread::spawn(worker_task.clone()));
117 }
118 worker_handles.push(work_dispatcher);
119
120 ParIterAsync {
121 output_receiver,
122 worker_thread: Some(worker_handles),
123 iterator_stopper: stopper_clone,
124 is_killed: false,
125 worker_count: cpus,
126 }
127 }
128}
129
130impl<R> ParIterAsync<R> {
131 pub fn kill(&mut self) {
136 if !self.is_killed {
137 self.iterator_stopper.fetch_or(true, Ordering::SeqCst);
139 for _ in 0..self.worker_count {
141 let _ = self.output_receiver.try_recv();
142 }
143 self.is_killed = true;
145 }
146 }
147}
148
149#[inline(always)]
155fn get_task<T>(tasks: &channel::Receiver<T>) -> Option<T>
156where
157 T: Send,
158{
159 tasks.recv().ok()
161}
162
163impl<R> Iterator for ParIterAsync<R> {
164 type Item = R;
165
166 fn next(&mut self) -> Option<Self::Item> {
170 if self.is_killed {
171 return None;
172 }
173 match self.output_receiver.recv() {
174 Ok(block) => Some(block),
175 Err(_) => {
177 self.kill();
178 None
179 }
180 }
181 }
182}
183
184impl<R> ParIterAsync<R> {
185 fn join(&mut self) {
191 for handle in self.worker_thread.take().unwrap() {
192 handle.join().unwrap()
193 }
194 }
195}
196
197impl<R> Drop for ParIterAsync<R> {
198 fn drop(&mut self) {
202 self.kill();
203 self.join();
204 }
205}
206
207#[cfg(test)]
208mod test_par_iter_async {
209 #[cfg(feature = "bench")]
210 extern crate test;
211 use crate::IntoParallelIteratorAsync;
212 use std::collections::HashSet;
213 #[cfg(feature = "bench")]
214 use test::Bencher;
215
216 #[test]
217 fn par_iter_test_exception() {
218 for _ in 0..100 {
219 let resource_captured = vec![3, 1, 4, 1, 5, 9, 2, 6, 5, 3];
220
221 let results: HashSet<i32> = (0..resource_captured.len())
223 .into_par_iter_async(move |a| {
224 let n = resource_captured.get(a).unwrap().to_owned();
225 if n == 5 {
226 Err(())
227 } else {
228 Ok(n)
229 }
230 })
231 .collect();
232
233 assert!(!results.contains(&5))
234 }
235 }
236
237 #[test]
247 fn par_iter_chained_exception() {
248 for _ in 0..100 {
249 let resource_captured: Vec<i32> = (0..10000).collect();
250 let resource_captured_1 = resource_captured.clone();
251 let resource_captured_2 = resource_captured.clone();
252
253 let results: HashSet<i32> = (0..resource_captured.len())
254 .into_par_iter_async(move |a| Ok(resource_captured.get(a).unwrap().to_owned()))
255 .into_par_iter_async(move |a| {
256 let n = resource_captured_1.get(a as usize).unwrap().to_owned();
257 if n == 1000 {
258 Err(())
259 } else {
260 Ok(n)
261 }
262 })
263 .into_par_iter_async(move |a| {
264 Ok(resource_captured_2.get(a as usize).unwrap().to_owned())
265 })
266 .collect();
267
268 assert!(!results.contains(&1000))
269 }
270 }
271
272 #[test]
273 fn test_break() {
275 for _ in 0..10000 {
276 for i in (0..2000).into_par_iter_async(|a| Ok(a)) {
277 if i == 1000 {
278 break;
279 }
280 }
281 }
282 }
283
284 #[cfg(feature = "bench")]
285 #[bench]
286 fn bench_into_par_iter_async(b: &mut Bencher) {
287 b.iter(|| {
288 (0..1_000_000)
289 .into_par_iter_async(|a| Ok(a))
290 .for_each(|_| {})
291 });
292 }
293}