use std::time::{Duration, Instant};
pub trait ReadTimeout {
fn set_read_timeout(&mut self, timeout: Option<Duration>) -> std::io::Result<()>;
}
pub trait WriteTimeout {
fn set_write_timeout(&mut self, timeout: Option<Duration>) -> std::io::Result<()>;
}
impl ReadTimeout for std::io::Empty {
fn set_read_timeout(&mut self, _: Option<Duration>) -> std::io::Result<()> {
Ok(())
}
}
impl<T: AsRef<[u8]>> ReadTimeout for std::io::Cursor<T> {
fn set_read_timeout(&mut self, _: Option<Duration>) -> std::io::Result<()> {
Ok(())
}
}
impl ReadTimeout for std::collections::VecDeque<u8> {
fn set_read_timeout(&mut self, _: Option<Duration>) -> std::io::Result<()> {
Ok(())
}
}
impl<R: ReadTimeout, W> ReadTimeout for super::Bidir<R, W> {
fn set_read_timeout(&mut self, timeout: Option<Duration>) -> std::io::Result<()> {
self.0.set_read_timeout(timeout)
}
}
impl WriteTimeout for std::io::Empty {
fn set_write_timeout(&mut self, _: Option<Duration>) -> std::io::Result<()> {
Ok(())
}
}
impl WriteTimeout for std::io::Sink {
fn set_write_timeout(&mut self, _: Option<Duration>) -> std::io::Result<()> {
Ok(())
}
}
impl WriteTimeout for Vec<u8> {
fn set_write_timeout(&mut self, _: Option<Duration>) -> std::io::Result<()> {
Ok(())
}
}
impl WriteTimeout for std::collections::VecDeque<u8> {
fn set_write_timeout(&mut self, _: Option<Duration>) -> std::io::Result<()> {
Ok(())
}
}
impl<R, W: WriteTimeout> WriteTimeout for super::Bidir<R, W> {
fn set_write_timeout(&mut self, timeout: Option<Duration>) -> std::io::Result<()> {
self.1.set_write_timeout(timeout)
}
}
#[repr(transparent)]
#[derive(Clone, Copy, Debug, Default)]
pub struct NoTimeout<T>(pub T);
impl<T> ReadTimeout for NoTimeout<T> {
fn set_read_timeout(&mut self, _: Option<Duration>) -> std::io::Result<()> {
Ok(())
}
}
impl<T> WriteTimeout for NoTimeout<T> {
fn set_write_timeout(&mut self, _: Option<Duration>) -> std::io::Result<()> {
Ok(())
}
}
fn timeout_fallback(
new_timeout: Option<Duration>,
old_timeout: Option<Duration>,
) -> (Option<Duration>, bool) {
if new_timeout.is_some() {
(new_timeout, true)
} else {
(old_timeout, false)
}
}
pub(super) fn filter_time_error<T>(result: std::io::Result<T>) -> std::io::Result<Option<T>> {
use std::io::ErrorKind;
match result {
Ok(v) => Ok(Some(v)),
Err(e) => match e.kind() {
ErrorKind::TimedOut | ErrorKind::WouldBlock => Ok(None),
_ => Err(e),
},
}
}
#[derive(Default)]
pub(crate) struct TimeLimits {
read: Option<Duration>,
write: Option<Duration>,
update_write: bool,
}
impl TimeLimits {
pub fn require_update(&mut self) {
self.update_write = true;
}
pub fn read_timeout(&self) -> Option<Duration> {
self.read
}
pub fn set_read_timeout(&mut self, timeout: Option<Duration>) -> &mut Self {
self.read = timeout.map(|t| std::cmp::max(t, Duration::from_secs(1)));
self
}
pub fn set_write_timeout(&mut self, timeout: Option<Duration>) -> &mut Self {
self.update_write |= true;
self.write = timeout.map(|t| std::cmp::max(t, Duration::from_secs(1)));
self
}
}
pub(super) struct TimeLimitedSync<'a, C> {
conn: &'a mut C,
read: Option<Instant>,
}
impl<'a, C: super::Connection> TimeLimitedSync<'a, C> {
pub fn new(
conn: &'a mut C,
timeouts: &mut TimeLimits,
write_after: Option<Duration>,
) -> std::io::Result<(Self, bool)> {
let TimeLimits { read, write, update_write, .. } = *timeouts;
if update_write {
conn.set_write_timeout(write)?;
timeouts.update_write = false;
}
let (read, constrained) = timeout_fallback(write_after, read);
let read = read.and_then(|dur| Instant::now().checked_add(dur));
Ok((TimeLimitedSync { conn, read }, constrained))
}
fn update_timeout(&mut self) -> std::io::Result<()> {
if let Some(deadline) = self.read {
let duration = deadline.saturating_duration_since(Instant::now());
if duration != Duration::ZERO {
self.conn.set_read_timeout(Some(duration))?;
} else {
return Err(std::io::Error::new(
std::io::ErrorKind::TimedOut,
"collective read time limit elapsed",
));
}
}
Ok(())
}
}
impl<'a, C: super::Connection> std::io::Write for TimeLimitedSync<'a, C> {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
self.conn.as_write().write(buf)
}
fn flush(&mut self) -> std::io::Result<()> {
self.conn.as_write().flush()
}
fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> {
self.conn.as_write().write_all(buf)
}
fn write_fmt(&mut self, fmt: std::fmt::Arguments<'_>) -> std::io::Result<()> {
self.conn.as_write().write_fmt(fmt)
}
}
impl<'a, C: super::Connection> std::io::Read for TimeLimitedSync<'a, C> {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
self.update_timeout()?;
self.conn.as_bufread().read(buf)
}
}
impl<'a, C: super::Connection> std::io::BufRead for TimeLimitedSync<'a, C> {
fn fill_buf(&mut self) -> std::io::Result<&[u8]> {
self.update_timeout()?;
self.conn.as_bufread().fill_buf()
}
fn consume(&mut self, amt: usize) {
self.conn.as_bufread().consume(amt);
}
}
#[cfg(feature = "tokio")]
pub(super) struct TimeLimitedTokio<'a, C> {
conn: &'a mut C,
write: Option<Duration>,
}
#[cfg(feature = "tokio")]
impl<'a, C: super::ConnectionTokio> TimeLimitedTokio<'a, C> {
pub fn new(conn: &'a mut C, timeouts: &TimeLimits) -> Self {
TimeLimitedTokio { conn, write: timeouts.write }
}
fn poll_write_timeout(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Error> {
if let Some(timeout) = self.write {
use std::future::Future;
let sleeper = std::pin::pin!(tokio::time::sleep(timeout));
if sleeper.poll(cx).is_ready() {
return std::task::Poll::Ready(std::io::Error::new(
std::io::ErrorKind::TimedOut,
"async write time limit elapsed",
));
}
}
std::task::Poll::Pending
}
}
#[cfg(feature = "tokio")]
impl<'a, C: super::ConnectionTokio> tokio::io::AsyncRead for TimeLimitedTokio<'a, C> {
fn poll_read(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
self.conn.as_bufread().poll_read(cx, buf)
}
}
#[cfg(feature = "tokio")]
impl<'a, C: super::ConnectionTokio> tokio::io::AsyncBufRead for TimeLimitedTokio<'a, C> {
fn poll_fill_buf(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<&[u8]>> {
std::pin::Pin::into_inner(self).conn.as_bufread().poll_fill_buf(cx)
}
fn consume(mut self: std::pin::Pin<&mut Self>, amt: usize) {
self.conn.as_bufread().consume(amt);
}
}
#[cfg(feature = "tokio")]
impl<'a, C: super::ConnectionTokio> tokio::io::AsyncWrite for TimeLimitedTokio<'a, C> {
fn poll_write(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<Result<usize, std::io::Error>> {
let writer = std::pin::Pin::new(self.conn.as_write());
if let std::task::Poll::Ready(v) = writer.poll_write(cx, buf) {
return std::task::Poll::Ready(v);
}
self.poll_write_timeout(cx).map(Err)
}
fn poll_flush(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
let writer = std::pin::Pin::new(self.conn.as_write());
if let std::task::Poll::Ready(v) = writer.poll_flush(cx) {
return std::task::Poll::Ready(v);
}
self.poll_write_timeout(cx).map(Err)
}
fn poll_shutdown(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
std::pin::Pin::new(self.conn.as_write()).poll_shutdown(cx)
}
}
#[cfg(feature = "tokio")]
pub(super) async fn timed_io<T, F: std::future::Future<Output = std::io::Result<T>>>(
fut: F,
new_timeout: Option<Duration>,
old_timeout: Option<Duration>,
) -> std::io::Result<Result<T, bool>> {
let (timeout, allow_time_error) = timeout_fallback(new_timeout, old_timeout);
let msg = if let Some(dur) = timeout {
match tokio::time::timeout(dur, fut).await {
Ok(res) => res,
Err(e) => Err(e.into()),
}
} else {
fut.await
};
filter_time_error(msg).map(|v| v.ok_or(allow_time_error))
}