Skip to main content

mdns_sd/
dns_parser.rs

1//! DNS parsing utility.
2//!
3//! [DnsIncoming] is the logic representation of an incoming DNS packet.
4//! [DnsOutgoing] is the logic representation of an outgoing DNS message of one or more packets.
5//! [DnsOutPacket] is the encoded one packet for [DnsOutgoing].
6
7#[cfg(feature = "logging")]
8use crate::log::{debug, trace};
9
10use crate::current_time_millis;
11use crate::error::{e_fmt, Error, Result};
12use crate::service_info::{is_unicast_link_local, DnsRegistry, MyIntf, ServiceInfo};
13
14use if_addrs::Interface;
15
16#[cfg(feature = "serde")]
17use serde::{Deserialize, Serialize};
18
19use std::{
20    any::Any,
21    cmp,
22    collections::HashMap,
23    convert::TryInto,
24    fmt,
25    hash::Hash,
26    net::{IpAddr, Ipv4Addr, Ipv6Addr},
27    str,
28};
29
30/// Represents a network interface identifier defined by the OS.
31#[derive(Clone, Debug, Eq, Hash, PartialEq, Default)]
32#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
33pub struct InterfaceId {
34    /// Interface name, e.g. "en0", "wlan0", etc.
35    pub name: String,
36
37    /// Interface index assigned by the OS, e.g. 1, 2, etc.
38    pub index: u32,
39}
40
41impl InterfaceId {
42    /// Returns all IP addresses associated with this interface by querying the OS.
43    pub fn get_addrs(&self) -> Vec<IpAddr> {
44        if_addrs::get_if_addrs()
45            .unwrap_or_default()
46            .into_iter()
47            .filter(|iface| iface.index == Some(self.index))
48            .map(|iface| iface.ip())
49            .collect()
50    }
51}
52
53impl fmt::Display for InterfaceId {
54    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
55        write!(f, "{}('{}')", self.index, self.name)
56    }
57}
58
59impl From<&Interface> for InterfaceId {
60    fn from(interface: &Interface) -> Self {
61        InterfaceId {
62            name: interface.name.clone(),
63            index: interface.index.unwrap_or_default(),
64        }
65    }
66}
67
68/// An IPv4 address with interface identifiers indicating which interfaces discovered it.
69#[derive(Debug, Clone, Eq, PartialEq, Hash)]
70#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
71pub struct ScopedIpV4 {
72    addr: Ipv4Addr,
73    /// The interfaces this address was discovered on.
74    interface_ids: Vec<InterfaceId>,
75}
76
77impl ScopedIpV4 {
78    /// Creates a new `ScopedIpV4` with a single interface identifier.
79    pub fn new(addr: Ipv4Addr, interface_id: InterfaceId) -> Self {
80        Self {
81            addr,
82            interface_ids: vec![interface_id],
83        }
84    }
85
86    /// Returns the IPv4 address.
87    pub const fn addr(&self) -> &Ipv4Addr {
88        &self.addr
89    }
90
91    /// Returns the interfaces this address was discovered on.
92    pub fn interface_ids(&self) -> &[InterfaceId] {
93        &self.interface_ids
94    }
95
96    /// Adds an interface identifier if not already present.
97    pub(crate) fn add_interface_id(&mut self, id: InterfaceId) {
98        if !self.interface_ids.contains(&id) {
99            self.interface_ids.push(id);
100        }
101    }
102}
103
104/// An IPv6 address with scope_id (interface identifier).
105#[derive(Debug, Clone, Eq, PartialEq, Hash)]
106#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
107pub struct ScopedIpV6 {
108    addr: Ipv6Addr,
109    scope_id: InterfaceId,
110}
111
112impl ScopedIpV6 {
113    /// Returns the IPv6 address.
114    pub const fn addr(&self) -> &Ipv6Addr {
115        &self.addr
116    }
117
118    /// Returns the scope_id for this IPv6 address.
119    pub const fn scope_id(&self) -> &InterfaceId {
120        &self.scope_id
121    }
122}
123
124/// An IP address, either IPv4 or IPv6, that supports scope_id for IPv6.
125#[derive(Debug, Clone, Eq, PartialEq, Hash)]
126#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
127#[non_exhaustive]
128pub enum ScopedIp {
129    V4(ScopedIpV4),
130    V6(ScopedIpV6),
131}
132
133impl ScopedIp {
134    pub const fn to_ip_addr(&self) -> IpAddr {
135        match self {
136            ScopedIp::V4(v4) => IpAddr::V4(v4.addr),
137            ScopedIp::V6(v6) => IpAddr::V6(v6.addr),
138        }
139    }
140
141    pub const fn is_ipv4(&self) -> bool {
142        matches!(self, ScopedIp::V4(_))
143    }
144
145    pub const fn is_ipv6(&self) -> bool {
146        matches!(self, ScopedIp::V6(_))
147    }
148
149    pub const fn is_loopback(&self) -> bool {
150        match self {
151            ScopedIp::V4(v4) => v4.addr.is_loopback(),
152            ScopedIp::V6(v6) => v6.addr.is_loopback(),
153        }
154    }
155}
156
157impl From<IpAddr> for ScopedIp {
158    fn from(ip: IpAddr) -> Self {
159        match ip {
160            IpAddr::V4(v4) => ScopedIp::V4(ScopedIpV4 {
161                addr: v4,
162                interface_ids: vec![],
163            }),
164            IpAddr::V6(v6) => ScopedIp::V6(ScopedIpV6 {
165                addr: v6,
166                scope_id: InterfaceId::default(),
167            }),
168        }
169    }
170}
171
172impl From<&Interface> for ScopedIp {
173    fn from(interface: &Interface) -> Self {
174        match interface.ip() {
175            IpAddr::V4(v4) => ScopedIp::V4(ScopedIpV4 {
176                addr: v4,
177                interface_ids: vec![InterfaceId::from(interface)],
178            }),
179            IpAddr::V6(v6) => ScopedIp::V6(ScopedIpV6 {
180                addr: v6,
181                scope_id: InterfaceId::from(interface),
182            }),
183        }
184    }
185}
186
187impl fmt::Display for ScopedIp {
188    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
189        match self {
190            ScopedIp::V4(v4) => write!(f, "{}", v4.addr),
191            ScopedIp::V6(v6) => {
192                if v6.scope_id.index != 0 && is_unicast_link_local(&v6.addr) {
193                    #[cfg(windows)]
194                    {
195                        write!(f, "{}%{}", v6.addr, v6.scope_id.index)
196                    }
197                    #[cfg(not(windows))]
198                    {
199                        write!(f, "{}%{}", v6.addr, v6.scope_id.name)
200                    }
201                } else {
202                    write!(f, "{}", v6.addr)
203                }
204            }
205        }
206    }
207}
208
209/// DNS resource record types, stored as `u16`. Can do `as u16` when needed.
210///
211/// See [RFC 1035 section 3.2.2](https://datatracker.ietf.org/doc/html/rfc1035#section-3.2.2)
212#[derive(Debug, PartialEq, Eq, Clone, Copy, PartialOrd, Ord)]
213#[non_exhaustive]
214#[repr(u16)]
215pub enum RRType {
216    /// DNS record type for IPv4 address
217    A = 1,
218
219    /// DNS record type for Canonical Name
220    CNAME = 5,
221
222    /// DNS record type for Pointer
223    PTR = 12,
224
225    /// DNS record type for Host Info
226    HINFO = 13,
227
228    /// DNS record type for Text (properties)
229    TXT = 16,
230
231    /// DNS record type for IPv6 address
232    AAAA = 28,
233
234    /// DNS record type for Service
235    SRV = 33,
236
237    /// DNS record type for Negative Responses
238    NSEC = 47,
239
240    /// DNS record type for any records (wildcard)
241    ANY = 255,
242}
243
244impl RRType {
245    /// Converts `u16` into `RRType` if possible.
246    pub const fn from_u16(value: u16) -> Option<Self> {
247        match value {
248            1 => Some(RRType::A),
249            5 => Some(RRType::CNAME),
250            12 => Some(RRType::PTR),
251            13 => Some(RRType::HINFO),
252            16 => Some(RRType::TXT),
253            28 => Some(RRType::AAAA),
254            33 => Some(RRType::SRV),
255            47 => Some(RRType::NSEC),
256            255 => Some(RRType::ANY),
257            _ => None,
258        }
259    }
260}
261
262impl fmt::Display for RRType {
263    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
264        match self {
265            RRType::A => write!(f, "TYPE_A"),
266            RRType::CNAME => write!(f, "TYPE_CNAME"),
267            RRType::PTR => write!(f, "TYPE_PTR"),
268            RRType::HINFO => write!(f, "TYPE_HINFO"),
269            RRType::TXT => write!(f, "TYPE_TXT"),
270            RRType::AAAA => write!(f, "TYPE_AAAA"),
271            RRType::SRV => write!(f, "TYPE_SRV"),
272            RRType::NSEC => write!(f, "TYPE_NSEC"),
273            RRType::ANY => write!(f, "TYPE_ANY"),
274        }
275    }
276}
277
278/// The class value for the Internet.
279pub const CLASS_IN: u16 = 1;
280pub const CLASS_MASK: u16 = 0x7FFF;
281
282/// Cache-flush bit: the most significant bit of the rrclass field of the resource record.  
283pub const CLASS_CACHE_FLUSH: u16 = 0x8000;
284
285/// Max size of UDP datagram payload.
286///
287/// It is calculated as: 9000 bytes - IP header 20 bytes - UDP header 8 bytes.
288/// Reference: [RFC6762 section 17](https://datatracker.ietf.org/doc/html/rfc6762#section-17)
289pub const MAX_MSG_ABSOLUTE: usize = 8972;
290
291const MSG_HEADER_LEN: usize = 12;
292
293// Definitions for DNS message header "flags" field
294//
295// The "flags" field is 16-bit long, in this format:
296// (RFC 1035 section 4.1.1)
297//
298//   0  1  2  3  4  5  6  7  8  9  0  1  2  3  4  5
299// |QR|   Opcode  |AA|TC|RD|RA|   Z    |   RCODE   |
300//
301pub const FLAGS_QR_MASK: u16 = 0x8000; // mask for query/response bit
302
303/// Flag bit to indicate a query
304pub const FLAGS_QR_QUERY: u16 = 0x0000;
305
306/// Flag bit to indicate a response
307pub const FLAGS_QR_RESPONSE: u16 = 0x8000;
308
309/// Flag bit for Authoritative Answer
310pub const FLAGS_AA: u16 = 0x0400;
311
312/// mask for TC(Truncated) bit
313///
314/// 2024-08-10: currently this flag is only supported on the querier side,
315///             not supported on the responder side. I.e. the responder only
316///             handles the first packet and ignore this bit. Since the
317///             additional packets have 0 questions, the processing of them
318///             is no-op.
319///             In practice, this means the responder supports Known-Answer
320///             only with single packet, not multi-packet. The querier supports
321///             both single packet and multi-packet.
322pub const FLAGS_TC: u16 = 0x0200;
323
324/// A convenience type alias for DNS record trait objects.
325pub type DnsRecordBox = Box<dyn DnsRecordExt>;
326
327impl Clone for DnsRecordBox {
328    fn clone(&self) -> Self {
329        self.clone_box()
330    }
331}
332
333const U16_SIZE: usize = 2;
334
335/// Returns `RRType` for a given IP address.
336#[inline]
337pub const fn ip_address_rr_type(address: &IpAddr) -> RRType {
338    match address {
339        IpAddr::V4(_) => RRType::A,
340        IpAddr::V6(_) => RRType::AAAA,
341    }
342}
343
344#[derive(Eq, PartialEq, Debug, Clone)]
345pub struct DnsEntry {
346    pub(crate) name: String, // always lower case.
347    pub(crate) ty: RRType,
348    class: u16,
349    cache_flush: bool,
350}
351
352impl DnsEntry {
353    const fn new(name: String, ty: RRType, class: u16) -> Self {
354        Self {
355            name,
356            ty,
357            class: class & CLASS_MASK,
358            cache_flush: (class & CLASS_CACHE_FLUSH) != 0,
359        }
360    }
361}
362
363/// Common methods for all DNS entries:  questions and resource records.
364pub trait DnsEntryExt: fmt::Debug {
365    fn entry_name(&self) -> &str;
366
367    fn entry_type(&self) -> RRType;
368}
369
370/// A DNS question entry
371#[derive(Debug)]
372pub struct DnsQuestion {
373    pub(crate) entry: DnsEntry,
374}
375
376impl DnsEntryExt for DnsQuestion {
377    fn entry_name(&self) -> &str {
378        &self.entry.name
379    }
380
381    fn entry_type(&self) -> RRType {
382        self.entry.ty
383    }
384}
385
386/// A DNS Resource Record - like a DNS entry, but has a TTL.
387/// RFC: https://www.rfc-editor.org/rfc/rfc1035#section-3.2.1
388///      https://www.rfc-editor.org/rfc/rfc1035#section-4.1.3
389#[derive(Debug, Clone)]
390pub struct DnsRecord {
391    pub(crate) entry: DnsEntry,
392    ttl: u32,     // in seconds, 0 means this record should not be cached
393    created: u64, // UNIX time in millis
394    expires: u64, // expires at this UNIX time in millis
395
396    /// Support re-query an instance before its PTR record expires.
397    /// See https://datatracker.ietf.org/doc/html/rfc6762#section-5.2
398    refresh: u64, // UNIX time in millis
399
400    /// If conflict resolution decides to change the name, this is the new one.
401    new_name: Option<String>,
402}
403
404impl DnsRecord {
405    fn new(name: &str, ty: RRType, class: u16, ttl: u32) -> Self {
406        let created = current_time_millis();
407
408        // From RFC 6762 section 5.2:
409        // "... The querier should plan to issue a query at 80% of the record
410        // lifetime, and then if no answer is received, at 85%, 90%, and 95%."
411        let refresh = get_expiration_time(created, ttl, 80);
412
413        let expires = get_expiration_time(created, ttl, 100);
414
415        Self {
416            entry: DnsEntry::new(name.to_string(), ty, class),
417            ttl,
418            created,
419            expires,
420            refresh,
421            new_name: None,
422        }
423    }
424
425    pub const fn get_ttl(&self) -> u32 {
426        self.ttl
427    }
428
429    pub const fn get_expire_time(&self) -> u64 {
430        self.expires
431    }
432
433    pub const fn get_refresh_time(&self) -> u64 {
434        self.refresh
435    }
436
437    pub const fn is_expired(&self, now: u64) -> bool {
438        now >= self.expires
439    }
440
441    /// Returns whether record expires in 1 second.
442    ///
443    /// This is useful because mDNS sets TTL to 1 (not 0) for expiring records.
444    pub const fn expires_soon(&self, now: u64) -> bool {
445        now + 1000 >= self.expires
446    }
447
448    pub const fn refresh_due(&self, now: u64) -> bool {
449        now >= self.refresh
450    }
451
452    /// Returns whether `now` (in millis) has passed half of TTL.
453    pub fn halflife_passed(&self, now: u64) -> bool {
454        let halflife = get_expiration_time(self.created, self.ttl, 50);
455        now > halflife
456    }
457
458    pub fn is_unique(&self) -> bool {
459        self.entry.cache_flush
460    }
461
462    /// Updates the refresh time to be the same as the expire time so that
463    /// this record will not refresh again and will just expire.
464    pub fn refresh_no_more(&mut self) {
465        self.refresh = get_expiration_time(self.created, self.ttl, 100);
466    }
467
468    /// Returns if this record is due for refresh. If yes, `refresh` time is updated.
469    pub fn refresh_maybe(&mut self, now: u64) -> bool {
470        if self.is_expired(now) || !self.refresh_due(now) {
471            return false;
472        }
473
474        trace!(
475            "{} qtype {} is due to refresh",
476            &self.entry.name,
477            self.entry.ty
478        );
479
480        // From RFC 6762 section 5.2:
481        // "... The querier should plan to issue a query at 80% of the record
482        // lifetime, and then if no answer is received, at 85%, 90%, and 95%."
483        //
484        // If the answer is received in time, 'refresh' will be reset outside
485        // this function, back to 80% of the new TTL.
486        if self.refresh == get_expiration_time(self.created, self.ttl, 80) {
487            self.refresh = get_expiration_time(self.created, self.ttl, 85);
488        } else if self.refresh == get_expiration_time(self.created, self.ttl, 85) {
489            self.refresh = get_expiration_time(self.created, self.ttl, 90);
490        } else if self.refresh == get_expiration_time(self.created, self.ttl, 90) {
491            self.refresh = get_expiration_time(self.created, self.ttl, 95);
492        } else {
493            self.refresh_no_more();
494        }
495
496        true
497    }
498
499    /// Returns the remaining TTL in seconds
500    fn get_remaining_ttl(&self, now: u64) -> u32 {
501        let remaining_millis = get_expiration_time(self.created, self.ttl, 100) - now;
502        cmp::max(0, remaining_millis / 1000) as u32
503    }
504
505    /// Return the absolute time for this record being created
506    pub const fn get_created(&self) -> u64 {
507        self.created
508    }
509
510    /// Set the absolute expiration time in millis
511    fn set_expire(&mut self, expire_at: u64) {
512        self.expires = expire_at;
513    }
514
515    fn reset_ttl(&mut self, other: &Self) {
516        self.ttl = other.ttl;
517        self.created = other.created;
518        self.expires = get_expiration_time(self.created, self.ttl, 100);
519        self.refresh = if self.ttl > 1 {
520            get_expiration_time(self.created, self.ttl, 80)
521        } else {
522            // If TTL is 1, it means this record is expiring,
523            // then we set refresh to the same time as expires.
524            self.expires
525        };
526    }
527
528    /// Modify TTL to reflect the remaining life time from `now`.
529    pub fn update_ttl(&mut self, now: u64) {
530        if now > self.created {
531            let elapsed = now - self.created;
532            self.ttl -= (elapsed / 1000) as u32;
533        }
534    }
535
536    pub fn set_new_name(&mut self, new_name: String) {
537        if new_name == self.entry.name {
538            self.new_name = None;
539        } else {
540            self.new_name = Some(new_name);
541        }
542    }
543
544    pub fn get_new_name(&self) -> Option<&str> {
545        self.new_name.as_deref()
546    }
547
548    /// Return the new name if exists, otherwise the regular name in DnsEntry.
549    pub(crate) fn get_name(&self) -> &str {
550        self.new_name.as_deref().unwrap_or(&self.entry.name)
551    }
552
553    pub fn get_original_name(&self) -> &str {
554        &self.entry.name
555    }
556}
557
558impl PartialEq for DnsRecord {
559    fn eq(&self, other: &Self) -> bool {
560        self.entry == other.entry
561    }
562}
563
564/// Common methods for DNS resource records.
565pub trait DnsRecordExt: fmt::Debug {
566    fn get_record(&self) -> &DnsRecord;
567    fn get_record_mut(&mut self) -> &mut DnsRecord;
568    fn write(&self, packet: &mut DnsOutPacket);
569    fn any(&self) -> &dyn Any;
570
571    /// Returns whether `other` record is considered the same except TTL.
572    fn matches(&self, other: &dyn DnsRecordExt) -> bool;
573
574    /// Returns whether `other` record has the same rdata.
575    fn rrdata_match(&self, other: &dyn DnsRecordExt) -> bool;
576
577    /// Returns the result based on a byte-level comparison of `rdata`.
578    /// If `other` is not valid, returns `Greater`.
579    fn compare_rdata(&self, other: &dyn DnsRecordExt) -> cmp::Ordering;
580
581    /// Returns the result based on "lexicographically later" defined below.
582    fn compare(&self, other: &dyn DnsRecordExt) -> cmp::Ordering {
583        /*
584        RFC 6762: https://datatracker.ietf.org/doc/html/rfc6762#section-8.2
585
586        ... The determination of "lexicographically later" is performed by first
587        comparing the record class (excluding the cache-flush bit described
588        in Section 10.2), then the record type, then raw comparison of the
589        binary content of the rdata without regard for meaning or structure.
590        If the record classes differ, then the numerically greater class is
591        considered "lexicographically later".  Otherwise, if the record types
592        differ, then the numerically greater type is considered
593        "lexicographically later".  If the rrtype and rrclass both match,
594        then the rdata is compared. ...
595        */
596        match self.get_class().cmp(&other.get_class()) {
597            cmp::Ordering::Equal => match self.get_type().cmp(&other.get_type()) {
598                cmp::Ordering::Equal => self.compare_rdata(other),
599                not_equal => not_equal,
600            },
601            not_equal => not_equal,
602        }
603    }
604
605    /// Returns a human-readable string of rdata.
606    fn rdata_print(&self) -> String;
607
608    /// Returns the class only, excluding class_flush / unique bit.
609    fn get_class(&self) -> u16 {
610        self.get_record().entry.class
611    }
612
613    fn get_cache_flush(&self) -> bool {
614        self.get_record().entry.cache_flush
615    }
616
617    /// Return the new name if exists, otherwise the regular name in DnsEntry.
618    fn get_name(&self) -> &str {
619        self.get_record().get_name()
620    }
621
622    fn get_type(&self) -> RRType {
623        self.get_record().entry.ty
624    }
625
626    /// Resets TTL using `other` record.
627    /// `self.refresh` and `self.expires` are also reset.
628    fn reset_ttl(&mut self, other: &dyn DnsRecordExt) {
629        self.get_record_mut().reset_ttl(other.get_record());
630    }
631
632    fn get_created(&self) -> u64 {
633        self.get_record().get_created()
634    }
635
636    fn get_expire(&self) -> u64 {
637        self.get_record().get_expire_time()
638    }
639
640    fn set_expire(&mut self, expire_at: u64) {
641        self.get_record_mut().set_expire(expire_at);
642    }
643
644    /// Set expire as `expire_at` if it is sooner than the current `expire`.
645    fn set_expire_sooner(&mut self, expire_at: u64) {
646        if expire_at < self.get_expire() {
647            self.get_record_mut().set_expire(expire_at);
648        }
649    }
650
651    /// Returns true if the record expires in 1 second from `now`.
652    fn expires_soon(&self, now: u64) -> bool {
653        self.get_record().expires_soon(now)
654    }
655
656    /// Given `now`, if the record is due to refresh, this method updates the refresh time
657    /// and returns the new refresh time. Otherwise, returns None.
658    fn updated_refresh_time(&mut self, now: u64) -> Option<u64> {
659        if self.get_record_mut().refresh_maybe(now) {
660            Some(self.get_record().get_refresh_time())
661        } else {
662            None
663        }
664    }
665
666    /// Returns true if another record has matched content,
667    /// and if its TTL is at least half of this record's.
668    fn suppressed_by_answer(&self, other: &dyn DnsRecordExt) -> bool {
669        self.matches(other) && (other.get_record().ttl > self.get_record().ttl / 2)
670    }
671
672    /// Required by RFC 6762 Section 7.1: Known-Answer Suppression.
673    fn suppressed_by(&self, msg: &DnsIncoming) -> bool {
674        for answer in msg.answers.iter() {
675            if self.suppressed_by_answer(answer.as_ref()) {
676                return true;
677            }
678        }
679        false
680    }
681
682    fn clone_box(&self) -> DnsRecordBox;
683
684    fn boxed(self) -> DnsRecordBox;
685}
686
687/// Resource Record for IPv4 address or IPv6 address.
688#[derive(Debug, Clone)]
689pub(crate) struct DnsAddress {
690    pub(crate) record: DnsRecord,
691    address: IpAddr,
692    pub(crate) interface_id: InterfaceId,
693}
694
695impl DnsAddress {
696    pub fn new(
697        name: &str,
698        ty: RRType,
699        class: u16,
700        ttl: u32,
701        address: IpAddr,
702        interface_id: InterfaceId,
703    ) -> Self {
704        let record = DnsRecord::new(name, ty, class, ttl);
705        Self {
706            record,
707            address,
708            interface_id,
709        }
710    }
711
712    pub fn address(&self) -> ScopedIp {
713        match self.address {
714            IpAddr::V4(v4) => ScopedIp::V4(ScopedIpV4 {
715                addr: v4,
716                interface_ids: vec![self.interface_id.clone()],
717            }),
718            IpAddr::V6(v6) => ScopedIp::V6(ScopedIpV6 {
719                addr: v6,
720                scope_id: self.interface_id.clone(),
721            }),
722        }
723    }
724}
725
726impl DnsRecordExt for DnsAddress {
727    fn get_record(&self) -> &DnsRecord {
728        &self.record
729    }
730
731    fn get_record_mut(&mut self) -> &mut DnsRecord {
732        &mut self.record
733    }
734
735    fn write(&self, packet: &mut DnsOutPacket) {
736        match self.address {
737            IpAddr::V4(addr) => packet.write_bytes(addr.octets().as_ref()),
738            IpAddr::V6(addr) => packet.write_bytes(addr.octets().as_ref()),
739        };
740    }
741
742    fn any(&self) -> &dyn Any {
743        self
744    }
745
746    fn matches(&self, other: &dyn DnsRecordExt) -> bool {
747        if let Some(other_a) = other.any().downcast_ref::<Self>() {
748            return self.address == other_a.address
749                && self.record.entry == other_a.record.entry
750                && self.interface_id == other_a.interface_id;
751        }
752        false
753    }
754
755    fn rrdata_match(&self, other: &dyn DnsRecordExt) -> bool {
756        if let Some(other_a) = other.any().downcast_ref::<Self>() {
757            return self.address == other_a.address;
758        }
759        false
760    }
761
762    fn compare_rdata(&self, other: &dyn DnsRecordExt) -> cmp::Ordering {
763        if let Some(other_a) = other.any().downcast_ref::<Self>() {
764            self.address.cmp(&other_a.address)
765        } else {
766            cmp::Ordering::Greater
767        }
768    }
769
770    fn rdata_print(&self) -> String {
771        format!("{}", self.address)
772    }
773
774    fn clone_box(&self) -> DnsRecordBox {
775        Box::new(self.clone())
776    }
777
778    fn boxed(self) -> DnsRecordBox {
779        Box::new(self)
780    }
781}
782
783/// Resource Record for a DNS pointer
784#[derive(Debug, Clone)]
785pub struct DnsPointer {
786    record: DnsRecord,
787    alias: String, // the full name of Service Instance
788}
789
790impl DnsPointer {
791    pub fn new(name: &str, ty: RRType, class: u16, ttl: u32, alias: String) -> Self {
792        let record = DnsRecord::new(name, ty, class, ttl);
793        Self { record, alias }
794    }
795
796    pub fn alias(&self) -> &str {
797        &self.alias
798    }
799}
800
801impl DnsRecordExt for DnsPointer {
802    fn get_record(&self) -> &DnsRecord {
803        &self.record
804    }
805
806    fn get_record_mut(&mut self) -> &mut DnsRecord {
807        &mut self.record
808    }
809
810    fn write(&self, packet: &mut DnsOutPacket) {
811        packet.write_name(&self.alias);
812    }
813
814    fn any(&self) -> &dyn Any {
815        self
816    }
817
818    fn matches(&self, other: &dyn DnsRecordExt) -> bool {
819        if let Some(other_ptr) = other.any().downcast_ref::<Self>() {
820            return self.alias == other_ptr.alias && self.record.entry == other_ptr.record.entry;
821        }
822        false
823    }
824
825    fn rrdata_match(&self, other: &dyn DnsRecordExt) -> bool {
826        if let Some(other_ptr) = other.any().downcast_ref::<Self>() {
827            return self.alias == other_ptr.alias;
828        }
829        false
830    }
831
832    fn compare_rdata(&self, other: &dyn DnsRecordExt) -> cmp::Ordering {
833        if let Some(other_ptr) = other.any().downcast_ref::<Self>() {
834            self.alias.cmp(&other_ptr.alias)
835        } else {
836            cmp::Ordering::Greater
837        }
838    }
839
840    fn rdata_print(&self) -> String {
841        self.alias.clone()
842    }
843
844    fn clone_box(&self) -> DnsRecordBox {
845        Box::new(self.clone())
846    }
847
848    fn boxed(self) -> DnsRecordBox {
849        Box::new(self)
850    }
851}
852
853/// Resource Record for a DNS service.
854#[derive(Debug, Clone)]
855pub struct DnsSrv {
856    pub(crate) record: DnsRecord,
857    pub(crate) priority: u16, // lower number means higher priority. Should be 0 in common cases.
858    pub(crate) weight: u16,   // Should be 0 in common cases
859    host: String,
860    port: u16,
861}
862
863impl DnsSrv {
864    pub fn new(
865        name: &str,
866        class: u16,
867        ttl: u32,
868        priority: u16,
869        weight: u16,
870        port: u16,
871        host: String,
872    ) -> Self {
873        let record = DnsRecord::new(name, RRType::SRV, class, ttl);
874        Self {
875            record,
876            priority,
877            weight,
878            host,
879            port,
880        }
881    }
882
883    pub fn host(&self) -> &str {
884        &self.host
885    }
886
887    pub fn port(&self) -> u16 {
888        self.port
889    }
890
891    pub fn set_host(&mut self, host: String) {
892        self.host = host;
893    }
894}
895
896impl DnsRecordExt for DnsSrv {
897    fn get_record(&self) -> &DnsRecord {
898        &self.record
899    }
900
901    fn get_record_mut(&mut self) -> &mut DnsRecord {
902        &mut self.record
903    }
904
905    fn write(&self, packet: &mut DnsOutPacket) {
906        packet.write_short(self.priority);
907        packet.write_short(self.weight);
908        packet.write_short(self.port);
909        packet.write_name(&self.host);
910    }
911
912    fn any(&self) -> &dyn Any {
913        self
914    }
915
916    fn matches(&self, other: &dyn DnsRecordExt) -> bool {
917        if let Some(other_svc) = other.any().downcast_ref::<Self>() {
918            return self.host == other_svc.host
919                && self.port == other_svc.port
920                && self.weight == other_svc.weight
921                && self.priority == other_svc.priority
922                && self.record.entry == other_svc.record.entry;
923        }
924        false
925    }
926
927    fn rrdata_match(&self, other: &dyn DnsRecordExt) -> bool {
928        if let Some(other_srv) = other.any().downcast_ref::<Self>() {
929            return self.host == other_srv.host
930                && self.port == other_srv.port
931                && self.weight == other_srv.weight
932                && self.priority == other_srv.priority;
933        }
934        false
935    }
936
937    fn compare_rdata(&self, other: &dyn DnsRecordExt) -> cmp::Ordering {
938        let Some(other_srv) = other.any().downcast_ref::<Self>() else {
939            return cmp::Ordering::Greater;
940        };
941
942        // 1. compare `priority`
943        match self
944            .priority
945            .to_be_bytes()
946            .cmp(&other_srv.priority.to_be_bytes())
947        {
948            cmp::Ordering::Equal => {
949                // 2. compare `weight`
950                match self
951                    .weight
952                    .to_be_bytes()
953                    .cmp(&other_srv.weight.to_be_bytes())
954                {
955                    cmp::Ordering::Equal => {
956                        // 3. compare `port`.
957                        match self.port.to_be_bytes().cmp(&other_srv.port.to_be_bytes()) {
958                            cmp::Ordering::Equal => self.host.cmp(&other_srv.host),
959                            not_equal => not_equal,
960                        }
961                    }
962                    not_equal => not_equal,
963                }
964            }
965            not_equal => not_equal,
966        }
967    }
968
969    fn rdata_print(&self) -> String {
970        format!(
971            "priority: {}, weight: {}, port: {}, host: {}",
972            self.priority, self.weight, self.port, self.host
973        )
974    }
975
976    fn clone_box(&self) -> DnsRecordBox {
977        Box::new(self.clone())
978    }
979
980    fn boxed(self) -> DnsRecordBox {
981        Box::new(self)
982    }
983}
984
985/// Resource Record for a DNS TXT record.
986///
987/// From [RFC 6763 section 6]:
988///
989/// The format of each constituent string within the DNS TXT record is a
990/// single length byte, followed by 0-255 bytes of text data.
991///
992/// DNS-SD uses DNS TXT records to store arbitrary key/value pairs
993///    conveying additional information about the named service.  Each
994///    key/value pair is encoded as its own constituent string within the
995///    DNS TXT record, in the form "key=value" (without the quotation
996///    marks).  Everything up to the first '=' character is the key (Section
997///    6.4).  Everything after the first '=' character to the end of the
998///    string (including subsequent '=' characters, if any) is the value
999#[derive(Clone)]
1000pub struct DnsTxt {
1001    pub(crate) record: DnsRecord,
1002    text: Vec<u8>,
1003}
1004
1005impl DnsTxt {
1006    pub fn new(name: &str, class: u16, ttl: u32, text: Vec<u8>) -> Self {
1007        let record = DnsRecord::new(name, RRType::TXT, class, ttl);
1008        Self { record, text }
1009    }
1010
1011    pub fn text(&self) -> &[u8] {
1012        &self.text
1013    }
1014}
1015
1016impl DnsRecordExt for DnsTxt {
1017    fn get_record(&self) -> &DnsRecord {
1018        &self.record
1019    }
1020
1021    fn get_record_mut(&mut self) -> &mut DnsRecord {
1022        &mut self.record
1023    }
1024
1025    fn write(&self, packet: &mut DnsOutPacket) {
1026        packet.write_bytes(&self.text);
1027    }
1028
1029    fn any(&self) -> &dyn Any {
1030        self
1031    }
1032
1033    fn matches(&self, other: &dyn DnsRecordExt) -> bool {
1034        if let Some(other_txt) = other.any().downcast_ref::<Self>() {
1035            return self.text == other_txt.text && self.record.entry == other_txt.record.entry;
1036        }
1037        false
1038    }
1039
1040    fn rrdata_match(&self, other: &dyn DnsRecordExt) -> bool {
1041        if let Some(other_txt) = other.any().downcast_ref::<Self>() {
1042            return self.text == other_txt.text;
1043        }
1044        false
1045    }
1046
1047    fn compare_rdata(&self, other: &dyn DnsRecordExt) -> cmp::Ordering {
1048        if let Some(other_txt) = other.any().downcast_ref::<Self>() {
1049            self.text.cmp(&other_txt.text)
1050        } else {
1051            cmp::Ordering::Greater
1052        }
1053    }
1054
1055    fn rdata_print(&self) -> String {
1056        format!("{:?}", decode_txt(&self.text))
1057    }
1058
1059    fn clone_box(&self) -> DnsRecordBox {
1060        Box::new(self.clone())
1061    }
1062
1063    fn boxed(self) -> DnsRecordBox {
1064        Box::new(self)
1065    }
1066}
1067
1068impl fmt::Debug for DnsTxt {
1069    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1070        let properties = decode_txt(&self.text);
1071        write!(
1072            f,
1073            "DnsTxt {{ record: {:?}, text: {:?} }}",
1074            self.record, properties
1075        )
1076    }
1077}
1078
1079// Convert from DNS TXT record content to key/value pairs
1080fn decode_txt(txt: &[u8]) -> Vec<TxtProperty> {
1081    let mut properties = Vec::new();
1082    let mut offset = 0;
1083    while offset < txt.len() {
1084        let length = txt[offset] as usize;
1085        if length == 0 {
1086            break; // reached the end
1087        }
1088        offset += 1; // move over the length byte
1089
1090        let offset_end = offset + length;
1091        if offset_end > txt.len() {
1092            trace!("ERROR: DNS TXT: size given for property is out of range. (offset={}, length={}, offset_end={}, record length={})", offset, length, offset_end, txt.len());
1093            break; // Skipping the rest of the record content, as the size for this property would already be out of range.
1094        }
1095        let kv_bytes = &txt[offset..offset_end];
1096
1097        // split key and val using the first `=`
1098        let (k, v) = kv_bytes.iter().position(|&x| x == b'=').map_or_else(
1099            || (kv_bytes.to_vec(), None),
1100            |idx| (kv_bytes[..idx].to_vec(), Some(kv_bytes[idx + 1..].to_vec())),
1101        );
1102
1103        // Make sure the key can be stored in UTF-8.
1104        match String::from_utf8(k) {
1105            Ok(k_string) => {
1106                properties.push(TxtProperty {
1107                    key: k_string,
1108                    val: v,
1109                });
1110            }
1111            Err(e) => trace!("ERROR: convert to String from key: {}", e),
1112        }
1113
1114        offset += length;
1115    }
1116
1117    properties
1118}
1119
1120/// Represents a property in a TXT record.
1121#[derive(Clone, PartialEq, Eq)]
1122pub struct TxtProperty {
1123    /// The name of the property. The original cases are kept.
1124    key: String,
1125
1126    /// RFC 6763 says values are bytes, not necessarily UTF-8.
1127    /// It is also possible that there is no value, in which case
1128    /// the key is a boolean key.
1129    val: Option<Vec<u8>>,
1130}
1131
1132impl TxtProperty {
1133    /// Returns the value of a property as str.
1134    pub fn val_str(&self) -> &str {
1135        self.val
1136            .as_ref()
1137            .map_or("", |v| std::str::from_utf8(&v[..]).unwrap_or_default())
1138    }
1139}
1140
1141/// Supports constructing from a tuple.
1142impl<K, V> From<&(K, V)> for TxtProperty
1143where
1144    K: ToString,
1145    V: ToString,
1146{
1147    fn from(prop: &(K, V)) -> Self {
1148        Self {
1149            key: prop.0.to_string(),
1150            val: Some(prop.1.to_string().into_bytes()),
1151        }
1152    }
1153}
1154
1155impl<K, V> From<(K, V)> for TxtProperty
1156where
1157    K: ToString,
1158    V: AsRef<[u8]>,
1159{
1160    fn from(prop: (K, V)) -> Self {
1161        Self {
1162            key: prop.0.to_string(),
1163            val: Some(prop.1.as_ref().into()),
1164        }
1165    }
1166}
1167
1168/// Support a property that has no value.
1169impl From<&str> for TxtProperty {
1170    fn from(key: &str) -> Self {
1171        Self {
1172            key: key.to_string(),
1173            val: None,
1174        }
1175    }
1176}
1177
1178impl fmt::Display for TxtProperty {
1179    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1180        write!(f, "{}={}", self.key, self.val_str())
1181    }
1182}
1183
1184/// Mimic the default debug output for a struct, with a twist:
1185/// - If self.var is UTF-8, will output it as a string in double quotes.
1186/// - If self.var is not UTF-8, will output its bytes as in hex.
1187impl fmt::Debug for TxtProperty {
1188    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1189        let val_string = self.val.as_ref().map_or_else(
1190            || "None".to_string(),
1191            |v| {
1192                std::str::from_utf8(&v[..]).map_or_else(
1193                    |_| format!("Some({})", u8_slice_to_hex(&v[..])),
1194                    |s| format!("Some(\"{s}\")"),
1195                )
1196            },
1197        );
1198
1199        write!(
1200            f,
1201            "TxtProperty {{key: \"{}\", val: {}}}",
1202            &self.key, &val_string,
1203        )
1204    }
1205}
1206
1207const HEX_TABLE: [char; 16] = [
1208    '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f',
1209];
1210
1211/// Create a hex string from `slice`, with a "0x" prefix.
1212///
1213/// For example, [1u8, 2u8] -> "0x0102"
1214fn u8_slice_to_hex(slice: &[u8]) -> String {
1215    let mut hex = String::with_capacity(slice.len() * 2 + 2);
1216    hex.push_str("0x");
1217    for b in slice {
1218        hex.push(HEX_TABLE[(b >> 4) as usize]);
1219        hex.push(HEX_TABLE[(b & 0x0F) as usize]);
1220    }
1221    hex
1222}
1223
1224/// A DNS host information record
1225#[derive(Debug, Clone)]
1226struct DnsHostInfo {
1227    record: DnsRecord,
1228    cpu: String,
1229    os: String,
1230}
1231
1232impl DnsHostInfo {
1233    fn new(name: &str, ty: RRType, class: u16, ttl: u32, cpu: String, os: String) -> Self {
1234        let record = DnsRecord::new(name, ty, class, ttl);
1235        Self { record, cpu, os }
1236    }
1237}
1238
1239impl DnsRecordExt for DnsHostInfo {
1240    fn get_record(&self) -> &DnsRecord {
1241        &self.record
1242    }
1243
1244    fn get_record_mut(&mut self) -> &mut DnsRecord {
1245        &mut self.record
1246    }
1247
1248    fn write(&self, packet: &mut DnsOutPacket) {
1249        debug!("Writing HInfo: cpu {} os {}", &self.cpu, &self.os);
1250        packet.write_bytes(self.cpu.as_bytes());
1251        packet.write_bytes(self.os.as_bytes());
1252    }
1253
1254    fn any(&self) -> &dyn Any {
1255        self
1256    }
1257
1258    fn matches(&self, other: &dyn DnsRecordExt) -> bool {
1259        if let Some(other_hinfo) = other.any().downcast_ref::<Self>() {
1260            return self.cpu == other_hinfo.cpu
1261                && self.os == other_hinfo.os
1262                && self.record.entry == other_hinfo.record.entry;
1263        }
1264        false
1265    }
1266
1267    fn rrdata_match(&self, other: &dyn DnsRecordExt) -> bool {
1268        if let Some(other_hinfo) = other.any().downcast_ref::<Self>() {
1269            return self.cpu == other_hinfo.cpu && self.os == other_hinfo.os;
1270        }
1271        false
1272    }
1273
1274    fn compare_rdata(&self, other: &dyn DnsRecordExt) -> cmp::Ordering {
1275        if let Some(other_hinfo) = other.any().downcast_ref::<Self>() {
1276            match self.cpu.cmp(&other_hinfo.cpu) {
1277                cmp::Ordering::Equal => self.os.cmp(&other_hinfo.os),
1278                ordering => ordering,
1279            }
1280        } else {
1281            cmp::Ordering::Greater
1282        }
1283    }
1284
1285    fn rdata_print(&self) -> String {
1286        format!("cpu: {}, os: {}", self.cpu, self.os)
1287    }
1288
1289    fn clone_box(&self) -> DnsRecordBox {
1290        Box::new(self.clone())
1291    }
1292
1293    fn boxed(self) -> DnsRecordBox {
1294        Box::new(self)
1295    }
1296}
1297
1298/// Resource Record for negative responses
1299///
1300/// [RFC4034 section 4.1](https://datatracker.ietf.org/doc/html/rfc4034#section-4.1)
1301/// and
1302/// [RFC6762 section 6.1](https://datatracker.ietf.org/doc/html/rfc6762#section-6.1)
1303#[derive(Debug, Clone)]
1304pub struct DnsNSec {
1305    record: DnsRecord,
1306    next_domain: String,
1307    type_bitmap: Vec<u8>,
1308}
1309
1310impl DnsNSec {
1311    pub fn new(
1312        name: &str,
1313        class: u16,
1314        ttl: u32,
1315        next_domain: String,
1316        type_bitmap: Vec<u8>,
1317    ) -> Self {
1318        let record = DnsRecord::new(name, RRType::NSEC, class, ttl);
1319        Self {
1320            record,
1321            next_domain,
1322            type_bitmap,
1323        }
1324    }
1325
1326    /// Returns the types marked by `type_bitmap`
1327    pub fn _types(&self) -> Vec<u16> {
1328        // From RFC 4034: 4.1.2 The Type Bit Maps Field
1329        // https://datatracker.ietf.org/doc/html/rfc4034#section-4.1.2
1330        //
1331        // Each bitmap encodes the low-order 8 bits of RR types within the
1332        // window block, in network bit order.  The first bit is bit 0.  For
1333        // window block 0, bit 1 corresponds to RR type 1 (A), bit 2 corresponds
1334        // to RR type 2 (NS), and so forth.
1335
1336        let mut bit_num = 0;
1337        let mut results = Vec::new();
1338
1339        for byte in self.type_bitmap.iter() {
1340            let mut bit_mask: u8 = 0x80; // for bit 0 in network bit order
1341
1342            // check every bit in this byte, one by one.
1343            for _ in 0..8 {
1344                if (byte & bit_mask) != 0 {
1345                    results.push(bit_num);
1346                }
1347                bit_num += 1;
1348                bit_mask >>= 1; // mask for the next bit
1349            }
1350        }
1351        results
1352    }
1353}
1354
1355impl DnsRecordExt for DnsNSec {
1356    fn get_record(&self) -> &DnsRecord {
1357        &self.record
1358    }
1359
1360    fn get_record_mut(&mut self) -> &mut DnsRecord {
1361        &mut self.record
1362    }
1363
1364    fn write(&self, packet: &mut DnsOutPacket) {
1365        packet.write_bytes(self.next_domain.as_bytes());
1366        packet.write_bytes(&self.type_bitmap);
1367    }
1368
1369    fn any(&self) -> &dyn Any {
1370        self
1371    }
1372
1373    fn matches(&self, other: &dyn DnsRecordExt) -> bool {
1374        if let Some(other_record) = other.any().downcast_ref::<Self>() {
1375            return self.next_domain == other_record.next_domain
1376                && self.type_bitmap == other_record.type_bitmap
1377                && self.record.entry == other_record.record.entry;
1378        }
1379        false
1380    }
1381
1382    fn rrdata_match(&self, other: &dyn DnsRecordExt) -> bool {
1383        if let Some(other_record) = other.any().downcast_ref::<Self>() {
1384            return self.next_domain == other_record.next_domain
1385                && self.type_bitmap == other_record.type_bitmap;
1386        }
1387        false
1388    }
1389
1390    fn compare_rdata(&self, other: &dyn DnsRecordExt) -> cmp::Ordering {
1391        if let Some(other_nsec) = other.any().downcast_ref::<Self>() {
1392            match self.next_domain.cmp(&other_nsec.next_domain) {
1393                cmp::Ordering::Equal => self.type_bitmap.cmp(&other_nsec.type_bitmap),
1394                ordering => ordering,
1395            }
1396        } else {
1397            cmp::Ordering::Greater
1398        }
1399    }
1400
1401    fn rdata_print(&self) -> String {
1402        format!(
1403            "next_domain: {}, type_bitmap len: {}",
1404            self.next_domain,
1405            self.type_bitmap.len()
1406        )
1407    }
1408
1409    fn clone_box(&self) -> DnsRecordBox {
1410        Box::new(self.clone())
1411    }
1412
1413    fn boxed(self) -> DnsRecordBox {
1414        Box::new(self)
1415    }
1416}
1417
1418#[derive(PartialEq)]
1419enum PacketState {
1420    Init = 0,
1421    Finished = 1,
1422}
1423
1424/// A single packet for outgoing DNS message.
1425pub struct DnsOutPacket {
1426    /// All bytes in `data` is the actual packet on the wire.
1427    data: Vec<u8>,
1428
1429    /// An internal state, not defined by DNS.
1430    state: PacketState,
1431
1432    /// k: name, v: offset
1433    names: HashMap<String, u16>,
1434}
1435
1436impl DnsOutPacket {
1437    fn new() -> Self {
1438        Self {
1439            data: vec![0; MSG_HEADER_LEN],
1440            state: PacketState::Init,
1441            names: HashMap::new(),
1442        }
1443    }
1444
1445    pub fn size(&self) -> usize {
1446        self.data.len()
1447    }
1448
1449    pub fn as_bytes(&self) -> &[u8] {
1450        &self.data
1451    }
1452
1453    fn write_question(&mut self, question: &DnsQuestion) {
1454        self.write_name(&question.entry.name);
1455        self.write_short(question.entry.ty as u16);
1456        self.write_short(question.entry.class);
1457    }
1458
1459    /// Writes a record (answer, authoritative answer, additional)
1460    /// Returns false if the packet exceeds the max size with this record, nothing is written to the packet.
1461    /// otherwise returns true.
1462    fn write_record(&mut self, record_ext: &dyn DnsRecordExt, now: u64) -> bool {
1463        let start_size = self.size();
1464
1465        let record = record_ext.get_record();
1466        self.write_name(record.get_name());
1467        self.write_short(record.entry.ty as u16);
1468        if record.entry.cache_flush {
1469            // check "multicast"
1470            self.write_short(record.entry.class | CLASS_CACHE_FLUSH);
1471        } else {
1472            self.write_short(record.entry.class);
1473        }
1474
1475        if now == 0 {
1476            self.write_u32(record.ttl);
1477        } else {
1478            self.write_u32(record.get_remaining_ttl(now));
1479        }
1480
1481        // Placeholder for record size
1482        self.write_short(0);
1483        let record_offset = self.size();
1484        record_ext.write(self);
1485        self.insert_short(record_offset - 2, (self.size() - record_offset) as u16);
1486
1487        if self.size() > MAX_MSG_ABSOLUTE {
1488            self.data.truncate(start_size);
1489            self.state = PacketState::Finished;
1490            return false;
1491        }
1492
1493        true
1494    }
1495
1496    pub(crate) fn insert_short(&mut self, index: usize, value: u16) {
1497        self.data[index..index + 2].copy_from_slice(&value.to_be_bytes());
1498    }
1499
1500    /// Parses a DNS name that may contain escaped characters according to RFC 6763 Section 4.3.
1501    /// Returns a vector of labels where each label is the unescaped content.
1502    ///
1503    /// Escape sequences:
1504    /// - \\. becomes . (literal dot)
1505    /// - \\\\ becomes \\ (literal backslash)
1506    fn parse_escaped_name(name: &str) -> Vec<String> {
1507        let mut labels = Vec::new();
1508        let mut current_label = String::new();
1509        let mut chars = name.chars().peekable();
1510
1511        while let Some(ch) = chars.next() {
1512            match ch {
1513                '\\' => {
1514                    // Backslash escape sequence
1515                    if let Some(&next_ch) = chars.peek() {
1516                        match next_ch {
1517                            '.' | '\\' => {
1518                                // \\. or \\\\ - consume the backslash and add the escaped char
1519                                chars.next();
1520                                current_label.push(next_ch);
1521                            }
1522                            _ => {
1523                                // Not a recognized escape - treat backslash literally
1524                                current_label.push(ch);
1525                            }
1526                        }
1527                    } else {
1528                        // Trailing backslash - add it literally
1529                        current_label.push(ch);
1530                    }
1531                }
1532                '.' => {
1533                    // Unescaped dot - label separator
1534                    if !current_label.is_empty() {
1535                        labels.push(current_label.clone());
1536                        current_label.clear();
1537                    }
1538                }
1539                _ => {
1540                    current_label.push(ch);
1541                }
1542            }
1543        }
1544
1545        // Add the last label if not empty
1546        if !current_label.is_empty() {
1547            labels.push(current_label);
1548        }
1549
1550        labels
1551    }
1552
1553    // Write name to packet
1554    //
1555    // [RFC1035]
1556    // 4.1.4. Message compression
1557    //
1558    // In order to reduce the size of messages, the domain system utilizes a
1559    // compression scheme which eliminates the repetition of domain names in a
1560    // message.  In this scheme, an entire domain name or a list of labels at
1561    // the end of a domain name is replaced with a pointer to a prior occurrence
1562    // of the same name.
1563    // The pointer takes the form of a two octet sequence:
1564    //     +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
1565    //     | 1  1|                OFFSET                   |
1566    //     +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
1567    // The first two bits are ones.  This allows a pointer to be distinguished
1568    // from a label, since the label must begin with two zero bits because
1569    // labels are restricted to 63 octets or less.  (The 10 and 01 combinations
1570    // are reserved for future use.)  The OFFSET field specifies an offset from
1571    // the start of the message (i.e., the first octet of the ID field in the
1572    // domain header).  A zero offset specifies the first byte of the ID field,
1573    // etc.
1574    //
1575    // This function also handles RFC 6763 Section 4.3 escaping where dots and backslashes
1576    // in instance names are escaped (e.g., "My\\.Service" represents a single label "My.Service").
1577    // The actual name sent over the wire is the unescaped version.
1578    fn write_name(&mut self, name: &str) {
1579        // Remove trailing dot if present
1580        let name_to_parse = name.strip_suffix('.').unwrap_or(name);
1581
1582        // Parse the name considering escape sequences
1583        let labels = Self::parse_escaped_name(name_to_parse);
1584
1585        if labels.is_empty() {
1586            self.write_byte(0);
1587            return;
1588        }
1589
1590        // Write each label
1591        for (i, label) in labels.iter().enumerate() {
1592            // Build the remaining name for compression (with dots as separators)
1593            let remaining: String = labels[i..].join(".");
1594
1595            // Check if we can use compression for the remaining part
1596            const POINTER_MASK: u16 = 0xC000;
1597            if let Some(&offset) = self.names.get(&remaining) {
1598                let pointer = offset | POINTER_MASK;
1599                self.write_short(pointer);
1600                return;
1601            }
1602
1603            // Store this position for potential future compression
1604            self.names.insert(remaining, self.size() as u16);
1605
1606            // Write the label
1607            self.write_utf8(label);
1608        }
1609
1610        // Write terminating zero byte
1611        self.write_byte(0);
1612    }
1613
1614    fn write_byte(&mut self, v: u8) {
1615        self.data.push(v);
1616    }
1617
1618    fn write_bytes(&mut self, s: &[u8]) {
1619        self.data.extend(s);
1620    }
1621
1622    fn write_utf8(&mut self, s: &str) {
1623        assert!(s.len() < 64);
1624        self.write_byte(s.len() as u8);
1625        self.write_bytes(s.as_bytes());
1626    }
1627
1628    fn write_u32(&mut self, v: u32) {
1629        self.data.extend(&v.to_be_bytes());
1630    }
1631
1632    fn write_short(&mut self, v: u16) {
1633        self.data.extend(&v.to_be_bytes());
1634    }
1635
1636    /// Writes the header fields and finish the packet.
1637    /// This function should be only called when finishing a packet.
1638    ///
1639    /// The header format is based on RFC 1035 section 4.1.1:
1640    /// https://datatracker.ietf.org/doc/html/rfc1035#section-4.1.1
1641    //
1642    //                                  1  1  1  1  1  1
1643    //    0  1  2  3  4  5  6  7  8  9  0  1  2  3  4  5
1644    //    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
1645    //    |                      ID                       |
1646    //    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
1647    //    |QR|   Opcode  |AA|TC|RD|RA|   Z    |   RCODE   |
1648    //    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
1649    //    |                    QDCOUNT                    |
1650    //    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
1651    //    |                    ANCOUNT                    |
1652    //    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
1653    //    |                    NSCOUNT                    |
1654    //    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
1655    //    |                    ARCOUNT                    |
1656    //    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
1657    //
1658    fn write_header(
1659        &mut self,
1660        id: u16,
1661        flags: u16,
1662        q_count: u16,
1663        a_count: u16,
1664        auth_count: u16,
1665        addi_count: u16,
1666    ) {
1667        self.insert_short(0, id);
1668        self.insert_short(2, flags);
1669        self.insert_short(4, q_count);
1670        self.insert_short(6, a_count);
1671        self.insert_short(8, auth_count);
1672        self.insert_short(10, addi_count);
1673
1674        self.state = PacketState::Finished;
1675    }
1676}
1677
1678/// Representation of one outgoing DNS message that could be sent in one or more packet(s).
1679#[derive(Debug)]
1680pub struct DnsOutgoing {
1681    flags: u16,
1682    id: u16,
1683    multicast: bool,
1684    questions: Vec<DnsQuestion>,
1685    answers: Vec<(DnsRecordBox, u64)>,
1686    authorities: Vec<DnsRecordBox>,
1687    additionals: Vec<DnsRecordBox>,
1688    known_answer_count: i64, // for internal maintenance only
1689}
1690
1691impl DnsOutgoing {
1692    pub fn new(flags: u16) -> Self {
1693        Self {
1694            flags,
1695            id: 0,
1696            multicast: true,
1697            questions: Vec::new(),
1698            answers: Vec::new(),
1699            authorities: Vec::new(),
1700            additionals: Vec::new(),
1701            known_answer_count: 0,
1702        }
1703    }
1704
1705    pub fn questions(&self) -> &[DnsQuestion] {
1706        &self.questions
1707    }
1708
1709    /// For testing purposes only.
1710    pub(crate) fn _answers(&self) -> &[(DnsRecordBox, u64)] {
1711        &self.answers
1712    }
1713
1714    pub fn answers_count(&self) -> usize {
1715        self.answers.len()
1716    }
1717
1718    pub fn authorities(&self) -> &[DnsRecordBox] {
1719        &self.authorities
1720    }
1721
1722    pub fn additionals(&self) -> &[DnsRecordBox] {
1723        &self.additionals
1724    }
1725
1726    pub fn known_answer_count(&self) -> i64 {
1727        self.known_answer_count
1728    }
1729
1730    pub fn set_id(&mut self, id: u16) {
1731        self.id = id;
1732    }
1733
1734    pub const fn is_query(&self) -> bool {
1735        (self.flags & FLAGS_QR_MASK) == FLAGS_QR_QUERY
1736    }
1737
1738    const fn is_response(&self) -> bool {
1739        (self.flags & FLAGS_QR_MASK) == FLAGS_QR_RESPONSE
1740    }
1741
1742    // Adds an additional answer
1743
1744    // From: RFC 6763, DNS-Based Service Discovery, February 2013
1745
1746    // 12.  DNS Additional Record Generation
1747
1748    //    DNS has an efficiency feature whereby a DNS server may place
1749    //    additional records in the additional section of the DNS message.
1750    //    These additional records are records that the client did not
1751    //    explicitly request, but the server has reasonable grounds to expect
1752    //    that the client might request them shortly, so including them can
1753    //    save the client from having to issue additional queries.
1754
1755    //    This section recommends which additional records SHOULD be generated
1756    //    to improve network efficiency, for both Unicast and Multicast DNS-SD
1757    //    responses.
1758
1759    // 12.1.  PTR Records
1760
1761    //    When including a DNS-SD Service Instance Enumeration or Selective
1762    //    Instance Enumeration (subtype) PTR record in a response packet, the
1763    //    server/responder SHOULD include the following additional records:
1764
1765    //    o  The SRV record(s) named in the PTR rdata.
1766    //    o  The TXT record(s) named in the PTR rdata.
1767    //    o  All address records (type "A" and "AAAA") named in the SRV rdata.
1768
1769    // 12.2.  SRV Records
1770
1771    //    When including an SRV record in a response packet, the
1772    //    server/responder SHOULD include the following additional records:
1773
1774    //    o  All address records (type "A" and "AAAA") named in the SRV rdata.
1775    pub fn add_additional_answer(&mut self, answer: impl DnsRecordExt + 'static) {
1776        trace!("add_additional_answer: {:?}", &answer);
1777        self.additionals.push(answer.boxed());
1778    }
1779
1780    /// A workaround as Rust doesn't allow us to pass DnsRecordBox in as `impl DnsRecordExt`
1781    pub fn add_answer_box(&mut self, answer_box: DnsRecordBox) {
1782        self.answers.push((answer_box, 0));
1783    }
1784
1785    pub fn add_authority(&mut self, record: DnsRecordBox) {
1786        self.authorities.push(record);
1787    }
1788
1789    /// Returns true if `answer` is added to the outgoing msg.
1790    /// Returns false if `answer` was not added as it expired or suppressed by the incoming `msg`.
1791    pub fn add_answer(
1792        &mut self,
1793        msg: &DnsIncoming,
1794        answer: impl DnsRecordExt + Send + 'static,
1795    ) -> bool {
1796        trace!("Check for add_answer");
1797        if answer.suppressed_by(msg) {
1798            trace!("my answer is suppressed by incoming msg");
1799            self.known_answer_count += 1;
1800            return false;
1801        }
1802
1803        self.add_answer_at_time(answer, 0)
1804    }
1805
1806    /// Returns true if `answer` is added to the outgoing msg.
1807    /// Returns false if the answer is expired `now` hence not added.
1808    /// If `now` is 0, do not check if the answer expires.
1809    pub fn add_answer_at_time(
1810        &mut self,
1811        answer: impl DnsRecordExt + Send + 'static,
1812        now: u64,
1813    ) -> bool {
1814        if now == 0 || !answer.get_record().is_expired(now) {
1815            trace!("add_answer push: {:?}", &answer);
1816            self.answers.push((answer.boxed(), now));
1817            return true;
1818        }
1819        false
1820    }
1821
1822    /// Adds a PTR answer for `service` along with recommended additional records
1823    /// (SRV, TXT, and address records) per [RFC 6763 Section 12.1].
1824    ///
1825    /// Resolves any name conflicts via `dns_registry` and selects addresses
1826    /// matching the given interface. Does nothing if no addresses are available
1827    /// on `intf` or if the PTR answer is suppressed by known-answer entries in `msg`.
1828    ///
1829    /// [RFC 6763 Section 12.1]: https://tools.ietf.org/html/rfc6763#section-12.1
1830    pub(crate) fn add_answer_with_additionals(
1831        &mut self,
1832        msg: &DnsIncoming,
1833        service: &ServiceInfo,
1834        intf: &MyIntf,
1835        dns_registry: &DnsRegistry,
1836        is_ipv4: bool,
1837    ) {
1838        let intf_addrs = if is_ipv4 {
1839            service.get_addrs_on_my_intf_v4(intf)
1840        } else {
1841            service.get_addrs_on_my_intf_v6(intf)
1842        };
1843        if intf_addrs.is_empty() {
1844            trace!("No addrs on LAN of intf {:?}", intf);
1845            return;
1846        }
1847
1848        // check if we changed our name due to conflicts.
1849        let service_fullname = dns_registry.resolve_name(service.get_fullname());
1850        let hostname = dns_registry.resolve_name(service.get_hostname());
1851
1852        let ptr_added = self.add_answer(
1853            msg,
1854            DnsPointer::new(
1855                service.get_type(),
1856                RRType::PTR,
1857                CLASS_IN,
1858                service.get_other_ttl(),
1859                service_fullname.to_string(),
1860            ),
1861        );
1862
1863        if !ptr_added {
1864            trace!("answer was not added for msg {:?}", msg);
1865            return;
1866        }
1867
1868        if let Some(sub) = service.get_subtype() {
1869            trace!("Adding subdomain {}", sub);
1870            self.add_additional_answer(DnsPointer::new(
1871                sub,
1872                RRType::PTR,
1873                CLASS_IN,
1874                service.get_other_ttl(),
1875                service_fullname.to_string(),
1876            ));
1877        }
1878
1879        // Add recommended additional answers according to
1880        // https://tools.ietf.org/html/rfc6763#section-12.1.
1881        self.add_additional_answer(DnsSrv::new(
1882            service_fullname,
1883            CLASS_IN | CLASS_CACHE_FLUSH,
1884            service.get_host_ttl(),
1885            service.get_priority(),
1886            service.get_weight(),
1887            service.get_port(),
1888            hostname.to_string(),
1889        ));
1890
1891        self.add_additional_answer(DnsTxt::new(
1892            service_fullname,
1893            CLASS_IN | CLASS_CACHE_FLUSH,
1894            service.get_other_ttl(),
1895            service.generate_txt(),
1896        ));
1897
1898        for address in intf_addrs {
1899            self.add_additional_answer(DnsAddress::new(
1900                hostname,
1901                ip_address_rr_type(&address),
1902                CLASS_IN | CLASS_CACHE_FLUSH,
1903                service.get_host_ttl(),
1904                address,
1905                intf.into(),
1906            ));
1907        }
1908    }
1909
1910    pub fn add_question(&mut self, name: &str, qtype: RRType) {
1911        let q = DnsQuestion {
1912            entry: DnsEntry::new(name.to_string(), qtype, CLASS_IN),
1913        };
1914        self.questions.push(q);
1915    }
1916
1917    /// Clear the cache-flush (unique) bit on every answer and additional
1918    /// record. Required for RFC 6762 §6.7 (Legacy Unicast Responses) and
1919    /// §10.2 — a legacy resolver doesn't know about the cache-flush bit
1920    /// and may misinterpret responses where it is set.
1921    pub fn clear_cache_flush_bits(&mut self) {
1922        for (rec, _) in &mut self.answers {
1923            rec.get_record_mut().entry.cache_flush = false;
1924        }
1925        for rec in &mut self.additionals {
1926            rec.get_record_mut().entry.cache_flush = false;
1927        }
1928        for rec in &mut self.authorities {
1929            rec.get_record_mut().entry.cache_flush = false;
1930        }
1931    }
1932
1933    /// Returns a list of actual DNS packet data to be sent on the wire.
1934    pub fn to_data_on_wire(&self) -> Vec<Vec<u8>> {
1935        let packet_list = self.to_packets();
1936        packet_list.into_iter().map(|p| p.data).collect()
1937    }
1938
1939    /// Encode self into one or more packets.
1940    pub fn to_packets(&self) -> Vec<DnsOutPacket> {
1941        let mut packet_list = Vec::new();
1942        let mut packet = DnsOutPacket::new();
1943
1944        let mut question_count = self.questions.len() as u16;
1945        let mut answer_count = 0;
1946        let mut auth_count = 0;
1947        let mut addi_count = 0;
1948        let id = if self.multicast { 0 } else { self.id };
1949
1950        for question in self.questions.iter() {
1951            packet.write_question(question);
1952        }
1953
1954        for (answer, time) in self.answers.iter() {
1955            if packet.write_record(answer.as_ref(), *time) {
1956                answer_count += 1;
1957            }
1958        }
1959
1960        for auth in self.authorities.iter() {
1961            auth_count += u16::from(packet.write_record(auth.as_ref(), 0));
1962        }
1963
1964        for addi in self.additionals.iter() {
1965            if packet.write_record(addi.as_ref(), 0) {
1966                addi_count += 1;
1967                continue;
1968            }
1969
1970            // No more processing for response packets.
1971            if self.is_response() {
1972                break;
1973            }
1974
1975            // For query, the current packet exceeds its max size due to known answers,
1976            // need to truncate.
1977
1978            // finish the current packet first.
1979            packet.write_header(
1980                id,
1981                self.flags | FLAGS_TC,
1982                question_count,
1983                answer_count,
1984                auth_count,
1985                addi_count,
1986            );
1987
1988            packet_list.push(packet);
1989
1990            // create a new packet and reset counts.
1991            packet = DnsOutPacket::new();
1992            packet.write_record(addi.as_ref(), 0);
1993
1994            question_count = 0;
1995            answer_count = 0;
1996            auth_count = 0;
1997            addi_count = 1;
1998        }
1999
2000        packet.write_header(
2001            id,
2002            self.flags,
2003            question_count,
2004            answer_count,
2005            auth_count,
2006            addi_count,
2007        );
2008
2009        packet_list.push(packet);
2010        packet_list
2011    }
2012}
2013
2014/// An incoming DNS message. It could be a query or a response.
2015#[derive(Debug)]
2016pub struct DnsIncoming {
2017    offset: usize,
2018    data: Vec<u8>,
2019    questions: Vec<DnsQuestion>,
2020    answers: Vec<DnsRecordBox>,
2021    authorities: Vec<DnsRecordBox>,
2022    additional: Vec<DnsRecordBox>,
2023    id: u16,
2024    flags: u16,
2025    num_questions: u16,
2026    num_answers: u16,
2027    num_authorities: u16,
2028    num_additionals: u16,
2029    interface_id: InterfaceId,
2030}
2031
2032impl DnsIncoming {
2033    pub fn new(data: Vec<u8>, interface_id: InterfaceId) -> Result<Self> {
2034        let mut incoming = Self {
2035            offset: 0,
2036            data,
2037            questions: Vec::new(),
2038            answers: Vec::new(),
2039            authorities: Vec::new(),
2040            additional: Vec::new(),
2041            id: 0,
2042            flags: 0,
2043            num_questions: 0,
2044            num_answers: 0,
2045            num_authorities: 0,
2046            num_additionals: 0,
2047            interface_id,
2048        };
2049
2050        /*
2051        RFC 1035 section 4.1: https://datatracker.ietf.org/doc/html/rfc1035#section-4.1
2052        ...
2053        All communications inside of the domain protocol are carried in a single
2054        format called a message.  The top level format of message is divided
2055        into 5 sections (some of which are empty in certain cases) shown below:
2056
2057            +---------------------+
2058            |        Header       |
2059            +---------------------+
2060            |       Question      | the question for the name server
2061            +---------------------+
2062            |        Answer       | RRs answering the question
2063            +---------------------+
2064            |      Authority      | RRs pointing toward an authority
2065            +---------------------+
2066            |      Additional     | RRs holding additional information
2067            +---------------------+
2068         */
2069        incoming.read_header()?;
2070        incoming.read_questions()?;
2071        incoming.read_answers()?;
2072        incoming.read_authorities()?;
2073        incoming.read_additional()?;
2074
2075        Ok(incoming)
2076    }
2077
2078    pub fn id(&self) -> u16 {
2079        self.id
2080    }
2081
2082    pub fn questions(&self) -> &[DnsQuestion] {
2083        &self.questions
2084    }
2085
2086    pub fn answers(&self) -> &[DnsRecordBox] {
2087        &self.answers
2088    }
2089
2090    pub fn authorities(&self) -> &[DnsRecordBox] {
2091        &self.authorities
2092    }
2093
2094    pub fn additionals(&self) -> &[DnsRecordBox] {
2095        &self.additional
2096    }
2097
2098    pub fn answers_mut(&mut self) -> &mut Vec<DnsRecordBox> {
2099        &mut self.answers
2100    }
2101
2102    pub fn authorities_mut(&mut self) -> &mut Vec<DnsRecordBox> {
2103        &mut self.authorities
2104    }
2105
2106    pub fn additionals_mut(&mut self) -> &mut Vec<DnsRecordBox> {
2107        &mut self.additional
2108    }
2109
2110    pub fn all_records(self) -> impl Iterator<Item = DnsRecordBox> {
2111        self.answers
2112            .into_iter()
2113            .chain(self.authorities)
2114            .chain(self.additional)
2115    }
2116
2117    pub fn num_additionals(&self) -> u16 {
2118        self.num_additionals
2119    }
2120
2121    pub fn num_authorities(&self) -> u16 {
2122        self.num_authorities
2123    }
2124
2125    pub fn num_questions(&self) -> u16 {
2126        self.num_questions
2127    }
2128
2129    pub const fn is_query(&self) -> bool {
2130        (self.flags & FLAGS_QR_MASK) == FLAGS_QR_QUERY
2131    }
2132
2133    pub const fn is_response(&self) -> bool {
2134        (self.flags & FLAGS_QR_MASK) == FLAGS_QR_RESPONSE
2135    }
2136
2137    fn read_header(&mut self) -> Result<()> {
2138        if self.data.len() < MSG_HEADER_LEN {
2139            return Err(e_fmt!(
2140                "DNS incoming: header is too short: {} bytes",
2141                self.data.len()
2142            ));
2143        }
2144
2145        let data = &self.data[0..];
2146        self.id = u16_from_be_slice(&data[..2]);
2147        self.flags = u16_from_be_slice(&data[2..4]);
2148        self.num_questions = u16_from_be_slice(&data[4..6]);
2149        self.num_answers = u16_from_be_slice(&data[6..8]);
2150        self.num_authorities = u16_from_be_slice(&data[8..10]);
2151        self.num_additionals = u16_from_be_slice(&data[10..12]);
2152
2153        self.offset = MSG_HEADER_LEN;
2154
2155        trace!(
2156            "read_header: id {}, {} questions {} answers {} authorities {} additionals",
2157            self.id,
2158            self.num_questions,
2159            self.num_answers,
2160            self.num_authorities,
2161            self.num_additionals
2162        );
2163        Ok(())
2164    }
2165
2166    fn read_questions(&mut self) -> Result<()> {
2167        trace!("read_questions: {}", &self.num_questions);
2168        for i in 0..self.num_questions {
2169            let name = self.read_name()?;
2170
2171            let data = &self.data[self.offset..];
2172            if data.len() < 4 {
2173                return Err(Error::Msg(format!(
2174                    "DNS incoming: question idx {} too short: {}",
2175                    i,
2176                    data.len()
2177                )));
2178            }
2179            let ty = u16_from_be_slice(&data[..2]);
2180            let class = u16_from_be_slice(&data[2..4]);
2181            self.offset += 4;
2182
2183            let Some(rr_type) = RRType::from_u16(ty) else {
2184                return Err(Error::Msg(format!(
2185                    "DNS incoming: question idx {i} qtype unknown: {ty}",
2186                )));
2187            };
2188
2189            self.questions.push(DnsQuestion {
2190                entry: DnsEntry::new(name, rr_type, class),
2191            });
2192        }
2193        Ok(())
2194    }
2195
2196    fn read_answers(&mut self) -> Result<()> {
2197        self.answers = self.read_rr_records(self.num_answers)?;
2198        Ok(())
2199    }
2200
2201    fn read_authorities(&mut self) -> Result<()> {
2202        self.authorities = self.read_rr_records(self.num_authorities)?;
2203        Ok(())
2204    }
2205
2206    fn read_additional(&mut self) -> Result<()> {
2207        self.additional = self.read_rr_records(self.num_additionals)?;
2208        Ok(())
2209    }
2210
2211    /// Decodes a sequence of RR records (in answers, authorities and additionals).
2212    fn read_rr_records(&mut self, count: u16) -> Result<Vec<DnsRecordBox>> {
2213        trace!("read_rr_records: {}", count);
2214        let mut rr_records = Vec::new();
2215
2216        // RFC 1035: https://datatracker.ietf.org/doc/html/rfc1035#section-3.2.1
2217        //
2218        // All RRs have the same top level format shown below:
2219        //                               1  1  1  1  1  1
2220        // 0  1  2  3  4  5  6  7  8  9  0  1  2  3  4  5
2221        // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
2222        // |                                               |
2223        // /                                               /
2224        // /                      NAME                     /
2225        // |                                               |
2226        // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
2227        // |                      TYPE                     |
2228        // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
2229        // |                     CLASS                     |
2230        // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
2231        // |                      TTL                      |
2232        // |                                               |
2233        // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
2234        // |                   RDLENGTH                    |
2235        // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--|
2236        // /                     RDATA                     /
2237        // /                                               /
2238        // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
2239
2240        // Muse have at least TYPE, CLASS, TTL, RDLENGTH fields: 10 bytes.
2241        const RR_HEADER_REMAIN: usize = 10;
2242
2243        for _ in 0..count {
2244            let name = self.read_name()?;
2245            let slice = &self.data[self.offset..];
2246
2247            if slice.len() < RR_HEADER_REMAIN {
2248                return Err(Error::Msg(format!(
2249                    "read_others: RR '{}' is too short after name: {} bytes",
2250                    &name,
2251                    slice.len()
2252                )));
2253            }
2254
2255            let ty = u16_from_be_slice(&slice[..2]);
2256            let class = u16_from_be_slice(&slice[2..4]);
2257            let mut ttl = u32_from_be_slice(&slice[4..8]);
2258            if ttl == 0 && self.is_response() {
2259                // RFC 6762 section 10.1:
2260                // "...Queriers receiving a Multicast DNS response with a TTL of zero SHOULD
2261                // NOT immediately delete the record from the cache, but instead record
2262                // a TTL of 1 and then delete the record one second later."
2263                // See https://datatracker.ietf.org/doc/html/rfc6762#section-10.1
2264
2265                ttl = 1;
2266            }
2267            let rdata_len = u16_from_be_slice(&slice[8..10]) as usize;
2268            self.offset += RR_HEADER_REMAIN;
2269            let next_offset = self.offset + rdata_len;
2270
2271            // Sanity check for RDATA length.
2272            if next_offset > self.data.len() {
2273                return Err(Error::Msg(format!(
2274                    "RR {name} RDATA length {rdata_len} is invalid: remain data len: {}",
2275                    self.data.len() - self.offset
2276                )));
2277            }
2278
2279            // decode RDATA based on the record type.
2280            let rec: Option<DnsRecordBox> = match RRType::from_u16(ty) {
2281                None => None,
2282
2283                Some(rr_type) => match rr_type {
2284                    RRType::CNAME | RRType::PTR => {
2285                        Some(DnsPointer::new(&name, rr_type, class, ttl, self.read_name()?).boxed())
2286                    }
2287                    RRType::TXT => {
2288                        Some(DnsTxt::new(&name, class, ttl, self.read_vec(rdata_len)?).boxed())
2289                    }
2290                    RRType::SRV => Some(
2291                        DnsSrv::new(
2292                            &name,
2293                            class,
2294                            ttl,
2295                            self.read_u16()?,
2296                            self.read_u16()?,
2297                            self.read_u16()?,
2298                            self.read_name()?,
2299                        )
2300                        .boxed(),
2301                    ),
2302                    RRType::HINFO => Some(
2303                        DnsHostInfo::new(
2304                            &name,
2305                            rr_type,
2306                            class,
2307                            ttl,
2308                            self.read_char_string()?,
2309                            self.read_char_string()?,
2310                        )
2311                        .boxed(),
2312                    ),
2313                    RRType::A => Some(
2314                        DnsAddress::new(
2315                            &name,
2316                            rr_type,
2317                            class,
2318                            ttl,
2319                            self.read_ipv4()?.into(),
2320                            self.interface_id.clone(),
2321                        )
2322                        .boxed(),
2323                    ),
2324                    RRType::AAAA => Some(
2325                        DnsAddress::new(
2326                            &name,
2327                            rr_type,
2328                            class,
2329                            ttl,
2330                            self.read_ipv6()?.into(),
2331                            self.interface_id.clone(),
2332                        )
2333                        .boxed(),
2334                    ),
2335                    RRType::NSEC => Some(
2336                        DnsNSec::new(
2337                            &name,
2338                            class,
2339                            ttl,
2340                            self.read_name()?,
2341                            self.read_type_bitmap()?,
2342                        )
2343                        .boxed(),
2344                    ),
2345                    _ => None,
2346                },
2347            };
2348
2349            if let Some(record) = rec {
2350                trace!("read_rr_records: {:?}", &record);
2351                rr_records.push(record);
2352            } else {
2353                trace!("Unsupported DNS record type: {} name: {}", ty, &name);
2354                self.offset += rdata_len;
2355            }
2356
2357            // sanity check.
2358            if self.offset != next_offset {
2359                return Err(Error::Msg(format!(
2360                    "read_rr_records: decode offset error for RData type {} offset: {} expected offset: {}",
2361                    ty, self.offset, next_offset,
2362                )));
2363            }
2364        }
2365
2366        Ok(rr_records)
2367    }
2368
2369    fn read_char_string(&mut self) -> Result<String> {
2370        let length = self.data[self.offset];
2371        self.offset += 1;
2372        self.read_string(length as usize)
2373    }
2374
2375    fn read_u16(&mut self) -> Result<u16> {
2376        let slice = &self.data[self.offset..];
2377        if slice.len() < U16_SIZE {
2378            return Err(Error::Msg(format!(
2379                "read_u16: slice len is only {}",
2380                slice.len()
2381            )));
2382        }
2383        let num = u16_from_be_slice(&slice[..U16_SIZE]);
2384        self.offset += U16_SIZE;
2385        Ok(num)
2386    }
2387
2388    /// Reads the "Type Bit Map" block for a DNS NSEC record.
2389    fn read_type_bitmap(&mut self) -> Result<Vec<u8>> {
2390        // From RFC 6762: 6.1.  Negative Responses
2391        // https://datatracker.ietf.org/doc/html/rfc6762#section-6.1
2392        //   o The Type Bit Map block number is 0.
2393        //   o The Type Bit Map block length byte is a value in the range 1-32.
2394        //   o The Type Bit Map data is 1-32 bytes, as indicated by length
2395        //     byte.
2396
2397        // Sanity check: at least 2 bytes to read.
2398        if self.data.len() < self.offset + 2 {
2399            return Err(Error::Msg(format!(
2400                "DnsIncoming is too short: {} at NSEC Type Bit Map offset {}",
2401                self.data.len(),
2402                self.offset
2403            )));
2404        }
2405
2406        let block_num = self.data[self.offset];
2407        self.offset += 1;
2408        if block_num != 0 {
2409            return Err(Error::Msg(format!(
2410                "NSEC block number is not 0: {block_num}"
2411            )));
2412        }
2413
2414        let block_len = self.data[self.offset] as usize;
2415        if !(1..=32).contains(&block_len) {
2416            return Err(Error::Msg(format!(
2417                "NSEC block length must be in the range 1-32: {block_len}"
2418            )));
2419        }
2420        self.offset += 1;
2421
2422        let end = self.offset + block_len;
2423        if end > self.data.len() {
2424            return Err(Error::Msg(format!(
2425                "NSEC block overflow: {} over RData len {}",
2426                end,
2427                self.data.len()
2428            )));
2429        }
2430        let bitmap = self.data[self.offset..end].to_vec();
2431        self.offset += block_len;
2432
2433        Ok(bitmap)
2434    }
2435
2436    fn read_vec(&mut self, length: usize) -> Result<Vec<u8>> {
2437        if self.data.len() < self.offset + length {
2438            return Err(e_fmt!(
2439                "DNS Incoming: not enough data to read a chunk of data"
2440            ));
2441        }
2442
2443        let v = self.data[self.offset..self.offset + length].to_vec();
2444        self.offset += length;
2445        Ok(v)
2446    }
2447
2448    fn read_ipv4(&mut self) -> Result<Ipv4Addr> {
2449        if self.data.len() < self.offset + 4 {
2450            return Err(e_fmt!("DNS Incoming: not enough data to read an IPV4"));
2451        }
2452
2453        let bytes: [u8; 4] = self.data[self.offset..self.offset + 4]
2454            .try_into()
2455            .map_err(|_| e_fmt!("DNS incoming: Not enough bytes for reading an IPV4"))?;
2456        self.offset += bytes.len();
2457        Ok(Ipv4Addr::from(bytes))
2458    }
2459
2460    fn read_ipv6(&mut self) -> Result<Ipv6Addr> {
2461        if self.data.len() < self.offset + 16 {
2462            return Err(e_fmt!("DNS Incoming: not enough data to read an IPV6"));
2463        }
2464
2465        let bytes: [u8; 16] = self.data[self.offset..self.offset + 16]
2466            .try_into()
2467            .map_err(|_| e_fmt!("DNS incoming: Not enough bytes for reading an IPV6"))?;
2468        self.offset += bytes.len();
2469        Ok(Ipv6Addr::from(bytes))
2470    }
2471
2472    fn read_string(&mut self, length: usize) -> Result<String> {
2473        if self.data.len() < self.offset + length {
2474            return Err(e_fmt!("DNS Incoming: not enough data to read a string"));
2475        }
2476
2477        let s = str::from_utf8(&self.data[self.offset..self.offset + length])
2478            .map_err(|e| Error::Msg(e.to_string()))?;
2479        self.offset += length;
2480        Ok(s.to_string())
2481    }
2482
2483    /// Reads a domain name at the current location of `self.data`.
2484    ///
2485    /// See https://datatracker.ietf.org/doc/html/rfc1035#section-3.1 for
2486    /// domain name encoding.
2487    fn read_name(&mut self) -> Result<String> {
2488        let data = &self.data[..];
2489        let start_offset = self.offset;
2490        let mut offset = start_offset;
2491        let mut name = "".to_string();
2492        let mut at_end = false;
2493
2494        // From RFC1035:
2495        // "...Domain names in messages are expressed in terms of a sequence of labels.
2496        // Each label is represented as a one octet length field followed by that
2497        // number of octets."
2498        //
2499        // "...The compression scheme allows a domain name in a message to be
2500        // represented as either:
2501        // - a sequence of labels ending in a zero octet
2502        // - a pointer
2503        // - a sequence of labels ending with a pointer"
2504        loop {
2505            if offset >= data.len() {
2506                return Err(Error::Msg(format!(
2507                    "read_name: offset: {} data len {}. DnsIncoming: {:?}",
2508                    offset,
2509                    data.len(),
2510                    self
2511                )));
2512            }
2513            let length = data[offset];
2514
2515            // From RFC1035:
2516            // "...Since every domain name ends with the null label of
2517            // the root, a domain name is terminated by a length byte of zero."
2518            if length == 0 {
2519                if !at_end {
2520                    self.offset = offset + 1;
2521                }
2522                break; // The end of the name
2523            }
2524
2525            // Check the first 2 bits for possible "Message compression".
2526            match length & 0xC0 {
2527                0x00 => {
2528                    // regular utf8 string with length
2529                    offset += 1;
2530                    let ending = offset + length as usize;
2531
2532                    // Never read beyond the whole data length.
2533                    if ending > data.len() {
2534                        return Err(Error::Msg(format!(
2535                            "read_name: ending {} exceeds data length {}",
2536                            ending,
2537                            data.len()
2538                        )));
2539                    }
2540
2541                    name += str::from_utf8(&data[offset..ending])
2542                        .map_err(|e| Error::Msg(format!("read_name: from_utf8: {e}")))?;
2543                    name += ".";
2544                    offset += length as usize;
2545                }
2546                0xC0 => {
2547                    // Message compression.
2548                    // See https://datatracker.ietf.org/doc/html/rfc1035#section-4.1.4
2549                    let slice = &data[offset..];
2550                    if slice.len() < U16_SIZE {
2551                        return Err(Error::Msg(format!(
2552                            "read_name: u16 slice len is only {}",
2553                            slice.len()
2554                        )));
2555                    }
2556                    let pointer = (u16_from_be_slice(slice) ^ 0xC000) as usize;
2557                    if pointer >= start_offset {
2558                        // Error: could trigger an infinite loop.
2559                        return Err(Error::Msg(format!(
2560                            "Invalid name compression: pointer {} must be less than the start offset {}",
2561                            &pointer, &start_offset
2562                        )));
2563                    }
2564
2565                    // A pointer marks the end of a domain name.
2566                    if !at_end {
2567                        self.offset = offset + U16_SIZE;
2568                        at_end = true;
2569                    }
2570                    offset = pointer;
2571                }
2572                _ => {
2573                    return Err(Error::Msg(format!(
2574                        "Bad name with invalid length: 0x{:x} offset {}, data (so far): {:x?}",
2575                        length,
2576                        offset,
2577                        &data[..offset]
2578                    )));
2579                }
2580            };
2581        }
2582
2583        Ok(name)
2584    }
2585}
2586
2587const fn u16_from_be_slice(bytes: &[u8]) -> u16 {
2588    let u8_array: [u8; 2] = [bytes[0], bytes[1]];
2589    u16::from_be_bytes(u8_array)
2590}
2591
2592const fn u32_from_be_slice(s: &[u8]) -> u32 {
2593    let u8_array: [u8; 4] = [s[0], s[1], s[2], s[3]];
2594    u32::from_be_bytes(u8_array)
2595}
2596
2597/// Returns the UNIX time in millis at which this record will have expired
2598/// by a certain percentage.
2599const fn get_expiration_time(created: u64, ttl: u32, percent: u32) -> u64 {
2600    // 'created' is in millis, 'ttl' is in seconds, hence:
2601    // ttl * 1000 * (percent / 100) => ttl * percent * 10
2602    created + (ttl as u64 * percent as u64 * 10)
2603}
2604
2605#[cfg(test)]
2606mod tests {
2607    use super::{
2608        DnsAddress, DnsHostInfo, DnsOutgoing, DnsPointer, DnsTxt, RRType, CLASS_CACHE_FLUSH,
2609        CLASS_IN,
2610    };
2611    use crate::InterfaceId;
2612    use std::collections::HashMap;
2613    use std::net::{IpAddr, Ipv4Addr};
2614
2615    #[test]
2616    fn test_dns_outgoing_serialization_empty() {
2617        let out = DnsOutgoing::new(0);
2618        let packets = out.to_packets();
2619        assert_eq!(packets.len(), 1);
2620        assert_eq!(packets[0].as_bytes(), &[0; 12]);
2621        let expected_names = HashMap::new();
2622        assert_eq!(&packets[0].names, &expected_names);
2623    }
2624
2625    #[test]
2626    fn test_dns_outgoing_serialization_question() {
2627        let mut out = DnsOutgoing::new(0);
2628        out.add_question("123.test", RRType::A);
2629        let packets = out.to_packets();
2630        assert_eq!(packets.len(), 1);
2631        assert_eq!(
2632            packets[0].as_bytes(),
2633            &[
2634                0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, // Header
2635                // Payload
2636                3, 49, 50, 51, 4, 116, 101, 115, 116, 0, 0, 1, 0, 1,
2637            ]
2638        );
2639        let mut expected_names = HashMap::new();
2640        expected_names.insert("123.test".to_string(), 12);
2641        expected_names.insert("test".to_string(), 16);
2642        assert_eq!(&packets[0].names, &expected_names);
2643    }
2644
2645    #[test]
2646    fn test_dns_outgoing_serialization_question_with_authority() {
2647        let mut out = DnsOutgoing::new(0);
2648        out.add_question("123.test", RRType::ANY);
2649        out.add_authority(Box::new(DnsTxt::new(
2650            "124.test",
2651            CLASS_IN,
2652            0x00112233,
2653            b"help".to_vec(),
2654        )));
2655        out.add_authority(Box::new(DnsHostInfo::new(
2656            "124.test",
2657            RRType::CNAME,
2658            CLASS_IN,
2659            0x00112233,
2660            "arm".to_string(),
2661            "linux".to_string(),
2662        )));
2663        let packets = out.to_packets();
2664        assert_eq!(packets.len(), 1);
2665        assert_eq!(
2666            packets[0].as_bytes(),
2667            &[
2668                0, 0, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, // Header
2669                // Payload
2670                3, 49, 50, 51, 4, 116, 101, 115, 116, 0, 0, 255, 0, 1, 3, 49, 50, 52, 192, 16, 0,
2671                16, 0, 1, 0, 17, 34, 51, 0, 4, 104, 101, 108, 112, 192, 26, 0, 5, 0, 1, 0, 17, 34,
2672                51, 0, 8, 97, 114, 109, 108, 105, 110, 117, 120,
2673            ]
2674        );
2675        let mut expected_names = HashMap::new();
2676        expected_names.insert("123.test".to_string(), 12);
2677        expected_names.insert("test".to_string(), 16);
2678        expected_names.insert("124.test".to_string(), 26);
2679        assert_eq!(&packets[0].names, &expected_names);
2680    }
2681
2682    #[test]
2683    fn test_dns_outgoing_serialization_additional_answer() {
2684        let mut out = DnsOutgoing::new(0);
2685        out.add_additional_answer(DnsAddress::new(
2686            "test.local",
2687            RRType::A,
2688            CLASS_IN | CLASS_CACHE_FLUSH,
2689            0xdead_beef,
2690            IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
2691            InterfaceId::default(),
2692        ));
2693        let packets = out.to_packets();
2694        assert_eq!(packets.len(), 1);
2695        assert_eq!(
2696            packets[0].as_bytes(),
2697            &[
2698                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, // Header
2699                // Payload
2700                4, 116, 101, 115, 116, 5, 108, 111, 99, 97, 108, 0, 0, 1, 128, 1, 222, 173, 190,
2701                239, 0, 4, 127, 0, 0, 1,
2702            ]
2703        );
2704        let mut expected_names = HashMap::new();
2705        expected_names.insert("test.local".to_string(), 12);
2706        expected_names.insert("local".to_string(), 17);
2707        assert_eq!(&packets[0].names, &expected_names);
2708    }
2709
2710    #[test]
2711    fn test_dns_outgoing_serialization_answer_at_time() {
2712        let mut out = DnsOutgoing::new(0);
2713        out.add_answer_at_time(
2714            DnsPointer::new(
2715                "test",
2716                RRType::PTR,
2717                CLASS_IN,
2718                0xaaaa5555,
2719                "test-service".to_string(),
2720            ),
2721            0,
2722        );
2723        let packets = out.to_packets();
2724        assert_eq!(packets.len(), 1);
2725        assert_eq!(
2726            packets[0].as_bytes(),
2727            &[
2728                0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, // Header
2729                // Payload
2730                4, 116, 101, 115, 116, 0, 0, 12, 0, 1, 170, 170, 85, 85, 0, 14, 12, 116, 101, 115,
2731                116, 45, 115, 101, 114, 118, 105, 99, 101, 0,
2732            ]
2733        );
2734
2735        let mut out = DnsOutgoing::new(0);
2736        out.add_answer_at_time(
2737            DnsPointer::new(
2738                "test",
2739                RRType::CNAME,
2740                CLASS_IN,
2741                0xaaaa5555,
2742                "test-service.local".to_string(),
2743            ),
2744            0,
2745        );
2746        out.add_answer_at_time(
2747            DnsPointer::new(
2748                "test",
2749                RRType::AAAA,
2750                CLASS_IN,
2751                0xffffffff,
2752                "test-service.local".to_string(),
2753            ),
2754            0,
2755        );
2756        let packets = out.to_packets();
2757        assert_eq!(packets.len(), 1);
2758        assert_eq!(
2759            packets[0].as_bytes(),
2760            &[
2761                0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, // Header
2762                // Payload
2763                4, 116, 101, 115, 116, 0, 0, 5, 0, 1, 170, 170, 85, 85, 0, 20, 12, 116, 101, 115,
2764                116, 45, 115, 101, 114, 118, 105, 99, 101, 5, 108, 111, 99, 97, 108, 0, 192, 12, 0,
2765                28, 0, 1, 255, 255, 255, 255, 0, 2, 192, 28,
2766            ]
2767        );
2768        let mut expected_names = HashMap::new();
2769        expected_names.insert("test".to_string(), 12);
2770        expected_names.insert("test-service.local".to_string(), 28);
2771        expected_names.insert("local".to_string(), 41);
2772        assert_eq!(&packets[0].names, &expected_names);
2773    }
2774}