#![warn(missing_docs)]
#![warn(clippy::missing_docs_in_private_items)]
use bytes::Bytes;
use core::time::Duration;
use std::fmt::Debug;
use std::sync::{Arc, Mutex};
use std::vec::Vec;
use tokio::time::Instant;
use crate::base::opt::AllOptData;
use crate::base::{Message, Name};
use crate::dep::octseq::Octets;
use crate::net::client::request;
use crate::net::client::request::{ComposeRequest, RequestMessage};
#[derive(Clone, Debug, Default)]
pub struct UdpTransportContext {
max_response_size_hint: Arc<Mutex<Option<u16>>>,
}
impl UdpTransportContext {
pub fn new(max_response_size_hint: Option<u16>) -> Self {
let max_response_size_hint =
Arc::new(Mutex::new(max_response_size_hint));
Self {
max_response_size_hint,
}
}
}
impl UdpTransportContext {
pub fn max_response_size_hint(&self) -> Option<u16> {
*self.max_response_size_hint.lock().unwrap()
}
pub fn set_max_response_size_hint(
&self,
max_response_size_hint: Option<u16>,
) {
*self.max_response_size_hint.lock().unwrap() = max_response_size_hint;
}
}
#[derive(Clone, Copy, Debug)]
pub struct NonUdpTransportContext {
idle_timeout: Option<Duration>,
}
impl NonUdpTransportContext {
pub fn new(idle_timeout: Option<Duration>) -> Self {
Self { idle_timeout }
}
}
impl NonUdpTransportContext {
pub fn idle_timeout(&self) -> Option<Duration> {
self.idle_timeout
}
}
#[derive(Debug, Clone)]
pub enum TransportSpecificContext {
Udp(UdpTransportContext),
NonUdp(NonUdpTransportContext),
}
impl TransportSpecificContext {
pub fn is_udp(&self) -> bool {
matches!(self, Self::Udp(_))
}
pub fn is_non_udp(&self) -> bool {
matches!(self, Self::NonUdp(_))
}
}
impl From<UdpTransportContext> for TransportSpecificContext {
fn from(ctx: UdpTransportContext) -> Self {
Self::Udp(ctx)
}
}
impl From<NonUdpTransportContext> for TransportSpecificContext {
fn from(ctx: NonUdpTransportContext) -> Self {
Self::NonUdp(ctx)
}
}
#[derive(Debug)]
pub struct Request<Octs, Metadata>
where
Octs: AsRef<[u8]> + Send + Sync,
{
client_addr: std::net::SocketAddr,
received_at: Instant,
message: Arc<Message<Octs>>,
transport_specific: TransportSpecificContext,
num_reserved_bytes: u16,
metadata: Metadata,
}
impl<Octs, Metadata> Request<Octs, Metadata>
where
Octs: AsRef<[u8]> + Send + Sync,
{
pub fn new(
client_addr: std::net::SocketAddr,
received_at: Instant,
message: Message<Octs>,
transport_specific: TransportSpecificContext,
metadata: Metadata,
) -> Self {
Self {
client_addr,
received_at,
message: Arc::new(message),
transport_specific,
num_reserved_bytes: 0,
metadata,
}
}
pub fn received_at(&self) -> Instant {
self.received_at
}
pub fn transport_ctx(&self) -> &TransportSpecificContext {
&self.transport_specific
}
pub fn client_addr(&self) -> std::net::SocketAddr {
self.client_addr
}
pub fn message(&self) -> &Arc<Message<Octs>> {
&self.message
}
pub fn reserve_bytes(&mut self, len: u16) {
self.num_reserved_bytes += len;
tracing::trace!(
"Reserved {len} bytes: total now = {}",
self.num_reserved_bytes
);
}
pub fn num_reserved_bytes(&self) -> u16 {
self.num_reserved_bytes
}
pub fn with_new_metadata<T>(self, new_metadata: T) -> Request<Octs, T> {
Request::<Octs, T> {
client_addr: self.client_addr,
received_at: self.received_at,
message: self.message,
transport_specific: self.transport_specific,
num_reserved_bytes: self.num_reserved_bytes,
metadata: new_metadata,
}
}
pub fn metadata(&self) -> &Metadata {
&self.metadata
}
}
impl<Octs, Metadata> Clone for Request<Octs, Metadata>
where
Octs: AsRef<[u8]> + Send + Sync,
Metadata: Clone,
{
fn clone(&self) -> Self {
Self {
client_addr: self.client_addr,
received_at: self.received_at,
message: Arc::clone(&self.message),
transport_specific: self.transport_specific.clone(),
num_reserved_bytes: self.num_reserved_bytes,
metadata: self.metadata.clone(),
}
}
}
impl<Octs: Octets + Send + Sync + Debug + Clone, Meta>
TryFrom<Request<Octs, Meta>> for RequestMessage<Octs>
{
type Error = request::Error;
fn try_from(req: Request<Octs, Meta>) -> Result<Self, Self::Error> {
let mut extra_opts: Vec<AllOptData<Bytes, Name<Bytes>>> = vec![];
let bytes = Bytes::copy_from_slice(req.message.as_slice());
let bytes_msg = Message::from_octets(bytes)?;
if let Some(optrec) = bytes_msg.opt() {
for opt in optrec.opt().iter::<AllOptData<_, _>>() {
let opt = opt?;
if let AllOptData::ClientSubnet(_ecs) = opt {
extra_opts.push(opt);
}
}
}
let set_do = dnssec_ok(&req.message);
let msg = Message::from_octets(req.message.as_octets().clone())?;
let mut reqmsg = RequestMessage::new(msg)?;
if set_do {
reqmsg.set_dnssec_ok(true);
}
for opt in &extra_opts {
reqmsg.add_opt(opt)?;
}
Ok(reqmsg)
}
}
fn dnssec_ok<Octs: Octets>(msg: &Message<Octs>) -> bool {
if let Some(opt) = msg.opt() {
opt.dnssec_ok()
} else {
false
}
}