fire_stream/util/
timeout.rs

1// reference: https://docs.rs/tokio-io-timeout/0.4.0/src/tokio_io_timeout/lib.rs.html
2
3use std::pin::Pin;
4use std::task::{ Context, Poll };
5use std::io::IoSlice;
6use std::future::Future;
7
8use tokio::net::{TcpStream, tcp};
9use tokio::time::{ Duration, Instant, sleep_until, Sleep };
10use tokio::io::{ self, AsyncRead, AsyncWrite, ReadBuf };
11
12#[derive(Debug)]
13pub struct TimeoutReader<S> {
14	stream: S,
15	timeout: Duration,
16	timer: Pin<Box<Sleep>>,
17	active: bool
18}
19
20impl<S> TimeoutReader<S>
21where S: AsyncRead {
22	pub fn new(stream: S, timeout: Duration) -> Self {
23		Self {
24			stream,
25			timeout,
26			timer: Box::pin(sleep_until(Instant::now())),
27			active: false
28		}
29	}
30
31	pub fn timeout(&self) -> Duration {
32		self.timeout
33	}
34
35	pub fn set_timeout(&mut self, timeout: Duration) {
36		self.timeout = timeout;
37		self.timer.as_mut().reset(Instant::now());
38		self.active = false;
39	}
40
41	pub fn poll_timeout(&mut self, cx: &mut Context) -> io::Result<()> {
42
43		// activate if not activated
44		if !self.active {
45			self.timer.as_mut().reset(Instant::now() + self.timeout);
46			self.active = true;
47		}
48
49		match self.timer.as_mut().poll(cx) {
50			// timer has ended
51			Poll::Ready(_) => Err(io::Error::from(io::ErrorKind::TimedOut)),
52			// timer is still running
53			Poll::Pending => Ok(())
54		}
55	}
56
57	#[allow(dead_code)]
58	pub fn inner_mut(&mut self) -> &mut S {
59		&mut self.stream
60	}
61}
62
63impl TimeoutReader<TcpStream> {
64	#[allow(dead_code)]
65	pub fn split<'a>(
66		&'a mut self
67	) -> (TimeoutReader<tcp::ReadHalf<'a>>, tcp::WriteHalf<'a>) {
68		let (read, write) = self.stream.split();
69		(TimeoutReader::new(read, self.timeout), write)
70	}
71}
72
73impl<S> AsyncRead for TimeoutReader<S>
74where S: AsyncRead + Unpin {
75	fn poll_read(
76		mut self: Pin<&mut Self>,
77		cx: &mut Context,
78		buf: &mut ReadBuf<'_>
79	) -> Poll< io::Result<()> > {
80
81		// call poll read on stream
82		let r = Pin::new(&mut self.stream).poll_read(cx, buf);
83
84		match r {
85			Poll::Pending => self.get_mut().poll_timeout(cx)?,
86			_ => { self.active = false }
87		}
88
89		r
90	}
91}
92
93impl<S> AsyncWrite for TimeoutReader<S>
94where S: AsyncWrite + Unpin {
95	#[inline]
96	fn poll_write(
97		mut self: Pin<&mut Self>,
98		cx: &mut Context,
99		buf: &[u8]
100	) -> Poll< io::Result<usize> > {
101		Pin::new(&mut self.stream).poll_write(cx, buf)
102	}
103
104	#[inline]
105	fn poll_flush(
106		mut self: Pin<&mut Self>,
107		cx: &mut Context
108	) -> Poll< io::Result<()> > {
109		Pin::new(&mut self.stream).poll_flush(cx)
110	}
111
112	#[inline]
113	fn poll_shutdown(
114		mut self: Pin<&mut Self>,
115		cx: &mut Context
116	) -> Poll< io::Result<()> > {
117		Pin::new(&mut self.stream).poll_shutdown(cx)
118	}
119
120	#[inline]
121	fn poll_write_vectored(
122		mut self: Pin<&mut Self>,
123		cx: &mut Context<'_>,
124		bufs: &[IoSlice<'_>]
125	) -> Poll<io::Result<usize>> {
126		Pin::new(&mut self.stream).poll_write_vectored(cx, bufs)
127	}
128
129	#[inline]
130	fn is_write_vectored(&self) -> bool {
131		self.stream.is_write_vectored()
132	}
133}
134
135
136#[cfg(test)]
137mod tests {
138
139	use super::*;
140	use tokio::io::AsyncReadExt;
141	use tokio::time::sleep;
142
143	struct DelayStream(Pin<Box<Sleep>>);
144
145	impl DelayStream {
146		fn new(until: Instant) -> Self {
147			Self(Box::pin(sleep_until(until)))
148		}
149	}
150
151	impl AsyncRead for DelayStream {
152		fn poll_read(
153			mut self: Pin<&mut Self>,
154			cx: &mut Context,
155			_: &mut ReadBuf<'_>
156		) -> Poll<io::Result<()>> {
157			self.0.as_mut().poll(cx)
158				.map(|_| Ok(()))
159		}
160	}
161
162	#[tokio::test]
163	async fn read_timeout() {
164		let reader = DelayStream::new(Instant::now() + Duration::from_millis(500));
165		let mut reader = TimeoutReader::new(reader, Duration::from_millis(100));
166
167		let r = reader.read(&mut [0]).await;
168		assert_eq!(r.unwrap_err().kind(), io::ErrorKind::TimedOut);
169		let r = reader.read(&mut [0]).await;
170		assert_eq!(r.unwrap_err().kind(), io::ErrorKind::TimedOut);
171		// now around 200ms passed
172		sleep(Duration::from_millis(400)).await;
173		reader.read(&mut [0]).await.unwrap();
174	}
175
176	#[tokio::test]
177	async fn read_ok() {
178		let reader = DelayStream::new(Instant::now() + Duration::from_millis(100));
179		let mut reader = TimeoutReader::new(reader, Duration::from_millis(500));
180
181		reader.read(&mut [0]).await.unwrap();
182	}
183
184}