futures_timeout/
lib.rs

1use std::io;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4use std::time::Duration;
5
6use futures::future::FusedFuture;
7use futures::stream::FusedStream;
8use futures::{Future, FutureExt, Stream};
9use futures_timer::Delay;
10use pin_project::pin_project;
11
12pub trait TimeoutExt: Sized {
13    /// Requires a [`Future`] or [`Stream`] to complete before the specific duration has elapsed.
14    ///
15    /// **Note: If a [`Stream`] returns an item, the timer will reset until `Poll::Ready(None)` is returned**
16    fn timeout(self, duration: Duration) -> Timeout<Self> {
17        Timeout {
18            inner: self,
19            timer: Some(Delay::new(duration)),
20            duration,
21        }
22    }
23}
24
25impl<T: Sized> TimeoutExt for T {}
26
27#[derive(Debug)]
28#[pin_project]
29pub struct Timeout<T> {
30    #[pin]
31    inner: T,
32    timer: Option<Delay>,
33    duration: Duration,
34}
35
36impl<T: Future> Future for Timeout<T> {
37    type Output = io::Result<T::Output>;
38
39    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
40        let this = self.project();
41
42        let Some(timer) = this.timer.as_mut() else {
43            return Poll::Ready(Err(io::ErrorKind::TimedOut.into()));
44        };
45
46        match this.inner.poll(cx) {
47            Poll::Ready(value) => return Poll::Ready(Ok(value)),
48            Poll::Pending => {}
49        }
50
51        futures::ready!(timer.poll_unpin(cx));
52        this.timer.take();
53        return Poll::Ready(Err(io::ErrorKind::TimedOut.into()))
54    }
55}
56
57impl<T: Future> FusedFuture for Timeout<T> {
58    fn is_terminated(&self) -> bool {
59        self.timer.is_none()
60    }
61}
62
63impl<T: Stream> Stream for Timeout<T> {
64    type Item = io::Result<T::Item>;
65
66    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
67        let this = self.project();
68
69        let Some(timer) = this.timer.as_mut() else {
70            return Poll::Ready(None);
71        };
72
73        match this.inner.poll_next(cx) {
74            Poll::Ready(Some(value)) => {
75                timer.reset(*this.duration);
76                return Poll::Ready(Some(Ok(value)));
77            }
78            Poll::Ready(None) => {
79                this.timer.take();
80                return Poll::Ready(None);
81            }
82            Poll::Pending => {}
83        }
84
85        futures::ready!(timer.poll_unpin(cx));
86        this.timer.take();
87        return Poll::Ready(Some(Err(io::ErrorKind::TimedOut.into())));
88    }
89
90    fn size_hint(&self) -> (usize, Option<usize>) {
91        self.inner.size_hint()
92    }
93}
94
95impl<T: Stream> FusedStream for Timeout<T> {
96    fn is_terminated(&self) -> bool {
97        self.timer.is_none()
98    }
99}
100
101#[cfg(test)]
102mod test {
103    use std::time::Duration;
104
105    use futures::{StreamExt, TryStreamExt};
106
107    use crate::TimeoutExt;
108
109    #[test]
110    fn fut_timeout() {
111        futures::executor::block_on(
112            futures_timer::Delay::new(Duration::from_secs(10)).timeout(Duration::from_secs(5)),
113        )
114        .expect_err("timeout after timer elapsed");
115    }
116
117    #[test]
118    fn stream_timeout() {
119        futures::executor::block_on(async move {
120            let mut st = futures::stream::once(async move {
121                futures_timer::Delay::new(Duration::from_secs(10)).await;
122                0
123            })
124            .timeout(Duration::from_secs(5))
125            .boxed();
126
127            st.try_next()
128                .await
129                .expect_err("timeout after timer elapsed");
130        });
131    }
132}