streamtools/
fast_forward.rs1use std::{pin::Pin, task::Poll};
2
3use futures::{stream::FusedStream, Stream, StreamExt};
4use pin_project_lite::pin_project;
5
6pin_project! {
7 #[must_use = "streams do nothing unless polled"]
9 pub struct FastForward<S> {
10 #[pin]
11 inner: Option<S>
12 }
13}
14
15impl<S> FastForward<S> {
16 pub(super) fn new(stream: S) -> Self {
17 Self {
18 inner: Some(stream),
19 }
20 }
21}
22
23impl<S> Stream for FastForward<S>
24where
25 S: Stream,
26{
27 type Item = S::Item;
28
29 fn poll_next(
30 self: Pin<&mut Self>,
31 cx: &mut std::task::Context<'_>,
32 ) -> std::task::Poll<Option<Self::Item>> {
33 let mut this = self.project();
34
35 let Some(mut inner) = this.inner.as_mut().as_pin_mut() else {
36 return Poll::Ready(None)
39 };
40
41 let mut last_value = None;
42
43 while let Poll::Ready(ready) = inner.poll_next_unpin(cx) {
44 match ready {
45 Some(value) => {
46 last_value = Some(value);
47 }
48 None => {
49 this.inner.set(None);
51 break;
52 }
53 }
54 }
55
56 match last_value {
57 Some(value) => Poll::Ready(Some(value)),
58 None => match this.inner.as_pin_mut() {
59 Some(_) => Poll::Pending, None => Poll::Ready(None), },
62 }
63 }
64}
65
66impl<S> FusedStream for FastForward<S>
67where
68 S: Stream,
69{
70 fn is_terminated(&self) -> bool {
71 self.inner.is_none()
72 }
73}
74
75impl<S> std::fmt::Debug for FastForward<S>
76where
77 S: std::fmt::Debug,
78{
79 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80 f.debug_struct("FlattenSwitch")
81 .field("inner", &self.inner)
82 .finish()
83 }
84}
85
86#[cfg(test)]
87mod tests {
88 use futures::{stream, SinkExt};
89 use tokio_test::{assert_pending, assert_ready_eq};
90
91 use super::*;
92
93 #[tokio::test]
94 async fn test_fast_forward() {
95 let waker = futures::task::noop_waker_ref();
96 let mut cx = std::task::Context::from_waker(&waker);
97
98 let (mut tx, rx) = futures::channel::mpsc::unbounded();
99
100 let mut stream = FastForward::new(rx);
101
102 assert_pending!(stream.poll_next_unpin(&mut cx));
103
104 tx.send(1).await.unwrap();
105 assert_ready_eq!(stream.poll_next_unpin(&mut cx), Some(1));
106 assert_pending!(stream.poll_next_unpin(&mut cx));
107
108 tx.send(2).await.unwrap(); tx.send(3).await.unwrap();
110
111 assert_ready_eq!(stream.poll_next_unpin(&mut cx), Some(3));
112 assert_pending!(stream.poll_next_unpin(&mut cx));
113
114 tx.send(4).await.unwrap();
116 drop(tx);
117
118 assert_ready_eq!(stream.poll_next_unpin(&mut cx), Some(4)); assert_ready_eq!(stream.poll_next_unpin(&mut cx), None);
120 assert_ready_eq!(stream.poll_next_unpin(&mut cx), None); }
122
123 #[tokio::test]
124 async fn test_fast_forward_empty_stream() {
125 let waker = futures::task::noop_waker_ref();
126 let mut cx = std::task::Context::from_waker(&waker);
127
128 let mut stream = FastForward::new(stream::empty::<()>());
129 assert_ready_eq!(stream.poll_next_unpin(&mut cx), None);
130 }
131
132 #[tokio::test]
133 async fn test_fast_forward_drop_before_polled() {
134 let waker = futures::task::noop_waker_ref();
135 let mut cx = std::task::Context::from_waker(&waker);
136
137 let (mut tx, rx) = futures::channel::mpsc::unbounded();
138
139 let mut stream = FastForward::new(rx);
140
141 tx.send(1).await.unwrap();
142 assert_ready_eq!(stream.poll_next_unpin(&mut cx), Some(1));
143 assert_pending!(stream.poll_next_unpin(&mut cx));
144
145 drop(tx); assert_ready_eq!(stream.poll_next_unpin(&mut cx), None);
147 }
148}