use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::time::Instant as TokioInstant;
pub trait RelayMetrics {
fn record_inbound(&self, bytes: u64);
fn record_outbound(&self, bytes: u64);
}
#[derive(Debug, Clone, Copy, Default)]
pub struct NoOpMetrics;
impl RelayMetrics for NoOpMetrics {
#[inline]
fn record_inbound(&self, _bytes: u64) {}
#[inline]
fn record_outbound(&self, _bytes: u64) {}
}
enum CopyState {
Reading(usize), Writing(usize, usize, usize), Flushing(usize, bool), ShuttingDown,
Done,
}
enum CopyPoll {
Flushed(usize),
Finished,
}
fn poll_copy_direction<R, W>(
cx: &mut Context<'_>,
reader: &mut R,
writer: &mut W,
buf: &mut [u8],
state: &mut CopyState,
) -> Poll<io::Result<CopyPoll>>
where
R: AsyncRead + Unpin + ?Sized,
W: AsyncWrite + Unpin + ?Sized,
{
loop {
match state {
CopyState::Reading(flushed) => {
let mut read_buf = ReadBuf::new(buf);
match Pin::new(&mut *reader).poll_read(cx, &mut read_buf) {
Poll::Ready(Ok(())) => {
let n = read_buf.filled().len();
if n == 0 {
if *flushed > 0 {
let total = *flushed;
*state = CopyState::Flushing(total, true);
} else {
*state = CopyState::ShuttingDown;
}
} else {
let acc = *flushed;
*state = CopyState::Writing(0, n, acc);
}
}
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
Poll::Pending => {
if *flushed > 0 {
let total = *flushed;
*state = CopyState::Flushing(total, false);
} else {
return Poll::Pending;
}
}
}
}
CopyState::Writing(pos, len, acc) => {
match Pin::new(&mut *writer).poll_write(cx, &buf[*pos..*len]) {
Poll::Ready(Ok(n)) => {
*pos += n;
if *pos >= *len {
let total = *acc + *len;
*state = CopyState::Reading(total);
}
}
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
Poll::Pending => return Poll::Pending,
}
}
CopyState::Flushing(bytes, is_eof) => {
let bytes = *bytes;
let eof = *is_eof;
match Pin::new(&mut *writer).poll_flush(cx) {
Poll::Ready(Ok(())) => {
if eof {
*state = CopyState::ShuttingDown;
} else {
*state = CopyState::Reading(0);
}
return Poll::Ready(Ok(CopyPoll::Flushed(bytes)));
}
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
Poll::Pending => return Poll::Pending,
}
}
CopyState::ShuttingDown => match Pin::new(&mut *writer).poll_shutdown(cx) {
Poll::Ready(_) => {
*state = CopyState::Done;
return Poll::Ready(Ok(CopyPoll::Finished));
}
Poll::Pending => return Poll::Pending,
},
CopyState::Done => return Poll::Ready(Ok(CopyPoll::Finished)),
}
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct RelayStats {
pub inbound: u64,
pub outbound: u64,
}
impl RelayStats {
#[inline]
pub fn total(self) -> u64 {
self.inbound + self.outbound
}
}
pub async fn relay_bidirectional<A, B, M>(
inbound: A,
outbound: B,
idle_timeout: Duration,
buffer_size: usize,
metrics: &M,
) -> io::Result<RelayStats>
where
A: AsyncRead + AsyncWrite + Unpin,
B: AsyncRead + AsyncWrite + Unpin,
M: RelayMetrics,
{
let (mut in_r, mut in_w) = tokio::io::split(inbound);
let (mut out_r, mut out_w) = tokio::io::split(outbound);
let mut buf_a = vec![0u8; buffer_size];
let mut buf_b = vec![0u8; buffer_size];
let mut state_a = CopyState::Reading(0);
let mut state_b = CopyState::Reading(0);
let idle_sleep = tokio::time::sleep(idle_timeout);
tokio::pin!(idle_sleep);
let mut a_done = false;
let mut b_done = false;
let mut total_inbound: u64 = 0;
let mut total_outbound: u64 = 0;
loop {
if a_done && b_done {
return Ok(RelayStats {
inbound: total_inbound,
outbound: total_outbound,
});
}
let both = std::future::poll_fn(|cx| {
let mut any_ready = false;
let mut activity = false;
let mut error: Option<io::Error> = None;
if !a_done {
match poll_copy_direction(cx, &mut in_r, &mut out_w, &mut buf_a, &mut state_a) {
Poll::Ready(Ok(CopyPoll::Flushed(n))) => {
let bytes = n as u64;
metrics.record_inbound(bytes);
total_inbound += bytes;
activity = true;
any_ready = true;
}
Poll::Ready(Ok(CopyPoll::Finished)) => {
a_done = true;
any_ready = true;
}
Poll::Ready(Err(e)) => {
error = Some(e);
any_ready = true;
}
Poll::Pending => {}
}
}
if !b_done {
match poll_copy_direction(cx, &mut out_r, &mut in_w, &mut buf_b, &mut state_b) {
Poll::Ready(Ok(CopyPoll::Flushed(n))) => {
let bytes = n as u64;
metrics.record_outbound(bytes);
total_outbound += bytes;
activity = true;
any_ready = true;
}
Poll::Ready(Ok(CopyPoll::Finished)) => {
b_done = true;
any_ready = true;
}
Poll::Ready(Err(e)) => {
error = Some(e);
any_ready = true;
}
Poll::Pending => {}
}
}
if let Some(e) = error {
return Poll::Ready(Err(e));
}
if any_ready {
Poll::Ready(Ok(activity))
} else {
Poll::Pending
}
});
tokio::select! {
result = both => {
let activity = result?;
if activity {
idle_sleep.as_mut().reset(TokioInstant::now() + idle_timeout);
}
}
_ = &mut idle_sleep => {
return Ok(RelayStats {
inbound: total_inbound,
outbound: total_outbound,
});
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::VecDeque;
use std::sync::atomic::{AtomicU64, Ordering};
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
struct TestMetrics {
inbound: AtomicU64,
outbound: AtomicU64,
}
impl TestMetrics {
fn new() -> Self {
Self {
inbound: AtomicU64::new(0),
outbound: AtomicU64::new(0),
}
}
}
impl RelayMetrics for TestMetrics {
fn record_inbound(&self, bytes: u64) {
self.inbound.fetch_add(bytes, Ordering::Relaxed);
}
fn record_outbound(&self, bytes: u64) {
self.outbound.fetch_add(bytes, Ordering::Relaxed);
}
}
#[tokio::test]
async fn test_relay_basic() {
let (client, server_side) = duplex(1024);
let (target_side, target) = duplex(1024);
let metrics = TestMetrics::new();
let relay_handle = tokio::spawn(async move {
relay_bidirectional(
server_side,
target_side,
Duration::from_secs(5),
1024,
&metrics,
)
.await
});
let (mut client_r, mut client_w) = tokio::io::split(client);
let (mut target_r, mut target_w) = tokio::io::split(target);
client_w.write_all(b"hello").await.unwrap();
drop(client_w);
let mut buf = vec![0u8; 1024];
let n = target_r.read(&mut buf).await.unwrap();
assert_eq!(&buf[..n], b"hello");
target_w.write_all(b"world").await.unwrap();
drop(target_w);
let n = client_r.read(&mut buf).await.unwrap();
assert_eq!(&buf[..n], b"world");
relay_handle.await.unwrap().unwrap();
}
#[tokio::test]
async fn test_relay_idle_timeout() {
let (client, server_side) = duplex(1024);
let (target_side, _target) = duplex(1024);
let start = TokioInstant::now();
let result = relay_bidirectional(
server_side,
target_side,
Duration::from_millis(50),
1024,
&NoOpMetrics,
)
.await;
result.unwrap();
assert!(start.elapsed() >= Duration::from_millis(50));
drop(client); }
struct MockReader {
chunks: VecDeque<Option<Vec<u8>>>,
pending_waker: Option<std::task::Waker>,
}
impl MockReader {
fn new(chunks: Vec<Option<Vec<u8>>>) -> Self {
Self {
chunks: chunks.into(),
pending_waker: None,
}
}
}
impl AsyncRead for MockReader {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
match self.chunks.front() {
Some(Some(_)) => {
let data = self.chunks.pop_front().unwrap().unwrap();
buf.put_slice(&data);
Poll::Ready(Ok(()))
}
Some(None) => {
self.chunks.pop_front();
self.pending_waker = Some(cx.waker().clone());
cx.waker().wake_by_ref();
Poll::Pending
}
None => {
Poll::Ready(Ok(()))
}
}
}
}
struct FlushCountingWriter {
written: Vec<u8>,
flush_count: usize,
}
impl FlushCountingWriter {
fn new() -> Self {
Self {
written: Vec::new(),
flush_count: 0,
}
}
}
impl AsyncWrite for FlushCountingWriter {
fn poll_write(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
self.written.extend_from_slice(buf);
Poll::Ready(Ok(buf.len()))
}
fn poll_flush(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.flush_count += 1;
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
}
#[tokio::test]
async fn test_flush_batching_consecutive_reads() {
let mut reader = MockReader::new(vec![
Some(b"aaa".to_vec()),
Some(b"bbb".to_vec()),
Some(b"ccc".to_vec()),
]);
let mut writer = FlushCountingWriter::new();
let mut buf = vec![0u8; 64];
let mut state = CopyState::Reading(0);
let mut total_bytes = 0;
loop {
let result = std::future::poll_fn(|cx| {
poll_copy_direction(cx, &mut reader, &mut writer, &mut buf, &mut state)
})
.await
.unwrap();
match result {
CopyPoll::Flushed(n) => total_bytes += n,
CopyPoll::Finished => break,
}
}
assert_eq!(writer.written, b"aaabbbccc");
assert_eq!(total_bytes, 9);
assert_eq!(
writer.flush_count, 1,
"consecutive reads should batch flushes"
);
}
#[tokio::test]
async fn test_flush_on_pending() {
let mut reader = MockReader::new(vec![
Some(b"aaa".to_vec()),
None, Some(b"bbb".to_vec()),
None, ]);
let mut writer = FlushCountingWriter::new();
let mut buf = vec![0u8; 64];
let mut state = CopyState::Reading(0);
let mut total_bytes = 0;
loop {
let result = std::future::poll_fn(|cx| {
poll_copy_direction(cx, &mut reader, &mut writer, &mut buf, &mut state)
})
.await
.unwrap();
match result {
CopyPoll::Flushed(n) => total_bytes += n,
CopyPoll::Finished => break,
}
}
assert_eq!(writer.written, b"aaabbb");
assert_eq!(total_bytes, 6);
assert_eq!(writer.flush_count, 2, "should flush once per Pending gap");
}
#[tokio::test]
async fn test_flush_batching_burst_then_pending() {
let mut reader = MockReader::new(vec![
Some(b"a".to_vec()),
Some(b"b".to_vec()),
Some(b"c".to_vec()),
None, Some(b"d".to_vec()),
]);
let mut writer = FlushCountingWriter::new();
let mut buf = vec![0u8; 64];
let mut state = CopyState::Reading(0);
let mut total_bytes = 0;
loop {
let result = std::future::poll_fn(|cx| {
poll_copy_direction(cx, &mut reader, &mut writer, &mut buf, &mut state)
})
.await
.unwrap();
match result {
CopyPoll::Flushed(n) => total_bytes += n,
CopyPoll::Finished => break,
}
}
assert_eq!(writer.written, b"abcd");
assert_eq!(total_bytes, 4);
assert_eq!(
writer.flush_count, 2,
"burst then pending then EOF = 2 flushes"
);
}
}