1use std::io;
2use std::ops::{Deref, DerefMut};
3use std::pin::Pin;
4use std::task::{Context, Poll};
5use std::time::Duration;
6
7use futures::future::FusedFuture;
8use futures::stream::FusedStream;
9use futures::{Future, FutureExt, Stream};
10use futures_timer::Delay;
11use pin_project::pin_project;
12
13pub trait TimeoutExt: Sized {
14 fn timeout(self, duration: Duration) -> Timeout<Self> {
18 Timeout {
19 inner: self,
20 timer: Some(Delay::new(duration)),
21 duration,
22 }
23 }
24}
25
26impl<T: Sized> TimeoutExt for T {}
27
28#[derive(Debug)]
29#[pin_project]
30pub struct Timeout<T> {
31 #[pin]
32 inner: T,
33 timer: Option<Delay>,
34 duration: Duration,
35}
36
37impl<T> Deref for Timeout<T> {
38 type Target = T;
39 fn deref(&self) -> &Self::Target {
40 &self.inner
41 }
42}
43
44impl<T> DerefMut for Timeout<T> {
45 fn deref_mut(&mut self) -> &mut Self::Target {
46 &mut self.inner
47 }
48}
49
50impl<T: Future> Future for Timeout<T> {
51 type Output = io::Result<T::Output>;
52
53 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
54 let this = self.project();
55
56 let Some(timer) = this.timer.as_mut() else {
57 return Poll::Ready(Err(io::ErrorKind::TimedOut.into()));
58 };
59
60 match this.inner.poll(cx) {
61 Poll::Ready(value) => return Poll::Ready(Ok(value)),
62 Poll::Pending => {}
63 }
64
65 futures::ready!(timer.poll_unpin(cx));
66 this.timer.take();
67 Poll::Ready(Err(io::ErrorKind::TimedOut.into()))
68 }
69}
70
71impl<T: Future> FusedFuture for Timeout<T> {
72 fn is_terminated(&self) -> bool {
73 self.timer.is_none()
74 }
75}
76
77impl<T: Stream> Stream for Timeout<T> {
78 type Item = io::Result<T::Item>;
79
80 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
81 let this = self.project();
82
83 let Some(timer) = this.timer.as_mut() else {
84 return Poll::Ready(None);
85 };
86
87 match this.inner.poll_next(cx) {
88 Poll::Ready(Some(value)) => {
89 timer.reset(*this.duration);
90 return Poll::Ready(Some(Ok(value)));
91 }
92 Poll::Ready(None) => {
93 this.timer.take();
94 return Poll::Ready(None);
95 }
96 Poll::Pending => {}
97 }
98
99 futures::ready!(timer.poll_unpin(cx));
100 this.timer.take();
101 Poll::Ready(Some(Err(io::ErrorKind::TimedOut.into())))
102 }
103
104 fn size_hint(&self) -> (usize, Option<usize>) {
105 self.inner.size_hint()
106 }
107}
108
109impl<T: Stream> FusedStream for Timeout<T> {
110 fn is_terminated(&self) -> bool {
111 self.timer.is_none()
112 }
113}
114
115#[cfg(test)]
116mod test {
117 use std::time::Duration;
118
119 use futures::{StreamExt, TryStreamExt};
120
121 use crate::TimeoutExt;
122
123 #[test]
124 fn fut_timeout() {
125 futures::executor::block_on(
126 futures_timer::Delay::new(Duration::from_secs(10)).timeout(Duration::from_secs(5)),
127 )
128 .expect_err("timeout after timer elapsed");
129 }
130
131 #[test]
132 fn stream_timeout() {
133 futures::executor::block_on(async move {
134 let mut st = futures::stream::once(async move {
135 futures_timer::Delay::new(Duration::from_secs(10)).await;
136 0
137 })
138 .timeout(Duration::from_secs(5))
139 .boxed();
140
141 st.try_next()
142 .await
143 .expect_err("timeout after timer elapsed");
144 });
145 }
146}