use std::io;
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::time::timeout;
pub const DEFAULT_IDLE_TIMEOUT: Duration = Duration::from_secs(300);
pub const DEFAULT_TOTAL_TIMEOUT: Duration = Duration::from_secs(3600);
const BUF_SIZE: usize = 8 * 1024;
pub async fn relay_with_timeouts<A, B>(
a: A,
b: B,
idle: Duration,
total: Duration,
) -> io::Result<(u64, u64)>
where
A: AsyncRead + AsyncWrite + Unpin,
B: AsyncRead + AsyncWrite + Unpin,
{
let (mut ar, mut aw) = tokio::io::split(a);
let (mut br, mut bw) = tokio::io::split(b);
let work = async {
tokio::try_join!(
copy_one_direction(&mut ar, &mut bw, idle),
copy_one_direction(&mut br, &mut aw, idle),
)
};
match timeout(total, work).await {
Ok(Ok((a_b, b_a))) => Ok((a_b, b_a)),
Ok(Err(e)) => Err(e),
Err(_) => Err(io::Error::new(
io::ErrorKind::TimedOut,
format!("tunnel exceeded total cap of {total:?}"),
)),
}
}
async fn copy_one_direction<R, W>(r: &mut R, w: &mut W, idle: Duration) -> io::Result<u64>
where
R: AsyncRead + Unpin,
W: AsyncWrite + Unpin,
{
let mut buf = vec![0u8; BUF_SIZE];
let mut total = 0u64;
loop {
let n = match timeout(idle, r.read(&mut buf)).await {
Ok(Ok(0)) => break,
Ok(Ok(n)) => n,
Ok(Err(e)) => return Err(e),
Err(_) => {
let _ = w.shutdown().await;
return Err(io::Error::new(
io::ErrorKind::TimedOut,
format!("relay idle for {idle:?}"),
));
}
};
w.write_all(&buf[..n]).await?;
total = total.saturating_add(n as u64);
}
let _ = w.shutdown().await;
Ok(total)
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
async fn bound_listener() -> (std::net::SocketAddr, TcpListener) {
let l = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = l.local_addr().unwrap();
(addr, l)
}
#[tokio::test]
async fn relays_bytes_in_both_directions_until_eof() {
let (server_addr, server_listener) = bound_listener().await;
let server_task = tokio::spawn(async move {
let (mut s, _) = server_listener.accept().await.unwrap();
let mut buf = [0u8; 4];
s.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf, b"ping");
s.write_all(b"pong-ping").await.unwrap();
s.shutdown().await.unwrap();
});
let (client_addr, client_listener) = bound_listener().await;
let relay_task = tokio::spawn(async move {
let (client_side, _) = client_listener.accept().await.unwrap();
let upstream_side = TcpStream::connect(server_addr).await.unwrap();
relay_with_timeouts(
client_side,
upstream_side,
Duration::from_secs(2),
Duration::from_secs(5),
)
.await
});
let mut client = TcpStream::connect(client_addr).await.unwrap();
client.write_all(b"ping").await.unwrap();
client.shutdown().await.unwrap();
let mut got = Vec::new();
client.read_to_end(&mut got).await.unwrap();
assert_eq!(got, b"pong-ping");
let (a_b, b_a) = relay_task.await.unwrap().unwrap();
assert_eq!(a_b, 4, "client→server byte count");
assert_eq!(b_a, 9, "server→client byte count");
server_task.await.unwrap();
}
#[tokio::test]
async fn idle_timeout_fires_when_both_sides_silent() {
let (server_addr, server_listener) = bound_listener().await;
let _server_task = tokio::spawn(async move {
let (sock, _) = server_listener.accept().await.unwrap();
tokio::time::sleep(Duration::from_secs(10)).await;
drop(sock);
});
let (client_addr, client_listener) = bound_listener().await;
let start = std::time::Instant::now();
let relay_task = tokio::spawn(async move {
let (client_side, _) = client_listener.accept().await.unwrap();
let upstream_side = TcpStream::connect(server_addr).await.unwrap();
relay_with_timeouts(
client_side,
upstream_side,
Duration::from_millis(150),
Duration::from_secs(5),
)
.await
});
let _client = TcpStream::connect(client_addr).await.unwrap();
let res = relay_task.await.unwrap();
let elapsed = start.elapsed();
let err = res.expect_err("expected idle timeout error");
assert_eq!(err.kind(), io::ErrorKind::TimedOut);
assert!(
elapsed < Duration::from_secs(2),
"idle should fire well before total cap; took {elapsed:?}"
);
}
#[tokio::test]
async fn total_timeout_fires_even_when_traffic_is_active() {
let (server_addr, server_listener) = bound_listener().await;
let _server_task = tokio::spawn(async move {
let (mut sock, _) = server_listener.accept().await.unwrap();
for _ in 0..1000 {
if sock.write_all(b".").await.is_err() {
break;
}
tokio::time::sleep(Duration::from_millis(20)).await;
}
});
let (client_addr, client_listener) = bound_listener().await;
let start = std::time::Instant::now();
let relay_task = tokio::spawn(async move {
let (client_side, _) = client_listener.accept().await.unwrap();
let upstream_side = TcpStream::connect(server_addr).await.unwrap();
relay_with_timeouts(
client_side,
upstream_side,
Duration::from_secs(1),
Duration::from_millis(300),
)
.await
});
let mut client = TcpStream::connect(client_addr).await.unwrap();
let drain = tokio::spawn(async move {
let mut buf = [0u8; 64];
loop {
if client.read(&mut buf).await.unwrap_or(0) == 0 {
break;
}
}
});
let err = relay_task.await.unwrap().expect_err("expected total cap");
let elapsed = start.elapsed();
assert_eq!(err.kind(), io::ErrorKind::TimedOut);
assert!(
elapsed >= Duration::from_millis(250) && elapsed < Duration::from_secs(2),
"total cap timing out of bounds: {elapsed:?}"
);
let _ = drain.await;
}
#[tokio::test]
async fn clean_half_close_propagates_without_idle_fire() {
let (server_addr, server_listener) = bound_listener().await;
let _server_task = tokio::spawn(async move {
let (mut s, _) = server_listener.accept().await.unwrap();
let mut buf = Vec::new();
let _ = s.read_to_end(&mut buf).await;
for _ in 0..5 {
if s.write_all(b"chunk").await.is_err() {
break;
}
tokio::time::sleep(Duration::from_millis(20)).await;
}
let _ = s.shutdown().await;
});
let (client_addr, client_listener) = bound_listener().await;
let relay_task = tokio::spawn(async move {
let (client_side, _) = client_listener.accept().await.unwrap();
let upstream_side = TcpStream::connect(server_addr).await.unwrap();
relay_with_timeouts(
client_side,
upstream_side,
Duration::from_secs(2),
Duration::from_secs(5),
)
.await
});
let mut client = TcpStream::connect(client_addr).await.unwrap();
client.write_all(b"hello").await.unwrap();
client.shutdown().await.unwrap();
let mut got = Vec::new();
client.read_to_end(&mut got).await.unwrap();
assert_eq!(got, b"chunkchunkchunkchunkchunk");
let (a_b, b_a) = relay_task.await.unwrap().unwrap();
assert_eq!(a_b, 5);
assert_eq!(b_a, 25);
}
}