use std::clone::Clone;
use std::collections::HashSet;
use std::error::Error;
use std::pin::Pin;
use std::sync::Arc;
use futures_util::future;
use futures_util::future::{Future, FutureExt, TryFutureExt};
use futures_util::stream;
use futures_util::stream::{Stream, TryStreamExt};
use tracing::{debug, trace};
use crate::op::{OpCode, Query};
use crate::rr::dnssec::rdata::{DNSSECRData, DNSKEY, SIG};
#[cfg(feature = "dnssec")]
use crate::rr::dnssec::Verifier;
use crate::rr::dnssec::{Algorithm, SupportedAlgorithms, TrustAnchor};
use crate::rr::rdata::opt::EdnsOption;
use crate::rr::{DNSClass, Name, RData, Record, RecordType};
use crate::xfer::dns_handle::DnsHandle;
use crate::xfer::{DnsRequest, DnsRequestOptions, DnsResponse, FirstAnswer};
use crate::{error::*, op::Edns};
#[derive(Debug)]
struct Rrset {
pub(crate) name: Name,
pub(crate) record_type: RecordType,
pub(crate) record_class: DNSClass,
pub(crate) records: Vec<Record>,
}
#[derive(Clone)]
#[must_use = "queries can only be sent through a DnsHandle"]
pub struct DnssecDnsHandle<H>
where
H: DnsHandle + Unpin + 'static,
{
handle: H,
trust_anchor: Arc<TrustAnchor>,
request_depth: usize,
minimum_key_len: usize,
minimum_algorithm: Algorithm, }
impl<H> DnssecDnsHandle<H>
where
H: DnsHandle + Unpin + 'static,
{
pub fn new(handle: H) -> Self {
Self::with_trust_anchor(handle, TrustAnchor::default())
}
pub fn with_trust_anchor(handle: H, trust_anchor: TrustAnchor) -> Self {
Self {
handle,
trust_anchor: Arc::new(trust_anchor),
request_depth: 0,
minimum_key_len: 0,
minimum_algorithm: Algorithm::RSASHA256,
}
}
fn clone_with_context(&self) -> Self {
Self {
handle: self.handle.clone(),
trust_anchor: Arc::clone(&self.trust_anchor),
request_depth: self.request_depth + 1,
minimum_key_len: self.minimum_key_len,
minimum_algorithm: self.minimum_algorithm,
}
}
}
impl<H> DnsHandle for DnssecDnsHandle<H>
where
H: DnsHandle + Sync + Unpin,
{
type Response = Pin<Box<dyn Stream<Item = Result<DnsResponse, Self::Error>> + Send>>;
type Error = <H as DnsHandle>::Error;
fn is_verifying_dnssec(&self) -> bool {
true
}
fn send<R: Into<DnsRequest>>(&mut self, request: R) -> Self::Response {
let mut request = request.into();
if self.request_depth > request.options().max_request_depth {
return Box::pin(stream::once(future::err(Self::Error::from(
ProtoError::from("exceeded max validation depth"),
))));
}
if let OpCode::Query = request.op_code() {
let query = request
.queries()
.first()
.cloned()
.expect("no queries in request");
let handle: Self = self.clone_with_context();
#[cfg(feature = "dnssec")]
{
let edns = request.extensions_mut().get_or_insert_with(Edns::new);
edns.set_dnssec_ok(true);
let mut algorithms = SupportedAlgorithms::new();
#[cfg(feature = "ring")]
{
algorithms.set(Algorithm::ED25519);
}
algorithms.set(Algorithm::ECDSAP256SHA256);
algorithms.set(Algorithm::ECDSAP384SHA384);
algorithms.set(Algorithm::RSASHA256);
let dau = EdnsOption::DAU(algorithms);
let dhu = EdnsOption::DHU(algorithms);
edns.options_mut().insert(dau);
edns.options_mut().insert(dhu);
}
request.set_authentic_data(true);
request.set_checking_disabled(false);
let dns_class = request
.queries()
.first()
.map_or(DNSClass::IN, Query::query_class);
let options = *request.options();
return Box::pin(
self.handle
.send(request)
.and_then(move |message_response| {
debug!(
"validating message_response: {}, with {} trust_anchors",
message_response.id(),
handle.trust_anchor.len(),
);
verify_rrsets(handle.clone(), message_response, dns_class, options)
})
.and_then(move |verified_message| {
if verified_message.answers().is_empty() {
let soa_name = if let Some(soa_name) = verified_message
.name_servers()
.iter()
.find(|rr| rr.record_type() == RecordType::SOA)
.map(Record::name)
{
soa_name
} else {
return future::err(Self::Error::from(ProtoError::from(
"could not validate negative response missing SOA",
)));
};
let nsecs = verified_message
.name_servers()
.iter()
.filter(|rr| is_dnssec(rr, RecordType::NSEC))
.collect::<Vec<_>>();
if !verify_nsec(&query, soa_name, nsecs.as_slice()) {
return future::err(Self::Error::from(ProtoError::from(
"could not validate negative response with NSEC",
)));
}
}
future::ok(verified_message)
}),
);
}
Box::pin(self.handle.send(request))
}
}
#[allow(clippy::type_complexity)]
async fn verify_rrsets<H, E>(
handle: DnssecDnsHandle<H>,
message_result: DnsResponse,
dns_class: DNSClass,
options: DnsRequestOptions,
) -> Result<DnsResponse, E>
where
H: DnsHandle<Error = E> + Sync + Unpin,
E: From<ProtoError> + Error + Clone + Send + Unpin + 'static,
{
let mut rrset_types: HashSet<(Name, RecordType)> = HashSet::new();
for rrset in message_result
.answers()
.iter()
.chain(message_result.name_servers())
.filter(|rr| {
!is_dnssec(rr, RecordType::RRSIG) &&
(handle.request_depth <= 1 ||
is_dnssec(rr, RecordType::DNSKEY) ||
is_dnssec(rr, RecordType::DS))
})
.map(|rr| (rr.name().clone(), rr.rr_type()))
{
rrset_types.insert(rrset);
}
if rrset_types.is_empty() {
let mut message_result = message_result;
message_result.take_answers();
message_result.take_name_servers();
message_result.take_additionals();
return Err(E::from(ProtoError::from(ProtoErrorKind::Message(
"no results to verify",
))));
}
let mut rrsets_to_verify = Vec::with_capacity(rrset_types.len());
for (name, record_type) in rrset_types {
let records: Vec<Record> = message_result
.answers()
.iter()
.chain(message_result.name_servers())
.chain(message_result.additionals())
.filter(|rr| rr.rr_type() == record_type && rr.name() == &name)
.cloned()
.collect();
let rrsigs: Vec<Record> = message_result
.answers()
.iter()
.chain(message_result.name_servers())
.chain(message_result.additionals())
.filter(|rr| is_dnssec(rr, RecordType::RRSIG))
.filter(|rr| {
if let Some(RData::DNSSEC(DNSSECRData::SIG(ref rrsig))) = rr.data() {
rrsig.type_covered() == record_type
} else {
false
}
})
.cloned()
.collect();
let rrset = Rrset {
name,
record_type,
record_class: dns_class,
records,
};
debug!(
"verifying: {}, record_type: {:?}, rrsigs: {}",
rrset.name,
record_type,
rrsigs.len()
);
rrsets_to_verify
.push(verify_rrset(handle.clone_with_context(), rrset, rrsigs, options).boxed());
}
verify_all_rrsets(message_result, rrsets_to_verify).await
}
fn is_dnssec(rr: &Record, dnssec_type: RecordType) -> bool {
rr.rr_type().is_dnssec() && dnssec_type.is_dnssec() && rr.record_type() == dnssec_type
}
async fn verify_all_rrsets<F, E>(
message_result: DnsResponse,
rrsets: Vec<F>,
) -> Result<DnsResponse, E>
where
F: Future<Output = Result<Rrset, E>> + Send + Unpin,
E: From<ProtoError> + Error + Clone + Send + Unpin + 'static,
{
let mut verified_rrsets: HashSet<(Name, RecordType)> = HashSet::new();
let mut rrsets = future::select_all(rrsets);
let mut last_validation_err: Option<E> = None;
loop {
let (rrset, _, remaining) = rrsets.await;
match rrset {
Ok(rrset) => {
debug!(
"an rrset was verified: {}, {:?}",
rrset.name, rrset.record_type
);
verified_rrsets.insert((rrset.name, rrset.record_type));
}
Err(e) => {
if tracing::enabled!(tracing::Level::DEBUG) {
let mut query = message_result
.queries()
.iter()
.map(|q| q.to_string())
.fold(String::new(), |s, q| format!("{},{}", q, s));
query.truncate(query.len() - 1);
debug!("an rrset failed to verify ({}): {:?}", query, e);
}
last_validation_err = Some(e);
}
};
if !remaining.is_empty() {
rrsets = future::select_all(remaining);
} else {
break;
}
}
if verified_rrsets.is_empty() && last_validation_err.is_some() {
return Err(last_validation_err.expect("can not be none based on above check"));
}
let mut message_result = message_result;
let answers = message_result
.take_answers()
.into_iter()
.chain(message_result.take_additionals().into_iter())
.filter(|record| verified_rrsets.contains(&(record.name().clone(), record.rr_type())))
.collect::<Vec<Record>>();
let name_servers = message_result
.take_name_servers()
.into_iter()
.filter(|record| verified_rrsets.contains(&(record.name().clone(), record.rr_type())))
.collect::<Vec<Record>>();
let additionals = message_result
.take_additionals()
.into_iter()
.filter(|record| verified_rrsets.contains(&(record.name().clone(), record.rr_type())))
.collect::<Vec<Record>>();
message_result.insert_answers(answers);
message_result.insert_name_servers(name_servers);
message_result.insert_additionals(additionals);
Ok(message_result)
}
async fn verify_rrset<H, E>(
handle: DnssecDnsHandle<H>,
rrset: Rrset,
rrsigs: Vec<Record>,
options: DnsRequestOptions,
) -> Result<Rrset, E>
where
H: DnsHandle<Error = E> + Sync + Unpin,
E: From<ProtoError> + Error + Clone + Send + Unpin + 'static,
{
if let RecordType::DNSKEY = rrset.record_type {
if rrsigs.is_empty() {
debug!("unsigned key: {}, {:?}", rrset.name, rrset.record_type);
return verify_dnskey_rrset(handle.clone_with_context(), rrset, options).await;
}
}
let rrset = verify_default_rrset(&handle.clone_with_context(), rrset, rrsigs, options).await?;
match rrset.record_type {
RecordType::DNSKEY => verify_dnskey_rrset(handle, rrset, options).await,
_ => Ok(rrset),
}
}
async fn verify_dnskey_rrset<H, E>(
mut handle: DnssecDnsHandle<H>,
rrset: Rrset,
options: DnsRequestOptions,
) -> Result<Rrset, E>
where
H: DnsHandle<Error = E> + Sync + Unpin,
E: From<ProtoError> + Error + Clone + Send + Unpin + 'static,
{
trace!(
"dnskey validation {}, record_type: {:?}",
rrset.name,
rrset.record_type
);
{
let anchored_keys = rrset
.records
.iter()
.enumerate()
.filter(|&(_, rr)| is_dnssec(rr, RecordType::DNSKEY))
.filter_map(|(i, rr)| {
if let Some(RData::DNSSEC(DNSSECRData::DNSKEY(ref rdata))) = rr.data() {
Some((i, rdata))
} else {
None
}
})
.filter_map(|(i, rdata)| {
if handle
.trust_anchor
.contains_dnskey_bytes(rdata.public_key())
{
debug!(
"validated dnskey with trust_anchor: {}, {}",
rrset.name, rdata
);
Some(i)
} else {
None
}
})
.collect::<Vec<usize>>();
if !anchored_keys.is_empty() {
let mut rrset = rrset;
preserve(&mut rrset.records, anchored_keys);
return Ok(rrset);
}
}
let ds_message = handle
.lookup(Query::query(rrset.name.clone(), RecordType::DS), options)
.first_answer()
.await?;
let valid_keys = rrset
.records
.iter()
.enumerate()
.filter(|&(_, rr)| is_dnssec(rr, RecordType::DNSKEY))
.filter_map(|(i, rr)| {
if let Some(RData::DNSSEC(DNSSECRData::DNSKEY(ref rdata))) = rr.data() {
Some((i, rdata))
} else {
None
}
})
.filter(|&(_, key_rdata)| {
ds_message
.answers()
.iter()
.filter(|ds| is_dnssec(ds, RecordType::DS))
.filter_map(|ds| {
if let Some(RData::DNSSEC(DNSSECRData::DS(ref ds_rdata))) = ds.data() {
Some((ds.name(), ds_rdata))
} else {
None
}
})
.any(|(ds_name, ds_rdata)| {
if ds_rdata.covers(&rrset.name, key_rdata).unwrap_or(false) {
debug!(
"validated dnskey ({}, {}) with {} {}",
rrset.name, key_rdata, ds_name, ds_rdata
);
true
} else {
false
}
})
})
.map(|(i, _)| i)
.collect::<Vec<usize>>();
if !valid_keys.is_empty() {
let mut rrset = rrset;
preserve(&mut rrset.records, valid_keys);
trace!("validated dnskey: {}", rrset.name);
Ok(rrset)
} else {
Err(E::from(ProtoError::from(ProtoErrorKind::Message(
"Could not validate all DNSKEYs",
))))
}
}
fn preserve<T, I>(vec: &mut Vec<T>, indexes: I)
where
I: IntoIterator<Item = usize>,
<I as IntoIterator>::IntoIter: DoubleEndedIterator,
{
let mut indexes_iter = indexes.into_iter().rev();
let mut i = indexes_iter.next();
for j in (0..vec.len()).rev() {
if i.map_or(false, |i| i > j) {
i = indexes_iter.next();
}
if i.map_or(true, |i| i != j) {
vec.remove(j);
}
}
}
#[test]
fn test_preserve() {
let mut vec = vec![1, 2, 3];
let indexes = vec![];
preserve(&mut vec, indexes);
assert_eq!(vec, vec![]);
let mut vec = vec![1, 2, 3];
let indexes = vec![0];
preserve(&mut vec, indexes);
assert_eq!(vec, vec![1]);
let mut vec = vec![1, 2, 3];
let indexes = vec![1];
preserve(&mut vec, indexes);
assert_eq!(vec, vec![2]);
let mut vec = vec![1, 2, 3];
let indexes = vec![2];
preserve(&mut vec, indexes);
assert_eq!(vec, vec![3]);
let mut vec = vec![1, 2, 3];
let indexes = vec![0, 2];
preserve(&mut vec, indexes);
assert_eq!(vec, vec![1, 3]);
let mut vec = vec![1, 2, 3];
let indexes = vec![0, 1, 2];
preserve(&mut vec, indexes);
assert_eq!(vec, vec![1, 2, 3]);
}
#[allow(clippy::blocks_in_if_conditions)]
async fn verify_default_rrset<H, E>(
handle: &DnssecDnsHandle<H>,
rrset: Rrset,
rrsigs: Vec<Record>,
options: DnsRequestOptions,
) -> Result<Rrset, E>
where
H: DnsHandle<Error = E> + Sync + Unpin,
E: From<ProtoError> + Error + Clone + Send + Unpin + 'static,
{
let rrset = Arc::new(rrset);
trace!(
"default validation {}, record_type: {:?}",
rrset.name,
rrset.record_type
);
if rrsigs
.iter()
.filter(|rrsig| is_dnssec(rrsig, RecordType::RRSIG))
.any(|rrsig| {
if let Some(RData::DNSSEC(DNSSECRData::SIG(ref sig))) = rrsig.data() {
RecordType::DNSKEY == rrset.record_type && sig.signer_name() == &rrset.name
} else {
panic!("expected a SIG here");
}
})
{
return future::ready(
rrsigs
.into_iter()
.filter(|rrsig| is_dnssec(rrsig, RecordType::RRSIG))
.map(|rrsig| {
if let Some(RData::DNSSEC(DNSSECRData::SIG(sig))) = rrsig.into_data() {
sig
} else {
panic!("expected a SIG here");
}
})
.filter_map(|sig| {
let rrset = Arc::clone(&rrset);
if rrset.records.iter().any(|r| {
if let Some(RData::DNSSEC(DNSSECRData::DNSKEY(ref dnskey))) = r.data() {
let dnskey_name = r.name();
verify_rrset_with_dnskey(dnskey_name, dnskey, &sig, &rrset).is_ok()
} else {
panic!("expected a DNSKEY here: {:?}", r.data());
}
}) {
Some(())
} else {
None
}
})
.next()
.ok_or_else(|| {
E::from(ProtoError::from(ProtoErrorKind::Message(
"self-signed dnskey is invalid",
)))
}),
)
.map_ok(move |_| Arc::try_unwrap(rrset).expect("unable to unwrap Arc"))
.await;
}
let verifications = rrsigs.into_iter()
.filter(|rrsig| is_dnssec(rrsig, RecordType::RRSIG))
.map(|rrsig|
if let Some(RData::DNSSEC(DNSSECRData::SIG(sig))) = rrsig.into_data() {
sig
} else {
panic!("expected a SIG here");
}
)
.map(|sig| {
let rrset = Arc::clone(&rrset);
let mut handle = handle.clone_with_context();
handle
.lookup(
Query::query(sig.signer_name().clone(), RecordType::DNSKEY),
options,
)
.first_answer()
.and_then(move |message|
future::ready(message
.answers()
.iter()
.filter(|r| is_dnssec(r, RecordType::DNSKEY))
.find(|r|
if let Some(RData::DNSSEC(DNSSECRData::DNSKEY(ref dnskey))) = r.data() {
let dnskey_name = r.name();
verify_rrset_with_dnskey(dnskey_name, dnskey, &sig, &rrset).is_ok()
} else {
panic!("expected a DNSKEY here: {:?}", r.data());
}
)
.map(|_| ())
.ok_or_else(|| E::from(ProtoError::from(ProtoErrorKind::Message("validation failed")))))
)
})
.collect::<Vec<_>>();
if verifications.is_empty() {
return Err(E::from(ProtoError::from(
ProtoErrorKind::RrsigsNotPresent {
name: rrset.name.clone(),
record_type: rrset.record_type,
},
)));
}
let select = future::select_ok(verifications)
.map_ok(move |((), rest)| {
drop(rest); Arc::try_unwrap(rrset).expect("unable to unwrap Arc")
});
select.await
}
#[cfg(feature = "dnssec")]
fn verify_rrset_with_dnskey(
dnskey_name: &Name,
dnskey: &DNSKEY,
sig: &SIG,
rrset: &Rrset,
) -> ProtoResult<()> {
if dnskey.revoke() {
debug!("revoked");
return Err(ProtoErrorKind::Message("revoked").into());
} if !dnskey.zone_key() {
return Err(ProtoErrorKind::Message("is not a zone key").into());
}
if dnskey.algorithm() != sig.algorithm() {
return Err(ProtoErrorKind::Message("mismatched algorithm").into());
}
dnskey
.verify_rrsig(&rrset.name, rrset.record_class, sig, &rrset.records)
.map(|r| {
debug!(
"validated ({}, {:?}) with ({}, {})",
rrset.name, rrset.record_type, dnskey_name, dnskey
);
r
})
.map_err(Into::into)
.map_err(|e| {
debug!(
"failed validation of ({}, {:?}) with ({}, {})",
rrset.name, rrset.record_type, dnskey_name, dnskey
);
e
})
}
#[cfg(not(feature = "dnssec"))]
fn verify_rrset_with_dnskey(_: &DNSKEY, _: &SIG, _: &Rrset) -> ProtoResult<()> {
Err(ProtoErrorKind::Message("openssl or ring feature(s) not enabled").into())
}
#[allow(clippy::blocks_in_if_conditions)]
#[doc(hidden)]
pub fn verify_nsec(query: &Query, soa_name: &Name, nsecs: &[&Record]) -> bool {
if let Some(nsec) = nsecs.iter().find(|nsec| query.name() == nsec.name()) {
return nsec
.data()
.and_then(RData::as_dnssec)
.and_then(DNSSECRData::as_nsec)
.map_or(false, |rdata| {
!rdata.type_bit_maps().contains(&query.query_type())
});
}
let verify_nsec_coverage = |name: &Name| -> bool {
nsecs.iter().any(|nsec| {
name >= nsec.name() && {
nsec.data()
.and_then(RData::as_dnssec)
.and_then(DNSSECRData::as_nsec)
.map_or(false, |rdata| {
name < rdata.next_domain_name() || rdata.next_domain_name() < nsec.name()
})
}
})
};
if !verify_nsec_coverage(query.name()) {
return false;
}
let wildcard = query.name().base_name();
let wildcard = if soa_name.zone_of(&wildcard) {
wildcard
} else {
soa_name.clone()
};
if wildcard == *query.name() {
true
} else {
verify_nsec_coverage(&wildcard)
}
}