pub mod stats;
pub mod stream;
use super::resolve_socket_addr;
use super::{
ConnectionState, DiscoveredPeer, PacketTx, ReceivedPacket, Transport, TransportAddr,
TransportError, TransportId, TransportState, TransportType,
};
use crate::config::TcpConfig;
use stats::TcpStats;
use stream::read_fmp_packet;
use futures::FutureExt;
use socket2::TcpKeepalive;
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use tokio::io::AsyncWriteExt;
use tokio::net::tcp::OwnedWriteHalf;
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::Mutex;
use tokio::task::JoinHandle;
use tokio::time::Instant;
use tracing::{debug, info, trace, warn};
struct TcpConnection {
writer: Arc<Mutex<OwnedWriteHalf>>,
recv_task: JoinHandle<()>,
#[allow(dead_code)]
mtu: u16,
#[allow(dead_code)]
established_at: Instant,
}
type ConnectionPool = Arc<Mutex<HashMap<TransportAddr, TcpConnection>>>;
struct ConnectingEntry {
task: JoinHandle<Result<(TcpStream, u16), TransportError>>,
}
type ConnectingPool = Arc<Mutex<HashMap<TransportAddr, ConnectingEntry>>>;
pub struct TcpTransport {
transport_id: TransportId,
name: Option<String>,
config: TcpConfig,
state: TransportState,
pool: ConnectionPool,
connecting: ConnectingPool,
packet_tx: PacketTx,
accept_task: Option<JoinHandle<()>>,
local_addr: Option<SocketAddr>,
stats: Arc<TcpStats>,
}
impl TcpTransport {
pub fn new(
transport_id: TransportId,
name: Option<String>,
config: TcpConfig,
packet_tx: PacketTx,
) -> Self {
Self {
transport_id,
name,
config,
state: TransportState::Configured,
pool: Arc::new(Mutex::new(HashMap::new())),
connecting: Arc::new(Mutex::new(HashMap::new())),
packet_tx,
accept_task: None,
local_addr: None,
stats: Arc::new(TcpStats::new()),
}
}
pub fn name(&self) -> Option<&str> {
self.name.as_deref()
}
pub fn local_addr(&self) -> Option<SocketAddr> {
self.local_addr
}
pub fn stats(&self) -> &Arc<TcpStats> {
&self.stats
}
pub async fn start_async(&mut self) -> Result<(), TransportError> {
if !self.state.can_start() {
return Err(TransportError::AlreadyStarted);
}
self.state = TransportState::Starting;
if let Some(ref bind_addr) = self.config.bind_addr {
let addr: SocketAddr = bind_addr
.parse()
.map_err(|e| TransportError::StartFailed(format!("invalid bind address: {}", e)))?;
let listener = TcpListener::bind(addr)
.await
.map_err(|e| TransportError::StartFailed(format!("bind failed: {}", e)))?;
self.local_addr = Some(
listener
.local_addr()
.map_err(|e| TransportError::StartFailed(format!("get local addr: {}", e)))?,
);
let transport_id = self.transport_id;
let packet_tx = self.packet_tx.clone();
let pool = self.pool.clone();
let stats = self.stats.clone();
let cfg = AcceptConfig {
mtu: self.config.mtu(),
max_inbound: self.config.max_inbound_connections(),
nodelay: self.config.nodelay(),
keepalive_secs: self.config.keepalive_secs(),
recv_buf: self.config.recv_buf_size(),
send_buf: self.config.send_buf_size(),
};
let accept_task = tokio::spawn(async move {
accept_loop(listener, transport_id, packet_tx, pool, cfg, stats).await;
});
self.accept_task = Some(accept_task);
}
self.state = TransportState::Up;
if let Some(ref name) = self.name {
info!(
name = %name,
local_addr = ?self.local_addr,
mtu = self.config.mtu(),
"TCP transport started"
);
} else {
info!(
local_addr = ?self.local_addr,
mtu = self.config.mtu(),
"TCP transport started"
);
}
Ok(())
}
pub async fn stop_async(&mut self) -> Result<(), TransportError> {
if !self.state.is_operational() {
return Err(TransportError::NotStarted);
}
if let Some(task) = self.accept_task.take() {
task.abort();
let _ = task.await;
}
let mut connecting = self.connecting.lock().await;
for (addr, entry) in connecting.drain() {
entry.task.abort();
debug!(
transport_id = %self.transport_id,
remote_addr = %addr,
"TCP connect aborted (transport stopping)"
);
}
drop(connecting);
let mut pool = self.pool.lock().await;
for (addr, conn) in pool.drain() {
conn.recv_task.abort();
let _ = conn.recv_task.await;
debug!(
transport_id = %self.transport_id,
remote_addr = %addr,
"TCP connection closed (transport stopping)"
);
}
drop(pool);
self.local_addr = None;
self.state = TransportState::Down;
info!(
transport_id = %self.transport_id,
"TCP transport stopped"
);
Ok(())
}
pub async fn send_async(
&self,
addr: &TransportAddr,
data: &[u8],
) -> Result<usize, TransportError> {
if !self.state.is_operational() {
return Err(TransportError::NotStarted);
}
let mtu = self.config.mtu() as usize;
if data.len() > mtu {
self.stats.record_mtu_exceeded();
return Err(TransportError::MtuExceeded {
packet_size: data.len(),
mtu: self.config.mtu(),
});
}
let writer = {
let pool = self.pool.lock().await;
pool.get(addr).map(|c| c.writer.clone())
};
let writer = match writer {
Some(w) => w,
None => {
self.connect(addr).await?
}
};
let mut w = writer.lock().await;
match w.write_all(data).await {
Ok(()) => {
self.stats.record_send(data.len());
trace!(
transport_id = %self.transport_id,
remote_addr = %addr,
bytes = data.len(),
"TCP packet sent"
);
Ok(data.len())
}
Err(e) => {
self.stats.record_send_error();
drop(w);
let mut pool = self.pool.lock().await;
if let Some(conn) = pool.remove(addr) {
conn.recv_task.abort();
}
Err(TransportError::SendFailed(format!("{}", e)))
}
}
}
async fn connect(
&self,
addr: &TransportAddr,
) -> Result<Arc<Mutex<OwnedWriteHalf>>, TransportError> {
let socket_addr = resolve_socket_addr(addr).await?;
let timeout_ms = self.config.connect_timeout_ms();
let stream = match tokio::time::timeout(
Duration::from_millis(timeout_ms),
TcpStream::connect(socket_addr),
)
.await
{
Ok(Ok(stream)) => stream,
Ok(Err(_)) => {
self.stats.record_connect_refused();
return Err(TransportError::ConnectionRefused);
}
Err(_) => {
self.stats.record_connect_timeout();
return Err(TransportError::Timeout);
}
};
let std_stream = stream
.into_std()
.map_err(|e| TransportError::StartFailed(format!("into_std: {}", e)))?;
configure_socket(&std_stream, &self.config)?;
let mss_mtu = read_mss_mtu(&std_stream, self.config.mtu());
let stream = TcpStream::from_std(std_stream)
.map_err(|e| TransportError::StartFailed(format!("from_std: {}", e)))?;
let (read_half, write_half) = stream.into_split();
let writer = Arc::new(Mutex::new(write_half));
let transport_id = self.transport_id;
let packet_tx = self.packet_tx.clone();
let pool = self.pool.clone();
let recv_stats = self.stats.clone();
let remote_addr = addr.clone();
let mtu = mss_mtu;
let recv_task = tokio::spawn(async move {
tcp_receive_loop(
read_half,
transport_id,
remote_addr.clone(),
packet_tx,
pool,
mtu,
recv_stats,
)
.await;
});
let conn = TcpConnection {
writer: writer.clone(),
recv_task,
mtu: mss_mtu,
established_at: Instant::now(),
};
let mut pool = self.pool.lock().await;
pool.insert(addr.clone(), conn);
self.stats.record_connection_established();
debug!(
transport_id = %self.transport_id,
remote_addr = %addr,
mtu = mss_mtu,
"TCP connection established (connect-on-send)"
);
Ok(writer)
}
pub async fn close_connection_async(&self, addr: &TransportAddr) {
let mut pool = self.pool.lock().await;
if let Some(conn) = pool.remove(addr) {
conn.recv_task.abort();
debug!(
transport_id = %self.transport_id,
remote_addr = %addr,
"TCP connection closed (close_connection)"
);
}
}
pub async fn connect_async(&self, addr: &TransportAddr) -> Result<(), TransportError> {
if !self.state.is_operational() {
return Err(TransportError::NotStarted);
}
{
let pool = self.pool.lock().await;
if pool.contains_key(addr) {
return Ok(());
}
}
{
let connecting = self.connecting.lock().await;
if connecting.contains_key(addr) {
return Ok(());
}
}
let addr_string = addr
.as_str()
.ok_or_else(|| TransportError::InvalidAddress("not valid UTF-8".into()))?
.to_string();
let timeout_ms = self.config.connect_timeout_ms();
let config = self.config.clone();
let transport_id = self.transport_id;
let remote_addr = addr.clone();
debug!(
transport_id = %transport_id,
remote_addr = %remote_addr,
timeout_ms,
"Initiating background TCP connect"
);
let task = tokio::spawn(async move {
let socket_addr: SocketAddr = if let Ok(sa) = addr_string.parse() {
sa
} else {
tokio::net::lookup_host(&addr_string)
.await
.map_err(|e| {
TransportError::InvalidAddress(format!(
"DNS resolution failed for {}: {}",
addr_string, e
))
})?
.next()
.ok_or_else(|| {
TransportError::InvalidAddress(format!(
"DNS resolution returned no addresses for {}",
addr_string
))
})?
};
let stream = match tokio::time::timeout(
Duration::from_millis(timeout_ms),
TcpStream::connect(socket_addr),
)
.await
{
Ok(Ok(stream)) => stream,
Ok(Err(e)) => {
debug!(
transport_id = %transport_id,
remote_addr = %remote_addr,
error = %e,
"Background TCP connect refused"
);
return Err(TransportError::ConnectionRefused);
}
Err(_) => {
debug!(
transport_id = %transport_id,
remote_addr = %remote_addr,
"Background TCP connect timed out"
);
return Err(TransportError::Timeout);
}
};
let std_stream = stream
.into_std()
.map_err(|e| TransportError::StartFailed(format!("into_std: {}", e)))?;
configure_socket(&std_stream, &config)?;
let mss_mtu = read_mss_mtu(&std_stream, config.mtu());
let stream = TcpStream::from_std(std_stream)
.map_err(|e| TransportError::StartFailed(format!("from_std: {}", e)))?;
Ok((stream, mss_mtu))
});
let mut connecting = self.connecting.lock().await;
connecting.insert(addr.clone(), ConnectingEntry { task });
Ok(())
}
pub fn connection_state_sync(&self, addr: &TransportAddr) -> ConnectionState {
if let Ok(pool) = self.pool.try_lock() {
if pool.contains_key(addr) {
return ConnectionState::Connected;
}
} else {
return ConnectionState::Connecting; }
let mut connecting = match self.connecting.try_lock() {
Ok(c) => c,
Err(_) => return ConnectionState::Connecting,
};
let entry = match connecting.get_mut(addr) {
Some(e) => e,
None => return ConnectionState::None,
};
if !entry.task.is_finished() {
return ConnectionState::Connecting;
}
let addr_clone = addr.clone();
let task = connecting.remove(&addr_clone).unwrap().task;
match task.now_or_never() {
Some(Ok(Ok((stream, mss_mtu)))) => {
self.promote_connection(addr, stream, mss_mtu);
ConnectionState::Connected
}
Some(Ok(Err(e))) => ConnectionState::Failed(format!("{}", e)),
Some(Err(e)) => {
ConnectionState::Failed(format!("task failed: {}", e))
}
None => {
ConnectionState::Connecting
}
}
}
fn promote_connection(&self, addr: &TransportAddr, stream: TcpStream, mss_mtu: u16) {
let (read_half, write_half) = stream.into_split();
let writer = Arc::new(Mutex::new(write_half));
let transport_id = self.transport_id;
let packet_tx = self.packet_tx.clone();
let pool = self.pool.clone();
let recv_stats = self.stats.clone();
let remote_addr = addr.clone();
let recv_task = tokio::spawn(async move {
tcp_receive_loop(
read_half,
transport_id,
remote_addr.clone(),
packet_tx,
pool,
mss_mtu,
recv_stats,
)
.await;
});
let conn = TcpConnection {
writer,
recv_task,
mtu: mss_mtu,
established_at: Instant::now(),
};
if let Ok(mut pool) = self.pool.try_lock() {
pool.insert(addr.clone(), conn);
self.stats.record_connection_established();
debug!(
transport_id = %self.transport_id,
remote_addr = %addr,
mtu = mss_mtu,
"TCP connection established (background connect)"
);
} else {
conn.recv_task.abort();
warn!(
transport_id = %self.transport_id,
remote_addr = %addr,
"Failed to promote connection (pool locked)"
);
}
}
}
impl Transport for TcpTransport {
fn transport_id(&self) -> TransportId {
self.transport_id
}
fn transport_type(&self) -> &TransportType {
&TransportType::TCP
}
fn state(&self) -> TransportState {
self.state
}
fn mtu(&self) -> u16 {
self.config.mtu()
}
fn link_mtu(&self, _addr: &TransportAddr) -> u16 {
self.config.mtu()
}
fn start(&mut self) -> Result<(), TransportError> {
Err(TransportError::NotSupported(
"use start_async() for TCP transport".into(),
))
}
fn stop(&mut self) -> Result<(), TransportError> {
Err(TransportError::NotSupported(
"use stop_async() for TCP transport".into(),
))
}
fn send(&self, _addr: &TransportAddr, _data: &[u8]) -> Result<(), TransportError> {
Err(TransportError::NotSupported(
"use send_async() for TCP transport".into(),
))
}
fn discover(&self) -> Result<Vec<DiscoveredPeer>, TransportError> {
Ok(Vec::new())
}
fn accept_connections(&self) -> bool {
self.config.bind_addr.is_some()
}
}
struct AcceptConfig {
mtu: u16,
max_inbound: usize,
nodelay: bool,
keepalive_secs: u64,
recv_buf: usize,
send_buf: usize,
}
#[allow(clippy::too_many_arguments)]
async fn accept_loop(
listener: TcpListener,
transport_id: TransportId,
packet_tx: PacketTx,
pool: ConnectionPool,
cfg: AcceptConfig,
stats: Arc<TcpStats>,
) {
let AcceptConfig {
mtu,
max_inbound,
nodelay,
keepalive_secs,
recv_buf,
send_buf,
} = cfg;
debug!(transport_id = %transport_id, "TCP accept loop starting");
loop {
match listener.accept().await {
Ok((stream, peer_addr)) => {
{
let pool_guard = pool.lock().await;
if pool_guard.len() >= max_inbound {
stats.record_connection_rejected();
warn!(
transport_id = %transport_id,
peer_addr = %peer_addr,
max = max_inbound,
"Rejecting inbound TCP connection (max_inbound_connections reached)"
);
continue;
}
}
let std_stream = match stream.into_std() {
Ok(s) => s,
Err(e) => {
warn!(
transport_id = %transport_id,
error = %e,
"Failed to convert accepted stream to std"
);
continue;
}
};
if let Err(e) = configure_accepted_socket(
&std_stream,
nodelay,
keepalive_secs,
recv_buf,
send_buf,
) {
warn!(
transport_id = %transport_id,
peer_addr = %peer_addr,
error = %e,
"Failed to configure accepted socket"
);
continue;
}
let conn_mtu = read_mss_mtu(&std_stream, mtu);
let stream = match TcpStream::from_std(std_stream) {
Ok(s) => s,
Err(e) => {
warn!(
transport_id = %transport_id,
error = %e,
"Failed to convert accepted stream back to tokio"
);
continue;
}
};
let remote_addr = TransportAddr::from_string(&peer_addr.to_string());
let (read_half, write_half) = stream.into_split();
let writer = Arc::new(Mutex::new(write_half));
let recv_pool = pool.clone();
let recv_packet_tx = packet_tx.clone();
let recv_stats = stats.clone();
let recv_addr = remote_addr.clone();
let recv_task = tokio::spawn(async move {
tcp_receive_loop(
read_half,
transport_id,
recv_addr,
recv_packet_tx,
recv_pool,
conn_mtu,
recv_stats,
)
.await;
});
let conn = TcpConnection {
writer,
recv_task,
mtu: conn_mtu,
established_at: Instant::now(),
};
let mut pool_guard = pool.lock().await;
pool_guard.insert(remote_addr.clone(), conn);
stats.record_connection_accepted();
debug!(
transport_id = %transport_id,
remote_addr = %remote_addr,
mtu = conn_mtu,
"Accepted inbound TCP connection"
);
}
Err(e) => {
warn!(
transport_id = %transport_id,
error = %e,
"TCP accept error"
);
}
}
}
}
async fn tcp_receive_loop(
mut reader: tokio::net::tcp::OwnedReadHalf,
transport_id: TransportId,
remote_addr: TransportAddr,
packet_tx: PacketTx,
pool: ConnectionPool,
mtu: u16,
stats: Arc<TcpStats>,
) {
debug!(
transport_id = %transport_id,
remote_addr = %remote_addr,
"TCP receive loop starting"
);
loop {
match read_fmp_packet(&mut reader, mtu).await {
Ok(data) => {
stats.record_recv(data.len());
trace!(
transport_id = %transport_id,
remote_addr = %remote_addr,
bytes = data.len(),
"TCP packet received"
);
let packet = ReceivedPacket::new(transport_id, remote_addr.clone(), data);
if packet_tx.send(packet).is_err() {
debug!(
transport_id = %transport_id,
"Packet channel closed, stopping TCP receive loop"
);
break;
}
}
Err(e) => {
stats.record_recv_error();
debug!(
transport_id = %transport_id,
remote_addr = %remote_addr,
error = %e,
"TCP receive error, removing connection"
);
break;
}
}
}
let mut pool_guard = pool.lock().await;
pool_guard.remove(&remote_addr);
debug!(
transport_id = %transport_id,
remote_addr = %remote_addr,
"TCP receive loop stopped"
);
}
fn configure_socket(
stream: &std::net::TcpStream,
config: &TcpConfig,
) -> Result<(), TransportError> {
let socket = socket2::SockRef::from(stream)
.try_clone()
.map_err(|e| TransportError::StartFailed(format!("clone socket: {}", e)))?;
socket
.set_tcp_nodelay(config.nodelay())
.map_err(|e| TransportError::StartFailed(format!("set nodelay: {}", e)))?;
let keepalive_secs = config.keepalive_secs();
if keepalive_secs > 0 {
let keepalive = TcpKeepalive::new().with_time(Duration::from_secs(keepalive_secs));
socket
.set_tcp_keepalive(&keepalive)
.map_err(|e| TransportError::StartFailed(format!("set keepalive: {}", e)))?;
}
socket
.set_recv_buffer_size(config.recv_buf_size())
.map_err(|e| TransportError::StartFailed(format!("set recv buffer: {}", e)))?;
socket
.set_send_buffer_size(config.send_buf_size())
.map_err(|e| TransportError::StartFailed(format!("set send buffer: {}", e)))?;
Ok(())
}
fn configure_accepted_socket(
stream: &std::net::TcpStream,
nodelay: bool,
keepalive_secs: u64,
recv_buf: usize,
send_buf: usize,
) -> Result<(), TransportError> {
let socket = socket2::SockRef::from(stream)
.try_clone()
.map_err(|e| TransportError::StartFailed(format!("clone socket: {}", e)))?;
socket
.set_tcp_nodelay(nodelay)
.map_err(|e| TransportError::StartFailed(format!("set nodelay: {}", e)))?;
if keepalive_secs > 0 {
let keepalive = TcpKeepalive::new().with_time(Duration::from_secs(keepalive_secs));
socket
.set_tcp_keepalive(&keepalive)
.map_err(|e| TransportError::StartFailed(format!("set keepalive: {}", e)))?;
}
socket
.set_recv_buffer_size(recv_buf)
.map_err(|e| TransportError::StartFailed(format!("set recv buffer: {}", e)))?;
socket
.set_send_buffer_size(send_buf)
.map_err(|e| TransportError::StartFailed(format!("set send buffer: {}", e)))?;
Ok(())
}
fn read_mss_mtu(stream: &std::net::TcpStream, default_mtu: u16) -> u16 {
#[cfg(target_os = "linux")]
{
use std::os::unix::io::AsRawFd;
unsafe {
let mut mss: libc::c_int = 0;
let mut len: libc::socklen_t = std::mem::size_of::<libc::c_int>() as libc::socklen_t;
let fd = stream.as_raw_fd();
let ret = libc::getsockopt(
fd,
libc::IPPROTO_TCP,
libc::TCP_MAXSEG,
&mut mss as *mut libc::c_int as *mut libc::c_void,
&mut len,
);
if ret == 0 && mss > 0 {
let mss_mtu = (mss as u32).min(u16::MAX as u32) as u16;
return mss_mtu.min(default_mtu);
}
}
}
#[cfg(not(target_os = "linux"))]
let _ = stream;
default_mtu
}
#[cfg(test)]
mod tests {
use super::*;
use crate::transport::packet_channel;
use tokio::time::{Duration, timeout};
fn make_config() -> TcpConfig {
TcpConfig {
bind_addr: Some("127.0.0.1:0".to_string()),
mtu: Some(1400),
..Default::default()
}
}
fn make_outbound_config() -> TcpConfig {
TcpConfig {
bind_addr: None,
mtu: Some(1400),
..Default::default()
}
}
#[tokio::test]
async fn test_start_stop() {
let (tx, _rx) = packet_channel(100);
let mut transport = TcpTransport::new(TransportId::new(1), None, make_config(), tx);
assert_eq!(transport.state(), TransportState::Configured);
transport.start_async().await.unwrap();
assert_eq!(transport.state(), TransportState::Up);
assert!(transport.local_addr().is_some());
transport.stop_async().await.unwrap();
assert_eq!(transport.state(), TransportState::Down);
}
#[tokio::test]
async fn test_start_outbound_only() {
let (tx, _rx) = packet_channel(100);
let mut transport =
TcpTransport::new(TransportId::new(1), None, make_outbound_config(), tx);
transport.start_async().await.unwrap();
assert_eq!(transport.state(), TransportState::Up);
assert!(transport.local_addr().is_none());
transport.stop_async().await.unwrap();
}
#[tokio::test]
async fn test_double_start_fails() {
let (tx, _rx) = packet_channel(100);
let mut transport = TcpTransport::new(TransportId::new(1), None, make_config(), tx);
transport.start_async().await.unwrap();
let result = transport.start_async().await;
assert!(matches!(result, Err(TransportError::AlreadyStarted)));
transport.stop_async().await.unwrap();
}
#[tokio::test]
async fn test_stop_not_started_fails() {
let (tx, _rx) = packet_channel(100);
let mut transport = TcpTransport::new(TransportId::new(1), None, make_config(), tx);
let result = transport.stop_async().await;
assert!(matches!(result, Err(TransportError::NotStarted)));
}
#[tokio::test]
async fn test_send_not_started() {
let (tx, _rx) = packet_channel(100);
let transport = TcpTransport::new(TransportId::new(1), None, make_config(), tx);
let result = transport
.send_async(&TransportAddr::from_string("127.0.0.1:9999"), b"test")
.await;
assert!(matches!(result, Err(TransportError::NotStarted)));
}
#[tokio::test]
async fn test_send_recv() {
let (tx1, _rx1) = packet_channel(100);
let (tx2, mut rx2) = packet_channel(100);
let mut t1 = TcpTransport::new(TransportId::new(1), None, make_outbound_config(), tx1);
let mut t2 = TcpTransport::new(TransportId::new(2), None, make_config(), tx2);
t1.start_async().await.unwrap();
t2.start_async().await.unwrap();
let addr2 = t2.local_addr().unwrap();
let payload_len = 4u16;
let total = 4 + 12 + payload_len as usize + 16;
let mut frame = vec![0u8; total];
frame[0] = 0x00; frame[1] = 0x00; frame[2..4].copy_from_slice(&payload_len.to_le_bytes());
for (i, byte) in frame[4..total].iter_mut().enumerate() {
*byte = ((4 + i) & 0xFF) as u8;
}
let bytes_sent = t1
.send_async(&TransportAddr::from_string(&addr2.to_string()), &frame)
.await
.unwrap();
assert_eq!(bytes_sent, frame.len());
let packet = timeout(Duration::from_secs(2), rx2.recv())
.await
.expect("timeout")
.expect("channel closed");
assert_eq!(packet.data, frame);
t1.stop_async().await.unwrap();
t2.stop_async().await.unwrap();
}
#[tokio::test]
async fn test_bidirectional() {
let (tx1, mut rx1) = packet_channel(100);
let (tx2, mut rx2) = packet_channel(100);
let mut t1 = TcpTransport::new(TransportId::new(1), None, make_config(), tx1);
let mut t2 = TcpTransport::new(TransportId::new(2), None, make_config(), tx2);
t1.start_async().await.unwrap();
t2.start_async().await.unwrap();
let addr1 = t1.local_addr().unwrap();
let addr2 = t2.local_addr().unwrap();
let mut msg1_frame = vec![0xAA; 114];
msg1_frame[0] = 0x01; msg1_frame[1] = 0x00;
msg1_frame[2..4].copy_from_slice(&110u16.to_le_bytes());
t1.send_async(&TransportAddr::from_string(&addr2.to_string()), &msg1_frame)
.await
.unwrap();
let packet = timeout(Duration::from_secs(2), rx2.recv())
.await
.expect("timeout")
.expect("channel closed");
assert_eq!(packet.data, msg1_frame);
let mut msg2_frame = vec![0xBB; 69];
msg2_frame[0] = 0x02; msg2_frame[1] = 0x00;
msg2_frame[2..4].copy_from_slice(&65u16.to_le_bytes());
t2.send_async(&TransportAddr::from_string(&addr1.to_string()), &msg2_frame)
.await
.unwrap();
let packet = timeout(Duration::from_secs(2), rx1.recv())
.await
.expect("timeout")
.expect("channel closed");
assert_eq!(packet.data, msg2_frame);
t1.stop_async().await.unwrap();
t2.stop_async().await.unwrap();
}
#[tokio::test]
async fn test_connect_timeout() {
let (tx, _rx) = packet_channel(100);
let config = TcpConfig {
bind_addr: None,
connect_timeout_ms: Some(100), ..Default::default()
};
let mut transport = TcpTransport::new(TransportId::new(1), None, config, tx);
transport.start_async().await.unwrap();
let result = transport
.send_async(
&TransportAddr::from_string("192.0.2.1:2121"),
b"\x00\x00\x04\x00test1234567890123456789012345678",
)
.await;
assert!(result.is_err());
transport.stop_async().await.unwrap();
}
#[tokio::test]
async fn test_close_connection() {
let (tx1, _rx1) = packet_channel(100);
let (tx2, _rx2) = packet_channel(100);
let mut t1 = TcpTransport::new(TransportId::new(1), None, make_outbound_config(), tx1);
let mut t2 = TcpTransport::new(TransportId::new(2), None, make_config(), tx2);
t1.start_async().await.unwrap();
t2.start_async().await.unwrap();
let addr2 = t2.local_addr().unwrap();
let remote = TransportAddr::from_string(&addr2.to_string());
let mut msg1 = vec![0xAA; 114];
msg1[0] = 0x01;
msg1[1] = 0x00;
msg1[2..4].copy_from_slice(&110u16.to_le_bytes());
t1.send_async(&remote, &msg1).await.unwrap();
{
let pool = t1.pool.lock().await;
assert!(pool.contains_key(&remote));
}
t1.close_connection_async(&remote).await;
{
let pool = t1.pool.lock().await;
assert!(!pool.contains_key(&remote));
}
t1.stop_async().await.unwrap();
t2.stop_async().await.unwrap();
}
#[tokio::test]
async fn test_discover_returns_empty() {
let (tx, _rx) = packet_channel(100);
let transport = TcpTransport::new(TransportId::new(1), None, make_config(), tx);
let peers = transport.discover().unwrap();
assert!(peers.is_empty());
}
#[test]
fn test_transport_type() {
let (tx, _rx) = packet_channel(100);
let transport = TcpTransport::new(TransportId::new(1), None, make_config(), tx);
assert_eq!(transport.transport_type().name, "tcp");
assert!(transport.transport_type().connection_oriented);
assert!(transport.transport_type().reliable);
}
#[test]
fn test_sync_methods_return_not_supported() {
let (tx, _rx) = packet_channel(100);
let mut transport = TcpTransport::new(TransportId::new(1), None, make_config(), tx);
assert!(matches!(
transport.start(),
Err(TransportError::NotSupported(_))
));
assert!(matches!(
transport.stop(),
Err(TransportError::NotSupported(_))
));
assert!(matches!(
transport.send(&TransportAddr::from_string("test"), b"data"),
Err(TransportError::NotSupported(_))
));
}
#[test]
fn test_accept_connections_with_bind() {
let (tx, _rx) = packet_channel(100);
let config = TcpConfig {
bind_addr: Some("0.0.0.0:0".to_string()),
..Default::default()
};
let transport = TcpTransport::new(TransportId::new(1), None, config, tx);
assert!(transport.accept_connections());
}
#[test]
fn test_accept_connections_without_bind() {
let (tx, _rx) = packet_channel(100);
let config = TcpConfig {
bind_addr: None,
..Default::default()
};
let transport = TcpTransport::new(TransportId::new(1), None, config, tx);
assert!(!transport.accept_connections());
}
#[tokio::test]
async fn test_connection_drop_and_reconnect() {
let (tx1, _rx1) = packet_channel(100);
let (tx2, mut rx2) = packet_channel(100);
let mut t1 = TcpTransport::new(TransportId::new(1), None, make_outbound_config(), tx1);
let mut t2 = TcpTransport::new(TransportId::new(2), None, make_config(), tx2);
t1.start_async().await.unwrap();
t2.start_async().await.unwrap();
let addr2 = t2.local_addr().unwrap();
let remote = TransportAddr::from_string(&addr2.to_string());
let mut msg1 = vec![0xAA; 114];
msg1[0] = 0x01;
msg1[1] = 0x00;
msg1[2..4].copy_from_slice(&110u16.to_le_bytes());
t1.send_async(&remote, &msg1).await.unwrap();
let _ = timeout(Duration::from_secs(1), rx2.recv()).await;
t1.close_connection_async(&remote).await;
t1.send_async(&remote, &msg1).await.unwrap();
let packet = timeout(Duration::from_secs(2), rx2.recv())
.await
.expect("timeout")
.expect("channel closed");
assert_eq!(packet.data, msg1);
t1.stop_async().await.unwrap();
t2.stop_async().await.unwrap();
}
#[tokio::test]
async fn test_connect_async_success() {
let (tx1, mut rx1) = packet_channel(100);
let (tx2, _rx2) = packet_channel(100);
let mut t1 = TcpTransport::new(TransportId::new(1), None, make_outbound_config(), tx1);
let mut t2 = TcpTransport::new(TransportId::new(2), None, make_config(), tx2);
t1.start_async().await.unwrap();
t2.start_async().await.unwrap();
let addr2 = t2.local_addr().unwrap();
let remote = TransportAddr::from_string(&addr2.to_string());
assert_eq!(t1.connection_state_sync(&remote), ConnectionState::None);
t1.connect_async(&remote).await.unwrap();
tokio::time::sleep(Duration::from_millis(200)).await;
let state = t1.connection_state_sync(&remote);
assert_eq!(state, ConnectionState::Connected);
let mut msg1 = vec![0xAA; 114];
msg1[0] = 0x01;
msg1[1] = 0x00;
msg1[2..4].copy_from_slice(&110u16.to_le_bytes());
t1.send_async(&remote, &msg1).await.unwrap();
let packet = timeout(Duration::from_secs(2), rx1.recv()).await;
drop(packet);
t1.stop_async().await.unwrap();
t2.stop_async().await.unwrap();
}
#[tokio::test]
async fn test_connect_async_timeout() {
let (tx, _rx) = packet_channel(100);
let config = TcpConfig {
bind_addr: None,
connect_timeout_ms: Some(100), ..Default::default()
};
let mut transport = TcpTransport::new(TransportId::new(1), None, config, tx);
transport.start_async().await.unwrap();
let remote = TransportAddr::from_string("192.0.2.1:2121");
transport.connect_async(&remote).await.unwrap();
tokio::time::sleep(Duration::from_millis(500)).await;
let state = transport.connection_state_sync(&remote);
assert!(matches!(state, ConnectionState::Failed(_)));
transport.stop_async().await.unwrap();
}
#[tokio::test]
async fn test_connect_async_not_started() {
let (tx, _rx) = packet_channel(100);
let transport = TcpTransport::new(TransportId::new(1), None, make_config(), tx);
let result = transport
.connect_async(&TransportAddr::from_string("127.0.0.1:9999"))
.await;
assert!(matches!(result, Err(TransportError::NotStarted)));
}
#[tokio::test]
async fn test_connect_async_already_connected() {
let (tx1, _rx1) = packet_channel(100);
let (tx2, _rx2) = packet_channel(100);
let mut t1 = TcpTransport::new(TransportId::new(1), None, make_outbound_config(), tx1);
let mut t2 = TcpTransport::new(TransportId::new(2), None, make_config(), tx2);
t1.start_async().await.unwrap();
t2.start_async().await.unwrap();
let addr2 = t2.local_addr().unwrap();
let remote = TransportAddr::from_string(&addr2.to_string());
t1.connect_async(&remote).await.unwrap();
tokio::time::sleep(Duration::from_millis(200)).await;
assert_eq!(
t1.connection_state_sync(&remote),
ConnectionState::Connected
);
t1.connect_async(&remote).await.unwrap();
t1.stop_async().await.unwrap();
t2.stop_async().await.unwrap();
}
#[tokio::test]
async fn test_connect_async_then_send_recv() {
let (tx1, _rx1) = packet_channel(100);
let (tx2, mut rx2) = packet_channel(100);
let mut t1 = TcpTransport::new(TransportId::new(1), None, make_outbound_config(), tx1);
let mut t2 = TcpTransport::new(TransportId::new(2), None, make_config(), tx2);
t1.start_async().await.unwrap();
t2.start_async().await.unwrap();
let addr2 = t2.local_addr().unwrap();
let remote = TransportAddr::from_string(&addr2.to_string());
t1.connect_async(&remote).await.unwrap();
tokio::time::sleep(Duration::from_millis(200)).await;
assert_eq!(
t1.connection_state_sync(&remote),
ConnectionState::Connected
);
let mut msg1 = vec![0xAA; 114];
msg1[0] = 0x01;
msg1[1] = 0x00;
msg1[2..4].copy_from_slice(&110u16.to_le_bytes());
t1.send_async(&remote, &msg1).await.unwrap();
let packet = timeout(Duration::from_secs(2), rx2.recv())
.await
.expect("timeout")
.expect("channel closed");
assert_eq!(packet.data, msg1);
t1.stop_async().await.unwrap();
t2.stop_async().await.unwrap();
}
#[test]
fn test_connection_state_none_for_unknown() {
let (tx, _rx) = packet_channel(100);
let transport = TcpTransport::new(TransportId::new(1), None, make_config(), tx);
let state = transport.connection_state_sync(&TransportAddr::from_string("unknown:1234"));
assert_eq!(state, ConnectionState::None);
}
#[tokio::test]
async fn test_connect_ip_string() {
let (tx1, _rx1) = packet_channel(100);
let (tx2, mut rx2) = packet_channel(100);
let mut t1 = TcpTransport::new(TransportId::new(1), None, make_config(), tx1);
let mut t2 = TcpTransport::new(
TransportId::new(2),
None,
TcpConfig {
bind_addr: Some("127.0.0.1:0".to_string()),
..Default::default()
},
tx2,
);
t1.start_async().await.unwrap();
t2.start_async().await.unwrap();
let port2 = t2.local_addr().unwrap().port();
let addr = TransportAddr::from_string(&format!("127.0.0.1:{}", port2));
let mut frame = vec![0xAA; 114];
frame[0] = 0x01; frame[1] = 0x00; frame[2..4].copy_from_slice(&110u16.to_le_bytes()); t1.send_async(&addr, &frame).await.unwrap();
let packet = tokio::time::timeout(Duration::from_secs(5), rx2.recv())
.await
.expect("timeout")
.expect("channel closed");
assert_eq!(packet.data, frame);
t1.stop_async().await.unwrap();
t2.stop_async().await.unwrap();
}
#[tokio::test]
async fn test_connect_async_ip_string() {
let (tx1, _rx1) = packet_channel(100);
let (tx2, _rx2) = packet_channel(100);
let mut t1 = TcpTransport::new(TransportId::new(1), None, make_config(), tx1);
let mut t2 = TcpTransport::new(
TransportId::new(2),
None,
TcpConfig {
bind_addr: Some("127.0.0.1:0".to_string()),
..Default::default()
},
tx2,
);
t1.start_async().await.unwrap();
t2.start_async().await.unwrap();
let port2 = t2.local_addr().unwrap().port();
let addr = TransportAddr::from_string(&format!("127.0.0.1:{}", port2));
t1.connect_async(&addr).await.unwrap();
for _ in 0..50 {
let state = t1.connection_state_sync(&addr);
if state == ConnectionState::Connected {
break;
}
tokio::time::sleep(Duration::from_millis(100)).await;
}
assert_eq!(t1.connection_state_sync(&addr), ConnectionState::Connected,);
t1.stop_async().await.unwrap();
t2.stop_async().await.unwrap();
}
}