use crate::network::transport::{
Transport, TransportAddr, TransportConnection, TransportListener, TransportType,
};
use anyhow::Result;
use std::net::SocketAddr;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener as TokioTcpListener, TcpStream};
use tracing::{debug, error, info};
#[derive(Debug, Clone)]
pub struct TcpTransport {
max_message_length: usize,
}
impl TcpTransport {
pub fn new() -> Self {
Self {
max_message_length: crate::network::protocol::MAX_PROTOCOL_MESSAGE_LENGTH,
}
}
pub fn with_max_message_length(max_message_length: usize) -> Self {
Self { max_message_length }
}
}
impl Default for TcpTransport {
fn default() -> Self {
Self::new()
}
}
#[async_trait::async_trait]
impl Transport for TcpTransport {
type Connection = TcpConnection;
type Listener = TcpListener;
fn transport_type(&self) -> TransportType {
TransportType::Tcp
}
async fn listen(&self, addr: SocketAddr) -> Result<Self::Listener> {
let listener = TokioTcpListener::bind(addr).await?;
Ok(TcpListener {
listener,
max_message_length: self.max_message_length,
})
}
async fn connect(&self, addr: TransportAddr) -> Result<Self::Connection> {
#[allow(irrefutable_let_patterns)]
let TransportAddr::Tcp(socket_addr) = addr
else {
return Err(anyhow::anyhow!(
"TCP transport can only connect to TCP addresses"
));
};
let stream = TcpStream::connect(socket_addr).await?;
let peer_addr = stream.peer_addr()?;
Ok(TcpConnection::new(
stream,
TransportAddr::Tcp(peer_addr),
self.max_message_length,
))
}
}
impl TcpTransport {
pub async fn connect_stream(&self, addr: SocketAddr) -> Result<TcpStream> {
self.connect_stream_with_timeout(addr, 10).await
}
pub async fn connect_stream_with_timeout(
&self,
addr: SocketAddr,
timeout_secs: u64,
) -> Result<TcpStream> {
use tokio::time::{timeout, Duration};
info!("Connecting to peer at {} (timeout {}s)", addr, timeout_secs);
let stream = timeout(Duration::from_secs(timeout_secs), TcpStream::connect(addr))
.await
.map_err(|_| anyhow::anyhow!("Connection timeout ({}s) to {}", timeout_secs, addr))?
.map_err(|e| anyhow::anyhow!("Connection failed to {}: {}", addr, e))?;
Ok(stream)
}
}
pub struct TcpListener {
listener: TokioTcpListener,
max_message_length: usize,
}
impl TcpListener {
pub async fn accept_stream(&mut self) -> Result<(TcpStream, SocketAddr)> {
match self.listener.accept().await {
Ok((stream, addr)) => {
debug!("Accepted TCP connection from {}", addr);
Ok((stream, addr))
}
Err(e) => {
error!("Failed to accept TCP connection: {}", e);
Err(anyhow::anyhow!("Failed to accept connection: {}", e))
}
}
}
}
#[async_trait::async_trait]
impl TransportListener for TcpListener {
type Connection = TcpConnection;
async fn accept(&mut self) -> Result<(Self::Connection, TransportAddr)> {
match self.listener.accept().await {
Ok((stream, addr)) => {
debug!("Accepted TCP connection from {}", addr);
let peer_addr = stream.peer_addr()?;
Ok((
TcpConnection::new(
stream,
TransportAddr::Tcp(peer_addr),
self.max_message_length,
),
TransportAddr::Tcp(addr),
))
}
Err(e) => {
error!("Failed to accept TCP connection: {}", e);
Err(anyhow::anyhow!("Failed to accept connection: {}", e))
}
}
}
fn local_addr(&self) -> Result<SocketAddr> {
self.listener
.local_addr()
.map_err(|e| anyhow::anyhow!("Failed to get local addr: {}", e))
}
}
pub struct TcpConnection {
pub(crate) reader: std::sync::Arc<tokio::sync::Mutex<tokio::io::ReadHalf<TcpStream>>>,
pub(crate) writer: std::sync::Arc<tokio::sync::Mutex<tokio::io::WriteHalf<TcpStream>>>,
pub(crate) peer_addr: TransportAddr,
pub(crate) connected: std::sync::atomic::AtomicBool,
max_message_length: usize,
}
impl TcpConnection {
pub fn new(stream: TcpStream, peer_addr: TransportAddr, max_message_length: usize) -> Self {
let (reader, writer) = tokio::io::split(stream);
Self {
reader: std::sync::Arc::new(tokio::sync::Mutex::new(reader)),
writer: std::sync::Arc::new(tokio::sync::Mutex::new(writer)),
peer_addr,
connected: std::sync::atomic::AtomicBool::new(true),
max_message_length,
}
}
}
#[async_trait::async_trait]
impl TransportConnection for TcpConnection {
async fn send(&mut self, data: &[u8]) -> Result<()> {
use std::sync::atomic::Ordering;
use tokio::io::AsyncWriteExt;
if !self.connected.load(Ordering::Relaxed) {
return Err(anyhow::anyhow!("Connection closed"));
}
let mut writer = self.writer.lock().await;
writer.write_all(data).await?;
writer.flush().await?;
Ok(())
}
async fn recv(&mut self) -> Result<Vec<u8>> {
use std::sync::atomic::Ordering;
use tokio::io::AsyncReadExt;
if !self.connected.load(Ordering::Relaxed) {
return Ok(Vec::new()); }
let mut reader = self.reader.lock().await;
let mut header = [0u8; 24];
match reader.read_exact(&mut header).await {
Ok(_) => {}
Err(e) => {
if e.kind() == std::io::ErrorKind::UnexpectedEof {
self.connected.store(false, Ordering::Relaxed);
return Ok(Vec::new()); }
return Err(anyhow::anyhow!("Failed to read header: {}", e));
}
}
let payload_len =
u32::from_le_bytes([header[16], header[17], header[18], header[19]]) as usize;
if payload_len == 0 {
return Ok(header.to_vec());
}
let max_payload = self.max_message_length.saturating_sub(24);
if payload_len > max_payload {
return Err(anyhow::anyhow!(
"Message payload too large: {} bytes (max: {} bytes)",
payload_len,
max_payload
));
}
let mut payload = vec![0u8; payload_len];
reader.read_exact(&mut payload).await?;
let mut message = header.to_vec();
message.extend_from_slice(&payload);
Ok(message)
}
fn peer_addr(&self) -> TransportAddr {
self.peer_addr.clone()
}
fn is_connected(&self) -> bool {
use std::sync::atomic::Ordering;
self.connected.load(Ordering::Relaxed)
}
async fn close(&mut self) -> Result<()> {
use std::sync::atomic::Ordering;
use tokio::io::AsyncWriteExt;
if self.connected.load(Ordering::Relaxed) {
let mut writer = self.writer.lock().await;
writer.shutdown().await?;
self.connected.store(false, Ordering::Relaxed);
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_tcp_transport_type() {
let transport = TcpTransport::new();
assert_eq!(transport.transport_type(), TransportType::Tcp);
}
#[tokio::test]
async fn test_tcp_transport_listen() {
let transport = TcpTransport::new();
let addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
let listener = transport.listen(addr).await;
assert!(listener.is_ok());
if let Ok(listener) = listener {
let local_addr = listener.local_addr();
assert!(local_addr.is_ok());
}
}
#[tokio::test]
async fn test_tcp_transport_connect_invalid_addr() {
let transport = TcpTransport::new();
#[cfg(feature = "iroh")]
{
let iroh_addr = TransportAddr::Iroh(vec![0u8; 32]);
let result = transport.connect(iroh_addr).await;
assert!(result.is_err());
}
}
}