futures_timeout/
lib.rs

1use std::io;
2use std::ops::{Deref, DerefMut};
3use std::pin::Pin;
4use std::task::{Context, Poll};
5use std::time::Duration;
6
7use futures::future::FusedFuture;
8use futures::stream::FusedStream;
9use futures::{Future, FutureExt, Stream};
10use futures_timer::Delay;
11use pin_project::pin_project;
12
13pub trait TimeoutExt: Sized {
14    /// Requires a [`Future`] or [`Stream`] to complete before the specific duration has elapsed.
15    ///
16    /// **Note: If a [`Stream`] returns an item, the timer will reset until `Poll::Ready(None)` is returned**
17    fn timeout(self, duration: Duration) -> Timeout<Self> {
18        Timeout {
19            inner: self,
20            timer: Some(Delay::new(duration)),
21            duration,
22        }
23    }
24}
25
26impl<T: Sized> TimeoutExt for T {}
27
28#[derive(Debug)]
29#[pin_project]
30pub struct Timeout<T> {
31    #[pin]
32    inner: T,
33    timer: Option<Delay>,
34    duration: Duration,
35}
36
37impl<T> Deref for Timeout<T> {
38    type Target = T;
39    fn deref(&self) -> &Self::Target {
40        &self.inner
41    }
42}
43
44impl<T> DerefMut for Timeout<T> {
45    fn deref_mut(&mut self) -> &mut Self::Target {
46        &mut self.inner
47    }
48}
49
50impl<T: Future> Future for Timeout<T> {
51    type Output = io::Result<T::Output>;
52
53    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
54        let this = self.project();
55
56        let Some(timer) = this.timer.as_mut() else {
57            return Poll::Ready(Err(io::ErrorKind::TimedOut.into()));
58        };
59
60        match this.inner.poll(cx) {
61            Poll::Ready(value) => return Poll::Ready(Ok(value)),
62            Poll::Pending => {}
63        }
64
65        futures::ready!(timer.poll_unpin(cx));
66        this.timer.take();
67        Poll::Ready(Err(io::ErrorKind::TimedOut.into()))
68    }
69}
70
71impl<T: Future> FusedFuture for Timeout<T> {
72    fn is_terminated(&self) -> bool {
73        self.timer.is_none()
74    }
75}
76
77impl<T: Stream> Stream for Timeout<T> {
78    type Item = io::Result<T::Item>;
79
80    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
81        let this = self.project();
82
83        let Some(timer) = this.timer.as_mut() else {
84            return Poll::Ready(None);
85        };
86
87        match this.inner.poll_next(cx) {
88            Poll::Ready(Some(value)) => {
89                timer.reset(*this.duration);
90                return Poll::Ready(Some(Ok(value)));
91            }
92            Poll::Ready(None) => {
93                this.timer.take();
94                return Poll::Ready(None);
95            }
96            Poll::Pending => {}
97        }
98
99        futures::ready!(timer.poll_unpin(cx));
100        this.timer.take();
101        Poll::Ready(Some(Err(io::ErrorKind::TimedOut.into())))
102    }
103
104    fn size_hint(&self) -> (usize, Option<usize>) {
105        self.inner.size_hint()
106    }
107}
108
109impl<T: Stream> FusedStream for Timeout<T> {
110    fn is_terminated(&self) -> bool {
111        self.timer.is_none()
112    }
113}
114
115#[cfg(test)]
116mod test {
117    use std::time::Duration;
118
119    use futures::{StreamExt, TryStreamExt};
120
121    use crate::TimeoutExt;
122
123    #[test]
124    fn fut_timeout() {
125        futures::executor::block_on(
126            futures_timer::Delay::new(Duration::from_secs(10)).timeout(Duration::from_secs(5)),
127        )
128        .expect_err("timeout after timer elapsed");
129    }
130
131    #[test]
132    fn stream_timeout() {
133        futures::executor::block_on(async move {
134            let mut st = futures::stream::once(async move {
135                futures_timer::Delay::new(Duration::from_secs(10)).await;
136                0
137            })
138            .timeout(Duration::from_secs(5))
139            .boxed();
140
141            st.try_next()
142                .await
143                .expect_err("timeout after timer elapsed");
144        });
145    }
146}