use core::future::poll_fn;
use core::ops::Deref;
use core::sync::atomic::{AtomicUsize, Ordering};
use core::time::Duration;
use std::fmt::Debug;
use std::io;
use std::net::SocketAddr;
use std::string::{String, ToString};
use std::sync::{Arc, Mutex};
use arc_swap::ArcSwap;
use octseq::Octets;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpListener;
use tokio::sync::watch;
use tokio::time::{interval, timeout, MissedTickBehavior};
use tracing::{error, trace, trace_span, warn};
use crate::net::server::buf::BufSource;
use crate::net::server::error::Error;
use crate::net::server::metrics::ServerMetrics;
use crate::net::server::service::Service;
use crate::net::server::sock::AsyncAccept;
use crate::utils::config::DefMinMax;
use super::buf::VecBufSource;
use super::connection::{self, Connection};
use super::ServerCommand;
pub type TcpServer<Svc> = StreamServer<TcpListener, VecBufSource, Svc>;
const MAX_CONCURRENT_TCP_CONNECTIONS: DefMinMax<usize> =
DefMinMax::new(100, 1, 100000);
pub struct Config {
max_concurrent_connections: usize,
accept_connections_at_max: bool,
pub(super) connection_config: connection::Config,
}
impl Config {
pub fn new() -> Self {
Default::default()
}
pub fn set_accept_connections_at_max(&mut self, value: bool) {
self.accept_connections_at_max = value;
}
pub fn accept_connections_at_max(&self) -> bool {
self.accept_connections_at_max
}
pub fn set_max_concurrent_connections(&mut self, value: usize) {
self.max_concurrent_connections = value;
}
pub fn max_concurrent_connections(&self) -> usize {
self.max_concurrent_connections
}
pub fn set_connection_config(
&mut self,
connection_config: connection::Config,
) {
self.connection_config = connection_config;
}
pub fn connection_config(&self) -> &connection::Config {
&self.connection_config
}
}
impl Default for Config {
fn default() -> Self {
Self {
accept_connections_at_max: true,
max_concurrent_connections: MAX_CONCURRENT_TCP_CONNECTIONS
.default(),
connection_config: connection::Config::default(),
}
}
}
impl Clone for Config {
fn clone(&self) -> Self {
Self {
accept_connections_at_max: self.accept_connections_at_max,
max_concurrent_connections: self.max_concurrent_connections,
connection_config: self.connection_config,
}
}
}
type ServerCommandType = ServerCommand<Config>;
type CommandSender = Arc<Mutex<watch::Sender<ServerCommandType>>>;
type CommandReceiver = watch::Receiver<ServerCommandType>;
pub struct StreamServer<Listener, Buf, Svc>
where
Listener: AsyncAccept + Send + Sync,
Buf: BufSource + Send + Sync + Clone,
Buf::Output: Octets + Send + Sync + Unpin,
Svc: Service<Buf::Output, ()> + Clone,
{
config: Arc<ArcSwap<Config>>,
command_rx: CommandReceiver,
command_tx: CommandSender,
listener: Arc<Listener>,
buf: Buf,
service: Svc,
pre_connect_hook: Option<fn(&mut Listener::StreamType)>,
connection_idx: AtomicUsize,
metrics: Arc<ServerMetrics>,
}
impl<Listener, Buf, Svc> StreamServer<Listener, Buf, Svc>
where
Listener: AsyncAccept + Send + Sync,
Buf: BufSource + Send + Sync + Clone,
Buf::Output: Octets + Send + Sync + Unpin,
Svc: Service<Buf::Output, ()> + Clone,
{
#[must_use]
pub fn new(listener: Listener, buf: Buf, service: Svc) -> Self {
Self::with_config(listener, buf, service, Config::default())
}
#[must_use]
pub fn with_config(
listener: Listener,
buf: Buf,
service: Svc,
config: Config,
) -> Self {
let (command_tx, command_rx) = watch::channel(ServerCommand::Init);
let command_tx = Arc::new(Mutex::new(command_tx));
let listener = Arc::new(listener);
let metrics = Arc::new(ServerMetrics::connection_oriented());
let config = Arc::new(ArcSwap::from_pointee(config));
StreamServer {
config,
command_tx,
command_rx,
listener,
buf,
service,
pre_connect_hook: None,
metrics,
connection_idx: AtomicUsize::new(0),
}
}
#[must_use]
pub fn with_pre_connect_hook(
mut self,
pre_connect_hook: fn(&mut Listener::StreamType),
) -> Self {
self.pre_connect_hook = Some(pre_connect_hook);
self
}
}
impl<Listener, Buf, Svc> StreamServer<Listener, Buf, Svc>
where
Listener: AsyncAccept + Send + Sync,
Buf: BufSource + Send + Sync + Clone,
Buf::Output: Octets + Debug + Send + Sync + Unpin,
Svc: Service<Buf::Output, ()> + Clone,
{
#[must_use]
pub fn source(&self) -> Arc<Listener> {
self.listener.clone()
}
#[must_use]
pub fn metrics(&self) -> Arc<ServerMetrics> {
self.metrics.clone()
}
}
impl<Listener, Buf, Svc> StreamServer<Listener, Buf, Svc>
where
Listener: AsyncAccept + Send + Sync,
Buf: BufSource + Send + Sync + Clone,
Buf::Output: Octets + Send + Sync + Unpin,
Svc: Service<Buf::Output, ()> + Clone,
{
pub async fn run(&self)
where
Buf: 'static,
Buf::Output: 'static,
Listener::Error: Send,
Listener::Future: Send + 'static,
Listener::StreamType: AsyncRead + AsyncWrite + Send + Sync + 'static,
{
if let Err(err) = self.run_until_error().await {
error!("Server stopped due to error: {err}");
}
}
pub fn reconfigure(&self, config: Config) -> Result<(), Error> {
self.command_tx
.lock()
.map_err(|_| Error::CommandCouldNotBeSent)?
.send(ServerCommand::Reconfigure(config))
.map_err(|_| Error::CommandCouldNotBeSent)
}
pub fn shutdown(&self) -> Result<(), Error> {
self.command_tx
.lock()
.map_err(|_| Error::CommandCouldNotBeSent)?
.send(ServerCommand::Shutdown)
.map_err(|_| Error::CommandCouldNotBeSent)
}
pub fn is_shutdown(&self) -> bool {
self.metrics.num_inflight_requests() == 0
&& self.metrics.num_pending_writes() == 0
}
pub async fn await_shutdown(&self, duration: Duration) -> bool {
timeout(duration, async {
let mut interval = interval(Duration::from_millis(100));
interval.set_missed_tick_behavior(MissedTickBehavior::Skip);
while !self.is_shutdown() {
interval.tick().await;
}
})
.await
.is_ok()
}
}
impl<Listener, Buf, Svc> StreamServer<Listener, Buf, Svc>
where
Listener: AsyncAccept + Send + Sync,
Buf: BufSource + Send + Sync + Clone,
Buf::Output: Octets + Send + Sync + Unpin,
Svc: Service<Buf::Output, ()> + Clone,
{
async fn run_until_error(&self) -> Result<(), String>
where
Buf: 'static,
Buf::Output: 'static,
Listener::Error: Send,
Listener::Future: Send + 'static,
Listener::StreamType: AsyncRead + AsyncWrite + Send + Sync + 'static,
{
let mut command_rx = self.command_rx.clone();
loop {
tokio::select! {
biased;
res = command_rx.changed() => {
self.process_server_command(res, &mut command_rx)?;
}
accept_res = self.accept(), if self.accepting_connections() => {
match accept_res {
Ok((stream, addr)) if !self.at_connection_limit() => {
self.spawn_connection_handler(stream, addr);
}
Ok(_) => {
warn!("Connection limit reached: dropping accepted connection");
}
Err(err) => {
error!("Error while accepting TCP connection: {err}");
}
}
}
}
}
}
fn at_connection_limit(&self) -> bool {
let config = ArcSwap::load(&self.config);
let num_conn = self.metrics.num_connections();
num_conn >= config.max_concurrent_connections()
}
fn accepting_connections(&self) -> bool {
if self.at_connection_limit() {
let config = ArcSwap::load(&self.config);
config.accept_connections_at_max
} else {
true
}
}
fn process_server_command(
&self,
res: Result<(), watch::error::RecvError>,
command_rx: &mut watch::Receiver<ServerCommand<Config>>,
) -> Result<(), String> {
res.map_err(|err| format!("Error while receiving command: {err}"))?;
let lock = command_rx.borrow_and_update();
let command = lock.deref();
match command {
ServerCommand::Reconfigure(new_config) => {
self.config.store(Arc::new(new_config.clone()));
}
ServerCommand::Shutdown => {
return Err("Shutdown command received".to_string());
}
ServerCommand::Init => {
unreachable!()
}
ServerCommand::CloseConnection => {
unreachable!()
}
}
Ok(())
}
fn spawn_connection_handler(
&self,
stream: Listener::Future,
addr: SocketAddr,
) where
Buf: 'static,
Buf::Output: Octets + 'static,
Listener::Error: Send,
Listener::Future: Send + 'static,
Listener::StreamType: AsyncRead + AsyncWrite + Send + Sync + 'static,
{
let config = ArcSwap::load(&self.config);
let conn_config = config.connection_config;
let conn_command_rx = self.command_rx.clone();
let conn_service = self.service.clone();
let conn_buf = self.buf.clone();
let conn_metrics = self.metrics.clone();
let pre_connect_hook = self.pre_connect_hook;
let new_connection_idx =
self.connection_idx.fetch_add(1, Ordering::SeqCst);
trace!("Spawning new connection handler.");
tokio::spawn(async move {
let span = trace_span!("stream", conn = new_connection_idx);
let _guard = span.enter();
trace!("Accepting connection.");
if let Ok(mut stream) = stream.await {
trace!("Connection accepted.");
if let Some(hook) = pre_connect_hook {
trace!("Running pre-connect hook.");
hook(&mut stream);
}
let conn = Connection::with_config(
conn_service,
conn_buf,
conn_metrics,
stream,
addr,
conn_config,
);
trace!("Starting connection handler.");
conn.run(conn_command_rx).await;
trace!("Connection handler terminated.");
}
});
}
async fn accept(
&self,
) -> Result<(Listener::Future, SocketAddr), io::Error> {
poll_fn(|ctx| self.listener.poll_accept(ctx)).await
}
}
impl<Listener, Buf, Svc> Drop for StreamServer<Listener, Buf, Svc>
where
Listener: AsyncAccept + Send + Sync,
Buf: BufSource + Send + Sync + Clone,
Buf::Output: Octets + Send + Sync + Unpin,
Svc: Service<Buf::Output, ()> + Clone,
{
fn drop(&mut self) {
let _ = self.shutdown();
}
}