use std::net::SocketAddr;
use std::sync::Arc;
use bytes::Bytes;
use dashmap::DashMap;
use futures::stream::StreamExt;
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::Mutex;
use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender};
use tokio_util::codec::{FramedRead, FramedWrite};
use tracing::{debug, error, warn};
use crate::core::message_codec::MAX_FRAME_SIZE;
use crate::{Error, Message, MessageCodec, Peer, PeerId, RateLimiter, Result};
const PEER_REGISTRATION_MAX_ATTEMPTS: u32 = 50;
const PEER_REGISTRATION_CHECK_DELAY_MS: u64 = 10;
pub struct Tcp {
local_addr: Option<SocketAddr>,
peers: Arc<DashMap<SocketAddr, Peer>>,
message_rx: Arc<Mutex<UnboundedReceiver<(SocketAddr, Message)>>>,
message_tx: UnboundedSender<(SocketAddr, Message)>,
rate_limiter: Option<Arc<Mutex<RateLimiter>>>,
max_message_size: usize,
}
impl Tcp {
pub fn new() -> Self {
Self::with_max_message_size(MAX_FRAME_SIZE) }
pub fn with_max_message_size(max_message_size: usize) -> Self {
let (message_tx, message_rx) = mpsc::unbounded_channel();
Self {
local_addr: None,
peers: Arc::new(DashMap::new()),
message_rx: Arc::new(Mutex::new(message_rx)),
message_tx,
rate_limiter: None,
max_message_size,
}
}
pub fn set_rate_limit(mut self, capacity: u32, refill_rate: u32) -> Self {
self.rate_limiter = Some(Arc::new(Mutex::new(RateLimiter::with_params(
capacity,
refill_rate,
))));
self
}
pub async fn listen(&mut self, addr: SocketAddr) -> Result<()> {
let listener = TcpListener::bind(addr)
.await
.map_err(|e| Error::Connection { addr, source: e })?;
let local_addr = listener.local_addr().map_err(Error::Io)?;
self.local_addr = Some(local_addr);
debug!("TCP transport listening on {local_addr}");
let peers = Arc::clone(&self.peers);
let message_tx = self.message_tx.clone();
let rate_limiter = self.rate_limiter.clone();
let max_message_size = self.max_message_size;
tokio::spawn(async move {
loop {
match listener.accept().await {
Ok((stream, peer_addr)) => {
debug!("Accepted connection from {peer_addr}");
Self::handle_connection(
stream,
peer_addr,
Arc::clone(&peers),
message_tx.clone(),
rate_limiter.clone(),
max_message_size,
);
}
Err(e) => {
error!("Failed to accept connection: {e}");
}
}
}
});
Ok(())
}
pub async fn connect(&self, addr: SocketAddr) -> Result<()> {
if self.peers.contains_key(&addr) {
debug!("Already connected to {addr}");
return Ok(());
}
let stream = TcpStream::connect(addr)
.await
.map_err(|e| Error::Connection { addr, source: e })?;
debug!("TCP connection established to {addr}");
Self::handle_connection(
stream,
addr,
Arc::clone(&self.peers),
self.message_tx.clone(),
self.rate_limiter.clone(),
self.max_message_size,
);
for _ in 0..PEER_REGISTRATION_MAX_ATTEMPTS {
if self.peers.contains_key(&addr) {
return Ok(());
}
tokio::time::sleep(tokio::time::Duration::from_millis(
PEER_REGISTRATION_CHECK_DELAY_MS,
))
.await;
}
Err(Error::Connection {
addr,
source: std::io::Error::new(std::io::ErrorKind::TimedOut, "peer registration timeout"),
})
}
pub async fn send(&self, peer: SocketAddr, message: Message) -> Result<()> {
let data = bincode::serde::encode_to_vec(&message, bincode::config::standard())?;
if let Some(conn) = self.peers.get_mut(&peer).as_deref_mut() {
conn.send(Bytes::from(data))
.map_err(|err| Error::Channel(err.to_string()))?;
Ok(())
} else {
Err(Error::PeerNotFound(peer))
}
}
pub async fn recv(&self) -> Result<(SocketAddr, Message)> {
self.message_rx
.lock()
.await
.recv()
.await
.ok_or(Error::Channel("Channel recv error".to_string()))
}
pub fn local_addr(&self) -> Option<SocketAddr> {
self.local_addr
}
pub fn peers(&self) -> Vec<SocketAddr> {
self.peers.iter().map(|entry| *entry.key()).collect()
}
fn handle_connection(
stream: TcpStream,
peer_addr: SocketAddr,
peers: Arc<DashMap<SocketAddr, Peer>>,
message_tx: UnboundedSender<(SocketAddr, Message)>,
rate_limiter: Option<Arc<Mutex<RateLimiter>>>,
max_message_size: usize,
) {
tokio::spawn(async move {
let (reader, writer) = stream.into_split();
let (tx, mut rx) = mpsc::unbounded_channel::<Bytes>();
let peer = Peer::new(PeerId(peer_addr), tx);
peers.insert(peer_addr, peer);
let codec = MessageCodec::with_max_frame_size(max_message_size);
let write_task = {
let mut sink = FramedWrite::new(writer, codec.clone());
tokio::spawn(async move {
while let Some(data) = rx.recv().await {
match bincode::serde::decode_from_slice::<Message, _>(
&data,
bincode::config::standard(),
) {
Ok((message, _)) => {
if let Err(e) = futures::SinkExt::send(&mut sink, message).await {
error!("Failed to send to {peer_addr}: {e}");
break;
}
}
Err(e) => {
error!("Failed to deserialize outgoing message: {e}");
}
}
}
})
};
let read_task = {
let mut stream = FramedRead::new(reader, codec);
tokio::spawn(async move {
while let Some(result) = stream.next().await {
match result {
Ok(message) => {
if let Some(ref limiter) = rate_limiter {
let allowed = limiter.lock().await.allow_request(peer_addr);
if !allowed {
warn!(
"Rate limit exceeded for peer {peer_addr}, dropping message"
);
continue;
}
}
if message_tx.send((peer_addr, message)).is_err() {
warn!("Message channel closed");
break;
}
}
Err(e) => {
error!("Error reading from {peer_addr}: {e}");
break;
}
}
}
})
};
tokio::select! {
_ = read_task => {},
_ = write_task => {},
}
peers.remove(&peer_addr);
debug!("Connection closed: {peer_addr}");
});
}
}
impl Default for Tcp {
fn default() -> Self {
Self::new()
}
}