use std::collections::HashMap;
use std::ops::RangeInclusive;
use std::sync::Arc;
use std::time::{Duration, Instant};
use moka::{Expiry, sync::Cache};
#[cfg(feature = "serde")]
use serde::{Deserialize, Deserializer};
use crate::config;
use crate::lookup::Lookup;
#[cfg(feature = "__dnssec")]
use crate::proto::dnssec::rdata::RRSIG;
use crate::proto::op::Query;
#[cfg(feature = "__dnssec")]
use crate::proto::rr::RecordData;
use crate::proto::rr::{Record, RecordType};
use crate::proto::{ProtoError, ProtoErrorKind};
pub(crate) const MAX_TTL: u32 = 86400_u32;
#[derive(Debug, Clone)]
struct LruValue {
lookup: Result<Lookup, ProtoError>,
valid_until: Instant,
}
impl LruValue {
fn is_current(&self, now: Instant) -> bool {
now <= self.valid_until
}
fn ttl(&self, now: Instant) -> Duration {
self.valid_until.saturating_duration_since(now)
}
fn with_updated_ttl(&self, now: Instant) -> Self {
let lookup = match &self.lookup {
Ok(lookup) => {
let records = lookup
.records()
.iter()
.map(|record| {
let mut record = record.clone();
record.set_ttl(self.ttl(now).as_secs() as u32);
record
})
.collect::<Vec<Record>>();
Ok(Lookup::new_with_deadline(
lookup.query().clone(),
Arc::from(records),
self.valid_until,
))
}
Err(e) => Err(e.clone()),
};
Self {
lookup,
valid_until: self.valid_until,
}
}
}
#[derive(Clone, Debug)]
pub struct DnsLru {
cache: Cache<Query, LruValue>,
ttl_config: Arc<TtlConfig>,
}
#[derive(Clone, Debug, Default, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(Deserialize))]
#[cfg_attr(
feature = "serde",
serde(from = "ttl_config_deserialize::TtlConfigMap")
)]
pub struct TtlConfig {
default: TtlBounds,
by_query_type: HashMap<RecordType, TtlBounds>,
}
impl TtlConfig {
pub fn from_opts(opts: &config::ResolverOpts) -> Self {
Self {
default: TtlBounds {
positive_min_ttl: opts.positive_min_ttl,
negative_min_ttl: opts.negative_min_ttl,
positive_max_ttl: opts.positive_max_ttl,
negative_max_ttl: opts.negative_max_ttl,
},
by_query_type: HashMap::new(),
}
}
pub fn new(
positive_min_ttl: Option<Duration>,
negative_min_ttl: Option<Duration>,
positive_max_ttl: Option<Duration>,
negative_max_ttl: Option<Duration>,
) -> Self {
Self {
default: TtlBounds {
positive_min_ttl,
negative_min_ttl,
positive_max_ttl,
negative_max_ttl,
},
by_query_type: HashMap::new(),
}
}
pub fn with_query_type_ttl_bounds(
&mut self,
query_type: RecordType,
positive_min_ttl: Option<Duration>,
negative_min_ttl: Option<Duration>,
positive_max_ttl: Option<Duration>,
negative_max_ttl: Option<Duration>,
) -> &mut Self {
self.by_query_type.insert(
query_type,
TtlBounds {
positive_min_ttl,
negative_min_ttl,
positive_max_ttl,
negative_max_ttl,
},
);
self
}
pub fn positive_response_ttl_bounds(&self, query_type: RecordType) -> RangeInclusive<Duration> {
let bounds = self.by_query_type.get(&query_type).unwrap_or(&self.default);
let min = bounds
.positive_min_ttl
.unwrap_or_else(|| Duration::from_secs(0));
let max = bounds
.positive_max_ttl
.unwrap_or_else(|| Duration::from_secs(u64::from(MAX_TTL)));
min..=max
}
pub fn negative_response_ttl_bounds(&self, query_type: RecordType) -> RangeInclusive<Duration> {
let bounds = self.by_query_type.get(&query_type).unwrap_or(&self.default);
let min = bounds
.negative_min_ttl
.unwrap_or_else(|| Duration::from_secs(0));
let max = bounds
.negative_max_ttl
.unwrap_or_else(|| Duration::from_secs(u64::from(MAX_TTL)));
min..=max
}
}
#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(Deserialize))]
#[cfg_attr(feature = "serde", serde(deny_unknown_fields))]
pub struct TtlBounds {
#[cfg_attr(
feature = "serde",
serde(default, deserialize_with = "duration_deserialize")
)]
positive_min_ttl: Option<Duration>,
#[cfg_attr(
feature = "serde",
serde(default, deserialize_with = "duration_deserialize")
)]
negative_min_ttl: Option<Duration>,
#[cfg_attr(
feature = "serde",
serde(default, deserialize_with = "duration_deserialize")
)]
positive_max_ttl: Option<Duration>,
#[cfg_attr(
feature = "serde",
serde(default, deserialize_with = "duration_deserialize")
)]
negative_max_ttl: Option<Duration>,
}
impl DnsLru {
pub fn new(capacity: usize, ttl_config: TtlConfig) -> Self {
let cache = Cache::builder()
.max_capacity(capacity.try_into().unwrap_or(u64::MAX))
.expire_after(LruValueExpiry)
.build();
Self {
cache,
ttl_config: Arc::new(ttl_config),
}
}
pub(crate) fn clear(&self) {
self.cache.invalidate_all();
}
pub(crate) fn insert(
&self,
query: Query,
records_and_ttl: Vec<(Record, u32)>,
now: Instant,
) -> Lookup {
let len = records_and_ttl.len();
let (positive_min_ttl, positive_max_ttl) = self
.ttl_config
.positive_response_ttl_bounds(query.query_type())
.into_inner();
let (records, ttl): (Vec<Record>, Duration) = records_and_ttl.into_iter().fold(
(Vec::with_capacity(len), positive_max_ttl),
|(mut records, mut min_ttl), (record, ttl)| {
records.push(record);
let ttl = Duration::from_secs(u64::from(ttl));
min_ttl = min_ttl.min(ttl);
(records, min_ttl)
},
);
let ttl = positive_min_ttl.max(ttl);
let valid_until = now + ttl;
let lookup = Lookup::new_with_deadline(query.clone(), Arc::from(records), valid_until);
self.cache.insert(
query,
LruValue {
lookup: Ok(lookup.clone()),
valid_until,
},
);
lookup
}
pub fn insert_records(
&self,
original_query: Query,
records: impl Iterator<Item = Record>,
now: Instant,
) -> Option<Lookup> {
let records = records.fold(
HashMap::<Query, Vec<(Record, u32)>>::new(),
|mut map, record| {
let rtype = match record.record_type() {
RecordType::CNAME => original_query.query_type(),
#[cfg(feature = "__dnssec")]
RecordType::RRSIG => match RRSIG::try_borrow(record.data()) {
Some(rrsig) => rrsig.type_covered(),
None => record.record_type(),
},
_ => record.record_type(),
};
let mut query = Query::query(record.name().clone(), rtype);
query.set_query_class(record.dns_class());
let ttl = record.ttl();
map.entry(query).or_default().push((record, ttl));
map
},
);
let mut lookup = None;
for (query, records_and_ttl) in records {
let is_query = original_query == query;
let inserted = self.insert(query, records_and_ttl, now);
if is_query {
lookup = Some(inserted)
}
}
lookup
}
pub(crate) fn duplicate(&self, query: Query, lookup: Lookup, ttl: u32, now: Instant) -> Lookup {
let ttl = Duration::from_secs(u64::from(ttl));
let valid_until = now + ttl;
self.cache.insert(
query,
LruValue {
lookup: Ok(lookup.clone()),
valid_until,
},
);
lookup
}
fn nx_error_with_ttl(error: &mut ProtoError, new_ttl: Duration) {
let ProtoError { kind, .. } = error;
if let ProtoErrorKind::NoRecordsFound { negative_ttl, .. } = kind.as_mut() {
*negative_ttl = Some(u32::try_from(new_ttl.as_secs()).unwrap_or(MAX_TTL));
}
}
pub(crate) fn negative(&self, query: Query, mut error: ProtoError, now: Instant) -> ProtoError {
let ProtoError { kind, .. } = &error;
if let ProtoErrorKind::NoRecordsFound {
negative_ttl: Some(ttl),
..
} = kind.as_ref()
{
let (negative_min_ttl, negative_max_ttl) = self
.ttl_config
.negative_response_ttl_bounds(query.query_type())
.into_inner();
let ttl_duration = Duration::from_secs(u64::from(*ttl))
.clamp(negative_min_ttl, negative_max_ttl);
let valid_until = now + ttl_duration;
{
let error = error.clone();
self.cache.insert(
query,
LruValue {
lookup: Err(error),
valid_until,
},
);
}
Self::nx_error_with_ttl(&mut error, ttl_duration);
}
error
}
pub fn get(&self, query: &Query, now: Instant) -> Option<Result<Lookup, ProtoError>> {
let value = self.cache.get(query)?;
if !value.is_current(now) {
return None;
}
let mut result = value.with_updated_ttl(now).lookup;
if let Err(err) = &mut result {
Self::nx_error_with_ttl(err, value.ttl(now));
}
Some(result)
}
}
#[cfg(feature = "serde")]
fn duration_deserialize<'de, D>(deserializer: D) -> Result<Option<Duration>, D::Error>
where
D: Deserializer<'de>,
{
Ok(
Option::<u32>::deserialize(deserializer)?
.map(|seconds| Duration::from_secs(seconds.into())),
)
}
#[cfg(feature = "serde")]
mod ttl_config_deserialize;
struct LruValueExpiry;
impl Expiry<Query, LruValue> for LruValueExpiry {
fn expire_after_create(
&self,
_key: &Query,
value: &LruValue,
created_at: Instant,
) -> Option<Duration> {
Some(value.ttl(created_at))
}
fn expire_after_update(
&self,
_key: &Query,
value: &LruValue,
updated_at: Instant,
_duration_until_expiry: Option<Duration>,
) -> Option<Duration> {
Some(value.ttl(updated_at))
}
}
#[cfg(test)]
mod tests {
use std::str::FromStr;
use std::time::*;
use hickory_proto::rr::rdata::TXT;
use crate::proto::op::{Query, ResponseCode};
use crate::proto::rr::rdata::A;
use crate::proto::rr::{Name, RData, RecordType};
use super::*;
#[test]
fn test_is_current() {
let now = Instant::now();
let not_the_future = now + Duration::from_secs(4);
let future = now + Duration::from_secs(5);
let past_the_future = now + Duration::from_secs(6);
let value = LruValue {
lookup: Err(ProtoErrorKind::Message("test error").into()),
valid_until: future,
};
assert!(value.is_current(now));
assert!(value.is_current(not_the_future));
assert!(value.is_current(future));
assert!(!value.is_current(past_the_future));
}
#[test]
fn test_lookup_uses_positive_min_ttl() {
let now = Instant::now();
let name = Name::from_str("www.example.com.").unwrap();
let query = Query::query(name.clone(), RecordType::A);
let ips_ttl = vec![(
Record::from_rdata(name.clone(), 1, RData::A(A::new(127, 0, 0, 1))),
1,
)];
let ips = [RData::A(A::new(127, 0, 0, 1))];
let ttls = TtlConfig {
default: TtlBounds {
positive_min_ttl: Some(Duration::from_secs(2)),
..TtlBounds::default()
},
..TtlConfig::default()
};
let lru = DnsLru::new(1, ttls);
let rc_ips = lru.insert(query.clone(), ips_ttl, now);
assert_eq!(*rc_ips.iter().next().unwrap(), ips[0]);
assert_eq!(rc_ips.valid_until(), now + Duration::from_secs(2));
let ips_ttl = vec![(
Record::from_rdata(name, 3, RData::A(A::new(127, 0, 0, 1))),
3,
)];
let rc_ips = lru.insert(query, ips_ttl, now);
assert_eq!(*rc_ips.iter().next().unwrap(), ips[0]);
assert_eq!(rc_ips.valid_until(), now + Duration::from_secs(3));
}
#[test]
fn test_error_uses_negative_min_ttl() {
let now = Instant::now();
let name = Query::query(Name::from_str("www.example.com.").unwrap(), RecordType::A);
let ttls = TtlConfig {
default: TtlBounds {
negative_min_ttl: Some(Duration::from_secs(2)),
..TtlBounds::default()
},
..TtlConfig::default()
};
let lru = DnsLru::new(1, ttls);
let err = ProtoErrorKind::NoRecordsFound {
query: Box::new(name.clone()),
soa: None,
ns: None,
negative_ttl: Some(1),
response_code: ResponseCode::NoError,
trusted: false,
authorities: None,
};
let nx_error = lru.negative(name.clone(), err.into(), now);
match nx_error.kind() {
&ProtoErrorKind::NoRecordsFound { negative_ttl, .. } => {
let valid_until = negative_ttl.expect("resolve error should have a deadline");
assert_eq!(valid_until, 2);
}
other => panic!("expected ProtoErrorKind::NoRecordsFound, got {:?}", other),
}
let err = ProtoErrorKind::NoRecordsFound {
query: Box::new(name.clone()),
soa: None,
ns: None,
negative_ttl: Some(3),
response_code: ResponseCode::NoError,
trusted: false,
authorities: None,
};
let nx_error = lru.negative(name, err.into(), now);
match nx_error.kind() {
&ProtoErrorKind::NoRecordsFound { negative_ttl, .. } => {
let negative_ttl = negative_ttl.expect("ProtoError should have a deadline");
assert_eq!(negative_ttl, 3);
}
other => panic!("expected ProtoErrorKind::NoRecordsFound, got {:?}", other),
}
}
#[test]
fn test_lookup_uses_positive_max_ttl() {
let now = Instant::now();
let name = Name::from_str("www.example.com.").unwrap();
let query = Query::query(name.clone(), RecordType::A);
let ips_ttl = vec![(
Record::from_rdata(name.clone(), 62, RData::A(A::new(127, 0, 0, 1))),
62,
)];
let ips = [RData::A(A::new(127, 0, 0, 1))];
let ttls = TtlConfig {
default: TtlBounds {
positive_max_ttl: Some(Duration::from_secs(60)),
..TtlBounds::default()
},
..TtlConfig::default()
};
let lru = DnsLru::new(1, ttls);
let rc_ips = lru.insert(query.clone(), ips_ttl, now);
assert_eq!(*rc_ips.iter().next().unwrap(), ips[0]);
assert_eq!(rc_ips.valid_until(), now + Duration::from_secs(60));
let ips_ttl = vec![(
Record::from_rdata(name, 59, RData::A(A::new(127, 0, 0, 1))),
59,
)];
let rc_ips = lru.insert(query, ips_ttl, now);
assert_eq!(*rc_ips.iter().next().unwrap(), ips[0]);
assert_eq!(rc_ips.valid_until(), now + Duration::from_secs(59));
}
#[test]
fn test_error_uses_negative_max_ttl() {
let now = Instant::now();
let name = Query::query(Name::from_str("www.example.com.").unwrap(), RecordType::A);
let ttls = TtlConfig {
default: TtlBounds {
negative_max_ttl: Some(Duration::from_secs(60)),
..TtlBounds::default()
},
..TtlConfig::default()
};
let lru = DnsLru::new(1, ttls);
let err: ProtoErrorKind = ProtoErrorKind::NoRecordsFound {
query: Box::new(name.clone()),
soa: None,
ns: None,
negative_ttl: Some(62),
response_code: ResponseCode::NoError,
trusted: false,
authorities: None,
};
let nx_error = lru.negative(name.clone(), err.into(), now);
match nx_error.kind() {
&ProtoErrorKind::NoRecordsFound { negative_ttl, .. } => {
let negative_ttl = negative_ttl.expect("resolve error should have a deadline");
assert_eq!(negative_ttl, 60);
}
other => panic!("expected ProtoErrorKind::NoRecordsFound, got {:?}", other),
}
let err = ProtoErrorKind::NoRecordsFound {
query: Box::new(name.clone()),
soa: None,
ns: None,
negative_ttl: Some(59),
response_code: ResponseCode::NoError,
trusted: false,
authorities: None,
};
let nx_error = lru.negative(name, err.into(), now);
match nx_error.kind() {
&ProtoErrorKind::NoRecordsFound { negative_ttl, .. } => {
let negative_ttl = negative_ttl.expect("resolve error should have a deadline");
assert_eq!(negative_ttl, 59);
}
other => panic!("expected ProtoErrorKind::NoRecordsFound, got {:?}", other),
}
}
#[test]
fn test_insert() {
let now = Instant::now();
let name = Name::from_str("www.example.com.").unwrap();
let query = Query::query(name.clone(), RecordType::A);
let ips_ttl = vec![(
Record::from_rdata(name, 1, RData::A(A::new(127, 0, 0, 1))),
1,
)];
let ips = [RData::A(A::new(127, 0, 0, 1))];
let lru = DnsLru::new(1, TtlConfig::default());
let rc_ips = lru.insert(query.clone(), ips_ttl, now);
assert_eq!(*rc_ips.iter().next().unwrap(), ips[0]);
let rc_ips = lru.get(&query, now).unwrap().expect("records should exist");
assert_eq!(*rc_ips.iter().next().unwrap(), ips[0]);
}
#[test]
fn test_update_ttl() {
let now = Instant::now();
let name = Name::from_str("www.example.com.").unwrap();
let query = Query::query(name.clone(), RecordType::A);
let ips_ttl = vec![(
Record::from_rdata(name, 10, RData::A(A::new(127, 0, 0, 1))),
10,
)];
let ips = [RData::A(A::new(127, 0, 0, 1))];
let lru = DnsLru::new(1, TtlConfig::default());
let rc_ips = lru.insert(query.clone(), ips_ttl, now);
assert_eq!(*rc_ips.iter().next().unwrap(), ips[0]);
let ttl = lru
.get(&query, now + Duration::from_secs(2))
.unwrap()
.expect("records should exist")
.record_iter()
.next()
.unwrap()
.ttl();
assert!(ttl <= 8);
}
#[test]
fn test_insert_ttl() {
let now = Instant::now();
let name = Name::from_str("www.example.com.").unwrap();
let query = Query::query(name.clone(), RecordType::A);
let ips_ttl = vec![
(
Record::from_rdata(name.clone(), 1, RData::A(A::new(127, 0, 0, 1))),
1,
),
(
Record::from_rdata(name, 2, RData::A(A::new(127, 0, 0, 2))),
2,
),
];
let ips = [
RData::A(A::new(127, 0, 0, 1)),
RData::A(A::new(127, 0, 0, 2)),
];
let lru = DnsLru::new(1, TtlConfig::default());
lru.insert(query.clone(), ips_ttl, now);
let rc_ips = lru
.get(&query, now + Duration::from_secs(1))
.unwrap()
.expect("records should exist");
assert_eq!(*rc_ips.iter().next().unwrap(), ips[0]);
let rc_ips = lru.get(&query, now + Duration::from_secs(2));
assert!(rc_ips.is_none());
}
#[test]
fn test_insert_positive_min_ttl() {
let now = Instant::now();
let name = Name::from_str("www.example.com.").unwrap();
let query = Query::query(name.clone(), RecordType::A);
let ips_ttl = vec![
(
Record::from_rdata(name.clone(), 1, RData::A(A::new(127, 0, 0, 1))),
1,
),
(
Record::from_rdata(name, 2, RData::A(A::new(127, 0, 0, 2))),
2,
),
];
let ips = [
RData::A(A::new(127, 0, 0, 1)),
RData::A(A::new(127, 0, 0, 2)),
];
let ttls = TtlConfig {
default: TtlBounds {
positive_min_ttl: Some(Duration::from_secs(3)),
..TtlBounds::default()
},
..TtlConfig::default()
};
let lru = DnsLru::new(1, ttls);
lru.insert(query.clone(), ips_ttl, now);
let rc_ips = lru
.get(&query, now + Duration::from_secs(1))
.unwrap()
.expect("records should exist");
for (rc_ip, ip) in rc_ips.iter().zip(ips.iter()) {
assert_eq!(rc_ip, ip, "after 1 second");
}
let rc_ips = lru
.get(&query, now + Duration::from_secs(2))
.unwrap()
.expect("records should exist");
for (rc_ip, ip) in rc_ips.iter().zip(ips.iter()) {
assert_eq!(rc_ip, ip, "after 2 seconds");
}
let rc_ips = lru
.get(&query, now + Duration::from_secs(3))
.unwrap()
.expect("records should exist");
for (rc_ip, ip) in rc_ips.iter().zip(ips.iter()) {
assert_eq!(rc_ip, ip, "after 3 seconds");
}
let rc_ips = lru.get(&query, now + Duration::from_secs(4));
assert!(rc_ips.is_none());
}
#[test]
fn test_insert_positive_max_ttl() {
let now = Instant::now();
let name = Name::from_str("www.example.com.").unwrap();
let query = Query::query(name.clone(), RecordType::A);
let ips_ttl = vec![
(
Record::from_rdata(name.clone(), 400, RData::A(A::new(127, 0, 0, 1))),
400,
),
(
Record::from_rdata(name, 500, RData::A(A::new(127, 0, 0, 2))),
500,
),
];
let ips = [
RData::A(A::new(127, 0, 0, 1)),
RData::A(A::new(127, 0, 0, 2)),
];
let ttls = TtlConfig {
default: TtlBounds {
positive_max_ttl: Some(Duration::from_secs(2)),
..TtlBounds::default()
},
..TtlConfig::default()
};
let lru = DnsLru::new(1, ttls);
lru.insert(query.clone(), ips_ttl, now);
let rc_ips = lru
.get(&query, now + Duration::from_secs(1))
.unwrap()
.expect("records should exist");
for (rc_ip, ip) in rc_ips.iter().zip(ips.iter()) {
assert_eq!(rc_ip, ip, "after 1 second");
}
let rc_ips = lru
.get(&query, now + Duration::from_secs(2))
.unwrap()
.expect("records should exist");
for (rc_ip, ip) in rc_ips.iter().zip(ips.iter()) {
assert_eq!(rc_ip, ip, "after 2 seconds");
}
let rc_ips = lru.get(&query, now + Duration::from_secs(3));
assert!(rc_ips.is_none());
}
#[test]
fn test_lookup_positive_min_ttl_different_query_types() {
let now = Instant::now();
let name = Name::from_str("www.example.com.").unwrap();
let query_a = Query::query(name.clone(), RecordType::A);
let query_txt = Query::query(name.clone(), RecordType::TXT);
let rdata_a = RData::A(A::new(127, 0, 0, 1));
let rdata_txt = RData::TXT(TXT::new(vec!["data".to_string()]));
let records_ttl_a = vec![(Record::from_rdata(name.clone(), 1, rdata_a.clone()), 1)];
let records_ttl_txt = vec![(Record::from_rdata(name.clone(), 1, rdata_txt.clone()), 1)];
let mut ttl_config = TtlConfig::new(Some(Duration::from_secs(2)), None, None, None);
ttl_config.with_query_type_ttl_bounds(
RecordType::TXT,
Some(Duration::from_secs(5)),
None,
None,
None,
);
let lru = DnsLru::new(2, ttl_config);
let rc_a = lru.insert(query_a.clone(), records_ttl_a, now);
assert_eq!(*rc_a.iter().next().unwrap(), rdata_a);
assert_eq!(rc_a.valid_until(), now + Duration::from_secs(2));
let rc_txt = lru.insert(query_txt.clone(), records_ttl_txt, now);
assert_eq!(*rc_txt.iter().next().unwrap(), rdata_txt);
assert_eq!(rc_txt.valid_until(), now + Duration::from_secs(5));
let records_ttl_a = vec![(Record::from_rdata(name.clone(), 1, rdata_a.clone()), 7)];
let records_ttl_txt = vec![(Record::from_rdata(name.clone(), 1, rdata_txt.clone()), 7)];
let rc_a = lru.insert(query_a, records_ttl_a, now);
assert_eq!(*rc_a.iter().next().unwrap(), rdata_a);
assert_eq!(rc_a.valid_until(), now + Duration::from_secs(7));
let rc_txt = lru.insert(query_txt, records_ttl_txt, now);
assert_eq!(*rc_txt.iter().next().unwrap(), rdata_txt);
assert_eq!(rc_txt.valid_until(), now + Duration::from_secs(7));
}
}