hickory_server/server/
timeout_stream.rs1use std::io;
2use std::mem;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5use std::time::Duration;
6
7use futures_util::FutureExt;
8use futures_util::stream::{Stream, StreamExt};
9use tokio::time::Sleep;
10use tracing::{debug, warn};
11
12pub struct TimeoutStream<S> {
16 stream: S,
17 timeout_duration: Duration,
18 timeout: Option<Pin<Box<Sleep>>>,
19}
20
21impl<S> TimeoutStream<S> {
22 pub fn new(stream: S, timeout_duration: Duration) -> Self {
29 Self {
30 stream,
31 timeout_duration,
32 timeout: None,
33 }
34 }
35
36 fn timeout(timeout_duration: Duration) -> Option<Pin<Box<Sleep>>> {
37 if timeout_duration > Duration::from_millis(0) {
38 Some(Box::pin(tokio::time::sleep(timeout_duration)))
39 } else {
40 None
41 }
42 }
43}
44
45impl<S, I> Stream for TimeoutStream<S>
46where
47 S: Stream<Item = Result<I, io::Error>> + Unpin,
48{
49 type Item = Result<I, io::Error>;
50
51 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
53 if self.timeout.is_none() {
55 let timeout = Self::timeout(self.timeout_duration);
56 self.as_mut().timeout = timeout;
57 }
58
59 match self.stream.poll_next_unpin(cx) {
60 r @ Poll::Ready(_) => {
61 let timeout = if let Some(mut timeout) = Self::timeout(self.timeout_duration) {
63 match timeout.poll_unpin(cx) {
65 Poll::Ready(_) => {
66 warn!("timeout fired immediately!");
67 return Poll::Ready(Some(Err(io::Error::new(
68 io::ErrorKind::TimedOut,
69 "timeout fired immediately!",
70 ))));
71 }
72 Poll::Pending => (), }
74
75 Some(timeout)
76 } else {
77 None
78 };
79
80 drop(mem::replace(&mut self.timeout, timeout));
81
82 r
83 }
84 Poll::Pending => {
85 if let Some(timeout) = &mut self.timeout {
86 match timeout.poll_unpin(cx) {
87 Poll::Pending => Poll::Pending,
88 Poll::Ready(()) => {
89 debug!("timeout on stream");
90 Poll::Ready(Some(Err(io::Error::new(
91 io::ErrorKind::TimedOut,
92 format!("nothing ready in {:?}", self.timeout_duration),
93 ))))
94 }
95 }
96 } else {
97 Poll::Pending
98 }
99 }
100 }
101 }
102}
103
104#[cfg(test)]
105mod tests {
106 use futures_util::stream::{TryStreamExt, iter};
107 use test_support::subscribe;
108 use tokio::runtime::Runtime;
109
110 use super::*;
111
112 #[test]
113 fn test_no_timeout() {
114 subscribe();
115
116 #[allow(deprecated)]
117 let sequence = iter(vec![Ok(1), Err("error"), Ok(2)]).map_err(io::Error::other);
118 let core = Runtime::new().expect("could not get core");
119
120 let timeout_stream = TimeoutStream::new(sequence, Duration::from_secs(360));
121
122 let (val, timeout_stream) = core.block_on(timeout_stream.into_future());
123 assert_eq!(val.expect("nothing in stream").ok(), Some(1));
124
125 let (error, timeout_stream) = core.block_on(timeout_stream.into_future());
126 assert!(error.expect("nothing in stream").is_err());
127
128 let (val, timeout_stream) = core.block_on(timeout_stream.into_future());
129 assert_eq!(val.expect("nothing in stream").ok(), Some(2));
130
131 let (val, _) = core.block_on(timeout_stream.into_future());
132 assert!(val.is_none())
133 }
134
135 struct NeverStream {}
136
137 impl Stream for NeverStream {
138 type Item = Result<(), io::Error>;
139
140 fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
142 Poll::Pending
143 }
144 }
145
146 #[test]
147 fn test_timeout() {
148 subscribe();
149
150 let core = Runtime::new().expect("could not get core");
151 let timeout_stream = TimeoutStream::new(NeverStream {}, Duration::from_millis(1));
152
153 assert!(
154 core.block_on(timeout_stream.into_future())
155 .0
156 .expect("nothing in stream")
157 .is_err()
158 );
159 }
160}