use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::time::{Sleep, sleep};
pub struct IdleTimeout<S> {
inner: S,
reaper: Option<Reaper>,
}
struct Reaper {
interval: Duration,
deadline: Pin<Box<Sleep>>,
}
impl<S> IdleTimeout<S> {
pub fn new(inner: S, idle: Option<Duration>) -> Self {
let reaper = idle.map(|interval| Reaper {
interval,
deadline: Box::pin(sleep(interval)),
});
Self { inner, reaper }
}
fn reset_timeout(&mut self, cx: &mut Context<'_>) {
if let Some(r) = self.reaper.as_mut() {
r.deadline.set(sleep(r.interval));
let _ = r.deadline.as_mut().poll(cx);
}
}
fn idle_expired(&mut self, cx: &mut Context<'_>) -> bool {
self.reaper
.as_mut()
.is_some_and(|r| r.deadline.as_mut().poll(cx).is_ready())
}
}
fn idle_timeout_err() -> std::io::Error {
std::io::Error::new(std::io::ErrorKind::TimedOut, "connection idle timeout")
}
impl<S: AsyncRead + Unpin> AsyncRead for IdleTimeout<S> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
match Pin::new(&mut self.inner).poll_read(cx, buf) {
Poll::Ready(Ok(())) => {
self.reset_timeout(cx);
Poll::Ready(Ok(()))
}
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
Poll::Pending => {
if self.idle_expired(cx) {
return Poll::Ready(Err(idle_timeout_err()));
}
Poll::Pending
}
}
}
}
impl<S: AsyncWrite + Unpin> AsyncWrite for IdleTimeout<S> {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
let r = Pin::new(&mut self.inner).poll_write(cx, buf);
match &r {
Poll::Ready(Ok(_)) => self.reset_timeout(cx),
Poll::Pending if self.idle_expired(cx) => {
return Poll::Ready(Err(idle_timeout_err()));
}
_ => {}
}
r
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.inner).poll_flush(cx)
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.inner).poll_shutdown(cx)
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
#[tokio::test]
async fn hyper_keepalive_idle_connection_is_reaped() {
use http_body_util::Full;
use hyper::body::Bytes;
use hyper::service::service_fn;
use hyper_util::rt::{TokioExecutor, TokioIo};
use hyper_util::server::conn::auto::Builder;
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let idle = Duration::from_millis(300);
tokio::spawn(async move {
let (sock, _) = listener.accept().await.unwrap();
let io = TokioIo::new(IdleTimeout::new(sock, Some(idle)));
let svc = service_fn(|_req| async {
Ok::<_, std::convert::Infallible>(hyper::Response::new(Full::new(Bytes::from(
"ok",
))))
});
let _ = Builder::new(TokioExecutor::new())
.serve_connection(io, svc)
.await;
});
let mut client = tokio::net::TcpStream::connect(addr).await.unwrap();
client
.write_all(b"GET / HTTP/1.1\r\nHost: x\r\n\r\n")
.await
.unwrap();
let mut buf = [0u8; 1024];
let n = client.read(&mut buf).await.unwrap();
assert!(n > 0, "expected a response");
let Ok(read) = tokio::time::timeout(Duration::from_secs(3), client.read(&mut buf)).await
else {
panic!("LEAK REPRODUCED: connection not reaped within 3s of a 300ms idle timeout");
};
match read {
Ok(0) => {}
Ok(n) => panic!("expected EOF, got {n} more bytes"),
Err(e) => panic!("expected clean EOF, got error: {e}"),
}
}
struct Quiet;
impl AsyncRead for Quiet {
fn poll_read(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
_buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
Poll::Pending
}
}
struct OnceThenQuiet(bool);
impl AsyncRead for OnceThenQuiet {
fn poll_read(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
let this = self.get_mut();
if this.0 {
Poll::Pending
} else {
this.0 = true;
buf.put_slice(b"x");
Poll::Ready(Ok(()))
}
}
}
#[tokio::test]
async fn reaper_armed_only_when_enabled() {
assert!(IdleTimeout::new((), None).reaper.is_none());
assert!(
IdleTimeout::new((), Some(Duration::from_secs(60)))
.reaper
.is_some()
);
}
#[tokio::test(start_paused = true)]
async fn times_out_after_the_interval_when_idle() {
let start = tokio::time::Instant::now();
let mut s = IdleTimeout::new(Quiet, Some(Duration::from_secs(30)));
let mut buf = [0u8; 8];
let err = s.read(&mut buf).await.unwrap_err();
assert_eq!(err.kind(), std::io::ErrorKind::TimedOut);
assert!(start.elapsed() >= Duration::from_secs(30));
}
#[tokio::test(start_paused = true)]
async fn disabled_reaper_never_times_out() {
let mut s = IdleTimeout::new(Quiet, None);
let mut buf = [0u8; 8];
tokio::select! {
r = s.read(&mut buf) => panic!("disabled reaper should never resolve, got {r:?}"),
() = tokio::time::sleep(Duration::from_secs(86_400)) => {}
}
}
#[tokio::test(start_paused = true)]
async fn activity_resets_the_deadline() {
let mut s = IdleTimeout::new(OnceThenQuiet(false), Some(Duration::from_secs(30)));
let mut buf = [0u8; 8];
assert_eq!(s.read(&mut buf).await.unwrap(), 1);
let after_data = tokio::time::Instant::now();
let err = s.read(&mut buf).await.unwrap_err();
assert_eq!(err.kind(), std::io::ErrorKind::TimedOut);
assert!(after_data.elapsed() >= Duration::from_secs(30));
}
}