fire_stream/util/
timeout.rs1use std::pin::Pin;
4use std::task::{ Context, Poll };
5use std::io::IoSlice;
6use std::future::Future;
7
8use tokio::net::{TcpStream, tcp};
9use tokio::time::{ Duration, Instant, sleep_until, Sleep };
10use tokio::io::{ self, AsyncRead, AsyncWrite, ReadBuf };
11
12#[derive(Debug)]
13pub struct TimeoutReader<S> {
14 stream: S,
15 timeout: Duration,
16 timer: Pin<Box<Sleep>>,
17 active: bool
18}
19
20impl<S> TimeoutReader<S>
21where S: AsyncRead {
22 pub fn new(stream: S, timeout: Duration) -> Self {
23 Self {
24 stream,
25 timeout,
26 timer: Box::pin(sleep_until(Instant::now())),
27 active: false
28 }
29 }
30
31 pub fn timeout(&self) -> Duration {
32 self.timeout
33 }
34
35 pub fn set_timeout(&mut self, timeout: Duration) {
36 self.timeout = timeout;
37 self.timer.as_mut().reset(Instant::now());
38 self.active = false;
39 }
40
41 pub fn poll_timeout(&mut self, cx: &mut Context) -> io::Result<()> {
42
43 if !self.active {
45 self.timer.as_mut().reset(Instant::now() + self.timeout);
46 self.active = true;
47 }
48
49 match self.timer.as_mut().poll(cx) {
50 Poll::Ready(_) => Err(io::Error::from(io::ErrorKind::TimedOut)),
52 Poll::Pending => Ok(())
54 }
55 }
56
57 #[allow(dead_code)]
58 pub fn inner_mut(&mut self) -> &mut S {
59 &mut self.stream
60 }
61}
62
63impl TimeoutReader<TcpStream> {
64 #[allow(dead_code)]
65 pub fn split<'a>(
66 &'a mut self
67 ) -> (TimeoutReader<tcp::ReadHalf<'a>>, tcp::WriteHalf<'a>) {
68 let (read, write) = self.stream.split();
69 (TimeoutReader::new(read, self.timeout), write)
70 }
71}
72
73impl<S> AsyncRead for TimeoutReader<S>
74where S: AsyncRead + Unpin {
75 fn poll_read(
76 mut self: Pin<&mut Self>,
77 cx: &mut Context,
78 buf: &mut ReadBuf<'_>
79 ) -> Poll< io::Result<()> > {
80
81 let r = Pin::new(&mut self.stream).poll_read(cx, buf);
83
84 match r {
85 Poll::Pending => self.get_mut().poll_timeout(cx)?,
86 _ => { self.active = false }
87 }
88
89 r
90 }
91}
92
93impl<S> AsyncWrite for TimeoutReader<S>
94where S: AsyncWrite + Unpin {
95 #[inline]
96 fn poll_write(
97 mut self: Pin<&mut Self>,
98 cx: &mut Context,
99 buf: &[u8]
100 ) -> Poll< io::Result<usize> > {
101 Pin::new(&mut self.stream).poll_write(cx, buf)
102 }
103
104 #[inline]
105 fn poll_flush(
106 mut self: Pin<&mut Self>,
107 cx: &mut Context
108 ) -> Poll< io::Result<()> > {
109 Pin::new(&mut self.stream).poll_flush(cx)
110 }
111
112 #[inline]
113 fn poll_shutdown(
114 mut self: Pin<&mut Self>,
115 cx: &mut Context
116 ) -> Poll< io::Result<()> > {
117 Pin::new(&mut self.stream).poll_shutdown(cx)
118 }
119
120 #[inline]
121 fn poll_write_vectored(
122 mut self: Pin<&mut Self>,
123 cx: &mut Context<'_>,
124 bufs: &[IoSlice<'_>]
125 ) -> Poll<io::Result<usize>> {
126 Pin::new(&mut self.stream).poll_write_vectored(cx, bufs)
127 }
128
129 #[inline]
130 fn is_write_vectored(&self) -> bool {
131 self.stream.is_write_vectored()
132 }
133}
134
135
136#[cfg(test)]
137mod tests {
138
139 use super::*;
140 use tokio::io::AsyncReadExt;
141 use tokio::time::sleep;
142
143 struct DelayStream(Pin<Box<Sleep>>);
144
145 impl DelayStream {
146 fn new(until: Instant) -> Self {
147 Self(Box::pin(sleep_until(until)))
148 }
149 }
150
151 impl AsyncRead for DelayStream {
152 fn poll_read(
153 mut self: Pin<&mut Self>,
154 cx: &mut Context,
155 _: &mut ReadBuf<'_>
156 ) -> Poll<io::Result<()>> {
157 self.0.as_mut().poll(cx)
158 .map(|_| Ok(()))
159 }
160 }
161
162 #[tokio::test]
163 async fn read_timeout() {
164 let reader = DelayStream::new(Instant::now() + Duration::from_millis(500));
165 let mut reader = TimeoutReader::new(reader, Duration::from_millis(100));
166
167 let r = reader.read(&mut [0]).await;
168 assert_eq!(r.unwrap_err().kind(), io::ErrorKind::TimedOut);
169 let r = reader.read(&mut [0]).await;
170 assert_eq!(r.unwrap_err().kind(), io::ErrorKind::TimedOut);
171 sleep(Duration::from_millis(400)).await;
173 reader.read(&mut [0]).await.unwrap();
174 }
175
176 #[tokio::test]
177 async fn read_ok() {
178 let reader = DelayStream::new(Instant::now() + Duration::from_millis(100));
179 let mut reader = TimeoutReader::new(reader, Duration::from_millis(500));
180
181 reader.read(&mut [0]).await.unwrap();
182 }
183
184}