futures_timeout/
lib.rs

1use std::io;
2use std::ops::{Deref, DerefMut};
3use std::pin::Pin;
4use std::task::{Context, Poll};
5use std::time::Duration;
6
7use futures::future::FusedFuture;
8use futures::stream::FusedStream;
9use futures::{Future, FutureExt, Stream};
10use futures_timer::Delay;
11use pin_project::pin_project;
12
13pub trait TimeoutExt: Sized {
14    /// Requires a [`Future`] or [`Stream`] to complete before the specific duration has elapsed.
15    ///
16    /// **Note: If a [`Stream`] returns an item, the timer will reset until `Poll::Ready(None)` is returned**
17    fn timeout(self, duration: Duration) -> Timeout<Self> {
18        Timeout {
19            inner: self,
20            timer: Some(Delay::new(duration)),
21            duration,
22        }
23    }
24}
25
26impl<T: Sized> TimeoutExt for T {}
27
28#[derive(Debug)]
29#[pin_project]
30pub struct Timeout<T> {
31    #[pin]
32    inner: T,
33    timer: Option<Delay>,
34    duration: Duration,
35}
36
37impl<T> Timeout<T> {
38    /// Consumes Timeout and returns the inner value
39    pub fn into_inner(self) -> T {
40        self.inner
41    }
42}
43
44impl<T> Deref for Timeout<T> {
45    type Target = T;
46    fn deref(&self) -> &Self::Target {
47        &self.inner
48    }
49}
50
51impl<T> DerefMut for Timeout<T> {
52    fn deref_mut(&mut self) -> &mut Self::Target {
53        &mut self.inner
54    }
55}
56
57impl<T: Future> Future for Timeout<T> {
58    type Output = io::Result<T::Output>;
59
60    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
61        let this = self.project();
62
63        let Some(timer) = this.timer.as_mut() else {
64            return Poll::Ready(Err(io::ErrorKind::TimedOut.into()));
65        };
66
67        match this.inner.poll(cx) {
68            Poll::Ready(value) => return Poll::Ready(Ok(value)),
69            Poll::Pending => {}
70        }
71
72        futures::ready!(timer.poll_unpin(cx));
73        this.timer.take();
74        Poll::Ready(Err(io::ErrorKind::TimedOut.into()))
75    }
76}
77
78impl<T: Future> FusedFuture for Timeout<T> {
79    fn is_terminated(&self) -> bool {
80        self.timer.is_none()
81    }
82}
83
84impl<T: Stream> Stream for Timeout<T> {
85    type Item = io::Result<T::Item>;
86
87    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
88        let this = self.project();
89
90        let Some(timer) = this.timer.as_mut() else {
91            return Poll::Ready(None);
92        };
93
94        match this.inner.poll_next(cx) {
95            Poll::Ready(Some(value)) => {
96                timer.reset(*this.duration);
97                return Poll::Ready(Some(Ok(value)));
98            }
99            Poll::Ready(None) => {
100                this.timer.take();
101                return Poll::Ready(None);
102            }
103            Poll::Pending => {}
104        }
105
106        futures::ready!(timer.poll_unpin(cx));
107        this.timer.take();
108        Poll::Ready(Some(Err(io::ErrorKind::TimedOut.into())))
109    }
110
111    fn size_hint(&self) -> (usize, Option<usize>) {
112        self.inner.size_hint()
113    }
114}
115
116impl<T: Stream> FusedStream for Timeout<T> {
117    fn is_terminated(&self) -> bool {
118        self.timer.is_none()
119    }
120}
121
122#[cfg(test)]
123mod test {
124    use std::time::Duration;
125
126    use futures::{StreamExt, TryStreamExt};
127
128    use crate::TimeoutExt;
129
130    #[test]
131    fn fut_timeout() {
132        futures::executor::block_on(
133            futures_timer::Delay::new(Duration::from_secs(10)).timeout(Duration::from_secs(5)),
134        )
135        .expect_err("timeout after timer elapsed");
136    }
137
138    #[test]
139    fn stream_timeout() {
140        futures::executor::block_on(async move {
141            let mut st = futures::stream::once(async move {
142                futures_timer::Delay::new(Duration::from_secs(10)).await;
143                0
144            })
145            .timeout(Duration::from_secs(5))
146            .boxed();
147
148            st.try_next()
149                .await
150                .expect_err("timeout after timer elapsed");
151        });
152    }
153}