#![deny(missing_docs)]
use std::{fmt, io, sync::Arc};
use enum_as_inner::EnumAsInner;
use thiserror::Error;
use tracing::warn;
#[cfg(feature = "backtrace")]
use crate::proto::{ExtBacktrace, trace};
use crate::proto::{
ForwardNSData, ProtoErrorKind,
op::ResponseCode,
rr::{Name, Record, rdata::SOA},
{ForwardData, ProtoError},
};
use crate::resolver::ResolveError;
#[derive(Debug, EnumAsInner, Error)]
#[non_exhaustive]
pub enum ErrorKind {
#[error("{0}")]
Message(&'static str),
#[error("{0}")]
Msg(String),
#[error("forward response")]
Forward(ForwardData),
#[error("forward NS Response")]
ForwardNS(Arc<[ForwardNSData]>),
#[error("io error: {0}")]
Io(#[from] std::io::Error),
#[error("proto error: {0}")]
Proto(#[from] ProtoError),
#[error("proto error: {0}")]
Resolve(ResolveError),
#[error("request timed out")]
Timeout,
#[error("maximum recursion limit exceeded: {count} queries")]
RecursionLimitExceeded {
count: usize,
},
}
#[derive(Error, Clone, Debug)]
#[non_exhaustive]
pub struct Error {
pub kind: Box<ErrorKind>,
#[cfg(feature = "backtrace")]
pub backtrack: Option<ExtBacktrace>,
}
impl Error {
pub fn kind(&self) -> &ErrorKind {
&self.kind
}
pub fn into_kind(self) -> ErrorKind {
*self.kind
}
pub fn is_nx_domain(&self) -> bool {
match &*self.kind {
ErrorKind::Proto(proto) => proto.is_nx_domain(),
ErrorKind::Resolve(err) => err.is_nx_domain(),
ErrorKind::Forward(fwd) => fwd.is_nx_domain(),
_ => false,
}
}
pub fn is_no_records_found(&self) -> bool {
match &*self.kind {
ErrorKind::Proto(proto) => proto.is_no_records_found(),
ErrorKind::Resolve(err) => err.is_no_records_found(),
ErrorKind::Forward(fwd) => fwd.is_no_records_found(),
_ => false,
}
}
pub fn is_timeout(&self) -> bool {
let proto_error = match &*self.kind {
ErrorKind::Proto(proto) => proto,
ErrorKind::Resolve(err) => match err.kind() {
hickory_resolver::ResolveErrorKind::Proto(proto) => proto,
_ => return false,
},
_ => return false,
};
matches!(proto_error.kind(), ProtoErrorKind::Timeout)
}
pub fn into_soa(self) -> Option<Box<Record<SOA>>> {
match *self.kind {
ErrorKind::Proto(proto) => proto.into_soa(),
ErrorKind::Resolve(err) => err.into_soa(),
ErrorKind::Forward(fwd) => Some(fwd.soa),
_ => None,
}
}
pub fn authorities(self) -> Option<Arc<[Record]>> {
match *self.kind {
ErrorKind::Forward(fwd) => fwd.authorities,
_ => None,
}
}
pub fn recursion_exceeded(limit: Option<u8>, depth: u8, name: &Name) -> Result<(), Error> {
match limit {
Some(limit) if depth > limit => {}
_ => return Ok(()),
}
warn!("recursion depth exceeded for {name}");
Err(ErrorKind::RecursionLimitExceeded {
count: depth as usize,
}
.into())
}
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
cfg_if::cfg_if! {
if #[cfg(feature = "backtrace")] {
if let Some(backtrace) = &self.backtrack {
fmt::Display::fmt(&self.kind, f)?;
fmt::Debug::fmt(backtrace, f)
} else {
fmt::Display::fmt(&self.kind, f)
}
} else {
fmt::Display::fmt(&self.kind, f)
}
}
}
}
impl<E> From<E> for Error
where
E: Into<ErrorKind>,
{
fn from(error: E) -> Self {
let kind: ErrorKind = error.into();
Self {
kind: Box::new(kind),
#[cfg(feature = "backtrace")]
backtrack: trace!(),
}
}
}
impl From<&'static str> for Error {
fn from(msg: &'static str) -> Self {
ErrorKind::Message(msg).into()
}
}
impl From<String> for Error {
fn from(msg: String) -> Self {
ErrorKind::Msg(msg).into()
}
}
impl From<Error> for io::Error {
fn from(e: Error) -> Self {
match e.kind() {
ErrorKind::Timeout => Self::new(io::ErrorKind::TimedOut, e),
_ => Self::new(io::ErrorKind::Other, e),
}
}
}
impl From<Error> for String {
fn from(e: Error) -> Self {
e.to_string()
}
}
impl From<ResolveError> for Error {
fn from(e: ResolveError) -> Self {
let nx_domain = e.is_nx_domain();
let no_records_found = e.is_no_records_found();
let proto_err = match ProtoErrorKind::try_from(e) {
Ok(res) => res,
Err(e) => return ErrorKind::Resolve(e).into(),
};
let ProtoErrorKind::NoRecordsFound {
query,
soa,
ns,
authorities,
..
} = proto_err
else {
return ErrorKind::Proto(proto_err.into()).into();
};
if let Some(ns) = ns {
ErrorKind::ForwardNS(ns).into()
} else if let Some(soa) = soa {
ErrorKind::Forward(ForwardData::new(
query,
soa.name().clone(),
soa,
no_records_found,
nx_domain,
authorities,
))
.into()
} else {
ErrorKind::Message("proto error missing ns and soa").into()
}
}
}
impl Clone for ErrorKind {
fn clone(&self) -> Self {
use self::ErrorKind::*;
match self {
Message(msg) => Message(msg),
Msg(msg) => Msg(msg.clone()),
Forward(ns) => Forward(ns.clone()),
ForwardNS(ns) => ForwardNS(ns.clone()),
Io(io) => Io(std::io::Error::from(io.kind())),
Proto(proto) => Proto(proto.clone()),
Resolve(resolve) => Resolve(resolve.clone()),
Timeout => Self::Timeout,
RecursionLimitExceeded { count } => RecursionLimitExceeded { count: *count },
}
}
}
impl From<Error> for ProtoError {
fn from(e: Error) -> Self {
let is_nx_domain = e.is_nx_domain();
match *e.kind {
ErrorKind::Forward(fwd) => ProtoError::nx_error(
fwd.query,
Some(fwd.soa),
None,
None,
if is_nx_domain {
ResponseCode::NXDomain
} else {
ResponseCode::NoError
},
true,
fwd.authorities,
),
_ => ProtoError::from(e.to_string()),
}
}
}