1use crate::{
2 common::*,
3 config::{BufSize, ParParams},
4 par_stream::ParStreamExt as _,
5 rt,
6 stream::StreamExt as _,
7 try_index_stream::TryIndexStreamExt as _,
8 try_stream::{TakeUntilError, TryStreamExt as _},
9 utils,
10};
11use flume::r#async::RecvStream;
12use tokio::sync::broadcast;
13
14pub type TryParBatching<T, E> = TakeUntilError<RecvStream<'static, Result<T, E>>, T, E>;
16
17pub trait TryParStreamExt
19where
20 Self: 'static + Send + TryStream,
21 Self::Ok: 'static + Send,
22 Self::Error: 'static + Send,
23{
24 fn try_map_blocking<B, T, F>(
26 self,
27 buf_size: B,
28 f: F,
29 ) -> RecvStream<'static, Result<T, Self::Error>>
30 where
31 B: Into<BufSize>,
32 T: Send,
33 F: 'static + Send + FnMut(Self::Ok) -> Result<T, Self::Error>;
34
35 fn try_par_batching<U, P, F, Fut>(self, params: P, f: F) -> TryParBatching<U, Self::Error>
37 where
38 Self: Sized,
39 P: Into<ParParams>,
40 F: 'static
41 + Clone
42 + Send
43 + FnMut(usize, flume::Receiver<Result<Self::Ok, Self::Error>>) -> Fut,
44 Fut: 'static
45 + Future<
46 Output = Result<
47 Option<(U, flume::Receiver<Result<Self::Ok, Self::Error>>)>,
48 Self::Error,
49 >,
50 >
51 + Send,
52 U: 'static + Send;
53
54 fn try_par_then<U, P, F, Fut>(
56 self,
57 params: P,
58 f: F,
59 ) -> BoxStream<'static, Result<U, Self::Error>>
60 where
61 P: Into<ParParams>,
62 U: 'static + Send,
63 F: 'static + FnMut(Self::Ok) -> Fut + Send,
64 Fut: 'static + Future<Output = Result<U, Self::Error>> + Send;
65
66 fn try_par_then_unordered<U, P, F, Fut>(
68 self,
69 params: P,
70 f: F,
71 ) -> BoxStream<'static, Result<U, Self::Error>>
72 where
73 U: 'static + Send,
74 F: 'static + FnMut(Self::Ok) -> Fut + Send,
75 Fut: 'static + Future<Output = Result<U, Self::Error>> + Send,
76 P: Into<ParParams>;
77
78 fn try_par_map<U, P, F, Func>(
80 self,
81 params: P,
82 f: F,
83 ) -> BoxStream<'static, Result<U, Self::Error>>
84 where
85 P: Into<ParParams>,
86 U: 'static + Send,
87 F: 'static + FnMut(Self::Ok) -> Func + Send,
88 Func: 'static + FnOnce() -> Result<U, Self::Error> + Send;
89
90 fn try_par_map_unordered<U, P, F, Func>(
92 self,
93 params: P,
94 f: F,
95 ) -> BoxStream<'static, Result<U, Self::Error>>
96 where
97 P: Into<ParParams>,
98 U: 'static + Send,
99 F: 'static + FnMut(Self::Ok) -> Func + Send,
100 Func: 'static + FnOnce() -> Result<U, Self::Error> + Send;
101
102 fn try_par_for_each<P, F, Fut>(
104 self,
105 params: P,
106 f: F,
107 ) -> BoxFuture<'static, Result<(), Self::Error>>
108 where
109 P: Into<ParParams>,
110 F: 'static + FnMut(Self::Ok) -> Fut + Send,
111 Fut: 'static + Future<Output = Result<(), Self::Error>> + Send;
112
113 fn try_par_for_each_blocking<P, F, Func>(
115 self,
116 params: P,
117 f: F,
118 ) -> BoxFuture<'static, Result<(), Self::Error>>
119 where
120 P: Into<ParParams>,
121 F: 'static + FnMut(Self::Ok) -> Func + Send,
122 Func: 'static + FnOnce() -> Result<(), Self::Error> + Send;
123}
124
125impl<S, T, E> TryParStreamExt for S
126where
127 Self: 'static + Send + Stream<Item = Result<T, E>>,
128 T: 'static + Send,
129 E: 'static + Send,
130{
131 fn try_map_blocking<B, U, F>(self, buf_size: B, mut f: F) -> RecvStream<'static, Result<U, E>>
132 where
133 B: Into<BufSize>,
134 U: Send,
135 F: 'static + Send + FnMut(T) -> Result<U, E>,
136 {
137 let buf_size = buf_size.into().get();
138 let mut stream = self.boxed();
139 let (output_tx, output_rx) = utils::channel(buf_size);
140
141 rt::spawn_blocking(move || loop {
142 match rt::block_on(stream.next()) {
143 Some(Ok(input)) => {
144 let result = f(input);
145 let is_err = result.is_err();
146
147 if output_tx.send(result).is_err() {
148 break;
149 }
150
151 if is_err {
152 break;
153 }
154 }
155 Some(Err(err)) => {
156 let _ = output_tx.send(Err(err));
157 break;
158 }
159 None => break,
160 }
161 });
162
163 output_rx.into_stream()
164 }
165
166 fn try_par_batching<U, P, F, Fut>(self, params: P, f: F) -> TryParBatching<U, E>
167 where
168 P: Into<ParParams>,
169 U: 'static + Send,
170 F: 'static
171 + Clone
172 + Send
173 + FnMut(usize, flume::Receiver<Result<Self::Ok, Self::Error>>) -> Fut,
174 Fut: 'static
175 + Future<
176 Output = Result<
177 Option<(U, flume::Receiver<Result<Self::Ok, Self::Error>>)>,
178 Self::Error,
179 >,
180 >
181 + Send,
182 {
183 let ParParams {
184 num_workers,
185 buf_size,
186 } = params.into();
187
188 let (input_tx, input_rx) = utils::channel(buf_size);
189 let (output_tx, output_rx) = utils::channel(buf_size);
190 let (terminate_tx, _) = broadcast::channel(1);
191
192 rt::spawn(async move {
193 let _ = self.map(Ok).forward(input_tx.into_sink()).await;
194 });
195
196 (0..num_workers).for_each(move |worker_index| {
197 let input_rx = input_rx.clone();
198 let output_tx = output_tx.clone();
199 let mut terminate_rx = terminate_tx.subscribe();
200 let terminate_tx = terminate_tx.clone();
201 let f = f.clone();
202
203 rt::spawn(async move {
204 let _ = stream::repeat(())
205 .take_until(async move {
206 let _ = terminate_rx.recv().await;
207 })
208 .stateful_then(
209 Some((f, terminate_tx, input_rx)),
210 move |state, ()| async move {
211 let (mut f, terminate_tx, input_rx) = state.unwrap();
212 let result = f(worker_index, input_rx).await;
213
214 if result.is_err() {
215 let _ = terminate_tx.send(());
216 }
217
218 match result {
219 Ok(Some((item, input_rx))) => {
220 Some((Some((f, terminate_tx, input_rx)), Ok(item)))
221 }
222 Ok(None) => None,
223 Err(err) => Some((None, Err(err))),
224 }
225 },
226 )
227 .take_until_error()
228 .map(Ok)
229 .forward(output_tx.into_sink())
230 .await;
231 });
232 });
233
234 output_rx.into_stream().take_until_error()
235 }
236
237 fn try_par_then<U, P, F, Fut>(self, params: P, mut f: F) -> BoxStream<'static, Result<U, E>>
238 where
239 P: Into<ParParams>,
240 U: 'static + Send,
241 F: 'static + FnMut(T) -> Fut + Send,
242 Fut: 'static + Future<Output = Result<U, E>> + Send,
243 {
244 self.take_until_error()
245 .enumerate()
246 .par_then_unordered(params, move |(index, input)| {
247 let fut = input.map(|input| f(input));
248
249 async move {
250 let output = fut?.await?;
251 Ok((index, output))
252 }
253 })
254 .try_reorder_enumerated()
255 .boxed()
256 }
257
258 fn try_par_then_unordered<U, P, F, Fut>(
259 self,
260 params: P,
261 f: F,
262 ) -> BoxStream<'static, Result<U, E>>
263 where
264 U: 'static + Send,
265 F: 'static + FnMut(T) -> Fut + Send,
266 Fut: 'static + Future<Output = Result<U, E>> + Send,
267 P: Into<ParParams>,
268 {
269 let (input_error, input_stream) = self.catch_error();
270 let output_stream = input_stream.par_then_unordered(params, f);
271
272 stream::select(
273 input_error
274 .map(|result| result.map(|()| None))
275 .into_stream(),
276 output_stream.map(|result| result.map(Some)),
277 )
278 .try_filter_map(|item| future::ok(item))
279 .take_until_error()
280 .boxed()
281 }
282
283 fn try_par_map<U, P, F, Func>(self, params: P, mut f: F) -> BoxStream<'static, Result<U, E>>
284 where
285 P: Into<ParParams>,
286 U: 'static + Send,
287 F: 'static + FnMut(T) -> Func + Send,
288 Func: 'static + FnOnce() -> Result<U, E> + Send,
289 {
290 self.take_until_error()
291 .enumerate()
292 .par_map_unordered(params, move |(index, input)| {
293 let func = input.map(|input| f(input));
294
295 move || {
296 let output = (func?)()?;
297 Ok((index, output))
298 }
299 })
300 .try_reorder_enumerated()
301 .boxed()
302 }
303
304 fn try_par_map_unordered<U, P, F, Func>(
305 self,
306 params: P,
307 f: F,
308 ) -> BoxStream<'static, Result<U, E>>
309 where
310 P: Into<ParParams>,
311 U: 'static + Send,
312 F: 'static + FnMut(T) -> Func + Send,
313 Func: 'static + FnOnce() -> Result<U, E> + Send,
314 {
315 let (input_error, input_stream) = self.catch_error();
316 let output_stream = input_stream.par_map_unordered(params, f);
317
318 stream::select(
319 input_error
320 .map(|result| result.map(|()| None))
321 .into_stream(),
322 output_stream.map(|result| result.map(Some)),
323 )
324 .try_filter_map(|item| future::ok(item))
325 .take_until_error()
326 .boxed()
327 }
328
329 fn try_par_for_each<P, F, Fut>(self, params: P, f: F) -> BoxFuture<'static, Result<(), E>>
330 where
331 P: Into<ParParams>,
332 F: 'static + FnMut(T) -> Fut + Send,
333 Fut: 'static + Future<Output = Result<(), E>> + Send,
334 {
335 let ParParams {
336 num_workers,
337 buf_size,
338 } = params.into();
339 let (terminate_tx, mut terminate_rx) = broadcast::channel(1);
340 let input_stream = self
341 .take_until_error()
342 .take_until(async move {
343 let _ = terminate_rx.recv().await;
344 })
345 .stateful_map(f, |mut f, item| {
346 let fut = item.map(|item| f(item));
347 Some((f, fut))
348 })
349 .spawned(buf_size);
350
351 let worker_futures = (0..num_workers).map(move |_| {
352 let terminate_tx = terminate_tx.clone();
353
354 rt::spawn(
355 input_stream
356 .clone()
357 .stateful_then(terminate_tx, |terminate_tx, fut| async move {
358 let result = async move {
359 fut?.await?;
360 Ok(())
361 }
362 .await;
363
364 if result.is_err() {
365 let _ = terminate_tx.send(());
366 }
367
368 Some((terminate_tx, result))
369 })
370 .try_for_each(|()| future::ok(())),
371 )
372 });
373
374 future::try_join_all(worker_futures)
375 .map(|result| result.map(|_| ()))
376 .boxed()
377 }
378
379 fn try_par_for_each_blocking<P, F, Func>(
380 self,
381 params: P,
382 f: F,
383 ) -> BoxFuture<'static, Result<(), E>>
384 where
385 P: Into<ParParams>,
386 F: 'static + FnMut(T) -> Func + Send,
387 Func: 'static + FnOnce() -> Result<(), E> + Send,
388 {
389 let ParParams {
390 num_workers,
391 buf_size,
392 } = params.into();
393 let (terminate_tx, mut terminate_rx) = broadcast::channel(1);
394 let stream = self
395 .take_until_error()
396 .take_until(async move {
397 let _ = terminate_rx.recv().await;
398 })
399 .stateful_map(f, |mut f, item| {
400 let fut = item.map(|item| f(item));
401 Some((f, fut))
402 })
403 .spawned(buf_size);
404
405 let worker_futures = (0..num_workers).map(|_| {
406 let mut stream = stream.clone();
407 let terminate_tx = terminate_tx.clone();
408
409 rt::spawn_blocking(move || {
410 while let Some(func) = rt::block_on(stream.next()) {
411 let result = (move || {
412 (func?)()?;
413 Ok(())
414 })();
415 if let Err(err) = result {
416 let _result = terminate_tx.send(()); return Err(err); }
419 }
420
421 Ok(())
422 })
423 });
424
425 future::try_join_all(worker_futures)
426 .map(|result| result.map(|_| ()))
427 .boxed()
428 }
429}
430
431#[cfg(test)]
434mod tests {
435 use super::*;
436 use crate::utils::async_test;
437 use rand::prelude::*;
438
439 async_test! {
440 async fn try_par_batching_test() {
441 {
442 let mut stream = stream::iter(iter::repeat(1).take(10))
443 .map(Ok)
444 .try_par_batching(None, |_, _| async move {
445 Result::<Option<((), _)>, _>::Err("init error")
446 });
447
448 assert_eq!(stream.next().await, Some(Err("init error")));
449 assert!(stream.next().await.is_none());
450 }
451
452 {
453 let mut stream = stream::repeat(1)
454 .take(10)
455 .map(Result::<_, ()>::Ok)
456 .try_par_batching(None, |_, rx| async move {
457 let mut sum = 0;
458
459 while let Ok(val) = rx.recv_async().await {
460 sum += val?;
461 if sum >= 3 {
462 return Ok(Some((sum, rx)));
463 }
464 }
465
466 if sum > 0 {
467 return Ok(Some((sum, rx)));
468 }
469
470 Ok(None)
471 });
472
473 let mut total = 0;
474 while total < 10 {
475 let sum = stream.next().await.unwrap().unwrap();
476 assert!(sum <= 3);
477 total += sum;
478 }
479 assert!(stream.next().await.is_none());
480 }
481
482 {
483 let mut stream = stream::repeat(1).take(10).map(Ok).try_par_batching(
484 None,
485 |_, rx| async move {
486 let mut sum = 0;
487
488 while let Ok(val) = rx.recv_async().await {
489 sum += val?;
490 if sum >= 3 {
491 return Ok(Some((sum, rx)));
492 }
493 }
494
495 if sum == 0 {
496 Ok(None)
497 } else {
498 Err(sum)
499 }
500 },
501 );
502
503 let mut total = 0;
504 while total < 10 {
505 let result = stream.next().await.unwrap();
506 match result {
507 Ok(sum) => {
508 assert!(sum == 3);
509 total += sum;
510 }
511 Err(sum) => {
512 assert!(sum < 3);
513 break;
514 }
515 }
516 }
517 assert!(stream.next().await.is_none());
518 }
519 }
520
521
522 async fn try_par_for_each_test() {
523 {
524 let result = stream::iter(vec![Ok(1usize), Ok(2), Ok(6), Ok(4)].into_iter())
525 .try_par_for_each(None, |_| async move { Result::<_, ()>::Ok(()) })
526 .await;
527
528 assert_eq!(result, Ok(()));
529 }
530
531 {
532 let result = stream::iter(vec![Ok(1usize), Ok(2), Err(-3isize), Ok(4)].into_iter())
533 .try_par_for_each(None, |_| async move { Ok(()) })
534 .await;
535
536 assert_eq!(result, Err(-3));
537 }
538 }
539
540
541 async fn try_par_for_each_blocking_test() {
542 {
543 let result = stream::iter(vec![Ok(1usize), Ok(2), Ok(6), Ok(4)])
544 .try_par_for_each_blocking(None, |_| || Result::<_, ()>::Ok(()))
545 .await;
546
547 assert_eq!(result, Ok(()));
548 }
549
550 {
551 let result = stream::iter(0..)
552 .then(|val| async move {
553 if val == 3 {
554 Err(val)
555 } else {
556 Ok(val)
557 }
558 })
559 .try_par_for_each_blocking(8, |_| || Ok(()))
560 .await;
561
562 assert_eq!(result, Err(3));
563 }
564
565 {
566 let result = stream::iter(0..)
567 .map(Ok)
568 .try_par_for_each_blocking(None, |val| {
569 move || {
570 if val == 3 {
571 std::thread::sleep(Duration::from_millis(100));
572 Err(val)
573 } else {
574 Ok(())
575 }
576 }
577 })
578 .await;
579
580 assert_eq!(result, Err(3));
581 }
582 }
583
584
585 async fn try_par_then_test() {
586 {
587 let vec: Vec<Result<_, _>> =
588 stream::iter(vec![Ok(1usize), Ok(2), Err(-3isize), Ok(4)].into_iter())
589 .try_par_then(None, |value| future::ok(value))
590 .collect()
591 .await;
592
593 assert!(matches!(
594 *vec,
595 [Err(-3)] | [Ok(1), Err(-3)] | [Ok(2), Err(-3)] | [Ok(1), Ok(2), Err(-3)],
596 ));
597 }
598
599 {
600 let vec: Result<Vec<()>, ()> = stream::iter(vec![])
601 .try_par_then(None, |()| async move { Ok(()) })
602 .try_collect()
603 .await;
604
605 assert!(matches!(vec, Ok(vec) if vec.is_empty()));
606 }
607
608 {
609 let vec: Vec<Result<_, _>> = stream::iter(1..)
610 .map(Ok)
611 .try_par_then(3, |index| async move {
612 match index {
613 3 | 6 => Err(index),
614 index => Ok(index),
615 }
616 })
617 .collect()
618 .await;
619
620 assert!(matches!(
621 *vec,
622 [Err(3)] | [Ok(1), Err(3)] | [Ok(2), Err(3)] | [Ok(1), Ok(2), Err(3)],
623 ));
624 }
625 }
626
627
628 async fn try_reorder_enumerated_test() {
629 let len: usize = 1000;
630 let mut rng = rand::thread_rng();
631
632 for _ in 0..10 {
633 let err_index_1 = rng.gen_range(0..len);
634 let err_index_2 = rng.gen_range(0..len);
635 let min_err_index = err_index_1.min(err_index_2);
636
637 let results: Vec<_> = stream::iter(0..len)
638 .map(move |value| {
639 if value == err_index_1 || value == err_index_2 {
640 Err(-(value as isize))
641 } else {
642 Ok(value)
643 }
644 })
645 .try_enumerate()
646 .try_par_then_unordered(None, |(index, value)| async move {
647 rt::sleep(Duration::from_millis(value as u64 % 10)).await;
648 Ok((index, value))
649 })
650 .try_reorder_enumerated()
651 .collect()
652 .await;
653 assert!(results.len() <= min_err_index + 1);
654
655 let (is_fused_at_error, _, _) = results.iter().cloned().fold(
656 (true, false, 0),
657 |(is_correct, found_err, expect), result| {
658 if !is_correct {
659 return (false, found_err, expect);
660 }
661
662 match result {
663 Ok(value) => {
664 let is_correct = value < min_err_index && value == expect && !found_err;
665 (is_correct, found_err, expect + 1)
666 }
667 Err(value) => {
668 let is_correct = (-value) as usize == min_err_index && !found_err;
669 let found_err = true;
670 (is_correct, found_err, expect + 1)
671 }
672 }
673 },
674 );
675 assert!(is_fused_at_error);
676 }
677 }
678
679
680 async fn try_map_blocking_test() {
681 {
682 let vec: Vec<_> = stream::iter(vec![Ok(1u64), Ok(2), Err(-3i64), Ok(4)])
683 .try_map_blocking(None, |val| Ok(val.pow(10)))
684 .collect()
685 .await;
686
687 assert_eq!(vec, [Ok(1), Ok(1024), Err(-3)]);
688 }
689
690 {
691 let vec: Vec<_> = stream::iter(vec![Ok(1i64), Ok(2), Err(-3i64), Ok(4)])
692 .try_map_blocking(None, |val| if val >= 2 { Err(-val) } else { Ok(val) })
693 .collect()
694 .await;
695
696 assert_eq!(vec, [Ok(1), Err(-2)]);
697 }
698 }
699 }
700}