hickory_server/server/
timeout_stream.rs1use std::io;
2use std::mem;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5use std::time::Duration;
6
7use futures_util::FutureExt;
8use futures_util::stream::{Stream, StreamExt};
9use tokio::time::Sleep;
10use tracing::{debug, warn};
11
12pub struct TimeoutStream<S> {
16 stream: S,
17 timeout_duration: Duration,
18 timeout: Option<Pin<Box<Sleep>>>,
19}
20
21impl<S> TimeoutStream<S> {
22 pub fn new(stream: S, timeout_duration: Duration) -> Self {
30 Self {
31 stream,
32 timeout_duration,
33 timeout: None,
34 }
35 }
36
37 fn timeout(timeout_duration: Duration) -> Option<Pin<Box<Sleep>>> {
38 if timeout_duration > Duration::from_millis(0) {
39 Some(Box::pin(tokio::time::sleep(timeout_duration)))
40 } else {
41 None
42 }
43 }
44}
45
46impl<S, I> Stream for TimeoutStream<S>
47where
48 S: Stream<Item = Result<I, io::Error>> + Unpin,
49{
50 type Item = Result<I, io::Error>;
51
52 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
54 if self.timeout.is_none() {
56 let timeout = Self::timeout(self.timeout_duration);
57 self.as_mut().timeout = timeout;
58 }
59
60 match self.stream.poll_next_unpin(cx) {
61 r @ Poll::Ready(_) => {
62 let timeout = if let Some(mut timeout) = Self::timeout(self.timeout_duration) {
64 match timeout.poll_unpin(cx) {
66 Poll::Ready(_) => {
67 warn!("timeout fired immediately!");
68 return Poll::Ready(Some(Err(io::Error::new(
69 io::ErrorKind::TimedOut,
70 "timeout fired immediately!",
71 ))));
72 }
73 Poll::Pending => (), }
75
76 Some(timeout)
77 } else {
78 None
79 };
80
81 drop(mem::replace(&mut self.timeout, timeout));
82
83 r
84 }
85 Poll::Pending => {
86 if let Some(timeout) = &mut self.timeout {
87 match timeout.poll_unpin(cx) {
88 Poll::Pending => Poll::Pending,
89 Poll::Ready(()) => {
90 debug!("timeout on stream");
91 Poll::Ready(Some(Err(io::Error::new(
92 io::ErrorKind::TimedOut,
93 format!("nothing ready in {:?}", self.timeout_duration),
94 ))))
95 }
96 }
97 } else {
98 Poll::Pending
99 }
100 }
101 }
102 }
103}
104
105#[cfg(test)]
106mod tests {
107 use futures_util::stream::{TryStreamExt, iter};
108 use test_support::subscribe;
109 use tokio::runtime::Runtime;
110
111 use super::*;
112
113 #[test]
114 fn test_no_timeout() {
115 subscribe();
116
117 #[allow(deprecated)]
118 let sequence = iter(vec![Ok(1), Err("error"), Ok(2)])
119 .map_err(|e| io::Error::new(io::ErrorKind::Other, e));
120 let core = Runtime::new().expect("could not get core");
121
122 let timeout_stream = TimeoutStream::new(sequence, Duration::from_secs(360));
123
124 let (val, timeout_stream) = core.block_on(timeout_stream.into_future());
125 assert_eq!(val.expect("nothing in stream").ok(), Some(1));
126
127 let (error, timeout_stream) = core.block_on(timeout_stream.into_future());
128 assert!(error.expect("nothing in stream").is_err());
129
130 let (val, timeout_stream) = core.block_on(timeout_stream.into_future());
131 assert_eq!(val.expect("nothing in stream").ok(), Some(2));
132
133 let (val, _) = core.block_on(timeout_stream.into_future());
134 assert!(val.is_none())
135 }
136
137 struct NeverStream {}
138
139 impl Stream for NeverStream {
140 type Item = Result<(), io::Error>;
141
142 fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
144 Poll::Pending
145 }
146 }
147
148 #[test]
149 fn test_timeout() {
150 subscribe();
151
152 let core = Runtime::new().expect("could not get core");
153 let timeout_stream = TimeoutStream::new(NeverStream {}, Duration::from_millis(1));
154
155 assert!(
156 core.block_on(timeout_stream.into_future())
157 .0
158 .expect("nothing in stream")
159 .is_err()
160 );
161 }
162}