1use std::{
4 collections::HashMap,
5 ops::RangeInclusive,
6 sync::Arc,
7 time::{Duration, Instant},
8};
9
10use moka::{Expiry, sync::Cache};
11#[cfg(feature = "serde")]
12use serde::Deserialize;
13
14use crate::{
15 config,
16 net::{DnsError, NetError, NoRecords},
17 proto::{
18 op::{Message, Query},
19 rr::RecordType,
20 },
21};
22
23#[derive(Clone, Debug)]
25pub struct ResponseCache {
26 cache: Cache<Query, Entry>,
27 ttl_config: Arc<TtlConfig>,
28}
29
30impl ResponseCache {
31 pub fn new(capacity: u64, ttl_config: TtlConfig) -> Self {
38 Self {
39 cache: Cache::builder()
40 .max_capacity(capacity)
41 .expire_after(EntryExpiry)
42 .build(),
43 ttl_config: Arc::new(ttl_config),
44 }
45 }
46
47 pub fn insert(&self, query: Query, result: Result<Message, NetError>, now: Instant) {
49 let (ttl, result) = match result {
50 Ok(mut message) => {
51 let ttl = self.clamp_positive_ttls(query.query_type(), &mut message);
52 (ttl, Ok(message))
53 }
54 Err(NetError::Dns(DnsError::NoRecordsFound(no_records))) => {
55 let (negative_min_ttl, negative_max_ttl) = self
56 .ttl_config
57 .negative_response_ttl_bounds(query.query_type())
58 .into_inner();
59 let ttl = if let Some(ttl) = no_records.negative_ttl {
60 Duration::from_secs(u64::from(ttl)).clamp(negative_min_ttl, negative_max_ttl)
61 } else {
62 negative_min_ttl
63 };
64 (
65 ttl,
66 Err(NetError::Dns(DnsError::NoRecordsFound(no_records))),
67 )
68 }
69 Err(_) => return,
70 };
71 let valid_until = now + ttl;
72 self.cache.insert(
73 query,
74 Entry {
75 result: Arc::new(result),
76 original_time: now,
77 valid_until,
78 },
79 );
80 }
81
82 pub fn get(&self, query: &Query, now: Instant) -> Option<Result<Message, NetError>> {
84 let entry = self.cache.get(query)?;
85 if !entry.is_current(now) {
86 return None;
87 }
88 Some(entry.updated_ttl(now))
89 }
90
91 pub(crate) fn clamp_positive_ttls(
99 &self,
100 query_type: RecordType,
101 message: &mut Message,
102 ) -> Duration {
103 for record in message
104 .answers
105 .iter_mut()
106 .chain(message.authorities.iter_mut())
107 .chain(message.additionals.iter_mut())
108 {
109 let (min_secs, max_secs) = self
110 .ttl_config
111 .positive_ttl_bounds_secs(record.record_type());
112 record.ttl = record.ttl.clamp(min_secs, max_secs);
113 }
114
115 let (positive_min_ttl, positive_max_ttl) = self
116 .ttl_config
117 .positive_response_ttl_bounds(query_type)
118 .into_inner();
119
120 let min_ttl = message
124 .all_sections()
125 .filter(|r| r.record_type() == query_type)
126 .map(|r| Duration::from_secs(r.ttl.into()))
127 .min();
128
129 min_ttl
130 .unwrap_or(positive_min_ttl)
131 .clamp(positive_min_ttl, positive_max_ttl)
132 }
133
134 pub(crate) fn clear(&self) {
135 self.cache.invalidate_all();
136 }
137
138 pub(crate) fn clear_query(&self, query: &Query) {
139 self.cache.invalidate(query);
140 }
141
142 #[cfg(feature = "metrics")]
144 pub(crate) fn entry_count(&self) -> u64 {
145 #[cfg(test)]
146 {
147 self.cache.run_pending_tasks();
152 }
153
154 self.cache.entry_count()
155 }
156}
157
158#[derive(Debug, Clone)]
163struct Entry {
164 result: Arc<Result<Message, NetError>>,
165 original_time: Instant,
166 valid_until: Instant,
167}
168
169impl Entry {
170 fn updated_ttl(&self, now: Instant) -> Result<Message, NetError> {
173 let elapsed = u32::try_from(now.saturating_duration_since(self.original_time).as_secs())
174 .unwrap_or(u32::MAX);
175 match &*self.result {
176 Ok(response) => {
177 let mut response = response.clone();
178 for records in [
179 &mut response.answers,
180 &mut response.authorities,
181 &mut response.additionals,
182 ] {
183 for record in records {
184 record.decrement_ttl(elapsed);
185 }
186 }
187 Ok(response)
188 }
189 Err(e) => {
190 let mut e = e.clone();
191
192 if let NetError::Dns(DnsError::NoRecordsFound(NoRecords {
195 negative_ttl,
196 soa,
197 authorities,
198 ns,
199 ..
200 })) = &mut e
201 {
202 if let Some(ttl) = negative_ttl {
203 *ttl = ttl.saturating_sub(elapsed);
204 }
205
206 if let Some(soa) = soa {
207 soa.decrement_ttl(elapsed);
208 }
209
210 if let Some(recs) = authorities.take() {
211 authorities.replace(Arc::from(
212 recs.iter()
213 .cloned()
214 .map(|mut rec| {
215 rec.decrement_ttl(elapsed);
216 rec
217 })
218 .collect::<Vec<_>>(),
219 ));
220 }
221
222 if let Some(ns_recs) = ns.take() {
223 ns.replace(Arc::from(
224 ns_recs
225 .iter()
226 .cloned()
227 .map(|mut ns| {
228 ns.ns.decrement_ttl(elapsed);
229 ns.glue = Arc::from(
230 ns.glue
231 .iter()
232 .cloned()
233 .map(|mut glue| {
234 glue.decrement_ttl(elapsed);
235 glue
236 })
237 .collect::<Vec<_>>(),
238 );
239
240 ns
241 })
242 .collect::<Vec<_>>(),
243 ));
244 }
245 }
246 Err(e)
247 }
248 }
249 }
250
251 fn is_current(&self, now: Instant) -> bool {
253 now <= self.valid_until
254 }
255
256 fn ttl(&self, now: Instant) -> Duration {
258 self.valid_until.saturating_duration_since(now)
259 }
260}
261
262struct EntryExpiry;
263
264impl Expiry<Query, Entry> for EntryExpiry {
265 fn expire_after_create(
266 &self,
267 _key: &Query,
268 value: &Entry,
269 created_at: Instant,
270 ) -> Option<Duration> {
271 Some(value.ttl(created_at))
272 }
273
274 fn expire_after_update(
275 &self,
276 _key: &Query,
277 value: &Entry,
278 updated_at: Instant,
279 _duration_until_expiry: Option<Duration>,
280 ) -> Option<Duration> {
281 Some(value.ttl(updated_at))
282 }
283}
284
285#[derive(Clone, Debug, Default, PartialEq, Eq)]
297#[cfg_attr(feature = "serde", derive(Deserialize))]
298#[cfg_attr(
299 feature = "serde",
300 serde(from = "ttl_config_deserialize::TtlConfigMap")
301)]
302pub struct TtlConfig {
303 default: TtlBounds,
305
306 by_query_type: HashMap<RecordType, TtlBounds>,
308}
309
310impl TtlConfig {
311 pub fn from_opts(opts: &config::ResolverOpts) -> Self {
313 Self::from(TtlBounds {
314 positive_min_ttl: opts.positive_min_ttl,
315 negative_min_ttl: opts.negative_min_ttl,
316 positive_max_ttl: opts.positive_max_ttl,
317 negative_max_ttl: opts.negative_max_ttl,
318 })
319 }
320
321 pub fn with_query_type_ttl_bounds(
326 &mut self,
327 query_type: RecordType,
328 bounds: TtlBounds,
329 ) -> &mut Self {
330 self.by_query_type.insert(query_type, bounds);
331 self
332 }
333
334 fn positive_ttl_bounds_secs(&self, record_type: RecordType) -> (u32, u32) {
339 let (min, max) = self.positive_response_ttl_bounds(record_type).into_inner();
340 (
341 u32::try_from(min.as_secs()).unwrap_or(MAX_TTL),
342 u32::try_from(max.as_secs()).unwrap_or(MAX_TTL),
343 )
344 }
345
346 pub fn positive_response_ttl_bounds(&self, query_type: RecordType) -> RangeInclusive<Duration> {
348 let bounds = self.by_query_type.get(&query_type).unwrap_or(&self.default);
349 let min = bounds
350 .positive_min_ttl
351 .unwrap_or_else(|| Duration::from_secs(0));
352 let max = bounds
353 .positive_max_ttl
354 .unwrap_or_else(|| Duration::from_secs(u64::from(MAX_TTL)));
355 min..=max
356 }
357
358 pub fn negative_response_ttl_bounds(&self, query_type: RecordType) -> RangeInclusive<Duration> {
360 let bounds = self.by_query_type.get(&query_type).unwrap_or(&self.default);
361 let min = bounds
362 .negative_min_ttl
363 .unwrap_or_else(|| Duration::from_secs(0));
364 let max = bounds
365 .negative_max_ttl
366 .unwrap_or_else(|| Duration::from_secs(u64::from(MAX_TTL)));
367 min..=max
368 }
369}
370
371impl From<TtlBounds> for TtlConfig {
372 fn from(default: TtlBounds) -> Self {
373 Self {
374 default,
375 by_query_type: HashMap::default(),
376 }
377 }
378}
379
380#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
382#[cfg_attr(feature = "serde", derive(Deserialize))]
383#[cfg_attr(feature = "serde", serde(deny_unknown_fields))]
384pub struct TtlBounds {
385 #[cfg_attr(
390 feature = "serde",
391 serde(default, deserialize_with = "config::duration_opt::deserialize")
392 )]
393 positive_min_ttl: Option<Duration>,
394
395 #[cfg_attr(
400 feature = "serde",
401 serde(default, deserialize_with = "config::duration_opt::deserialize")
402 )]
403 negative_min_ttl: Option<Duration>,
404
405 #[cfg_attr(
410 feature = "serde",
411 serde(default, deserialize_with = "config::duration_opt::deserialize")
412 )]
413 positive_max_ttl: Option<Duration>,
414
415 #[cfg_attr(
420 feature = "serde",
421 serde(default, deserialize_with = "config::duration_opt::deserialize")
422 )]
423 negative_max_ttl: Option<Duration>,
424}
425
426#[cfg(feature = "serde")]
427mod ttl_config_deserialize {
428 use std::collections::HashMap;
429
430 use serde::Deserialize;
431
432 use super::{TtlBounds, TtlConfig};
433 use crate::proto::rr::RecordType;
434
435 #[derive(Deserialize)]
436 pub(super) struct TtlConfigMap(HashMap<TtlConfigField, TtlBounds>);
437
438 impl From<TtlConfigMap> for TtlConfig {
439 fn from(value: TtlConfigMap) -> Self {
440 let mut default = TtlBounds::default();
441 let mut by_query_type = HashMap::new();
442 for (field, bounds) in value.0.into_iter() {
443 match field {
444 TtlConfigField::RecordType(record_type) => {
445 by_query_type.insert(record_type, bounds);
446 }
447 TtlConfigField::Default => default = bounds,
448 }
449 }
450 Self {
451 default,
452 by_query_type,
453 }
454 }
455 }
456
457 #[derive(PartialEq, Eq, Hash, Deserialize)]
458 enum TtlConfigField {
459 #[serde(rename = "default")]
460 Default,
461 #[serde(untagged)]
462 RecordType(RecordType),
463 }
464}
465
466pub const MAX_TTL: u32 = 86400_u32;
472
473#[cfg(test)]
474mod tests {
475 use std::{
476 str::FromStr,
477 time::{Duration, Instant},
478 };
479
480 #[cfg(feature = "serde")]
481 use serde::Deserialize;
482
483 use super::*;
484 use crate::{
485 net::{ForwardNSData, NetError},
486 proto::{
487 op::{Message, OpCode, Query, ResponseCode},
488 rr::{
489 Name, RData, Record, RecordType,
490 rdata::{A, AAAA, NS, SOA, TXT},
491 },
492 },
493 };
494 use test_support::subscribe;
495
496 #[test]
497 fn test_is_current() {
498 let now = Instant::now();
499 let not_the_future = now + Duration::from_secs(4);
500 let future = now + Duration::from_secs(5);
501 let past_the_future = now + Duration::from_secs(6);
502
503 let entry = Entry {
504 result: Err(NetError::Message("test error")).into(),
505 original_time: now,
506 valid_until: future,
507 };
508
509 assert!(entry.is_current(now));
510 assert!(entry.is_current(not_the_future));
511 assert!(entry.is_current(future));
512 assert!(!entry.is_current(past_the_future));
513 }
514
515 #[test]
516 fn test_positive_min_ttl() {
517 let now = Instant::now();
518
519 let name = Name::from_str("www.example.com.").unwrap();
520 let query = Query::query(name.clone(), RecordType::A);
521 let mut message = Message::response(0, OpCode::Query);
523 message.add_answer(Record::from_rdata(
524 name.clone(),
525 1,
526 RData::A(A::new(127, 0, 0, 1)),
527 ));
528
529 let ttls = TtlConfig::from(TtlBounds {
531 positive_min_ttl: Some(Duration::from_secs(2)),
532 ..TtlBounds::default()
533 });
534 let cache = ResponseCache::new(1, ttls);
535
536 cache.insert(query.clone(), Ok(message), now);
537 let valid_until = cache.cache.get(&query).unwrap().valid_until;
538 assert_eq!(valid_until, now + Duration::from_secs(2));
541
542 let mut message = Message::response(0, OpCode::Query);
544 message.add_answer(Record::from_rdata(
545 name.clone(),
546 3,
547 RData::A(A::new(127, 0, 0, 1)),
548 ));
549
550 cache.insert(query.clone(), Ok(message), now);
551 let valid_until = cache.cache.get(&query).unwrap().valid_until;
552 assert_eq!(valid_until, now + Duration::from_secs(3));
555 }
556
557 #[test]
558 fn test_positive_min_ttl_clamps_record_ttls() {
559 let now = Instant::now();
564
565 let name = Name::from_str("www.example.com.").unwrap();
566 let query = Query::query(name.clone(), RecordType::A);
567
568 let mut message = Message::response(0, OpCode::Query);
570 message.add_answer(Record::from_rdata(
571 name.clone(),
572 60,
573 RData::A(A::new(93, 184, 216, 34)),
574 ));
575
576 let ttls = TtlConfig::from(TtlBounds {
577 positive_min_ttl: Some(Duration::from_secs(3600)),
578 ..TtlBounds::default()
579 });
580 let cache = ResponseCache::new(1, ttls);
581
582 cache.insert(query.clone(), Ok(message), now);
583
584 let result = cache.get(&query, now).unwrap().unwrap();
587 assert_eq!(result.answers.first().unwrap().ttl, 3600);
588
589 let result = cache
592 .get(&query, now + Duration::from_secs(61))
593 .unwrap()
594 .unwrap();
595 assert_eq!(result.answers.first().unwrap().ttl, 3539);
596
597 let result = cache
599 .get(&query, now + Duration::from_secs(3599))
600 .unwrap()
601 .unwrap();
602 assert_eq!(result.answers.first().unwrap().ttl, 1);
603
604 assert!(cache.get(&query, now + Duration::from_secs(3601)).is_none());
606 }
607
608 #[test]
609 fn test_positive_max_ttl_clamps_record_ttls() {
610 let now = Instant::now();
615
616 let name = Name::from_str("www.example.com.").unwrap();
617 let query = Query::query(name.clone(), RecordType::A);
618
619 let mut message = Message::response(0, OpCode::Query);
621 message.add_answer(Record::from_rdata(
622 name.clone(),
623 3600,
624 RData::A(A::new(93, 184, 216, 34)),
625 ));
626
627 let ttls = TtlConfig::from(TtlBounds {
628 positive_max_ttl: Some(Duration::from_secs(120)),
629 ..TtlBounds::default()
630 });
631 let cache = ResponseCache::new(1, ttls);
632
633 cache.insert(query.clone(), Ok(message), now);
634
635 let result = cache.get(&query, now).unwrap().unwrap();
638 assert_eq!(result.answers.first().unwrap().ttl, 120);
639
640 let result = cache
642 .get(&query, now + Duration::from_secs(60))
643 .unwrap()
644 .unwrap();
645 assert_eq!(result.answers.first().unwrap().ttl, 60);
646
647 assert!(cache.get(&query, now + Duration::from_secs(121)).is_none());
649 }
650
651 #[test]
652 fn test_authority_ttl_does_not_shorten_answer_cache() {
653 let now = Instant::now();
657
658 let name = Name::from_str("api.example.com.").unwrap();
659 let query = Query::query(name.clone(), RecordType::AAAA);
660
661 let mut message = Message::response(0, OpCode::Query);
662 message.add_answer(Record::from_rdata(
664 name.clone(),
665 120,
666 RData::AAAA(AAAA::new(0x2001, 0x0db8, 0, 0, 0, 0, 0, 1)),
667 ));
668 message.add_authority(Record::from_rdata(
670 Name::from_str("example.com.").unwrap(),
671 30,
672 RData::NS(NS(Name::from_str("ns1.example.com.").unwrap())),
673 ));
674
675 let ttls = TtlConfig::from(TtlBounds {
676 positive_min_ttl: Some(Duration::from_secs(3600)),
677 positive_max_ttl: Some(Duration::from_secs(28800)),
678 ..TtlBounds::default()
679 });
680 let cache = ResponseCache::new(1, ttls);
681
682 cache.insert(query.clone(), Ok(message), now);
683
684 let valid_until = cache.cache.get(&query).unwrap().valid_until;
687 assert_eq!(valid_until, now + Duration::from_secs(3600));
688
689 let result = cache
691 .get(&query, now + Duration::from_secs(130))
692 .unwrap()
693 .unwrap();
694 assert_eq!(result.answers.first().unwrap().ttl, 3470);
696
697 assert!(cache.get(&query, now + Duration::from_secs(3601)).is_none());
699 }
700
701 #[test]
702 fn test_negative_min_ttl() {
703 let now = Instant::now();
704
705 let name = Name::from_str("www.example.com.").unwrap();
706 let query = Query::query(name.clone(), RecordType::A);
707
708 let ttls = TtlConfig::from(TtlBounds {
710 negative_min_ttl: Some(Duration::from_secs(2)),
711 ..TtlBounds::default()
712 });
713 let cache = ResponseCache::new(1, ttls);
714
715 let mut no_records = NoRecords::new(query.clone(), ResponseCode::NoError);
717 no_records.negative_ttl = Some(1);
718 cache.insert(query.clone(), Err(no_records.into()), now);
719 let valid_until = cache.cache.get(&query).unwrap().valid_until;
720 assert_eq!(valid_until, now + Duration::from_secs(2));
722
723 let mut no_records = NoRecords::new(query.clone(), ResponseCode::NoError);
725 no_records.negative_ttl = Some(3);
726 cache.insert(query.clone(), Err(no_records.into()), now);
727 let valid_until = cache.cache.get(&query).unwrap().valid_until;
728 assert_eq!(valid_until, now + Duration::from_secs(3));
731 }
732
733 #[test]
734 fn test_positive_max_ttl() {
735 let now = Instant::now();
736
737 let name = Name::from_str("www.example.com.").unwrap();
738 let query = Query::query(name.clone(), RecordType::A);
739 let mut message = Message::response(0, OpCode::Query);
741 message.add_answer(Record::from_rdata(
742 name.clone(),
743 62,
744 RData::A(A::new(127, 0, 0, 1)),
745 ));
746
747 let ttls = TtlConfig::from(TtlBounds {
749 positive_max_ttl: Some(Duration::from_secs(60)),
750 ..Default::default()
751 });
752 let cache = ResponseCache::new(1, ttls);
753
754 cache.insert(query.clone(), Ok(message), now);
755 let valid_until = cache.cache.get(&query).unwrap().valid_until;
756 assert_eq!(valid_until, now + Duration::from_secs(60));
759
760 let mut message = Message::response(0, OpCode::Query);
762 message.add_answer(Record::from_rdata(
763 name.clone(),
764 59,
765 RData::A(A::new(127, 0, 0, 1)),
766 ));
767
768 cache.insert(query.clone(), Ok(message), now);
769 let valid_until = cache.cache.get(&query).unwrap().valid_until;
770 assert_eq!(valid_until, now + Duration::from_secs(59));
773 }
774
775 #[test]
776 fn test_negative_max_ttl() {
777 let now = Instant::now();
778
779 let name = Name::from_str("www.example.com.").unwrap();
780 let query = Query::query(name.clone(), RecordType::A);
781
782 let ttls = TtlConfig::from(TtlBounds {
784 negative_max_ttl: Some(Duration::from_secs(60)),
785 ..TtlBounds::default()
786 });
787 let cache = ResponseCache::new(1, ttls);
788
789 let mut no_records = NoRecords::new(query.clone(), ResponseCode::NoError);
791 no_records.negative_ttl = Some(62);
792 cache.insert(query.clone(), Err(no_records.into()), now);
793 let valid_until = cache.cache.get(&query).unwrap().valid_until;
794 assert_eq!(valid_until, now + Duration::from_secs(60));
796
797 let mut no_records = NoRecords::new(query.clone(), ResponseCode::NoError);
799 no_records.negative_ttl = Some(59);
800 cache.insert(query.clone(), Err(no_records.into()), now);
801 let valid_until = cache.cache.get(&query).unwrap().valid_until;
802 assert_eq!(valid_until, now + Duration::from_secs(59));
805 }
806
807 #[test]
808 fn test_insert() {
809 let now = Instant::now();
810
811 let name = Name::from_str("www.example.com.").unwrap();
812 let query = Query::query(name.clone(), RecordType::A);
813 let mut message = Message::response(0, OpCode::Query);
814 message.add_answer(Record::from_rdata(
815 name.clone(),
816 1,
817 RData::A(A::new(127, 0, 0, 1)),
818 ));
819 let cache = ResponseCache::new(1, TtlConfig::default());
820 cache.insert(query.clone(), Ok(message.clone()), now);
821
822 let result = cache.get(&query, now).unwrap();
823 let cache_message = result.unwrap();
824 assert_eq!(cache_message.answers, message.answers);
825 }
826
827 #[test]
828 fn test_insert_negative() {
829 subscribe();
830 let now = Instant::now();
831
832 let query = Query::query(
833 Name::from_str("www.example.com.").unwrap(),
834 RecordType::AAAA,
835 );
836
837 let mut norecs = NoRecords::new(query.clone(), ResponseCode::NXDomain);
838 norecs.negative_ttl = Some(10);
839 let error = NetError::from(norecs);
840 let cache = ResponseCache::new(1, TtlConfig::default());
841
842 cache.insert(query.clone(), Err(error), now);
843
844 let cache_err = cache.get(&query, now).unwrap().unwrap_err();
845 let NetError::Dns(DnsError::NoRecordsFound(_no_records)) = &cache_err else {
846 panic!("expected NoRecordsFound");
847 };
848
849 assert!(cache.get(&query, now + Duration::from_secs(11)).is_none());
851 }
852
853 #[test]
854 fn test_update_ttl() {
855 let now = Instant::now();
856
857 let name = Name::from_str("www.example.com.").unwrap();
858 let query = Query::query(name.clone(), RecordType::A);
859 let mut message = Message::response(0, OpCode::Query);
860 message.add_answer(Record::from_rdata(
861 name.clone(),
862 10,
863 RData::A(A::new(127, 0, 0, 1)),
864 ));
865 let cache = ResponseCache::new(1, TtlConfig::default());
866 cache.insert(query.clone(), Ok(message), now);
867
868 let result = cache.get(&query, now + Duration::from_secs(2)).unwrap();
869 let cache_message = result.unwrap();
870 let record = cache_message.answers.first().unwrap();
871 assert_eq!(record.ttl, 8);
872 }
873
874 #[test]
875 fn test_update_ttl_negative() -> Result<(), NetError> {
876 subscribe();
877 let now = Instant::now();
878 let name = Name::from_str("www.example.com.")?;
879 let ns_name = Name::from_str("ns1.example.com")?;
880 let zone_name = name.base_name();
881 let query = Query::query(name.clone(), RecordType::AAAA);
882
883 let mut norecs = NoRecords::new(query.clone(), ResponseCode::NXDomain);
884 norecs.negative_ttl = Some(10);
885 norecs.soa = Some(Box::new(Record::from_rdata(
886 zone_name.clone(),
887 10,
888 SOA::new(name.base_name(), name.clone(), 1, 1, 1, 1, 1),
889 )));
890 norecs.authorities = Some(Arc::new([Record::from_rdata(
891 zone_name.clone(),
892 10,
893 RData::NS(NS(ns_name.clone())),
894 )]));
895 norecs.ns = Some(Arc::new([ForwardNSData {
896 ns: Record::from_rdata(zone_name.clone(), 10, RData::NS(NS(ns_name.clone()))),
897 glue: Arc::new([Record::from_rdata(
898 ns_name.clone(),
899 10,
900 RData::A(A([192, 0, 2, 1].into())),
901 )]),
902 }]));
903
904 let error = NetError::from(norecs);
905
906 let cache = ResponseCache::new(1, TtlConfig::default());
907 cache.insert(query.clone(), Err(error), now);
908
909 let cache_err = cache.get(&query, now).unwrap().unwrap_err();
910 let NetError::Dns(DnsError::NoRecordsFound(no_records)) = &cache_err else {
911 panic!("expected NoRecordsFound");
912 };
913
914 let Some(soa) = no_records.soa.clone() else {
915 panic!("no SOA in NoRecordsFound");
916 };
917 assert_eq!(soa.ttl, 10);
918
919 let cache_err = cache
920 .get(&query, now + Duration::from_secs(2))
921 .unwrap()
922 .unwrap_err();
923 let NetError::Dns(DnsError::NoRecordsFound(NoRecords {
924 negative_ttl: Some(negative_ttl),
925 soa: Some(soa),
926 authorities: Some(authorities),
927 ns: Some(ns),
928 ..
929 })) = &cache_err
930 else {
931 panic!("expected NoRecordsFound with negative_ttl, soa, authorities, and ns");
932 };
933
934 assert_eq!(*negative_ttl, 8);
935 assert_eq!(soa.ttl, 8);
936 assert_eq!(authorities[0].ttl, 8);
937 assert_eq!(ns[0].ns.ttl, 8);
938
939 assert!(cache.get(&query, now + Duration::from_secs(11)).is_none());
941 Ok(())
942 }
943
944 #[test]
945 fn test_insert_ttl() {
946 let now = Instant::now();
947
948 let name = Name::from_str("www.example.com.").unwrap();
949 let query = Query::query(name.clone(), RecordType::A);
950
951 let mut message = Message::response(0, OpCode::Query);
953 message.add_answer(Record::from_rdata(
954 name.clone(),
955 1,
956 RData::A(A::new(127, 0, 0, 1)),
957 ));
958 message.add_answer(Record::from_rdata(name, 2, RData::A(A::new(127, 0, 0, 2))));
959
960 let cache = ResponseCache::new(1, TtlConfig::default());
961 cache.insert(query.clone(), Ok(message), now);
962
963 cache
965 .get(&query, now + Duration::from_secs(1))
966 .unwrap()
967 .unwrap();
968
969 let option = cache.get(&query, now + Duration::from_secs(2));
971 assert!(option.is_none());
972 }
973
974 #[test]
975 fn test_ttl_different_query_types() {
976 let now = Instant::now();
977 let name = Name::from_str("www.example.com.").unwrap();
978
979 let query_a = Query::query(name.clone(), RecordType::A);
981 let rdata_a = RData::A(A::new(127, 0, 0, 1));
982 let mut message_a = Message::response(0, OpCode::Query);
983 message_a.add_answer(Record::from_rdata(name.clone(), 1, rdata_a.clone()));
984
985 let query_txt = Query::query(name.clone(), RecordType::TXT);
986 let rdata_txt = RData::TXT(TXT::new(vec!["data".to_string()]));
987 let mut message_txt = Message::response(0, OpCode::Query);
988 message_txt.add_answer(Record::from_rdata(name.clone(), 1, rdata_txt.clone()));
989
990 let mut ttl_config = TtlConfig::from(TtlBounds {
992 positive_min_ttl: Some(Duration::from_secs(2)),
993 ..TtlBounds::default()
994 });
995 ttl_config.with_query_type_ttl_bounds(
996 RecordType::TXT,
997 TtlBounds {
998 positive_min_ttl: Some(Duration::from_secs(5)),
999 ..TtlBounds::default()
1000 },
1001 );
1002 let cache = ResponseCache::new(2, ttl_config);
1003
1004 cache.insert(query_a.clone(), Ok(message_a), now);
1005 assert_eq!(
1008 cache.cache.get(&query_a).unwrap().valid_until,
1009 now + Duration::from_secs(2)
1010 );
1011
1012 cache.insert(query_txt.clone(), Ok(message_txt), now);
1013 assert_eq!(
1016 cache.cache.get(&query_txt).unwrap().valid_until,
1017 now + Duration::from_secs(5)
1018 );
1019
1020 let mut message_a = Message::response(0, OpCode::Query);
1022 message_a.add_answer(Record::from_rdata(name.clone(), 7, rdata_a));
1023
1024 let mut message_txt = Message::response(0, OpCode::Query);
1025 message_txt.add_answer(Record::from_rdata(name.clone(), 7, rdata_txt));
1026
1027 cache.insert(query_a.clone(), Ok(message_a), now);
1028 assert_eq!(
1030 cache.cache.get(&query_a).unwrap().valid_until,
1031 now + Duration::from_secs(7)
1032 );
1033
1034 cache.insert(query_txt.clone(), Ok(message_txt), now);
1035 assert_eq!(
1037 cache.cache.get(&query_txt).unwrap().valid_until,
1038 now + Duration::from_secs(7)
1039 );
1040 }
1041
1042 #[cfg(feature = "serde")]
1043 #[test]
1044 fn ttl_config_deserialize_errors() {
1045 let input = r#"[default]
1047positive_max_ttl = 3600
1048[default]
1049positive_max_ttl = 3599"#;
1050 let error = toml::from_str::<TtlConfig>(input).unwrap_err();
1051 assert!(
1052 error.message().contains("duplicate key"),
1053 "wrong error message: {error}"
1054 );
1055
1056 let input = r#"[default]
1058positive_max_ttl = 86400
1059[OPENPGPKEY]
1060positive_max_ttl = 3600
1061[OPENPGPKEY]
1062negative_min_ttl = 60"#;
1063 let error = toml::from_str::<TtlConfig>(input).unwrap_err();
1064 assert!(
1065 error.message().contains("duplicate key"),
1066 "wrong error message: {error}"
1067 );
1068
1069 let input = r#"[not_a_record_type]
1071positive_max_ttl = 3600"#;
1072 let error = toml::from_str::<TtlConfig>(input).unwrap_err();
1073 assert!(
1074 error.message().contains("data did not match any variant"),
1075 "wrong error message: {error}"
1076 );
1077
1078 #[derive(Debug, Deserialize)]
1080 struct Wrapper {
1081 #[allow(unused)]
1082 cache_policy: TtlConfig,
1083 }
1084 let input = r#"cache_policy = []"#;
1085 let error = toml::from_str::<Wrapper>(input).unwrap_err();
1086 assert!(
1087 error.message().contains("invalid type: sequence"),
1088 "wrong error message: {error}"
1089 );
1090
1091 let input = r#"cache_policy = "yes""#;
1093 let error = toml::from_str::<Wrapper>(input).unwrap_err();
1094 assert!(
1095 error.message().contains("invalid type: string"),
1096 "wrong error message: {error}"
1097 );
1098 }
1099}