use core::fmt::Debug;
use core::future::poll_fn;
use core::future::Future;
use core::ops::Deref;
use core::pin::Pin;
use core::time::Duration;
use std::boxed::Box;
use std::io;
use std::net::SocketAddr;
use std::string::String;
use std::string::ToString;
use std::sync::{Arc, Mutex};
use arc_swap::ArcSwap;
use log::{log_enabled, Level};
use octseq::Octets;
use tokio::io::ReadBuf;
use tokio::net::UdpSocket;
use tokio::sync::watch;
use tokio::time::interval;
use tokio::time::timeout;
use tokio::time::Instant;
use tokio::time::MissedTickBehavior;
use tracing::{error, trace, warn};
use crate::base::iana::OptRcode;
use crate::base::message_builder::AdditionalBuilder;
use crate::base::wire::Composer;
use crate::base::Message;
use crate::base::StreamTarget;
use crate::net::server::buf::BufSource;
use crate::net::server::error::Error;
use crate::net::server::message::Request;
use crate::net::server::metrics::ServerMetrics;
use crate::net::server::service::Service;
use crate::net::server::sock::AsyncDgramSock;
use crate::net::server::util::mk_error_response;
use crate::net::server::util::to_pcap_text;
use crate::utils::config::DefMinMax;
use super::buf::VecBufSource;
use super::invoker::{InvokerStatus, ServiceInvoker};
use super::message::{TransportSpecificContext, UdpTransportContext};
use super::ServerCommand;
pub type UdpServer<Svc> = DgramServer<UdpSocket, VecBufSource, Svc>;
const WRITE_TIMEOUT: DefMinMax<Duration> = DefMinMax::new(
Duration::from_secs(5),
Duration::from_millis(1),
Duration::from_secs(60),
);
const MAX_RESPONSE_SIZE: DefMinMax<u16> = DefMinMax::new(1232, 512, 4096);
#[derive(Debug)]
pub struct Config {
max_response_size: Option<u16>,
write_timeout: Duration,
}
impl Config {
pub fn new() -> Self {
Default::default()
}
pub fn set_max_response_size(&mut self, value: Option<u16>) {
self.max_response_size = value.map(|v| MAX_RESPONSE_SIZE.limit(v));
}
pub fn set_write_timeout(&mut self, value: Duration) {
self.write_timeout = value;
}
}
impl Default for Config {
fn default() -> Self {
Self {
max_response_size: Some(MAX_RESPONSE_SIZE.default()),
write_timeout: WRITE_TIMEOUT.default(),
}
}
}
impl Clone for Config {
fn clone(&self) -> Self {
Self {
max_response_size: self.max_response_size,
write_timeout: self.write_timeout,
}
}
}
type ServerCommandType = ServerCommand<Config>;
type CommandSender = Arc<Mutex<watch::Sender<ServerCommandType>>>;
type CommandReceiver = watch::Receiver<ServerCommandType>;
pub struct DgramServer<Sock, Buf, Svc>
where
Sock: AsyncDgramSock + Send + Sync + 'static,
Buf: BufSource + Send + Sync,
<Buf as BufSource>::Output: Octets + Send + Sync + Unpin + 'static,
Svc: Service<<Buf as BufSource>::Output, ()> + Clone,
{
config: Arc<ArcSwap<Config>>,
command_rx: CommandReceiver,
command_tx: CommandSender,
sock: Arc<Sock>,
buf: Buf,
service: Svc,
metrics: Arc<ServerMetrics>,
request_dispatcher: ServiceResponseHandler<Sock>,
}
impl<Sock, Buf, Svc> DgramServer<Sock, Buf, Svc>
where
Sock: AsyncDgramSock + Send + Sync,
Buf: BufSource + Send + Sync,
<Buf as BufSource>::Output: Octets + Send + Sync + Unpin,
Svc: Service<<Buf as BufSource>::Output, ()> + Clone,
{
#[must_use]
pub fn new(sock: Sock, buf: Buf, service: Svc) -> Self {
Self::with_config(sock, buf, service, Config::default())
}
#[must_use]
pub fn with_config(
sock: Sock,
buf: Buf,
service: Svc,
config: Config,
) -> Self {
let (command_tx, command_rx) = watch::channel(ServerCommand::Init);
let command_tx = Arc::new(Mutex::new(command_tx));
let metrics = Arc::new(ServerMetrics::connection_less());
let config = Arc::new(ArcSwap::from_pointee(config));
let sock = Arc::new(sock);
let request_dispatcher = ServiceResponseHandler::new(
config.clone(),
sock.clone(),
metrics.clone(),
);
DgramServer {
config,
command_tx,
command_rx,
sock,
buf,
service,
metrics,
request_dispatcher,
}
}
}
impl<Sock, Buf, Svc> DgramServer<Sock, Buf, Svc>
where
Sock: AsyncDgramSock + Send + Sync,
Buf: BufSource + Send + Sync,
<Buf as BufSource>::Output: Octets + Send + Sync + Unpin,
Svc: Service<<Buf as BufSource>::Output, ()> + Clone,
{
#[must_use]
pub fn source(&self) -> Arc<Sock> {
self.sock.clone()
}
#[must_use]
pub fn metrics(&self) -> Arc<ServerMetrics> {
self.metrics.clone()
}
}
impl<Sock, Buf, Svc> DgramServer<Sock, Buf, Svc>
where
Sock: AsyncDgramSock + Send + Sync + 'static,
Buf: BufSource + Send + Sync,
<Buf as BufSource>::Output: Octets + Send + Sync + 'static + Unpin,
Svc: Service<<Buf as BufSource>::Output, ()> + Clone,
{
pub async fn run(&self) {
if let Err(err) = self.run_until_error().await {
error!("Server stopped due to error: {err}");
}
}
pub fn reconfigure(&self, config: Config) -> Result<(), Error> {
self.command_tx
.lock()
.map_err(|_| Error::CommandCouldNotBeSent)?
.send(ServerCommand::Reconfigure(config))
.map_err(|_| Error::CommandCouldNotBeSent)
}
pub fn shutdown(&self) -> Result<(), Error> {
self.command_tx
.lock()
.map_err(|_| Error::CommandCouldNotBeSent)?
.send(ServerCommand::Shutdown)
.map_err(|_| Error::CommandCouldNotBeSent)
}
pub fn is_shutdown(&self) -> bool {
self.metrics.num_inflight_requests() == 0
&& self.metrics.num_pending_writes() == 0
}
pub async fn await_shutdown(&self, duration: Duration) -> bool {
timeout(duration, async {
let mut interval = interval(Duration::from_millis(100));
interval.set_missed_tick_behavior(MissedTickBehavior::Skip);
while !self.is_shutdown() {
interval.tick().await;
}
})
.await
.is_ok()
}
}
impl<Sock, Buf, Svc> DgramServer<Sock, Buf, Svc>
where
Sock: AsyncDgramSock + Send + Sync,
Buf: BufSource + Send + Sync,
<Buf as BufSource>::Output: Octets + Send + Sync + Unpin,
Svc: Service<<Buf as BufSource>::Output, ()> + Clone,
{
async fn run_until_error(&self) -> Result<(), String> {
let mut command_rx = self.command_rx.clone();
loop {
tokio::select! {
biased;
res = command_rx.changed() => {
self.process_server_command(res, &mut command_rx)?;
}
_ = self.sock.readable() => {
let (buf, addr, bytes_read) = match self.recv_from() {
Ok(res) => res,
Err(err) if err.kind() == io::ErrorKind::WouldBlock => continue,
Err(err) => return Err(format!("Error while receiving message: {err}")),
};
self.process_received_message(buf, addr, bytes_read);
}
}
}
}
fn process_server_command(
&self,
res: Result<(), watch::error::RecvError>,
command_rx: &mut CommandReceiver,
) -> Result<(), String> {
res.map_err(|err| format!("Error while receiving command: {err}"))?;
let lock = command_rx.borrow_and_update();
let command = lock.deref();
match command {
ServerCommand::Init => {
unreachable!()
}
ServerCommand::CloseConnection => {
}
ServerCommand::Reconfigure(new_config) => {
self.config.store(Arc::new(new_config.clone()));
}
ServerCommand::Shutdown => {
return Err("Shutdown command received".to_string());
}
}
Ok(())
}
fn process_received_message(
&self,
buf: <Buf as BufSource>::Output,
addr: SocketAddr,
bytes_read: usize,
) {
let received_at = Instant::now();
self.metrics.inc_num_received_requests();
if log_enabled!(Level::Trace) {
let pcap_text = to_pcap_text(&buf, bytes_read);
trace!(%addr, pcap_text, "Received message");
}
match Message::from_octets(buf) {
Err(err) => {
warn!("Failed while parsing request message: {err}");
}
Ok(msg) if msg.header().qr() => {
trace!("Ignoring received message because it is a reply, not a query.");
let response = mk_error_response::<Buf::Output, Svc::Target>(
&msg,
OptRcode::FORMERR,
);
let dispatcher = self.request_dispatcher.clone();
tokio::spawn(async move {
dispatcher.send_response(addr, response).await;
});
}
Ok(msg) => {
let ctx = UdpTransportContext::new(
self.config.load().max_response_size,
);
let ctx = TransportSpecificContext::Udp(ctx);
let request = Request::new(addr, received_at, msg, ctx, ());
trace!(
"Spawning task to handle new message with id {}",
request.message().header().id()
);
let mut dispatcher = self.request_dispatcher.clone();
let service = self.service.clone();
tokio::spawn(async move {
dispatcher.dispatch(request, service, addr).await
});
}
}
}
fn recv_from(
&self,
) -> Result<(Buf::Output, SocketAddr, usize), io::Error> {
let mut msg = self.buf.create_buf();
let mut buf = ReadBuf::new(msg.as_mut());
self.sock
.try_recv_buf_from(&mut buf)
.map(|(bytes_read, addr)| (msg, addr, bytes_read))
}
}
impl<Sock, Buf, Svc> Drop for DgramServer<Sock, Buf, Svc>
where
Sock: AsyncDgramSock + Send + Sync + 'static,
Buf: BufSource + Send + Sync,
<Buf as BufSource>::Output: Octets + Send + Sync + Unpin + 'static,
Svc: Service<<Buf as BufSource>::Output, ()> + Clone,
{
fn drop(&mut self) {
let _ = self.shutdown();
}
}
struct ServiceResponseHandler<Sock> {
config: Arc<ArcSwap<Config>>,
sock: Arc<Sock>,
metrics: Arc<ServerMetrics>,
status: InvokerStatus,
}
impl<Sock> ServiceResponseHandler<Sock>
where
Sock: AsyncDgramSock + Send + Sync + 'static,
{
fn new(
config: Arc<ArcSwap<Config>>,
sock: Arc<Sock>,
metrics: Arc<ServerMetrics>,
) -> Self {
Self {
config,
sock,
metrics,
status: InvokerStatus::Normal,
}
}
async fn send_response<Target: Composer>(
&self,
addr: SocketAddr,
response: AdditionalBuilder<StreamTarget<Target>>,
) {
let target = response.finish();
let bytes = target.as_dgram_slice();
if log_enabled!(Level::Trace) {
let pcap_text = to_pcap_text(bytes, bytes.len());
trace!(%addr, pcap_text, "Sending {} bytes of response tp {addr}", bytes.len());
}
self.metrics.inc_num_pending_writes();
let write_timeout = self.config.load().write_timeout;
if let Err(err) =
Self::write_to_network(&self.sock, bytes, &addr, write_timeout)
.await
{
warn!(%addr, "Failed to send response: {err}");
}
self.metrics.dec_num_pending_writes();
self.metrics.inc_num_sent_responses();
}
async fn write_to_network(
sock: &Sock,
data: &[u8],
dest: &SocketAddr,
limit: Duration,
) -> Result<(), io::Error> {
let send_res =
timeout(limit, poll_fn(|ctx| sock.poll_send_to(ctx, data, dest)))
.await;
let Ok(send_res) = send_res else {
return Err(io::ErrorKind::TimedOut.into());
};
let sent = send_res?;
if sent != data.len() {
Err(io::Error::other("short send"))
} else {
Ok(())
}
}
}
impl<Sock> Clone for ServiceResponseHandler<Sock> {
fn clone(&self) -> Self {
Self {
config: self.config.clone(),
sock: self.sock.clone(),
metrics: self.metrics.clone(),
status: InvokerStatus::Normal,
}
}
}
impl<Sock, RequestOctets, Svc, RequestMeta>
ServiceInvoker<RequestOctets, Svc, RequestMeta, SocketAddr>
for ServiceResponseHandler<Sock>
where
RequestOctets: Octets + Send + Sync + 'static,
RequestMeta: Clone + Default + Send + 'static,
Sock: AsyncDgramSock + Send + Sync + 'static,
Svc: Service<RequestOctets, RequestMeta> + Clone,
Svc::Target: 'static,
{
fn status(&self) -> InvokerStatus {
self.status
}
fn set_status(&mut self, status: InvokerStatus) {
self.status = status;
}
fn reconfigure(&self, _idle_timeout: Option<Duration>) {
}
fn enqueue_response<'a>(
&'a self,
response: AdditionalBuilder<StreamTarget<Svc::Target>>,
addr: &'a SocketAddr,
) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
Box::pin(async move { self.send_response(*addr, response).await })
}
}