par_stream/
try_index_stream.rs

1use crate::common::*;
2
3/// The trait extends [TryStream](futures::stream::TryStream) types with ordering manipulation combinators.
4pub trait TryIndexStreamExt
5where
6    Self: Stream<Item = Result<(usize, Self::Ok), Self::Error>>,
7{
8    type Ok;
9    type Error;
10
11    /// Reorders the input items `Ok((index, item))` according to the index number and returns `item`.
12    ///
13    /// It can be combined with [try_enumerate](crate::try_stream::TryStreamExt::try_enumerate) and
14    /// unordered parallel tasks.
15    ///
16    /// If an `Err` item is received, it stops receiving future items and flushes buffered values,
17    /// and sends the `Err` in the end.
18    ///
19    /// ```rust
20    /// # par_stream::rt::block_on_executor(async move {
21    /// use futures::prelude::*;
22    /// use par_stream::prelude::*;
23    ///
24    /// let result: Result<Vec<_>, _> = stream::iter(0..100)
25    ///     .map(|val| if val < 50 { Ok(val) } else { Err(val) })
26    ///     // add index number
27    ///     .try_enumerate()
28    ///     // double the values in parallel
29    ///     .try_par_then_unordered(None, move |(index, value)| {
30    ///         // the closure is sent to parallel worker
31    ///         async move { Ok((index, value * 2)) }
32    ///     })
33    ///     // add values by one in parallel
34    ///     .try_par_then_unordered(None, move |(index, value)| {
35    ///         // the closure is sent to parallel worker
36    ///         async move {
37    ///             let new_val = value + 1;
38    ///             if new_val < 50 {
39    ///                 Ok((index, new_val))
40    ///             } else {
41    ///                 Err(value)
42    ///             }
43    ///         }
44    ///     })
45    ///     // reorder the values according to index number
46    ///     .try_reorder_enumerated()
47    ///     .try_collect()
48    ///     .await;
49    /// # })
50    /// ```
51    fn try_reorder_enumerated(self) -> TryReorderEnumerated<Self, Self::Ok, Self::Error>;
52}
53
54impl<S, T, E> TryIndexStreamExt for S
55where
56    S: Stream<Item = Result<(usize, T), E>>,
57{
58    type Ok = T;
59    type Error = E;
60
61    fn try_reorder_enumerated(self) -> TryReorderEnumerated<Self, T, E> {
62        TryReorderEnumerated {
63            stream: self,
64            commit: 0,
65            pending_error: None,
66            is_terminated: false,
67            buffer: HashMap::new(),
68            _phantom: PhantomData,
69        }
70    }
71}
72
73// try_reorder_enumerated
74
75pub use try_reorder_enumerated::*;
76
77mod try_reorder_enumerated {
78    use super::*;
79
80    /// Stream for the [try_reorder_enumerated()](TryIndexStreamExt::try_reorder_enumerated) method.
81    #[derive(Derivative)]
82    #[derivative(Debug)]
83    #[pin_project]
84    pub struct TryReorderEnumerated<S, T, E>
85    where
86        S: ?Sized,
87    {
88        pub(super) commit: usize,
89        pub(super) is_terminated: bool,
90        pub(super) pending_error: Option<E>,
91        pub(super) buffer: HashMap<usize, T>,
92        pub(super) _phantom: PhantomData<E>,
93        #[pin]
94        #[derivative(Debug = "ignore")]
95        pub(super) stream: S,
96    }
97
98    impl<S, T, E> Stream for TryReorderEnumerated<S, T, E>
99    where
100        S: Stream<Item = Result<(usize, T), E>>,
101    {
102        type Item = Result<T, E>;
103
104        fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
105            let mut this = self.project();
106
107            Ready(loop {
108                if *this.is_terminated {
109                    break None;
110                } else if let Some(err) = this.pending_error.take() {
111                    if let Some(item) = this.buffer.remove(this.commit) {
112                        *this.pending_error = Some(err);
113                        *this.commit += 1;
114                        break Some(Ok(item));
115                    } else {
116                        *this.is_terminated = true;
117                        break Some(Err(err));
118                    }
119                } else if let Some(item) = this.buffer.remove(this.commit) {
120                    *this.commit += 1;
121                    break Some(Ok(item));
122                } else {
123                    match ready!(Pin::new(&mut this.stream).poll_next(cx)) {
124                        Some(Ok((index, item))) => match (*this.commit).cmp(&index) {
125                            Less => {
126                                let prev = this.buffer.insert(index, item);
127                                assert!(
128                                    prev.is_none(),
129                                    "the index number {} appears more than once",
130                                    index
131                                );
132                            }
133                            Equal => {
134                                *this.commit += 1;
135                                break Some(Ok(item));
136                            }
137                            Greater => {
138                                panic!("the index number {} appears more than once", index);
139                            }
140                        },
141                        Some(Err(err)) => {
142                            *this.pending_error = Some(err);
143                        }
144                        None => {
145                            assert!(
146                                this.buffer.is_empty(),
147                                "the item for index number {} is missing",
148                                this.commit
149                            );
150                            break None;
151                        }
152                    }
153                }
154            })
155        }
156    }
157
158    impl<S, T, E> FusedStream for TryReorderEnumerated<S, T, E>
159    where
160        Self: Stream,
161    {
162        fn is_terminated(&self) -> bool {
163            self.is_terminated
164        }
165    }
166}