use std::{borrow::Borrow, collections::HashMap, future::Future, io};
use cfg_if::cfg_if;
use tracing::{debug, error, info, trace, warn};
#[cfg(feature = "dnssec")]
use crate::proto::rr::{
dnssec::{Algorithm, SupportedAlgorithms},
rdata::opt::{EdnsCode, EdnsOption},
};
use crate::{
authority::{
AuthLookup, AuthorityObject, EmptyLookup, LookupError, LookupObject, LookupOptions,
MessageResponse, MessageResponseBuilder, ZoneType,
},
proto::op::{Edns, Header, LowerQuery, MessageType, OpCode, ResponseCode},
proto::rr::{LowerName, Record, RecordType},
server::{Request, RequestHandler, RequestInfo, ResponseHandler, ResponseInfo},
};
#[derive(Default)]
pub struct Catalog {
authorities: HashMap<LowerName, Box<dyn AuthorityObject>>,
}
#[allow(unused_mut, unused_variables)]
async fn send_response<'a, R: ResponseHandler>(
response_edns: Option<Edns>,
mut response: MessageResponse<
'_,
'a,
impl Iterator<Item = &'a Record> + Send + 'a,
impl Iterator<Item = &'a Record> + Send + 'a,
impl Iterator<Item = &'a Record> + Send + 'a,
impl Iterator<Item = &'a Record> + Send + 'a,
>,
mut response_handle: R,
) -> io::Result<ResponseInfo> {
if let Some(mut resp_edns) = response_edns {
#[cfg(feature = "dnssec")]
{
let mut algorithms = SupportedAlgorithms::default();
algorithms.set(Algorithm::RSASHA256);
algorithms.set(Algorithm::ECDSAP256SHA256);
algorithms.set(Algorithm::ECDSAP384SHA384);
algorithms.set(Algorithm::ED25519);
let dau = EdnsOption::DAU(algorithms);
let dhu = EdnsOption::DHU(algorithms);
resp_edns.options_mut().insert(dau);
resp_edns.options_mut().insert(dhu);
}
response.set_edns(resp_edns);
}
response_handle.send_response(response).await
}
#[async_trait::async_trait]
impl RequestHandler for Catalog {
async fn handle_request<R: ResponseHandler>(
&self,
request: &Request,
mut response_handle: R,
) -> ResponseInfo {
trace!("request: {:?}", request);
let response_edns: Option<Edns>;
if let Some(req_edns) = request.edns() {
let mut response = MessageResponseBuilder::new(Some(request.raw_query()));
let mut response_header = Header::response_from_request(request.header());
let mut resp_edns: Edns = Edns::new();
let our_version = 0;
resp_edns.set_dnssec_ok(true);
resp_edns.set_max_payload(req_edns.max_payload().max(512));
resp_edns.set_version(our_version);
if req_edns.version() > our_version {
warn!(
"request edns version greater than {}: {}",
our_version,
req_edns.version()
);
response_header.set_response_code(ResponseCode::BADVERS);
resp_edns.set_rcode_high(ResponseCode::BADVERS.high());
response.edns(resp_edns);
let result = response_handle
.send_response(response.build_no_records(response_header))
.await;
return match result {
Err(e) => {
error!("request error: {}", e);
ResponseInfo::serve_failed()
}
Ok(info) => info,
};
}
response_edns = Some(resp_edns);
} else {
response_edns = None;
}
let result = match request.message_type() {
MessageType::Query => match request.op_code() {
OpCode::Query => {
debug!("query received: {}", request.id());
let info = self.lookup(request, response_edns, response_handle).await;
Ok(info)
}
OpCode::Update => {
debug!("update received: {}", request.id());
self.update(request, response_edns, response_handle).await
}
c => {
warn!("unimplemented op_code: {:?}", c);
let response = MessageResponseBuilder::new(Some(request.raw_query()));
response_handle
.send_response(response.error_msg(request.header(), ResponseCode::NotImp))
.await
}
},
MessageType::Response => {
warn!("got a response as a request from id: {}", request.id());
let response = MessageResponseBuilder::new(Some(request.raw_query()));
response_handle
.send_response(response.error_msg(request.header(), ResponseCode::FormErr))
.await
}
};
match result {
Err(e) => {
error!("request failed: {}", e);
ResponseInfo::serve_failed()
}
Ok(info) => info,
}
}
}
impl Catalog {
pub fn new() -> Self {
Self {
authorities: HashMap::new(),
}
}
pub fn upsert(&mut self, name: LowerName, authority: Box<dyn AuthorityObject>) {
self.authorities.insert(name, authority);
}
pub fn remove(&mut self, name: &LowerName) -> Option<Box<dyn AuthorityObject>> {
self.authorities.remove(name)
}
pub async fn update<R: ResponseHandler>(
&self,
update: &Request,
response_edns: Option<Edns>,
response_handle: R,
) -> io::Result<ResponseInfo> {
let request_info = update.request_info();
let verify_request = move || -> Result<RequestInfo<'_>, ResponseCode> {
let ztype = request_info.query.query_type();
if ztype != RecordType::SOA {
warn!(
"invalid update request zone type must be SOA, ztype: {}",
ztype
);
return Err(ResponseCode::FormErr);
}
Ok(request_info)
};
let request_info = verify_request();
let authority = request_info.as_ref().map_err(|e| *e).and_then(|info| {
self.find(info.query.name())
.map(|a| a.box_clone())
.ok_or(ResponseCode::Refused)
});
let response_code = match authority {
Ok(authority) => {
#[allow(deprecated)]
match authority.zone_type() {
ZoneType::Secondary | ZoneType::Slave => {
error!("secondary forwarding for update not yet implemented");
ResponseCode::NotImp
}
ZoneType::Primary | ZoneType::Master => {
let update_result = authority.update(update).await;
match update_result {
Ok(..) => ResponseCode::NoError,
Err(response_code) => response_code,
}
}
_ => ResponseCode::NotAuth,
}
}
Err(response_code) => response_code,
};
let response = MessageResponseBuilder::new(Some(update.raw_query()));
let mut response_header = Header::default();
response_header.set_id(update.id());
response_header.set_op_code(OpCode::Update);
response_header.set_message_type(MessageType::Response);
response_header.set_response_code(response_code);
send_response(
response_edns,
response.build_no_records(response_header),
response_handle,
)
.await
}
pub fn contains(&self, name: &LowerName) -> bool {
self.authorities.contains_key(name)
}
pub async fn lookup<R: ResponseHandler>(
&self,
request: &Request,
response_edns: Option<Edns>,
response_handle: R,
) -> ResponseInfo {
let request_info = request.request_info();
let authority = self.find(request_info.query.name());
if let Some(authority) = authority {
lookup(
request_info,
authority,
request,
response_edns
.as_ref()
.map(|arc| Borrow::<Edns>::borrow(arc).clone()),
response_handle.clone(),
)
.await
} else {
let response = MessageResponseBuilder::new(Some(request.raw_query()));
let result = send_response(
response_edns,
response.error_msg(request.header(), ResponseCode::Refused),
response_handle,
)
.await;
match result {
Err(e) => {
error!("failed to send response: {}", e);
ResponseInfo::serve_failed()
}
Ok(r) => r,
}
}
}
pub fn find(&self, name: &LowerName) -> Option<&(dyn AuthorityObject + 'static)> {
debug!("searching authorities for: {}", name);
self.authorities
.get(name)
.map(|authority| &**authority)
.or_else(|| {
if !name.is_root() {
let name = name.base_name();
self.find(&name)
} else {
None
}
})
}
}
async fn lookup<'a, R: ResponseHandler + Unpin>(
request_info: RequestInfo<'_>,
authority: &dyn AuthorityObject,
request: &Request,
response_edns: Option<Edns>,
response_handle: R,
) -> ResponseInfo {
let query = request_info.query;
debug!(
"request: {} found authority: {}",
request.id(),
authority.origin()
);
let (response_header, sections) = build_response(
authority,
request_info,
request.id(),
request.header(),
query,
request.edns(),
)
.await;
let response = MessageResponseBuilder::new(Some(request.raw_query())).build(
response_header,
sections.answers.iter(),
sections.ns.iter(),
sections.soa.iter(),
sections.additionals.iter(),
);
let result = send_response(response_edns.clone(), response, response_handle.clone()).await;
match result {
Err(e) => {
error!("error sending response: {}", e);
ResponseInfo::serve_failed()
}
Ok(i) => i,
}
}
#[allow(unused_variables)]
fn lookup_options_for_edns(edns: Option<&Edns>) -> LookupOptions {
let edns = match edns {
Some(edns) => edns,
None => return LookupOptions::default(),
};
cfg_if! {
if #[cfg(feature = "dnssec")] {
let supported_algorithms = if let Some(&EdnsOption::DAU(algs)) = edns.option(EdnsCode::DAU)
{
algs
} else {
debug!("no DAU in request, used default SupportAlgorithms");
SupportedAlgorithms::default()
};
LookupOptions::for_dnssec(edns.dnssec_ok(), supported_algorithms)
} else {
LookupOptions::default()
}
}
}
async fn build_response(
authority: &dyn AuthorityObject,
request_info: RequestInfo<'_>,
request_id: u16,
request_header: &Header,
query: &LowerQuery,
edns: Option<&Edns>,
) -> (Header, LookupSections) {
let lookup_options = lookup_options_for_edns(edns);
if lookup_options.is_dnssec() {
info!(
"request: {} lookup_options: {:?}",
request_id, lookup_options
);
}
let mut response_header = Header::response_from_request(request_header);
response_header.set_authoritative(authority.zone_type().is_authoritative());
debug!("performing {} on {}", query, authority.origin());
let future = authority.search(request_info, lookup_options);
#[allow(deprecated)]
let sections = match authority.zone_type() {
ZoneType::Primary | ZoneType::Secondary | ZoneType::Master | ZoneType::Slave => {
send_authoritative_response(
future,
authority,
&mut response_header,
lookup_options,
request_id,
query,
)
.await
}
ZoneType::Forward | ZoneType::Hint => {
send_forwarded_response(future, request_header, &mut response_header).await
}
};
(response_header, sections)
}
async fn send_authoritative_response(
future: impl Future<Output = Result<Box<dyn LookupObject>, LookupError>>,
authority: &dyn AuthorityObject,
response_header: &mut Header,
lookup_options: LookupOptions,
request_id: u16,
query: &LowerQuery,
) -> LookupSections {
let answers = match future.await {
Ok(records) => {
response_header.set_response_code(ResponseCode::NoError);
response_header.set_authoritative(true);
Some(records)
}
Err(LookupError::ResponseCode(ResponseCode::Refused)) => {
response_header.set_response_code(ResponseCode::Refused);
return LookupSections {
answers: Box::<AuthLookup>::default(),
ns: Box::<AuthLookup>::default(),
soa: Box::<AuthLookup>::default(),
additionals: Box::<AuthLookup>::default(),
};
}
Err(e) => {
if e.is_nx_domain() {
response_header.set_response_code(ResponseCode::NXDomain);
} else if e.is_name_exists() {
response_header.set_response_code(ResponseCode::NoError);
};
None
}
};
let (ns, soa) = if answers.is_some() {
if query.query_type().is_soa() {
match authority.ns(lookup_options).await {
Ok(ns) => (Some(ns), None),
Err(e) => {
warn!("ns_lookup errored: {}", e);
(None, None)
}
}
} else {
(None, None)
}
} else {
let nsecs = if lookup_options.is_dnssec() {
debug!("request: {} non-existent adding nsecs", request_id);
let future = authority.get_nsec_records(query.name(), lookup_options);
match future.await {
Ok(nsecs) => Some(nsecs),
Err(e) => {
warn!("failed to lookup nsecs: {}", e);
None
}
}
} else {
None
};
match authority.soa_secure(lookup_options).await {
Ok(soa) => (nsecs, Some(soa)),
Err(e) => {
warn!("failed to lookup soa: {}", e);
(nsecs, None)
}
}
};
let (answers, additionals) = match answers {
Some(mut answers) => match answers.take_additionals() {
Some(additionals) => (answers, additionals),
None => (
answers,
Box::<AuthLookup>::default() as Box<dyn LookupObject>,
),
},
None => (
Box::<AuthLookup>::default() as Box<dyn LookupObject>,
Box::<AuthLookup>::default() as Box<dyn LookupObject>,
),
};
LookupSections {
answers,
ns: ns.unwrap_or_else(|| Box::<AuthLookup>::default()),
soa: soa.unwrap_or_else(|| Box::<AuthLookup>::default()),
additionals,
}
}
async fn send_forwarded_response(
future: impl Future<Output = Result<Box<dyn LookupObject>, LookupError>>,
request_header: &Header,
response_header: &mut Header,
) -> LookupSections {
response_header.set_recursion_available(true);
response_header.set_authoritative(false);
let answers = if !request_header.recursion_desired() {
drop(future);
info!(
"request disabled recursion, returning no records: {}",
request_header.id()
);
Box::new(EmptyLookup)
} else {
match future.await {
Err(e) => {
if e.is_nx_domain() {
response_header.set_response_code(ResponseCode::NXDomain);
}
debug!("error resolving: {}", e);
Box::new(EmptyLookup)
}
Ok(rsp) => rsp,
}
};
LookupSections {
answers,
ns: Box::<AuthLookup>::default(),
soa: Box::<AuthLookup>::default(),
additionals: Box::<AuthLookup>::default(),
}
}
struct LookupSections {
answers: Box<dyn LookupObject>,
ns: Box<dyn LookupObject>,
soa: Box<dyn LookupObject>,
additionals: Box<dyn LookupObject>,
}