par_stream/
try_stream.rs

1use crate::common::*;
2use tokio::sync::oneshot;
3
4/// The trait extends [TryStream](futures::stream::TryStream) types with combinators.
5pub trait TryStreamExt
6where
7    Self: TryStream,
8{
9    /// Create a fallible stream that gives the current iteration count.
10    ///
11    /// # Overflow Behavior
12    /// The method does no guarding against overflows, so enumerating more than `usize::MAX`
13    /// elements either produces the wrong result or panics. If debug assertions are enabled, a panic is guaranteed.
14    ///
15    /// Panics
16    /// The returned iterator might panic if the to-be-returned index would overflow a `usize`.
17    fn try_enumerate(self) -> TryEnumerate<Self, Self::Ok, Self::Error>;
18
19    /// Takes elements until an `Err(_)`.
20    fn take_until_error(self) -> TakeUntilError<Self, Self::Ok, Self::Error>;
21
22    /// Split the stream of `Result<T, E>` to a stream of `T` and a future of `Result<(), E>`.
23    ///
24    /// The method returns `(future, stream)`. If this combinator encoutners an `Err`,
25    /// `future.await` returns that error, and returned `stream` fuses. If the input
26    /// stream is depleted without error, `future.await` resolves to `Ok(())`.
27    fn catch_error(self) -> (ErrorNotify<Self::Error>, CatchError<Self>);
28
29    /// Similar to [and_then](futures::stream::TryStreamExt::and_then) but with a state.
30    fn try_stateful_then<B, U, F, Fut>(
31        self,
32        init: B,
33        f: F,
34    ) -> TryStatefulThen<Self, B, Self::Ok, U, Self::Error, F, Fut>
35    where
36        F: FnMut(B, Self::Ok) -> Fut,
37        Fut: Future<Output = Result<Option<(B, U)>, Self::Error>>;
38
39    /// Similar to [map](futures::stream::StreamExt::map) but with a state and is fallible.
40    fn try_stateful_map<B, U, F>(
41        self,
42        init: B,
43        f: F,
44    ) -> TryStatefulMap<Self, B, Self::Ok, U, Self::Error, F>
45    where
46        F: FnMut(B, Self::Ok) -> Result<Option<(B, U)>, Self::Error>;
47}
48
49impl<S, T, E> TryStreamExt for S
50where
51    S: Stream<Item = Result<T, E>>,
52{
53    fn try_enumerate(self) -> TryEnumerate<Self, T, E> {
54        TryEnumerate {
55            counter: 0,
56            fused: false,
57            _phantom: PhantomData,
58            stream: self,
59        }
60    }
61
62    fn take_until_error(self) -> TakeUntilError<Self, T, E> {
63        TakeUntilError {
64            _phantom: PhantomData,
65            is_terminated: false,
66            stream: self,
67        }
68    }
69
70    fn try_stateful_then<B, U, F, Fut>(
71        self,
72        init: B,
73        f: F,
74    ) -> TryStatefulThen<Self, B, T, U, E, F, Fut>
75    where
76        F: FnMut(B, T) -> Fut,
77        Fut: Future<Output = Result<Option<(B, U)>, E>>,
78    {
79        TryStatefulThen {
80            stream: self,
81            future: None,
82            state: Some(init),
83            f,
84            _phantom: PhantomData,
85        }
86    }
87
88    fn try_stateful_map<B, U, F>(self, init: B, f: F) -> TryStatefulMap<Self, B, T, U, E, F>
89    where
90        F: FnMut(B, T) -> Result<Option<(B, U)>, E>,
91    {
92        TryStatefulMap {
93            stream: self,
94            state: Some(init),
95            f,
96            _phantom: PhantomData,
97        }
98    }
99
100    fn catch_error(self) -> (ErrorNotify<E>, CatchError<S>) {
101        let (tx, rx) = oneshot::channel();
102        let stream = CatchError {
103            sender: Some(tx),
104            stream: self,
105        };
106        let notify = ErrorNotify { receiver: rx };
107
108        (notify, stream)
109    }
110}
111
112pub use take_until_error::*;
113mod take_until_error {
114    use super::*;
115
116    /// Stream for the [`take_until_error`](super::TryStreamExt::take_until_error) method.
117    #[pin_project]
118    pub struct TakeUntilError<St, T, E>
119    where
120        St: ?Sized,
121    {
122        pub(super) _phantom: PhantomData<(T, E)>,
123        pub(super) is_terminated: bool,
124        #[pin]
125        pub(super) stream: St,
126    }
127
128    impl<St, T, E> Stream for TakeUntilError<St, T, E>
129    where
130        St: Stream<Item = Result<T, E>>,
131    {
132        type Item = Result<T, E>;
133
134        fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
135            let this = self.project();
136
137            Ready({
138                if *this.is_terminated {
139                    None
140                } else if let Some(result) = ready!(this.stream.poll_next(cx)) {
141                    if result.is_err() {
142                        *this.is_terminated = true;
143                    }
144                    Some(result)
145                } else {
146                    *this.is_terminated = true;
147                    None
148                }
149            })
150        }
151    }
152}
153
154pub use try_stateful_then::*;
155mod try_stateful_then {
156    use super::*;
157
158    /// Stream for the [`try_stateful_then`](super::TryStreamExt::try_stateful_then) method.
159    #[pin_project]
160    pub struct TryStatefulThen<St, B, T, U, E, F, Fut>
161    where
162        St: ?Sized,
163    {
164        #[pin]
165        pub(super) future: Option<Fut>,
166        pub(super) state: Option<B>,
167        pub(super) f: F,
168        pub(super) _phantom: PhantomData<(T, U, E)>,
169        #[pin]
170        pub(super) stream: St,
171    }
172
173    impl<St, B, T, U, E, F, Fut> Stream for TryStatefulThen<St, B, T, U, E, F, Fut>
174    where
175        St: Stream<Item = Result<T, E>>,
176        F: FnMut(B, T) -> Fut,
177        Fut: Future<Output = Result<Option<(B, U)>, E>>,
178    {
179        type Item = Result<U, E>;
180
181        fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
182            let mut this = self.project();
183
184            Poll::Ready(loop {
185                if let Some(fut) = this.future.as_mut().as_pin_mut() {
186                    let result = ready!(fut.poll(cx));
187                    this.future.set(None);
188
189                    match result {
190                        Ok(Some((state, item))) => {
191                            *this.state = Some(state);
192                            break Some(Ok(item));
193                        }
194                        Ok(None) => {
195                            break None;
196                        }
197                        Err(err) => break Some(Err(err)),
198                    }
199                } else if let Some(state) = this.state.take() {
200                    match this.stream.as_mut().poll_next(cx) {
201                        Ready(Some(Ok(item))) => {
202                            this.future.set(Some((this.f)(state, item)));
203                        }
204                        Ready(Some(Err(err))) => break Some(Err(err)),
205                        Ready(None) => break None,
206                        Pending => {
207                            *this.state = Some(state);
208                            return Pending;
209                        }
210                    }
211                } else {
212                    break None;
213                }
214            })
215        }
216    }
217}
218
219pub use try_stateful_map::*;
220mod try_stateful_map {
221    use super::*;
222
223    /// Stream for the [`try_stateful_map`](super::TryStreamExt::try_stateful_map) method.
224    #[pin_project]
225    pub struct TryStatefulMap<St, B, T, U, E, F>
226    where
227        St: ?Sized,
228    {
229        pub(super) state: Option<B>,
230        pub(super) f: F,
231        pub(super) _phantom: PhantomData<(T, U, E)>,
232        #[pin]
233        pub(super) stream: St,
234    }
235
236    impl<St, B, T, U, E, F> Stream for TryStatefulMap<St, B, T, U, E, F>
237    where
238        St: Stream<Item = Result<T, E>>,
239        F: FnMut(B, T) -> Result<Option<(B, U)>, E>,
240    {
241        type Item = Result<U, E>;
242
243        fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
244            let mut this = self.project();
245
246            Poll::Ready({
247                if let Some(state) = this.state.take() {
248                    match this.stream.as_mut().poll_next(cx) {
249                        Ready(Some(Ok(in_item))) => {
250                            let result = (this.f)(state, in_item);
251
252                            match result {
253                                Ok(Some((state, out_item))) => {
254                                    *this.state = Some(state);
255                                    Some(Ok(out_item))
256                                }
257                                Ok(None) => None,
258                                Err(err) => Some(Err(err)),
259                            }
260                        }
261                        Ready(Some(Err(err))) => Some(Err(err)),
262                        Ready(None) => None,
263                        Pending => {
264                            *this.state = Some(state);
265                            return Pending;
266                        }
267                    }
268                } else {
269                    None
270                }
271            })
272        }
273    }
274}
275
276pub use try_enumerate::*;
277mod try_enumerate {
278    use super::*;
279
280    /// Stream for the [try_enumerate()](crate::try_stream::TryStreamExt::try_enumerate) method.
281    #[derive(Derivative)]
282    #[derivative(Debug)]
283    #[pin_project]
284    pub struct TryEnumerate<S, T, E>
285    where
286        S: ?Sized,
287    {
288        pub(super) counter: usize,
289        pub(super) fused: bool,
290        pub(super) _phantom: PhantomData<(T, E)>,
291        #[pin]
292        #[derivative(Debug = "ignore")]
293        pub(super) stream: S,
294    }
295
296    impl<S, T, E> Stream for TryEnumerate<S, T, E>
297    where
298        S: Stream<Item = Result<T, E>>,
299    {
300        type Item = Result<(usize, T), E>;
301
302        fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
303            let mut this = self.project();
304
305            Ready({
306                if *this.fused {
307                    None
308                } else {
309                    match ready!(Pin::new(&mut this.stream).poll_next(cx)) {
310                        Some(Ok(item)) => {
311                            let index = *this.counter;
312                            *this.counter += 1;
313                            Some(Ok((index, item)))
314                        }
315                        Some(Err(err)) => {
316                            *this.fused = true;
317                            Some(Err(err))
318                        }
319                        None => None,
320                    }
321                }
322            })
323        }
324    }
325
326    impl<S, T, E> FusedStream for TryEnumerate<S, T, E>
327    where
328        S: Stream<Item = Result<T, E>>,
329    {
330        fn is_terminated(&self) -> bool {
331            self.fused
332        }
333    }
334}
335
336pub use catch_error::*;
337mod catch_error {
338    use super::*;
339
340    /// Stream for the [`catch_error`](super::TryStreamExt::catch_error) method.
341    #[pin_project]
342    pub struct CatchError<St>
343    where
344        St: ?Sized + TryStream,
345    {
346        pub(super) sender: Option<oneshot::Sender<St::Error>>,
347        #[pin]
348        pub(super) stream: St,
349    }
350
351    impl<St> Stream for CatchError<St>
352    where
353        St: TryStream,
354    {
355        type Item = St::Ok;
356
357        fn poll_next(self: Pin<&mut Self>, ctx: &mut Context) -> Poll<Option<Self::Item>> {
358            let this = self.project();
359
360            Ready({
361                if let Some(sender) = this.sender.take() {
362                    match this.stream.try_poll_next(ctx) {
363                        Ready(Some(Ok(item))) => {
364                            *this.sender = Some(sender);
365                            Some(item)
366                        }
367                        Ready(Some(Err(err))) => {
368                            let _ = sender.send(err);
369                            None
370                        }
371                        Ready(None) => {
372                            drop(sender);
373                            None
374                        }
375                        Pending => {
376                            *this.sender = Some(sender);
377                            return Pending;
378                        }
379                    }
380                } else {
381                    None
382                }
383            })
384        }
385    }
386
387    /// Future for the [`catch_error`](super::TryStreamExt::catch_error) method.
388    #[pin_project]
389    pub struct ErrorNotify<E> {
390        #[pin]
391        pub(super) receiver: oneshot::Receiver<E>,
392    }
393
394    impl<E> ErrorNotify<E> {
395        pub fn try_catch(mut self) -> ControlFlow<Result<(), E>, Self> {
396            use oneshot::error::TryRecvError::*;
397
398            match self.receiver.try_recv() {
399                Ok(err) => Break(Err(err)),
400                Err(Empty) => Continue(self),
401                Err(Closed) => Break(Ok(())),
402            }
403        }
404    }
405
406    impl<E> Future for ErrorNotify<E> {
407        type Output = Result<(), E>;
408
409        fn poll(self: Pin<&mut Self>, ctx: &mut Context) -> Poll<Self::Output> {
410            let this = self.project();
411
412            Ready(match ready!(this.receiver.poll(ctx)) {
413                Ok(err) => Err(err),
414                Err(_) => Ok(()),
415            })
416        }
417    }
418}
419
420#[cfg(test)]
421mod tests {
422    use super::*;
423    use crate::utils::async_test;
424
425    async_test! {
426        async fn take_until_error_test() {
427            {
428                let vec: Vec<Result<(), ()>> = stream::empty().take_until_error().collect().await;
429                assert_eq!(vec, []);
430            }
431
432            {
433                let vec: Vec<Result<_, ()>> = stream::iter([Ok(0), Ok(1), Ok(2), Ok(3)])
434                    .take_until_error()
435                    .collect()
436                    .await;
437                assert_eq!(vec, [Ok(0), Ok(1), Ok(2), Ok(3)]);
438            }
439
440            {
441                let vec: Vec<Result<_, _>> = stream::iter([Ok(0), Ok(1), Err(2), Ok(3)])
442                    .take_until_error()
443                    .collect()
444                    .await;
445                assert_eq!(vec, [Ok(0), Ok(1), Err(2),]);
446            }
447        }
448
449
450        async fn try_stateful_then_test() {
451            {
452                let values: Result<Vec<_>, ()> = stream::iter([Ok(3), Ok(1), Ok(4), Ok(1)])
453                    .try_stateful_then(0, |acc, val| async move {
454                        let new_acc = acc + val;
455                        Ok(Some((new_acc, new_acc)))
456                    })
457                    .try_collect()
458                    .await;
459
460                assert_eq!(values, Ok(vec![3, 4, 8, 9]));
461            }
462
463            {
464                let mut stream = stream::iter([Ok(3), Ok(1), Err(()), Ok(1)])
465                    .try_stateful_then(0, |acc, val| async move {
466                        let new_acc = acc + val;
467                        Ok(Some((new_acc, new_acc)))
468                    })
469                    .boxed();
470
471                assert_eq!(stream.next().await, Some(Ok(3)));
472                assert_eq!(stream.next().await, Some(Ok(4)));
473                assert_eq!(stream.next().await, Some(Err(())));
474                assert_eq!(stream.next().await, None);
475            }
476
477            {
478                let mut stream = stream::iter([Ok(3), Ok(1), Ok(4), Ok(1), Err(())])
479                    .try_stateful_then(0, |acc, val| async move {
480                        let new_acc = acc + val;
481                        if new_acc != 8 {
482                            Ok(Some((new_acc, new_acc)))
483                        } else {
484                            Err(())
485                        }
486                    })
487                    .boxed();
488
489                assert_eq!(stream.next().await, Some(Ok(3)));
490                assert_eq!(stream.next().await, Some(Ok(4)));
491                assert_eq!(stream.next().await, Some(Err(())));
492                assert_eq!(stream.next().await, None);
493            }
494
495            {
496                let mut stream = stream::iter([Ok(3), Ok(1), Ok(4), Ok(1), Err(())])
497                    .try_stateful_then(0, |acc, val| async move {
498                        let new_acc = acc + val;
499                        if new_acc != 8 {
500                            Ok(Some((new_acc, new_acc)))
501                        } else {
502                            Ok(None)
503                        }
504                    })
505                    .boxed();
506
507                assert_eq!(stream.next().await, Some(Ok(3)));
508                assert_eq!(stream.next().await, Some(Ok(4)));
509                assert_eq!(stream.next().await, None);
510            }
511        }
512
513
514        async fn catch_error_test() {
515            {
516                let (notify, stream) = stream::empty::<Result<(), ()>>().catch_error();
517
518                let vec: Vec<_> = stream.collect().await;
519                let result = notify.await;
520
521                assert_eq!(vec, []);
522                assert_eq!(result, Ok(()));
523            }
524
525            {
526                let (notify, stream) =
527                    stream::iter([Result::<_, ()>::Ok(0), Ok(1), Ok(2), Ok(3)]).catch_error();
528
529                let vec: Vec<_> = stream.collect().await;
530                let result = notify.await;
531
532                assert_eq!(vec, [0, 1, 2, 3]);
533                assert_eq!(result, Ok(()));
534            }
535
536            {
537                let (notify, stream) = stream::iter([Ok(0), Ok(1), Err(2), Ok(3)]).catch_error();
538
539                let vec: Vec<_> = stream.collect().await;
540                let result = notify.await;
541
542                assert_eq!(vec, [0, 1]);
543                assert_eq!(result, Err(2));
544            }
545
546            {
547                let (notify, mut stream) = stream::empty::<Result<(), ()>>().catch_error();
548
549                let notify = match notify.try_catch() {
550                    Continue(notify) => notify,
551                    _ => unreachable!(),
552                };
553
554                assert_eq!(stream.next().await, None);
555                assert!(matches!(notify.try_catch(), Break(Ok(()))));
556            }
557
558            {
559                let (notify, mut stream) = stream::iter([Result::<_, ()>::Ok(0)]).catch_error();
560
561                let notify = match notify.try_catch() {
562                    Continue(notify) => notify,
563                    _ => unreachable!(),
564                };
565
566                assert_eq!(stream.next().await, Some(0));
567                let notify = match notify.try_catch() {
568                    Continue(notify) => notify,
569                    _ => unreachable!(),
570                };
571
572                assert_eq!(stream.next().await, None);
573                assert!(matches!(notify.try_catch(), Break(Ok(()))));
574            }
575
576            {
577                let (notify, mut stream) = stream::iter([Ok(0), Err(2)]).catch_error();
578
579                let notify = match notify.try_catch() {
580                    Continue(notify) => notify,
581                    _ => unreachable!(),
582                };
583
584                assert_eq!(stream.next().await, Some(0));
585                let notify = match notify.try_catch() {
586                    Continue(notify) => notify,
587                    _ => unreachable!(),
588                };
589
590                assert_eq!(stream.next().await, None);
591                assert!(matches!(notify.try_catch(), Break(Err(2))));
592            }
593        }
594    }
595}