par_stream/
try_par_stream.rs

1use crate::{
2    common::*,
3    config::{BufSize, ParParams},
4    par_stream::ParStreamExt as _,
5    rt,
6    stream::StreamExt as _,
7    try_index_stream::TryIndexStreamExt as _,
8    try_stream::{TakeUntilError, TryStreamExt as _},
9    utils,
10};
11use flume::r#async::RecvStream;
12use tokio::sync::broadcast;
13
14/// Stream for the [try_par_batching()](TryParStreamExt::try_par_batching) method.
15pub type TryParBatching<T, E> = TakeUntilError<RecvStream<'static, Result<T, E>>, T, E>;
16
17/// The trait extends [TryStream](futures::stream::TryStream) types with parallel processing combinators.
18pub trait TryParStreamExt
19where
20    Self: 'static + Send + TryStream,
21    Self::Ok: 'static + Send,
22    Self::Error: 'static + Send,
23{
24    /// Fallible stream combinator for [map_blocking](crate::ParStreamExt::map_blocking).
25    fn try_map_blocking<B, T, F>(
26        self,
27        buf_size: B,
28        f: F,
29    ) -> RecvStream<'static, Result<T, Self::Error>>
30    where
31        B: Into<BufSize>,
32        T: Send,
33        F: 'static + Send + FnMut(Self::Ok) -> Result<T, Self::Error>;
34
35    /// Fallible stream combinator for [par_batching](crate::ParStreamExt::par_batching).
36    fn try_par_batching<U, P, F, Fut>(self, params: P, f: F) -> TryParBatching<U, Self::Error>
37    where
38        Self: Sized,
39        P: Into<ParParams>,
40        F: 'static
41            + Clone
42            + Send
43            + FnMut(usize, flume::Receiver<Result<Self::Ok, Self::Error>>) -> Fut,
44        Fut: 'static
45            + Future<
46                Output = Result<
47                    Option<(U, flume::Receiver<Result<Self::Ok, Self::Error>>)>,
48                    Self::Error,
49                >,
50            >
51            + Send,
52        U: 'static + Send;
53
54    /// Fallible stream combinator for [par_then](crate::ParStreamExt::par_then).
55    fn try_par_then<U, P, F, Fut>(
56        self,
57        params: P,
58        f: F,
59    ) -> BoxStream<'static, Result<U, Self::Error>>
60    where
61        P: Into<ParParams>,
62        U: 'static + Send,
63        F: 'static + FnMut(Self::Ok) -> Fut + Send,
64        Fut: 'static + Future<Output = Result<U, Self::Error>> + Send;
65
66    /// Fallible stream combinator for [par_then_unordered](crate::ParStreamExt::par_then_unordered).
67    fn try_par_then_unordered<U, P, F, Fut>(
68        self,
69        params: P,
70        f: F,
71    ) -> BoxStream<'static, Result<U, Self::Error>>
72    where
73        U: 'static + Send,
74        F: 'static + FnMut(Self::Ok) -> Fut + Send,
75        Fut: 'static + Future<Output = Result<U, Self::Error>> + Send,
76        P: Into<ParParams>;
77
78    /// Fallible stream combinator for [par_map](crate::ParStreamExt::par_map).
79    fn try_par_map<U, P, F, Func>(
80        self,
81        params: P,
82        f: F,
83    ) -> BoxStream<'static, Result<U, Self::Error>>
84    where
85        P: Into<ParParams>,
86        U: 'static + Send,
87        F: 'static + FnMut(Self::Ok) -> Func + Send,
88        Func: 'static + FnOnce() -> Result<U, Self::Error> + Send;
89
90    /// Fallible stream combinator for [par_map_unordered](crate::ParStreamExt::par_map_unordered).
91    fn try_par_map_unordered<U, P, F, Func>(
92        self,
93        params: P,
94        f: F,
95    ) -> BoxStream<'static, Result<U, Self::Error>>
96    where
97        P: Into<ParParams>,
98        U: 'static + Send,
99        F: 'static + FnMut(Self::Ok) -> Func + Send,
100        Func: 'static + FnOnce() -> Result<U, Self::Error> + Send;
101
102    /// Fallible stream combinator for [par_for_each](crate::par_stream::ParStreamExt::par_for_each).
103    fn try_par_for_each<P, F, Fut>(
104        self,
105        params: P,
106        f: F,
107    ) -> BoxFuture<'static, Result<(), Self::Error>>
108    where
109        P: Into<ParParams>,
110        F: 'static + FnMut(Self::Ok) -> Fut + Send,
111        Fut: 'static + Future<Output = Result<(), Self::Error>> + Send;
112
113    /// Fallible stream combinator for [par_for_each_blocking](crate::par_stream::ParStreamExt::par_for_each_blocking).
114    fn try_par_for_each_blocking<P, F, Func>(
115        self,
116        params: P,
117        f: F,
118    ) -> BoxFuture<'static, Result<(), Self::Error>>
119    where
120        P: Into<ParParams>,
121        F: 'static + FnMut(Self::Ok) -> Func + Send,
122        Func: 'static + FnOnce() -> Result<(), Self::Error> + Send;
123}
124
125impl<S, T, E> TryParStreamExt for S
126where
127    Self: 'static + Send + Stream<Item = Result<T, E>>,
128    T: 'static + Send,
129    E: 'static + Send,
130{
131    fn try_map_blocking<B, U, F>(self, buf_size: B, mut f: F) -> RecvStream<'static, Result<U, E>>
132    where
133        B: Into<BufSize>,
134        U: Send,
135        F: 'static + Send + FnMut(T) -> Result<U, E>,
136    {
137        let buf_size = buf_size.into().get();
138        let mut stream = self.boxed();
139        let (output_tx, output_rx) = utils::channel(buf_size);
140
141        rt::spawn_blocking(move || loop {
142            match rt::block_on(stream.next()) {
143                Some(Ok(input)) => {
144                    let result = f(input);
145                    let is_err = result.is_err();
146
147                    if output_tx.send(result).is_err() {
148                        break;
149                    }
150
151                    if is_err {
152                        break;
153                    }
154                }
155                Some(Err(err)) => {
156                    let _ = output_tx.send(Err(err));
157                    break;
158                }
159                None => break,
160            }
161        });
162
163        output_rx.into_stream()
164    }
165
166    fn try_par_batching<U, P, F, Fut>(self, params: P, f: F) -> TryParBatching<U, E>
167    where
168        P: Into<ParParams>,
169        U: 'static + Send,
170        F: 'static
171            + Clone
172            + Send
173            + FnMut(usize, flume::Receiver<Result<Self::Ok, Self::Error>>) -> Fut,
174        Fut: 'static
175            + Future<
176                Output = Result<
177                    Option<(U, flume::Receiver<Result<Self::Ok, Self::Error>>)>,
178                    Self::Error,
179                >,
180            >
181            + Send,
182    {
183        let ParParams {
184            num_workers,
185            buf_size,
186        } = params.into();
187
188        let (input_tx, input_rx) = utils::channel(buf_size);
189        let (output_tx, output_rx) = utils::channel(buf_size);
190        let (terminate_tx, _) = broadcast::channel(1);
191
192        rt::spawn(async move {
193            let _ = self.map(Ok).forward(input_tx.into_sink()).await;
194        });
195
196        (0..num_workers).for_each(move |worker_index| {
197            let input_rx = input_rx.clone();
198            let output_tx = output_tx.clone();
199            let mut terminate_rx = terminate_tx.subscribe();
200            let terminate_tx = terminate_tx.clone();
201            let f = f.clone();
202
203            rt::spawn(async move {
204                let _ = stream::repeat(())
205                    .take_until(async move {
206                        let _ = terminate_rx.recv().await;
207                    })
208                    .stateful_then(
209                        Some((f, terminate_tx, input_rx)),
210                        move |state, ()| async move {
211                            let (mut f, terminate_tx, input_rx) = state.unwrap();
212                            let result = f(worker_index, input_rx).await;
213
214                            if result.is_err() {
215                                let _ = terminate_tx.send(());
216                            }
217
218                            match result {
219                                Ok(Some((item, input_rx))) => {
220                                    Some((Some((f, terminate_tx, input_rx)), Ok(item)))
221                                }
222                                Ok(None) => None,
223                                Err(err) => Some((None, Err(err))),
224                            }
225                        },
226                    )
227                    .take_until_error()
228                    .map(Ok)
229                    .forward(output_tx.into_sink())
230                    .await;
231            });
232        });
233
234        output_rx.into_stream().take_until_error()
235    }
236
237    fn try_par_then<U, P, F, Fut>(self, params: P, mut f: F) -> BoxStream<'static, Result<U, E>>
238    where
239        P: Into<ParParams>,
240        U: 'static + Send,
241        F: 'static + FnMut(T) -> Fut + Send,
242        Fut: 'static + Future<Output = Result<U, E>> + Send,
243    {
244        self.take_until_error()
245            .enumerate()
246            .par_then_unordered(params, move |(index, input)| {
247                let fut = input.map(|input| f(input));
248
249                async move {
250                    let output = fut?.await?;
251                    Ok((index, output))
252                }
253            })
254            .try_reorder_enumerated()
255            .boxed()
256    }
257
258    fn try_par_then_unordered<U, P, F, Fut>(
259        self,
260        params: P,
261        f: F,
262    ) -> BoxStream<'static, Result<U, E>>
263    where
264        U: 'static + Send,
265        F: 'static + FnMut(T) -> Fut + Send,
266        Fut: 'static + Future<Output = Result<U, E>> + Send,
267        P: Into<ParParams>,
268    {
269        let (input_error, input_stream) = self.catch_error();
270        let output_stream = input_stream.par_then_unordered(params, f);
271
272        stream::select(
273            input_error
274                .map(|result| result.map(|()| None))
275                .into_stream(),
276            output_stream.map(|result| result.map(Some)),
277        )
278        .try_filter_map(|item| future::ok(item))
279        .take_until_error()
280        .boxed()
281    }
282
283    fn try_par_map<U, P, F, Func>(self, params: P, mut f: F) -> BoxStream<'static, Result<U, E>>
284    where
285        P: Into<ParParams>,
286        U: 'static + Send,
287        F: 'static + FnMut(T) -> Func + Send,
288        Func: 'static + FnOnce() -> Result<U, E> + Send,
289    {
290        self.take_until_error()
291            .enumerate()
292            .par_map_unordered(params, move |(index, input)| {
293                let func = input.map(|input| f(input));
294
295                move || {
296                    let output = (func?)()?;
297                    Ok((index, output))
298                }
299            })
300            .try_reorder_enumerated()
301            .boxed()
302    }
303
304    fn try_par_map_unordered<U, P, F, Func>(
305        self,
306        params: P,
307        f: F,
308    ) -> BoxStream<'static, Result<U, E>>
309    where
310        P: Into<ParParams>,
311        U: 'static + Send,
312        F: 'static + FnMut(T) -> Func + Send,
313        Func: 'static + FnOnce() -> Result<U, E> + Send,
314    {
315        let (input_error, input_stream) = self.catch_error();
316        let output_stream = input_stream.par_map_unordered(params, f);
317
318        stream::select(
319            input_error
320                .map(|result| result.map(|()| None))
321                .into_stream(),
322            output_stream.map(|result| result.map(Some)),
323        )
324        .try_filter_map(|item| future::ok(item))
325        .take_until_error()
326        .boxed()
327    }
328
329    fn try_par_for_each<P, F, Fut>(self, params: P, f: F) -> BoxFuture<'static, Result<(), E>>
330    where
331        P: Into<ParParams>,
332        F: 'static + FnMut(T) -> Fut + Send,
333        Fut: 'static + Future<Output = Result<(), E>> + Send,
334    {
335        let ParParams {
336            num_workers,
337            buf_size,
338        } = params.into();
339        let (terminate_tx, mut terminate_rx) = broadcast::channel(1);
340        let input_stream = self
341            .take_until_error()
342            .take_until(async move {
343                let _ = terminate_rx.recv().await;
344            })
345            .stateful_map(f, |mut f, item| {
346                let fut = item.map(|item| f(item));
347                Some((f, fut))
348            })
349            .spawned(buf_size);
350
351        let worker_futures = (0..num_workers).map(move |_| {
352            let terminate_tx = terminate_tx.clone();
353
354            rt::spawn(
355                input_stream
356                    .clone()
357                    .stateful_then(terminate_tx, |terminate_tx, fut| async move {
358                        let result = async move {
359                            fut?.await?;
360                            Ok(())
361                        }
362                        .await;
363
364                        if result.is_err() {
365                            let _ = terminate_tx.send(());
366                        }
367
368                        Some((terminate_tx, result))
369                    })
370                    .try_for_each(|()| future::ok(())),
371            )
372        });
373
374        future::try_join_all(worker_futures)
375            .map(|result| result.map(|_| ()))
376            .boxed()
377    }
378
379    fn try_par_for_each_blocking<P, F, Func>(
380        self,
381        params: P,
382        f: F,
383    ) -> BoxFuture<'static, Result<(), E>>
384    where
385        P: Into<ParParams>,
386        F: 'static + FnMut(T) -> Func + Send,
387        Func: 'static + FnOnce() -> Result<(), E> + Send,
388    {
389        let ParParams {
390            num_workers,
391            buf_size,
392        } = params.into();
393        let (terminate_tx, mut terminate_rx) = broadcast::channel(1);
394        let stream = self
395            .take_until_error()
396            .take_until(async move {
397                let _ = terminate_rx.recv().await;
398            })
399            .stateful_map(f, |mut f, item| {
400                let fut = item.map(|item| f(item));
401                Some((f, fut))
402            })
403            .spawned(buf_size);
404
405        let worker_futures = (0..num_workers).map(|_| {
406            let mut stream = stream.clone();
407            let terminate_tx = terminate_tx.clone();
408
409            rt::spawn_blocking(move || {
410                while let Some(func) = rt::block_on(stream.next()) {
411                    let result = (move || {
412                        (func?)()?;
413                        Ok(())
414                    })();
415                    if let Err(err) = result {
416                        let _result = terminate_tx.send(()); // shutdown workers
417                        return Err(err); // return error
418                    }
419                }
420
421                Ok(())
422            })
423        });
424
425        future::try_join_all(worker_futures)
426            .map(|result| result.map(|_| ()))
427            .boxed()
428    }
429}
430
431// tests
432
433#[cfg(test)]
434mod tests {
435    use super::*;
436    use crate::utils::async_test;
437    use rand::prelude::*;
438
439    async_test! {
440        async fn try_par_batching_test() {
441            {
442                let mut stream = stream::iter(iter::repeat(1).take(10))
443                    .map(Ok)
444                    .try_par_batching(None, |_, _| async move {
445                        Result::<Option<((), _)>, _>::Err("init error")
446                    });
447
448                assert_eq!(stream.next().await, Some(Err("init error")));
449                assert!(stream.next().await.is_none());
450            }
451
452            {
453                let mut stream = stream::repeat(1)
454                    .take(10)
455                    .map(Result::<_, ()>::Ok)
456                    .try_par_batching(None, |_, rx| async move {
457                        let mut sum = 0;
458
459                        while let Ok(val) = rx.recv_async().await {
460                            sum += val?;
461                            if sum >= 3 {
462                                return Ok(Some((sum, rx)));
463                            }
464                        }
465
466                        if sum > 0 {
467                            return Ok(Some((sum, rx)));
468                        }
469
470                        Ok(None)
471                    });
472
473                let mut total = 0;
474                while total < 10 {
475                    let sum = stream.next().await.unwrap().unwrap();
476                    assert!(sum <= 3);
477                    total += sum;
478                }
479                assert!(stream.next().await.is_none());
480            }
481
482            {
483                let mut stream = stream::repeat(1).take(10).map(Ok).try_par_batching(
484                    None,
485                    |_, rx| async move {
486                        let mut sum = 0;
487
488                        while let Ok(val) = rx.recv_async().await {
489                            sum += val?;
490                            if sum >= 3 {
491                                return Ok(Some((sum, rx)));
492                            }
493                        }
494
495                        if sum == 0 {
496                            Ok(None)
497                        } else {
498                            Err(sum)
499                        }
500                    },
501                );
502
503                let mut total = 0;
504                while total < 10 {
505                    let result = stream.next().await.unwrap();
506                    match result {
507                        Ok(sum) => {
508                            assert!(sum == 3);
509                            total += sum;
510                        }
511                        Err(sum) => {
512                            assert!(sum < 3);
513                            break;
514                        }
515                    }
516                }
517                assert!(stream.next().await.is_none());
518            }
519        }
520
521
522        async fn try_par_for_each_test() {
523            {
524                let result = stream::iter(vec![Ok(1usize), Ok(2), Ok(6), Ok(4)].into_iter())
525                    .try_par_for_each(None, |_| async move { Result::<_, ()>::Ok(()) })
526                    .await;
527
528                assert_eq!(result, Ok(()));
529            }
530
531            {
532                let result = stream::iter(vec![Ok(1usize), Ok(2), Err(-3isize), Ok(4)].into_iter())
533                    .try_par_for_each(None, |_| async move { Ok(()) })
534                    .await;
535
536                assert_eq!(result, Err(-3));
537            }
538        }
539
540
541        async fn try_par_for_each_blocking_test() {
542            {
543                let result = stream::iter(vec![Ok(1usize), Ok(2), Ok(6), Ok(4)])
544                    .try_par_for_each_blocking(None, |_| || Result::<_, ()>::Ok(()))
545                    .await;
546
547                assert_eq!(result, Ok(()));
548            }
549
550            {
551                let result = stream::iter(0..)
552                    .then(|val| async move {
553                        if val == 3 {
554                            Err(val)
555                        } else {
556                            Ok(val)
557                        }
558                    })
559                    .try_par_for_each_blocking(8, |_| || Ok(()))
560                    .await;
561
562                assert_eq!(result, Err(3));
563            }
564
565            {
566                let result = stream::iter(0..)
567                    .map(Ok)
568                    .try_par_for_each_blocking(None, |val| {
569                        move || {
570                            if val == 3 {
571                                std::thread::sleep(Duration::from_millis(100));
572                                Err(val)
573                            } else {
574                                Ok(())
575                            }
576                        }
577                    })
578                    .await;
579
580                assert_eq!(result, Err(3));
581            }
582        }
583
584
585        async fn try_par_then_test() {
586            {
587                let vec: Vec<Result<_, _>> =
588                    stream::iter(vec![Ok(1usize), Ok(2), Err(-3isize), Ok(4)].into_iter())
589                    .try_par_then(None, |value| future::ok(value))
590                    .collect()
591                    .await;
592
593                assert!(matches!(
594                    *vec,
595                    [Err(-3)] | [Ok(1), Err(-3)] | [Ok(2), Err(-3)] | [Ok(1), Ok(2), Err(-3)],
596                ));
597            }
598
599            {
600                let vec: Result<Vec<()>, ()> = stream::iter(vec![])
601                    .try_par_then(None, |()| async move { Ok(()) })
602                    .try_collect()
603                    .await;
604
605                assert!(matches!(vec, Ok(vec) if vec.is_empty()));
606            }
607
608            {
609                let vec: Vec<Result<_, _>> = stream::iter(1..)
610                    .map(Ok)
611                    .try_par_then(3, |index| async move {
612                        match index {
613                            3 | 6 => Err(index),
614                            index => Ok(index),
615                        }
616                    })
617                    .collect()
618                    .await;
619
620                assert!(matches!(
621                    *vec,
622                    [Err(3)] | [Ok(1), Err(3)] | [Ok(2), Err(3)] | [Ok(1), Ok(2), Err(3)],
623                ));
624            }
625        }
626
627
628        async fn try_reorder_enumerated_test() {
629            let len: usize = 1000;
630            let mut rng = rand::thread_rng();
631
632            for _ in 0..10 {
633                let err_index_1 = rng.gen_range(0..len);
634                let err_index_2 = rng.gen_range(0..len);
635                let min_err_index = err_index_1.min(err_index_2);
636
637                let results: Vec<_> = stream::iter(0..len)
638                    .map(move |value| {
639                        if value == err_index_1 || value == err_index_2 {
640                            Err(-(value as isize))
641                        } else {
642                            Ok(value)
643                        }
644                    })
645                    .try_enumerate()
646                    .try_par_then_unordered(None, |(index, value)| async move {
647                        rt::sleep(Duration::from_millis(value as u64 % 10)).await;
648                        Ok((index, value))
649                    })
650                    .try_reorder_enumerated()
651                    .collect()
652                    .await;
653                assert!(results.len() <= min_err_index + 1);
654
655                let (is_fused_at_error, _, _) = results.iter().cloned().fold(
656                    (true, false, 0),
657                    |(is_correct, found_err, expect), result| {
658                        if !is_correct {
659                            return (false, found_err, expect);
660                        }
661
662                        match result {
663                            Ok(value) => {
664                                let is_correct = value < min_err_index && value == expect && !found_err;
665                                (is_correct, found_err, expect + 1)
666                            }
667                            Err(value) => {
668                                let is_correct = (-value) as usize == min_err_index && !found_err;
669                                let found_err = true;
670                                (is_correct, found_err, expect + 1)
671                            }
672                        }
673                    },
674                );
675                assert!(is_fused_at_error);
676            }
677        }
678
679
680        async fn try_map_blocking_test() {
681            {
682                let vec: Vec<_> = stream::iter(vec![Ok(1u64), Ok(2), Err(-3i64), Ok(4)])
683                    .try_map_blocking(None, |val| Ok(val.pow(10)))
684                    .collect()
685                    .await;
686
687                assert_eq!(vec, [Ok(1), Ok(1024), Err(-3)]);
688            }
689
690            {
691                let vec: Vec<_> = stream::iter(vec![Ok(1i64), Ok(2), Err(-3i64), Ok(4)])
692                    .try_map_blocking(None, |val| if val >= 2 { Err(-val) } else { Ok(val) })
693                    .collect()
694                    .await;
695
696                assert_eq!(vec, [Ok(1), Err(-2)]);
697            }
698        }
699    }
700}