use std::time::Duration;
use futures::{future::Either, pin_mut};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio::time::timeout;
use crate::helpers::{IntoError, Res};
pub struct CopyPump {
client_socket: TcpStream,
endpoint_socket: TcpStream,
read_timeout: u64,
}
impl CopyPump {
pub fn from(client_socket: TcpStream, endpoint_socket: TcpStream, read_timeout: u64) -> Self {
CopyPump {
client_socket,
endpoint_socket,
read_timeout,
}
}
pub async fn start(self) -> Res<()> {
self.run_pumps_as_copy().await
}
async fn run_pumps_as_copy(self) -> Res<()> {
let (mut client_socket_read, mut client_socket_write) = self.client_socket.into_split();
let (mut endpoint_socket_read, mut endpoint_socket_write) = self.endpoint_socket.into_split();
let idle = match self.read_timeout {
0 => None,
ms => Some(Duration::from_millis(ms)),
};
let pump_up = Self::pump(&mut client_socket_read, &mut endpoint_socket_write, idle);
let pump_down = Self::pump(&mut endpoint_socket_read, &mut client_socket_write, idle);
pin_mut!(pump_up);
pin_mut!(pump_down);
match futures::future::select(pump_up, pump_down).await {
Either::Left((result, _)) | Either::Right((result, _)) => result,
}
}
async fn pump<R, W>(from: &mut R, to: &mut W, idle: Option<Duration>) -> Res<()>
where
R: AsyncRead + Unpin,
W: AsyncWrite + Unpin,
{
let mut buffer = [0u8; 16 * 1024];
loop {
let read = match idle {
Some(duration) => match timeout(duration, from.read(&mut buffer)).await {
Ok(result) => result?,
Err(_) => return "Idle timeout.".into_error(),
},
None => from.read(&mut buffer).await?,
};
if read == 0 {
return Ok(());
}
to.write_all(&buffer[..read]).await?;
to.flush().await?;
}
}
}
#[cfg(test)]
mod tests {
use super::CopyPump;
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
use tokio::time::{sleep, timeout};
#[tokio::test]
async fn idle_timeout_resets_on_activity() {
let idle = Some(Duration::from_millis(250));
let (mut src, mut from) = duplex(256);
let (mut to, mut drain) = duplex(256);
let writer = async move {
for _ in 0..10 {
src.write_all(b"x").await.unwrap();
src.flush().await.unwrap();
sleep(Duration::from_millis(50)).await;
}
drop(src);
};
let reader = async move {
let mut buf = [0u8; 16];
let mut total = 0;
while total < 10 {
match drain.read(&mut buf).await {
Ok(0) | Err(_) => break,
Ok(n) => total += n,
}
}
total
};
let pump = CopyPump::pump(&mut from, &mut to, idle);
let (pump_result, (), received) = timeout(Duration::from_secs(5), async { tokio::join!(pump, writer, reader) })
.await
.expect("pump + driver should finish well within 5s");
assert!(pump_result.is_ok(), "active connection was killed: {:?}", pump_result.err());
assert_eq!(received, 10, "all bytes should have been pumped through");
}
#[tokio::test]
async fn idle_timeout_fires_when_silent() {
let idle = Some(Duration::from_millis(100));
let (_src, mut from) = duplex(64); let (mut to, _drain) = duplex(64);
let result = timeout(Duration::from_secs(2), CopyPump::pump(&mut from, &mut to, idle))
.await
.expect("pump should give up around the idle window, well before 2s");
assert!(result.is_err(), "silent connection should have hit the idle timeout");
}
#[tokio::test]
async fn disabled_idle_timeout_never_fires() {
let (_src, mut from) = duplex(64);
let (mut to, _drain) = duplex(64);
let outcome = timeout(Duration::from_millis(300), CopyPump::pump(&mut from, &mut to, None)).await;
assert!(outcome.is_err(), "with idle disabled the pump must keep waiting, not return");
}
}