1use std::{
11 collections::BTreeMap,
12 fs,
13 marker::PhantomData,
14 ops::{Deref, DerefMut},
15 path::Path,
16 sync::Arc,
17};
18
19#[cfg(feature = "__dnssec")]
20use crate::{
21 dnssec::NxProofKind,
22 net::runtime::Time,
23 proto::dnssec::{
24 DnsSecResult, DnssecSigner,
25 rdata::{DNSKEY, DNSSECRData},
26 },
27 zone_handler::{DnssecZoneHandler, Nsec3QueryInfo},
28};
29use crate::{
30 net::runtime::{RuntimeProvider, TokioRuntimeProvider},
31 proto::{
32 op::ResponseCode,
33 rr::{DNSClass, LowerName, Name, RData, Record, RecordSet, RecordType, RrKey},
34 serialize::txt::Parser,
35 },
36 server::{Request, RequestInfo},
37 zone_handler::{
38 AuthLookup, AxfrPolicy, AxfrRecords, LookupControlFlow, LookupError, LookupOptions,
39 LookupRecords, ZoneHandler, ZoneTransfer, ZoneType,
40 },
41};
42use hickory_proto::rr::TSigResponseContext;
43#[cfg(feature = "__dnssec")]
44use time::OffsetDateTime;
45use tokio::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard};
46#[cfg(feature = "__dnssec")]
47use tracing::warn;
48use tracing::{debug, info};
49
50mod inner;
51use inner::InnerInMemory;
52
53pub struct InMemoryZoneHandler<P = TokioRuntimeProvider> {
58 origin: LowerName,
59 class: DNSClass,
60 zone_type: ZoneType,
61 axfr_policy: AxfrPolicy,
62 inner: RwLock<InnerInMemory>,
63 #[cfg(feature = "__dnssec")]
64 nx_proof_kind: Option<NxProofKind>,
65 _phantom: PhantomData<P>,
66}
67
68impl<P: RuntimeProvider + Send + Sync> InMemoryZoneHandler<P> {
69 pub fn new(
84 origin: Name,
85 records: BTreeMap<RrKey, RecordSet>,
86 zone_type: ZoneType,
87 axfr_policy: AxfrPolicy,
88 #[cfg(feature = "__dnssec")] nx_proof_kind: Option<NxProofKind>,
89 ) -> Result<Self, String> {
90 let mut this = Self::empty(
91 origin.clone(),
92 zone_type,
93 axfr_policy,
94 #[cfg(feature = "__dnssec")]
95 nx_proof_kind,
96 );
97 let inner = this.inner.get_mut();
98
99 let soa = records
101 .get(&RrKey::new(origin.clone().into(), RecordType::SOA))
102 .and_then(|rrset| match &rrset.records_without_rrsigs().next()?.data {
103 RData::SOA(soa) => Some(soa),
104 _ => None,
105 })
106 .ok_or_else(|| format!("SOA record must be present: {origin}"))?;
107 let serial = soa.serial;
108
109 let iter = records.into_values();
110
111 for rrset in iter {
113 let name = rrset.name().clone();
114 let rr_type = rrset.record_type();
115
116 for record in rrset.records_without_rrsigs() {
117 if !inner.upsert(record.clone(), serial, this.class) {
118 return Err(format!(
119 "Failed to insert {name} {rr_type} to zone: {origin}"
120 ));
121 };
122 }
123 }
124
125 Ok(this)
126 }
127
128 pub fn empty(
134 origin: Name,
135 zone_type: ZoneType,
136 axfr_policy: AxfrPolicy,
137 #[cfg(feature = "__dnssec")] nx_proof_kind: Option<NxProofKind>,
138 ) -> Self {
139 Self {
140 origin: LowerName::new(&origin),
141 class: DNSClass::IN,
142 zone_type,
143 axfr_policy,
144 inner: RwLock::new(InnerInMemory::default()),
145
146 #[cfg(feature = "__dnssec")]
147 nx_proof_kind,
148
149 _phantom: PhantomData,
150 }
151 }
152
153 pub fn class(&self) -> DNSClass {
155 self.class
156 }
157
158 #[cfg(any(test, feature = "testing"))]
160 pub fn set_axfr_policy(&mut self, policy: AxfrPolicy) {
161 self.axfr_policy = policy;
162 }
163
164 pub fn clear(&mut self) {
166 self.inner.get_mut().records.clear()
167 }
168
169 #[cfg(all(feature = "__dnssec", feature = "testing"))]
171 pub async fn secure_keys(&self) -> impl Deref<Target = [DnssecSigner]> + '_ {
172 RwLockWriteGuard::map(self.inner.write().await, |i| i.secure_keys.as_mut_slice())
173 }
174
175 pub async fn records(&self) -> impl Deref<Target = BTreeMap<RrKey, Arc<RecordSet>>> + '_ {
177 RwLockReadGuard::map(self.inner.read().await, |i| &i.records)
178 }
179
180 pub async fn records_mut(
182 &self,
183 ) -> impl DerefMut<Target = BTreeMap<RrKey, Arc<RecordSet>>> + '_ {
184 RwLockWriteGuard::map(self.inner.write().await, |i| &mut i.records)
185 }
186
187 pub fn records_get_mut(&mut self) -> &mut BTreeMap<RrKey, Arc<RecordSet>> {
189 &mut self.inner.get_mut().records
190 }
191
192 pub async fn minimum_ttl(&self) -> u32 {
194 self.inner.read().await.minimum_ttl(self.origin())
195 }
196
197 pub async fn serial(&self) -> u32 {
199 self.inner.read().await.serial(self.origin())
200 }
201
202 #[cfg(feature = "sqlite")]
203 pub(crate) async fn increment_soa_serial(&self) -> u32 {
204 self.inner
205 .write()
206 .await
207 .increment_soa_serial(self.origin(), self.class)
208 }
209
210 pub async fn upsert(&self, record: Record, serial: u32) -> bool {
223 self.inner.write().await.upsert(record, serial, self.class)
224 }
225
226 pub fn upsert_mut(&mut self, record: Record, serial: u32) -> bool {
228 self.inner.get_mut().upsert(record, serial, self.class)
229 }
230
231 #[cfg(feature = "__dnssec")]
239 fn inner_add_zone_signing_key(
240 inner: &mut InnerInMemory,
241 signer: DnssecSigner,
242 origin: &LowerName,
243 dns_class: DNSClass,
244 ) -> DnsSecResult<()> {
245 let zone_ttl = inner.minimum_ttl(origin);
247 let dnskey = DNSKEY::from_key(&signer.key().to_public_key()?);
248 let dnskey = Record::from_rdata(
249 origin.clone().into(),
250 zone_ttl,
251 RData::DNSSEC(DNSSECRData::DNSKEY(dnskey)),
252 );
253
254 let serial = inner.serial(origin);
256 inner.upsert(dnskey, serial, dns_class);
257 inner.secure_keys.push(signer);
258 Ok(())
259 }
260
261 #[cfg(feature = "__dnssec")]
263 pub fn add_zone_signing_key_mut(&mut self, signer: DnssecSigner) -> DnsSecResult<()> {
264 let Self {
265 origin,
266 inner,
267 class,
268 ..
269 } = self;
270
271 Self::inner_add_zone_signing_key(inner.get_mut(), signer, origin, *class)
272 }
273
274 #[cfg(feature = "__dnssec")]
276 pub fn secure_zone_mut(&mut self) -> DnsSecResult<()> {
277 let Self { origin, inner, .. } = self;
278 inner.get_mut().secure_zone_mut(
279 origin,
280 self.class,
281 self.nx_proof_kind.as_ref(),
282 Self::current_time()?,
283 )
284 }
285
286 #[cfg(not(feature = "__dnssec"))]
288 pub fn secure_zone_mut(&mut self) -> Result<(), &str> {
289 Err("DNSSEC was not enabled during compilation.")
290 }
291
292 #[cfg(feature = "__dnssec")]
293 fn current_time() -> DnsSecResult<OffsetDateTime> {
294 let timestamp_unsigned = P::Timer::current_time();
295 let timestamp_signed = timestamp_unsigned
296 .try_into()
297 .map_err(|_| "current time is out of range")?;
298 OffsetDateTime::from_unix_timestamp(timestamp_signed)
299 .map_err(|_| "current time is out of range".into())
300 }
301}
302
303#[async_trait::async_trait]
304impl<P: RuntimeProvider + Send + Sync> ZoneHandler for InMemoryZoneHandler<P> {
305 fn zone_type(&self) -> ZoneType {
307 self.zone_type
308 }
309
310 fn axfr_policy(&self) -> AxfrPolicy {
312 self.axfr_policy
313 }
314
315 fn origin(&self) -> &LowerName {
317 &self.origin
318 }
319
320 async fn lookup(
336 &self,
337 name: &LowerName,
338 mut query_type: RecordType,
339 _request_info: Option<&RequestInfo<'_>>,
340 lookup_options: LookupOptions,
341 ) -> LookupControlFlow<AuthLookup> {
342 let inner = self.inner.read().await;
343
344 if query_type == RecordType::AXFR {
345 return Break(Err(LookupError::NetError(
346 "AXFR must be handled with ZoneHandler::zone_transfer()".into(),
347 )));
348 }
349
350 if query_type == RecordType::ANY {
351 query_type = inner.replace_any(name);
352 }
353
354 let answer = inner.inner_lookup(name, query_type, lookup_options);
355
356 let additionals_root_chain_type: Option<(_, _)> = answer
358 .as_ref()
359 .and_then(|a| maybe_next_name(a, query_type))
360 .and_then(|(search_name, search_type)| {
361 inner
362 .additional_search(name, query_type, search_name, search_type, lookup_options)
363 .map(|adds| (adds, search_type))
364 });
365
366 let (additionals, answer) = match (additionals_root_chain_type, answer, query_type) {
368 (Some((additionals, RecordType::ANAME)), Some(answer), RecordType::A)
369 | (Some((additionals, RecordType::ANAME)), Some(answer), RecordType::AAAA) => {
370 debug_assert_eq!(answer.record_type(), RecordType::ANAME);
372
373 let (rdatas, a_aaaa_ttl) = {
375 let last_record = additionals.last();
376 let a_aaaa_ttl = last_record.map_or(u32::MAX, |r| r.ttl());
377
378 let rdatas: Option<Vec<RData>> = last_record
380 .and_then(|record| match record.record_type() {
381 RecordType::A | RecordType::AAAA => {
382 Some(record.records_without_rrsigs())
384 }
385 _ => None,
386 })
387 .map(|records| records.map(|r| &r.data).cloned().collect::<Vec<_>>());
388
389 (rdatas, a_aaaa_ttl)
390 };
391
392 let ttl = answer.ttl().min(a_aaaa_ttl);
397 let mut new_answer = RecordSet::new(answer.name().clone(), query_type, ttl);
398
399 for rdata in rdatas.into_iter().flatten() {
400 new_answer.add_rdata(rdata);
401 }
402
403 #[cfg(feature = "__dnssec")]
405 if lookup_options.dnssec_ok {
407 let result = Self::current_time().and_then(|time| {
408 InnerInMemory::sign_rrset(
409 &mut new_answer,
410 &inner.secure_keys,
411 self.class(),
412 time,
413 )
414 });
415 if let Err(error) = result {
416 warn!(%error, "failed to sign ANAME record")
418 }
419 }
420
421 let additionals = std::iter::once(answer).chain(additionals).collect();
423
424 (Some(additionals), Some(Arc::new(new_answer)))
427 }
428 (Some((additionals, _)), answer, _) => (Some(additionals), answer),
429 (None, answer, _) => (None, answer),
430 };
431
432 use LookupControlFlow::*;
438 let answers = match answer {
439 Some(rr_set) => LookupRecords::new(lookup_options, rr_set),
440 None => {
441 return Continue(Err(
442 if inner
443 .records
444 .keys()
445 .any(|key| key.name() == name || name.zone_of(key.name()))
446 {
447 LookupError::NameExists
448 } else {
449 LookupError::from(match self.origin().zone_of(name) {
450 true => ResponseCode::NXDomain,
451 false => ResponseCode::Refused,
452 })
453 },
454 ));
455 }
456 };
457
458 Continue(Ok(AuthLookup::answers(
459 answers,
460 additionals.map(|a| LookupRecords::many(lookup_options, a)),
461 )))
462 }
463
464 async fn search(
465 &self,
466 request: &Request,
467 lookup_options: LookupOptions,
468 ) -> (LookupControlFlow<AuthLookup>, Option<TSigResponseContext>) {
469 let request_info = match request.request_info() {
470 Ok(info) => info,
471 Err(e) => return (LookupControlFlow::Break(Err(e)), None),
472 };
473 debug!("searching InMemoryZoneHandler for: {}", request_info.query);
474
475 let lookup_name = request_info.query.name();
476 let record_type: RecordType = request_info.query.query_type();
477
478 match record_type {
480 RecordType::SOA => (
481 self.lookup(
482 self.origin(),
483 record_type,
484 Some(&request_info),
485 lookup_options,
486 )
487 .await,
488 None,
489 ),
490 RecordType::AXFR => (
491 LookupControlFlow::Break(Err(LookupError::NetError(
492 "AXFR must be handled with ZoneHandler::zone_transfer()".into(),
493 ))),
494 None,
495 ),
496 _ => (
498 self.lookup(
499 lookup_name,
500 record_type,
501 Some(&request_info),
502 lookup_options,
503 )
504 .await,
505 None,
506 ),
507 }
508 }
509
510 async fn zone_transfer(
511 &self,
512 request: &Request,
513 lookup_options: LookupOptions,
514 _now: u64,
515 ) -> Option<(
516 Result<ZoneTransfer, LookupError>,
517 Option<TSigResponseContext>,
518 )> {
519 let request_info = match request.request_info() {
520 Ok(info) => info,
521 Err(e) => return Some((Err(e), None)),
522 };
523
524 if request_info.query.query_type() == RecordType::AXFR {
525 if !matches!(self.axfr_policy, AxfrPolicy::AllowAll) {
527 return Some((Err(LookupError::from(ResponseCode::Refused)), None));
528 }
529 }
530
531 let future = self.lookup(self.origin(), RecordType::SOA, None, lookup_options);
532 let start_soa = if let LookupControlFlow::Continue(Ok(res)) = future.await {
533 res.unwrap_records()
534 } else {
535 LookupRecords::Empty
536 };
537
538 let future = self.lookup(
539 self.origin(),
540 RecordType::SOA,
541 None,
542 LookupOptions::default(),
543 );
544 let end_soa = if let LookupControlFlow::Continue(Ok(res)) = future.await {
545 res.unwrap_records()
546 } else {
547 LookupRecords::Empty
548 };
549
550 let records = AxfrRecords::new(
551 lookup_options.dnssec_ok,
552 self.inner.read().await.records.values().cloned().collect(),
553 );
554
555 Some((
556 Ok(ZoneTransfer {
557 start_soa,
558 records,
559 end_soa,
560 }),
561 None,
562 ))
563 }
564
565 #[cfg(feature = "__dnssec")]
574 async fn nsec_records(
575 &self,
576 name: &LowerName,
577 lookup_options: LookupOptions,
578 ) -> LookupControlFlow<AuthLookup> {
579 let inner = self.inner.read().await;
580
581 let rr_key = RrKey::new(name.clone(), RecordType::NSEC);
583 let no_data = inner
584 .records
585 .get(&rr_key)
586 .map(|rr_set| LookupRecords::new(lookup_options, rr_set.clone()));
587
588 if let Some(no_data) = no_data {
589 return LookupControlFlow::Continue(Ok(no_data.into()));
590 }
591
592 let closest_proof = inner.closest_nsec(name);
593
594 let wildcard = name.base_name();
596 let origin = self.origin();
597 let wildcard = if origin.zone_of(&wildcard) {
598 wildcard
599 } else {
600 origin.clone()
601 };
602
603 let wildcard_proof = if wildcard != *name {
605 inner.closest_nsec(&wildcard)
606 } else {
607 None
608 };
609
610 let proofs = match (closest_proof, wildcard_proof) {
611 (Some(closest_proof), Some(wildcard_proof)) => {
612 if wildcard_proof != closest_proof {
614 vec![wildcard_proof, closest_proof]
615 } else {
616 vec![closest_proof]
617 }
618 }
619 (None, Some(proof)) | (Some(proof), None) => vec![proof],
620 (None, None) => vec![],
621 };
622
623 LookupControlFlow::Continue(Ok(LookupRecords::many(lookup_options, proofs).into()))
624 }
625
626 #[cfg(not(feature = "__dnssec"))]
627 async fn nsec_records(
628 &self,
629 _name: &LowerName,
630 _lookup_options: LookupOptions,
631 ) -> LookupControlFlow<AuthLookup> {
632 LookupControlFlow::Continue(Ok(AuthLookup::default()))
633 }
634
635 #[cfg(feature = "__dnssec")]
636 async fn nsec3_records(
637 &self,
638 info: Nsec3QueryInfo<'_>,
639 lookup_options: LookupOptions,
640 ) -> LookupControlFlow<AuthLookup> {
641 let inner = self.inner.read().await;
642 LookupControlFlow::Continue(
643 inner
644 .proof(info, self.origin())
645 .map(|proof| LookupRecords::many(lookup_options, proof).into()),
646 )
647 }
648
649 #[cfg(feature = "__dnssec")]
650 fn nx_proof_kind(&self) -> Option<&NxProofKind> {
651 self.nx_proof_kind.as_ref()
652 }
653
654 #[cfg(feature = "metrics")]
655 fn metrics_label(&self) -> &'static str {
656 "in-memory"
657 }
658}
659
660#[cfg(feature = "__dnssec")]
661#[async_trait::async_trait]
662impl<P: RuntimeProvider + Send + Sync> DnssecZoneHandler for InMemoryZoneHandler<P> {
663 async fn add_zone_signing_key(&self, signer: DnssecSigner) -> DnsSecResult<()> {
669 let mut inner = self.inner.write().await;
670
671 Self::inner_add_zone_signing_key(&mut inner, signer, self.origin(), self.class)
672 }
673
674 async fn secure_zone(&self) -> DnsSecResult<()> {
676 let mut inner = self.inner.write().await;
677
678 inner.secure_zone_mut(
679 self.origin(),
680 self.class,
681 self.nx_proof_kind.as_ref(),
682 Self::current_time()?,
683 )
684 }
685}
686
687fn maybe_next_name(
689 record_set: &RecordSet,
690 query_type: RecordType,
691) -> Option<(LowerName, RecordType)> {
692 let t = match (record_set.record_type(), query_type) {
693 (t @ RecordType::ANAME, RecordType::A)
697 | (t @ RecordType::ANAME, RecordType::AAAA)
698 | (t @ RecordType::ANAME, RecordType::ANAME) => t,
699 (t @ RecordType::NS, RecordType::NS) => t,
700 (t @ RecordType::CNAME, _) => t,
702 (t @ RecordType::MX, RecordType::MX) => t,
703 (t @ RecordType::SRV, RecordType::SRV) => t,
704 _ => return None,
706 };
707
708 let name = match (&record_set.records_without_rrsigs().next()?.data, t) {
709 (RData::ANAME(name), RecordType::ANAME) => name,
710 (RData::NS(ns), RecordType::NS) => &ns.0,
711 (RData::CNAME(name), RecordType::CNAME) => name,
712 (RData::MX(mx), RecordType::MX) => &mx.exchange,
713 (RData::SRV(srv), RecordType::SRV) => &srv.target,
714 _ => return None,
715 };
716
717 Some((LowerName::from(name), t))
718}
719
720pub(crate) fn zone_from_path(
722 zone_path: &Path,
723 origin: Name,
724) -> Result<BTreeMap<RrKey, RecordSet>, String> {
725 info!("loading zone file: {zone_path:?}");
726
727 let buf = fs::read_to_string(zone_path)
730 .map_err(|e| format!("failed to read {}: {e:?}", zone_path.display()))?;
731
732 let (origin, records) = Parser::new(buf, Some(zone_path.to_owned()), Some(origin))
733 .parse()
734 .map_err(|e| format!("failed to parse {}: {e:?}", zone_path.display()))?;
735
736 info!("zone file loaded: {origin} with {} records", records.len());
737 debug!("zone: {records:#?}");
738 Ok(records)
739}