futures_stream_select_ext/
select_until_left_is_done_with_strategy.rs

1use core::{
2    fmt,
3    pin::Pin,
4    task::{Context, Poll},
5};
6
7use futures_util::{
8    stream::{
9        abortable, select_with_strategy, AbortHandle, Abortable, FusedStream, PollNext,
10        SelectWithStrategy,
11    },
12    Stream,
13};
14use pin_project_lite::pin_project;
15
16//
17type Inner<St1, St2, State> =
18    SelectWithStrategy<St1, Abortable<St2>, fn(&mut State) -> PollNext, State>;
19
20pin_project! {
21    /// Stream for the [`select_until_left_is_done_with_strategy()`] function. See function docs for details.
22    #[must_use = "streams do nothing unless polled"]
23    pub struct SelectUntilLeftIsDoneWithStrategy<St1, St2, Clos, State> {
24        #[pin]
25        inner: Inner<St1, St2, PollNext>,
26        abort_handle: AbortHandle,
27        state: State,
28        clos: Clos,
29    }
30}
31
32//
33pub fn select_until_left_is_done_with_strategy<St1, St2, Clos, State>(
34    stream1: St1,
35    stream2: St2,
36    which: Clos,
37) -> SelectUntilLeftIsDoneWithStrategy<St1, St2, Clos, State>
38where
39    St1: Stream,
40    St2: Stream<Item = St1::Item>,
41    Clos: FnMut(&mut State) -> PollNext,
42    State: Default,
43{
44    let (stream2, abort_handle) = abortable(stream2);
45
46    SelectUntilLeftIsDoneWithStrategy {
47        inner: select_with_strategy(stream1, stream2, |last| last.toggle()),
48        abort_handle,
49        state: Default::default(),
50        clos: which,
51    }
52}
53
54//
55impl<St1, St2, Clos, State> SelectUntilLeftIsDoneWithStrategy<St1, St2, Clos, State> {
56    pub fn get_ref(&self) -> (&St1, &Abortable<St2>) {
57        self.inner.get_ref()
58    }
59
60    pub fn get_mut(&mut self) -> (&mut St1, &mut Abortable<St2>) {
61        self.inner.get_mut()
62    }
63
64    pub fn get_pin_mut(self: Pin<&mut Self>) -> (Pin<&mut St1>, Pin<&mut Abortable<St2>>) {
65        let this = self.project();
66        this.inner.get_pin_mut()
67    }
68
69    pub fn into_inner(self) -> (St1, Abortable<St2>) {
70        self.inner.into_inner()
71    }
72}
73
74//
75impl<St1, St2, Clos, State> FusedStream for SelectUntilLeftIsDoneWithStrategy<St1, St2, Clos, State>
76where
77    St1: Stream,
78    St2: Stream<Item = St1::Item>,
79    Clos: FnMut(&mut State) -> PollNext,
80{
81    fn is_terminated(&self) -> bool {
82        self.inner.is_terminated()
83    }
84}
85
86//
87impl<St1, St2, Clos, State> Stream for SelectUntilLeftIsDoneWithStrategy<St1, St2, Clos, State>
88where
89    St1: Stream,
90    St2: Stream<Item = St1::Item>,
91    Clos: FnMut(&mut State) -> PollNext,
92{
93    type Item = St1::Item;
94
95    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<St1::Item>> {
96        let this = self.project();
97        let (left, right) = this.inner.get_pin_mut();
98
99        match (this.clos)(this.state) {
100            PollNext::Left => {
101                let left_done = match left.poll_next(cx) {
102                    Poll::Ready(Some(item)) => return Poll::Ready(Some(item)),
103                    Poll::Ready(None) => {
104                        this.abort_handle.abort();
105                        true
106                    }
107                    Poll::Pending => false,
108                };
109
110                match right.poll_next(cx) {
111                    Poll::Ready(Some(item)) => Poll::Ready(Some(item)),
112                    Poll::Ready(None) if left_done => Poll::Ready(None),
113                    Poll::Ready(None) | Poll::Pending => Poll::Pending,
114                }
115            }
116            PollNext::Right => {
117                let right_done = match right.poll_next(cx) {
118                    Poll::Ready(Some(item)) => return Poll::Ready(Some(item)),
119                    Poll::Ready(None) => true,
120                    Poll::Pending => false,
121                };
122
123                match left.poll_next(cx) {
124                    Poll::Ready(Some(item)) => Poll::Ready(Some(item)),
125                    Poll::Ready(None) if right_done => Poll::Ready(None),
126                    Poll::Ready(None) => {
127                        this.abort_handle.abort();
128                        Poll::Pending
129                    }
130                    Poll::Pending => Poll::Pending,
131                }
132            }
133        }
134    }
135}
136
137//
138impl<St1, St2, Clos, State> fmt::Debug for SelectUntilLeftIsDoneWithStrategy<St1, St2, Clos, State>
139where
140    St1: fmt::Debug,
141    St2: fmt::Debug,
142    State: fmt::Debug,
143{
144    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
145        let (stream1, stream2) = self.get_ref();
146
147        f.debug_struct("SelectUntilLeftIsDoneWithStrategy")
148            .field("stream1", &stream1)
149            .field("stream2", &stream2)
150            .field("state", &self.state)
151            .finish()
152    }
153}
154
155#[cfg(test)]
156mod tests {
157    use super::*;
158
159    use alloc::{vec, vec::Vec};
160
161    use futures_util::{stream, StreamExt as _};
162
163    fn round_robin(last: &mut PollNext) -> PollNext {
164        last.toggle()
165    }
166
167    fn right_right_left(i: &mut usize) -> PollNext {
168        let poll_next = if *i % 3 == 2 {
169            PollNext::Left
170        } else {
171            PollNext::Right
172        };
173
174        *i += 1;
175        poll_next
176    }
177
178    #[test]
179    fn test_with_round_robin() {
180        futures_executor::block_on(async {
181            for (range, ret) in vec![
182                (1..=1, vec![1, 0]),
183                (1..=2, vec![1, 0, 2, 0]),
184                (1..=3, vec![1, 0, 2, 0, 3, 0]),
185                (1..=4, vec![1, 0, 2, 0, 3, 0, 4, 0]),
186                (1..=5, vec![1, 0, 2, 0, 3, 0, 4, 0, 5, 0]),
187            ] {
188                let st1 = stream::iter(range).boxed();
189                let st2 = stream::repeat(0);
190
191                let st = select_until_left_is_done_with_strategy(st1, st2, round_robin);
192
193                assert_eq!(st.collect::<Vec<_>>().await, ret);
194            }
195        })
196    }
197
198    #[test]
199    fn test_with_right_right_left() {
200        futures_executor::block_on(async {
201            for (range, ret) in vec![
202                (1..=1, vec![0, 0, 1, 0, 0]),
203                (1..=2, vec![0, 0, 1, 0, 0, 2, 0, 0]),
204                (1..=3, vec![0, 0, 1, 0, 0, 2, 0, 0, 3, 0, 0]),
205                (1..=4, vec![0, 0, 1, 0, 0, 2, 0, 0, 3, 0, 0, 4, 0, 0]),
206                (
207                    1..=5,
208                    vec![0, 0, 1, 0, 0, 2, 0, 0, 3, 0, 0, 4, 0, 0, 5, 0, 0],
209                ),
210            ] {
211                let st1 = stream::iter(range).boxed();
212                let st2 = stream::repeat(0);
213
214                let st = select_until_left_is_done_with_strategy(st1, st2, right_right_left);
215
216                assert_eq!(st.collect::<Vec<_>>().await, ret);
217            }
218        })
219    }
220
221    #[tokio::test]
222    async fn test_with_round_robin_and_right_long_sleep() {
223        for (range, ret) in vec![
224            (1..=1, vec![1]),
225            (1..=2, vec![1, 2]),
226            (1..=3, vec![1, 2, 3]),
227            (1..=4, vec![1, 2, 3, 4]),
228            (1..=5, vec![1, 2, 3, 4, 5]),
229        ] {
230            let st1 = stream::iter(range).boxed();
231            let st2 = stream::repeat(0)
232                .then(|n| async move {
233                    tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
234                    n
235                })
236                .boxed();
237
238            let st = select_until_left_is_done_with_strategy(st1, st2, round_robin);
239
240            #[cfg(feature = "std")]
241            let now = std::time::Instant::now();
242
243            assert_eq!(st.collect::<Vec<_>>().await, ret);
244
245            #[cfg(feature = "std")]
246            assert!(now.elapsed() < core::time::Duration::from_secs(1));
247        }
248    }
249
250    #[tokio::test]
251    async fn test_with_round_robin_and_both_sleep() {
252        for (range, ret_vec) in vec![
253            (1..=1, vec![vec![1]]),
254            (1..=2, vec![vec![1, 0, 2]]),
255            (1..=3, vec![vec![1, 0, 2, 3]]),
256            (1..=4, vec![vec![1, 0, 2, 3, 0, 4]]),
257            (1..=5, vec![vec![1, 0, 2, 3, 0, 4, 0, 5]]),
258        ] {
259            let st1 = stream::iter(range)
260                .then(|n| async move {
261                    tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
262                    n
263                })
264                .boxed();
265            let st2 = stream::repeat(0)
266                .then(|n| async move {
267                    tokio::time::sleep(tokio::time::Duration::from_millis(160)).await;
268                    n
269                })
270                .boxed();
271
272            let st = select_until_left_is_done_with_strategy(st1, st2, round_robin);
273
274            #[cfg(feature = "std")]
275            let now = std::time::Instant::now();
276
277            let ret = st.collect::<Vec<_>>().await;
278            #[cfg(feature = "std")]
279            println!("ret {:?}", ret);
280            assert!(ret_vec.contains(&ret));
281
282            #[cfg(feature = "std")]
283            assert!(now.elapsed() < core::time::Duration::from_secs(1));
284        }
285    }
286
287    #[tokio::test]
288    async fn test_with_round_robin_and_both_sleep_2() {
289        for (range, ret_vec) in vec![
290            (1..=1, vec![vec![0, 1]]),
291            (1..=2, vec![vec![0, 1, 0, 2]]),
292            (1..=3, vec![vec![0, 1, 0, 2, 0, 0, 3]]),
293            (1..=4, vec![vec![0, 1, 0, 2, 0, 0, 3, 0, 4]]),
294            (
295                1..=5,
296                vec![
297                    vec![0, 1, 0, 2, 0, 0, 3, 0, 4, 0, 0, 5],
298                    vec![0, 1, 0, 2, 0, 0, 3, 0, 4, 0, 5],
299                ],
300            ),
301        ] {
302            let st1 = stream::iter(range)
303                .then(|n| async move {
304                    tokio::time::sleep(tokio::time::Duration::from_millis(140)).await;
305                    n
306                })
307                .boxed();
308            let st2 = stream::repeat(0)
309                .then(|n| async move {
310                    tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
311                    n
312                })
313                .boxed();
314
315            let st = select_until_left_is_done_with_strategy(st1, st2, round_robin);
316
317            #[cfg(feature = "std")]
318            let now = std::time::Instant::now();
319
320            let ret = st.collect::<Vec<_>>().await;
321            #[cfg(feature = "std")]
322            println!("ret {:?}", ret);
323            assert!(ret_vec.contains(&ret));
324
325            #[cfg(feature = "std")]
326            assert!(now.elapsed() < core::time::Duration::from_secs(1));
327        }
328    }
329
330    #[tokio::test]
331    async fn test_with_right_right_left_and_both_sleep() {
332        for (range, ret_vec) in vec![
333            (1..=1, vec![vec![0, 1]]),
334            (1..=2, vec![vec![0, 1, 0, 0, 2]]),
335            (1..=3, vec![vec![0, 1, 0, 0, 2, 3]]),
336            (1..=4, vec![vec![0, 1, 0, 0, 2, 3, 4]]),
337            (1..=5, vec![vec![0, 1, 0, 0, 2, 3, 4, 5]]),
338        ] {
339            let st1 = stream::iter(range)
340                .then(|n| async move {
341                    tokio::time::sleep(tokio::time::Duration::from_millis(60)).await;
342                    n
343                })
344                .boxed();
345            let st2 = stream::iter(vec![0, 0, 0])
346                .then(|n| async move {
347                    tokio::time::sleep(tokio::time::Duration::from_millis(35)).await;
348                    n
349                })
350                .boxed();
351
352            let st = select_until_left_is_done_with_strategy(st1, st2, right_right_left);
353
354            #[cfg(feature = "std")]
355            let now = std::time::Instant::now();
356
357            let ret = st.collect::<Vec<_>>().await;
358            #[cfg(feature = "std")]
359            println!("ret {:?}", ret);
360            assert!(ret_vec.contains(&ret));
361
362            #[cfg(feature = "std")]
363            assert!(now.elapsed() < core::time::Duration::from_secs(1));
364        }
365    }
366}