use cfg_if::cfg_if;
use std::fmt;
use crate::{
authority::{LookupError, LookupObject, MessageRequest, UpdateResult, ZoneType},
proto::rr::{LowerName, RecordSet, RecordType, RrsetRecords},
server::RequestInfo,
};
#[cfg(feature = "__dnssec")]
use crate::{
dnssec::NxProofKind,
proto::{
ProtoError,
dnssec::{DnsSecResult, Nsec3HashAlgorithm, SigSigner, crypto::Digest, rdata::key::KEY},
rr::Name,
},
};
#[derive(Clone, Copy, Debug, Default)]
pub struct LookupOptions {
dnssec_ok: bool,
}
impl LookupOptions {
#[cfg(feature = "__dnssec")]
pub fn for_dnssec(dnssec_ok: bool) -> Self {
Self { dnssec_ok }
}
#[allow(clippy::needless_update)]
pub fn set_dnssec_ok(self, val: bool) -> Self {
Self {
dnssec_ok: val,
..self
}
}
pub fn dnssec_ok(&self) -> bool {
self.dnssec_ok
}
pub fn rrset_with_rrigs<'r>(&self, record_set: &'r RecordSet) -> RrsetRecords<'r> {
cfg_if! {
if #[cfg(feature = "__dnssec")] {
record_set.records(self.dnssec_ok())
} else {
record_set.records_without_rrsigs()
}
}
}
}
#[async_trait::async_trait]
pub trait Authority: Send + Sync {
type Lookup: Send + Sync + Sized + 'static;
fn zone_type(&self) -> ZoneType;
fn is_axfr_allowed(&self) -> bool;
fn can_validate_dnssec(&self) -> bool {
false
}
async fn update(&self, update: &MessageRequest) -> UpdateResult<bool>;
fn origin(&self) -> &LowerName;
async fn lookup(
&self,
name: &LowerName,
rtype: RecordType,
lookup_options: LookupOptions,
) -> LookupControlFlow<Self::Lookup>;
async fn consult(
&self,
_name: &LowerName,
_rtype: RecordType,
_lookup_options: LookupOptions,
last_result: LookupControlFlow<Box<dyn LookupObject>>,
) -> LookupControlFlow<Box<dyn LookupObject>> {
last_result
}
async fn search(
&self,
request: RequestInfo<'_>,
lookup_options: LookupOptions,
) -> LookupControlFlow<Self::Lookup>;
async fn ns(&self, lookup_options: LookupOptions) -> LookupControlFlow<Self::Lookup> {
self.lookup(self.origin(), RecordType::NS, lookup_options)
.await
}
async fn get_nsec_records(
&self,
name: &LowerName,
lookup_options: LookupOptions,
) -> LookupControlFlow<Self::Lookup>;
#[cfg(feature = "__dnssec")]
async fn get_nsec3_records(
&self,
info: Nsec3QueryInfo<'_>,
lookup_options: LookupOptions,
) -> LookupControlFlow<Self::Lookup>;
async fn soa(&self) -> LookupControlFlow<Self::Lookup> {
self.lookup(self.origin(), RecordType::SOA, LookupOptions::default())
.await
}
async fn soa_secure(&self, lookup_options: LookupOptions) -> LookupControlFlow<Self::Lookup> {
self.lookup(self.origin(), RecordType::SOA, lookup_options)
.await
}
#[cfg(feature = "__dnssec")]
fn nx_proof_kind(&self) -> Option<&NxProofKind>;
}
#[cfg(feature = "__dnssec")]
#[async_trait::async_trait]
pub trait DnssecAuthority: Authority {
async fn add_update_auth_key(&self, name: Name, key: KEY) -> DnsSecResult<()>;
async fn add_zone_signing_key(&self, signer: SigSigner) -> DnsSecResult<()>;
async fn secure_zone(&self) -> DnsSecResult<()>;
}
pub enum LookupControlFlow<T, E = LookupError> {
Continue(Result<T, E>),
Break(Result<T, E>),
Skip,
}
impl<T, E> fmt::Display for LookupControlFlow<T, E> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Continue(cont) => match cont {
Ok(_) => write!(f, "LookupControlFlow::Continue(Ok)"),
Err(_) => write!(f, "LookupControlFlow::Continue(Err)"),
},
Self::Break(b) => match b {
Ok(_) => write!(f, "LookupControlFlow::Break(Ok)"),
Err(_) => write!(f, "LookupControlFlow::Break(Err)"),
},
Self::Skip => write!(f, "LookupControlFlow::Skip"),
}
}
}
impl<T, E> LookupControlFlow<T, E> {
pub fn is_continue(&self) -> bool {
matches!(self, Self::Continue(_))
}
pub fn is_break(&self) -> bool {
matches!(self, Self::Break(_))
}
pub fn map_result(self) -> Option<Result<T, E>> {
match self {
Self::Continue(Ok(lookup)) | Self::Break(Ok(lookup)) => Some(Ok(lookup)),
Self::Continue(Err(e)) | Self::Break(Err(e)) => Some(Err(e)),
Self::Skip => None,
}
}
}
impl<T: LookupObject + 'static, E: std::fmt::Display> LookupControlFlow<T, E> {
pub fn expect(self, msg: &str) -> T {
match self {
Self::Continue(Ok(ok)) | Self::Break(Ok(ok)) => ok,
_ => {
panic!("lookupcontrolflow::expect() called on unexpected variant {self}: {msg}");
}
}
}
pub fn expect_err(self, msg: &str) -> E {
match self {
Self::Continue(Err(e)) | Self::Break(Err(e)) => e,
_ => {
panic!(
"lookupcontrolflow::expect_err() called on unexpected variant {self}: {msg}"
);
}
}
}
pub fn unwrap(self) -> T {
match self {
Self::Continue(Ok(ok)) | Self::Break(Ok(ok)) => ok,
Self::Continue(Err(e)) | Self::Break(Err(e)) => {
panic!("lookupcontrolflow::unwrap() called on unexpected variant _(Err(_)): {e}");
}
_ => {
panic!("lookupcontrolflow::unwrap() called on unexpected variant: {self}");
}
}
}
pub fn unwrap_err(self) -> E {
match self {
Self::Continue(Err(e)) | Self::Break(Err(e)) => e,
_ => {
panic!("lookupcontrolflow::unwrap_err() called on unexpected variant: {self}");
}
}
}
pub fn unwrap_or_default(self) -> T
where
T: Default,
{
match self {
Self::Continue(Ok(ok)) | Self::Break(Ok(ok)) => ok,
_ => T::default(),
}
}
pub fn map<U, F: FnOnce(T) -> U>(self, op: F) -> LookupControlFlow<U, E> {
match self {
Self::Continue(cont) => match cont {
Ok(t) => LookupControlFlow::Continue(Ok(op(t))),
Err(e) => LookupControlFlow::Continue(Err(e)),
},
Self::Break(b) => match b {
Ok(t) => LookupControlFlow::Break(Ok(op(t))),
Err(e) => LookupControlFlow::Break(Err(e)),
},
Self::Skip => LookupControlFlow::<U, E>::Skip,
}
}
pub fn map_dyn(self) -> LookupControlFlow<Box<dyn LookupObject>, E> {
match self {
Self::Continue(cont) => match cont {
Ok(lookup) => {
LookupControlFlow::Continue(Ok(Box::new(lookup) as Box<dyn LookupObject>))
}
Err(e) => LookupControlFlow::Continue(Err(e)),
},
Self::Break(b) => match b {
Ok(lookup) => {
LookupControlFlow::Break(Ok(Box::new(lookup) as Box<dyn LookupObject>))
}
Err(e) => LookupControlFlow::Break(Err(e)),
},
Self::Skip => LookupControlFlow::<Box<dyn LookupObject>, E>::Skip,
}
}
pub fn map_err<U, F: FnOnce(E) -> U>(self, op: F) -> LookupControlFlow<T, U> {
match self {
Self::Continue(cont) => match cont {
Ok(lookup) => LookupControlFlow::Continue(Ok(lookup)),
Err(e) => LookupControlFlow::Continue(Err(op(e))),
},
Self::Break(b) => match b {
Ok(lookup) => LookupControlFlow::Break(Ok(lookup)),
Err(e) => LookupControlFlow::Break(Err(op(e))),
},
Self::Skip => LookupControlFlow::Skip,
}
}
}
#[cfg(feature = "__dnssec")]
pub struct Nsec3QueryInfo<'q> {
pub qname: &'q LowerName,
pub qtype: RecordType,
pub has_wildcard_match: bool,
pub algorithm: Nsec3HashAlgorithm,
pub salt: &'q [u8],
pub iterations: u16,
}
#[cfg(feature = "__dnssec")]
impl Nsec3QueryInfo<'_> {
pub(crate) fn hash_name(&self, name: &Name) -> Result<Digest, ProtoError> {
self.algorithm.hash(self.salt, name, self.iterations)
}
pub(crate) fn get_hashed_owner_name(
&self,
name: &LowerName,
zone: &Name,
) -> Result<LowerName, ProtoError> {
let hash = self.hash_name(name)?;
let label = data_encoding::BASE32_DNSSEC.encode(hash.as_ref());
Ok(LowerName::new(&zone.prepend_label(label)?))
}
}