futures_ext/stream/
stream_with_timeout.rs1use std::pin::Pin;
11use std::time::Duration;
12
13use futures::future::Future;
14use futures::stream::Stream;
15use futures::task::Context;
16use futures::task::Poll;
17use pin_project::pin_project;
18use thiserror::Error;
19use tokio::time::Sleep;
20
21#[derive(Debug, Error)]
23#[error("Stream timeout with duration {:?} was exceeded", .0)]
24pub struct StreamTimeoutError(Duration);
25
26#[pin_project]
29pub struct StreamWithTimeout<S> {
30 #[pin]
31 inner: S,
32 duration: Duration,
33 done: bool,
34 #[pin]
35 deadline: Option<Sleep>,
36}
37
38impl<S> StreamWithTimeout<S> {
39 pub fn new(inner: S, duration: Duration) -> Self {
41 Self {
42 inner,
43 duration,
44 done: false,
45 deadline: None,
46 }
47 }
48}
49
50impl<S: Stream> Stream for StreamWithTimeout<S> {
51 type Item = Result<<S as Stream>::Item, StreamTimeoutError>;
52
53 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
54 let mut this = self.project();
55
56 if *this.done {
57 return Poll::Ready(None);
58 }
59
60 let duration = *this.duration;
61
62 if this.deadline.is_none() {
63 this.deadline.set(Some(tokio::time::sleep(duration)));
64 }
65
66 match this.deadline.as_pin_mut().unwrap().poll(cx) {
68 Poll::Ready(()) => {
69 *this.done = true;
70 return Poll::Ready(Some(Err(StreamTimeoutError(duration))));
71 }
72 Poll::Pending => {
73 }
75 }
76
77 let res = futures::ready!(this.inner.poll_next(cx));
80 if res.is_none() {
81 *this.done = true;
82 }
83
84 Poll::Ready(Ok(res).transpose())
85 }
86}
87
88#[cfg(test)]
89mod test {
90 use anyhow::Error;
91 use futures::stream::StreamExt;
92 use futures::stream::TryStreamExt;
93
94 use super::*;
95
96 #[tokio::test]
97 async fn test_stream_timeout() -> Result<(), Error> {
98 tokio::time::pause();
99
100 let s = async_stream::stream! {
101 yield Result::<(), Error>::Ok(());
102 tokio::time::advance(Duration::from_secs(2)).await;
103 yield Result::<(), Error>::Ok(());
104 };
105
106 let mut s = StreamWithTimeout::new(s.boxed(), Duration::from_secs(1)).boxed();
107
108 assert!(s.try_next().await?.is_some());
109 assert!(s.try_next().await.is_err());
110 assert!(s.try_next().await?.is_none());
111
112 Ok(())
113 }
114
115 #[tokio::test]
116 async fn test_stream_done_before_timeout() -> Result<(), Error> {
117 tokio::time::pause();
118
119 let s = async_stream::stream! {
120 yield Result::<(), Error>::Ok(());
121 yield Result::<(), Error>::Ok(());
122 };
123
124 let mut s = StreamWithTimeout::new(s.boxed(), Duration::from_secs(1)).boxed();
125
126 assert!(s.try_next().await?.is_some());
127 assert!(s.try_next().await?.is_some());
128 assert!(s.try_next().await?.is_none());
129
130 tokio::time::advance(Duration::from_secs(2)).await;
131
132 assert!(s.try_next().await?.is_none());
133
134 Ok(())
135 }
136
137 #[tokio::test]
138 async fn test_clock_starts_at_poll() -> Result<(), Error> {
139 tokio::time::pause();
140
141 let s = async_stream::stream! {
142 yield Result::<(), Error>::Ok(());
143 yield Result::<(), Error>::Ok(());
144 };
145 let mut s = StreamWithTimeout::new(s.boxed(), Duration::from_secs(1)).boxed();
146
147 tokio::time::advance(Duration::from_secs(2)).await;
148
149 assert!(s.try_next().await?.is_some());
150 assert!(s.try_next().await?.is_some());
151 assert!(s.try_next().await?.is_none());
152
153 Ok(())
154 }
155}