futures_bounded/
stream_map.rs

1use std::mem;
2use std::pin::Pin;
3use std::task::{Context, Poll, Waker};
4use std::time::Duration;
5
6use futures_util::stream::{BoxStream, SelectAll};
7use futures_util::{stream, FutureExt, Stream, StreamExt};
8
9use crate::{Delay, PushError, Timeout};
10
11/// Represents a map of [`Stream`]s.
12///
13/// Each stream must finish within the specified time and the map never outgrows its capacity.
14pub struct StreamMap<ID, O> {
15    make_delay: Box<dyn Fn() -> Delay + Send + Sync>,
16    capacity: usize,
17    inner: SelectAll<TaggedStream<ID, TimeoutStream<BoxStream<'static, O>>>>,
18    empty_waker: Option<Waker>,
19    full_waker: Option<Waker>,
20}
21
22impl<ID, O> StreamMap<ID, O>
23where
24    ID: Clone + Unpin,
25{
26    pub fn new(make_delay: impl Fn() -> Delay + Send + Sync + 'static, capacity: usize) -> Self {
27        Self {
28            make_delay: Box::new(make_delay),
29            capacity,
30            inner: Default::default(),
31            empty_waker: None,
32            full_waker: None,
33        }
34    }
35}
36
37impl<ID, O> StreamMap<ID, O>
38where
39    ID: Clone + PartialEq + Send + Unpin + 'static,
40    O: Send + 'static,
41{
42    /// Push a stream into the map.
43    pub fn try_push<F>(&mut self, id: ID, stream: F) -> Result<(), PushError<BoxStream<O>>>
44    where
45        F: Stream<Item = O> + Send + 'static,
46    {
47        if self.inner.len() >= self.capacity {
48            return Err(PushError::BeyondCapacity(stream.boxed()));
49        }
50
51        if let Some(waker) = self.empty_waker.take() {
52            waker.wake();
53        }
54
55        let old = self.remove(id.clone());
56        self.inner.push(TaggedStream::new(
57            id,
58            TimeoutStream {
59                inner: stream.boxed(),
60                timeout: (self.make_delay)(),
61            },
62        ));
63
64        match old {
65            None => Ok(()),
66            Some(old) => Err(PushError::Replaced(old)),
67        }
68    }
69
70    pub fn remove(&mut self, id: ID) -> Option<BoxStream<'static, O>> {
71        let tagged = self.inner.iter_mut().find(|s| s.key == id)?;
72
73        let inner = mem::replace(&mut tagged.inner.inner, stream::pending().boxed());
74        tagged.exhausted = true; // Setting this will emit `None` on the next poll and ensure `SelectAll` cleans up the resources.
75
76        Some(inner)
77    }
78
79    pub fn len(&self) -> usize {
80        self.inner.len()
81    }
82
83    pub fn is_empty(&self) -> bool {
84        self.inner.is_empty()
85    }
86
87    #[allow(unknown_lints, clippy::needless_pass_by_ref_mut)] // &mut Context is idiomatic.
88    pub fn poll_ready_unpin(&mut self, cx: &mut Context<'_>) -> Poll<()> {
89        if self.inner.len() < self.capacity {
90            return Poll::Ready(());
91        }
92
93        self.full_waker = Some(cx.waker().clone());
94
95        Poll::Pending
96    }
97
98    pub fn poll_next_unpin(
99        &mut self,
100        cx: &mut Context<'_>,
101    ) -> Poll<(ID, Option<Result<O, Timeout>>)> {
102        match futures_util::ready!(self.inner.poll_next_unpin(cx)) {
103            None => {
104                self.empty_waker = Some(cx.waker().clone());
105                Poll::Pending
106            }
107            Some((id, Some(Ok(output)))) => Poll::Ready((id, Some(Ok(output)))),
108            Some((id, Some(Err(dur)))) => {
109                self.remove(id.clone()); // Remove stream, otherwise we keep reporting the timeout.
110
111                Poll::Ready((id, Some(Err(Timeout::new(dur)))))
112            }
113            Some((id, None)) => Poll::Ready((id, None)),
114        }
115    }
116}
117
118struct TimeoutStream<S> {
119    inner: S,
120    timeout: Delay,
121}
122
123impl<F> Stream for TimeoutStream<F>
124where
125    F: Stream + Unpin,
126{
127    type Item = Result<F::Item, Duration>;
128
129    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
130        if let Poll::Ready(dur) = self.timeout.poll_unpin(cx) {
131            return Poll::Ready(Some(Err(dur)));
132        }
133
134        self.inner.poll_next_unpin(cx).map(|a| a.map(Ok))
135    }
136}
137
138struct TaggedStream<K, S> {
139    key: K,
140    inner: S,
141
142    exhausted: bool,
143}
144
145impl<K, S> TaggedStream<K, S> {
146    fn new(key: K, inner: S) -> Self {
147        Self {
148            key,
149            inner,
150            exhausted: false,
151        }
152    }
153}
154
155impl<K, S> Stream for TaggedStream<K, S>
156where
157    K: Clone + Unpin,
158    S: Stream + Unpin,
159{
160    type Item = (K, Option<S::Item>);
161
162    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
163        if self.exhausted {
164            return Poll::Ready(None);
165        }
166
167        match futures_util::ready!(self.inner.poll_next_unpin(cx)) {
168            Some(item) => Poll::Ready(Some((self.key.clone(), Some(item)))),
169            None => {
170                self.exhausted = true;
171
172                Poll::Ready(Some((self.key.clone(), None)))
173            }
174        }
175    }
176}
177
178#[cfg(all(test, feature = "futures-timer"))]
179mod tests {
180    use futures::channel::mpsc;
181    use futures_util::stream::{once, pending};
182    use futures_util::SinkExt;
183    use std::future::{poll_fn, ready, Future};
184    use std::pin::Pin;
185    use std::time::Instant;
186
187    use super::*;
188
189    #[test]
190    fn cannot_push_more_than_capacity_tasks() {
191        let mut streams = StreamMap::new(|| Delay::futures_timer(Duration::from_secs(10)), 1);
192
193        assert!(streams.try_push("ID_1", once(ready(()))).is_ok());
194        matches!(
195            streams.try_push("ID_2", once(ready(()))),
196            Err(PushError::BeyondCapacity(_))
197        );
198    }
199
200    #[test]
201    fn cannot_push_the_same_id_few_times() {
202        let mut streams = StreamMap::new(|| Delay::futures_timer(Duration::from_secs(10)), 5);
203
204        assert!(streams.try_push("ID", once(ready(()))).is_ok());
205        matches!(
206            streams.try_push("ID", once(ready(()))),
207            Err(PushError::Replaced(_))
208        );
209    }
210
211    #[tokio::test]
212    async fn streams_timeout() {
213        let mut streams = StreamMap::new(|| Delay::futures_timer(Duration::from_millis(100)), 1);
214
215        let _ = streams.try_push("ID", pending::<()>());
216        futures_timer::Delay::new(Duration::from_millis(150)).await;
217        let (_, result) = poll_fn(|cx| streams.poll_next_unpin(cx)).await;
218
219        assert!(result.unwrap().is_err())
220    }
221
222    #[tokio::test]
223    async fn timed_out_stream_gets_removed() {
224        let mut streams = StreamMap::new(|| Delay::futures_timer(Duration::from_millis(100)), 1);
225
226        let _ = streams.try_push("ID", pending::<()>());
227        futures_timer::Delay::new(Duration::from_millis(150)).await;
228        poll_fn(|cx| streams.poll_next_unpin(cx)).await;
229
230        let poll = streams.poll_next_unpin(&mut Context::from_waker(
231            futures_util::task::noop_waker_ref(),
232        ));
233        assert!(poll.is_pending())
234    }
235
236    #[test]
237    fn removing_stream() {
238        let mut streams = StreamMap::new(|| Delay::futures_timer(Duration::from_millis(100)), 1);
239
240        let _ = streams.try_push("ID", stream::once(ready(())));
241
242        {
243            let cancelled_stream = streams.remove("ID");
244            assert!(cancelled_stream.is_some());
245        }
246
247        let poll = streams.poll_next_unpin(&mut Context::from_waker(
248            futures_util::task::noop_waker_ref(),
249        ));
250
251        assert!(poll.is_pending());
252        assert_eq!(
253            streams.len(),
254            0,
255            "resources of cancelled streams are cleaned up properly"
256        );
257    }
258
259    #[tokio::test]
260    async fn replaced_stream_is_still_registered() {
261        let mut streams = StreamMap::new(|| Delay::futures_timer(Duration::from_millis(100)), 3);
262
263        let (mut tx1, rx1) = mpsc::channel(5);
264        let (mut tx2, rx2) = mpsc::channel(5);
265
266        let _ = streams.try_push("ID1", rx1);
267        let _ = streams.try_push("ID2", rx2);
268
269        let _ = tx2.send(2).await;
270        let _ = tx1.send(1).await;
271        let _ = tx2.send(3).await;
272        let (id, res) = poll_fn(|cx| streams.poll_next_unpin(cx)).await;
273        assert_eq!(id, "ID1");
274        assert_eq!(res.unwrap().unwrap(), 1);
275        let (id, res) = poll_fn(|cx| streams.poll_next_unpin(cx)).await;
276        assert_eq!(id, "ID2");
277        assert_eq!(res.unwrap().unwrap(), 2);
278        let (id, res) = poll_fn(|cx| streams.poll_next_unpin(cx)).await;
279        assert_eq!(id, "ID2");
280        assert_eq!(res.unwrap().unwrap(), 3);
281
282        let (mut new_tx1, new_rx1) = mpsc::channel(5);
283        let replaced = streams.try_push("ID1", new_rx1);
284        assert!(matches!(replaced.unwrap_err(), PushError::Replaced(_)));
285
286        let _ = new_tx1.send(4).await;
287        let (id, res) = poll_fn(|cx| streams.poll_next_unpin(cx)).await;
288
289        assert_eq!(id, "ID1");
290        assert_eq!(res.unwrap().unwrap(), 4);
291    }
292
293    // Each stream emits 1 item with delay, `Task` only has a capacity of 1, meaning they must be processed in sequence.
294    // We stop after NUM_STREAMS tasks, meaning the overall execution must at least take DELAY * NUM_STREAMS.
295    #[tokio::test]
296    async fn backpressure() {
297        const DELAY: Duration = Duration::from_millis(100);
298        const NUM_STREAMS: u32 = 10;
299
300        let start = Instant::now();
301        Task::new(DELAY, NUM_STREAMS, 1).await;
302        let duration = start.elapsed();
303
304        assert!(duration >= DELAY * NUM_STREAMS);
305    }
306
307    struct Task {
308        item_delay: Duration,
309        num_streams: usize,
310        num_processed: usize,
311        inner: StreamMap<u8, ()>,
312    }
313
314    impl Task {
315        fn new(item_delay: Duration, num_streams: u32, capacity: usize) -> Self {
316            Self {
317                item_delay,
318                num_streams: num_streams as usize,
319                num_processed: 0,
320                inner: StreamMap::new(|| Delay::futures_timer(Duration::from_secs(60)), capacity),
321            }
322        }
323    }
324
325    impl Future for Task {
326        type Output = ();
327
328        fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
329            let this = self.get_mut();
330
331            while this.num_processed < this.num_streams {
332                match this.inner.poll_next_unpin(cx) {
333                    Poll::Ready((_, Some(result))) => {
334                        if result.is_err() {
335                            panic!("Timeout is great than item delay")
336                        }
337
338                        this.num_processed += 1;
339                        continue;
340                    }
341                    Poll::Ready((_, None)) => {
342                        continue;
343                    }
344                    _ => {}
345                }
346
347                if let Poll::Ready(()) = this.inner.poll_ready_unpin(cx) {
348                    // We push the constant ID to prove that user can use the same ID if the stream was finished
349                    let maybe_future = this
350                        .inner
351                        .try_push(1u8, once(futures_timer::Delay::new(this.item_delay)));
352                    assert!(maybe_future.is_ok(), "we polled for readiness");
353
354                    continue;
355                }
356
357                return Poll::Pending;
358            }
359
360            Poll::Ready(())
361        }
362    }
363}