use std::net::SocketAddr;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use futures_util::StreamExt;
use rusty_modbus_frame::mbap::MbapCodec;
use tokio::net::TcpListener;
use tokio_util::codec::Framed;
use tracing::{debug, trace};
use crate::config::TcpServerConfig;
use crate::connect::{TcpRecvStream, TcpSink};
use crate::error::TransportError;
#[derive(Debug)]
pub struct ConnectionGuard {
counter: Arc<AtomicUsize>,
}
impl Drop for ConnectionGuard {
fn drop(&mut self) {
self.counter.fetch_sub(1, Ordering::Relaxed);
}
}
pub struct TcpServerListener {
listener: TcpListener,
config: TcpServerConfig,
active_connections: Arc<AtomicUsize>,
}
impl TcpServerListener {
pub async fn bind(addr: SocketAddr, config: TcpServerConfig) -> Result<Self, TransportError> {
debug!(
addr = %addr,
max_connections = config.max_connections,
"binding TCP Modbus listener"
);
let listener = TcpListener::bind(addr).await?;
debug!(addr = %listener.local_addr()?, "TCP Modbus listener bound");
Ok(Self {
listener,
config,
active_connections: Arc::new(AtomicUsize::new(0)),
})
}
pub async fn accept(
&self,
) -> Result<(TcpSink, TcpRecvStream, SocketAddr, ConnectionGuard), TransportError> {
loop {
let (stream, addr) = self.listener.accept().await?;
trace!(peer_addr = %addr, "accepted TCP connection");
if let Some(ref ac) = self.config.access_control
&& !ac.is_allowed(&addr.ip())
{
debug!(peer_addr = %addr, "dropping TCP connection denied by access control");
continue;
}
let current = self.active_connections.load(Ordering::Relaxed);
if current >= self.config.max_connections {
debug!(
peer_addr = %addr,
active_connections = current,
max_connections = self.config.max_connections,
"dropping TCP connection over limit"
);
continue;
}
self.active_connections.fetch_add(1, Ordering::Relaxed);
trace!(
peer_addr = %addr,
active_connections = current + 1,
"tracking accepted TCP connection"
);
stream.set_nodelay(self.config.tcp.tcp_nodelay)?;
let framed = Framed::new(stream, MbapCodec);
let (sink, recv_stream) = framed.split();
let sink = TcpSink::new(sink, self.config.tcp.write_timeout);
let recv = TcpRecvStream::new(recv_stream, self.config.tcp.read_timeout);
let guard = ConnectionGuard {
counter: Arc::clone(&self.active_connections),
};
return Ok((sink, recv, addr, guard));
}
}
pub fn local_addr(&self) -> Result<SocketAddr, TransportError> {
Ok(self.listener.local_addr()?)
}
#[must_use]
pub fn connection_counter(&self) -> Arc<AtomicUsize> {
Arc::clone(&self.active_connections)
}
}