Skip to main content

hickory_resolver/
cache.rs

1//! A cache for DNS responses.
2
3use std::{
4    collections::HashMap,
5    ops::RangeInclusive,
6    sync::Arc,
7    time::{Duration, Instant},
8};
9
10use moka::{Expiry, sync::Cache};
11#[cfg(feature = "serde")]
12use serde::Deserialize;
13
14use crate::{
15    config,
16    net::{DnsError, NetError, NoRecords},
17    proto::{
18        op::{Message, Query},
19        rr::RecordType,
20    },
21};
22
23/// A cache for DNS responses.
24#[derive(Clone, Debug)]
25pub struct ResponseCache {
26    cache: Cache<Query, Entry>,
27    ttl_config: Arc<TtlConfig>,
28}
29
30impl ResponseCache {
31    /// Construct a new response cache.
32    ///
33    /// # Arguments
34    ///
35    /// * `capacity` - size in number of cached responses
36    /// * `ttl_config` - minimum and maximum TTLs for cached records
37    pub fn new(capacity: u64, ttl_config: TtlConfig) -> Self {
38        Self {
39            cache: Cache::builder()
40                .max_capacity(capacity)
41                .expire_after(EntryExpiry)
42                .build(),
43            ttl_config: Arc::new(ttl_config),
44        }
45    }
46
47    /// Insert a response into the cache.
48    pub fn insert(&self, query: Query, result: Result<Message, NetError>, now: Instant) {
49        let (ttl, result) = match result {
50            Ok(mut message) => {
51                let ttl = self.clamp_positive_ttls(query.query_type(), &mut message);
52                (ttl, Ok(message))
53            }
54            Err(NetError::Dns(DnsError::NoRecordsFound(no_records))) => {
55                let (negative_min_ttl, negative_max_ttl) = self
56                    .ttl_config
57                    .negative_response_ttl_bounds(query.query_type())
58                    .into_inner();
59                let ttl = if let Some(ttl) = no_records.negative_ttl {
60                    Duration::from_secs(u64::from(ttl)).clamp(negative_min_ttl, negative_max_ttl)
61                } else {
62                    negative_min_ttl
63                };
64                (
65                    ttl,
66                    Err(NetError::Dns(DnsError::NoRecordsFound(no_records))),
67                )
68            }
69            Err(_) => return,
70        };
71        let valid_until = now + ttl;
72        self.cache.insert(
73            query,
74            Entry {
75                result: Arc::new(result),
76                original_time: now,
77                valid_until,
78            },
79        );
80    }
81
82    /// Try to retrieve a cached response with the given query.
83    pub fn get(&self, query: &Query, now: Instant) -> Option<Result<Message, NetError>> {
84        let entry = self.cache.get(query)?;
85        if !entry.is_current(now) {
86            return None;
87        }
88        Some(entry.updated_ttl(now))
89    }
90
91    /// Clamp all record TTLs to `[positive_min_ttl, positive_max_ttl]` and return
92    /// the cache duration derived from the minimum TTL of records matching
93    /// `query_type` across all sections.
94    ///
95    /// Each record is clamped according to the TTL bounds configured for its own
96    /// record type, so that per-type overrides are respected even for authority
97    /// and additional section records.
98    pub(crate) fn clamp_positive_ttls(
99        &self,
100        query_type: RecordType,
101        message: &mut Message,
102    ) -> Duration {
103        for record in message
104            .answers
105            .iter_mut()
106            .chain(message.authorities.iter_mut())
107            .chain(message.additionals.iter_mut())
108        {
109            let (min_secs, max_secs) = self
110                .ttl_config
111                .positive_ttl_bounds_secs(record.record_type());
112            record.ttl = record.ttl.clamp(min_secs, max_secs);
113        }
114
115        let (positive_min_ttl, positive_max_ttl) = self
116            .ttl_config
117            .positive_response_ttl_bounds(query_type)
118            .into_inner();
119
120        // Derive cache duration from the minimum TTL of records whose type
121        // matches the query, across all sections.  This avoids letting
122        // unrelated authority/additional records skew the cache lifetime.
123        let min_ttl = message
124            .all_sections()
125            .filter(|r| r.record_type() == query_type)
126            .map(|r| Duration::from_secs(r.ttl.into()))
127            .min();
128
129        min_ttl
130            .unwrap_or(positive_min_ttl)
131            .clamp(positive_min_ttl, positive_max_ttl)
132    }
133
134    pub(crate) fn clear(&self) {
135        self.cache.invalidate_all();
136    }
137
138    pub(crate) fn clear_query(&self, query: &Query) {
139        self.cache.invalidate(query);
140    }
141
142    /// Returns the approximate number of entries in the cache.
143    #[cfg(feature = "metrics")]
144    pub(crate) fn entry_count(&self) -> u64 {
145        #[cfg(test)]
146        {
147            // For tests, ensure pending tasks are processed before getting the count.
148            // This allows unit tests of the respective cache size metrics to be
149            // written without flakyness. In a production context, we're happier
150            // to defer background work and to return an approximate count.
151            self.cache.run_pending_tasks();
152        }
153
154        self.cache.entry_count()
155    }
156}
157
158/// An entry in the response cache.
159///
160/// This contains the response itself (or an error), the time it was received, and the time at which
161/// it expires.
162#[derive(Debug, Clone)]
163struct Entry {
164    result: Arc<Result<Message, NetError>>,
165    original_time: Instant,
166    valid_until: Instant,
167}
168
169impl Entry {
170    /// Return the `Result` stored in this entry, with modified TTLs, subtracting the elapsed time
171    /// since the response was received.
172    fn updated_ttl(&self, now: Instant) -> Result<Message, NetError> {
173        let elapsed = u32::try_from(now.saturating_duration_since(self.original_time).as_secs())
174            .unwrap_or(u32::MAX);
175        match &*self.result {
176            Ok(response) => {
177                let mut response = response.clone();
178                for records in [
179                    &mut response.answers,
180                    &mut response.authorities,
181                    &mut response.additionals,
182                ] {
183                    for record in records {
184                        record.decrement_ttl(elapsed);
185                    }
186                }
187                Ok(response)
188            }
189            Err(e) => {
190                let mut e = e.clone();
191
192                // The NoRecords error may contain up to four fields with TTL values present: negative_ttl, soa, authorities, and ns.
193                // For completeness, we update each field, if present.
194                if let NetError::Dns(DnsError::NoRecordsFound(NoRecords {
195                    negative_ttl,
196                    soa,
197                    authorities,
198                    ns,
199                    ..
200                })) = &mut e
201                {
202                    if let Some(ttl) = negative_ttl {
203                        *ttl = ttl.saturating_sub(elapsed);
204                    }
205
206                    if let Some(soa) = soa {
207                        soa.decrement_ttl(elapsed);
208                    }
209
210                    if let Some(recs) = authorities.take() {
211                        authorities.replace(Arc::from(
212                            recs.iter()
213                                .cloned()
214                                .map(|mut rec| {
215                                    rec.decrement_ttl(elapsed);
216                                    rec
217                                })
218                                .collect::<Vec<_>>(),
219                        ));
220                    }
221
222                    if let Some(ns_recs) = ns.take() {
223                        ns.replace(Arc::from(
224                            ns_recs
225                                .iter()
226                                .cloned()
227                                .map(|mut ns| {
228                                    ns.ns.decrement_ttl(elapsed);
229                                    ns.glue = Arc::from(
230                                        ns.glue
231                                            .iter()
232                                            .cloned()
233                                            .map(|mut glue| {
234                                                glue.decrement_ttl(elapsed);
235                                                glue
236                                            })
237                                            .collect::<Vec<_>>(),
238                                    );
239
240                                    ns
241                                })
242                                .collect::<Vec<_>>(),
243                        ));
244                    }
245                }
246                Err(e)
247            }
248        }
249    }
250
251    /// Returns whether this cache entry is still valid.
252    fn is_current(&self, now: Instant) -> bool {
253        now <= self.valid_until
254    }
255
256    /// Returns the remaining time that this cache entry is valid for.
257    fn ttl(&self, now: Instant) -> Duration {
258        self.valid_until.saturating_duration_since(now)
259    }
260}
261
262struct EntryExpiry;
263
264impl Expiry<Query, Entry> for EntryExpiry {
265    fn expire_after_create(
266        &self,
267        _key: &Query,
268        value: &Entry,
269        created_at: Instant,
270    ) -> Option<Duration> {
271        Some(value.ttl(created_at))
272    }
273
274    fn expire_after_update(
275        &self,
276        _key: &Query,
277        value: &Entry,
278        updated_at: Instant,
279        _duration_until_expiry: Option<Duration>,
280    ) -> Option<Duration> {
281        Some(value.ttl(updated_at))
282    }
283}
284
285/// The time-to-live (TTL) configuration used by the cache.
286///
287/// Minimum and maximum TTLs can be set for both positive responses and negative responses. Separate
288/// limits may be set depending on the query type. If a minimum value is not provided, it will
289/// default to 0 seconds. If a maximum value is not provided, it will default to one day.
290///
291/// Note that TTLs in DNS are represented as a number of seconds stored in a 32-bit unsigned
292/// integer. We use `Duration` here, instead of `u32`, which can express larger values than the DNS
293/// standard. Generally, a `Duration` greater than `u32::MAX_VALUE` shouldn't cause any issue, as
294/// this will never be used in serialization, but note that this would be outside the standard
295/// range.
296#[derive(Clone, Debug, Default, PartialEq, Eq)]
297#[cfg_attr(feature = "serde", derive(Deserialize))]
298#[cfg_attr(
299    feature = "serde",
300    serde(from = "ttl_config_deserialize::TtlConfigMap")
301)]
302pub struct TtlConfig {
303    /// TTL limits applied to all queries.
304    default: TtlBounds,
305
306    /// TTL limits applied to queries with specific query types.
307    by_query_type: HashMap<RecordType, TtlBounds>,
308}
309
310impl TtlConfig {
311    /// Construct the LRU's TTL configuration based on the ResolverOpts configuration.
312    pub fn from_opts(opts: &config::ResolverOpts) -> Self {
313        Self::from(TtlBounds {
314            positive_min_ttl: opts.positive_min_ttl,
315            negative_min_ttl: opts.negative_min_ttl,
316            positive_max_ttl: opts.positive_max_ttl,
317            negative_max_ttl: opts.negative_max_ttl,
318        })
319    }
320
321    /// Override the minimum and maximum TTL values for a specific query type.
322    ///
323    /// If a minimum value is not provided, it will default to 0 seconds. If a maximum value is not
324    /// provided, it will default to one day.
325    pub fn with_query_type_ttl_bounds(
326        &mut self,
327        query_type: RecordType,
328        bounds: TtlBounds,
329    ) -> &mut Self {
330        self.by_query_type.insert(query_type, bounds);
331        self
332    }
333
334    /// Returns the positive-response TTL bounds as `(min_secs, max_secs)` clamped to `u32`.
335    ///
336    /// This is a convenience wrapper around [`positive_response_ttl_bounds`](Self::positive_response_ttl_bounds)
337    /// for use when clamping individual record TTLs.
338    fn positive_ttl_bounds_secs(&self, record_type: RecordType) -> (u32, u32) {
339        let (min, max) = self.positive_response_ttl_bounds(record_type).into_inner();
340        (
341            u32::try_from(min.as_secs()).unwrap_or(MAX_TTL),
342            u32::try_from(max.as_secs()).unwrap_or(MAX_TTL),
343        )
344    }
345
346    /// Retrieves the minimum and maximum TTL values for positive responses.
347    pub fn positive_response_ttl_bounds(&self, query_type: RecordType) -> RangeInclusive<Duration> {
348        let bounds = self.by_query_type.get(&query_type).unwrap_or(&self.default);
349        let min = bounds
350            .positive_min_ttl
351            .unwrap_or_else(|| Duration::from_secs(0));
352        let max = bounds
353            .positive_max_ttl
354            .unwrap_or_else(|| Duration::from_secs(u64::from(MAX_TTL)));
355        min..=max
356    }
357
358    /// Retrieves the minimum and maximum TTL values for negative responses.
359    pub fn negative_response_ttl_bounds(&self, query_type: RecordType) -> RangeInclusive<Duration> {
360        let bounds = self.by_query_type.get(&query_type).unwrap_or(&self.default);
361        let min = bounds
362            .negative_min_ttl
363            .unwrap_or_else(|| Duration::from_secs(0));
364        let max = bounds
365            .negative_max_ttl
366            .unwrap_or_else(|| Duration::from_secs(u64::from(MAX_TTL)));
367        min..=max
368    }
369}
370
371impl From<TtlBounds> for TtlConfig {
372    fn from(default: TtlBounds) -> Self {
373        Self {
374            default,
375            by_query_type: HashMap::default(),
376        }
377    }
378}
379
380/// Minimum and maximum TTL values for positive and negative responses.
381#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
382#[cfg_attr(feature = "serde", derive(Deserialize))]
383#[cfg_attr(feature = "serde", serde(deny_unknown_fields))]
384pub struct TtlBounds {
385    /// An optional minimum TTL value for positive responses.
386    ///
387    /// Positive responses with TTLs under `positive_min_ttl` will use
388    /// `positive_min_ttl` instead.
389    #[cfg_attr(
390        feature = "serde",
391        serde(default, deserialize_with = "config::duration_opt::deserialize")
392    )]
393    positive_min_ttl: Option<Duration>,
394
395    /// An optional minimum TTL value for negative (`NXDOMAIN`) responses.
396    ///
397    /// `NXDOMAIN` responses with TTLs under `negative_min_ttl` will use
398    /// `negative_min_ttl` instead.
399    #[cfg_attr(
400        feature = "serde",
401        serde(default, deserialize_with = "config::duration_opt::deserialize")
402    )]
403    negative_min_ttl: Option<Duration>,
404
405    /// An optional maximum TTL value for positive responses.
406    ///
407    /// Positive responses with TTLs over `positive_max_ttl` will use
408    /// `positive_max_ttl` instead.
409    #[cfg_attr(
410        feature = "serde",
411        serde(default, deserialize_with = "config::duration_opt::deserialize")
412    )]
413    positive_max_ttl: Option<Duration>,
414
415    /// An optional maximum TTL value for negative (`NXDOMAIN`) responses.
416    ///
417    /// `NXDOMAIN` responses with TTLs over `negative_max_ttl` will use
418    /// `negative_max_ttl` instead.
419    #[cfg_attr(
420        feature = "serde",
421        serde(default, deserialize_with = "config::duration_opt::deserialize")
422    )]
423    negative_max_ttl: Option<Duration>,
424}
425
426#[cfg(feature = "serde")]
427mod ttl_config_deserialize {
428    use std::collections::HashMap;
429
430    use serde::Deserialize;
431
432    use super::{TtlBounds, TtlConfig};
433    use crate::proto::rr::RecordType;
434
435    #[derive(Deserialize)]
436    pub(super) struct TtlConfigMap(HashMap<TtlConfigField, TtlBounds>);
437
438    impl From<TtlConfigMap> for TtlConfig {
439        fn from(value: TtlConfigMap) -> Self {
440            let mut default = TtlBounds::default();
441            let mut by_query_type = HashMap::new();
442            for (field, bounds) in value.0.into_iter() {
443                match field {
444                    TtlConfigField::RecordType(record_type) => {
445                        by_query_type.insert(record_type, bounds);
446                    }
447                    TtlConfigField::Default => default = bounds,
448                }
449            }
450            Self {
451                default,
452                by_query_type,
453            }
454        }
455    }
456
457    #[derive(PartialEq, Eq, Hash, Deserialize)]
458    enum TtlConfigField {
459        #[serde(rename = "default")]
460        Default,
461        #[serde(untagged)]
462        RecordType(RecordType),
463    }
464}
465
466/// Maximum TTL. This is set to one day (in seconds).
467///
468/// [RFC 2181, section 8](https://tools.ietf.org/html/rfc2181#section-8) says
469/// that the maximum TTL value is 2147483647, but implementations may place an
470/// upper bound on received TTLs.
471pub const MAX_TTL: u32 = 86400_u32;
472
473#[cfg(test)]
474mod tests {
475    use std::{
476        str::FromStr,
477        time::{Duration, Instant},
478    };
479
480    #[cfg(feature = "serde")]
481    use serde::Deserialize;
482
483    use super::*;
484    use crate::{
485        net::{ForwardNSData, NetError},
486        proto::{
487            op::{Message, OpCode, Query, ResponseCode},
488            rr::{
489                Name, RData, Record, RecordType,
490                rdata::{A, AAAA, NS, SOA, TXT},
491            },
492        },
493    };
494    use test_support::subscribe;
495
496    #[test]
497    fn test_is_current() {
498        let now = Instant::now();
499        let not_the_future = now + Duration::from_secs(4);
500        let future = now + Duration::from_secs(5);
501        let past_the_future = now + Duration::from_secs(6);
502
503        let entry = Entry {
504            result: Err(NetError::Message("test error")).into(),
505            original_time: now,
506            valid_until: future,
507        };
508
509        assert!(entry.is_current(now));
510        assert!(entry.is_current(not_the_future));
511        assert!(entry.is_current(future));
512        assert!(!entry.is_current(past_the_future));
513    }
514
515    #[test]
516    fn test_positive_min_ttl() {
517        let now = Instant::now();
518
519        let name = Name::from_str("www.example.com.").unwrap();
520        let query = Query::query(name.clone(), RecordType::A);
521        // Record should have TTL of 1 second.
522        let mut message = Message::response(0, OpCode::Query);
523        message.add_answer(Record::from_rdata(
524            name.clone(),
525            1,
526            RData::A(A::new(127, 0, 0, 1)),
527        ));
528
529        // Configure the cache with a minimum TTL of 2 seconds.
530        let ttls = TtlConfig::from(TtlBounds {
531            positive_min_ttl: Some(Duration::from_secs(2)),
532            ..TtlBounds::default()
533        });
534        let cache = ResponseCache::new(1, ttls);
535
536        cache.insert(query.clone(), Ok(message), now);
537        let valid_until = cache.cache.get(&query).unwrap().valid_until;
538        // The returned lookup should use the cache's minimum TTL, since the
539        // query's TTL was below the minimum.
540        assert_eq!(valid_until, now + Duration::from_secs(2));
541
542        // Record should have TTL of 3 seconds.
543        let mut message = Message::response(0, OpCode::Query);
544        message.add_answer(Record::from_rdata(
545            name.clone(),
546            3,
547            RData::A(A::new(127, 0, 0, 1)),
548        ));
549
550        cache.insert(query.clone(), Ok(message), now);
551        let valid_until = cache.cache.get(&query).unwrap().valid_until;
552        // The returned lookup should use the record's TTL, since it's
553        // greater than the cache's minimum.
554        assert_eq!(valid_until, now + Duration::from_secs(3));
555    }
556
557    #[test]
558    fn test_positive_min_ttl_clamps_record_ttls() {
559        // Regression test: records with TTLs below positive_min_ttl must have their
560        // TTLs raised in the cached message. Otherwise `updated_ttl()` subtracts
561        // elapsed time from the original (low) TTL, which saturates to 0 long before
562        // the cache entry expires.
563        let now = Instant::now();
564
565        let name = Name::from_str("www.example.com.").unwrap();
566        let query = Query::query(name.clone(), RecordType::A);
567
568        // Upstream record has TTL=60, but positive_min_ttl is 3600.
569        let mut message = Message::response(0, OpCode::Query);
570        message.add_answer(Record::from_rdata(
571            name.clone(),
572            60,
573            RData::A(A::new(93, 184, 216, 34)),
574        ));
575
576        let ttls = TtlConfig::from(TtlBounds {
577            positive_min_ttl: Some(Duration::from_secs(3600)),
578            ..TtlBounds::default()
579        });
580        let cache = ResponseCache::new(1, ttls);
581
582        cache.insert(query.clone(), Ok(message), now);
583
584        // The cache stores the record with the clamped TTL (3600). At t=0 that is
585        // what clients receive.
586        let result = cache.get(&query, now).unwrap().unwrap();
587        assert_eq!(result.answers.first().unwrap().ttl, 3600);
588
589        // At t=61 the returned TTL counts down from the cached 3600, not the
590        // upstream 60.
591        let result = cache
592            .get(&query, now + Duration::from_secs(61))
593            .unwrap()
594            .unwrap();
595        assert_eq!(result.answers.first().unwrap().ttl, 3539);
596
597        // At t=3599: still valid, TTL=1.
598        let result = cache
599            .get(&query, now + Duration::from_secs(3599))
600            .unwrap()
601            .unwrap();
602        assert_eq!(result.answers.first().unwrap().ttl, 1);
603
604        // At t=3601: cache miss — a new upstream lookup will be issued.
605        assert!(cache.get(&query, now + Duration::from_secs(3601)).is_none());
606    }
607
608    #[test]
609    fn test_positive_max_ttl_clamps_record_ttls() {
610        // Regression test: records with TTLs above positive_max_ttl must have their
611        // TTLs lowered in the cached message. Otherwise clients see the original high
612        // TTL while the cache entry expires at positive_max_ttl, causing the TTL to
613        // appear to "reset" after every max_ttl interval.
614        let now = Instant::now();
615
616        let name = Name::from_str("www.example.com.").unwrap();
617        let query = Query::query(name.clone(), RecordType::A);
618
619        // Upstream record has TTL=3600, but positive_max_ttl is 120.
620        let mut message = Message::response(0, OpCode::Query);
621        message.add_answer(Record::from_rdata(
622            name.clone(),
623            3600,
624            RData::A(A::new(93, 184, 216, 34)),
625        ));
626
627        let ttls = TtlConfig::from(TtlBounds {
628            positive_max_ttl: Some(Duration::from_secs(120)),
629            ..TtlBounds::default()
630        });
631        let cache = ResponseCache::new(1, ttls);
632
633        cache.insert(query.clone(), Ok(message), now);
634
635        // The cache stores the record with the clamped TTL (120), not the
636        // upstream 3600.
637        let result = cache.get(&query, now).unwrap().unwrap();
638        assert_eq!(result.answers.first().unwrap().ttl, 120);
639
640        // At t=60 the returned TTL counts down from the cached 120.
641        let result = cache
642            .get(&query, now + Duration::from_secs(60))
643            .unwrap()
644            .unwrap();
645        assert_eq!(result.answers.first().unwrap().ttl, 60);
646
647        // At t=121: cache miss.
648        assert!(cache.get(&query, now + Duration::from_secs(121)).is_none());
649    }
650
651    #[test]
652    fn test_authority_ttl_does_not_shorten_answer_cache() {
653        // Regression test: authority section records (NS, SOA) with short TTLs must
654        // not reduce the cache lifetime of positive answers when answer records
655        // are present.
656        let now = Instant::now();
657
658        let name = Name::from_str("api.example.com.").unwrap();
659        let query = Query::query(name.clone(), RecordType::AAAA);
660
661        let mut message = Message::response(0, OpCode::Query);
662        // Answer: AAAA record with TTL=120
663        message.add_answer(Record::from_rdata(
664            name.clone(),
665            120,
666            RData::AAAA(AAAA::new(0x2001, 0x0db8, 0, 0, 0, 0, 0, 1)),
667        ));
668        // Authority: NS record with TTL=30 (much shorter)
669        message.add_authority(Record::from_rdata(
670            Name::from_str("example.com.").unwrap(),
671            30,
672            RData::NS(NS(Name::from_str("ns1.example.com.").unwrap())),
673        ));
674
675        let ttls = TtlConfig::from(TtlBounds {
676            positive_min_ttl: Some(Duration::from_secs(3600)),
677            positive_max_ttl: Some(Duration::from_secs(28800)),
678            ..TtlBounds::default()
679        });
680        let cache = ResponseCache::new(1, ttls);
681
682        cache.insert(query.clone(), Ok(message), now);
683
684        // Cache should be valid for 3600s (answer TTL=120 raised to min_ttl=3600),
685        // NOT 30s from the authority NS record.
686        let valid_until = cache.cache.get(&query).unwrap().valid_until;
687        assert_eq!(valid_until, now + Duration::from_secs(3600));
688
689        // At t=130 (past the original authority TTL of 30): still a cache hit.
690        let result = cache
691            .get(&query, now + Duration::from_secs(130))
692            .unwrap()
693            .unwrap();
694        // AAAA answer TTL counts down from 3600.
695        assert_eq!(result.answers.first().unwrap().ttl, 3470);
696
697        // At t=3601: cache miss.
698        assert!(cache.get(&query, now + Duration::from_secs(3601)).is_none());
699    }
700
701    #[test]
702    fn test_negative_min_ttl() {
703        let now = Instant::now();
704
705        let name = Name::from_str("www.example.com.").unwrap();
706        let query = Query::query(name.clone(), RecordType::A);
707
708        // Configure the cache with a minimum TTL of 2 seconds.
709        let ttls = TtlConfig::from(TtlBounds {
710            negative_min_ttl: Some(Duration::from_secs(2)),
711            ..TtlBounds::default()
712        });
713        let cache = ResponseCache::new(1, ttls);
714
715        // Negative response should have TTL of 1 second.
716        let mut no_records = NoRecords::new(query.clone(), ResponseCode::NoError);
717        no_records.negative_ttl = Some(1);
718        cache.insert(query.clone(), Err(no_records.into()), now);
719        let valid_until = cache.cache.get(&query).unwrap().valid_until;
720        // The error's `valid_until` field should have been limited to 2 seconds.
721        assert_eq!(valid_until, now + Duration::from_secs(2));
722
723        // Negative response should have TTL of 3 seconds.
724        let mut no_records = NoRecords::new(query.clone(), ResponseCode::NoError);
725        no_records.negative_ttl = Some(3);
726        cache.insert(query.clone(), Err(no_records.into()), now);
727        let valid_until = cache.cache.get(&query).unwrap().valid_until;
728        // The error's `valid_until` field should not have been limited, as it was over the minimum
729        // TTL.
730        assert_eq!(valid_until, now + Duration::from_secs(3));
731    }
732
733    #[test]
734    fn test_positive_max_ttl() {
735        let now = Instant::now();
736
737        let name = Name::from_str("www.example.com.").unwrap();
738        let query = Query::query(name.clone(), RecordType::A);
739        // Record should have TTL of 62 seconds.
740        let mut message = Message::response(0, OpCode::Query);
741        message.add_answer(Record::from_rdata(
742            name.clone(),
743            62,
744            RData::A(A::new(127, 0, 0, 1)),
745        ));
746
747        // Configure the cache with a maximum TTL of 60 seconds.
748        let ttls = TtlConfig::from(TtlBounds {
749            positive_max_ttl: Some(Duration::from_secs(60)),
750            ..Default::default()
751        });
752        let cache = ResponseCache::new(1, ttls);
753
754        cache.insert(query.clone(), Ok(message), now);
755        let valid_until = cache.cache.get(&query).unwrap().valid_until;
756        // The returned lookup should use the cache's minimum TTL, since the
757        // query's TTL was above the maximum.
758        assert_eq!(valid_until, now + Duration::from_secs(60));
759
760        // Record should have TTL of 59 seconds.
761        let mut message = Message::response(0, OpCode::Query);
762        message.add_answer(Record::from_rdata(
763            name.clone(),
764            59,
765            RData::A(A::new(127, 0, 0, 1)),
766        ));
767
768        cache.insert(query.clone(), Ok(message), now);
769        let valid_until = cache.cache.get(&query).unwrap().valid_until;
770        // The returned lookup should use the record's TTL, since it's
771        // below than the cache's maximum.
772        assert_eq!(valid_until, now + Duration::from_secs(59));
773    }
774
775    #[test]
776    fn test_negative_max_ttl() {
777        let now = Instant::now();
778
779        let name = Name::from_str("www.example.com.").unwrap();
780        let query = Query::query(name.clone(), RecordType::A);
781
782        // Configure the cache with a maximum TTL of 60 seconds.
783        let ttls = TtlConfig::from(TtlBounds {
784            negative_max_ttl: Some(Duration::from_secs(60)),
785            ..TtlBounds::default()
786        });
787        let cache = ResponseCache::new(1, ttls);
788
789        // Negative response should have TTL of 62 seconds.
790        let mut no_records = NoRecords::new(query.clone(), ResponseCode::NoError);
791        no_records.negative_ttl = Some(62);
792        cache.insert(query.clone(), Err(no_records.into()), now);
793        let valid_until = cache.cache.get(&query).unwrap().valid_until;
794        // The error's `valid_until` field should have been limited to 60 seconds.
795        assert_eq!(valid_until, now + Duration::from_secs(60));
796
797        // Negative response should have TTL of 59 seconds.
798        let mut no_records = NoRecords::new(query.clone(), ResponseCode::NoError);
799        no_records.negative_ttl = Some(59);
800        cache.insert(query.clone(), Err(no_records.into()), now);
801        let valid_until = cache.cache.get(&query).unwrap().valid_until;
802        // The error's `valid_until` field should not have been limited, as it was under the maximum
803        // TTL.
804        assert_eq!(valid_until, now + Duration::from_secs(59));
805    }
806
807    #[test]
808    fn test_insert() {
809        let now = Instant::now();
810
811        let name = Name::from_str("www.example.com.").unwrap();
812        let query = Query::query(name.clone(), RecordType::A);
813        let mut message = Message::response(0, OpCode::Query);
814        message.add_answer(Record::from_rdata(
815            name.clone(),
816            1,
817            RData::A(A::new(127, 0, 0, 1)),
818        ));
819        let cache = ResponseCache::new(1, TtlConfig::default());
820        cache.insert(query.clone(), Ok(message.clone()), now);
821
822        let result = cache.get(&query, now).unwrap();
823        let cache_message = result.unwrap();
824        assert_eq!(cache_message.answers, message.answers);
825    }
826
827    #[test]
828    fn test_insert_negative() {
829        subscribe();
830        let now = Instant::now();
831
832        let query = Query::query(
833            Name::from_str("www.example.com.").unwrap(),
834            RecordType::AAAA,
835        );
836
837        let mut norecs = NoRecords::new(query.clone(), ResponseCode::NXDomain);
838        norecs.negative_ttl = Some(10);
839        let error = NetError::from(norecs);
840        let cache = ResponseCache::new(1, TtlConfig::default());
841
842        cache.insert(query.clone(), Err(error), now);
843
844        let cache_err = cache.get(&query, now).unwrap().unwrap_err();
845        let NetError::Dns(DnsError::NoRecordsFound(_no_records)) = &cache_err else {
846            panic!("expected NoRecordsFound");
847        };
848
849        // Cache should be expired
850        assert!(cache.get(&query, now + Duration::from_secs(11)).is_none());
851    }
852
853    #[test]
854    fn test_update_ttl() {
855        let now = Instant::now();
856
857        let name = Name::from_str("www.example.com.").unwrap();
858        let query = Query::query(name.clone(), RecordType::A);
859        let mut message = Message::response(0, OpCode::Query);
860        message.add_answer(Record::from_rdata(
861            name.clone(),
862            10,
863            RData::A(A::new(127, 0, 0, 1)),
864        ));
865        let cache = ResponseCache::new(1, TtlConfig::default());
866        cache.insert(query.clone(), Ok(message), now);
867
868        let result = cache.get(&query, now + Duration::from_secs(2)).unwrap();
869        let cache_message = result.unwrap();
870        let record = cache_message.answers.first().unwrap();
871        assert_eq!(record.ttl, 8);
872    }
873
874    #[test]
875    fn test_update_ttl_negative() -> Result<(), NetError> {
876        subscribe();
877        let now = Instant::now();
878        let name = Name::from_str("www.example.com.")?;
879        let ns_name = Name::from_str("ns1.example.com")?;
880        let zone_name = name.base_name();
881        let query = Query::query(name.clone(), RecordType::AAAA);
882
883        let mut norecs = NoRecords::new(query.clone(), ResponseCode::NXDomain);
884        norecs.negative_ttl = Some(10);
885        norecs.soa = Some(Box::new(Record::from_rdata(
886            zone_name.clone(),
887            10,
888            SOA::new(name.base_name(), name.clone(), 1, 1, 1, 1, 1),
889        )));
890        norecs.authorities = Some(Arc::new([Record::from_rdata(
891            zone_name.clone(),
892            10,
893            RData::NS(NS(ns_name.clone())),
894        )]));
895        norecs.ns = Some(Arc::new([ForwardNSData {
896            ns: Record::from_rdata(zone_name.clone(), 10, RData::NS(NS(ns_name.clone()))),
897            glue: Arc::new([Record::from_rdata(
898                ns_name.clone(),
899                10,
900                RData::A(A([192, 0, 2, 1].into())),
901            )]),
902        }]));
903
904        let error = NetError::from(norecs);
905
906        let cache = ResponseCache::new(1, TtlConfig::default());
907        cache.insert(query.clone(), Err(error), now);
908
909        let cache_err = cache.get(&query, now).unwrap().unwrap_err();
910        let NetError::Dns(DnsError::NoRecordsFound(no_records)) = &cache_err else {
911            panic!("expected NoRecordsFound");
912        };
913
914        let Some(soa) = no_records.soa.clone() else {
915            panic!("no SOA in NoRecordsFound");
916        };
917        assert_eq!(soa.ttl, 10);
918
919        let cache_err = cache
920            .get(&query, now + Duration::from_secs(2))
921            .unwrap()
922            .unwrap_err();
923        let NetError::Dns(DnsError::NoRecordsFound(NoRecords {
924            negative_ttl: Some(negative_ttl),
925            soa: Some(soa),
926            authorities: Some(authorities),
927            ns: Some(ns),
928            ..
929        })) = &cache_err
930        else {
931            panic!("expected NoRecordsFound with negative_ttl, soa, authorities, and ns");
932        };
933
934        assert_eq!(*negative_ttl, 8);
935        assert_eq!(soa.ttl, 8);
936        assert_eq!(authorities[0].ttl, 8);
937        assert_eq!(ns[0].ns.ttl, 8);
938
939        // Cache should be expired
940        assert!(cache.get(&query, now + Duration::from_secs(11)).is_none());
941        Ok(())
942    }
943
944    #[test]
945    fn test_insert_ttl() {
946        let now = Instant::now();
947
948        let name = Name::from_str("www.example.com.").unwrap();
949        let query = Query::query(name.clone(), RecordType::A);
950
951        // TTL of entry should be 1.
952        let mut message = Message::response(0, OpCode::Query);
953        message.add_answer(Record::from_rdata(
954            name.clone(),
955            1,
956            RData::A(A::new(127, 0, 0, 1)),
957        ));
958        message.add_answer(Record::from_rdata(name, 2, RData::A(A::new(127, 0, 0, 2))));
959
960        let cache = ResponseCache::new(1, TtlConfig::default());
961        cache.insert(query.clone(), Ok(message), now);
962
963        // Entry is still valid.
964        cache
965            .get(&query, now + Duration::from_secs(1))
966            .unwrap()
967            .unwrap();
968
969        // Entry is expired.
970        let option = cache.get(&query, now + Duration::from_secs(2));
971        assert!(option.is_none());
972    }
973
974    #[test]
975    fn test_ttl_different_query_types() {
976        let now = Instant::now();
977        let name = Name::from_str("www.example.com.").unwrap();
978
979        // Store records with a TTL of 1 second.
980        let query_a = Query::query(name.clone(), RecordType::A);
981        let rdata_a = RData::A(A::new(127, 0, 0, 1));
982        let mut message_a = Message::response(0, OpCode::Query);
983        message_a.add_answer(Record::from_rdata(name.clone(), 1, rdata_a.clone()));
984
985        let query_txt = Query::query(name.clone(), RecordType::TXT);
986        let rdata_txt = RData::TXT(TXT::new(vec!["data".to_string()]));
987        let mut message_txt = Message::response(0, OpCode::Query);
988        message_txt.add_answer(Record::from_rdata(name.clone(), 1, rdata_txt.clone()));
989
990        // Set separate positive_min_ttl limits for TXT queries and all others.
991        let mut ttl_config = TtlConfig::from(TtlBounds {
992            positive_min_ttl: Some(Duration::from_secs(2)),
993            ..TtlBounds::default()
994        });
995        ttl_config.with_query_type_ttl_bounds(
996            RecordType::TXT,
997            TtlBounds {
998                positive_min_ttl: Some(Duration::from_secs(5)),
999                ..TtlBounds::default()
1000            },
1001        );
1002        let cache = ResponseCache::new(2, ttl_config);
1003
1004        cache.insert(query_a.clone(), Ok(message_a), now);
1005        // This should use the cache's default minimum TTL, since the record's TTL was below the
1006        // minimum.
1007        assert_eq!(
1008            cache.cache.get(&query_a).unwrap().valid_until,
1009            now + Duration::from_secs(2)
1010        );
1011
1012        cache.insert(query_txt.clone(), Ok(message_txt), now);
1013        // This should use the minimum for TTL records, since the record's TTL was below the
1014        // minimum.
1015        assert_eq!(
1016            cache.cache.get(&query_txt).unwrap().valid_until,
1017            now + Duration::from_secs(5)
1018        );
1019
1020        // store records with a TTL of 7 seconds.
1021        let mut message_a = Message::response(0, OpCode::Query);
1022        message_a.add_answer(Record::from_rdata(name.clone(), 7, rdata_a));
1023
1024        let mut message_txt = Message::response(0, OpCode::Query);
1025        message_txt.add_answer(Record::from_rdata(name.clone(), 7, rdata_txt));
1026
1027        cache.insert(query_a.clone(), Ok(message_a), now);
1028        // This should use the record's TTL, since it's greater than the default minimum TTL.
1029        assert_eq!(
1030            cache.cache.get(&query_a).unwrap().valid_until,
1031            now + Duration::from_secs(7)
1032        );
1033
1034        cache.insert(query_txt.clone(), Ok(message_txt), now);
1035        // This should use the record's TTL, since it's greater than the minimum TTL for TXT records.
1036        assert_eq!(
1037            cache.cache.get(&query_txt).unwrap().valid_until,
1038            now + Duration::from_secs(7)
1039        );
1040    }
1041
1042    #[cfg(feature = "serde")]
1043    #[test]
1044    fn ttl_config_deserialize_errors() {
1045        // Duplicate of "default"
1046        let input = r#"[default]
1047positive_max_ttl = 3600
1048[default]
1049positive_max_ttl = 3599"#;
1050        let error = toml::from_str::<TtlConfig>(input).unwrap_err();
1051        assert!(
1052            error.message().contains("duplicate key"),
1053            "wrong error message: {error}"
1054        );
1055
1056        // Duplicate of a record type
1057        let input = r#"[default]
1058positive_max_ttl = 86400
1059[OPENPGPKEY]
1060positive_max_ttl = 3600
1061[OPENPGPKEY]
1062negative_min_ttl = 60"#;
1063        let error = toml::from_str::<TtlConfig>(input).unwrap_err();
1064        assert!(
1065            error.message().contains("duplicate key"),
1066            "wrong error message: {error}"
1067        );
1068
1069        // Neither "default" nor a record type
1070        let input = r#"[not_a_record_type]
1071positive_max_ttl = 3600"#;
1072        let error = toml::from_str::<TtlConfig>(input).unwrap_err();
1073        assert!(
1074            error.message().contains("data did not match any variant"),
1075            "wrong error message: {error}"
1076        );
1077
1078        // Array instead of table
1079        #[derive(Debug, Deserialize)]
1080        struct Wrapper {
1081            #[allow(unused)]
1082            cache_policy: TtlConfig,
1083        }
1084        let input = r#"cache_policy = []"#;
1085        let error = toml::from_str::<Wrapper>(input).unwrap_err();
1086        assert!(
1087            error.message().contains("invalid type: sequence"),
1088            "wrong error message: {error}"
1089        );
1090
1091        // String instead of table
1092        let input = r#"cache_policy = "yes""#;
1093        let error = toml::from_str::<Wrapper>(input).unwrap_err();
1094        assert!(
1095            error.message().contains("invalid type: string"),
1096            "wrong error message: {error}"
1097        );
1098    }
1099}