hickory_resolver/
dns_lru.rs

1// Copyright 2015-2017 Benjamin Fry <benjaminfry@me.com>
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// https://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// https://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8//! An LRU cache designed for work with DNS lookups
9
10use std::collections::HashMap;
11use std::ops::RangeInclusive;
12use std::sync::Arc;
13use std::time::{Duration, Instant};
14
15use moka::{Expiry, sync::Cache};
16#[cfg(feature = "serde")]
17use serde::{Deserialize, Deserializer};
18
19use crate::config;
20use crate::lookup::Lookup;
21#[cfg(feature = "__dnssec")]
22use crate::proto::dnssec::rdata::RRSIG;
23use crate::proto::op::Query;
24#[cfg(feature = "__dnssec")]
25use crate::proto::rr::RecordData;
26use crate::proto::rr::{Record, RecordType};
27use crate::proto::{ProtoError, ProtoErrorKind};
28
29/// Maximum TTL. This is set to one day (in seconds).
30///
31/// [RFC 2181, section 8](https://tools.ietf.org/html/rfc2181#section-8) says
32/// that the maximum TTL value is 2147483647, but implementations may place an
33/// upper bound on received TTLs.
34pub(crate) const MAX_TTL: u32 = 86400_u32;
35
36#[derive(Debug, Clone)]
37struct LruValue {
38    // In the Err case, this represents an NXDomain
39    lookup: Result<Lookup, ProtoError>,
40    valid_until: Instant,
41}
42
43impl LruValue {
44    /// Returns true if this set of ips is still valid
45    fn is_current(&self, now: Instant) -> bool {
46        now <= self.valid_until
47    }
48
49    /// Returns the ttl as a Duration of time remaining.
50    fn ttl(&self, now: Instant) -> Duration {
51        self.valid_until.saturating_duration_since(now)
52    }
53
54    fn with_updated_ttl(&self, now: Instant) -> Self {
55        let lookup = match &self.lookup {
56            Ok(lookup) => {
57                let records = lookup
58                    .records()
59                    .iter()
60                    .map(|record| {
61                        let mut record = record.clone();
62                        record.set_ttl(self.ttl(now).as_secs() as u32);
63                        record
64                    })
65                    .collect::<Vec<Record>>();
66                Ok(Lookup::new_with_deadline(
67                    lookup.query().clone(),
68                    Arc::from(records),
69                    self.valid_until,
70                ))
71            }
72            Err(e) => Err(e.clone()),
73        };
74        Self {
75            lookup,
76            valid_until: self.valid_until,
77        }
78    }
79}
80
81/// A cache specifically for storing DNS records.
82///
83/// This is named `DnsLru` for historical reasons. It currently uses a "TinyLFU" policy, implemented
84/// in the `moka` library.
85#[derive(Clone, Debug)]
86pub struct DnsLru {
87    cache: Cache<Query, LruValue>,
88    ttl_config: Arc<TtlConfig>,
89}
90
91/// The time-to-live (TTL) configuration used by the cache.
92///
93/// Minimum and maximum TTLs can be set for both positive responses and negative responses. Separate
94/// limits may be set depending on the query type.
95///
96/// Note that TTLs in DNS are represented as a number of seconds stored in a 32-bit unsigned
97/// integer. We use `Duration` here, instead of `u32`, which can express larger values than the DNS
98/// standard. Generally, a `Duration` greater than `u32::MAX_VALUE` shouldn't cause any issue, as
99/// this will never be used in serialization, but note that this would be outside the standard
100/// range.
101#[derive(Clone, Debug, Default, PartialEq, Eq)]
102#[cfg_attr(feature = "serde", derive(Deserialize))]
103#[cfg_attr(
104    feature = "serde",
105    serde(from = "ttl_config_deserialize::TtlConfigMap")
106)]
107pub struct TtlConfig {
108    /// TTL limits applied to all queries.
109    default: TtlBounds,
110
111    /// TTL limits applied to queries with specific query types.
112    by_query_type: HashMap<RecordType, TtlBounds>,
113}
114
115impl TtlConfig {
116    /// Construct the LRU's TTL configuration based on the ResolverOpts configuration.
117    pub fn from_opts(opts: &config::ResolverOpts) -> Self {
118        Self {
119            default: TtlBounds {
120                positive_min_ttl: opts.positive_min_ttl,
121                negative_min_ttl: opts.negative_min_ttl,
122                positive_max_ttl: opts.positive_max_ttl,
123                negative_max_ttl: opts.negative_max_ttl,
124            },
125            by_query_type: HashMap::new(),
126        }
127    }
128
129    /// Creates a new cache TTL configuration.
130    ///
131    /// The provided minimum and maximum TTLs will be applied to all queries unless otherwise
132    /// specified via [`Self::with_query_type_ttl_bounds`].
133    ///
134    /// If a minimum value is not provided, it will default to 0 seconds. If a maximum value is not
135    /// provided, it will default to one day.
136    pub fn new(
137        positive_min_ttl: Option<Duration>,
138        negative_min_ttl: Option<Duration>,
139        positive_max_ttl: Option<Duration>,
140        negative_max_ttl: Option<Duration>,
141    ) -> Self {
142        Self {
143            default: TtlBounds {
144                positive_min_ttl,
145                negative_min_ttl,
146                positive_max_ttl,
147                negative_max_ttl,
148            },
149            by_query_type: HashMap::new(),
150        }
151    }
152
153    /// Override the minimum and maximum TTL values for a specific query type.
154    ///
155    /// If a minimum value is not provided, it will default to 0 seconds. If a maximum value is not
156    /// provided, it will default to one day.
157    pub fn with_query_type_ttl_bounds(
158        &mut self,
159        query_type: RecordType,
160        positive_min_ttl: Option<Duration>,
161        negative_min_ttl: Option<Duration>,
162        positive_max_ttl: Option<Duration>,
163        negative_max_ttl: Option<Duration>,
164    ) -> &mut Self {
165        self.by_query_type.insert(
166            query_type,
167            TtlBounds {
168                positive_min_ttl,
169                negative_min_ttl,
170                positive_max_ttl,
171                negative_max_ttl,
172            },
173        );
174        self
175    }
176
177    /// Retrieves the minimum and maximum TTL values for positive responses.
178    pub fn positive_response_ttl_bounds(&self, query_type: RecordType) -> RangeInclusive<Duration> {
179        let bounds = self.by_query_type.get(&query_type).unwrap_or(&self.default);
180        let min = bounds
181            .positive_min_ttl
182            .unwrap_or_else(|| Duration::from_secs(0));
183        let max = bounds
184            .positive_max_ttl
185            .unwrap_or_else(|| Duration::from_secs(u64::from(MAX_TTL)));
186        min..=max
187    }
188
189    /// Retrieves the minimum and maximum TTL values for negative responses.
190    pub fn negative_response_ttl_bounds(&self, query_type: RecordType) -> RangeInclusive<Duration> {
191        let bounds = self.by_query_type.get(&query_type).unwrap_or(&self.default);
192        let min = bounds
193            .negative_min_ttl
194            .unwrap_or_else(|| Duration::from_secs(0));
195        let max = bounds
196            .negative_max_ttl
197            .unwrap_or_else(|| Duration::from_secs(u64::from(MAX_TTL)));
198        min..=max
199    }
200}
201
202/// Minimum and maximum TTL values for positive and negative responses.
203#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
204#[cfg_attr(feature = "serde", derive(Deserialize))]
205#[cfg_attr(feature = "serde", serde(deny_unknown_fields))]
206pub struct TtlBounds {
207    /// An optional minimum TTL value for positive responses.
208    ///
209    /// Positive responses with TTLs under `positive_min_ttl` will use
210    /// `positive_min_ttl` instead.
211    #[cfg_attr(
212        feature = "serde",
213        serde(default, deserialize_with = "duration_deserialize")
214    )]
215    positive_min_ttl: Option<Duration>,
216
217    /// An optional minimum TTL value for negative (`NXDOMAIN`) responses.
218    ///
219    /// `NXDOMAIN` responses with TTLs under `negative_min_ttl` will use
220    /// `negative_min_ttl` instead.
221    #[cfg_attr(
222        feature = "serde",
223        serde(default, deserialize_with = "duration_deserialize")
224    )]
225    negative_min_ttl: Option<Duration>,
226
227    /// An optional maximum TTL value for positive responses.
228    ///
229    /// Positive responses with TTLs over `positive_max_ttl` will use
230    /// `positive_max_ttl` instead.
231    #[cfg_attr(
232        feature = "serde",
233        serde(default, deserialize_with = "duration_deserialize")
234    )]
235    positive_max_ttl: Option<Duration>,
236
237    /// An optional maximum TTL value for negative (`NXDOMAIN`) responses.
238    ///
239    /// `NXDOMAIN` responses with TTLs over `negative_max_ttl` will use
240    /// `negative_max_ttl` instead.
241    #[cfg_attr(
242        feature = "serde",
243        serde(default, deserialize_with = "duration_deserialize")
244    )]
245    negative_max_ttl: Option<Duration>,
246}
247
248impl DnsLru {
249    /// Construct a new cache
250    ///
251    /// # Arguments
252    ///
253    /// * `capacity` - size in number of cached queries
254    /// * `ttl_config` - minimum and maximum TTLs for cached records
255    pub fn new(capacity: usize, ttl_config: TtlConfig) -> Self {
256        let cache = Cache::builder()
257            .max_capacity(capacity.try_into().unwrap_or(u64::MAX))
258            .expire_after(LruValueExpiry)
259            .build();
260        Self {
261            cache,
262            ttl_config: Arc::new(ttl_config),
263        }
264    }
265
266    pub(crate) fn clear(&self) {
267        self.cache.invalidate_all();
268    }
269
270    pub(crate) fn insert(
271        &self,
272        query: Query,
273        records_and_ttl: Vec<(Record, u32)>,
274        now: Instant,
275    ) -> Lookup {
276        let len = records_and_ttl.len();
277        let (positive_min_ttl, positive_max_ttl) = self
278            .ttl_config
279            .positive_response_ttl_bounds(query.query_type())
280            .into_inner();
281
282        // collapse the values, we're going to take the Minimum TTL as the correct one
283        let (records, ttl): (Vec<Record>, Duration) = records_and_ttl.into_iter().fold(
284            (Vec::with_capacity(len), positive_max_ttl),
285            |(mut records, mut min_ttl), (record, ttl)| {
286                records.push(record);
287                let ttl = Duration::from_secs(u64::from(ttl));
288                min_ttl = min_ttl.min(ttl);
289                (records, min_ttl)
290            },
291        );
292
293        // If the cache was configured with a minimum TTL, and that value is higher
294        // than the minimum TTL in the values, use it instead.
295        let ttl = positive_min_ttl.max(ttl);
296        let valid_until = now + ttl;
297
298        // insert into the LRU
299        let lookup = Lookup::new_with_deadline(query.clone(), Arc::from(records), valid_until);
300        self.cache.insert(
301            query,
302            LruValue {
303                lookup: Ok(lookup.clone()),
304                valid_until,
305            },
306        );
307
308        lookup
309    }
310
311    /// inserts a record based on the name and type.
312    ///
313    /// # Arguments
314    ///
315    /// * `original_query` - is used for matching the records that should be returned
316    /// * `records` - the records will be partitioned by type and name for storage in the cache
317    /// * `now` - current time for use in associating TTLs
318    ///
319    /// # Return
320    ///
321    /// This should always return some records, but will be None if there are no records or the original_query matches none
322    pub fn insert_records(
323        &self,
324        original_query: Query,
325        records: impl Iterator<Item = Record>,
326        now: Instant,
327    ) -> Option<Lookup> {
328        // collect all records by name
329        let records = records.fold(
330            HashMap::<Query, Vec<(Record, u32)>>::new(),
331            |mut map, record| {
332                // it's not useful to cache RRSIGs on their own using `name()` as a key because
333                // there can be multiple RRSIG associated to the same domain name where each
334                // RRSIG is *covering* a different record type
335                //
336                // an example of this is shown below
337                //
338                // ``` console
339                // $ dig @a.iana-servers.net. +norecurse +dnssec A example.com.
340                // example.com.     3600    IN  A   93.184.215.14
341                // example.com.     3600    IN  RRSIG   A 13 2 3600 20240705065834 (..)
342                //
343                // $ dig @a.iana-servers.net. +norecurse +dnssec A example.com.
344                // example.com.     86400   IN  NS  a.iana-servers.net.
345                // example.com.     86400   IN  NS  b.iana-servers.net.
346                // example.com.     86400   IN  RRSIG   NS 13 2 86400 20240705060635 (..)
347                // ```
348                //
349                // note that there are two RRSIG records associated to `example.com.` but they are
350                // covering different record types. the first RRSIG covers the
351                // `A example.com.` record. the second RRSIG covers two `NS example.com.` records
352                //
353                // if we use ("example.com.", RecordType::RRSIG) as a key in our cache these two
354                // consecutive queries will cause the entry to be overwritten, losing the RRSIG
355                // covering the A record
356                //
357                // to avoid this problem, we'll cache the RRSIG along the record it covers using
358                // the record's type along the record's `name()` as the key in the cache
359                //
360                // For CNAME records, we want to preserve the original request query type, since
361                // that's what would be used to retrieve the cached query.
362                let rtype = match record.record_type() {
363                    RecordType::CNAME => original_query.query_type(),
364                    #[cfg(feature = "__dnssec")]
365                    RecordType::RRSIG => match RRSIG::try_borrow(record.data()) {
366                        Some(rrsig) => rrsig.type_covered(),
367                        None => record.record_type(),
368                    },
369                    _ => record.record_type(),
370                };
371
372                let mut query = Query::query(record.name().clone(), rtype);
373                query.set_query_class(record.dns_class());
374
375                let ttl = record.ttl();
376
377                map.entry(query).or_default().push((record, ttl));
378
379                map
380            },
381        );
382
383        // now insert by record type and name
384        let mut lookup = None;
385        for (query, records_and_ttl) in records {
386            let is_query = original_query == query;
387            let inserted = self.insert(query, records_and_ttl, now);
388
389            if is_query {
390                lookup = Some(inserted)
391            }
392        }
393
394        lookup
395    }
396
397    /// Generally for inserting a set of records that have already been cached, but with a different Query.
398    pub(crate) fn duplicate(&self, query: Query, lookup: Lookup, ttl: u32, now: Instant) -> Lookup {
399        let ttl = Duration::from_secs(u64::from(ttl));
400        let valid_until = now + ttl;
401
402        self.cache.insert(
403            query,
404            LruValue {
405                lookup: Ok(lookup.clone()),
406                valid_until,
407            },
408        );
409
410        lookup
411    }
412
413    /// This converts the Error to set the inner negative_ttl value to be the
414    ///  current expiration ttl.
415    fn nx_error_with_ttl(error: &mut ProtoError, new_ttl: Duration) {
416        let ProtoError { kind, .. } = error;
417
418        if let ProtoErrorKind::NoRecordsFound { negative_ttl, .. } = kind.as_mut() {
419            *negative_ttl = Some(u32::try_from(new_ttl.as_secs()).unwrap_or(MAX_TTL));
420        }
421    }
422
423    pub(crate) fn negative(&self, query: Query, mut error: ProtoError, now: Instant) -> ProtoError {
424        let ProtoError { kind, .. } = &error;
425
426        // TODO: if we are getting a negative response, should we instead fallback to cache?
427        //   this would cache indefinitely, probably not correct
428        if let ProtoErrorKind::NoRecordsFound {
429            negative_ttl: Some(ttl),
430            ..
431        } = kind.as_ref()
432        {
433            let (negative_min_ttl, negative_max_ttl) = self
434                .ttl_config
435                .negative_response_ttl_bounds(query.query_type())
436                .into_inner();
437
438            let ttl_duration = Duration::from_secs(u64::from(*ttl))
439                // Clamp the TTL so that it's between the cache's configured
440                // minimum and maximum TTLs for negative responses.
441                .clamp(negative_min_ttl, negative_max_ttl);
442            let valid_until = now + ttl_duration;
443
444            {
445                let error = error.clone();
446
447                self.cache.insert(
448                    query,
449                    LruValue {
450                        lookup: Err(error),
451                        valid_until,
452                    },
453                );
454            }
455
456            Self::nx_error_with_ttl(&mut error, ttl_duration);
457        }
458
459        error
460    }
461
462    /// Based on the query, see if there are any records available
463    pub fn get(&self, query: &Query, now: Instant) -> Option<Result<Lookup, ProtoError>> {
464        let value = self.cache.get(query)?;
465        if !value.is_current(now) {
466            return None;
467        }
468        let mut result = value.with_updated_ttl(now).lookup;
469        if let Err(err) = &mut result {
470            Self::nx_error_with_ttl(err, value.ttl(now));
471        }
472        Some(result)
473    }
474}
475
476/// This is an alternate deserialization function for an optional [`Duration`] that expects a single
477/// number, representing the number of seconds, instead of a struct with `secs` and `nanos` fields.
478#[cfg(feature = "serde")]
479fn duration_deserialize<'de, D>(deserializer: D) -> Result<Option<Duration>, D::Error>
480where
481    D: Deserializer<'de>,
482{
483    Ok(
484        Option::<u32>::deserialize(deserializer)?
485            .map(|seconds| Duration::from_secs(seconds.into())),
486    )
487}
488
489#[cfg(feature = "serde")]
490mod ttl_config_deserialize;
491
492struct LruValueExpiry;
493
494impl Expiry<Query, LruValue> for LruValueExpiry {
495    fn expire_after_create(
496        &self,
497        _key: &Query,
498        value: &LruValue,
499        created_at: Instant,
500    ) -> Option<Duration> {
501        Some(value.ttl(created_at))
502    }
503
504    fn expire_after_update(
505        &self,
506        _key: &Query,
507        value: &LruValue,
508        updated_at: Instant,
509        _duration_until_expiry: Option<Duration>,
510    ) -> Option<Duration> {
511        Some(value.ttl(updated_at))
512    }
513}
514
515// see also the lookup_tests.rs in integration-tests crate
516#[cfg(test)]
517mod tests {
518    use std::str::FromStr;
519    use std::time::*;
520
521    use hickory_proto::rr::rdata::TXT;
522
523    use crate::proto::op::{Query, ResponseCode};
524    use crate::proto::rr::rdata::A;
525    use crate::proto::rr::{Name, RData, RecordType};
526
527    use super::*;
528
529    #[test]
530    fn test_is_current() {
531        let now = Instant::now();
532        let not_the_future = now + Duration::from_secs(4);
533        let future = now + Duration::from_secs(5);
534        let past_the_future = now + Duration::from_secs(6);
535
536        let value = LruValue {
537            lookup: Err(ProtoErrorKind::Message("test error").into()),
538            valid_until: future,
539        };
540
541        assert!(value.is_current(now));
542        assert!(value.is_current(not_the_future));
543        assert!(value.is_current(future));
544        assert!(!value.is_current(past_the_future));
545    }
546
547    #[test]
548    fn test_lookup_uses_positive_min_ttl() {
549        let now = Instant::now();
550
551        let name = Name::from_str("www.example.com.").unwrap();
552        let query = Query::query(name.clone(), RecordType::A);
553        // record should have TTL of 1 second.
554        let ips_ttl = vec![(
555            Record::from_rdata(name.clone(), 1, RData::A(A::new(127, 0, 0, 1))),
556            1,
557        )];
558        let ips = [RData::A(A::new(127, 0, 0, 1))];
559
560        // configure the cache with a minimum TTL of 2 seconds.
561        let ttls = TtlConfig {
562            default: TtlBounds {
563                positive_min_ttl: Some(Duration::from_secs(2)),
564                ..TtlBounds::default()
565            },
566            ..TtlConfig::default()
567        };
568        let lru = DnsLru::new(1, ttls);
569
570        let rc_ips = lru.insert(query.clone(), ips_ttl, now);
571        assert_eq!(*rc_ips.iter().next().unwrap(), ips[0]);
572        // the returned lookup should use the cache's min TTL, since the
573        // query's TTL was below the minimum.
574        assert_eq!(rc_ips.valid_until(), now + Duration::from_secs(2));
575
576        // record should have TTL of 3 seconds.
577        let ips_ttl = vec![(
578            Record::from_rdata(name, 3, RData::A(A::new(127, 0, 0, 1))),
579            3,
580        )];
581
582        let rc_ips = lru.insert(query, ips_ttl, now);
583        assert_eq!(*rc_ips.iter().next().unwrap(), ips[0]);
584        // the returned lookup should use the record's TTL, since it's
585        // greater than the cache's minimum.
586        assert_eq!(rc_ips.valid_until(), now + Duration::from_secs(3));
587    }
588
589    #[test]
590    fn test_error_uses_negative_min_ttl() {
591        let now = Instant::now();
592
593        let name = Query::query(Name::from_str("www.example.com.").unwrap(), RecordType::A);
594
595        // configure the cache with a maximum TTL of 2 seconds.
596        let ttls = TtlConfig {
597            default: TtlBounds {
598                negative_min_ttl: Some(Duration::from_secs(2)),
599                ..TtlBounds::default()
600            },
601            ..TtlConfig::default()
602        };
603        let lru = DnsLru::new(1, ttls);
604
605        // neg response should have TTL of 1 seconds.
606        let err = ProtoErrorKind::NoRecordsFound {
607            query: Box::new(name.clone()),
608            soa: None,
609            ns: None,
610            negative_ttl: Some(1),
611            response_code: ResponseCode::NoError,
612            trusted: false,
613            authorities: None,
614        };
615        let nx_error = lru.negative(name.clone(), err.into(), now);
616        match nx_error.kind() {
617            &ProtoErrorKind::NoRecordsFound { negative_ttl, .. } => {
618                let valid_until = negative_ttl.expect("resolve error should have a deadline");
619                // the error's `valid_until` field should have been limited to 2 seconds.
620                assert_eq!(valid_until, 2);
621            }
622            other => panic!("expected ProtoErrorKind::NoRecordsFound, got {:?}", other),
623        }
624
625        // neg response should have TTL of 3 seconds.
626        let err = ProtoErrorKind::NoRecordsFound {
627            query: Box::new(name.clone()),
628            soa: None,
629            ns: None,
630            negative_ttl: Some(3),
631            response_code: ResponseCode::NoError,
632            trusted: false,
633            authorities: None,
634        };
635        let nx_error = lru.negative(name, err.into(), now);
636        match nx_error.kind() {
637            &ProtoErrorKind::NoRecordsFound { negative_ttl, .. } => {
638                let negative_ttl = negative_ttl.expect("ProtoError should have a deadline");
639                // the error's `valid_until` field should not have been limited, as it was
640                // over the min TTL.
641                assert_eq!(negative_ttl, 3);
642            }
643            other => panic!("expected ProtoErrorKind::NoRecordsFound, got {:?}", other),
644        }
645    }
646
647    #[test]
648    fn test_lookup_uses_positive_max_ttl() {
649        let now = Instant::now();
650
651        let name = Name::from_str("www.example.com.").unwrap();
652        let query = Query::query(name.clone(), RecordType::A);
653        // record should have TTL of 62 seconds.
654        let ips_ttl = vec![(
655            Record::from_rdata(name.clone(), 62, RData::A(A::new(127, 0, 0, 1))),
656            62,
657        )];
658        let ips = [RData::A(A::new(127, 0, 0, 1))];
659
660        // configure the cache with a maximum TTL of 60 seconds.
661        let ttls = TtlConfig {
662            default: TtlBounds {
663                positive_max_ttl: Some(Duration::from_secs(60)),
664                ..TtlBounds::default()
665            },
666            ..TtlConfig::default()
667        };
668        let lru = DnsLru::new(1, ttls);
669
670        let rc_ips = lru.insert(query.clone(), ips_ttl, now);
671        assert_eq!(*rc_ips.iter().next().unwrap(), ips[0]);
672        // the returned lookup should use the cache's min TTL, since the
673        // query's TTL was above the maximum.
674        assert_eq!(rc_ips.valid_until(), now + Duration::from_secs(60));
675
676        // record should have TTL of 59 seconds.
677        let ips_ttl = vec![(
678            Record::from_rdata(name, 59, RData::A(A::new(127, 0, 0, 1))),
679            59,
680        )];
681
682        let rc_ips = lru.insert(query, ips_ttl, now);
683        assert_eq!(*rc_ips.iter().next().unwrap(), ips[0]);
684        // the returned lookup should use the record's TTL, since it's
685        // below than the cache's maximum.
686        assert_eq!(rc_ips.valid_until(), now + Duration::from_secs(59));
687    }
688
689    #[test]
690    fn test_error_uses_negative_max_ttl() {
691        let now = Instant::now();
692
693        let name = Query::query(Name::from_str("www.example.com.").unwrap(), RecordType::A);
694
695        // configure the cache with a maximum TTL of 60 seconds.
696        let ttls = TtlConfig {
697            default: TtlBounds {
698                negative_max_ttl: Some(Duration::from_secs(60)),
699                ..TtlBounds::default()
700            },
701            ..TtlConfig::default()
702        };
703        let lru = DnsLru::new(1, ttls);
704
705        // neg response should have TTL of 62 seconds.
706        let err: ProtoErrorKind = ProtoErrorKind::NoRecordsFound {
707            query: Box::new(name.clone()),
708            soa: None,
709            ns: None,
710            negative_ttl: Some(62),
711            response_code: ResponseCode::NoError,
712            trusted: false,
713            authorities: None,
714        };
715        let nx_error = lru.negative(name.clone(), err.into(), now);
716        match nx_error.kind() {
717            &ProtoErrorKind::NoRecordsFound { negative_ttl, .. } => {
718                let negative_ttl = negative_ttl.expect("resolve error should have a deadline");
719                // the error's `valid_until` field should have been limited to 60 seconds.
720                assert_eq!(negative_ttl, 60);
721            }
722            other => panic!("expected ProtoErrorKind::NoRecordsFound, got {:?}", other),
723        }
724
725        // neg response should have TTL of 59 seconds.
726        let err = ProtoErrorKind::NoRecordsFound {
727            query: Box::new(name.clone()),
728            soa: None,
729            ns: None,
730            negative_ttl: Some(59),
731            response_code: ResponseCode::NoError,
732            trusted: false,
733            authorities: None,
734        };
735        let nx_error = lru.negative(name, err.into(), now);
736        match nx_error.kind() {
737            &ProtoErrorKind::NoRecordsFound { negative_ttl, .. } => {
738                let negative_ttl = negative_ttl.expect("resolve error should have a deadline");
739                // the error's `valid_until` field should not have been limited, as it was
740                // under the max TTL.
741                assert_eq!(negative_ttl, 59);
742            }
743            other => panic!("expected ProtoErrorKind::NoRecordsFound, got {:?}", other),
744        }
745    }
746
747    #[test]
748    fn test_insert() {
749        let now = Instant::now();
750
751        let name = Name::from_str("www.example.com.").unwrap();
752        let query = Query::query(name.clone(), RecordType::A);
753        let ips_ttl = vec![(
754            Record::from_rdata(name, 1, RData::A(A::new(127, 0, 0, 1))),
755            1,
756        )];
757        let ips = [RData::A(A::new(127, 0, 0, 1))];
758        let lru = DnsLru::new(1, TtlConfig::default());
759
760        let rc_ips = lru.insert(query.clone(), ips_ttl, now);
761        assert_eq!(*rc_ips.iter().next().unwrap(), ips[0]);
762
763        let rc_ips = lru.get(&query, now).unwrap().expect("records should exist");
764        assert_eq!(*rc_ips.iter().next().unwrap(), ips[0]);
765    }
766
767    #[test]
768    fn test_update_ttl() {
769        let now = Instant::now();
770
771        let name = Name::from_str("www.example.com.").unwrap();
772        let query = Query::query(name.clone(), RecordType::A);
773        let ips_ttl = vec![(
774            Record::from_rdata(name, 10, RData::A(A::new(127, 0, 0, 1))),
775            10,
776        )];
777        let ips = [RData::A(A::new(127, 0, 0, 1))];
778        let lru = DnsLru::new(1, TtlConfig::default());
779
780        let rc_ips = lru.insert(query.clone(), ips_ttl, now);
781        assert_eq!(*rc_ips.iter().next().unwrap(), ips[0]);
782
783        let ttl = lru
784            .get(&query, now + Duration::from_secs(2))
785            .unwrap()
786            .expect("records should exist")
787            .record_iter()
788            .next()
789            .unwrap()
790            .ttl();
791        assert!(ttl <= 8);
792    }
793
794    #[test]
795    fn test_insert_ttl() {
796        let now = Instant::now();
797        let name = Name::from_str("www.example.com.").unwrap();
798        let query = Query::query(name.clone(), RecordType::A);
799        // TTL should be 1
800        let ips_ttl = vec![
801            (
802                Record::from_rdata(name.clone(), 1, RData::A(A::new(127, 0, 0, 1))),
803                1,
804            ),
805            (
806                Record::from_rdata(name, 2, RData::A(A::new(127, 0, 0, 2))),
807                2,
808            ),
809        ];
810        let ips = vec![
811            RData::A(A::new(127, 0, 0, 1)),
812            RData::A(A::new(127, 0, 0, 2)),
813        ];
814        let lru = DnsLru::new(1, TtlConfig::default());
815
816        lru.insert(query.clone(), ips_ttl, now);
817
818        // still valid
819        let rc_ips = lru
820            .get(&query, now + Duration::from_secs(1))
821            .unwrap()
822            .expect("records should exist");
823        assert_eq!(*rc_ips.iter().next().unwrap(), ips[0]);
824
825        // 2 should be one too far
826        let rc_ips = lru.get(&query, now + Duration::from_secs(2));
827        assert!(rc_ips.is_none());
828    }
829
830    #[test]
831    fn test_insert_positive_min_ttl() {
832        let now = Instant::now();
833        let name = Name::from_str("www.example.com.").unwrap();
834        let query = Query::query(name.clone(), RecordType::A);
835        // TTL should be 1
836        let ips_ttl = vec![
837            (
838                Record::from_rdata(name.clone(), 1, RData::A(A::new(127, 0, 0, 1))),
839                1,
840            ),
841            (
842                Record::from_rdata(name, 2, RData::A(A::new(127, 0, 0, 2))),
843                2,
844            ),
845        ];
846        let ips = vec![
847            RData::A(A::new(127, 0, 0, 1)),
848            RData::A(A::new(127, 0, 0, 2)),
849        ];
850
851        // this cache should override the TTL of 1 seconds with the configured
852        // minimum TTL of 3 seconds.
853        let ttls = TtlConfig {
854            default: TtlBounds {
855                positive_min_ttl: Some(Duration::from_secs(3)),
856                ..TtlBounds::default()
857            },
858            ..TtlConfig::default()
859        };
860        let lru = DnsLru::new(1, ttls);
861        lru.insert(query.clone(), ips_ttl, now);
862
863        // still valid
864        let rc_ips = lru
865            .get(&query, now + Duration::from_secs(1))
866            .unwrap()
867            .expect("records should exist");
868        for (rc_ip, ip) in rc_ips.iter().zip(ips.iter()) {
869            assert_eq!(rc_ip, ip, "after 1 second");
870        }
871
872        let rc_ips = lru
873            .get(&query, now + Duration::from_secs(2))
874            .unwrap()
875            .expect("records should exist");
876        for (rc_ip, ip) in rc_ips.iter().zip(ips.iter()) {
877            assert_eq!(rc_ip, ip, "after 2 seconds");
878        }
879
880        let rc_ips = lru
881            .get(&query, now + Duration::from_secs(3))
882            .unwrap()
883            .expect("records should exist");
884        for (rc_ip, ip) in rc_ips.iter().zip(ips.iter()) {
885            assert_eq!(rc_ip, ip, "after 3 seconds");
886        }
887
888        // after 4 seconds, the records should be invalid.
889        let rc_ips = lru.get(&query, now + Duration::from_secs(4));
890        assert!(rc_ips.is_none());
891    }
892
893    #[test]
894    fn test_insert_positive_max_ttl() {
895        let now = Instant::now();
896        let name = Name::from_str("www.example.com.").unwrap();
897        let query = Query::query(name.clone(), RecordType::A);
898        // TTL should be 500
899        let ips_ttl = vec![
900            (
901                Record::from_rdata(name.clone(), 400, RData::A(A::new(127, 0, 0, 1))),
902                400,
903            ),
904            (
905                Record::from_rdata(name, 500, RData::A(A::new(127, 0, 0, 2))),
906                500,
907            ),
908        ];
909        let ips = vec![
910            RData::A(A::new(127, 0, 0, 1)),
911            RData::A(A::new(127, 0, 0, 2)),
912        ];
913
914        // this cache should override the TTL of 500 seconds with the configured
915        // minimum TTL of 2 seconds.
916        let ttls = TtlConfig {
917            default: TtlBounds {
918                positive_max_ttl: Some(Duration::from_secs(2)),
919                ..TtlBounds::default()
920            },
921            ..TtlConfig::default()
922        };
923        let lru = DnsLru::new(1, ttls);
924        lru.insert(query.clone(), ips_ttl, now);
925
926        // still valid
927        let rc_ips = lru
928            .get(&query, now + Duration::from_secs(1))
929            .unwrap()
930            .expect("records should exist");
931        for (rc_ip, ip) in rc_ips.iter().zip(ips.iter()) {
932            assert_eq!(rc_ip, ip, "after 1 second");
933        }
934
935        let rc_ips = lru
936            .get(&query, now + Duration::from_secs(2))
937            .unwrap()
938            .expect("records should exist");
939        for (rc_ip, ip) in rc_ips.iter().zip(ips.iter()) {
940            assert_eq!(rc_ip, ip, "after 2 seconds");
941        }
942
943        // after 3 seconds, the records should be invalid.
944        let rc_ips = lru.get(&query, now + Duration::from_secs(3));
945        assert!(rc_ips.is_none());
946    }
947
948    #[test]
949    fn test_lookup_positive_min_ttl_different_query_types() {
950        let now = Instant::now();
951
952        let name = Name::from_str("www.example.com.").unwrap();
953        let query_a = Query::query(name.clone(), RecordType::A);
954        let query_txt = Query::query(name.clone(), RecordType::TXT);
955        let rdata_a = RData::A(A::new(127, 0, 0, 1));
956        let rdata_txt = RData::TXT(TXT::new(vec!["data".to_string()]));
957        // store records with a TTL of 1 second.
958        let records_ttl_a = vec![(Record::from_rdata(name.clone(), 1, rdata_a.clone()), 1)];
959        let records_ttl_txt = vec![(Record::from_rdata(name.clone(), 1, rdata_txt.clone()), 1)];
960
961        // set separate positive_min_ttl limits for TXT queries and all others
962        let mut ttl_config = TtlConfig::new(Some(Duration::from_secs(2)), None, None, None);
963        ttl_config.with_query_type_ttl_bounds(
964            RecordType::TXT,
965            Some(Duration::from_secs(5)),
966            None,
967            None,
968            None,
969        );
970        let lru = DnsLru::new(2, ttl_config);
971
972        let rc_a = lru.insert(query_a.clone(), records_ttl_a, now);
973        assert_eq!(*rc_a.iter().next().unwrap(), rdata_a);
974        // the returned lookup should use the cache's default min TTL, since the
975        // response's TTL was below the minimum.
976        assert_eq!(rc_a.valid_until(), now + Duration::from_secs(2));
977
978        let rc_txt = lru.insert(query_txt.clone(), records_ttl_txt, now);
979        assert_eq!(*rc_txt.iter().next().unwrap(), rdata_txt);
980        // the returned lookup should use the min TTL for TXT records, since the
981        // response's TTL was below the minimum.
982        assert_eq!(rc_txt.valid_until(), now + Duration::from_secs(5));
983
984        // store records with a TTL of 7 seconds.
985        let records_ttl_a = vec![(Record::from_rdata(name.clone(), 1, rdata_a.clone()), 7)];
986        let records_ttl_txt = vec![(Record::from_rdata(name.clone(), 1, rdata_txt.clone()), 7)];
987
988        let rc_a = lru.insert(query_a, records_ttl_a, now);
989        assert_eq!(*rc_a.iter().next().unwrap(), rdata_a);
990        // the returned lookup should use the record's TTL, since it's
991        // greater than the default min TTL.
992        assert_eq!(rc_a.valid_until(), now + Duration::from_secs(7));
993
994        let rc_txt = lru.insert(query_txt, records_ttl_txt, now);
995        assert_eq!(*rc_txt.iter().next().unwrap(), rdata_txt);
996        // the returned lookup should use the record's TTL, since it's
997        // greater than the min TTL for TXT records.
998        assert_eq!(rc_txt.valid_until(), now + Duration::from_secs(7));
999    }
1000}