use async_trait::async_trait;
use log::{debug, error, trace};
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
use tokio::net::{TcpStream, ToSocketAddrs};
use tokio::sync::Mutex;
use crate::error::{Error, Result};
use crate::transport::{TransportReceive, TransportSend};
const MAX_FRAME_SIZE: usize = 16 * 1024 * 1024;
#[derive(Debug)]
pub struct TcpTransport {
reader: Mutex<OwnedReadHalf>,
writer: Mutex<OwnedWriteHalf>,
}
impl TcpTransport {
pub async fn connect(addr: impl ToSocketAddrs, timeout: Duration) -> Result<Self> {
let stream = tokio::time::timeout(timeout, TcpStream::connect(addr))
.await
.map_err(|_| Error::Timeout)?
.map_err(Error::Io)?;
stream.set_nodelay(true).map_err(Error::Io)?;
debug!("tcp: connected, nodelay=true");
let (reader, writer) = stream.into_split();
Ok(Self {
reader: Mutex::new(reader),
writer: Mutex::new(writer),
})
}
}
#[async_trait]
impl TransportSend for TcpTransport {
async fn send(&self, data: &[u8]) -> Result<()> {
let len = data.len();
if len > MAX_FRAME_SIZE {
return Err(Error::invalid_data(format!(
"message size {} exceeds maximum frame size {}",
len, MAX_FRAME_SIZE
)));
}
let mut frame_header = [0u8; 4];
frame_header[0] = 0x00;
frame_header[1] = (len >> 16) as u8;
frame_header[2] = (len >> 8) as u8;
frame_header[3] = len as u8;
let mut writer = self.writer.lock().await;
writer.write_all(&frame_header).await.map_err(Error::Io)?;
writer.write_all(data).await.map_err(Error::Io)?;
writer.flush().await.map_err(Error::Io)?;
trace!("tcp: sent frame, len={}", len);
Ok(())
}
}
#[async_trait]
impl TransportReceive for TcpTransport {
async fn receive(&self) -> Result<Vec<u8>> {
let mut reader = self.reader.lock().await;
let mut frame_header = [0u8; 4];
reader.read_exact(&mut frame_header).await.map_err(|e| {
if e.kind() == std::io::ErrorKind::UnexpectedEof {
Error::Disconnected
} else {
Error::Io(e)
}
})?;
if frame_header[0] != 0x00 {
error!("tcp: invalid frame, first byte=0x{:02X}", frame_header[0]);
return Err(Error::invalid_data(format!(
"invalid transport frame: first byte must be 0x00, got 0x{:02X}",
frame_header[0]
)));
}
let msg_len = ((frame_header[1] as usize) << 16)
| ((frame_header[2] as usize) << 8)
| (frame_header[3] as usize);
if msg_len > MAX_FRAME_SIZE {
return Err(Error::invalid_data(format!(
"frame length {} exceeds maximum {}",
msg_len, MAX_FRAME_SIZE
)));
}
trace!("tcp: receiving frame, len={}", msg_len);
let mut buf = vec![0u8; msg_len];
reader.read_exact(&mut buf).await.map_err(|e| {
if e.kind() == std::io::ErrorKind::UnexpectedEof {
Error::Disconnected
} else {
Error::Io(e)
}
})?;
trace!("tcp: received frame, len={}", msg_len);
Ok(buf)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn frame_message(payload: &[u8]) -> Vec<u8> {
let len = payload.len();
let mut frame = Vec::with_capacity(4 + len);
frame.push(0x00);
frame.push((len >> 16) as u8);
frame.push((len >> 8) as u8);
frame.push(len as u8);
frame.extend_from_slice(payload);
frame
}
#[test]
fn frame_header_format_small_message() {
let payload = vec![0xFE, 0x53, 0x4D, 0x42]; let framed = frame_message(&payload);
assert_eq!(framed[0], 0x00, "first byte must be 0x00");
assert_eq!(framed[1], 0x00, "length high byte");
assert_eq!(framed[2], 0x00, "length mid byte");
assert_eq!(framed[3], 0x04, "length low byte = 4");
assert_eq!(&framed[4..], &payload);
}
#[test]
fn frame_header_format_medium_message() {
let payload = vec![0xAA; 300];
let framed = frame_message(&payload);
assert_eq!(framed[0], 0x00);
assert_eq!(framed[1], 0x00);
assert_eq!(framed[2], 0x01);
assert_eq!(framed[3], 0x2C);
assert_eq!(framed.len(), 304);
}
#[test]
fn frame_header_format_large_message() {
let payload = vec![0xBB; 66051];
let framed = frame_message(&payload);
assert_eq!(framed[0], 0x00);
assert_eq!(framed[1], 0x01);
assert_eq!(framed[2], 0x02);
assert_eq!(framed[3], 0x03);
}
#[test]
fn frame_header_empty_payload() {
let framed = frame_message(&[]);
assert_eq!(framed, vec![0x00, 0x00, 0x00, 0x00]);
}
async fn receive_from_bytes(data: &[u8]) -> Result<Vec<u8>> {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let data = data.to_vec();
let writer_task = tokio::spawn(async move {
let mut stream = TcpStream::connect(addr).await.unwrap();
stream.write_all(&data).await.unwrap();
stream.shutdown().await.unwrap();
});
let (stream, _) = listener.accept().await.unwrap();
let (reader, writer) = stream.into_split();
let transport = TcpTransport {
reader: Mutex::new(reader),
writer: Mutex::new(writer),
};
let result = transport.receive().await;
writer_task.await.unwrap();
result
}
#[tokio::test]
async fn receive_valid_frame() {
let payload = vec![0xFE, 0x53, 0x4D, 0x42, 0x01, 0x02];
let framed = frame_message(&payload);
let received = receive_from_bytes(&framed).await.unwrap();
assert_eq!(received, payload);
}
#[tokio::test]
async fn receive_empty_payload() {
let framed = frame_message(&[]);
let received = receive_from_bytes(&framed).await.unwrap();
assert!(received.is_empty());
}
#[tokio::test]
async fn receive_first_byte_not_zero_returns_error() {
let data = vec![0x01, 0x00, 0x00, 0x04, 0xAA, 0xBB, 0xCC, 0xDD];
let result = receive_from_bytes(&data).await;
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("first byte must be 0x00"),
"unexpected error: {err}"
);
}
#[tokio::test]
async fn receive_length_exceeds_max_returns_error() {
let data = vec![0x00, 0xFF, 0xFF, 0xFF];
let result = receive_from_bytes(&data).await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
matches!(err, Error::Disconnected),
"expected Disconnected for truncated large frame, got: {err}"
);
}
#[tokio::test]
async fn receive_disconnected_on_eof() {
let result = receive_from_bytes(&[]).await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
matches!(err, Error::Disconnected),
"expected Disconnected, got: {err}"
);
}
#[tokio::test]
async fn receive_partial_header_returns_disconnected() {
let result = receive_from_bytes(&[0x00, 0x00]).await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
matches!(err, Error::Disconnected),
"expected Disconnected for partial header, got: {err}"
);
}
#[tokio::test]
async fn receive_partial_payload_returns_disconnected() {
let data = vec![0x00, 0x00, 0x00, 0x0A, 0x01, 0x02, 0x03];
let result = receive_from_bytes(&data).await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
matches!(err, Error::Disconnected),
"expected Disconnected for truncated payload, got: {err}"
);
}
#[tokio::test]
async fn send_and_receive_roundtrip() {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let send_task = tokio::spawn(async move {
let stream = TcpStream::connect(addr).await.unwrap();
let (reader, writer) = stream.into_split();
let transport = TcpTransport {
reader: Mutex::new(reader),
writer: Mutex::new(writer),
};
let payload = vec![0xFE, 0x53, 0x4D, 0x42, 0xDE, 0xAD];
transport.send(&payload).await.unwrap();
});
let (stream, _) = listener.accept().await.unwrap();
let (reader, writer) = stream.into_split();
let recv_transport = TcpTransport {
reader: Mutex::new(reader),
writer: Mutex::new(writer),
};
let received = recv_transport.receive().await.unwrap();
assert_eq!(received, vec![0xFE, 0x53, 0x4D, 0x42, 0xDE, 0xAD]);
send_task.await.unwrap();
}
#[tokio::test]
async fn send_and_receive_multiple_messages() {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let send_task = tokio::spawn(async move {
let stream = TcpStream::connect(addr).await.unwrap();
let (reader, writer) = stream.into_split();
let transport = TcpTransport {
reader: Mutex::new(reader),
writer: Mutex::new(writer),
};
transport.send(&[0x01, 0x02]).await.unwrap();
transport.send(&[0x03, 0x04, 0x05]).await.unwrap();
transport.send(&[0x06]).await.unwrap();
});
let (stream, _) = listener.accept().await.unwrap();
let (reader, writer) = stream.into_split();
let recv_transport = TcpTransport {
reader: Mutex::new(reader),
writer: Mutex::new(writer),
};
assert_eq!(recv_transport.receive().await.unwrap(), vec![0x01, 0x02]);
assert_eq!(
recv_transport.receive().await.unwrap(),
vec![0x03, 0x04, 0x05]
);
assert_eq!(recv_transport.receive().await.unwrap(), vec![0x06]);
send_task.await.unwrap();
}
#[tokio::test]
async fn partial_reads_are_handled_by_read_exact() {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let payload: Vec<u8> = (0..=255).cycle().take(8192).collect();
let payload_clone = payload.clone();
let send_task = tokio::spawn(async move {
let stream = TcpStream::connect(addr).await.unwrap();
let (reader, writer) = stream.into_split();
let transport = TcpTransport {
reader: Mutex::new(reader),
writer: Mutex::new(writer),
};
transport.send(&payload_clone).await.unwrap();
});
let (stream, _) = listener.accept().await.unwrap();
let (reader, writer) = stream.into_split();
let recv_transport = TcpTransport {
reader: Mutex::new(reader),
writer: Mutex::new(writer),
};
let received = recv_transport.receive().await.unwrap();
assert_eq!(received.len(), payload.len());
assert_eq!(received, payload);
send_task.await.unwrap();
}
#[tokio::test]
async fn connect_with_timeout() {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let transport = TcpTransport::connect(addr, Duration::from_secs(5))
.await
.unwrap();
let (server_stream, _) = listener.accept().await.unwrap();
let (server_reader, mut server_writer) = server_stream.into_split();
drop(server_reader);
let payload = vec![0xDE, 0xAD, 0xBE, 0xEF];
let mut frame = vec![0x00, 0x00, 0x00, 0x04];
frame.extend_from_slice(&payload);
server_writer.write_all(&frame).await.unwrap();
server_writer.flush().await.unwrap();
let received = transport.receive().await.unwrap();
assert_eq!(received, payload);
}
#[tokio::test]
async fn connect_timeout_fires() {
let result = TcpTransport::connect("192.0.2.1:445", Duration::from_millis(100)).await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
matches!(err, Error::Timeout | Error::Io(_)),
"expected Timeout or Io error, got: {err}"
);
}
}