1use core::{
2 fmt,
3 pin::Pin,
4 task::{Context, Poll},
5};
6
7use futures_util::{
8 stream::{
9 abortable, select_with_strategy, AbortHandle, Abortable, FusedStream, PollNext,
10 SelectWithStrategy,
11 },
12 Stream,
13};
14use pin_project_lite::pin_project;
15
16type Inner<St1, St2, State> =
18 SelectWithStrategy<St1, Abortable<St2>, fn(&mut State) -> PollNext, State>;
19
20pin_project! {
21 #[must_use = "streams do nothing unless polled"]
23 pub struct SelectUntilLeftIsDoneWithStrategy<St1, St2, Clos, State> {
24 #[pin]
25 inner: Inner<St1, St2, PollNext>,
26 abort_handle: AbortHandle,
27 state: State,
28 clos: Clos,
29 }
30}
31
32pub fn select_until_left_is_done_with_strategy<St1, St2, Clos, State>(
34 stream1: St1,
35 stream2: St2,
36 which: Clos,
37) -> SelectUntilLeftIsDoneWithStrategy<St1, St2, Clos, State>
38where
39 St1: Stream,
40 St2: Stream<Item = St1::Item>,
41 Clos: FnMut(&mut State) -> PollNext,
42 State: Default,
43{
44 let (stream2, abort_handle) = abortable(stream2);
45
46 SelectUntilLeftIsDoneWithStrategy {
47 inner: select_with_strategy(stream1, stream2, |last| last.toggle()),
48 abort_handle,
49 state: Default::default(),
50 clos: which,
51 }
52}
53
54impl<St1, St2, Clos, State> SelectUntilLeftIsDoneWithStrategy<St1, St2, Clos, State> {
56 pub fn get_ref(&self) -> (&St1, &Abortable<St2>) {
57 self.inner.get_ref()
58 }
59
60 pub fn get_mut(&mut self) -> (&mut St1, &mut Abortable<St2>) {
61 self.inner.get_mut()
62 }
63
64 pub fn get_pin_mut(self: Pin<&mut Self>) -> (Pin<&mut St1>, Pin<&mut Abortable<St2>>) {
65 let this = self.project();
66 this.inner.get_pin_mut()
67 }
68
69 pub fn into_inner(self) -> (St1, Abortable<St2>) {
70 self.inner.into_inner()
71 }
72}
73
74impl<St1, St2, Clos, State> FusedStream for SelectUntilLeftIsDoneWithStrategy<St1, St2, Clos, State>
76where
77 St1: Stream,
78 St2: Stream<Item = St1::Item>,
79 Clos: FnMut(&mut State) -> PollNext,
80{
81 fn is_terminated(&self) -> bool {
82 self.inner.is_terminated()
83 }
84}
85
86impl<St1, St2, Clos, State> Stream for SelectUntilLeftIsDoneWithStrategy<St1, St2, Clos, State>
88where
89 St1: Stream,
90 St2: Stream<Item = St1::Item>,
91 Clos: FnMut(&mut State) -> PollNext,
92{
93 type Item = St1::Item;
94
95 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<St1::Item>> {
96 let this = self.project();
97 let (left, right) = this.inner.get_pin_mut();
98
99 match (this.clos)(this.state) {
100 PollNext::Left => {
101 let left_done = match left.poll_next(cx) {
102 Poll::Ready(Some(item)) => return Poll::Ready(Some(item)),
103 Poll::Ready(None) => {
104 this.abort_handle.abort();
105 true
106 }
107 Poll::Pending => false,
108 };
109
110 match right.poll_next(cx) {
111 Poll::Ready(Some(item)) => Poll::Ready(Some(item)),
112 Poll::Ready(None) if left_done => Poll::Ready(None),
113 Poll::Ready(None) | Poll::Pending => Poll::Pending,
114 }
115 }
116 PollNext::Right => {
117 let right_done = match right.poll_next(cx) {
118 Poll::Ready(Some(item)) => return Poll::Ready(Some(item)),
119 Poll::Ready(None) => true,
120 Poll::Pending => false,
121 };
122
123 match left.poll_next(cx) {
124 Poll::Ready(Some(item)) => Poll::Ready(Some(item)),
125 Poll::Ready(None) if right_done => Poll::Ready(None),
126 Poll::Ready(None) => {
127 this.abort_handle.abort();
128 Poll::Pending
129 }
130 Poll::Pending => Poll::Pending,
131 }
132 }
133 }
134 }
135}
136
137impl<St1, St2, Clos, State> fmt::Debug for SelectUntilLeftIsDoneWithStrategy<St1, St2, Clos, State>
139where
140 St1: fmt::Debug,
141 St2: fmt::Debug,
142 State: fmt::Debug,
143{
144 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
145 let (stream1, stream2) = self.get_ref();
146
147 f.debug_struct("SelectUntilLeftIsDoneWithStrategy")
148 .field("stream1", &stream1)
149 .field("stream2", &stream2)
150 .field("state", &self.state)
151 .finish()
152 }
153}
154
155#[cfg(test)]
156mod tests {
157 use super::*;
158
159 use alloc::{vec, vec::Vec};
160
161 use futures_util::{stream, StreamExt as _};
162
163 fn round_robin(last: &mut PollNext) -> PollNext {
164 last.toggle()
165 }
166
167 fn right_right_left(i: &mut usize) -> PollNext {
168 let poll_next = if *i % 3 == 2 {
169 PollNext::Left
170 } else {
171 PollNext::Right
172 };
173
174 *i += 1;
175 poll_next
176 }
177
178 #[test]
179 fn test_with_round_robin() {
180 futures_executor::block_on(async {
181 for (range, ret) in vec![
182 (1..=1, vec![1, 0]),
183 (1..=2, vec![1, 0, 2, 0]),
184 (1..=3, vec![1, 0, 2, 0, 3, 0]),
185 (1..=4, vec![1, 0, 2, 0, 3, 0, 4, 0]),
186 (1..=5, vec![1, 0, 2, 0, 3, 0, 4, 0, 5, 0]),
187 ] {
188 let st1 = stream::iter(range).boxed();
189 let st2 = stream::repeat(0);
190
191 let st = select_until_left_is_done_with_strategy(st1, st2, round_robin);
192
193 assert_eq!(st.collect::<Vec<_>>().await, ret);
194 }
195 })
196 }
197
198 #[test]
199 fn test_with_right_right_left() {
200 futures_executor::block_on(async {
201 for (range, ret) in vec![
202 (1..=1, vec![0, 0, 1, 0, 0]),
203 (1..=2, vec![0, 0, 1, 0, 0, 2, 0, 0]),
204 (1..=3, vec![0, 0, 1, 0, 0, 2, 0, 0, 3, 0, 0]),
205 (1..=4, vec![0, 0, 1, 0, 0, 2, 0, 0, 3, 0, 0, 4, 0, 0]),
206 (
207 1..=5,
208 vec![0, 0, 1, 0, 0, 2, 0, 0, 3, 0, 0, 4, 0, 0, 5, 0, 0],
209 ),
210 ] {
211 let st1 = stream::iter(range).boxed();
212 let st2 = stream::repeat(0);
213
214 let st = select_until_left_is_done_with_strategy(st1, st2, right_right_left);
215
216 assert_eq!(st.collect::<Vec<_>>().await, ret);
217 }
218 })
219 }
220
221 #[tokio::test]
222 async fn test_with_round_robin_and_right_long_sleep() {
223 for (range, ret) in vec![
224 (1..=1, vec![1]),
225 (1..=2, vec![1, 2]),
226 (1..=3, vec![1, 2, 3]),
227 (1..=4, vec![1, 2, 3, 4]),
228 (1..=5, vec![1, 2, 3, 4, 5]),
229 ] {
230 let st1 = stream::iter(range).boxed();
231 let st2 = stream::repeat(0)
232 .then(|n| async move {
233 tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
234 n
235 })
236 .boxed();
237
238 let st = select_until_left_is_done_with_strategy(st1, st2, round_robin);
239
240 #[cfg(feature = "std")]
241 let now = std::time::Instant::now();
242
243 assert_eq!(st.collect::<Vec<_>>().await, ret);
244
245 #[cfg(feature = "std")]
246 assert!(now.elapsed() < core::time::Duration::from_secs(1));
247 }
248 }
249
250 #[tokio::test]
251 async fn test_with_round_robin_and_both_sleep() {
252 for (range, ret_vec) in vec![
253 (1..=1, vec![vec![1]]),
254 (1..=2, vec![vec![1, 0, 2]]),
255 (1..=3, vec![vec![1, 0, 2, 3]]),
256 (1..=4, vec![vec![1, 0, 2, 3, 0, 4]]),
257 (1..=5, vec![vec![1, 0, 2, 3, 0, 4, 0, 5]]),
258 ] {
259 let st1 = stream::iter(range)
260 .then(|n| async move {
261 tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
262 n
263 })
264 .boxed();
265 let st2 = stream::repeat(0)
266 .then(|n| async move {
267 tokio::time::sleep(tokio::time::Duration::from_millis(160)).await;
268 n
269 })
270 .boxed();
271
272 let st = select_until_left_is_done_with_strategy(st1, st2, round_robin);
273
274 #[cfg(feature = "std")]
275 let now = std::time::Instant::now();
276
277 let ret = st.collect::<Vec<_>>().await;
278 #[cfg(feature = "std")]
279 println!("ret {:?}", ret);
280 assert!(ret_vec.contains(&ret));
281
282 #[cfg(feature = "std")]
283 assert!(now.elapsed() < core::time::Duration::from_secs(1));
284 }
285 }
286
287 #[tokio::test]
288 async fn test_with_round_robin_and_both_sleep_2() {
289 for (range, ret_vec) in vec![
290 (1..=1, vec![vec![0, 1]]),
291 (1..=2, vec![vec![0, 1, 0, 2]]),
292 (1..=3, vec![vec![0, 1, 0, 2, 0, 0, 3]]),
293 (1..=4, vec![vec![0, 1, 0, 2, 0, 0, 3, 0, 4]]),
294 (
295 1..=5,
296 vec![
297 vec![0, 1, 0, 2, 0, 0, 3, 0, 4, 0, 0, 5],
298 vec![0, 1, 0, 2, 0, 0, 3, 0, 4, 0, 5],
299 ],
300 ),
301 ] {
302 let st1 = stream::iter(range)
303 .then(|n| async move {
304 tokio::time::sleep(tokio::time::Duration::from_millis(140)).await;
305 n
306 })
307 .boxed();
308 let st2 = stream::repeat(0)
309 .then(|n| async move {
310 tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
311 n
312 })
313 .boxed();
314
315 let st = select_until_left_is_done_with_strategy(st1, st2, round_robin);
316
317 #[cfg(feature = "std")]
318 let now = std::time::Instant::now();
319
320 let ret = st.collect::<Vec<_>>().await;
321 #[cfg(feature = "std")]
322 println!("ret {:?}", ret);
323 assert!(ret_vec.contains(&ret));
324
325 #[cfg(feature = "std")]
326 assert!(now.elapsed() < core::time::Duration::from_secs(1));
327 }
328 }
329
330 #[tokio::test]
331 async fn test_with_right_right_left_and_both_sleep() {
332 for (range, ret_vec) in vec![
333 (1..=1, vec![vec![0, 1]]),
334 (1..=2, vec![vec![0, 1, 0, 0, 2]]),
335 (1..=3, vec![vec![0, 1, 0, 0, 2, 3]]),
336 (1..=4, vec![vec![0, 1, 0, 0, 2, 3, 4]]),
337 (1..=5, vec![vec![0, 1, 0, 0, 2, 3, 4, 5]]),
338 ] {
339 let st1 = stream::iter(range)
340 .then(|n| async move {
341 tokio::time::sleep(tokio::time::Duration::from_millis(60)).await;
342 n
343 })
344 .boxed();
345 let st2 = stream::iter(vec![0, 0, 0])
346 .then(|n| async move {
347 tokio::time::sleep(tokio::time::Duration::from_millis(35)).await;
348 n
349 })
350 .boxed();
351
352 let st = select_until_left_is_done_with_strategy(st1, st2, right_right_left);
353
354 #[cfg(feature = "std")]
355 let now = std::time::Instant::now();
356
357 let ret = st.collect::<Vec<_>>().await;
358 #[cfg(feature = "std")]
359 println!("ret {:?}", ret);
360 assert!(ret_vec.contains(&ret));
361
362 #[cfg(feature = "std")]
363 assert!(now.elapsed() < core::time::Duration::from_secs(1));
364 }
365 }
366}