use core::future::Future;
use core::ops::{ControlFlow, Deref};
use core::pin::Pin;
use core::sync::atomic::{AtomicBool, Ordering};
use core::time::Duration;
use std::boxed::Box;
use std::fmt::Display;
use std::io;
use std::net::SocketAddr;
use std::sync::Arc;
use arc_swap::ArcSwap;
use log::{log_enabled, Level};
use octseq::Octets;
use tokio::io::{
AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf,
};
use tokio::sync::mpsc::error::TrySendError;
use tokio::sync::{mpsc, watch};
use tokio::time::Instant;
use tokio::time::{sleep_until, timeout};
use tracing::{debug, error, trace, warn};
use crate::base::iana::OptRcode;
use crate::base::message_builder::AdditionalBuilder;
use crate::base::{Message, StreamTarget};
use crate::net::server::buf::BufSource;
use crate::net::server::message::Request;
use crate::net::server::metrics::ServerMetrics;
use crate::net::server::service::Service;
use crate::net::server::util::{mk_error_response, to_pcap_text};
use crate::utils::config::DefMinMax;
use super::invoker::{InvokerStatus, ServiceInvoker};
use super::message::{NonUdpTransportContext, TransportSpecificContext};
use super::stream::Config as ServerConfig;
use super::ServerCommand;
const IDLE_TIMEOUT: DefMinMax<Duration> = DefMinMax::new(
Duration::from_secs(30),
Duration::from_millis(200),
Duration::from_secs(30 * 24 * 60 * 60),
);
const RESPONSE_WRITE_TIMEOUT: DefMinMax<Duration> = DefMinMax::new(
Duration::from_secs(30),
Duration::from_millis(1),
Duration::from_secs(60 * 60),
);
const MAX_QUEUED_RESPONSES: DefMinMax<usize> = DefMinMax::new(10, 0, 1024);
#[derive(Copy, Debug)]
pub struct Config {
idle_timeout: Duration,
response_write_timeout: Duration,
max_queued_responses: usize,
}
impl Config {
pub fn new() -> Self {
Default::default()
}
pub fn set_idle_timeout(&mut self, value: Duration) {
self.idle_timeout = value;
}
pub fn set_response_write_timeout(&mut self, value: Duration) {
self.response_write_timeout = value;
}
pub fn set_max_queued_responses(&mut self, value: usize) {
self.max_queued_responses = value;
}
}
impl Default for Config {
fn default() -> Self {
Self {
idle_timeout: IDLE_TIMEOUT.default(),
response_write_timeout: RESPONSE_WRITE_TIMEOUT.default(),
max_queued_responses: MAX_QUEUED_RESPONSES.default(),
}
}
}
impl Clone for Config {
fn clone(&self) -> Self {
*self
}
}
pub struct Connection<Stream, Buf, Svc, RequestMeta>
where
RequestMeta: Default + Clone + Send + 'static,
Buf: BufSource,
Buf::Output: Send + Sync + Unpin,
Svc: Service<Buf::Output, RequestMeta> + Clone,
{
active: bool,
buf: Buf,
config: Arc<ArcSwap<Config>>,
addr: SocketAddr,
stream_rx: Option<ReadHalf<Stream>>,
stream_tx: WriteHalf<Stream>,
result_q_rx: mpsc::Receiver<AdditionalBuilder<StreamTarget<Svc::Target>>>,
result_q_tx: mpsc::Sender<AdditionalBuilder<StreamTarget<Svc::Target>>>,
service: Svc,
idle_timer: IdleTimer,
in_transaction: Arc<AtomicBool>,
metrics: Arc<ServerMetrics>,
request_dispatcher: ServiceResponseHandler<Buf::Output, Svc, RequestMeta>,
}
impl<Stream, Buf, Svc, RequestMeta> Connection<Stream, Buf, Svc, RequestMeta>
where
RequestMeta: Default + Clone + Send + 'static,
Stream: AsyncRead + AsyncWrite,
Buf: BufSource + Clone + Send + Sync,
Buf::Output: Octets + Send + Sync + Unpin,
Svc: Service<Buf::Output, RequestMeta> + Clone,
{
#[must_use]
#[allow(dead_code)]
pub fn new(
service: Svc,
buf: Buf,
metrics: Arc<ServerMetrics>,
stream: Stream,
addr: SocketAddr,
) -> Self {
Self::with_config(
service,
buf,
metrics,
stream,
addr,
Config::default(),
)
}
#[must_use]
pub fn with_config(
service: Svc,
buf: Buf,
metrics: Arc<ServerMetrics>,
stream: Stream,
addr: SocketAddr,
config: Config,
) -> Self {
let (stream_rx, stream_tx) = tokio::io::split(stream);
let (result_q_tx, result_q_rx) =
mpsc::channel(config.max_queued_responses);
let config = Arc::new(ArcSwap::from_pointee(config));
let idle_timer = IdleTimer::new();
let in_transaction = Arc::new(AtomicBool::new(false));
let stream_rx = Some(stream_rx);
let request_dispatcher = ServiceResponseHandler::new(
config.clone(),
result_q_tx.clone(),
metrics.clone(),
);
Self {
active: false,
buf,
config,
addr,
stream_rx,
stream_tx,
result_q_rx,
result_q_tx,
service,
idle_timer,
in_transaction,
metrics,
request_dispatcher,
}
}
}
impl<Stream, Buf, Svc, RequestMeta> Connection<Stream, Buf, Svc, RequestMeta>
where
RequestMeta: Default + Clone + Send + 'static,
Stream: AsyncRead + AsyncWrite + Send + Sync + 'static,
Buf: BufSource + Send + Sync + Clone + 'static,
Buf::Output: Octets + Send + Sync + Unpin,
Svc: Service<Buf::Output, RequestMeta> + Clone,
{
pub async fn run(
mut self,
command_rx: watch::Receiver<ServerCommand<ServerConfig>>,
) {
self.metrics.inc_num_connections();
self.active = true;
self.run_until_error(command_rx).await;
}
}
impl<Stream, Buf, Svc, RequestMeta> Connection<Stream, Buf, Svc, RequestMeta>
where
RequestMeta: Default + Clone + Send + 'static,
Stream: AsyncRead + AsyncWrite + Send + Sync + 'static,
Buf: BufSource + Send + Sync + Clone + 'static,
Buf::Output: Octets + Send + Sync + Unpin,
Svc: Service<Buf::Output, RequestMeta> + Clone,
{
async fn run_until_error(
mut self,
mut command_rx: watch::Receiver<ServerCommand<ServerConfig>>,
) {
let stream_rx = self.stream_rx.take().unwrap();
let mut dns_msg_receiver =
DnsMessageReceiver::new(self.buf.clone(), stream_rx);
'outer: loop {
let msg_recv = dns_msg_receiver.recv();
tokio::pin!(msg_recv);
'inner: loop {
let res = tokio::select! {
biased;
res = command_rx.changed() => {
self.process_server_command(res, &mut command_rx)
}
res = self.result_q_rx.recv() => {
self.process_queued_result(res).await
}
_ = sleep_until(self.idle_timer.idle_timeout_deadline(self.config.load().idle_timeout)) => {
self.process_dns_idle_timeout(self.config.load().idle_timeout)
}
res = &mut msg_recv => {
let res = self.process_read_request(res).await;
if res.is_ok() {
break 'inner;
} else {
res
}
}
};
if let Err(err) = res {
match err {
ConnectionEvent::DisconnectWithoutFlush => {
break 'outer;
}
ConnectionEvent::DisconnectWithFlush => {
self.flush_write_queue().await;
break 'outer;
}
}
}
}
}
trace!("Shutting down the write stream.");
if let Err(err) = self.stream_tx.shutdown().await {
warn!("Error while shutting down the write stream: {err}");
}
trace!("Connection terminated.");
#[cfg(test)]
if dns_msg_receiver.cancelled() {
panic!("Async not-cancel-safe code was cancelled");
}
}
fn process_server_command(
&mut self,
res: Result<(), watch::error::RecvError>,
command_rx: &mut watch::Receiver<ServerCommand<ServerConfig>>,
) -> Result<(), ConnectionEvent> {
res.map_err(|_err| ConnectionEvent::DisconnectWithFlush)?;
let lock = command_rx.borrow_and_update();
let command = lock.deref();
match command {
ServerCommand::Init => {
unreachable!()
}
ServerCommand::CloseConnection => {
return Err(ConnectionEvent::DisconnectWithFlush);
}
ServerCommand::Reconfigure(ServerConfig {
connection_config,
.. }) => {
self.config.store(Arc::new(*connection_config));
}
ServerCommand::Shutdown => {
return Err(ConnectionEvent::DisconnectWithFlush);
}
}
Ok(())
}
async fn flush_write_queue(&mut self) {
debug!("Flushing connection write queue.");
trace!("Stop queueing up new results.");
self.result_q_rx.close();
trace!("Process already queued results.");
while let Some(response) = self.result_q_rx.recv().await {
trace!("Processing queued result.");
if let Err(err) = self.process_queued_result(Some(response)).await
{
warn!("Error while processing queued result: {err}");
} else {
trace!("Result processed");
}
}
debug!("Connection write queue flush complete.");
}
async fn process_queued_result(
&mut self,
response: Option<AdditionalBuilder<StreamTarget<Svc::Target>>>,
) -> Result<(), ConnectionEvent> {
let Some(response) = response else {
trace!("Disconnecting due to failed response queue read.");
return Err(ConnectionEvent::DisconnectWithFlush);
};
trace!(
"Writing queued response with id {} to stream",
response.header().id()
);
self.write_response_to_stream(response.finish()).await
}
async fn write_response_to_stream(
&mut self,
msg: StreamTarget<Svc::Target>,
) -> Result<(), ConnectionEvent> {
if log_enabled!(Level::Trace) {
let bytes = msg.as_dgram_slice();
let pcap_text = to_pcap_text(bytes, bytes.len());
trace!(addr = %self.addr, pcap_text, "Sending {} bytes of response tp {}", self.addr, bytes.len());
}
match timeout(
self.config.load().response_write_timeout,
self.stream_tx.write_all(msg.as_stream_slice()),
)
.await
{
Err(_) => {
error!(
"Write timed out (>{:?})",
self.config.load().response_write_timeout
);
return Err(ConnectionEvent::DisconnectWithoutFlush);
}
Ok(Err(err)) => {
error!("Write error: {err}");
return Err(ConnectionEvent::DisconnectWithoutFlush);
}
Ok(Ok(_)) => {
self.metrics.inc_num_sent_responses();
}
}
self.metrics.dec_num_pending_writes();
if self.result_q_tx.capacity() == self.result_q_tx.max_capacity() {
self.idle_timer.response_queue_emptied();
}
Ok(())
}
fn process_dns_idle_timeout(
&self,
timeout: Duration,
) -> Result<(), ConnectionEvent> {
if self.idle_timer.idle_timeout_expired(timeout)
&& !self.in_transaction.load(Ordering::SeqCst)
{
trace!("Timing out idle connection");
Err(ConnectionEvent::DisconnectWithoutFlush)
} else {
Ok(())
}
}
async fn process_read_request(
&mut self,
res: Result<Buf::Output, ConnectionEvent>,
) -> Result<(), ConnectionEvent>
where
Svc::Stream: Send,
Svc::Target: Default,
{
match res {
Ok(buf) => {
let received_at = Instant::now();
if log_enabled!(Level::Trace) {
let pcap_text = to_pcap_text(&buf, buf.as_ref().len());
trace!(addr = %self.addr, pcap_text, "Received message");
}
self.metrics.inc_num_received_requests();
self.idle_timer.full_msg_received();
match Message::from_octets(buf) {
Err(err) => {
tracing::warn!(
"Failed while parsing request message: {err}"
);
return Err(ConnectionEvent::DisconnectWithoutFlush);
}
Ok(msg) if msg.header().qr() => {
trace!("Ignoring received message because it is a reply, not a query.");
let response =
mk_error_response::<Buf::Output, Svc::Target>(
&msg,
OptRcode::FORMERR,
);
let dispatcher = self.request_dispatcher.clone();
tokio::spawn(async move {
dispatcher.do_enqueue_response(response).await;
});
}
Ok(msg) => {
let ctx = NonUdpTransportContext::new(Some(
self.config.load().idle_timeout,
));
let request = Request::new(
self.addr,
received_at,
msg,
TransportSpecificContext::NonUdp(ctx),
Default::default(),
);
trace!(
"Spawning task to handle new message with id {}",
request.message().header().id()
);
let mut dispatcher = self.request_dispatcher.clone();
let service = self.service.clone();
tokio::spawn(async move {
dispatcher.dispatch(request, service, ()).await
});
}
}
Ok(())
}
Err(err) => Err(err),
}
}
}
impl<Stream, Buf, Svc, RequestMeta> Drop
for Connection<Stream, Buf, Svc, RequestMeta>
where
RequestMeta: Default + Clone + Send + 'static,
Buf: BufSource,
Buf::Output: Send + Sync + Unpin,
Svc: Service<Buf::Output, RequestMeta> + Clone,
{
fn drop(&mut self) {
if self.active {
self.active = false;
self.metrics.dec_num_connections();
}
}
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
enum Status {
New,
WaitingForMessageHeader,
WaitingForMessageBody,
MessageReceived,
}
struct DnsMessageReceiver<Stream, Buf> {
msg_size_buf: [u8; 2],
buf: Buf,
stream_rx: ReadHalf<Stream>,
status: Status,
#[cfg(test)]
cancelled: bool,
}
impl<Stream, Buf> DnsMessageReceiver<Stream, Buf>
where
Stream: AsyncRead + AsyncWrite + Send + Sync + 'static,
Buf: BufSource + Send + Sync + 'static + Clone,
Buf::Output: Send + Sync + 'static,
{
fn new(buf: Buf, stream_rx: ReadHalf<Stream>) -> Self {
Self {
msg_size_buf: [0; 2],
buf,
stream_rx,
status: Status::New,
#[cfg(test)]
cancelled: false,
}
}
#[cfg(test)]
pub fn cancelled(&self) -> bool {
self.cancelled
}
pub async fn recv(&mut self) -> Result<Buf::Output, ConnectionEvent> {
#[cfg(test)]
if self.status == Status::WaitingForMessageBody {
self.cancelled = true;
}
self.status = Status::WaitingForMessageHeader;
Self::recv_n_bytes(&mut self.stream_rx, &mut self.msg_size_buf)
.await?;
let msg_len = u16::from_be_bytes(self.msg_size_buf) as usize;
let mut msg_buf = self.buf.create_sized(msg_len);
self.status = Status::WaitingForMessageBody;
Self::recv_n_bytes(&mut self.stream_rx, &mut msg_buf).await?;
self.status = Status::MessageReceived;
Ok(msg_buf)
}
async fn recv_n_bytes<T: AsMut<[u8]>>(
stream_rx: &mut ReadHalf<Stream>,
buf: &mut T,
) -> Result<(), ConnectionEvent> {
loop {
match stream_rx.read_exact(buf.as_mut()).await {
Ok(_size) => return Ok(()),
Err(err) => match Self::process_io_error(err) {
ControlFlow::Continue(_) => continue,
ControlFlow::Break(err) => return Err(err),
},
}
}
}
fn process_io_error(err: io::Error) -> ControlFlow<ConnectionEvent> {
match err.kind() {
io::ErrorKind::UnexpectedEof => {
ControlFlow::Break(ConnectionEvent::DisconnectWithoutFlush)
}
io::ErrorKind::TimedOut | io::ErrorKind::Interrupted => {
ControlFlow::Continue(())
}
_ => {
error!("I/O error: {}", err);
ControlFlow::Break(ConnectionEvent::DisconnectWithoutFlush)
}
}
}
}
enum ConnectionEvent {
DisconnectWithoutFlush,
DisconnectWithFlush,
}
impl Display for ConnectionEvent {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
ConnectionEvent::DisconnectWithoutFlush => {
write!(f, "Disconnect without flush")
}
ConnectionEvent::DisconnectWithFlush => {
write!(f, "Disconnect with flush")
}
}
}
}
pub struct IdleTimer {
idle_timer_reset_at: Instant,
}
impl IdleTimer {
#[must_use]
fn new() -> Self {
Self {
idle_timer_reset_at: Instant::now(),
}
}
#[must_use]
pub fn idle_timeout_deadline(&self, timeout: Duration) -> Instant {
self.idle_timer_reset_at
.checked_add(timeout)
.unwrap_or_else(|| {
warn!("Unable to reset idle timer: value out of bounds");
Instant::now()
})
}
#[must_use]
pub fn idle_timeout_expired(&self, timeout: Duration) -> bool {
self.idle_timeout_deadline(timeout) <= Instant::now()
}
fn reset_idle_timer(&mut self) {
self.idle_timer_reset_at = Instant::now();
}
fn full_msg_received(&mut self) {
self.reset_idle_timer()
}
fn response_queue_emptied(&mut self) {
self.reset_idle_timer()
}
}
struct ServiceResponseHandler<RequestOctets, Svc, RequestMeta>
where
RequestOctets: AsRef<[u8]> + Send + Sync,
RequestMeta: Clone + Default,
Svc: Service<RequestOctets, RequestMeta> + Clone,
{
config: Arc<ArcSwap<Config>>,
result_q_tx: mpsc::Sender<AdditionalBuilder<StreamTarget<Svc::Target>>>,
metrics: Arc<ServerMetrics>,
status: InvokerStatus,
}
impl<RequestOctets, Svc, RequestMeta>
ServiceResponseHandler<RequestOctets, Svc, RequestMeta>
where
RequestOctets: AsRef<[u8]> + Send + Sync,
RequestMeta: Clone + Default,
Svc: Service<RequestOctets, RequestMeta> + Clone,
{
fn new(
config: Arc<ArcSwap<Config>>,
result_q_tx: mpsc::Sender<
AdditionalBuilder<StreamTarget<Svc::Target>>,
>,
metrics: Arc<ServerMetrics>,
) -> Self {
Self {
config,
result_q_tx,
metrics,
status: InvokerStatus::Normal,
}
}
fn update_config(&self, idle_timeout: Option<Duration>) {
if let Some(idle_timeout) = idle_timeout {
debug!("Reconfigured connection timeout to {idle_timeout:?}");
let guard = self.config.load();
let mut new_config = **guard;
new_config.idle_timeout = idle_timeout;
self.config.store(Arc::new(new_config));
}
}
async fn do_enqueue_response(
&self,
mut response: AdditionalBuilder<StreamTarget<Svc::Target>>,
) {
loop {
match self.result_q_tx.try_send(response) {
Ok(()) => {
let pending_writes = self.result_q_tx.max_capacity()
- self.result_q_tx.capacity();
trace!("Queued message for sending: # pending writes={pending_writes}");
self.metrics.set_num_pending_writes(pending_writes);
break;
}
Err(TrySendError::Closed(_)) => {
error!("Unable to queue message for sending: connection is shutting down.");
break;
}
Err(TrySendError::Full(unused_response)) => {
if matches!(self.status, InvokerStatus::InTransaction) {
tokio::task::yield_now().await;
response = unused_response;
} else {
error!("Unable to queue message for sending: queue is full.");
break;
}
}
}
}
}
}
impl<RequestOctets, Svc, RequestMeta> Clone
for ServiceResponseHandler<RequestOctets, Svc, RequestMeta>
where
RequestOctets: AsRef<[u8]> + Send + Sync,
RequestMeta: Clone + Default,
Svc: Service<RequestOctets, RequestMeta> + Clone,
{
fn clone(&self) -> Self {
Self {
config: self.config.clone(),
result_q_tx: self.result_q_tx.clone(),
metrics: self.metrics.clone(),
status: InvokerStatus::Normal,
}
}
}
impl<RequestOctets, Svc, RequestMeta>
ServiceInvoker<RequestOctets, Svc, RequestMeta, ()>
for ServiceResponseHandler<RequestOctets, Svc, RequestMeta>
where
RequestOctets: Octets + Send + Sync + 'static,
RequestMeta: Clone + Default + Send + 'static,
Svc: Service<RequestOctets, RequestMeta> + Clone,
{
fn status(&self) -> InvokerStatus {
self.status
}
fn set_status(&mut self, status: InvokerStatus) {
self.status = status;
}
fn reconfigure(&self, idle_timeout: Option<Duration>) {
self.update_config(idle_timeout);
}
fn enqueue_response(
&self,
response: AdditionalBuilder<StreamTarget<Svc::Target>>,
_meta: &(),
) -> Pin<Box<dyn Future<Output = ()> + Send + '_>> {
Box::pin(async move { self.do_enqueue_response(response).await })
}
}