use std::sync::Arc;
use std::time::{Duration, Instant};
use proto::op::Query;
use proto::rr::Record;
use config;
use error::*;
use lookup::Lookup;
use lru_cache::LruCache;
pub const MAX_TTL: u32 = 86400_u32;
#[derive(Debug)]
struct LruValue {
lookup: Option<Lookup>,
valid_until: Instant,
}
impl LruValue {
fn is_current(&self, now: Instant) -> bool {
now <= self.valid_until
}
}
#[derive(Debug)]
pub(crate) struct DnsLru {
cache: LruCache<Query, LruValue>,
positive_min_ttl: Duration,
negative_min_ttl: Duration,
positive_max_ttl: Duration,
negative_max_ttl: Duration,
}
#[derive(Copy, Clone, Debug, Default)]
pub(crate) struct TtlConfig {
pub positive_min_ttl: Option<Duration>,
pub negative_min_ttl: Option<Duration>,
pub positive_max_ttl: Option<Duration>,
pub negative_max_ttl: Option<Duration>,
}
impl TtlConfig {
pub(crate) fn from_opts(opts: &config::ResolverOpts) -> TtlConfig {
TtlConfig {
positive_min_ttl: opts.positive_min_ttl,
negative_min_ttl: opts.negative_min_ttl,
positive_max_ttl: opts.positive_max_ttl,
negative_max_ttl: opts.negative_max_ttl,
}
}
}
impl DnsLru {
pub(crate) fn new(capacity: usize, ttl_cfg: TtlConfig) -> Self {
let TtlConfig {
positive_min_ttl,
negative_min_ttl,
positive_max_ttl,
negative_max_ttl,
} = ttl_cfg;
let cache = LruCache::new(capacity);
Self {
cache,
positive_min_ttl: positive_min_ttl.unwrap_or_else(|| Duration::from_secs(0)),
negative_min_ttl: negative_min_ttl.unwrap_or_else(|| Duration::from_secs(0)),
positive_max_ttl: positive_max_ttl
.unwrap_or_else(|| Duration::from_secs(u64::from(MAX_TTL))),
negative_max_ttl: negative_max_ttl
.unwrap_or_else(|| Duration::from_secs(u64::from(MAX_TTL))),
}
}
pub(crate) fn insert(
&mut self,
query: Query,
records_and_ttl: Vec<(Record, u32)>,
now: Instant,
) -> Lookup {
let len = records_and_ttl.len();
let (records, ttl): (Vec<Record>, Duration) = records_and_ttl.into_iter().fold(
(Vec::with_capacity(len), self.positive_max_ttl),
|(mut records, mut min_ttl), (record, ttl)| {
records.push(record);
let ttl = Duration::from_secs(u64::from(ttl));
min_ttl = min_ttl.min(ttl);
(records, min_ttl)
},
);
let ttl = self.positive_min_ttl.max(ttl);
let valid_until = now + ttl;
let lookup = Lookup::new_with_deadline(query.clone(), Arc::new(records), valid_until);
self.cache.insert(
query,
LruValue {
lookup: Some(lookup.clone()),
valid_until,
},
);
lookup
}
pub(crate) fn duplicate(
&mut self,
query: Query,
lookup: Lookup,
ttl: u32,
now: Instant,
) -> Lookup {
let ttl = Duration::from_secs(u64::from(ttl));
let valid_until = now + ttl;
self.cache.insert(
query,
LruValue {
lookup: Some(lookup.clone()),
valid_until,
},
);
lookup
}
pub(crate) fn nx_error(query: Query, valid_until: Option<Instant>) -> ResolveError {
ResolveErrorKind::NoRecordsFound { query, valid_until }.into()
}
pub(crate) fn negative(&mut self, query: Query, ttl: u32, now: Instant) -> ResolveError {
let ttl = Duration::from_secs(u64::from(ttl))
.max(self.negative_min_ttl)
.min(self.negative_max_ttl);
let valid_until = now + ttl;
self.cache.insert(
query.clone(),
LruValue {
lookup: None,
valid_until,
},
);
Self::nx_error(query, Some(valid_until))
}
pub(crate) fn get(&mut self, query: &Query, now: Instant) -> Option<Lookup> {
let mut out_of_date = false;
let lookup = self.cache.get_mut(query).and_then(|value| {
if value.is_current(now) {
out_of_date = false;
value.lookup.clone()
} else {
out_of_date = true;
None
}
});
if out_of_date {
self.cache.remove(query);
}
lookup
}
}
#[cfg(test)]
mod tests {
use std::net::*;
use std::str::FromStr;
use std::time::*;
use proto::op::Query;
use proto::rr::{Name, RData, RecordType};
use super::*;
#[test]
fn test_is_current() {
let now = Instant::now();
let not_the_future = now + Duration::from_secs(4);
let future = now + Duration::from_secs(5);
let past_the_future = now + Duration::from_secs(6);
let value = LruValue {
lookup: None,
valid_until: future,
};
assert!(value.is_current(now));
assert!(value.is_current(not_the_future));
assert!(value.is_current(future));
assert!(!value.is_current(past_the_future));
}
#[test]
fn test_lookup_uses_positive_min_ttl() {
let now = Instant::now();
let name = Name::from_str("www.example.com.").unwrap();
let query = Query::query(name.clone(), RecordType::A);
let ips_ttl = vec![(
Record::from_rdata(name.clone(), 1, RData::A(Ipv4Addr::new(127, 0, 0, 1))),
1,
)];
let ips = vec![RData::A(Ipv4Addr::new(127, 0, 0, 1))];
let ttls = TtlConfig {
positive_min_ttl: Some(Duration::from_secs(2)),
..Default::default()
};
let mut lru = DnsLru::new(1, ttls);
let rc_ips = lru.insert(query.clone(), ips_ttl, now);
assert_eq!(*rc_ips.iter().next().unwrap(), ips[0]);
assert_eq!(rc_ips.valid_until(), now + Duration::from_secs(2));
let ips_ttl = vec![(
Record::from_rdata(name.clone(), 3, RData::A(Ipv4Addr::new(127, 0, 0, 1))),
3,
)];
let rc_ips = lru.insert(query.clone(), ips_ttl, now);
assert_eq!(*rc_ips.iter().next().unwrap(), ips[0]);
assert_eq!(rc_ips.valid_until(), now + Duration::from_secs(3));
}
#[test]
fn test_error_uses_negative_min_ttl() {
let now = Instant::now();
let name = Query::query(Name::from_str("www.example.com.").unwrap(), RecordType::A);
let ttls = TtlConfig {
negative_min_ttl: Some(Duration::from_secs(2)),
..Default::default()
};
let mut lru = DnsLru::new(1, ttls);
let nx_error = lru.negative(name.clone(), 1, now);
match nx_error.kind() {
&ResolveErrorKind::NoRecordsFound { valid_until, .. } => {
let valid_until = valid_until.expect("resolve error should have a deadline");
assert_eq!(valid_until, now + Duration::from_secs(2));
}
other => panic!("expected ResolveErrorKind::NoRecordsFound, got {:?}", other),
}
let nx_error = lru.negative(name.clone(), 3, now);
match nx_error.kind() {
&ResolveErrorKind::NoRecordsFound { valid_until, .. } => {
let valid_until = valid_until.expect("ResolveError should have a deadline");
assert_eq!(valid_until, now + Duration::from_secs(3));
}
other => panic!("expected ResolveErrorKind::NoRecordsFound, got {:?}", other),
}
}
#[test]
fn test_lookup_uses_positive_max_ttl() {
let now = Instant::now();
let name = Name::from_str("www.example.com.").unwrap();
let query = Query::query(name.clone(), RecordType::A);
let ips_ttl = vec![(
Record::from_rdata(name.clone(), 62, RData::A(Ipv4Addr::new(127, 0, 0, 1))),
62,
)];
let ips = vec![RData::A(Ipv4Addr::new(127, 0, 0, 1))];
let ttls = TtlConfig {
positive_max_ttl: Some(Duration::from_secs(60)),
..Default::default()
};
let mut lru = DnsLru::new(1, ttls);
let rc_ips = lru.insert(query.clone(), ips_ttl, now);
assert_eq!(*rc_ips.iter().next().unwrap(), ips[0]);
assert_eq!(rc_ips.valid_until(), now + Duration::from_secs(60));
let ips_ttl = vec![(
Record::from_rdata(name.clone(), 59, RData::A(Ipv4Addr::new(127, 0, 0, 1))),
59,
)];
let rc_ips = lru.insert(query.clone(), ips_ttl, now);
assert_eq!(*rc_ips.iter().next().unwrap(), ips[0]);
assert_eq!(rc_ips.valid_until(), now + Duration::from_secs(59));
}
#[test]
fn test_error_uses_negative_max_ttl() {
let now = Instant::now();
let name = Query::query(Name::from_str("www.example.com.").unwrap(), RecordType::A);
let ttls = TtlConfig {
negative_max_ttl: Some(Duration::from_secs(60)),
..Default::default()
};
let mut lru = DnsLru::new(1, ttls);
let nx_error = lru.negative(name.clone(), 62, now);
match nx_error.kind() {
&ResolveErrorKind::NoRecordsFound { valid_until, .. } => {
let valid_until = valid_until.expect("resolve error should have a deadline");
assert_eq!(valid_until, now + Duration::from_secs(60));
}
other => panic!("expected ResolveErrorKind::NoRecordsFound, got {:?}", other),
}
let nx_error = lru.negative(name.clone(), 59, now);
match nx_error.kind() {
&ResolveErrorKind::NoRecordsFound { valid_until, .. } => {
let valid_until = valid_until.expect("resolve error should have a deadline");
assert_eq!(valid_until, now + Duration::from_secs(59));
}
other => panic!("expected ResolveErrorKind::NoRecordsFound, got {:?}", other),
}
}
#[test]
fn test_insert() {
let now = Instant::now();
let name = Name::from_str("www.example.com.").unwrap();
let query = Query::query(name.clone(), RecordType::A);
let ips_ttl = vec![(
Record::from_rdata(name.clone(), 1, RData::A(Ipv4Addr::new(127, 0, 0, 1))),
1,
)];
let ips = vec![RData::A(Ipv4Addr::new(127, 0, 0, 1))];
let mut lru = DnsLru::new(1, TtlConfig::default());
let rc_ips = lru.insert(query.clone(), ips_ttl, now);
assert_eq!(*rc_ips.iter().next().unwrap(), ips[0]);
let rc_ips = lru.get(&query, now).unwrap();
assert_eq!(*rc_ips.iter().next().unwrap(), ips[0]);
}
#[test]
fn test_insert_ttl() {
let now = Instant::now();
let name = Name::from_str("www.example.com.").unwrap();
let query = Query::query(name.clone(), RecordType::A);
let ips_ttl = vec![
(
Record::from_rdata(name.clone(), 1, RData::A(Ipv4Addr::new(127, 0, 0, 1))),
1,
),
(
Record::from_rdata(name.clone(), 2, RData::A(Ipv4Addr::new(127, 0, 0, 2))),
2,
),
];
let ips = vec![
RData::A(Ipv4Addr::new(127, 0, 0, 1)),
RData::A(Ipv4Addr::new(127, 0, 0, 2)),
];
let mut lru = DnsLru::new(1, TtlConfig::default());
lru.insert(query.clone(), ips_ttl, now);
let rc_ips = lru.get(&query, now + Duration::from_secs(1)).unwrap();
assert_eq!(*rc_ips.iter().next().unwrap(), ips[0]);
let rc_ips = lru.get(&query, now + Duration::from_secs(2));
assert!(rc_ips.is_none());
}
#[test]
fn test_insert_positive_min_ttl() {
let now = Instant::now();
let name = Name::from_str("www.example.com.").unwrap();
let query = Query::query(name.clone(), RecordType::A);
let ips_ttl = vec![
(
Record::from_rdata(name.clone(), 1, RData::A(Ipv4Addr::new(127, 0, 0, 1))),
1,
),
(
Record::from_rdata(name.clone(), 2, RData::A(Ipv4Addr::new(127, 0, 0, 2))),
2,
),
];
let ips = vec![
RData::A(Ipv4Addr::new(127, 0, 0, 1)),
RData::A(Ipv4Addr::new(127, 0, 0, 2)),
];
let ttls = TtlConfig {
positive_min_ttl: Some(Duration::from_secs(3)),
..Default::default()
};
let mut lru = DnsLru::new(1, ttls);
lru.insert(query.clone(), ips_ttl, now);
let rc_ips = lru.get(&query, now + Duration::from_secs(1)).unwrap();
for (rc_ip, ip) in rc_ips.iter().zip(ips.iter()) {
assert_eq!(rc_ip, ip, "after 1 second");
}
let rc_ips = lru.get(&query, now + Duration::from_secs(2)).unwrap();
for (rc_ip, ip) in rc_ips.iter().zip(ips.iter()) {
assert_eq!(rc_ip, ip, "after 2 seconds");
}
let rc_ips = lru.get(&query, now + Duration::from_secs(3)).unwrap();
for (rc_ip, ip) in rc_ips.iter().zip(ips.iter()) {
assert_eq!(rc_ip, ip, "after 3 seconds");
}
let rc_ips = lru.get(&query, now + Duration::from_secs(4));
assert!(rc_ips.is_none());
}
#[test]
fn test_insert_positive_max_ttl() {
let now = Instant::now();
let name = Name::from_str("www.example.com.").unwrap();
let query = Query::query(name.clone(), RecordType::A);
let ips_ttl = vec![
(
Record::from_rdata(name.clone(), 400, RData::A(Ipv4Addr::new(127, 0, 0, 1))),
400,
),
(
Record::from_rdata(name.clone(), 500, RData::A(Ipv4Addr::new(127, 0, 0, 2))),
500,
),
];
let ips = vec![
RData::A(Ipv4Addr::new(127, 0, 0, 1)),
RData::A(Ipv4Addr::new(127, 0, 0, 2)),
];
let ttls = TtlConfig {
positive_max_ttl: Some(Duration::from_secs(2)),
..Default::default()
};
let mut lru = DnsLru::new(1, ttls);
lru.insert(query.clone(), ips_ttl, now);
let rc_ips = lru.get(&query, now + Duration::from_secs(1)).unwrap();
for (rc_ip, ip) in rc_ips.iter().zip(ips.iter()) {
assert_eq!(rc_ip, ip, "after 1 second");
}
let rc_ips = lru.get(&query, now + Duration::from_secs(2)).unwrap();
for (rc_ip, ip) in rc_ips.iter().zip(ips.iter()) {
assert_eq!(rc_ip, ip, "after 2 seconds");
}
let rc_ips = lru.get(&query, now + Duration::from_secs(3));
assert!(rc_ips.is_none());
}
}