use async_trait::async_trait;
use bytes::{Buf, BufMut, Bytes, BytesMut};
use parking_lot::Mutex;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::mpsc;
use tracing::{debug, error, info};
use crate::error::{Result, TransportError};
use crate::traits::{TransportEvent, TransportReceiver, TransportSender, TransportServer};
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
const MAX_MESSAGE_SIZE: usize = 64 * 1024;
const DEFAULT_CHANNEL_BUFFER_SIZE: usize = 1000;
#[derive(Debug, Clone)]
pub struct TcpConfig {
pub max_message_size: usize,
pub read_buffer_size: usize,
pub keepalive_secs: u64,
}
impl Default for TcpConfig {
fn default() -> Self {
Self {
max_message_size: MAX_MESSAGE_SIZE,
read_buffer_size: 8192,
keepalive_secs: 30,
}
}
}
pub struct TcpTransport {
config: TcpConfig,
}
impl TcpTransport {
pub fn new() -> Self {
Self {
config: TcpConfig::default(),
}
}
pub fn with_config(config: TcpConfig) -> Self {
Self { config }
}
pub async fn connect(&self, addr: &str) -> Result<(TcpSender, TcpReceiver)> {
info!("Connecting to TCP: {}", addr);
let stream = TcpStream::connect(addr)
.await
.map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
if self.config.keepalive_secs > 0 {
let socket = socket2::SockRef::from(&stream);
let keepalive = socket2::TcpKeepalive::new()
.with_time(std::time::Duration::from_secs(self.config.keepalive_secs));
let _ = socket.set_tcp_keepalive(&keepalive);
}
let connected = Arc::new(Mutex::new(true));
let (outgoing_tx, outgoing_rx) = mpsc::channel::<Bytes>(DEFAULT_CHANNEL_BUFFER_SIZE);
let (incoming_tx, incoming_rx) =
mpsc::channel::<TransportEvent>(DEFAULT_CHANNEL_BUFFER_SIZE);
let sender = TcpSender {
tx: outgoing_tx,
connected: connected.clone(),
};
let receiver = TcpReceiver { rx: incoming_rx };
let max_size = self.config.max_message_size;
let connected_clone = connected.clone();
tokio::spawn(async move {
let (reader, writer) = stream.into_split();
run_tcp_io_loop(
reader,
writer,
outgoing_rx,
incoming_tx,
max_size,
connected_clone,
)
.await;
});
info!("TCP connected to {}", addr);
Ok((sender, receiver))
}
}
impl Default for TcpTransport {
fn default() -> Self {
Self::new()
}
}
async fn run_tcp_io_loop(
mut reader: OwnedReadHalf,
mut writer: OwnedWriteHalf,
mut outgoing_rx: mpsc::Receiver<Bytes>,
incoming_tx: mpsc::Sender<TransportEvent>,
max_size: usize,
connected: Arc<Mutex<bool>>,
) {
let mut read_buf = BytesMut::with_capacity(8192);
loop {
tokio::select! {
Some(data) = outgoing_rx.recv() => {
let len = data.len() as u32;
let mut frame = BytesMut::with_capacity(4 + data.len());
frame.put_u32(len);
frame.extend_from_slice(&data);
if let Err(e) = writer.write_all(&frame).await {
error!("TCP write error: {}", e);
break;
}
}
result = reader.read_buf(&mut read_buf) => {
match result {
Ok(0) => {
debug!("TCP connection closed");
let _ = incoming_tx.send(TransportEvent::Disconnected { reason: None }).await;
break;
}
Ok(_) => {
while read_buf.len() >= 4 {
let len = (&read_buf[..4]).get_u32() as usize;
if len > max_size {
error!("Message too large: {} > {}", len, max_size);
let _ = incoming_tx.send(TransportEvent::Disconnected {
reason: Some(format!("Message too large: {}", len))
}).await;
break;
}
if read_buf.len() >= 4 + len {
read_buf.advance(4);
let data = read_buf.split_to(len).freeze();
if incoming_tx.send(TransportEvent::Data(data)).await.is_err() {
break;
}
} else {
break;
}
}
}
Err(e) => {
error!("TCP read error: {}", e);
let _ = incoming_tx.send(TransportEvent::Error(e.to_string())).await;
break;
}
}
}
}
}
*connected.lock() = false;
}
pub struct TcpSender {
tx: mpsc::Sender<Bytes>,
connected: Arc<Mutex<bool>>,
}
#[async_trait]
impl TransportSender for TcpSender {
async fn send(&self, data: Bytes) -> Result<()> {
if !*self.connected.lock() {
return Err(TransportError::NotConnected);
}
self.tx
.send(data)
.await
.map_err(|_| TransportError::SendFailed("Channel closed".into()))
}
fn try_send(&self, data: Bytes) -> Result<()> {
if !*self.connected.lock() {
return Err(TransportError::NotConnected);
}
self.tx.try_send(data).map_err(|e| match e {
mpsc::error::TrySendError::Full(_) => TransportError::BufferFull,
mpsc::error::TrySendError::Closed(_) => TransportError::ConnectionClosed,
})
}
fn is_connected(&self) -> bool {
*self.connected.lock()
}
async fn close(&self) -> Result<()> {
*self.connected.lock() = false;
Ok(())
}
}
pub struct TcpReceiver {
rx: mpsc::Receiver<TransportEvent>,
}
#[async_trait]
impl TransportReceiver for TcpReceiver {
async fn recv(&mut self) -> Option<TransportEvent> {
self.rx.recv().await
}
}
pub struct TcpServer {
listener: TcpListener,
config: TcpConfig,
}
impl TcpServer {
pub async fn bind(addr: &str) -> Result<Self> {
let listener = TcpListener::bind(addr)
.await
.map_err(|e| TransportError::BindFailed(e.to_string()))?;
info!("TCP server listening on {}", addr);
Ok(Self {
listener,
config: TcpConfig::default(),
})
}
pub async fn bind_with_config(addr: &str, config: TcpConfig) -> Result<Self> {
let listener = TcpListener::bind(addr)
.await
.map_err(|e| TransportError::BindFailed(e.to_string()))?;
info!("TCP server listening on {}", addr);
Ok(Self { listener, config })
}
}
#[async_trait]
impl TransportServer for TcpServer {
type Sender = TcpSender;
type Receiver = TcpReceiver;
async fn accept(&mut self) -> Result<(Self::Sender, Self::Receiver, SocketAddr)> {
let (stream, peer_addr) = self
.listener
.accept()
.await
.map_err(|e| TransportError::AcceptFailed(e.to_string()))?;
info!("TCP connection accepted from {}", peer_addr);
let connected = Arc::new(Mutex::new(true));
let (outgoing_tx, outgoing_rx) = mpsc::channel::<Bytes>(DEFAULT_CHANNEL_BUFFER_SIZE);
let (incoming_tx, incoming_rx) =
mpsc::channel::<TransportEvent>(DEFAULT_CHANNEL_BUFFER_SIZE);
let sender = TcpSender {
tx: outgoing_tx,
connected: connected.clone(),
};
let receiver = TcpReceiver { rx: incoming_rx };
let max_size = self.config.max_message_size;
let connected_clone = connected.clone();
let stream: TcpStream = stream; tokio::spawn(async move {
let (reader, writer) = stream.into_split();
run_tcp_io_loop(
reader,
writer,
outgoing_rx,
incoming_tx,
max_size,
connected_clone,
)
.await;
});
Ok((sender, receiver, peer_addr))
}
fn local_addr(&self) -> Result<SocketAddr> {
self.listener
.local_addr()
.map_err(|e| TransportError::Other(e.to_string()))
}
async fn close(&self) -> Result<()> {
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::time::{sleep, Duration};
#[tokio::test]
async fn test_tcp_config_default() {
let config = TcpConfig::default();
assert_eq!(config.max_message_size, 64 * 1024);
assert_eq!(config.read_buffer_size, 8192);
assert_eq!(config.keepalive_secs, 30);
}
#[tokio::test]
async fn test_tcp_transport_creation() {
let transport = TcpTransport::new();
assert_eq!(transport.config.max_message_size, 64 * 1024);
}
#[tokio::test]
async fn test_tcp_client_server_connection() {
let mut server = TcpServer::bind("127.0.0.1:0").await.unwrap();
let addr = server.local_addr().unwrap();
let accept_handle = tokio::spawn(async move {
let (sender, mut receiver, peer) = server.accept().await.unwrap();
info!("Server accepted connection from {}", peer);
if let Some(TransportEvent::Data(data)) = receiver.recv().await {
sender.send(data).await.unwrap();
}
(sender, receiver)
});
sleep(Duration::from_millis(50)).await;
let transport = TcpTransport::new();
let (client_sender, mut client_receiver) =
transport.connect(&addr.to_string()).await.unwrap();
let test_data = Bytes::from("hello tcp");
client_sender.send(test_data.clone()).await.unwrap();
if let Some(TransportEvent::Data(received)) = client_receiver.recv().await {
assert_eq!(received, test_data);
} else {
panic!("Expected Data event");
}
client_sender.close().await.unwrap();
let _ = accept_handle.await;
}
}