use std::{fmt, io, sync::Arc};
use cfg_if::cfg_if;
use serde::{Deserialize, Serialize};
use thiserror::Error;
#[cfg(feature = "__dnssec")]
use crate::dnssec::NxProofKind;
use crate::net::{DnsError, NetError, NoRecords};
use crate::proto::ProtoError;
#[cfg(feature = "__dnssec")]
use crate::proto::dnssec::crypto::Digest;
#[cfg(feature = "__dnssec")]
use crate::proto::dnssec::{DnsSecResult, DnssecSigner, Nsec3HashAlgorithm};
use crate::proto::op::{Edns, ResponseCode};
#[cfg(feature = "__dnssec")]
use crate::proto::rr::Name;
use crate::proto::rr::{
LowerName, Record, RecordSet, RecordType, RrsetRecords, TSigResponseContext, rdata::SOA,
};
#[cfg(feature = "recursor")]
use crate::resolver::recursor::RecursorError;
use crate::server::{Request, RequestInfo};
mod auth_lookup;
mod catalog;
pub(crate) mod message_request;
mod message_response;
pub use self::auth_lookup::{
AuthLookup, AuthLookupIter, AxfrRecords, AxfrRecordsIter, LookupRecords, LookupRecordsIter,
ZoneTransfer,
};
pub use self::catalog::Catalog;
pub use self::message_request::{MessageRequest, Queries, UpdateRequest};
pub use self::message_response::{MessageResponse, MessageResponseBuilder};
#[async_trait::async_trait]
pub trait ZoneHandler: Send + Sync {
fn zone_type(&self) -> ZoneType;
fn axfr_policy(&self) -> AxfrPolicy;
fn can_validate_dnssec(&self) -> bool {
false
}
async fn update(
&self,
_update: &Request,
_now: u64,
) -> (Result<bool, ResponseCode>, Option<TSigResponseContext>) {
(Err(ResponseCode::NotImp), None)
}
fn origin(&self) -> &LowerName;
async fn lookup(
&self,
name: &LowerName,
rtype: RecordType,
request_info: Option<&RequestInfo<'_>>,
lookup_options: LookupOptions,
) -> LookupControlFlow<AuthLookup>;
async fn consult(
&self,
_name: &LowerName,
_rtype: RecordType,
_request_info: Option<&RequestInfo<'_>>,
_lookup_options: LookupOptions,
last_result: LookupControlFlow<AuthLookup>,
) -> (LookupControlFlow<AuthLookup>, Option<TSigResponseContext>) {
(last_result, None)
}
async fn search(
&self,
request: &Request,
lookup_options: LookupOptions,
) -> (LookupControlFlow<AuthLookup>, Option<TSigResponseContext>);
async fn nsec_records(
&self,
name: &LowerName,
lookup_options: LookupOptions,
) -> LookupControlFlow<AuthLookup>;
#[cfg(feature = "__dnssec")]
async fn nsec3_records(
&self,
info: Nsec3QueryInfo<'_>,
lookup_options: LookupOptions,
) -> LookupControlFlow<AuthLookup>;
async fn zone_transfer(
&self,
_request: &Request,
_lookup_options: LookupOptions,
_now: u64,
) -> Option<(
Result<ZoneTransfer, LookupError>,
Option<TSigResponseContext>,
)> {
Some((Err(LookupError::from(ResponseCode::NotImp)), None))
}
#[cfg(feature = "__dnssec")]
fn nx_proof_kind(&self) -> Option<&NxProofKind>;
#[cfg(feature = "metrics")]
fn metrics_label(&self) -> &'static str;
}
#[cfg(feature = "__dnssec")]
#[async_trait::async_trait]
pub trait DnssecZoneHandler: ZoneHandler {
async fn add_zone_signing_key(&self, signer: DnssecSigner) -> DnsSecResult<()>;
async fn secure_zone(&self) -> DnsSecResult<()>;
}
pub enum LookupControlFlow<T, E = LookupError> {
Continue(Result<T, E>),
Break(Result<T, E>),
Skip,
}
impl<T, E> fmt::Display for LookupControlFlow<T, E> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Continue(cont) => match cont {
Ok(_) => write!(f, "LookupControlFlow::Continue(Ok)"),
Err(_) => write!(f, "LookupControlFlow::Continue(Err)"),
},
Self::Break(b) => match b {
Ok(_) => write!(f, "LookupControlFlow::Break(Ok)"),
Err(_) => write!(f, "LookupControlFlow::Break(Err)"),
},
Self::Skip => write!(f, "LookupControlFlow::Skip"),
}
}
}
impl<T, E> LookupControlFlow<T, E> {
pub fn is_continue(&self) -> bool {
matches!(self, Self::Continue(_))
}
pub fn is_break(&self) -> bool {
matches!(self, Self::Break(_))
}
pub fn map_result(self) -> Option<Result<T, E>> {
match self {
Self::Continue(Ok(lookup)) | Self::Break(Ok(lookup)) => Some(Ok(lookup)),
Self::Continue(Err(e)) | Self::Break(Err(e)) => Some(Err(e)),
Self::Skip => None,
}
}
}
impl<E: fmt::Display> LookupControlFlow<AuthLookup, E> {
pub fn expect(self, msg: &str) -> AuthLookup {
match self {
Self::Continue(Ok(ok)) | Self::Break(Ok(ok)) => ok,
_ => {
panic!("lookupcontrolflow::expect() called on unexpected variant {self}: {msg}");
}
}
}
pub fn expect_err(self, msg: &str) -> E {
match self {
Self::Continue(Err(e)) | Self::Break(Err(e)) => e,
_ => {
panic!(
"lookupcontrolflow::expect_err() called on unexpected variant {self}: {msg}"
);
}
}
}
pub fn unwrap(self) -> AuthLookup {
match self {
Self::Continue(Ok(ok)) | Self::Break(Ok(ok)) => ok,
Self::Continue(Err(e)) | Self::Break(Err(e)) => {
panic!("lookupcontrolflow::unwrap() called on unexpected variant _(Err(_)): {e}");
}
_ => {
panic!("lookupcontrolflow::unwrap() called on unexpected variant: {self}");
}
}
}
pub fn unwrap_err(self) -> E {
match self {
Self::Continue(Err(e)) | Self::Break(Err(e)) => e,
_ => {
panic!("lookupcontrolflow::unwrap_err() called on unexpected variant: {self}");
}
}
}
pub fn unwrap_or_default(self) -> AuthLookup {
match self {
Self::Continue(Ok(ok)) | Self::Break(Ok(ok)) => ok,
_ => AuthLookup::default(),
}
}
pub fn map<U, F: FnOnce(AuthLookup) -> U>(self, op: F) -> LookupControlFlow<U, E> {
match self {
Self::Continue(cont) => match cont {
Ok(t) => LookupControlFlow::Continue(Ok(op(t))),
Err(e) => LookupControlFlow::Continue(Err(e)),
},
Self::Break(b) => match b {
Ok(t) => LookupControlFlow::Break(Ok(op(t))),
Err(e) => LookupControlFlow::Break(Err(e)),
},
Self::Skip => LookupControlFlow::<U, E>::Skip,
}
}
pub fn map_err<U, F: FnOnce(E) -> U>(self, op: F) -> LookupControlFlow<AuthLookup, U> {
match self {
Self::Continue(cont) => match cont {
Ok(lookup) => LookupControlFlow::Continue(Ok(lookup)),
Err(e) => LookupControlFlow::Continue(Err(op(e))),
},
Self::Break(b) => match b {
Ok(lookup) => LookupControlFlow::Break(Ok(lookup)),
Err(e) => LookupControlFlow::Break(Err(op(e))),
},
Self::Skip => LookupControlFlow::Skip,
}
}
}
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum LookupError {
#[error("there should only be one query per request, got {0}")]
BadQueryCount(usize),
#[error("The name exists, but not for the record requested")]
NameExists,
#[error("Error performing lookup: {0}")]
ResponseCode(ResponseCode),
#[error("net error: {0}")]
NetError(#[from] NetError),
#[cfg(feature = "recursor")]
#[error("Recursive resolution error: {0}")]
RecursiveError(#[from] RecursorError),
#[error("io error: {0}")]
Io(io::Error),
}
impl LookupError {
pub fn for_name_exists() -> Self {
Self::NameExists
}
pub fn is_nx_domain(&self) -> bool {
match self {
Self::NetError(e) => e.is_nx_domain(),
Self::ResponseCode(ResponseCode::NXDomain) => true,
#[cfg(feature = "recursor")]
Self::RecursiveError(e) if e.is_nx_domain() => true,
_ => false,
}
}
pub fn is_no_records_found(&self) -> bool {
match self {
Self::NetError(e) => e.is_no_records_found(),
#[cfg(feature = "recursor")]
Self::RecursiveError(e) if e.is_no_records_found() => true,
_ => false,
}
}
pub fn into_soa(self) -> Option<Box<Record<SOA>>> {
match self {
Self::NetError(e) => e.into_soa(),
#[cfg(feature = "recursor")]
Self::RecursiveError(e) => e.into_soa(),
_ => None,
}
}
pub fn authorities(&self) -> Option<Arc<[Record]>> {
match self {
Self::NetError(NetError::Dns(DnsError::NoRecordsFound(NoRecords {
authorities,
..
}))) => authorities.clone(),
Self::NetError(_) => None,
#[cfg(feature = "recursor")]
Self::RecursiveError(RecursorError::Negative(fwd)) => fwd.authorities.clone(),
#[cfg(feature = "recursor")]
Self::RecursiveError(RecursorError::Net(NetError::Dns(DnsError::NoRecordsFound(
NoRecords { authorities, .. },
)))) => authorities.clone(),
_ => None,
}
}
}
impl From<ResponseCode> for LookupError {
fn from(code: ResponseCode) -> Self {
debug_assert!(code != ResponseCode::NoError);
Self::ResponseCode(code)
}
}
impl From<io::Error> for LookupError {
fn from(e: io::Error) -> Self {
Self::Io(e)
}
}
impl From<LookupError> for io::Error {
fn from(e: LookupError) -> Self {
Self::other(Box::new(e))
}
}
impl From<ProtoError> for LookupError {
fn from(e: ProtoError) -> Self {
NetError::from(e).into()
}
}
#[cfg(feature = "__dnssec")]
pub struct Nsec3QueryInfo<'q> {
pub qname: &'q LowerName,
pub qtype: RecordType,
pub has_wildcard_match: bool,
pub algorithm: Nsec3HashAlgorithm,
pub salt: &'q [u8],
pub iterations: u16,
}
#[cfg(feature = "__dnssec")]
impl Nsec3QueryInfo<'_> {
pub(crate) fn hash_name(&self, name: &Name) -> Result<Digest, ProtoError> {
self.algorithm.hash(self.salt, name, self.iterations)
}
pub(crate) fn hashed_owner_name(
&self,
name: &LowerName,
zone: &Name,
) -> Result<LowerName, ProtoError> {
let hash = self.hash_name(name)?;
let label = data_encoding::BASE32_DNSSEC.encode(hash.as_ref());
Ok(LowerName::new(&zone.prepend_label(label)?))
}
}
#[non_exhaustive]
#[derive(Clone, Copy, Debug, Default)]
pub struct LookupOptions {
pub dnssec_ok: bool,
}
impl LookupOptions {
#[cfg_attr(not(feature = "__dnssec"), allow(unused_variables))]
pub fn from_edns(edns: Option<&Edns>) -> Self {
#[cfg_attr(not(feature = "__dnssec"), allow(unused_mut))]
let mut new = Self::default();
#[cfg(feature = "__dnssec")]
if let Some(edns) = edns {
new.dnssec_ok = edns.flags().dnssec_ok;
}
new
}
#[cfg(feature = "__dnssec")]
pub fn for_dnssec() -> Self {
Self { dnssec_ok: true }
}
pub fn rrset_with_rrigs<'r>(&self, record_set: &'r RecordSet) -> RrsetRecords<'r> {
cfg_if! {
if #[cfg(feature = "__dnssec")] {
record_set.records(self.dnssec_ok)
} else {
record_set.records_without_rrsigs()
}
}
}
}
#[derive(Debug, Default, Copy, Clone, PartialEq, Eq, Deserialize)]
pub enum AxfrPolicy {
#[default]
Deny,
AllowAll,
#[cfg(feature = "__dnssec")]
AllowSigned,
}
#[derive(Serialize, Deserialize, Hash, PartialEq, Eq, Debug, Clone, Copy)]
pub enum ZoneType {
Primary,
Secondary,
External,
}