use crate::transport::legs::TransportLeg;
use async_trait::async_trait;
use bytes::{Buf, Bytes, BytesMut};
use std::io;
use std::net::SocketAddr;
use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU8, Ordering};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio::sync::Mutex;
pub struct TcpLeg {
stream: Mutex<Option<TcpStream>>,
remote_addr: Option<SocketAddr>,
rtt_ms: AtomicU32,
#[allow(dead_code)]
loss_percent: AtomicU8,
available: AtomicBool,
read_buf: Mutex<BytesMut>,
}
impl TcpLeg {
pub fn new() -> Self {
Self {
stream: Mutex::new(None),
remote_addr: None,
rtt_ms: AtomicU32::new(100), loss_percent: AtomicU8::new(0),
available: AtomicBool::new(false),
read_buf: Mutex::new(BytesMut::with_capacity(16384)),
}
}
pub async fn connect(addr: SocketAddr) -> io::Result<Self> {
let start = std::time::Instant::now();
let stream = TcpStream::connect(addr).await?;
let rtt = start.elapsed().as_millis() as u32;
stream.set_nodelay(true)?;
log::debug!("TCP connected to {} (RTT {}ms)", addr, rtt);
Ok(Self {
stream: Mutex::new(Some(stream)),
remote_addr: Some(addr),
rtt_ms: AtomicU32::new(rtt),
loss_percent: AtomicU8::new(0),
available: AtomicBool::new(true),
read_buf: Mutex::new(BytesMut::with_capacity(16384)),
})
}
pub fn from_stream(stream: TcpStream, addr: SocketAddr) -> Self {
let _ = stream.set_nodelay(true);
Self {
stream: Mutex::new(Some(stream)),
remote_addr: Some(addr),
rtt_ms: AtomicU32::new(100),
loss_percent: AtomicU8::new(0),
available: AtomicBool::new(true),
read_buf: Mutex::new(BytesMut::with_capacity(16384)),
}
}
pub fn update_rtt(&self, sample_ms: u32) {
let current = self.rtt_ms.load(Ordering::Relaxed);
let new_rtt = (7 * current + sample_ms) / 8;
self.rtt_ms.store(new_rtt, Ordering::Relaxed);
}
async fn read_framed(&self) -> io::Result<Bytes> {
let mut stream_guard = self.stream.lock().await;
let stream = stream_guard
.as_mut()
.ok_or_else(|| io::Error::new(io::ErrorKind::NotConnected, "Not connected"))?;
let mut read_buf = self.read_buf.lock().await;
while read_buf.len() < 4 {
let mut temp = [0u8; 4096];
let n = stream.read(&mut temp).await?;
if n == 0 {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"Connection closed",
));
}
read_buf.extend_from_slice(&temp[..n]);
}
let length =
u32::from_be_bytes([read_buf[0], read_buf[1], read_buf[2], read_buf[3]]) as usize;
if length > 10 * 1024 * 1024 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Message too large",
));
}
while read_buf.len() < 4 + length {
let mut temp = [0u8; 4096];
let n = stream.read(&mut temp).await?;
if n == 0 {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"Connection closed",
));
}
read_buf.extend_from_slice(&temp[..n]);
}
read_buf.advance(4);
let data = read_buf.split_to(length).freeze();
Ok(data)
}
async fn write_framed(&self, data: &[u8]) -> io::Result<()> {
let mut stream_guard = self.stream.lock().await;
let stream = stream_guard
.as_mut()
.ok_or_else(|| io::Error::new(io::ErrorKind::NotConnected, "Not connected"))?;
let length = data.len() as u32;
stream.write_all(&length.to_be_bytes()).await?;
stream.write_all(data).await?;
stream.flush().await?;
Ok(())
}
}
impl Default for TcpLeg {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl TransportLeg for TcpLeg {
async fn send(&self, data: Bytes) -> io::Result<()> {
if !self.is_available() {
return Err(io::Error::new(
io::ErrorKind::NotConnected,
"TCP not connected",
));
}
self.write_framed(&data).await
}
async fn recv(&self) -> io::Result<Bytes> {
if !self.is_available() {
return Err(io::Error::new(
io::ErrorKind::NotConnected,
"TCP not connected",
));
}
self.read_framed().await
}
fn is_available(&self) -> bool {
self.available.load(Ordering::Relaxed)
}
fn rtt_ms(&self) -> u32 {
self.rtt_ms.load(Ordering::Relaxed)
}
fn loss_percent(&self) -> u8 {
0 }
fn remote_addr(&self) -> Option<SocketAddr> {
self.remote_addr
}
async fn close(&self) -> io::Result<()> {
self.available.store(false, Ordering::Relaxed);
if let Some(stream) = self.stream.lock().await.take() {
drop(stream); }
log::info!("TCP closed");
Ok(())
}
}
impl std::fmt::Debug for TcpLeg {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TcpLeg")
.field("remote", &self.remote_addr)
.field("rtt_ms", &self.rtt_ms.load(Ordering::Relaxed))
.field("available", &self.is_available())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tcp_leg_creation() {
let leg = TcpLeg::new();
assert!(!leg.is_available());
assert_eq!(leg.loss_percent(), 0);
}
}