use mbus_core::transport::UnitIdOrSlaveAddr;
use mbus_network::TokioTcpTransport;
use std::convert::Infallible;
use std::future::Future;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::net::{TcpListener, ToSocketAddrs};
use tokio::sync::Mutex;
use super::app_handler::{AsyncAppHandler, AsyncServerError};
use super::session::AsyncServerSession;
pub struct AsyncTcpServer {
listener: TcpListener,
unit: UnitIdOrSlaveAddr,
}
impl AsyncTcpServer {
pub async fn serve<APP, A>(
addr: A,
app: APP,
unit: UnitIdOrSlaveAddr,
) -> Result<Infallible, AsyncServerError>
where
A: ToSocketAddrs,
APP: AsyncAppHandler + Clone,
{
let server = Self::bind(addr, unit).await?;
loop {
let (mut session, _peer) = server.accept().await?;
let app_instance = app.clone();
tokio::spawn(async move {
let mut app_instance = app_instance;
let _ = session.run(&mut app_instance).await;
});
}
}
pub async fn serve_shared<APP, A>(
addr: A,
app: Arc<Mutex<APP>>,
unit: UnitIdOrSlaveAddr,
) -> Result<Infallible, AsyncServerError>
where
A: ToSocketAddrs,
APP: AsyncAppHandler,
{
Self::serve(addr, app, unit).await
}
pub async fn serve_with_shutdown<APP, A, F>(
addr: A,
app: APP,
unit: UnitIdOrSlaveAddr,
shutdown: F,
) -> Result<(), AsyncServerError>
where
A: ToSocketAddrs,
APP: AsyncAppHandler + Clone,
F: Future<Output = ()>,
{
let server = Self::bind(addr, unit).await?;
tokio::pin!(shutdown);
loop {
tokio::select! {
biased;
_ = &mut shutdown => return Ok(()),
result = server.accept() => {
let (mut session, _peer) = result?;
let app_instance = app.clone();
tokio::spawn(async move {
let mut app_instance = app_instance;
let _ = session.run(&mut app_instance).await;
});
}
}
}
}
pub async fn bind<A: ToSocketAddrs>(
addr: A,
unit: UnitIdOrSlaveAddr,
) -> Result<Self, AsyncServerError> {
let listener = TcpListener::bind(addr)
.await
.map_err(AsyncServerError::BindFailed)?;
Ok(Self { listener, unit })
}
pub async fn accept(
&self,
) -> Result<(AsyncServerSession<TokioTcpTransport>, SocketAddr), AsyncServerError> {
let (stream, peer) = self
.listener
.accept()
.await
.map_err(|_| AsyncServerError::ConnectionClosed)?;
let _ = stream.set_nodelay(true);
let transport = TokioTcpTransport::from_stream(stream);
let session = AsyncServerSession::new(transport, self.unit);
Ok((session, peer))
}
pub fn local_addr(&self) -> Result<SocketAddr, AsyncServerError> {
self.listener
.local_addr()
.map_err(|_| AsyncServerError::Transport(mbus_core::errors::MbusError::IoError))
}
}