Skip to main content

mdns_sd/
dns_parser.rs

1//! DNS parsing utility.
2//!
3//! [DnsIncoming] is the logic representation of an incoming DNS packet.
4//! [DnsOutgoing] is the logic representation of an outgoing DNS message of one or more packets.
5//! [DnsOutPacket] is the encoded one packet for [DnsOutgoing].
6
7#[cfg(feature = "logging")]
8use crate::log::trace;
9
10use crate::error::{e_fmt, Error, Result};
11use crate::service_info::is_unicast_link_local;
12
13use if_addrs::Interface;
14
15use std::{
16    any::Any,
17    cmp,
18    collections::HashMap,
19    convert::TryInto,
20    fmt,
21    net::{IpAddr, Ipv4Addr, Ipv6Addr},
22    str,
23    time::SystemTime,
24};
25
26/// Represents a network interface identifier defined by the OS.
27#[derive(Clone, Debug, Eq, Hash, PartialEq, Default)]
28pub struct InterfaceId {
29    /// Interface name, e.g. "en0", "wlan0", etc.
30    pub name: String,
31
32    /// Interface index assigned by the OS, e.g. 1, 2, etc.
33    pub index: u32,
34}
35
36impl fmt::Display for InterfaceId {
37    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
38        write!(f, "{}('{}')", self.index, self.name)
39    }
40}
41
42impl From<&Interface> for InterfaceId {
43    fn from(interface: &Interface) -> Self {
44        InterfaceId {
45            name: interface.name.clone(),
46            index: interface.index.unwrap_or_default(),
47        }
48    }
49}
50
51/// An IPv4 address used in `ScopedIp`.
52///
53/// Note: IPv4 addresses don't have scope IDs, but this type is named for consistency
54/// with the rest of the addressing system.
55#[derive(Debug, Clone, Eq, PartialEq, Hash)]
56pub struct ScopedIpV4 {
57    addr: Ipv4Addr,
58}
59
60impl ScopedIpV4 {
61    /// Returns the IPv4 address.
62    pub const fn addr(&self) -> &Ipv4Addr {
63        &self.addr
64    }
65}
66
67/// An IPv6 address with scope_id (interface identifier).
68#[derive(Debug, Clone, Eq, PartialEq, Hash)]
69pub struct ScopedIpV6 {
70    addr: Ipv6Addr,
71    scope_id: InterfaceId,
72}
73
74impl ScopedIpV6 {
75    /// Returns the IPv6 address.
76    pub const fn addr(&self) -> &Ipv6Addr {
77        &self.addr
78    }
79
80    /// Returns the scope_id for this IPv6 address.
81    pub const fn scope_id(&self) -> &InterfaceId {
82        &self.scope_id
83    }
84}
85
86/// An IP address, either IPv4 or IPv6, that supports scope_id for IPv6.
87#[derive(Debug, Clone, Eq, PartialEq, Hash)]
88#[non_exhaustive]
89pub enum ScopedIp {
90    V4(ScopedIpV4),
91    V6(ScopedIpV6),
92}
93
94impl ScopedIp {
95    pub const fn to_ip_addr(&self) -> IpAddr {
96        match self {
97            ScopedIp::V4(v4) => IpAddr::V4(v4.addr),
98            ScopedIp::V6(v6) => IpAddr::V6(v6.addr),
99        }
100    }
101
102    pub const fn is_ipv4(&self) -> bool {
103        matches!(self, ScopedIp::V4(_))
104    }
105
106    pub const fn is_ipv6(&self) -> bool {
107        matches!(self, ScopedIp::V6(_))
108    }
109
110    pub const fn is_loopback(&self) -> bool {
111        match self {
112            ScopedIp::V4(v4) => v4.addr.is_loopback(),
113            ScopedIp::V6(v6) => v6.addr.is_loopback(),
114        }
115    }
116}
117
118impl From<IpAddr> for ScopedIp {
119    fn from(ip: IpAddr) -> Self {
120        match ip {
121            IpAddr::V4(v4) => ScopedIp::V4(ScopedIpV4 { addr: v4 }),
122            IpAddr::V6(v6) => ScopedIp::V6(ScopedIpV6 {
123                addr: v6,
124                scope_id: InterfaceId::default(),
125            }),
126        }
127    }
128}
129
130impl From<&Interface> for ScopedIp {
131    fn from(interface: &Interface) -> Self {
132        match interface.ip() {
133            IpAddr::V4(v4) => ScopedIp::V4(ScopedIpV4 { addr: v4 }),
134            IpAddr::V6(v6) => ScopedIp::V6(ScopedIpV6 {
135                addr: v6,
136                scope_id: InterfaceId::from(interface),
137            }),
138        }
139    }
140}
141
142impl fmt::Display for ScopedIp {
143    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
144        match self {
145            ScopedIp::V4(v4) => write!(f, "{}", v4.addr),
146            ScopedIp::V6(v6) => {
147                if v6.scope_id.index != 0 && is_unicast_link_local(&v6.addr) {
148                    #[cfg(windows)]
149                    {
150                        write!(f, "{}%{}", v6.addr, v6.scope_id.index)
151                    }
152                    #[cfg(not(windows))]
153                    {
154                        write!(f, "{}%{}", v6.addr, v6.scope_id.name)
155                    }
156                } else {
157                    write!(f, "{}", v6.addr)
158                }
159            }
160        }
161    }
162}
163
164/// DNS resource record types, stored as `u16`. Can do `as u16` when needed.
165///
166/// See [RFC 1035 section 3.2.2](https://datatracker.ietf.org/doc/html/rfc1035#section-3.2.2)
167#[derive(Debug, PartialEq, Eq, Clone, Copy, PartialOrd, Ord)]
168#[non_exhaustive]
169#[repr(u16)]
170pub enum RRType {
171    /// DNS record type for IPv4 address
172    A = 1,
173
174    /// DNS record type for Canonical Name
175    CNAME = 5,
176
177    /// DNS record type for Pointer
178    PTR = 12,
179
180    /// DNS record type for Host Info
181    HINFO = 13,
182
183    /// DNS record type for Text (properties)
184    TXT = 16,
185
186    /// DNS record type for IPv6 address
187    AAAA = 28,
188
189    /// DNS record type for Service
190    SRV = 33,
191
192    /// DNS record type for Negative Responses
193    NSEC = 47,
194
195    /// DNS record type for any records (wildcard)
196    ANY = 255,
197}
198
199impl RRType {
200    /// Converts `u16` into `RRType` if possible.
201    pub const fn from_u16(value: u16) -> Option<Self> {
202        match value {
203            1 => Some(RRType::A),
204            5 => Some(RRType::CNAME),
205            12 => Some(RRType::PTR),
206            13 => Some(RRType::HINFO),
207            16 => Some(RRType::TXT),
208            28 => Some(RRType::AAAA),
209            33 => Some(RRType::SRV),
210            47 => Some(RRType::NSEC),
211            255 => Some(RRType::ANY),
212            _ => None,
213        }
214    }
215}
216
217impl fmt::Display for RRType {
218    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
219        match self {
220            RRType::A => write!(f, "TYPE_A"),
221            RRType::CNAME => write!(f, "TYPE_CNAME"),
222            RRType::PTR => write!(f, "TYPE_PTR"),
223            RRType::HINFO => write!(f, "TYPE_HINFO"),
224            RRType::TXT => write!(f, "TYPE_TXT"),
225            RRType::AAAA => write!(f, "TYPE_AAAA"),
226            RRType::SRV => write!(f, "TYPE_SRV"),
227            RRType::NSEC => write!(f, "TYPE_NSEC"),
228            RRType::ANY => write!(f, "TYPE_ANY"),
229        }
230    }
231}
232
233/// The class value for the Internet.
234pub const CLASS_IN: u16 = 1;
235pub const CLASS_MASK: u16 = 0x7FFF;
236
237/// Cache-flush bit: the most significant bit of the rrclass field of the resource record.  
238pub const CLASS_CACHE_FLUSH: u16 = 0x8000;
239
240/// Max size of UDP datagram payload.
241///
242/// It is calculated as: 9000 bytes - IP header 20 bytes - UDP header 8 bytes.
243/// Reference: [RFC6762 section 17](https://datatracker.ietf.org/doc/html/rfc6762#section-17)
244pub const MAX_MSG_ABSOLUTE: usize = 8972;
245
246const MSG_HEADER_LEN: usize = 12;
247
248// Definitions for DNS message header "flags" field
249//
250// The "flags" field is 16-bit long, in this format:
251// (RFC 1035 section 4.1.1)
252//
253//   0  1  2  3  4  5  6  7  8  9  0  1  2  3  4  5
254// |QR|   Opcode  |AA|TC|RD|RA|   Z    |   RCODE   |
255//
256pub const FLAGS_QR_MASK: u16 = 0x8000; // mask for query/response bit
257
258/// Flag bit to indicate a query
259pub const FLAGS_QR_QUERY: u16 = 0x0000;
260
261/// Flag bit to indicate a response
262pub const FLAGS_QR_RESPONSE: u16 = 0x8000;
263
264/// Flag bit for Authoritative Answer
265pub const FLAGS_AA: u16 = 0x0400;
266
267/// mask for TC(Truncated) bit
268///
269/// 2024-08-10: currently this flag is only supported on the querier side,
270///             not supported on the responder side. I.e. the responder only
271///             handles the first packet and ignore this bit. Since the
272///             additional packets have 0 questions, the processing of them
273///             is no-op.
274///             In practice, this means the responder supports Known-Answer
275///             only with single packet, not multi-packet. The querier supports
276///             both single packet and multi-packet.
277pub const FLAGS_TC: u16 = 0x0200;
278
279/// A convenience type alias for DNS record trait objects.
280pub type DnsRecordBox = Box<dyn DnsRecordExt>;
281
282impl Clone for DnsRecordBox {
283    fn clone(&self) -> Self {
284        self.clone_box()
285    }
286}
287
288const U16_SIZE: usize = 2;
289
290/// Returns `RRType` for a given IP address.
291#[inline]
292pub const fn ip_address_rr_type(address: &IpAddr) -> RRType {
293    match address {
294        IpAddr::V4(_) => RRType::A,
295        IpAddr::V6(_) => RRType::AAAA,
296    }
297}
298
299#[derive(Eq, PartialEq, Debug, Clone)]
300pub struct DnsEntry {
301    pub(crate) name: String, // always lower case.
302    pub(crate) ty: RRType,
303    class: u16,
304    cache_flush: bool,
305}
306
307impl DnsEntry {
308    const fn new(name: String, ty: RRType, class: u16) -> Self {
309        Self {
310            name,
311            ty,
312            class: class & CLASS_MASK,
313            cache_flush: (class & CLASS_CACHE_FLUSH) != 0,
314        }
315    }
316}
317
318/// Common methods for all DNS entries:  questions and resource records.
319pub trait DnsEntryExt: fmt::Debug {
320    fn entry_name(&self) -> &str;
321
322    fn entry_type(&self) -> RRType;
323}
324
325/// A DNS question entry
326#[derive(Debug)]
327pub struct DnsQuestion {
328    pub(crate) entry: DnsEntry,
329}
330
331impl DnsEntryExt for DnsQuestion {
332    fn entry_name(&self) -> &str {
333        &self.entry.name
334    }
335
336    fn entry_type(&self) -> RRType {
337        self.entry.ty
338    }
339}
340
341/// A DNS Resource Record - like a DNS entry, but has a TTL.
342/// RFC: https://www.rfc-editor.org/rfc/rfc1035#section-3.2.1
343///      https://www.rfc-editor.org/rfc/rfc1035#section-4.1.3
344#[derive(Debug, Clone)]
345pub struct DnsRecord {
346    pub(crate) entry: DnsEntry,
347    ttl: u32,     // in seconds, 0 means this record should not be cached
348    created: u64, // UNIX time in millis
349    expires: u64, // expires at this UNIX time in millis
350
351    /// Support re-query an instance before its PTR record expires.
352    /// See https://datatracker.ietf.org/doc/html/rfc6762#section-5.2
353    refresh: u64, // UNIX time in millis
354
355    /// If conflict resolution decides to change the name, this is the new one.
356    new_name: Option<String>,
357}
358
359impl DnsRecord {
360    fn new(name: &str, ty: RRType, class: u16, ttl: u32) -> Self {
361        let created = current_time_millis();
362
363        // From RFC 6762 section 5.2:
364        // "... The querier should plan to issue a query at 80% of the record
365        // lifetime, and then if no answer is received, at 85%, 90%, and 95%."
366        let refresh = get_expiration_time(created, ttl, 80);
367
368        let expires = get_expiration_time(created, ttl, 100);
369
370        Self {
371            entry: DnsEntry::new(name.to_string(), ty, class),
372            ttl,
373            created,
374            expires,
375            refresh,
376            new_name: None,
377        }
378    }
379
380    pub const fn get_ttl(&self) -> u32 {
381        self.ttl
382    }
383
384    pub const fn get_expire_time(&self) -> u64 {
385        self.expires
386    }
387
388    pub const fn get_refresh_time(&self) -> u64 {
389        self.refresh
390    }
391
392    pub const fn is_expired(&self, now: u64) -> bool {
393        now >= self.expires
394    }
395
396    /// Returns whether record expires in 1 second.
397    ///
398    /// This is useful because mDNS sets TTL to 1 (not 0) for expiring records.
399    pub const fn expires_soon(&self, now: u64) -> bool {
400        now + 1000 >= self.expires
401    }
402
403    pub const fn refresh_due(&self, now: u64) -> bool {
404        now >= self.refresh
405    }
406
407    /// Returns whether `now` (in millis) has passed half of TTL.
408    pub fn halflife_passed(&self, now: u64) -> bool {
409        let halflife = get_expiration_time(self.created, self.ttl, 50);
410        now > halflife
411    }
412
413    pub fn is_unique(&self) -> bool {
414        self.entry.cache_flush
415    }
416
417    /// Updates the refresh time to be the same as the expire time so that
418    /// this record will not refresh again and will just expire.
419    pub fn refresh_no_more(&mut self) {
420        self.refresh = get_expiration_time(self.created, self.ttl, 100);
421    }
422
423    /// Returns if this record is due for refresh. If yes, `refresh` time is updated.
424    pub fn refresh_maybe(&mut self, now: u64) -> bool {
425        if self.is_expired(now) || !self.refresh_due(now) {
426            return false;
427        }
428
429        trace!(
430            "{} qtype {} is due to refresh",
431            &self.entry.name,
432            self.entry.ty
433        );
434
435        // From RFC 6762 section 5.2:
436        // "... The querier should plan to issue a query at 80% of the record
437        // lifetime, and then if no answer is received, at 85%, 90%, and 95%."
438        //
439        // If the answer is received in time, 'refresh' will be reset outside
440        // this function, back to 80% of the new TTL.
441        if self.refresh == get_expiration_time(self.created, self.ttl, 80) {
442            self.refresh = get_expiration_time(self.created, self.ttl, 85);
443        } else if self.refresh == get_expiration_time(self.created, self.ttl, 85) {
444            self.refresh = get_expiration_time(self.created, self.ttl, 90);
445        } else if self.refresh == get_expiration_time(self.created, self.ttl, 90) {
446            self.refresh = get_expiration_time(self.created, self.ttl, 95);
447        } else {
448            self.refresh_no_more();
449        }
450
451        true
452    }
453
454    /// Returns the remaining TTL in seconds
455    fn get_remaining_ttl(&self, now: u64) -> u32 {
456        let remaining_millis = get_expiration_time(self.created, self.ttl, 100) - now;
457        cmp::max(0, remaining_millis / 1000) as u32
458    }
459
460    /// Return the absolute time for this record being created
461    pub const fn get_created(&self) -> u64 {
462        self.created
463    }
464
465    /// Set the absolute expiration time in millis
466    fn set_expire(&mut self, expire_at: u64) {
467        self.expires = expire_at;
468    }
469
470    fn reset_ttl(&mut self, other: &Self) {
471        self.ttl = other.ttl;
472        self.created = other.created;
473        self.expires = get_expiration_time(self.created, self.ttl, 100);
474        self.refresh = if self.ttl > 1 {
475            get_expiration_time(self.created, self.ttl, 80)
476        } else {
477            // If TTL is 1, it means this record is expiring,
478            // then we set refresh to the same time as expires.
479            self.expires
480        };
481    }
482
483    /// Modify TTL to reflect the remaining life time from `now`.
484    pub fn update_ttl(&mut self, now: u64) {
485        if now > self.created {
486            let elapsed = now - self.created;
487            self.ttl -= (elapsed / 1000) as u32;
488        }
489    }
490
491    pub fn set_new_name(&mut self, new_name: String) {
492        if new_name == self.entry.name {
493            self.new_name = None;
494        } else {
495            self.new_name = Some(new_name);
496        }
497    }
498
499    pub fn get_new_name(&self) -> Option<&str> {
500        self.new_name.as_deref()
501    }
502
503    /// Return the new name if exists, otherwise the regular name in DnsEntry.
504    pub(crate) fn get_name(&self) -> &str {
505        self.new_name.as_deref().unwrap_or(&self.entry.name)
506    }
507
508    pub fn get_original_name(&self) -> &str {
509        &self.entry.name
510    }
511}
512
513impl PartialEq for DnsRecord {
514    fn eq(&self, other: &Self) -> bool {
515        self.entry == other.entry
516    }
517}
518
519/// Common methods for DNS resource records.
520pub trait DnsRecordExt: fmt::Debug {
521    fn get_record(&self) -> &DnsRecord;
522    fn get_record_mut(&mut self) -> &mut DnsRecord;
523    fn write(&self, packet: &mut DnsOutPacket);
524    fn any(&self) -> &dyn Any;
525
526    /// Returns whether `other` record is considered the same except TTL.
527    fn matches(&self, other: &dyn DnsRecordExt) -> bool;
528
529    /// Returns whether `other` record has the same rdata.
530    fn rrdata_match(&self, other: &dyn DnsRecordExt) -> bool;
531
532    /// Returns the result based on a byte-level comparison of `rdata`.
533    /// If `other` is not valid, returns `Greater`.
534    fn compare_rdata(&self, other: &dyn DnsRecordExt) -> cmp::Ordering;
535
536    /// Returns the result based on "lexicographically later" defined below.
537    fn compare(&self, other: &dyn DnsRecordExt) -> cmp::Ordering {
538        /*
539        RFC 6762: https://datatracker.ietf.org/doc/html/rfc6762#section-8.2
540
541        ... The determination of "lexicographically later" is performed by first
542        comparing the record class (excluding the cache-flush bit described
543        in Section 10.2), then the record type, then raw comparison of the
544        binary content of the rdata without regard for meaning or structure.
545        If the record classes differ, then the numerically greater class is
546        considered "lexicographically later".  Otherwise, if the record types
547        differ, then the numerically greater type is considered
548        "lexicographically later".  If the rrtype and rrclass both match,
549        then the rdata is compared. ...
550        */
551        match self.get_class().cmp(&other.get_class()) {
552            cmp::Ordering::Equal => match self.get_type().cmp(&other.get_type()) {
553                cmp::Ordering::Equal => self.compare_rdata(other),
554                not_equal => not_equal,
555            },
556            not_equal => not_equal,
557        }
558    }
559
560    /// Returns a human-readable string of rdata.
561    fn rdata_print(&self) -> String;
562
563    /// Returns the class only, excluding class_flush / unique bit.
564    fn get_class(&self) -> u16 {
565        self.get_record().entry.class
566    }
567
568    fn get_cache_flush(&self) -> bool {
569        self.get_record().entry.cache_flush
570    }
571
572    /// Return the new name if exists, otherwise the regular name in DnsEntry.
573    fn get_name(&self) -> &str {
574        self.get_record().get_name()
575    }
576
577    fn get_type(&self) -> RRType {
578        self.get_record().entry.ty
579    }
580
581    /// Resets TTL using `other` record.
582    /// `self.refresh` and `self.expires` are also reset.
583    fn reset_ttl(&mut self, other: &dyn DnsRecordExt) {
584        self.get_record_mut().reset_ttl(other.get_record());
585    }
586
587    fn get_created(&self) -> u64 {
588        self.get_record().get_created()
589    }
590
591    fn get_expire(&self) -> u64 {
592        self.get_record().get_expire_time()
593    }
594
595    fn set_expire(&mut self, expire_at: u64) {
596        self.get_record_mut().set_expire(expire_at);
597    }
598
599    /// Set expire as `expire_at` if it is sooner than the current `expire`.
600    fn set_expire_sooner(&mut self, expire_at: u64) {
601        if expire_at < self.get_expire() {
602            self.get_record_mut().set_expire(expire_at);
603        }
604    }
605
606    /// Returns true if the record expires in 1 second from `now`.
607    fn expires_soon(&self, now: u64) -> bool {
608        self.get_record().expires_soon(now)
609    }
610
611    /// Given `now`, if the record is due to refresh, this method updates the refresh time
612    /// and returns the new refresh time. Otherwise, returns None.
613    fn updated_refresh_time(&mut self, now: u64) -> Option<u64> {
614        if self.get_record_mut().refresh_maybe(now) {
615            Some(self.get_record().get_refresh_time())
616        } else {
617            None
618        }
619    }
620
621    /// Returns true if another record has matched content,
622    /// and if its TTL is at least half of this record's.
623    fn suppressed_by_answer(&self, other: &dyn DnsRecordExt) -> bool {
624        self.matches(other) && (other.get_record().ttl > self.get_record().ttl / 2)
625    }
626
627    /// Required by RFC 6762 Section 7.1: Known-Answer Suppression.
628    fn suppressed_by(&self, msg: &DnsIncoming) -> bool {
629        for answer in msg.answers.iter() {
630            if self.suppressed_by_answer(answer.as_ref()) {
631                return true;
632            }
633        }
634        false
635    }
636
637    fn clone_box(&self) -> DnsRecordBox;
638
639    fn boxed(self) -> DnsRecordBox;
640}
641
642/// Resource Record for IPv4 address or IPv6 address.
643#[derive(Debug, Clone)]
644pub(crate) struct DnsAddress {
645    pub(crate) record: DnsRecord,
646    address: IpAddr,
647    pub(crate) interface_id: InterfaceId,
648}
649
650impl DnsAddress {
651    pub fn new(
652        name: &str,
653        ty: RRType,
654        class: u16,
655        ttl: u32,
656        address: IpAddr,
657        interface_id: InterfaceId,
658    ) -> Self {
659        let record = DnsRecord::new(name, ty, class, ttl);
660        Self {
661            record,
662            address,
663            interface_id,
664        }
665    }
666
667    pub fn address(&self) -> ScopedIp {
668        match self.address {
669            IpAddr::V4(v4) => ScopedIp::V4(ScopedIpV4 { addr: v4 }),
670            IpAddr::V6(v6) => ScopedIp::V6(ScopedIpV6 {
671                addr: v6,
672                scope_id: self.interface_id.clone(),
673            }),
674        }
675    }
676}
677
678impl DnsRecordExt for DnsAddress {
679    fn get_record(&self) -> &DnsRecord {
680        &self.record
681    }
682
683    fn get_record_mut(&mut self) -> &mut DnsRecord {
684        &mut self.record
685    }
686
687    fn write(&self, packet: &mut DnsOutPacket) {
688        match self.address {
689            IpAddr::V4(addr) => packet.write_bytes(addr.octets().as_ref()),
690            IpAddr::V6(addr) => packet.write_bytes(addr.octets().as_ref()),
691        };
692    }
693
694    fn any(&self) -> &dyn Any {
695        self
696    }
697
698    fn matches(&self, other: &dyn DnsRecordExt) -> bool {
699        if let Some(other_a) = other.any().downcast_ref::<Self>() {
700            return self.address == other_a.address
701                && self.record.entry == other_a.record.entry
702                && self.interface_id == other_a.interface_id;
703        }
704        false
705    }
706
707    fn rrdata_match(&self, other: &dyn DnsRecordExt) -> bool {
708        if let Some(other_a) = other.any().downcast_ref::<Self>() {
709            return self.address == other_a.address;
710        }
711        false
712    }
713
714    fn compare_rdata(&self, other: &dyn DnsRecordExt) -> cmp::Ordering {
715        if let Some(other_a) = other.any().downcast_ref::<Self>() {
716            self.address.cmp(&other_a.address)
717        } else {
718            cmp::Ordering::Greater
719        }
720    }
721
722    fn rdata_print(&self) -> String {
723        format!("{}", self.address)
724    }
725
726    fn clone_box(&self) -> DnsRecordBox {
727        Box::new(self.clone())
728    }
729
730    fn boxed(self) -> DnsRecordBox {
731        Box::new(self)
732    }
733}
734
735/// Resource Record for a DNS pointer
736#[derive(Debug, Clone)]
737pub struct DnsPointer {
738    record: DnsRecord,
739    alias: String, // the full name of Service Instance
740}
741
742impl DnsPointer {
743    pub fn new(name: &str, ty: RRType, class: u16, ttl: u32, alias: String) -> Self {
744        let record = DnsRecord::new(name, ty, class, ttl);
745        Self { record, alias }
746    }
747
748    pub fn alias(&self) -> &str {
749        &self.alias
750    }
751}
752
753impl DnsRecordExt for DnsPointer {
754    fn get_record(&self) -> &DnsRecord {
755        &self.record
756    }
757
758    fn get_record_mut(&mut self) -> &mut DnsRecord {
759        &mut self.record
760    }
761
762    fn write(&self, packet: &mut DnsOutPacket) {
763        packet.write_name(&self.alias);
764    }
765
766    fn any(&self) -> &dyn Any {
767        self
768    }
769
770    fn matches(&self, other: &dyn DnsRecordExt) -> bool {
771        if let Some(other_ptr) = other.any().downcast_ref::<Self>() {
772            return self.alias == other_ptr.alias && self.record.entry == other_ptr.record.entry;
773        }
774        false
775    }
776
777    fn rrdata_match(&self, other: &dyn DnsRecordExt) -> bool {
778        if let Some(other_ptr) = other.any().downcast_ref::<Self>() {
779            return self.alias == other_ptr.alias;
780        }
781        false
782    }
783
784    fn compare_rdata(&self, other: &dyn DnsRecordExt) -> cmp::Ordering {
785        if let Some(other_ptr) = other.any().downcast_ref::<Self>() {
786            self.alias.cmp(&other_ptr.alias)
787        } else {
788            cmp::Ordering::Greater
789        }
790    }
791
792    fn rdata_print(&self) -> String {
793        self.alias.clone()
794    }
795
796    fn clone_box(&self) -> DnsRecordBox {
797        Box::new(self.clone())
798    }
799
800    fn boxed(self) -> DnsRecordBox {
801        Box::new(self)
802    }
803}
804
805/// Resource Record for a DNS service.
806#[derive(Debug, Clone)]
807pub struct DnsSrv {
808    pub(crate) record: DnsRecord,
809    pub(crate) priority: u16, // lower number means higher priority. Should be 0 in common cases.
810    pub(crate) weight: u16,   // Should be 0 in common cases
811    host: String,
812    port: u16,
813}
814
815impl DnsSrv {
816    pub fn new(
817        name: &str,
818        class: u16,
819        ttl: u32,
820        priority: u16,
821        weight: u16,
822        port: u16,
823        host: String,
824    ) -> Self {
825        let record = DnsRecord::new(name, RRType::SRV, class, ttl);
826        Self {
827            record,
828            priority,
829            weight,
830            host,
831            port,
832        }
833    }
834
835    pub fn host(&self) -> &str {
836        &self.host
837    }
838
839    pub fn port(&self) -> u16 {
840        self.port
841    }
842
843    pub fn set_host(&mut self, host: String) {
844        self.host = host;
845    }
846}
847
848impl DnsRecordExt for DnsSrv {
849    fn get_record(&self) -> &DnsRecord {
850        &self.record
851    }
852
853    fn get_record_mut(&mut self) -> &mut DnsRecord {
854        &mut self.record
855    }
856
857    fn write(&self, packet: &mut DnsOutPacket) {
858        packet.write_short(self.priority);
859        packet.write_short(self.weight);
860        packet.write_short(self.port);
861        packet.write_name(&self.host);
862    }
863
864    fn any(&self) -> &dyn Any {
865        self
866    }
867
868    fn matches(&self, other: &dyn DnsRecordExt) -> bool {
869        if let Some(other_svc) = other.any().downcast_ref::<Self>() {
870            return self.host == other_svc.host
871                && self.port == other_svc.port
872                && self.weight == other_svc.weight
873                && self.priority == other_svc.priority
874                && self.record.entry == other_svc.record.entry;
875        }
876        false
877    }
878
879    fn rrdata_match(&self, other: &dyn DnsRecordExt) -> bool {
880        if let Some(other_srv) = other.any().downcast_ref::<Self>() {
881            return self.host == other_srv.host
882                && self.port == other_srv.port
883                && self.weight == other_srv.weight
884                && self.priority == other_srv.priority;
885        }
886        false
887    }
888
889    fn compare_rdata(&self, other: &dyn DnsRecordExt) -> cmp::Ordering {
890        let Some(other_srv) = other.any().downcast_ref::<Self>() else {
891            return cmp::Ordering::Greater;
892        };
893
894        // 1. compare `priority`
895        match self
896            .priority
897            .to_be_bytes()
898            .cmp(&other_srv.priority.to_be_bytes())
899        {
900            cmp::Ordering::Equal => {
901                // 2. compare `weight`
902                match self
903                    .weight
904                    .to_be_bytes()
905                    .cmp(&other_srv.weight.to_be_bytes())
906                {
907                    cmp::Ordering::Equal => {
908                        // 3. compare `port`.
909                        match self.port.to_be_bytes().cmp(&other_srv.port.to_be_bytes()) {
910                            cmp::Ordering::Equal => self.host.cmp(&other_srv.host),
911                            not_equal => not_equal,
912                        }
913                    }
914                    not_equal => not_equal,
915                }
916            }
917            not_equal => not_equal,
918        }
919    }
920
921    fn rdata_print(&self) -> String {
922        format!(
923            "priority: {}, weight: {}, port: {}, host: {}",
924            self.priority, self.weight, self.port, self.host
925        )
926    }
927
928    fn clone_box(&self) -> DnsRecordBox {
929        Box::new(self.clone())
930    }
931
932    fn boxed(self) -> DnsRecordBox {
933        Box::new(self)
934    }
935}
936
937/// Resource Record for a DNS TXT record.
938///
939/// From [RFC 6763 section 6]:
940///
941/// The format of each constituent string within the DNS TXT record is a
942/// single length byte, followed by 0-255 bytes of text data.
943///
944/// DNS-SD uses DNS TXT records to store arbitrary key/value pairs
945///    conveying additional information about the named service.  Each
946///    key/value pair is encoded as its own constituent string within the
947///    DNS TXT record, in the form "key=value" (without the quotation
948///    marks).  Everything up to the first '=' character is the key (Section
949///    6.4).  Everything after the first '=' character to the end of the
950///    string (including subsequent '=' characters, if any) is the value
951#[derive(Clone)]
952pub struct DnsTxt {
953    pub(crate) record: DnsRecord,
954    text: Vec<u8>,
955}
956
957impl DnsTxt {
958    pub fn new(name: &str, class: u16, ttl: u32, text: Vec<u8>) -> Self {
959        let record = DnsRecord::new(name, RRType::TXT, class, ttl);
960        Self { record, text }
961    }
962
963    pub fn text(&self) -> &[u8] {
964        &self.text
965    }
966}
967
968impl DnsRecordExt for DnsTxt {
969    fn get_record(&self) -> &DnsRecord {
970        &self.record
971    }
972
973    fn get_record_mut(&mut self) -> &mut DnsRecord {
974        &mut self.record
975    }
976
977    fn write(&self, packet: &mut DnsOutPacket) {
978        packet.write_bytes(&self.text);
979    }
980
981    fn any(&self) -> &dyn Any {
982        self
983    }
984
985    fn matches(&self, other: &dyn DnsRecordExt) -> bool {
986        if let Some(other_txt) = other.any().downcast_ref::<Self>() {
987            return self.text == other_txt.text && self.record.entry == other_txt.record.entry;
988        }
989        false
990    }
991
992    fn rrdata_match(&self, other: &dyn DnsRecordExt) -> bool {
993        if let Some(other_txt) = other.any().downcast_ref::<Self>() {
994            return self.text == other_txt.text;
995        }
996        false
997    }
998
999    fn compare_rdata(&self, other: &dyn DnsRecordExt) -> cmp::Ordering {
1000        if let Some(other_txt) = other.any().downcast_ref::<Self>() {
1001            self.text.cmp(&other_txt.text)
1002        } else {
1003            cmp::Ordering::Greater
1004        }
1005    }
1006
1007    fn rdata_print(&self) -> String {
1008        format!("{:?}", decode_txt(&self.text))
1009    }
1010
1011    fn clone_box(&self) -> DnsRecordBox {
1012        Box::new(self.clone())
1013    }
1014
1015    fn boxed(self) -> DnsRecordBox {
1016        Box::new(self)
1017    }
1018}
1019
1020impl fmt::Debug for DnsTxt {
1021    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1022        let properties = decode_txt(&self.text);
1023        write!(
1024            f,
1025            "DnsTxt {{ record: {:?}, text: {:?} }}",
1026            self.record, properties
1027        )
1028    }
1029}
1030
1031// Convert from DNS TXT record content to key/value pairs
1032fn decode_txt(txt: &[u8]) -> Vec<TxtProperty> {
1033    let mut properties = Vec::new();
1034    let mut offset = 0;
1035    while offset < txt.len() {
1036        let length = txt[offset] as usize;
1037        if length == 0 {
1038            break; // reached the end
1039        }
1040        offset += 1; // move over the length byte
1041
1042        let offset_end = offset + length;
1043        if offset_end > txt.len() {
1044            trace!("ERROR: DNS TXT: size given for property is out of range. (offset={}, length={}, offset_end={}, record length={})", offset, length, offset_end, txt.len());
1045            break; // Skipping the rest of the record content, as the size for this property would already be out of range.
1046        }
1047        let kv_bytes = &txt[offset..offset_end];
1048
1049        // split key and val using the first `=`
1050        let (k, v) = kv_bytes.iter().position(|&x| x == b'=').map_or_else(
1051            || (kv_bytes.to_vec(), None),
1052            |idx| (kv_bytes[..idx].to_vec(), Some(kv_bytes[idx + 1..].to_vec())),
1053        );
1054
1055        // Make sure the key can be stored in UTF-8.
1056        match String::from_utf8(k) {
1057            Ok(k_string) => {
1058                properties.push(TxtProperty {
1059                    key: k_string,
1060                    val: v,
1061                });
1062            }
1063            Err(e) => trace!("ERROR: convert to String from key: {}", e),
1064        }
1065
1066        offset += length;
1067    }
1068
1069    properties
1070}
1071
1072/// Represents a property in a TXT record.
1073#[derive(Clone, PartialEq, Eq)]
1074pub struct TxtProperty {
1075    /// The name of the property. The original cases are kept.
1076    key: String,
1077
1078    /// RFC 6763 says values are bytes, not necessarily UTF-8.
1079    /// It is also possible that there is no value, in which case
1080    /// the key is a boolean key.
1081    val: Option<Vec<u8>>,
1082}
1083
1084impl TxtProperty {
1085    /// Returns the value of a property as str.
1086    pub fn val_str(&self) -> &str {
1087        self.val
1088            .as_ref()
1089            .map_or("", |v| std::str::from_utf8(&v[..]).unwrap_or_default())
1090    }
1091}
1092
1093/// Supports constructing from a tuple.
1094impl<K, V> From<&(K, V)> for TxtProperty
1095where
1096    K: ToString,
1097    V: ToString,
1098{
1099    fn from(prop: &(K, V)) -> Self {
1100        Self {
1101            key: prop.0.to_string(),
1102            val: Some(prop.1.to_string().into_bytes()),
1103        }
1104    }
1105}
1106
1107impl<K, V> From<(K, V)> for TxtProperty
1108where
1109    K: ToString,
1110    V: AsRef<[u8]>,
1111{
1112    fn from(prop: (K, V)) -> Self {
1113        Self {
1114            key: prop.0.to_string(),
1115            val: Some(prop.1.as_ref().into()),
1116        }
1117    }
1118}
1119
1120/// Support a property that has no value.
1121impl From<&str> for TxtProperty {
1122    fn from(key: &str) -> Self {
1123        Self {
1124            key: key.to_string(),
1125            val: None,
1126        }
1127    }
1128}
1129
1130impl fmt::Display for TxtProperty {
1131    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1132        write!(f, "{}={}", self.key, self.val_str())
1133    }
1134}
1135
1136/// Mimic the default debug output for a struct, with a twist:
1137/// - If self.var is UTF-8, will output it as a string in double quotes.
1138/// - If self.var is not UTF-8, will output its bytes as in hex.
1139impl fmt::Debug for TxtProperty {
1140    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1141        let val_string = self.val.as_ref().map_or_else(
1142            || "None".to_string(),
1143            |v| {
1144                std::str::from_utf8(&v[..]).map_or_else(
1145                    |_| format!("Some({})", u8_slice_to_hex(&v[..])),
1146                    |s| format!("Some(\"{s}\")"),
1147                )
1148            },
1149        );
1150
1151        write!(
1152            f,
1153            "TxtProperty {{key: \"{}\", val: {}}}",
1154            &self.key, &val_string,
1155        )
1156    }
1157}
1158
1159const HEX_TABLE: [char; 16] = [
1160    '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f',
1161];
1162
1163/// Create a hex string from `slice`, with a "0x" prefix.
1164///
1165/// For example, [1u8, 2u8] -> "0x0102"
1166fn u8_slice_to_hex(slice: &[u8]) -> String {
1167    let mut hex = String::with_capacity(slice.len() * 2 + 2);
1168    hex.push_str("0x");
1169    for b in slice {
1170        hex.push(HEX_TABLE[(b >> 4) as usize]);
1171        hex.push(HEX_TABLE[(b & 0x0F) as usize]);
1172    }
1173    hex
1174}
1175
1176/// A DNS host information record
1177#[derive(Debug, Clone)]
1178struct DnsHostInfo {
1179    record: DnsRecord,
1180    cpu: String,
1181    os: String,
1182}
1183
1184impl DnsHostInfo {
1185    fn new(name: &str, ty: RRType, class: u16, ttl: u32, cpu: String, os: String) -> Self {
1186        let record = DnsRecord::new(name, ty, class, ttl);
1187        Self { record, cpu, os }
1188    }
1189}
1190
1191impl DnsRecordExt for DnsHostInfo {
1192    fn get_record(&self) -> &DnsRecord {
1193        &self.record
1194    }
1195
1196    fn get_record_mut(&mut self) -> &mut DnsRecord {
1197        &mut self.record
1198    }
1199
1200    fn write(&self, packet: &mut DnsOutPacket) {
1201        println!("writing HInfo: cpu {} os {}", &self.cpu, &self.os);
1202        packet.write_bytes(self.cpu.as_bytes());
1203        packet.write_bytes(self.os.as_bytes());
1204    }
1205
1206    fn any(&self) -> &dyn Any {
1207        self
1208    }
1209
1210    fn matches(&self, other: &dyn DnsRecordExt) -> bool {
1211        if let Some(other_hinfo) = other.any().downcast_ref::<Self>() {
1212            return self.cpu == other_hinfo.cpu
1213                && self.os == other_hinfo.os
1214                && self.record.entry == other_hinfo.record.entry;
1215        }
1216        false
1217    }
1218
1219    fn rrdata_match(&self, other: &dyn DnsRecordExt) -> bool {
1220        if let Some(other_hinfo) = other.any().downcast_ref::<Self>() {
1221            return self.cpu == other_hinfo.cpu && self.os == other_hinfo.os;
1222        }
1223        false
1224    }
1225
1226    fn compare_rdata(&self, other: &dyn DnsRecordExt) -> cmp::Ordering {
1227        if let Some(other_hinfo) = other.any().downcast_ref::<Self>() {
1228            match self.cpu.cmp(&other_hinfo.cpu) {
1229                cmp::Ordering::Equal => self.os.cmp(&other_hinfo.os),
1230                ordering => ordering,
1231            }
1232        } else {
1233            cmp::Ordering::Greater
1234        }
1235    }
1236
1237    fn rdata_print(&self) -> String {
1238        format!("cpu: {}, os: {}", self.cpu, self.os)
1239    }
1240
1241    fn clone_box(&self) -> DnsRecordBox {
1242        Box::new(self.clone())
1243    }
1244
1245    fn boxed(self) -> DnsRecordBox {
1246        Box::new(self)
1247    }
1248}
1249
1250/// Resource Record for negative responses
1251///
1252/// [RFC4034 section 4.1](https://datatracker.ietf.org/doc/html/rfc4034#section-4.1)
1253/// and
1254/// [RFC6762 section 6.1](https://datatracker.ietf.org/doc/html/rfc6762#section-6.1)
1255#[derive(Debug, Clone)]
1256pub struct DnsNSec {
1257    record: DnsRecord,
1258    next_domain: String,
1259    type_bitmap: Vec<u8>,
1260}
1261
1262impl DnsNSec {
1263    pub fn new(
1264        name: &str,
1265        class: u16,
1266        ttl: u32,
1267        next_domain: String,
1268        type_bitmap: Vec<u8>,
1269    ) -> Self {
1270        let record = DnsRecord::new(name, RRType::NSEC, class, ttl);
1271        Self {
1272            record,
1273            next_domain,
1274            type_bitmap,
1275        }
1276    }
1277
1278    /// Returns the types marked by `type_bitmap`
1279    pub fn _types(&self) -> Vec<u16> {
1280        // From RFC 4034: 4.1.2 The Type Bit Maps Field
1281        // https://datatracker.ietf.org/doc/html/rfc4034#section-4.1.2
1282        //
1283        // Each bitmap encodes the low-order 8 bits of RR types within the
1284        // window block, in network bit order.  The first bit is bit 0.  For
1285        // window block 0, bit 1 corresponds to RR type 1 (A), bit 2 corresponds
1286        // to RR type 2 (NS), and so forth.
1287
1288        let mut bit_num = 0;
1289        let mut results = Vec::new();
1290
1291        for byte in self.type_bitmap.iter() {
1292            let mut bit_mask: u8 = 0x80; // for bit 0 in network bit order
1293
1294            // check every bit in this byte, one by one.
1295            for _ in 0..8 {
1296                if (byte & bit_mask) != 0 {
1297                    results.push(bit_num);
1298                }
1299                bit_num += 1;
1300                bit_mask >>= 1; // mask for the next bit
1301            }
1302        }
1303        results
1304    }
1305}
1306
1307impl DnsRecordExt for DnsNSec {
1308    fn get_record(&self) -> &DnsRecord {
1309        &self.record
1310    }
1311
1312    fn get_record_mut(&mut self) -> &mut DnsRecord {
1313        &mut self.record
1314    }
1315
1316    fn write(&self, packet: &mut DnsOutPacket) {
1317        packet.write_bytes(self.next_domain.as_bytes());
1318        packet.write_bytes(&self.type_bitmap);
1319    }
1320
1321    fn any(&self) -> &dyn Any {
1322        self
1323    }
1324
1325    fn matches(&self, other: &dyn DnsRecordExt) -> bool {
1326        if let Some(other_record) = other.any().downcast_ref::<Self>() {
1327            return self.next_domain == other_record.next_domain
1328                && self.type_bitmap == other_record.type_bitmap
1329                && self.record.entry == other_record.record.entry;
1330        }
1331        false
1332    }
1333
1334    fn rrdata_match(&self, other: &dyn DnsRecordExt) -> bool {
1335        if let Some(other_record) = other.any().downcast_ref::<Self>() {
1336            return self.next_domain == other_record.next_domain
1337                && self.type_bitmap == other_record.type_bitmap;
1338        }
1339        false
1340    }
1341
1342    fn compare_rdata(&self, other: &dyn DnsRecordExt) -> cmp::Ordering {
1343        if let Some(other_nsec) = other.any().downcast_ref::<Self>() {
1344            match self.next_domain.cmp(&other_nsec.next_domain) {
1345                cmp::Ordering::Equal => self.type_bitmap.cmp(&other_nsec.type_bitmap),
1346                ordering => ordering,
1347            }
1348        } else {
1349            cmp::Ordering::Greater
1350        }
1351    }
1352
1353    fn rdata_print(&self) -> String {
1354        format!(
1355            "next_domain: {}, type_bitmap len: {}",
1356            self.next_domain,
1357            self.type_bitmap.len()
1358        )
1359    }
1360
1361    fn clone_box(&self) -> DnsRecordBox {
1362        Box::new(self.clone())
1363    }
1364
1365    fn boxed(self) -> DnsRecordBox {
1366        Box::new(self)
1367    }
1368}
1369
1370#[derive(PartialEq)]
1371enum PacketState {
1372    Init = 0,
1373    Finished = 1,
1374}
1375
1376/// A single packet for outgoing DNS message.
1377pub struct DnsOutPacket {
1378    /// All bytes in `data` concatenated is the actual packet on the wire.
1379    data: Vec<Vec<u8>>,
1380
1381    /// Current logical size of the packet. It starts with the size of the mandatory header.
1382    size: usize,
1383
1384    /// An internal state, not defined by DNS.
1385    state: PacketState,
1386
1387    /// k: name, v: offset
1388    names: HashMap<String, u16>,
1389}
1390
1391impl DnsOutPacket {
1392    fn new() -> Self {
1393        Self {
1394            data: Vec::new(),
1395            size: MSG_HEADER_LEN, // Header is mandatory.
1396            state: PacketState::Init,
1397            names: HashMap::new(),
1398        }
1399    }
1400
1401    pub fn size(&self) -> usize {
1402        self.size
1403    }
1404
1405    pub fn to_bytes(&self) -> Vec<u8> {
1406        self.data.concat()
1407    }
1408
1409    fn write_question(&mut self, question: &DnsQuestion) {
1410        self.write_name(&question.entry.name);
1411        self.write_short(question.entry.ty as u16);
1412        self.write_short(question.entry.class);
1413    }
1414
1415    /// Writes a record (answer, authoritative answer, additional)
1416    /// Returns false if the packet exceeds the max size with this record, nothing is written to the packet.
1417    /// otherwise returns true.
1418    fn write_record(&mut self, record_ext: &dyn DnsRecordExt, now: u64) -> bool {
1419        let start_data_length = self.data.len();
1420        let start_size = self.size;
1421
1422        let record = record_ext.get_record();
1423        self.write_name(record.get_name());
1424        self.write_short(record.entry.ty as u16);
1425        if record.entry.cache_flush {
1426            // check "multicast"
1427            self.write_short(record.entry.class | CLASS_CACHE_FLUSH);
1428        } else {
1429            self.write_short(record.entry.class);
1430        }
1431
1432        if now == 0 {
1433            self.write_u32(record.ttl);
1434        } else {
1435            self.write_u32(record.get_remaining_ttl(now));
1436        }
1437
1438        let index = self.data.len();
1439
1440        // Adjust size for the short we will write before this record
1441        self.size += 2;
1442        record_ext.write(self);
1443        self.size -= 2;
1444
1445        let length: usize = self.data[index..].iter().map(|x| x.len()).sum();
1446        self.insert_short(index, length as u16);
1447
1448        if self.size > MAX_MSG_ABSOLUTE {
1449            self.data.truncate(start_data_length);
1450            self.size = start_size;
1451            self.state = PacketState::Finished;
1452            return false;
1453        }
1454
1455        true
1456    }
1457
1458    pub(crate) fn insert_short(&mut self, index: usize, value: u16) {
1459        self.data.insert(index, value.to_be_bytes().to_vec());
1460        self.size += 2;
1461    }
1462
1463    /// Parses a DNS name that may contain escaped characters according to RFC 6763 Section 4.3.
1464    /// Returns a vector of labels where each label is the unescaped content.
1465    ///
1466    /// Escape sequences:
1467    /// - \\. becomes . (literal dot)
1468    /// - \\\\ becomes \\ (literal backslash)
1469    fn parse_escaped_name(name: &str) -> Vec<String> {
1470        let mut labels = Vec::new();
1471        let mut current_label = String::new();
1472        let mut chars = name.chars().peekable();
1473
1474        while let Some(ch) = chars.next() {
1475            match ch {
1476                '\\' => {
1477                    // Backslash escape sequence
1478                    if let Some(&next_ch) = chars.peek() {
1479                        match next_ch {
1480                            '.' | '\\' => {
1481                                // \\. or \\\\ - consume the backslash and add the escaped char
1482                                chars.next();
1483                                current_label.push(next_ch);
1484                            }
1485                            _ => {
1486                                // Not a recognized escape - treat backslash literally
1487                                current_label.push(ch);
1488                            }
1489                        }
1490                    } else {
1491                        // Trailing backslash - add it literally
1492                        current_label.push(ch);
1493                    }
1494                }
1495                '.' => {
1496                    // Unescaped dot - label separator
1497                    if !current_label.is_empty() {
1498                        labels.push(current_label.clone());
1499                        current_label.clear();
1500                    }
1501                }
1502                _ => {
1503                    current_label.push(ch);
1504                }
1505            }
1506        }
1507
1508        // Add the last label if not empty
1509        if !current_label.is_empty() {
1510            labels.push(current_label);
1511        }
1512
1513        labels
1514    }
1515
1516    // Write name to packet
1517    //
1518    // [RFC1035]
1519    // 4.1.4. Message compression
1520    //
1521    // In order to reduce the size of messages, the domain system utilizes a
1522    // compression scheme which eliminates the repetition of domain names in a
1523    // message.  In this scheme, an entire domain name or a list of labels at
1524    // the end of a domain name is replaced with a pointer to a prior occurrence
1525    // of the same name.
1526    // The pointer takes the form of a two octet sequence:
1527    //     +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
1528    //     | 1  1|                OFFSET                   |
1529    //     +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
1530    // The first two bits are ones.  This allows a pointer to be distinguished
1531    // from a label, since the label must begin with two zero bits because
1532    // labels are restricted to 63 octets or less.  (The 10 and 01 combinations
1533    // are reserved for future use.)  The OFFSET field specifies an offset from
1534    // the start of the message (i.e., the first octet of the ID field in the
1535    // domain header).  A zero offset specifies the first byte of the ID field,
1536    // etc.
1537    //
1538    // This function also handles RFC 6763 Section 4.3 escaping where dots and backslashes
1539    // in instance names are escaped (e.g., "My\\.Service" represents a single label "My.Service").
1540    // The actual name sent over the wire is the unescaped version.
1541    fn write_name(&mut self, name: &str) {
1542        // Remove trailing dot if present
1543        let name_to_parse = name.strip_suffix('.').unwrap_or(name);
1544
1545        // Parse the name considering escape sequences
1546        let labels = Self::parse_escaped_name(name_to_parse);
1547
1548        if labels.is_empty() {
1549            self.write_byte(0);
1550            return;
1551        }
1552
1553        // Write each label
1554        for (i, label) in labels.iter().enumerate() {
1555            // Build the remaining name for compression (with dots as separators)
1556            let remaining: String = labels[i..].join(".");
1557
1558            // Check if we can use compression for the remaining part
1559            const POINTER_MASK: u16 = 0xC000;
1560            if let Some(&offset) = self.names.get(&remaining) {
1561                let pointer = offset | POINTER_MASK;
1562                self.write_short(pointer);
1563                return;
1564            }
1565
1566            // Store this position for potential future compression
1567            self.names.insert(remaining, self.size as u16);
1568
1569            // Write the label
1570            self.write_utf8(label);
1571        }
1572
1573        // Write terminating zero byte
1574        self.write_byte(0);
1575    }
1576
1577    fn write_utf8(&mut self, utf: &str) {
1578        assert!(utf.len() < 64);
1579        self.write_byte(utf.len() as u8);
1580        self.write_bytes(utf.as_bytes());
1581    }
1582
1583    fn write_bytes(&mut self, s: &[u8]) {
1584        self.data.push(s.to_vec());
1585        self.size += s.len();
1586    }
1587
1588    fn write_u32(&mut self, int: u32) {
1589        self.data.push(int.to_be_bytes().to_vec());
1590        self.size += 4;
1591    }
1592
1593    fn write_short(&mut self, short: u16) {
1594        self.data.push(short.to_be_bytes().to_vec());
1595        self.size += 2;
1596    }
1597
1598    fn write_byte(&mut self, byte: u8) {
1599        self.data.push(vec![byte]);
1600        self.size += 1;
1601    }
1602
1603    /// Writes the header fields and finish the packet.
1604    /// This function should be only called when finishing a packet.
1605    ///
1606    /// The header format is based on RFC 1035 section 4.1.1:
1607    /// https://datatracker.ietf.org/doc/html/rfc1035#section-4.1.1
1608    //
1609    //                                  1  1  1  1  1  1
1610    //    0  1  2  3  4  5  6  7  8  9  0  1  2  3  4  5
1611    //    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
1612    //    |                      ID                       |
1613    //    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
1614    //    |QR|   Opcode  |AA|TC|RD|RA|   Z    |   RCODE   |
1615    //    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
1616    //    |                    QDCOUNT                    |
1617    //    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
1618    //    |                    ANCOUNT                    |
1619    //    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
1620    //    |                    NSCOUNT                    |
1621    //    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
1622    //    |                    ARCOUNT                    |
1623    //    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
1624    //
1625    fn write_header(
1626        &mut self,
1627        id: u16,
1628        flags: u16,
1629        q_count: u16,
1630        a_count: u16,
1631        auth_count: u16,
1632        addi_count: u16,
1633    ) {
1634        self.insert_short(0, addi_count);
1635        self.insert_short(0, auth_count);
1636        self.insert_short(0, a_count);
1637        self.insert_short(0, q_count);
1638        self.insert_short(0, flags);
1639        self.insert_short(0, id);
1640
1641        // Adjust the size as it was already initialized to include the header.
1642        self.size -= MSG_HEADER_LEN;
1643
1644        self.state = PacketState::Finished;
1645    }
1646}
1647
1648/// Representation of one outgoing DNS message that could be sent in one or more packet(s).
1649pub struct DnsOutgoing {
1650    flags: u16,
1651    id: u16,
1652    multicast: bool,
1653    questions: Vec<DnsQuestion>,
1654    answers: Vec<(DnsRecordBox, u64)>,
1655    authorities: Vec<DnsRecordBox>,
1656    additionals: Vec<DnsRecordBox>,
1657    known_answer_count: i64, // for internal maintenance only
1658}
1659
1660impl DnsOutgoing {
1661    pub fn new(flags: u16) -> Self {
1662        Self {
1663            flags,
1664            id: 0,
1665            multicast: true,
1666            questions: Vec::new(),
1667            answers: Vec::new(),
1668            authorities: Vec::new(),
1669            additionals: Vec::new(),
1670            known_answer_count: 0,
1671        }
1672    }
1673
1674    pub fn questions(&self) -> &[DnsQuestion] {
1675        &self.questions
1676    }
1677
1678    /// For testing purposes only.
1679    pub(crate) fn _answers(&self) -> &[(DnsRecordBox, u64)] {
1680        &self.answers
1681    }
1682
1683    pub fn answers_count(&self) -> usize {
1684        self.answers.len()
1685    }
1686
1687    pub fn authorities(&self) -> &[DnsRecordBox] {
1688        &self.authorities
1689    }
1690
1691    pub fn additionals(&self) -> &[DnsRecordBox] {
1692        &self.additionals
1693    }
1694
1695    pub fn known_answer_count(&self) -> i64 {
1696        self.known_answer_count
1697    }
1698
1699    pub fn set_id(&mut self, id: u16) {
1700        self.id = id;
1701    }
1702
1703    pub const fn is_query(&self) -> bool {
1704        (self.flags & FLAGS_QR_MASK) == FLAGS_QR_QUERY
1705    }
1706
1707    const fn is_response(&self) -> bool {
1708        (self.flags & FLAGS_QR_MASK) == FLAGS_QR_RESPONSE
1709    }
1710
1711    // Adds an additional answer
1712
1713    // From: RFC 6763, DNS-Based Service Discovery, February 2013
1714
1715    // 12.  DNS Additional Record Generation
1716
1717    //    DNS has an efficiency feature whereby a DNS server may place
1718    //    additional records in the additional section of the DNS message.
1719    //    These additional records are records that the client did not
1720    //    explicitly request, but the server has reasonable grounds to expect
1721    //    that the client might request them shortly, so including them can
1722    //    save the client from having to issue additional queries.
1723
1724    //    This section recommends which additional records SHOULD be generated
1725    //    to improve network efficiency, for both Unicast and Multicast DNS-SD
1726    //    responses.
1727
1728    // 12.1.  PTR Records
1729
1730    //    When including a DNS-SD Service Instance Enumeration or Selective
1731    //    Instance Enumeration (subtype) PTR record in a response packet, the
1732    //    server/responder SHOULD include the following additional records:
1733
1734    //    o  The SRV record(s) named in the PTR rdata.
1735    //    o  The TXT record(s) named in the PTR rdata.
1736    //    o  All address records (type "A" and "AAAA") named in the SRV rdata.
1737
1738    // 12.2.  SRV Records
1739
1740    //    When including an SRV record in a response packet, the
1741    //    server/responder SHOULD include the following additional records:
1742
1743    //    o  All address records (type "A" and "AAAA") named in the SRV rdata.
1744    pub fn add_additional_answer(&mut self, answer: impl DnsRecordExt + 'static) {
1745        trace!("add_additional_answer: {:?}", &answer);
1746        self.additionals.push(answer.boxed());
1747    }
1748
1749    /// A workaround as Rust doesn't allow us to pass DnsRecordBox in as `impl DnsRecordExt`
1750    pub fn add_answer_box(&mut self, answer_box: DnsRecordBox) {
1751        self.answers.push((answer_box, 0));
1752    }
1753
1754    pub fn add_authority(&mut self, record: DnsRecordBox) {
1755        self.authorities.push(record);
1756    }
1757
1758    /// Returns true if `answer` is added to the outgoing msg.
1759    /// Returns false if `answer` was not added as it expired or suppressed by the incoming `msg`.
1760    pub fn add_answer(
1761        &mut self,
1762        msg: &DnsIncoming,
1763        answer: impl DnsRecordExt + Send + 'static,
1764    ) -> bool {
1765        trace!("Check for add_answer");
1766        if answer.suppressed_by(msg) {
1767            trace!("my answer is suppressed by incoming msg");
1768            self.known_answer_count += 1;
1769            return false;
1770        }
1771
1772        self.add_answer_at_time(answer, 0)
1773    }
1774
1775    /// Returns true if `answer` is added to the outgoing msg.
1776    /// Returns false if the answer is expired `now` hence not added.
1777    /// If `now` is 0, do not check if the answer expires.
1778    pub fn add_answer_at_time(
1779        &mut self,
1780        answer: impl DnsRecordExt + Send + 'static,
1781        now: u64,
1782    ) -> bool {
1783        if now == 0 || !answer.get_record().is_expired(now) {
1784            trace!("add_answer push: {:?}", &answer);
1785            self.answers.push((answer.boxed(), now));
1786            return true;
1787        }
1788        false
1789    }
1790
1791    pub fn add_question(&mut self, name: &str, qtype: RRType) {
1792        let q = DnsQuestion {
1793            entry: DnsEntry::new(name.to_string(), qtype, CLASS_IN),
1794        };
1795        self.questions.push(q);
1796    }
1797
1798    /// Returns a list of actual DNS packet data to be sent on the wire.
1799    pub fn to_data_on_wire(&self) -> Vec<Vec<u8>> {
1800        let packet_list = self.to_packets();
1801        packet_list.iter().map(|p| p.data.concat()).collect()
1802    }
1803
1804    /// Encode self into one or more packets.
1805    pub fn to_packets(&self) -> Vec<DnsOutPacket> {
1806        let mut packet_list = Vec::new();
1807        let mut packet = DnsOutPacket::new();
1808
1809        let mut question_count = self.questions.len() as u16;
1810        let mut answer_count = 0;
1811        let mut auth_count = 0;
1812        let mut addi_count = 0;
1813        let id = if self.multicast { 0 } else { self.id };
1814
1815        for question in self.questions.iter() {
1816            packet.write_question(question);
1817        }
1818
1819        for (answer, time) in self.answers.iter() {
1820            if packet.write_record(answer.as_ref(), *time) {
1821                answer_count += 1;
1822            }
1823        }
1824
1825        for auth in self.authorities.iter() {
1826            auth_count += u16::from(packet.write_record(auth.as_ref(), 0));
1827        }
1828
1829        for addi in self.additionals.iter() {
1830            if packet.write_record(addi.as_ref(), 0) {
1831                addi_count += 1;
1832                continue;
1833            }
1834
1835            // No more processing for response packets.
1836            if self.is_response() {
1837                break;
1838            }
1839
1840            // For query, the current packet exceeds its max size due to known answers,
1841            // need to truncate.
1842
1843            // finish the current packet first.
1844            packet.write_header(
1845                id,
1846                self.flags | FLAGS_TC,
1847                question_count,
1848                answer_count,
1849                auth_count,
1850                addi_count,
1851            );
1852
1853            packet_list.push(packet);
1854
1855            // create a new packet and reset counts.
1856            packet = DnsOutPacket::new();
1857            packet.write_record(addi.as_ref(), 0);
1858
1859            question_count = 0;
1860            answer_count = 0;
1861            auth_count = 0;
1862            addi_count = 1;
1863        }
1864
1865        packet.write_header(
1866            id,
1867            self.flags,
1868            question_count,
1869            answer_count,
1870            auth_count,
1871            addi_count,
1872        );
1873
1874        packet_list.push(packet);
1875        packet_list
1876    }
1877}
1878
1879/// An incoming DNS message. It could be a query or a response.
1880#[derive(Debug)]
1881pub struct DnsIncoming {
1882    offset: usize,
1883    data: Vec<u8>,
1884    questions: Vec<DnsQuestion>,
1885    answers: Vec<DnsRecordBox>,
1886    authorities: Vec<DnsRecordBox>,
1887    additional: Vec<DnsRecordBox>,
1888    id: u16,
1889    flags: u16,
1890    num_questions: u16,
1891    num_answers: u16,
1892    num_authorities: u16,
1893    num_additionals: u16,
1894    interface_id: InterfaceId,
1895}
1896
1897impl DnsIncoming {
1898    pub fn new(data: Vec<u8>, interface_id: InterfaceId) -> Result<Self> {
1899        let mut incoming = Self {
1900            offset: 0,
1901            data,
1902            questions: Vec::new(),
1903            answers: Vec::new(),
1904            authorities: Vec::new(),
1905            additional: Vec::new(),
1906            id: 0,
1907            flags: 0,
1908            num_questions: 0,
1909            num_answers: 0,
1910            num_authorities: 0,
1911            num_additionals: 0,
1912            interface_id,
1913        };
1914
1915        /*
1916        RFC 1035 section 4.1: https://datatracker.ietf.org/doc/html/rfc1035#section-4.1
1917        ...
1918        All communications inside of the domain protocol are carried in a single
1919        format called a message.  The top level format of message is divided
1920        into 5 sections (some of which are empty in certain cases) shown below:
1921
1922            +---------------------+
1923            |        Header       |
1924            +---------------------+
1925            |       Question      | the question for the name server
1926            +---------------------+
1927            |        Answer       | RRs answering the question
1928            +---------------------+
1929            |      Authority      | RRs pointing toward an authority
1930            +---------------------+
1931            |      Additional     | RRs holding additional information
1932            +---------------------+
1933         */
1934        incoming.read_header()?;
1935        incoming.read_questions()?;
1936        incoming.read_answers()?;
1937        incoming.read_authorities()?;
1938        incoming.read_additional()?;
1939
1940        Ok(incoming)
1941    }
1942
1943    pub fn id(&self) -> u16 {
1944        self.id
1945    }
1946
1947    pub fn questions(&self) -> &[DnsQuestion] {
1948        &self.questions
1949    }
1950
1951    pub fn answers(&self) -> &[DnsRecordBox] {
1952        &self.answers
1953    }
1954
1955    pub fn authorities(&self) -> &[DnsRecordBox] {
1956        &self.authorities
1957    }
1958
1959    pub fn additionals(&self) -> &[DnsRecordBox] {
1960        &self.additional
1961    }
1962
1963    pub fn answers_mut(&mut self) -> &mut Vec<DnsRecordBox> {
1964        &mut self.answers
1965    }
1966
1967    pub fn authorities_mut(&mut self) -> &mut Vec<DnsRecordBox> {
1968        &mut self.authorities
1969    }
1970
1971    pub fn additionals_mut(&mut self) -> &mut Vec<DnsRecordBox> {
1972        &mut self.additional
1973    }
1974
1975    pub fn all_records(self) -> impl Iterator<Item = DnsRecordBox> {
1976        self.answers
1977            .into_iter()
1978            .chain(self.authorities)
1979            .chain(self.additional)
1980    }
1981
1982    pub fn num_additionals(&self) -> u16 {
1983        self.num_additionals
1984    }
1985
1986    pub fn num_authorities(&self) -> u16 {
1987        self.num_authorities
1988    }
1989
1990    pub fn num_questions(&self) -> u16 {
1991        self.num_questions
1992    }
1993
1994    pub const fn is_query(&self) -> bool {
1995        (self.flags & FLAGS_QR_MASK) == FLAGS_QR_QUERY
1996    }
1997
1998    pub const fn is_response(&self) -> bool {
1999        (self.flags & FLAGS_QR_MASK) == FLAGS_QR_RESPONSE
2000    }
2001
2002    fn read_header(&mut self) -> Result<()> {
2003        if self.data.len() < MSG_HEADER_LEN {
2004            return Err(e_fmt!(
2005                "DNS incoming: header is too short: {} bytes",
2006                self.data.len()
2007            ));
2008        }
2009
2010        let data = &self.data[0..];
2011        self.id = u16_from_be_slice(&data[..2]);
2012        self.flags = u16_from_be_slice(&data[2..4]);
2013        self.num_questions = u16_from_be_slice(&data[4..6]);
2014        self.num_answers = u16_from_be_slice(&data[6..8]);
2015        self.num_authorities = u16_from_be_slice(&data[8..10]);
2016        self.num_additionals = u16_from_be_slice(&data[10..12]);
2017
2018        self.offset = MSG_HEADER_LEN;
2019
2020        trace!(
2021            "read_header: id {}, {} questions {} answers {} authorities {} additionals",
2022            self.id,
2023            self.num_questions,
2024            self.num_answers,
2025            self.num_authorities,
2026            self.num_additionals
2027        );
2028        Ok(())
2029    }
2030
2031    fn read_questions(&mut self) -> Result<()> {
2032        trace!("read_questions: {}", &self.num_questions);
2033        for i in 0..self.num_questions {
2034            let name = self.read_name()?;
2035
2036            let data = &self.data[self.offset..];
2037            if data.len() < 4 {
2038                return Err(Error::Msg(format!(
2039                    "DNS incoming: question idx {} too short: {}",
2040                    i,
2041                    data.len()
2042                )));
2043            }
2044            let ty = u16_from_be_slice(&data[..2]);
2045            let class = u16_from_be_slice(&data[2..4]);
2046            self.offset += 4;
2047
2048            let Some(rr_type) = RRType::from_u16(ty) else {
2049                return Err(Error::Msg(format!(
2050                    "DNS incoming: question idx {i} qtype unknown: {ty}",
2051                )));
2052            };
2053
2054            self.questions.push(DnsQuestion {
2055                entry: DnsEntry::new(name, rr_type, class),
2056            });
2057        }
2058        Ok(())
2059    }
2060
2061    fn read_answers(&mut self) -> Result<()> {
2062        self.answers = self.read_rr_records(self.num_answers)?;
2063        Ok(())
2064    }
2065
2066    fn read_authorities(&mut self) -> Result<()> {
2067        self.authorities = self.read_rr_records(self.num_authorities)?;
2068        Ok(())
2069    }
2070
2071    fn read_additional(&mut self) -> Result<()> {
2072        self.additional = self.read_rr_records(self.num_additionals)?;
2073        Ok(())
2074    }
2075
2076    /// Decodes a sequence of RR records (in answers, authorities and additionals).
2077    fn read_rr_records(&mut self, count: u16) -> Result<Vec<DnsRecordBox>> {
2078        trace!("read_rr_records: {}", count);
2079        let mut rr_records = Vec::new();
2080
2081        // RFC 1035: https://datatracker.ietf.org/doc/html/rfc1035#section-3.2.1
2082        //
2083        // All RRs have the same top level format shown below:
2084        //                               1  1  1  1  1  1
2085        // 0  1  2  3  4  5  6  7  8  9  0  1  2  3  4  5
2086        // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
2087        // |                                               |
2088        // /                                               /
2089        // /                      NAME                     /
2090        // |                                               |
2091        // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
2092        // |                      TYPE                     |
2093        // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
2094        // |                     CLASS                     |
2095        // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
2096        // |                      TTL                      |
2097        // |                                               |
2098        // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
2099        // |                   RDLENGTH                    |
2100        // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--|
2101        // /                     RDATA                     /
2102        // /                                               /
2103        // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
2104
2105        // Muse have at least TYPE, CLASS, TTL, RDLENGTH fields: 10 bytes.
2106        const RR_HEADER_REMAIN: usize = 10;
2107
2108        for _ in 0..count {
2109            let name = self.read_name()?;
2110            let slice = &self.data[self.offset..];
2111
2112            if slice.len() < RR_HEADER_REMAIN {
2113                return Err(Error::Msg(format!(
2114                    "read_others: RR '{}' is too short after name: {} bytes",
2115                    &name,
2116                    slice.len()
2117                )));
2118            }
2119
2120            let ty = u16_from_be_slice(&slice[..2]);
2121            let class = u16_from_be_slice(&slice[2..4]);
2122            let mut ttl = u32_from_be_slice(&slice[4..8]);
2123            if ttl == 0 && self.is_response() {
2124                // RFC 6762 section 10.1:
2125                // "...Queriers receiving a Multicast DNS response with a TTL of zero SHOULD
2126                // NOT immediately delete the record from the cache, but instead record
2127                // a TTL of 1 and then delete the record one second later."
2128                // See https://datatracker.ietf.org/doc/html/rfc6762#section-10.1
2129
2130                ttl = 1;
2131            }
2132            let rdata_len = u16_from_be_slice(&slice[8..10]) as usize;
2133            self.offset += RR_HEADER_REMAIN;
2134            let next_offset = self.offset + rdata_len;
2135
2136            // Sanity check for RDATA length.
2137            if next_offset > self.data.len() {
2138                return Err(Error::Msg(format!(
2139                    "RR {name} RDATA length {rdata_len} is invalid: remain data len: {}",
2140                    self.data.len() - self.offset
2141                )));
2142            }
2143
2144            // decode RDATA based on the record type.
2145            let rec: Option<DnsRecordBox> = match RRType::from_u16(ty) {
2146                None => None,
2147
2148                Some(rr_type) => match rr_type {
2149                    RRType::CNAME | RRType::PTR => {
2150                        Some(DnsPointer::new(&name, rr_type, class, ttl, self.read_name()?).boxed())
2151                    }
2152                    RRType::TXT => {
2153                        Some(DnsTxt::new(&name, class, ttl, self.read_vec(rdata_len)?).boxed())
2154                    }
2155                    RRType::SRV => Some(
2156                        DnsSrv::new(
2157                            &name,
2158                            class,
2159                            ttl,
2160                            self.read_u16()?,
2161                            self.read_u16()?,
2162                            self.read_u16()?,
2163                            self.read_name()?,
2164                        )
2165                        .boxed(),
2166                    ),
2167                    RRType::HINFO => Some(
2168                        DnsHostInfo::new(
2169                            &name,
2170                            rr_type,
2171                            class,
2172                            ttl,
2173                            self.read_char_string()?,
2174                            self.read_char_string()?,
2175                        )
2176                        .boxed(),
2177                    ),
2178                    RRType::A => Some(
2179                        DnsAddress::new(
2180                            &name,
2181                            rr_type,
2182                            class,
2183                            ttl,
2184                            self.read_ipv4()?.into(),
2185                            self.interface_id.clone(),
2186                        )
2187                        .boxed(),
2188                    ),
2189                    RRType::AAAA => Some(
2190                        DnsAddress::new(
2191                            &name,
2192                            rr_type,
2193                            class,
2194                            ttl,
2195                            self.read_ipv6()?.into(),
2196                            self.interface_id.clone(),
2197                        )
2198                        .boxed(),
2199                    ),
2200                    RRType::NSEC => Some(
2201                        DnsNSec::new(
2202                            &name,
2203                            class,
2204                            ttl,
2205                            self.read_name()?,
2206                            self.read_type_bitmap()?,
2207                        )
2208                        .boxed(),
2209                    ),
2210                    _ => None,
2211                },
2212            };
2213
2214            if let Some(record) = rec {
2215                trace!("read_rr_records: {:?}", &record);
2216                rr_records.push(record);
2217            } else {
2218                trace!("Unsupported DNS record type: {} name: {}", ty, &name);
2219                self.offset += rdata_len;
2220            }
2221
2222            // sanity check.
2223            if self.offset != next_offset {
2224                return Err(Error::Msg(format!(
2225                    "read_rr_records: decode offset error for RData type {} offset: {} expected offset: {}",
2226                    ty, self.offset, next_offset,
2227                )));
2228            }
2229        }
2230
2231        Ok(rr_records)
2232    }
2233
2234    fn read_char_string(&mut self) -> Result<String> {
2235        let length = self.data[self.offset];
2236        self.offset += 1;
2237        self.read_string(length as usize)
2238    }
2239
2240    fn read_u16(&mut self) -> Result<u16> {
2241        let slice = &self.data[self.offset..];
2242        if slice.len() < U16_SIZE {
2243            return Err(Error::Msg(format!(
2244                "read_u16: slice len is only {}",
2245                slice.len()
2246            )));
2247        }
2248        let num = u16_from_be_slice(&slice[..U16_SIZE]);
2249        self.offset += U16_SIZE;
2250        Ok(num)
2251    }
2252
2253    /// Reads the "Type Bit Map" block for a DNS NSEC record.
2254    fn read_type_bitmap(&mut self) -> Result<Vec<u8>> {
2255        // From RFC 6762: 6.1.  Negative Responses
2256        // https://datatracker.ietf.org/doc/html/rfc6762#section-6.1
2257        //   o The Type Bit Map block number is 0.
2258        //   o The Type Bit Map block length byte is a value in the range 1-32.
2259        //   o The Type Bit Map data is 1-32 bytes, as indicated by length
2260        //     byte.
2261
2262        // Sanity check: at least 2 bytes to read.
2263        if self.data.len() < self.offset + 2 {
2264            return Err(Error::Msg(format!(
2265                "DnsIncoming is too short: {} at NSEC Type Bit Map offset {}",
2266                self.data.len(),
2267                self.offset
2268            )));
2269        }
2270
2271        let block_num = self.data[self.offset];
2272        self.offset += 1;
2273        if block_num != 0 {
2274            return Err(Error::Msg(format!(
2275                "NSEC block number is not 0: {block_num}"
2276            )));
2277        }
2278
2279        let block_len = self.data[self.offset] as usize;
2280        if !(1..=32).contains(&block_len) {
2281            return Err(Error::Msg(format!(
2282                "NSEC block length must be in the range 1-32: {block_len}"
2283            )));
2284        }
2285        self.offset += 1;
2286
2287        let end = self.offset + block_len;
2288        if end > self.data.len() {
2289            return Err(Error::Msg(format!(
2290                "NSEC block overflow: {} over RData len {}",
2291                end,
2292                self.data.len()
2293            )));
2294        }
2295        let bitmap = self.data[self.offset..end].to_vec();
2296        self.offset += block_len;
2297
2298        Ok(bitmap)
2299    }
2300
2301    fn read_vec(&mut self, length: usize) -> Result<Vec<u8>> {
2302        if self.data.len() < self.offset + length {
2303            return Err(e_fmt!(
2304                "DNS Incoming: not enough data to read a chunk of data"
2305            ));
2306        }
2307
2308        let v = self.data[self.offset..self.offset + length].to_vec();
2309        self.offset += length;
2310        Ok(v)
2311    }
2312
2313    fn read_ipv4(&mut self) -> Result<Ipv4Addr> {
2314        if self.data.len() < self.offset + 4 {
2315            return Err(e_fmt!("DNS Incoming: not enough data to read an IPV4"));
2316        }
2317
2318        let bytes: [u8; 4] = self.data[self.offset..self.offset + 4]
2319            .try_into()
2320            .map_err(|_| e_fmt!("DNS incoming: Not enough bytes for reading an IPV4"))?;
2321        self.offset += bytes.len();
2322        Ok(Ipv4Addr::from(bytes))
2323    }
2324
2325    fn read_ipv6(&mut self) -> Result<Ipv6Addr> {
2326        if self.data.len() < self.offset + 16 {
2327            return Err(e_fmt!("DNS Incoming: not enough data to read an IPV6"));
2328        }
2329
2330        let bytes: [u8; 16] = self.data[self.offset..self.offset + 16]
2331            .try_into()
2332            .map_err(|_| e_fmt!("DNS incoming: Not enough bytes for reading an IPV6"))?;
2333        self.offset += bytes.len();
2334        Ok(Ipv6Addr::from(bytes))
2335    }
2336
2337    fn read_string(&mut self, length: usize) -> Result<String> {
2338        if self.data.len() < self.offset + length {
2339            return Err(e_fmt!("DNS Incoming: not enough data to read a string"));
2340        }
2341
2342        let s = str::from_utf8(&self.data[self.offset..self.offset + length])
2343            .map_err(|e| Error::Msg(e.to_string()))?;
2344        self.offset += length;
2345        Ok(s.to_string())
2346    }
2347
2348    /// Reads a domain name at the current location of `self.data`.
2349    ///
2350    /// See https://datatracker.ietf.org/doc/html/rfc1035#section-3.1 for
2351    /// domain name encoding.
2352    fn read_name(&mut self) -> Result<String> {
2353        let data = &self.data[..];
2354        let start_offset = self.offset;
2355        let mut offset = start_offset;
2356        let mut name = "".to_string();
2357        let mut at_end = false;
2358
2359        // From RFC1035:
2360        // "...Domain names in messages are expressed in terms of a sequence of labels.
2361        // Each label is represented as a one octet length field followed by that
2362        // number of octets."
2363        //
2364        // "...The compression scheme allows a domain name in a message to be
2365        // represented as either:
2366        // - a sequence of labels ending in a zero octet
2367        // - a pointer
2368        // - a sequence of labels ending with a pointer"
2369        loop {
2370            if offset >= data.len() {
2371                return Err(Error::Msg(format!(
2372                    "read_name: offset: {} data len {}. DnsIncoming: {:?}",
2373                    offset,
2374                    data.len(),
2375                    self
2376                )));
2377            }
2378            let length = data[offset];
2379
2380            // From RFC1035:
2381            // "...Since every domain name ends with the null label of
2382            // the root, a domain name is terminated by a length byte of zero."
2383            if length == 0 {
2384                if !at_end {
2385                    self.offset = offset + 1;
2386                }
2387                break; // The end of the name
2388            }
2389
2390            // Check the first 2 bits for possible "Message compression".
2391            match length & 0xC0 {
2392                0x00 => {
2393                    // regular utf8 string with length
2394                    offset += 1;
2395                    let ending = offset + length as usize;
2396
2397                    // Never read beyond the whole data length.
2398                    if ending > data.len() {
2399                        return Err(Error::Msg(format!(
2400                            "read_name: ending {} exceeds data length {}",
2401                            ending,
2402                            data.len()
2403                        )));
2404                    }
2405
2406                    name += str::from_utf8(&data[offset..ending])
2407                        .map_err(|e| Error::Msg(format!("read_name: from_utf8: {e}")))?;
2408                    name += ".";
2409                    offset += length as usize;
2410                }
2411                0xC0 => {
2412                    // Message compression.
2413                    // See https://datatracker.ietf.org/doc/html/rfc1035#section-4.1.4
2414                    let slice = &data[offset..];
2415                    if slice.len() < U16_SIZE {
2416                        return Err(Error::Msg(format!(
2417                            "read_name: u16 slice len is only {}",
2418                            slice.len()
2419                        )));
2420                    }
2421                    let pointer = (u16_from_be_slice(slice) ^ 0xC000) as usize;
2422                    if pointer >= start_offset {
2423                        // Error: could trigger an infinite loop.
2424                        return Err(Error::Msg(format!(
2425                            "Invalid name compression: pointer {} must be less than the start offset {}",
2426                            &pointer, &start_offset
2427                        )));
2428                    }
2429
2430                    // A pointer marks the end of a domain name.
2431                    if !at_end {
2432                        self.offset = offset + U16_SIZE;
2433                        at_end = true;
2434                    }
2435                    offset = pointer;
2436                }
2437                _ => {
2438                    return Err(Error::Msg(format!(
2439                        "Bad name with invalid length: 0x{:x} offset {}, data (so far): {:x?}",
2440                        length,
2441                        offset,
2442                        &data[..offset]
2443                    )));
2444                }
2445            };
2446        }
2447
2448        Ok(name)
2449    }
2450}
2451
2452/// Returns UNIX time in millis
2453fn current_time_millis() -> u64 {
2454    SystemTime::now()
2455        .duration_since(SystemTime::UNIX_EPOCH)
2456        .expect("failed to get current UNIX time")
2457        .as_millis() as u64
2458}
2459
2460const fn u16_from_be_slice(bytes: &[u8]) -> u16 {
2461    let u8_array: [u8; 2] = [bytes[0], bytes[1]];
2462    u16::from_be_bytes(u8_array)
2463}
2464
2465const fn u32_from_be_slice(s: &[u8]) -> u32 {
2466    let u8_array: [u8; 4] = [s[0], s[1], s[2], s[3]];
2467    u32::from_be_bytes(u8_array)
2468}
2469
2470/// Returns the UNIX time in millis at which this record will have expired
2471/// by a certain percentage.
2472const fn get_expiration_time(created: u64, ttl: u32, percent: u32) -> u64 {
2473    // 'created' is in millis, 'ttl' is in seconds, hence:
2474    // ttl * 1000 * (percent / 100) => ttl * percent * 10
2475    created + (ttl * percent * 10) as u64
2476}