use std::{
collections::BTreeMap,
fs,
marker::PhantomData,
ops::{Deref, DerefMut},
path::Path,
sync::Arc,
};
#[cfg(feature = "__dnssec")]
use crate::{
dnssec::NxProofKind,
net::runtime::Time,
proto::dnssec::{
DnsSecResult, DnssecSigner,
rdata::{DNSKEY, DNSSECRData},
},
zone_handler::{DnssecZoneHandler, Nsec3QueryInfo},
};
use crate::{
net::runtime::{RuntimeProvider, TokioRuntimeProvider},
proto::{
op::ResponseCode,
rr::{DNSClass, LowerName, Name, RData, Record, RecordSet, RecordType, RrKey},
serialize::txt::Parser,
},
server::{Request, RequestInfo},
zone_handler::{
AuthLookup, AxfrPolicy, AxfrRecords, LookupControlFlow, LookupError, LookupOptions,
LookupRecords, ZoneHandler, ZoneTransfer, ZoneType,
},
};
use hickory_proto::rr::TSigResponseContext;
#[cfg(feature = "__dnssec")]
use time::OffsetDateTime;
use tokio::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard};
#[cfg(feature = "__dnssec")]
use tracing::warn;
use tracing::{debug, info};
mod inner;
use inner::InnerInMemory;
pub struct InMemoryZoneHandler<P = TokioRuntimeProvider> {
origin: LowerName,
class: DNSClass,
zone_type: ZoneType,
axfr_policy: AxfrPolicy,
inner: RwLock<InnerInMemory>,
#[cfg(feature = "__dnssec")]
nx_proof_kind: Option<NxProofKind>,
_phantom: PhantomData<P>,
}
impl<P: RuntimeProvider + Send + Sync> InMemoryZoneHandler<P> {
pub fn new(
origin: Name,
records: BTreeMap<RrKey, RecordSet>,
zone_type: ZoneType,
axfr_policy: AxfrPolicy,
#[cfg(feature = "__dnssec")] nx_proof_kind: Option<NxProofKind>,
) -> Result<Self, String> {
let mut this = Self::empty(
origin.clone(),
zone_type,
axfr_policy,
#[cfg(feature = "__dnssec")]
nx_proof_kind,
);
let inner = this.inner.get_mut();
let soa = records
.get(&RrKey::new(origin.clone().into(), RecordType::SOA))
.and_then(
|rrset| match rrset.records_without_rrsigs().next()?.data() {
RData::SOA(soa) => Some(soa),
_ => None,
},
)
.ok_or_else(|| format!("SOA record must be present: {origin}"))?;
let serial = soa.serial;
let iter = records.into_values();
for rrset in iter {
let name = rrset.name().clone();
let rr_type = rrset.record_type();
for record in rrset.records_without_rrsigs() {
if !inner.upsert(record.clone(), serial, this.class) {
return Err(format!(
"Failed to insert {name} {rr_type} to zone: {origin}"
));
};
}
}
Ok(this)
}
pub fn empty(
origin: Name,
zone_type: ZoneType,
axfr_policy: AxfrPolicy,
#[cfg(feature = "__dnssec")] nx_proof_kind: Option<NxProofKind>,
) -> Self {
Self {
origin: LowerName::new(&origin),
class: DNSClass::IN,
zone_type,
axfr_policy,
inner: RwLock::new(InnerInMemory::default()),
#[cfg(feature = "__dnssec")]
nx_proof_kind,
_phantom: PhantomData,
}
}
pub fn class(&self) -> DNSClass {
self.class
}
#[cfg(any(test, feature = "testing"))]
pub fn set_axfr_policy(&mut self, policy: AxfrPolicy) {
self.axfr_policy = policy;
}
pub fn clear(&mut self) {
self.inner.get_mut().records.clear()
}
#[cfg(all(feature = "__dnssec", feature = "testing"))]
pub async fn secure_keys(&self) -> impl Deref<Target = [DnssecSigner]> + '_ {
RwLockWriteGuard::map(self.inner.write().await, |i| i.secure_keys.as_mut_slice())
}
pub async fn records(&self) -> impl Deref<Target = BTreeMap<RrKey, Arc<RecordSet>>> + '_ {
RwLockReadGuard::map(self.inner.read().await, |i| &i.records)
}
pub async fn records_mut(
&self,
) -> impl DerefMut<Target = BTreeMap<RrKey, Arc<RecordSet>>> + '_ {
RwLockWriteGuard::map(self.inner.write().await, |i| &mut i.records)
}
pub fn records_get_mut(&mut self) -> &mut BTreeMap<RrKey, Arc<RecordSet>> {
&mut self.inner.get_mut().records
}
pub async fn minimum_ttl(&self) -> u32 {
self.inner.read().await.minimum_ttl(self.origin())
}
pub async fn serial(&self) -> u32 {
self.inner.read().await.serial(self.origin())
}
#[cfg(feature = "sqlite")]
pub(crate) async fn increment_soa_serial(&self) -> u32 {
self.inner
.write()
.await
.increment_soa_serial(self.origin(), self.class)
}
pub async fn upsert(&self, record: Record, serial: u32) -> bool {
self.inner.write().await.upsert(record, serial, self.class)
}
pub fn upsert_mut(&mut self, record: Record, serial: u32) -> bool {
self.inner.get_mut().upsert(record, serial, self.class)
}
#[cfg(feature = "__dnssec")]
fn inner_add_zone_signing_key(
inner: &mut InnerInMemory,
signer: DnssecSigner,
origin: &LowerName,
dns_class: DNSClass,
) -> DnsSecResult<()> {
let zone_ttl = inner.minimum_ttl(origin);
let dnskey = DNSKEY::from_key(&signer.key().to_public_key()?);
let dnskey = Record::from_rdata(
origin.clone().into(),
zone_ttl,
RData::DNSSEC(DNSSECRData::DNSKEY(dnskey)),
);
let serial = inner.serial(origin);
inner.upsert(dnskey, serial, dns_class);
inner.secure_keys.push(signer);
Ok(())
}
#[cfg(feature = "__dnssec")]
pub fn add_zone_signing_key_mut(&mut self, signer: DnssecSigner) -> DnsSecResult<()> {
let Self {
origin,
inner,
class,
..
} = self;
Self::inner_add_zone_signing_key(inner.get_mut(), signer, origin, *class)
}
#[cfg(feature = "__dnssec")]
pub fn secure_zone_mut(&mut self) -> DnsSecResult<()> {
let Self { origin, inner, .. } = self;
inner.get_mut().secure_zone_mut(
origin,
self.class,
self.nx_proof_kind.as_ref(),
Self::current_time()?,
)
}
#[cfg(not(feature = "__dnssec"))]
pub fn secure_zone_mut(&mut self) -> Result<(), &str> {
Err("DNSSEC was not enabled during compilation.")
}
#[cfg(feature = "__dnssec")]
fn current_time() -> DnsSecResult<OffsetDateTime> {
let timestamp_unsigned = P::Timer::current_time();
let timestamp_signed = timestamp_unsigned
.try_into()
.map_err(|_| "current time is out of range")?;
OffsetDateTime::from_unix_timestamp(timestamp_signed)
.map_err(|_| "current time is out of range".into())
}
}
#[async_trait::async_trait]
impl<P: RuntimeProvider + Send + Sync> ZoneHandler for InMemoryZoneHandler<P> {
fn zone_type(&self) -> ZoneType {
self.zone_type
}
fn axfr_policy(&self) -> AxfrPolicy {
self.axfr_policy
}
fn origin(&self) -> &LowerName {
&self.origin
}
async fn lookup(
&self,
name: &LowerName,
mut query_type: RecordType,
_request_info: Option<&RequestInfo<'_>>,
lookup_options: LookupOptions,
) -> LookupControlFlow<AuthLookup> {
let inner = self.inner.read().await;
if query_type == RecordType::AXFR {
return Break(Err(LookupError::NetError(
"AXFR must be handled with ZoneHandler::zone_transfer()".into(),
)));
}
if query_type == RecordType::ANY {
query_type = inner.replace_any(name);
}
let answer = inner.inner_lookup(name, query_type, lookup_options);
let additionals_root_chain_type: Option<(_, _)> = answer
.as_ref()
.and_then(|a| maybe_next_name(a, query_type))
.and_then(|(search_name, search_type)| {
inner
.additional_search(name, query_type, search_name, search_type, lookup_options)
.map(|adds| (adds, search_type))
});
let (additionals, answer) = match (additionals_root_chain_type, answer, query_type) {
(Some((additionals, RecordType::ANAME)), Some(answer), RecordType::A)
| (Some((additionals, RecordType::ANAME)), Some(answer), RecordType::AAAA) => {
debug_assert_eq!(answer.record_type(), RecordType::ANAME);
let (rdatas, a_aaaa_ttl) = {
let last_record = additionals.last();
let a_aaaa_ttl = last_record.map_or(u32::MAX, |r| r.ttl());
let rdatas: Option<Vec<RData>> = last_record
.and_then(|record| match record.record_type() {
RecordType::A | RecordType::AAAA => {
Some(record.records_without_rrsigs())
}
_ => None,
})
.map(|records| records.map(Record::data).cloned().collect::<Vec<_>>());
(rdatas, a_aaaa_ttl)
};
let ttl = answer.ttl().min(a_aaaa_ttl);
let mut new_answer = RecordSet::new(answer.name().clone(), query_type, ttl);
for rdata in rdatas.into_iter().flatten() {
new_answer.add_rdata(rdata);
}
#[cfg(feature = "__dnssec")]
if lookup_options.dnssec_ok {
let result = Self::current_time().and_then(|time| {
InnerInMemory::sign_rrset(
&mut new_answer,
&inner.secure_keys,
self.class(),
time,
)
});
if let Err(error) = result {
warn!(%error, "failed to sign ANAME record")
}
}
let additionals = std::iter::once(answer).chain(additionals).collect();
(Some(additionals), Some(Arc::new(new_answer)))
}
(Some((additionals, _)), answer, _) => (Some(additionals), answer),
(None, answer, _) => (None, answer),
};
use LookupControlFlow::*;
let answers = match answer {
Some(rr_set) => LookupRecords::new(lookup_options, rr_set),
None => {
return Continue(Err(
if inner
.records
.keys()
.any(|key| key.name() == name || name.zone_of(key.name()))
{
LookupError::NameExists
} else {
LookupError::from(match self.origin().zone_of(name) {
true => ResponseCode::NXDomain,
false => ResponseCode::Refused,
})
},
));
}
};
Continue(Ok(AuthLookup::answers(
answers,
additionals.map(|a| LookupRecords::many(lookup_options, a)),
)))
}
async fn search(
&self,
request: &Request,
lookup_options: LookupOptions,
) -> (LookupControlFlow<AuthLookup>, Option<TSigResponseContext>) {
let request_info = match request.request_info() {
Ok(info) => info,
Err(e) => return (LookupControlFlow::Break(Err(e)), None),
};
debug!("searching InMemoryZoneHandler for: {}", request_info.query);
let lookup_name = request_info.query.name();
let record_type: RecordType = request_info.query.query_type();
match record_type {
RecordType::SOA => (
self.lookup(
self.origin(),
record_type,
Some(&request_info),
lookup_options,
)
.await,
None,
),
RecordType::AXFR => (
LookupControlFlow::Break(Err(LookupError::NetError(
"AXFR must be handled with ZoneHandler::zone_transfer()".into(),
))),
None,
),
_ => (
self.lookup(
lookup_name,
record_type,
Some(&request_info),
lookup_options,
)
.await,
None,
),
}
}
async fn zone_transfer(
&self,
request: &Request,
lookup_options: LookupOptions,
_now: u64,
) -> Option<(
Result<ZoneTransfer, LookupError>,
Option<TSigResponseContext>,
)> {
let request_info = match request.request_info() {
Ok(info) => info,
Err(e) => return Some((Err(e), None)),
};
if request_info.query.query_type() == RecordType::AXFR {
if !matches!(self.axfr_policy, AxfrPolicy::AllowAll) {
return Some((Err(LookupError::from(ResponseCode::Refused)), None));
}
}
let future = self.lookup(self.origin(), RecordType::SOA, None, lookup_options);
let start_soa = if let LookupControlFlow::Continue(Ok(res)) = future.await {
res.unwrap_records()
} else {
LookupRecords::Empty
};
let future = self.lookup(
self.origin(),
RecordType::SOA,
None,
LookupOptions::default(),
);
let end_soa = if let LookupControlFlow::Continue(Ok(res)) = future.await {
res.unwrap_records()
} else {
LookupRecords::Empty
};
let records = AxfrRecords::new(
lookup_options.dnssec_ok,
self.inner.read().await.records.values().cloned().collect(),
);
Some((
Ok(ZoneTransfer {
start_soa,
records,
end_soa,
}),
None,
))
}
#[cfg(feature = "__dnssec")]
async fn nsec_records(
&self,
name: &LowerName,
lookup_options: LookupOptions,
) -> LookupControlFlow<AuthLookup> {
let inner = self.inner.read().await;
let rr_key = RrKey::new(name.clone(), RecordType::NSEC);
let no_data = inner
.records
.get(&rr_key)
.map(|rr_set| LookupRecords::new(lookup_options, rr_set.clone()));
if let Some(no_data) = no_data {
return LookupControlFlow::Continue(Ok(no_data.into()));
}
let closest_proof = inner.closest_nsec(name);
let wildcard = name.base_name();
let origin = self.origin();
let wildcard = if origin.zone_of(&wildcard) {
wildcard
} else {
origin.clone()
};
let wildcard_proof = if wildcard != *name {
inner.closest_nsec(&wildcard)
} else {
None
};
let proofs = match (closest_proof, wildcard_proof) {
(Some(closest_proof), Some(wildcard_proof)) => {
if wildcard_proof != closest_proof {
vec![wildcard_proof, closest_proof]
} else {
vec![closest_proof]
}
}
(None, Some(proof)) | (Some(proof), None) => vec![proof],
(None, None) => vec![],
};
LookupControlFlow::Continue(Ok(LookupRecords::many(lookup_options, proofs).into()))
}
#[cfg(not(feature = "__dnssec"))]
async fn nsec_records(
&self,
_name: &LowerName,
_lookup_options: LookupOptions,
) -> LookupControlFlow<AuthLookup> {
LookupControlFlow::Continue(Ok(AuthLookup::default()))
}
#[cfg(feature = "__dnssec")]
async fn nsec3_records(
&self,
info: Nsec3QueryInfo<'_>,
lookup_options: LookupOptions,
) -> LookupControlFlow<AuthLookup> {
let inner = self.inner.read().await;
LookupControlFlow::Continue(
inner
.proof(info, self.origin())
.map(|proof| LookupRecords::many(lookup_options, proof).into()),
)
}
#[cfg(feature = "__dnssec")]
fn nx_proof_kind(&self) -> Option<&NxProofKind> {
self.nx_proof_kind.as_ref()
}
#[cfg(feature = "metrics")]
fn metrics_label(&self) -> &'static str {
"in-memory"
}
}
#[cfg(feature = "__dnssec")]
#[async_trait::async_trait]
impl<P: RuntimeProvider + Send + Sync> DnssecZoneHandler for InMemoryZoneHandler<P> {
async fn add_zone_signing_key(&self, signer: DnssecSigner) -> DnsSecResult<()> {
let mut inner = self.inner.write().await;
Self::inner_add_zone_signing_key(&mut inner, signer, self.origin(), self.class)
}
async fn secure_zone(&self) -> DnsSecResult<()> {
let mut inner = self.inner.write().await;
inner.secure_zone_mut(
self.origin(),
self.class,
self.nx_proof_kind.as_ref(),
Self::current_time()?,
)
}
}
fn maybe_next_name(
record_set: &RecordSet,
query_type: RecordType,
) -> Option<(LowerName, RecordType)> {
let t = match (record_set.record_type(), query_type) {
(t @ RecordType::ANAME, RecordType::A)
| (t @ RecordType::ANAME, RecordType::AAAA)
| (t @ RecordType::ANAME, RecordType::ANAME) => t,
(t @ RecordType::NS, RecordType::NS) => t,
(t @ RecordType::CNAME, _) => t,
(t @ RecordType::MX, RecordType::MX) => t,
(t @ RecordType::SRV, RecordType::SRV) => t,
_ => return None,
};
let name = match (record_set.records_without_rrsigs().next()?.data(), t) {
(RData::ANAME(name), RecordType::ANAME) => name,
(RData::NS(ns), RecordType::NS) => &ns.0,
(RData::CNAME(name), RecordType::CNAME) => name,
(RData::MX(mx), RecordType::MX) => &mx.exchange,
(RData::SRV(srv), RecordType::SRV) => &srv.target,
_ => return None,
};
Some((LowerName::from(name), t))
}
pub(crate) fn zone_from_path(
zone_path: &Path,
origin: Name,
) -> Result<BTreeMap<RrKey, RecordSet>, String> {
info!("loading zone file: {zone_path:?}");
let buf = fs::read_to_string(zone_path)
.map_err(|e| format!("failed to read {}: {e:?}", zone_path.display()))?;
let (origin, records) = Parser::new(buf, Some(zone_path.to_owned()), Some(origin))
.parse()
.map_err(|e| format!("failed to parse {}: {e:?}", zone_path.display()))?;
info!("zone file loaded: {origin} with {} records", records.len());
debug!("zone: {records:#?}");
Ok(records)
}