use std::{future::Future, io, net::SocketAddr};
use async_trait::async_trait;
use futures_util::{FutureExt as _, SinkExt as _, StreamExt as _};
use socket2::{Domain, Socket, Type};
use tokio::{
io::{AsyncRead, AsyncWrite},
net::{TcpListener, TcpStream},
};
use tokio_util::codec::Framed;
use crate::{
codec::rtu::ServerCodec,
frame::{
rtu::{RequestAdu, ResponseAdu},
ExceptionResponse, OptionalResponsePdu,
},
server::service::Service,
};
use super::Terminated;
#[async_trait]
pub trait BindSocket {
type Error;
async fn bind_socket(addr: SocketAddr) -> Result<Socket, Self::Error>;
}
pub fn accept_tcp_connection<S, NewService>(
stream: TcpStream,
socket_addr: SocketAddr,
new_service: NewService,
) -> io::Result<Option<(S, TcpStream)>>
where
S: Service + Send + Sync + 'static,
S::Request: From<RequestAdu<'static>> + Send,
NewService: Fn(SocketAddr) -> io::Result<Option<S>>,
{
let service = new_service(socket_addr)?;
Ok(service.map(|service| (service, stream)))
}
#[derive(Debug)]
pub struct Server {
listener: TcpListener,
}
impl Server {
#[must_use]
pub fn new(listener: TcpListener) -> Self {
Self { listener }
}
pub async fn serve<S, T, F, OnConnected, OnProcessError>(
&self,
on_connected: &OnConnected,
on_process_error: OnProcessError,
) -> io::Result<()>
where
S: Service + Send + Sync + 'static,
S::Request: From<RequestAdu<'static>> + Send,
T: AsyncRead + AsyncWrite + Unpin + Send + 'static,
OnConnected: Fn(TcpStream, SocketAddr) -> F,
F: Future<Output = io::Result<Option<(S, T)>>>,
OnProcessError: FnOnce(io::Error) + Clone + Send + 'static,
{
loop {
let (stream, socket_addr) = self.listener.accept().await?;
log::debug!("Accepted connection from {socket_addr}");
let Some((service, transport)) = on_connected(stream, socket_addr).await? else {
log::debug!("No service for connection from {socket_addr}");
continue;
};
let on_process_error = on_process_error.clone();
let framed = Framed::new(transport, ServerCodec::default());
tokio::spawn(async move {
log::debug!("Processing requests from {socket_addr}");
if let Err(err) = process(framed, service).await {
on_process_error(err);
}
});
}
}
pub async fn serve_until<S, T, F, X, OnConnected, OnProcessError>(
self,
on_connected: &OnConnected,
on_process_error: OnProcessError,
abort_signal: X,
) -> io::Result<Terminated>
where
S: Service + Send + Sync + 'static,
S::Request: From<RequestAdu<'static>> + Send,
T: AsyncRead + AsyncWrite + Unpin + Send + 'static,
X: Future<Output = ()> + Sync + Send + Unpin + 'static,
OnConnected: Fn(TcpStream, SocketAddr) -> F,
F: Future<Output = io::Result<Option<(S, T)>>>,
OnProcessError: FnOnce(io::Error) + Clone + Send + 'static,
{
let abort_signal = abort_signal.fuse();
tokio::select! {
res = self.serve(on_connected, on_process_error) => {
res.map(|()| Terminated::Finished)
},
() = abort_signal => {
Ok(Terminated::Aborted)
}
}
}
}
async fn process<S, T, Req>(mut framed: Framed<T, ServerCodec>, service: S) -> io::Result<()>
where
S: Service<Request = Req> + Send + Sync + 'static,
S::Request: From<RequestAdu<'static>> + Send,
T: AsyncRead + AsyncWrite + Unpin,
{
loop {
let Some(request) = framed.next().await.transpose()? else {
log::debug!("TCP socket has been closed");
break;
};
let fc = request.pdu.0.function_code();
let hdr = request.hdr;
let OptionalResponsePdu(Some(response_pdu)) = service
.call(request.into())
.await
.map_err(|e| ExceptionResponse {
function: fc,
exception: e,
})
.into()
else {
log::trace!("Sending no response for request {hdr:?}");
continue;
};
framed
.send(ResponseAdu {
hdr,
pdu: response_pdu,
})
.await?;
}
Ok(())
}
#[allow(unused)]
fn listener(addr: SocketAddr, workers: usize) -> io::Result<TcpListener> {
let listener = match addr {
SocketAddr::V4(_) => Socket::new(Domain::IPV4, Type::STREAM, None)?,
SocketAddr::V6(_) => Socket::new(Domain::IPV6, Type::STREAM, None)?,
};
configure_tcp(workers, &listener)?;
listener.reuse_address()?;
listener.bind(&addr.into())?;
listener.listen(1024)?;
TcpListener::from_std(listener.into())
}
#[cfg(unix)]
#[allow(unused)]
fn configure_tcp(workers: usize, tcp: &Socket) -> io::Result<()> {
if workers > 1 {
tcp.reuse_port()?;
}
Ok(())
}
#[cfg(windows)]
#[allow(unused)]
fn configure_tcp(_workers: usize, _tcp: &Socket) -> io::Result<()> {
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{prelude::*, server::Service};
use std::sync::Arc;
use futures::future;
#[tokio::test]
async fn delegate_service_through_deref_for_server() {
#[derive(Clone)]
struct DummyService {
response: Response,
}
impl Service for DummyService {
type Request = Request<'static>;
type Future = future::Ready<Result<Response, Exception>>;
fn call(&self, _: Self::Request) -> Self::Future {
future::ready(Ok(self.response.clone()))
}
}
let service = Arc::new(DummyService {
response: Response::ReadInputRegisters(vec![0x33]),
});
let svc = |_socket_addr| Ok(Some(Arc::clone(&service)));
let on_connected =
|stream, socket_addr| async move { accept_tcp_connection(stream, socket_addr, svc) };
let addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
let listener = TcpListener::bind(addr).await.unwrap();
let server = Server::new(listener);
std::mem::drop(server.serve(&on_connected, |_err| {}));
}
#[tokio::test]
async fn service_wrapper() {
#[derive(Clone)]
struct DummyService {
response: Response,
}
impl Service for DummyService {
type Request = Request<'static>;
type Future = future::Ready<Result<Response, Exception>>;
fn call(&self, _: Self::Request) -> Self::Future {
future::ready(Ok(self.response.clone()))
}
}
let service = DummyService {
response: Response::ReadInputRegisters(vec![0x33]),
};
let pdu = Request::ReadInputRegisters(0, 1);
let rsp_adu = service.call(pdu).await.unwrap();
assert_eq!(rsp_adu, service.response);
}
}