streamtools/
fast_forward.rs

1use std::{pin::Pin, task::Poll};
2
3use futures::{stream::FusedStream, Stream, StreamExt};
4use pin_project_lite::pin_project;
5
6pin_project! {
7    /// Stream for the [`fast_forward`](crate::StreamTools::fast_forward) method.
8    #[must_use = "streams do nothing unless polled"]
9    pub struct FastForward<S> {
10        #[pin]
11        inner: Option<S>
12    }
13}
14
15impl<S> FastForward<S> {
16    pub(super) fn new(stream: S) -> Self {
17        Self {
18            inner: Some(stream),
19        }
20    }
21}
22
23impl<S> Stream for FastForward<S>
24where
25    S: Stream,
26{
27    type Item = S::Item;
28
29    fn poll_next(
30        self: Pin<&mut Self>,
31        cx: &mut std::task::Context<'_>,
32    ) -> std::task::Poll<Option<Self::Item>> {
33        let mut this = self.project();
34
35        let Some(mut inner) = this.inner.as_mut().as_pin_mut() else {
36            // Last time we polled, the inner stream terminated, but we yielded a value.
37            // If we are here then it's time to terminate.
38            return Poll::Ready(None)
39        };
40
41        let mut last_value = None;
42
43        while let Poll::Ready(ready) = inner.poll_next_unpin(cx) {
44            match ready {
45                Some(value) => {
46                    last_value = Some(value);
47                }
48                None => {
49                    // Clear inner so that if we poll again we will _definitely_ return Poll::Ready(None)
50                    this.inner.set(None);
51                    break;
52                }
53            }
54        }
55
56        match last_value {
57            Some(value) => Poll::Ready(Some(value)),
58            None => match this.inner.as_pin_mut() {
59                Some(_) => Poll::Pending, // The stream didn't terminate yet, so we must be pending
60                None => Poll::Ready(None), // The stream did terminate and there was no value seen, so we are done.
61            },
62        }
63    }
64}
65
66impl<S> FusedStream for FastForward<S>
67where
68    S: Stream,
69{
70    fn is_terminated(&self) -> bool {
71        self.inner.is_none()
72    }
73}
74
75impl<S> std::fmt::Debug for FastForward<S>
76where
77    S: std::fmt::Debug,
78{
79    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80        f.debug_struct("FlattenSwitch")
81            .field("inner", &self.inner)
82            .finish()
83    }
84}
85
86#[cfg(test)]
87mod tests {
88    use futures::{stream, SinkExt};
89    use tokio_test::{assert_pending, assert_ready_eq};
90
91    use super::*;
92
93    #[tokio::test]
94    async fn test_fast_forward() {
95        let waker = futures::task::noop_waker_ref();
96        let mut cx = std::task::Context::from_waker(&waker);
97
98        let (mut tx, rx) = futures::channel::mpsc::unbounded();
99
100        let mut stream = FastForward::new(rx);
101
102        assert_pending!(stream.poll_next_unpin(&mut cx));
103
104        tx.send(1).await.unwrap();
105        assert_ready_eq!(stream.poll_next_unpin(&mut cx), Some(1));
106        assert_pending!(stream.poll_next_unpin(&mut cx));
107
108        tx.send(2).await.unwrap(); // This value gets skipped
109        tx.send(3).await.unwrap();
110
111        assert_ready_eq!(stream.poll_next_unpin(&mut cx), Some(3));
112        assert_pending!(stream.poll_next_unpin(&mut cx));
113
114        // Send a value and then terminate the stream
115        tx.send(4).await.unwrap();
116        drop(tx);
117
118        assert_ready_eq!(stream.poll_next_unpin(&mut cx), Some(4)); // We still see the last value even though the stream was terminated
119        assert_ready_eq!(stream.poll_next_unpin(&mut cx), None);
120        assert_ready_eq!(stream.poll_next_unpin(&mut cx), None); // Should continue to return None if polled again
121    }
122
123    #[tokio::test]
124    async fn test_fast_forward_empty_stream() {
125        let waker = futures::task::noop_waker_ref();
126        let mut cx = std::task::Context::from_waker(&waker);
127
128        let mut stream = FastForward::new(stream::empty::<()>());
129        assert_ready_eq!(stream.poll_next_unpin(&mut cx), None);
130    }
131
132    #[tokio::test]
133    async fn test_fast_forward_drop_before_polled() {
134        let waker = futures::task::noop_waker_ref();
135        let mut cx = std::task::Context::from_waker(&waker);
136
137        let (mut tx, rx) = futures::channel::mpsc::unbounded();
138
139        let mut stream = FastForward::new(rx);
140
141        tx.send(1).await.unwrap();
142        assert_ready_eq!(stream.poll_next_unpin(&mut cx), Some(1));
143        assert_pending!(stream.poll_next_unpin(&mut cx));
144
145        drop(tx); // Terminate the stream without sending any more values
146        assert_ready_eq!(stream.poll_next_unpin(&mut cx), None);
147    }
148}