Skip to main content

hickory_server/server/
timeout_stream.rs

1use std::io;
2use std::mem;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5use std::time::Duration;
6
7use futures_util::FutureExt;
8use futures_util::stream::{Stream, StreamExt};
9use tokio::time::Sleep;
10use tracing::{debug, warn};
11
12/// This wraps the underlying Stream in a timeout.
13///
14/// Any `Ok(Poll::Ready(_))` from the underlying Stream will reset the timeout.
15pub struct TimeoutStream<S> {
16    stream: S,
17    timeout_duration: Duration,
18    timeout: Option<Pin<Box<Sleep>>>,
19}
20
21impl<S> TimeoutStream<S> {
22    /// Returns a new TimeoutStream
23    ///
24    /// # Arguments
25    ///
26    /// * `stream` - stream to wrap
27    /// * `timeout_duration` - timeout between each request, once exceed the connection is killed
28    pub fn new(stream: S, timeout_duration: Duration) -> Self {
29        Self {
30            stream,
31            timeout_duration,
32            timeout: None,
33        }
34    }
35
36    fn timeout(timeout_duration: Duration) -> Option<Pin<Box<Sleep>>> {
37        if timeout_duration > Duration::from_millis(0) {
38            Some(Box::pin(tokio::time::sleep(timeout_duration)))
39        } else {
40            None
41        }
42    }
43}
44
45impl<S, I> Stream for TimeoutStream<S>
46where
47    S: Stream<Item = Result<I, io::Error>> + Unpin,
48{
49    type Item = Result<I, io::Error>;
50
51    // somehow insert a timeout here...
52    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
53        // if the timer isn't set, set one now
54        if self.timeout.is_none() {
55            let timeout = Self::timeout(self.timeout_duration);
56            self.as_mut().timeout = timeout;
57        }
58
59        match self.stream.poll_next_unpin(cx) {
60            r @ Poll::Ready(_) => {
61                // reset the timeout to wait for the next request...
62                let timeout = if let Some(mut timeout) = Self::timeout(self.timeout_duration) {
63                    // ensure that interest in the Timeout is registered
64                    match timeout.poll_unpin(cx) {
65                        Poll::Ready(_) => {
66                            warn!("timeout fired immediately!");
67                            return Poll::Ready(Some(Err(io::Error::new(
68                                io::ErrorKind::TimedOut,
69                                "timeout fired immediately!",
70                            ))));
71                        }
72                        Poll::Pending => (), // this is the expected state...
73                    }
74
75                    Some(timeout)
76                } else {
77                    None
78                };
79
80                drop(mem::replace(&mut self.timeout, timeout));
81
82                r
83            }
84            Poll::Pending => {
85                if let Some(timeout) = &mut self.timeout {
86                    match timeout.poll_unpin(cx) {
87                        Poll::Pending => Poll::Pending,
88                        Poll::Ready(()) => {
89                            debug!("timeout on stream");
90                            Poll::Ready(Some(Err(io::Error::new(
91                                io::ErrorKind::TimedOut,
92                                format!("nothing ready in {:?}", self.timeout_duration),
93                            ))))
94                        }
95                    }
96                } else {
97                    Poll::Pending
98                }
99            }
100        }
101    }
102}
103
104#[cfg(test)]
105mod tests {
106    use futures_util::stream::{TryStreamExt, iter};
107    use test_support::subscribe;
108    use tokio::runtime::Runtime;
109
110    use super::*;
111
112    #[test]
113    fn test_no_timeout() {
114        subscribe();
115
116        #[allow(deprecated)]
117        let sequence = iter(vec![Ok(1), Err("error"), Ok(2)]).map_err(io::Error::other);
118        let core = Runtime::new().expect("could not get core");
119
120        let timeout_stream = TimeoutStream::new(sequence, Duration::from_secs(360));
121
122        let (val, timeout_stream) = core.block_on(timeout_stream.into_future());
123        assert_eq!(val.expect("nothing in stream").ok(), Some(1));
124
125        let (error, timeout_stream) = core.block_on(timeout_stream.into_future());
126        assert!(error.expect("nothing in stream").is_err());
127
128        let (val, timeout_stream) = core.block_on(timeout_stream.into_future());
129        assert_eq!(val.expect("nothing in stream").ok(), Some(2));
130
131        let (val, _) = core.block_on(timeout_stream.into_future());
132        assert!(val.is_none())
133    }
134
135    struct NeverStream {}
136
137    impl Stream for NeverStream {
138        type Item = Result<(), io::Error>;
139
140        // somehow insert a timeout here...
141        fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
142            Poll::Pending
143        }
144    }
145
146    #[test]
147    fn test_timeout() {
148        subscribe();
149
150        let core = Runtime::new().expect("could not get core");
151        let timeout_stream = TimeoutStream::new(NeverStream {}, Duration::from_millis(1));
152
153        assert!(
154            core.block_on(timeout_stream.into_future())
155                .0
156                .expect("nothing in stream")
157                .is_err()
158        );
159    }
160}