mdns_sd/
dns_parser.rs

1//! DNS parsing utility.
2//!
3//! [DnsIncoming] is the logic representation of an incoming DNS packet.
4//! [DnsOutgoing] is the logic representation of an outgoing DNS message of one or more packets.
5//! [DnsOutPacket] is the encoded one packet for [DnsOutgoing].
6
7#[cfg(feature = "logging")]
8use crate::log::trace;
9
10use crate::error::{e_fmt, Error, Result};
11
12use if_addrs::Interface;
13
14use std::{
15    any::Any,
16    cmp,
17    collections::HashMap,
18    convert::TryInto,
19    fmt,
20    net::{IpAddr, Ipv4Addr, Ipv6Addr},
21    str,
22    time::SystemTime,
23};
24
25/// Represents a network interface identifier defined by the OS.
26#[derive(Clone, Debug, Eq, Hash, PartialEq, Default)]
27pub struct InterfaceId {
28    /// Interface name, e.g. "en0", "wlan0", etc.
29    pub name: String,
30
31    /// Interface index assigned by the OS, e.g. 1, 2, etc.
32    pub index: u32,
33}
34
35impl fmt::Display for InterfaceId {
36    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
37        write!(f, "{}('{}')", self.index, self.name)
38    }
39}
40
41impl From<&Interface> for InterfaceId {
42    fn from(interface: &Interface) -> Self {
43        InterfaceId {
44            name: interface.name.clone(),
45            index: interface.index.unwrap_or_default(),
46        }
47    }
48}
49
50/// An IPv4 address used in `ScopedIp`.
51///
52/// Note: IPv4 addresses don't have scope IDs, but this type is named for consistency
53/// with the rest of the addressing system.
54#[derive(Debug, Clone, Eq, PartialEq, Hash)]
55pub struct ScopedIpV4 {
56    addr: Ipv4Addr,
57}
58
59impl ScopedIpV4 {
60    /// Returns the IPv4 address.
61    pub const fn addr(&self) -> &Ipv4Addr {
62        &self.addr
63    }
64}
65
66/// An IPv6 address with scope_id (interface identifier).
67#[derive(Debug, Clone, Eq, PartialEq, Hash)]
68pub struct ScopedIpV6 {
69    addr: Ipv6Addr,
70    scope_id: InterfaceId,
71}
72
73impl ScopedIpV6 {
74    /// Returns the IPv6 address.
75    pub const fn addr(&self) -> &Ipv6Addr {
76        &self.addr
77    }
78
79    /// Returns the scope_id for this IPv6 address.
80    pub const fn scope_id(&self) -> &InterfaceId {
81        &self.scope_id
82    }
83}
84
85/// An IP address, either IPv4 or IPv6, that supports scope_id for IPv6.
86#[derive(Debug, Clone, Eq, PartialEq, Hash)]
87#[non_exhaustive]
88pub enum ScopedIp {
89    V4(ScopedIpV4),
90    V6(ScopedIpV6),
91}
92
93impl ScopedIp {
94    pub const fn to_ip_addr(&self) -> IpAddr {
95        match self {
96            ScopedIp::V4(v4) => IpAddr::V4(v4.addr),
97            ScopedIp::V6(v6) => IpAddr::V6(v6.addr),
98        }
99    }
100
101    pub const fn is_ipv4(&self) -> bool {
102        matches!(self, ScopedIp::V4(_))
103    }
104
105    pub const fn is_ipv6(&self) -> bool {
106        matches!(self, ScopedIp::V6(_))
107    }
108
109    pub const fn is_loopback(&self) -> bool {
110        match self {
111            ScopedIp::V4(v4) => v4.addr.is_loopback(),
112            ScopedIp::V6(v6) => v6.addr.is_loopback(),
113        }
114    }
115}
116
117impl From<IpAddr> for ScopedIp {
118    fn from(ip: IpAddr) -> Self {
119        match ip {
120            IpAddr::V4(v4) => ScopedIp::V4(ScopedIpV4 { addr: v4 }),
121            IpAddr::V6(v6) => ScopedIp::V6(ScopedIpV6 {
122                addr: v6,
123                scope_id: InterfaceId::default(),
124            }),
125        }
126    }
127}
128
129impl From<&Interface> for ScopedIp {
130    fn from(interface: &Interface) -> Self {
131        match interface.ip() {
132            IpAddr::V4(v4) => ScopedIp::V4(ScopedIpV4 { addr: v4 }),
133            IpAddr::V6(v6) => ScopedIp::V6(ScopedIpV6 {
134                addr: v6,
135                scope_id: InterfaceId::from(interface),
136            }),
137        }
138    }
139}
140
141impl fmt::Display for ScopedIp {
142    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
143        match self {
144            ScopedIp::V4(v4) => write!(f, "{}", v4.addr),
145            ScopedIp::V6(v6) => {
146                if v6.scope_id.index != 0 {
147                    #[cfg(windows)]
148                    {
149                        write!(f, "{}%{}", v6.addr, v6.scope_id.index)
150                    }
151                    #[cfg(not(windows))]
152                    {
153                        write!(f, "{}%{}", v6.addr, v6.scope_id.name)
154                    }
155                } else {
156                    write!(f, "{}", v6.addr)
157                }
158            }
159        }
160    }
161}
162
163/// DNS resource record types, stored as `u16`. Can do `as u16` when needed.
164///
165/// See [RFC 1035 section 3.2.2](https://datatracker.ietf.org/doc/html/rfc1035#section-3.2.2)
166#[derive(Debug, PartialEq, Eq, Clone, Copy, PartialOrd, Ord)]
167#[non_exhaustive]
168#[repr(u16)]
169pub enum RRType {
170    /// DNS record type for IPv4 address
171    A = 1,
172
173    /// DNS record type for Canonical Name
174    CNAME = 5,
175
176    /// DNS record type for Pointer
177    PTR = 12,
178
179    /// DNS record type for Host Info
180    HINFO = 13,
181
182    /// DNS record type for Text (properties)
183    TXT = 16,
184
185    /// DNS record type for IPv6 address
186    AAAA = 28,
187
188    /// DNS record type for Service
189    SRV = 33,
190
191    /// DNS record type for Negative Responses
192    NSEC = 47,
193
194    /// DNS record type for any records (wildcard)
195    ANY = 255,
196}
197
198impl RRType {
199    /// Converts `u16` into `RRType` if possible.
200    pub const fn from_u16(value: u16) -> Option<RRType> {
201        match value {
202            1 => Some(RRType::A),
203            5 => Some(RRType::CNAME),
204            12 => Some(RRType::PTR),
205            13 => Some(RRType::HINFO),
206            16 => Some(RRType::TXT),
207            28 => Some(RRType::AAAA),
208            33 => Some(RRType::SRV),
209            47 => Some(RRType::NSEC),
210            255 => Some(RRType::ANY),
211            _ => None,
212        }
213    }
214}
215
216impl fmt::Display for RRType {
217    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
218        match self {
219            RRType::A => write!(f, "TYPE_A"),
220            RRType::CNAME => write!(f, "TYPE_CNAME"),
221            RRType::PTR => write!(f, "TYPE_PTR"),
222            RRType::HINFO => write!(f, "TYPE_HINFO"),
223            RRType::TXT => write!(f, "TYPE_TXT"),
224            RRType::AAAA => write!(f, "TYPE_AAAA"),
225            RRType::SRV => write!(f, "TYPE_SRV"),
226            RRType::NSEC => write!(f, "TYPE_NSEC"),
227            RRType::ANY => write!(f, "TYPE_ANY"),
228        }
229    }
230}
231
232/// The class value for the Internet.
233pub const CLASS_IN: u16 = 1;
234pub const CLASS_MASK: u16 = 0x7FFF;
235
236/// Cache-flush bit: the most significant bit of the rrclass field of the resource record.  
237pub const CLASS_CACHE_FLUSH: u16 = 0x8000;
238
239/// Max size of UDP datagram payload.
240///
241/// It is calculated as: 9000 bytes - IP header 20 bytes - UDP header 8 bytes.
242/// Reference: [RFC6762 section 17](https://datatracker.ietf.org/doc/html/rfc6762#section-17)
243pub const MAX_MSG_ABSOLUTE: usize = 8972;
244
245const MSG_HEADER_LEN: usize = 12;
246
247// Definitions for DNS message header "flags" field
248//
249// The "flags" field is 16-bit long, in this format:
250// (RFC 1035 section 4.1.1)
251//
252//   0  1  2  3  4  5  6  7  8  9  0  1  2  3  4  5
253// |QR|   Opcode  |AA|TC|RD|RA|   Z    |   RCODE   |
254//
255pub const FLAGS_QR_MASK: u16 = 0x8000; // mask for query/response bit
256
257/// Flag bit to indicate a query
258pub const FLAGS_QR_QUERY: u16 = 0x0000;
259
260/// Flag bit to indicate a response
261pub const FLAGS_QR_RESPONSE: u16 = 0x8000;
262
263/// Flag bit for Authoritative Answer
264pub const FLAGS_AA: u16 = 0x0400;
265
266/// mask for TC(Truncated) bit
267///
268/// 2024-08-10: currently this flag is only supported on the querier side,
269///             not supported on the responder side. I.e. the responder only
270///             handles the first packet and ignore this bit. Since the
271///             additional packets have 0 questions, the processing of them
272///             is no-op.
273///             In practice, this means the responder supports Known-Answer
274///             only with single packet, not multi-packet. The querier supports
275///             both single packet and multi-packet.
276pub const FLAGS_TC: u16 = 0x0200;
277
278/// A convenience type alias for DNS record trait objects.
279pub type DnsRecordBox = Box<dyn DnsRecordExt>;
280
281impl Clone for DnsRecordBox {
282    fn clone(&self) -> Self {
283        self.clone_box()
284    }
285}
286
287const U16_SIZE: usize = 2;
288
289/// Returns `RRType` for a given IP address.
290#[inline]
291pub const fn ip_address_rr_type(address: &IpAddr) -> RRType {
292    match address {
293        IpAddr::V4(_) => RRType::A,
294        IpAddr::V6(_) => RRType::AAAA,
295    }
296}
297
298#[derive(Eq, PartialEq, Debug, Clone)]
299pub struct DnsEntry {
300    pub(crate) name: String, // always lower case.
301    pub(crate) ty: RRType,
302    class: u16,
303    cache_flush: bool,
304}
305
306impl DnsEntry {
307    const fn new(name: String, ty: RRType, class: u16) -> Self {
308        Self {
309            name,
310            ty,
311            class: class & CLASS_MASK,
312            cache_flush: (class & CLASS_CACHE_FLUSH) != 0,
313        }
314    }
315}
316
317/// Common methods for all DNS entries:  questions and resource records.
318pub trait DnsEntryExt: fmt::Debug {
319    fn entry_name(&self) -> &str;
320
321    fn entry_type(&self) -> RRType;
322}
323
324/// A DNS question entry
325#[derive(Debug)]
326pub struct DnsQuestion {
327    pub(crate) entry: DnsEntry,
328}
329
330impl DnsEntryExt for DnsQuestion {
331    fn entry_name(&self) -> &str {
332        &self.entry.name
333    }
334
335    fn entry_type(&self) -> RRType {
336        self.entry.ty
337    }
338}
339
340/// A DNS Resource Record - like a DNS entry, but has a TTL.
341/// RFC: https://www.rfc-editor.org/rfc/rfc1035#section-3.2.1
342///      https://www.rfc-editor.org/rfc/rfc1035#section-4.1.3
343#[derive(Debug, Clone)]
344pub struct DnsRecord {
345    pub(crate) entry: DnsEntry,
346    ttl: u32,     // in seconds, 0 means this record should not be cached
347    created: u64, // UNIX time in millis
348    expires: u64, // expires at this UNIX time in millis
349
350    /// Support re-query an instance before its PTR record expires.
351    /// See https://datatracker.ietf.org/doc/html/rfc6762#section-5.2
352    refresh: u64, // UNIX time in millis
353
354    /// If conflict resolution decides to change the name, this is the new one.
355    new_name: Option<String>,
356}
357
358impl DnsRecord {
359    fn new(name: &str, ty: RRType, class: u16, ttl: u32) -> Self {
360        let created = current_time_millis();
361
362        // From RFC 6762 section 5.2:
363        // "... The querier should plan to issue a query at 80% of the record
364        // lifetime, and then if no answer is received, at 85%, 90%, and 95%."
365        let refresh = get_expiration_time(created, ttl, 80);
366
367        let expires = get_expiration_time(created, ttl, 100);
368
369        Self {
370            entry: DnsEntry::new(name.to_string(), ty, class),
371            ttl,
372            created,
373            expires,
374            refresh,
375            new_name: None,
376        }
377    }
378
379    pub const fn get_ttl(&self) -> u32 {
380        self.ttl
381    }
382
383    pub const fn get_expire_time(&self) -> u64 {
384        self.expires
385    }
386
387    pub const fn get_refresh_time(&self) -> u64 {
388        self.refresh
389    }
390
391    pub const fn is_expired(&self, now: u64) -> bool {
392        now >= self.expires
393    }
394
395    /// Returns whether record expires in 1 second.
396    ///
397    /// This is useful because mDNS sets TTL to 1 (not 0) for expiring records.
398    pub const fn expires_soon(&self, now: u64) -> bool {
399        now + 1000 >= self.expires
400    }
401
402    pub const fn refresh_due(&self, now: u64) -> bool {
403        now >= self.refresh
404    }
405
406    /// Returns whether `now` (in millis) has passed half of TTL.
407    pub fn halflife_passed(&self, now: u64) -> bool {
408        let halflife = get_expiration_time(self.created, self.ttl, 50);
409        now > halflife
410    }
411
412    pub fn is_unique(&self) -> bool {
413        self.entry.cache_flush
414    }
415
416    /// Updates the refresh time to be the same as the expire time so that
417    /// this record will not refresh again and will just expire.
418    pub fn refresh_no_more(&mut self) {
419        self.refresh = get_expiration_time(self.created, self.ttl, 100);
420    }
421
422    /// Returns if this record is due for refresh. If yes, `refresh` time is updated.
423    pub fn refresh_maybe(&mut self, now: u64) -> bool {
424        if self.is_expired(now) || !self.refresh_due(now) {
425            return false;
426        }
427
428        trace!(
429            "{} qtype {} is due to refresh",
430            &self.entry.name,
431            self.entry.ty
432        );
433
434        // From RFC 6762 section 5.2:
435        // "... The querier should plan to issue a query at 80% of the record
436        // lifetime, and then if no answer is received, at 85%, 90%, and 95%."
437        //
438        // If the answer is received in time, 'refresh' will be reset outside
439        // this function, back to 80% of the new TTL.
440        if self.refresh == get_expiration_time(self.created, self.ttl, 80) {
441            self.refresh = get_expiration_time(self.created, self.ttl, 85);
442        } else if self.refresh == get_expiration_time(self.created, self.ttl, 85) {
443            self.refresh = get_expiration_time(self.created, self.ttl, 90);
444        } else if self.refresh == get_expiration_time(self.created, self.ttl, 90) {
445            self.refresh = get_expiration_time(self.created, self.ttl, 95);
446        } else {
447            self.refresh_no_more();
448        }
449
450        true
451    }
452
453    /// Returns the remaining TTL in seconds
454    fn get_remaining_ttl(&self, now: u64) -> u32 {
455        let remaining_millis = get_expiration_time(self.created, self.ttl, 100) - now;
456        cmp::max(0, remaining_millis / 1000) as u32
457    }
458
459    /// Return the absolute time for this record being created
460    pub const fn get_created(&self) -> u64 {
461        self.created
462    }
463
464    /// Set the absolute expiration time in millis
465    fn set_expire(&mut self, expire_at: u64) {
466        self.expires = expire_at;
467    }
468
469    fn reset_ttl(&mut self, other: &Self) {
470        self.ttl = other.ttl;
471        self.created = other.created;
472        self.expires = get_expiration_time(self.created, self.ttl, 100);
473        self.refresh = if self.ttl > 1 {
474            get_expiration_time(self.created, self.ttl, 80)
475        } else {
476            // If TTL is 1, it means this record is expiring,
477            // then we set refresh to the same time as expires.
478            self.expires
479        };
480    }
481
482    /// Modify TTL to reflect the remaining life time from `now`.
483    pub fn update_ttl(&mut self, now: u64) {
484        if now > self.created {
485            let elapsed = now - self.created;
486            self.ttl -= (elapsed / 1000) as u32;
487        }
488    }
489
490    pub fn set_new_name(&mut self, new_name: String) {
491        if new_name == self.entry.name {
492            self.new_name = None;
493        } else {
494            self.new_name = Some(new_name);
495        }
496    }
497
498    pub fn get_new_name(&self) -> Option<&str> {
499        self.new_name.as_deref()
500    }
501
502    /// Return the new name if exists, otherwise the regular name in DnsEntry.
503    pub(crate) fn get_name(&self) -> &str {
504        self.new_name.as_deref().unwrap_or(&self.entry.name)
505    }
506
507    pub fn get_original_name(&self) -> &str {
508        &self.entry.name
509    }
510}
511
512impl PartialEq for DnsRecord {
513    fn eq(&self, other: &Self) -> bool {
514        self.entry == other.entry
515    }
516}
517
518/// Common methods for DNS resource records.
519pub trait DnsRecordExt: fmt::Debug {
520    fn get_record(&self) -> &DnsRecord;
521    fn get_record_mut(&mut self) -> &mut DnsRecord;
522    fn write(&self, packet: &mut DnsOutPacket);
523    fn any(&self) -> &dyn Any;
524
525    /// Returns whether `other` record is considered the same except TTL.
526    fn matches(&self, other: &dyn DnsRecordExt) -> bool;
527
528    /// Returns whether `other` record has the same rdata.
529    fn rrdata_match(&self, other: &dyn DnsRecordExt) -> bool;
530
531    /// Returns the result based on a byte-level comparison of `rdata`.
532    /// If `other` is not valid, returns `Greater`.
533    fn compare_rdata(&self, other: &dyn DnsRecordExt) -> cmp::Ordering;
534
535    /// Returns the result based on "lexicographically later" defined below.
536    fn compare(&self, other: &dyn DnsRecordExt) -> cmp::Ordering {
537        /*
538        RFC 6762: https://datatracker.ietf.org/doc/html/rfc6762#section-8.2
539
540        ... The determination of "lexicographically later" is performed by first
541        comparing the record class (excluding the cache-flush bit described
542        in Section 10.2), then the record type, then raw comparison of the
543        binary content of the rdata without regard for meaning or structure.
544        If the record classes differ, then the numerically greater class is
545        considered "lexicographically later".  Otherwise, if the record types
546        differ, then the numerically greater type is considered
547        "lexicographically later".  If the rrtype and rrclass both match,
548        then the rdata is compared. ...
549        */
550        match self.get_class().cmp(&other.get_class()) {
551            cmp::Ordering::Equal => match self.get_type().cmp(&other.get_type()) {
552                cmp::Ordering::Equal => self.compare_rdata(other),
553                not_equal => not_equal,
554            },
555            not_equal => not_equal,
556        }
557    }
558
559    /// Returns a human-readable string of rdata.
560    fn rdata_print(&self) -> String;
561
562    /// Returns the class only, excluding class_flush / unique bit.
563    fn get_class(&self) -> u16 {
564        self.get_record().entry.class
565    }
566
567    fn get_cache_flush(&self) -> bool {
568        self.get_record().entry.cache_flush
569    }
570
571    /// Return the new name if exists, otherwise the regular name in DnsEntry.
572    fn get_name(&self) -> &str {
573        self.get_record().get_name()
574    }
575
576    fn get_type(&self) -> RRType {
577        self.get_record().entry.ty
578    }
579
580    /// Resets TTL using `other` record.
581    /// `self.refresh` and `self.expires` are also reset.
582    fn reset_ttl(&mut self, other: &dyn DnsRecordExt) {
583        self.get_record_mut().reset_ttl(other.get_record());
584    }
585
586    fn get_created(&self) -> u64 {
587        self.get_record().get_created()
588    }
589
590    fn get_expire(&self) -> u64 {
591        self.get_record().get_expire_time()
592    }
593
594    fn set_expire(&mut self, expire_at: u64) {
595        self.get_record_mut().set_expire(expire_at);
596    }
597
598    /// Set expire as `expire_at` if it is sooner than the current `expire`.
599    fn set_expire_sooner(&mut self, expire_at: u64) {
600        if expire_at < self.get_expire() {
601            self.get_record_mut().set_expire(expire_at);
602        }
603    }
604
605    /// Returns true if the record expires in 1 second from `now`.
606    fn expires_soon(&self, now: u64) -> bool {
607        self.get_record().expires_soon(now)
608    }
609
610    /// Given `now`, if the record is due to refresh, this method updates the refresh time
611    /// and returns the new refresh time. Otherwise, returns None.
612    fn updated_refresh_time(&mut self, now: u64) -> Option<u64> {
613        if self.get_record_mut().refresh_maybe(now) {
614            Some(self.get_record().get_refresh_time())
615        } else {
616            None
617        }
618    }
619
620    /// Returns true if another record has matched content,
621    /// and if its TTL is at least half of this record's.
622    fn suppressed_by_answer(&self, other: &dyn DnsRecordExt) -> bool {
623        self.matches(other) && (other.get_record().ttl > self.get_record().ttl / 2)
624    }
625
626    /// Required by RFC 6762 Section 7.1: Known-Answer Suppression.
627    fn suppressed_by(&self, msg: &DnsIncoming) -> bool {
628        for answer in msg.answers.iter() {
629            if self.suppressed_by_answer(answer.as_ref()) {
630                return true;
631            }
632        }
633        false
634    }
635
636    fn clone_box(&self) -> DnsRecordBox;
637
638    fn boxed(self) -> DnsRecordBox;
639}
640
641/// Resource Record for IPv4 address or IPv6 address.
642#[derive(Debug, Clone)]
643pub(crate) struct DnsAddress {
644    pub(crate) record: DnsRecord,
645    address: IpAddr,
646    pub(crate) interface_id: InterfaceId,
647}
648
649impl DnsAddress {
650    pub fn new(
651        name: &str,
652        ty: RRType,
653        class: u16,
654        ttl: u32,
655        address: IpAddr,
656        interface_id: InterfaceId,
657    ) -> Self {
658        let record = DnsRecord::new(name, ty, class, ttl);
659        Self {
660            record,
661            address,
662            interface_id,
663        }
664    }
665
666    pub fn address(&self) -> ScopedIp {
667        match self.address {
668            IpAddr::V4(v4) => ScopedIp::V4(ScopedIpV4 { addr: v4 }),
669            IpAddr::V6(v6) => ScopedIp::V6(ScopedIpV6 {
670                addr: v6,
671                scope_id: self.interface_id.clone(),
672            }),
673        }
674    }
675}
676
677impl DnsRecordExt for DnsAddress {
678    fn get_record(&self) -> &DnsRecord {
679        &self.record
680    }
681
682    fn get_record_mut(&mut self) -> &mut DnsRecord {
683        &mut self.record
684    }
685
686    fn write(&self, packet: &mut DnsOutPacket) {
687        match self.address {
688            IpAddr::V4(addr) => packet.write_bytes(addr.octets().as_ref()),
689            IpAddr::V6(addr) => packet.write_bytes(addr.octets().as_ref()),
690        };
691    }
692
693    fn any(&self) -> &dyn Any {
694        self
695    }
696
697    fn matches(&self, other: &dyn DnsRecordExt) -> bool {
698        if let Some(other_a) = other.any().downcast_ref::<Self>() {
699            return self.address == other_a.address
700                && self.record.entry == other_a.record.entry
701                && self.interface_id == other_a.interface_id;
702        }
703        false
704    }
705
706    fn rrdata_match(&self, other: &dyn DnsRecordExt) -> bool {
707        if let Some(other_a) = other.any().downcast_ref::<Self>() {
708            return self.address == other_a.address;
709        }
710        false
711    }
712
713    fn compare_rdata(&self, other: &dyn DnsRecordExt) -> cmp::Ordering {
714        if let Some(other_a) = other.any().downcast_ref::<Self>() {
715            self.address.cmp(&other_a.address)
716        } else {
717            cmp::Ordering::Greater
718        }
719    }
720
721    fn rdata_print(&self) -> String {
722        format!("{}", self.address)
723    }
724
725    fn clone_box(&self) -> DnsRecordBox {
726        Box::new(self.clone())
727    }
728
729    fn boxed(self) -> DnsRecordBox {
730        Box::new(self)
731    }
732}
733
734/// Resource Record for a DNS pointer
735#[derive(Debug, Clone)]
736pub struct DnsPointer {
737    record: DnsRecord,
738    alias: String, // the full name of Service Instance
739}
740
741impl DnsPointer {
742    pub fn new(name: &str, ty: RRType, class: u16, ttl: u32, alias: String) -> Self {
743        let record = DnsRecord::new(name, ty, class, ttl);
744        Self { record, alias }
745    }
746
747    pub fn alias(&self) -> &str {
748        &self.alias
749    }
750}
751
752impl DnsRecordExt for DnsPointer {
753    fn get_record(&self) -> &DnsRecord {
754        &self.record
755    }
756
757    fn get_record_mut(&mut self) -> &mut DnsRecord {
758        &mut self.record
759    }
760
761    fn write(&self, packet: &mut DnsOutPacket) {
762        packet.write_name(&self.alias);
763    }
764
765    fn any(&self) -> &dyn Any {
766        self
767    }
768
769    fn matches(&self, other: &dyn DnsRecordExt) -> bool {
770        if let Some(other_ptr) = other.any().downcast_ref::<Self>() {
771            return self.alias == other_ptr.alias && self.record.entry == other_ptr.record.entry;
772        }
773        false
774    }
775
776    fn rrdata_match(&self, other: &dyn DnsRecordExt) -> bool {
777        if let Some(other_ptr) = other.any().downcast_ref::<Self>() {
778            return self.alias == other_ptr.alias;
779        }
780        false
781    }
782
783    fn compare_rdata(&self, other: &dyn DnsRecordExt) -> cmp::Ordering {
784        if let Some(other_ptr) = other.any().downcast_ref::<Self>() {
785            self.alias.cmp(&other_ptr.alias)
786        } else {
787            cmp::Ordering::Greater
788        }
789    }
790
791    fn rdata_print(&self) -> String {
792        self.alias.clone()
793    }
794
795    fn clone_box(&self) -> DnsRecordBox {
796        Box::new(self.clone())
797    }
798
799    fn boxed(self) -> DnsRecordBox {
800        Box::new(self)
801    }
802}
803
804/// Resource Record for a DNS service.
805#[derive(Debug, Clone)]
806pub struct DnsSrv {
807    pub(crate) record: DnsRecord,
808    pub(crate) priority: u16, // lower number means higher priority. Should be 0 in common cases.
809    pub(crate) weight: u16,   // Should be 0 in common cases
810    host: String,
811    port: u16,
812}
813
814impl DnsSrv {
815    pub fn new(
816        name: &str,
817        class: u16,
818        ttl: u32,
819        priority: u16,
820        weight: u16,
821        port: u16,
822        host: String,
823    ) -> Self {
824        let record = DnsRecord::new(name, RRType::SRV, class, ttl);
825        Self {
826            record,
827            priority,
828            weight,
829            host,
830            port,
831        }
832    }
833
834    pub fn host(&self) -> &str {
835        &self.host
836    }
837
838    pub fn port(&self) -> u16 {
839        self.port
840    }
841
842    pub fn set_host(&mut self, host: String) {
843        self.host = host;
844    }
845}
846
847impl DnsRecordExt for DnsSrv {
848    fn get_record(&self) -> &DnsRecord {
849        &self.record
850    }
851
852    fn get_record_mut(&mut self) -> &mut DnsRecord {
853        &mut self.record
854    }
855
856    fn write(&self, packet: &mut DnsOutPacket) {
857        packet.write_short(self.priority);
858        packet.write_short(self.weight);
859        packet.write_short(self.port);
860        packet.write_name(&self.host);
861    }
862
863    fn any(&self) -> &dyn Any {
864        self
865    }
866
867    fn matches(&self, other: &dyn DnsRecordExt) -> bool {
868        if let Some(other_svc) = other.any().downcast_ref::<Self>() {
869            return self.host == other_svc.host
870                && self.port == other_svc.port
871                && self.weight == other_svc.weight
872                && self.priority == other_svc.priority
873                && self.record.entry == other_svc.record.entry;
874        }
875        false
876    }
877
878    fn rrdata_match(&self, other: &dyn DnsRecordExt) -> bool {
879        if let Some(other_srv) = other.any().downcast_ref::<Self>() {
880            return self.host == other_srv.host
881                && self.port == other_srv.port
882                && self.weight == other_srv.weight
883                && self.priority == other_srv.priority;
884        }
885        false
886    }
887
888    fn compare_rdata(&self, other: &dyn DnsRecordExt) -> cmp::Ordering {
889        let Some(other_srv) = other.any().downcast_ref::<Self>() else {
890            return cmp::Ordering::Greater;
891        };
892
893        // 1. compare `priority`
894        match self
895            .priority
896            .to_be_bytes()
897            .cmp(&other_srv.priority.to_be_bytes())
898        {
899            cmp::Ordering::Equal => {
900                // 2. compare `weight`
901                match self
902                    .weight
903                    .to_be_bytes()
904                    .cmp(&other_srv.weight.to_be_bytes())
905                {
906                    cmp::Ordering::Equal => {
907                        // 3. compare `port`.
908                        match self.port.to_be_bytes().cmp(&other_srv.port.to_be_bytes()) {
909                            cmp::Ordering::Equal => self.host.cmp(&other_srv.host),
910                            not_equal => not_equal,
911                        }
912                    }
913                    not_equal => not_equal,
914                }
915            }
916            not_equal => not_equal,
917        }
918    }
919
920    fn rdata_print(&self) -> String {
921        format!(
922            "priority: {}, weight: {}, port: {}, host: {}",
923            self.priority, self.weight, self.port, self.host
924        )
925    }
926
927    fn clone_box(&self) -> DnsRecordBox {
928        Box::new(self.clone())
929    }
930
931    fn boxed(self) -> DnsRecordBox {
932        Box::new(self)
933    }
934}
935
936/// Resource Record for a DNS TXT record.
937///
938/// From [RFC 6763 section 6]:
939///
940/// The format of each constituent string within the DNS TXT record is a
941/// single length byte, followed by 0-255 bytes of text data.
942///
943/// DNS-SD uses DNS TXT records to store arbitrary key/value pairs
944///    conveying additional information about the named service.  Each
945///    key/value pair is encoded as its own constituent string within the
946///    DNS TXT record, in the form "key=value" (without the quotation
947///    marks).  Everything up to the first '=' character is the key (Section
948///    6.4).  Everything after the first '=' character to the end of the
949///    string (including subsequent '=' characters, if any) is the value
950#[derive(Clone)]
951pub struct DnsTxt {
952    pub(crate) record: DnsRecord,
953    text: Vec<u8>,
954}
955
956impl DnsTxt {
957    pub fn new(name: &str, class: u16, ttl: u32, text: Vec<u8>) -> Self {
958        let record = DnsRecord::new(name, RRType::TXT, class, ttl);
959        Self { record, text }
960    }
961
962    pub fn text(&self) -> &[u8] {
963        &self.text
964    }
965}
966
967impl DnsRecordExt for DnsTxt {
968    fn get_record(&self) -> &DnsRecord {
969        &self.record
970    }
971
972    fn get_record_mut(&mut self) -> &mut DnsRecord {
973        &mut self.record
974    }
975
976    fn write(&self, packet: &mut DnsOutPacket) {
977        packet.write_bytes(&self.text);
978    }
979
980    fn any(&self) -> &dyn Any {
981        self
982    }
983
984    fn matches(&self, other: &dyn DnsRecordExt) -> bool {
985        if let Some(other_txt) = other.any().downcast_ref::<Self>() {
986            return self.text == other_txt.text && self.record.entry == other_txt.record.entry;
987        }
988        false
989    }
990
991    fn rrdata_match(&self, other: &dyn DnsRecordExt) -> bool {
992        if let Some(other_txt) = other.any().downcast_ref::<Self>() {
993            return self.text == other_txt.text;
994        }
995        false
996    }
997
998    fn compare_rdata(&self, other: &dyn DnsRecordExt) -> cmp::Ordering {
999        if let Some(other_txt) = other.any().downcast_ref::<Self>() {
1000            self.text.cmp(&other_txt.text)
1001        } else {
1002            cmp::Ordering::Greater
1003        }
1004    }
1005
1006    fn rdata_print(&self) -> String {
1007        format!("{:?}", decode_txt(&self.text))
1008    }
1009
1010    fn clone_box(&self) -> DnsRecordBox {
1011        Box::new(self.clone())
1012    }
1013
1014    fn boxed(self) -> DnsRecordBox {
1015        Box::new(self)
1016    }
1017}
1018
1019impl fmt::Debug for DnsTxt {
1020    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1021        let properties = decode_txt(&self.text);
1022        write!(
1023            f,
1024            "DnsTxt {{ record: {:?}, text: {:?} }}",
1025            self.record, properties
1026        )
1027    }
1028}
1029
1030// Convert from DNS TXT record content to key/value pairs
1031fn decode_txt(txt: &[u8]) -> Vec<TxtProperty> {
1032    let mut properties = Vec::new();
1033    let mut offset = 0;
1034    while offset < txt.len() {
1035        let length = txt[offset] as usize;
1036        if length == 0 {
1037            break; // reached the end
1038        }
1039        offset += 1; // move over the length byte
1040
1041        let offset_end = offset + length;
1042        if offset_end > txt.len() {
1043            trace!("ERROR: DNS TXT: size given for property is out of range. (offset={}, length={}, offset_end={}, record length={})", offset, length, offset_end, txt.len());
1044            break; // Skipping the rest of the record content, as the size for this property would already be out of range.
1045        }
1046        let kv_bytes = &txt[offset..offset_end];
1047
1048        // split key and val using the first `=`
1049        let (k, v) = kv_bytes.iter().position(|&x| x == b'=').map_or_else(
1050            || (kv_bytes.to_vec(), None),
1051            |idx| (kv_bytes[..idx].to_vec(), Some(kv_bytes[idx + 1..].to_vec())),
1052        );
1053
1054        // Make sure the key can be stored in UTF-8.
1055        match String::from_utf8(k) {
1056            Ok(k_string) => {
1057                properties.push(TxtProperty {
1058                    key: k_string,
1059                    val: v,
1060                });
1061            }
1062            Err(e) => trace!("ERROR: convert to String from key: {}", e),
1063        }
1064
1065        offset += length;
1066    }
1067
1068    properties
1069}
1070
1071/// Represents a property in a TXT record.
1072#[derive(Clone, PartialEq, Eq)]
1073pub struct TxtProperty {
1074    /// The name of the property. The original cases are kept.
1075    key: String,
1076
1077    /// RFC 6763 says values are bytes, not necessarily UTF-8.
1078    /// It is also possible that there is no value, in which case
1079    /// the key is a boolean key.
1080    val: Option<Vec<u8>>,
1081}
1082
1083impl TxtProperty {
1084    /// Returns the value of a property as str.
1085    pub fn val_str(&self) -> &str {
1086        self.val
1087            .as_ref()
1088            .map_or("", |v| std::str::from_utf8(&v[..]).unwrap_or_default())
1089    }
1090}
1091
1092/// Supports constructing from a tuple.
1093impl<K, V> From<&(K, V)> for TxtProperty
1094where
1095    K: ToString,
1096    V: ToString,
1097{
1098    fn from(prop: &(K, V)) -> Self {
1099        Self {
1100            key: prop.0.to_string(),
1101            val: Some(prop.1.to_string().into_bytes()),
1102        }
1103    }
1104}
1105
1106impl<K, V> From<(K, V)> for TxtProperty
1107where
1108    K: ToString,
1109    V: AsRef<[u8]>,
1110{
1111    fn from(prop: (K, V)) -> Self {
1112        Self {
1113            key: prop.0.to_string(),
1114            val: Some(prop.1.as_ref().into()),
1115        }
1116    }
1117}
1118
1119/// Support a property that has no value.
1120impl From<&str> for TxtProperty {
1121    fn from(key: &str) -> Self {
1122        Self {
1123            key: key.to_string(),
1124            val: None,
1125        }
1126    }
1127}
1128
1129impl fmt::Display for TxtProperty {
1130    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1131        write!(f, "{}={}", self.key, self.val_str())
1132    }
1133}
1134
1135/// Mimic the default debug output for a struct, with a twist:
1136/// - If self.var is UTF-8, will output it as a string in double quotes.
1137/// - If self.var is not UTF-8, will output its bytes as in hex.
1138impl fmt::Debug for TxtProperty {
1139    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1140        let val_string = self.val.as_ref().map_or_else(
1141            || "None".to_string(),
1142            |v| {
1143                std::str::from_utf8(&v[..]).map_or_else(
1144                    |_| format!("Some({})", u8_slice_to_hex(&v[..])),
1145                    |s| format!("Some(\"{s}\")"),
1146                )
1147            },
1148        );
1149
1150        write!(
1151            f,
1152            "TxtProperty {{key: \"{}\", val: {}}}",
1153            &self.key, &val_string,
1154        )
1155    }
1156}
1157
1158const HEX_TABLE: [char; 16] = [
1159    '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f',
1160];
1161
1162/// Create a hex string from `slice`, with a "0x" prefix.
1163///
1164/// For example, [1u8, 2u8] -> "0x0102"
1165fn u8_slice_to_hex(slice: &[u8]) -> String {
1166    let mut hex = String::with_capacity(slice.len() * 2 + 2);
1167    hex.push_str("0x");
1168    for b in slice {
1169        hex.push(HEX_TABLE[(b >> 4) as usize]);
1170        hex.push(HEX_TABLE[(b & 0x0F) as usize]);
1171    }
1172    hex
1173}
1174
1175/// A DNS host information record
1176#[derive(Debug, Clone)]
1177struct DnsHostInfo {
1178    record: DnsRecord,
1179    cpu: String,
1180    os: String,
1181}
1182
1183impl DnsHostInfo {
1184    fn new(name: &str, ty: RRType, class: u16, ttl: u32, cpu: String, os: String) -> Self {
1185        let record = DnsRecord::new(name, ty, class, ttl);
1186        Self { record, cpu, os }
1187    }
1188}
1189
1190impl DnsRecordExt for DnsHostInfo {
1191    fn get_record(&self) -> &DnsRecord {
1192        &self.record
1193    }
1194
1195    fn get_record_mut(&mut self) -> &mut DnsRecord {
1196        &mut self.record
1197    }
1198
1199    fn write(&self, packet: &mut DnsOutPacket) {
1200        println!("writing HInfo: cpu {} os {}", &self.cpu, &self.os);
1201        packet.write_bytes(self.cpu.as_bytes());
1202        packet.write_bytes(self.os.as_bytes());
1203    }
1204
1205    fn any(&self) -> &dyn Any {
1206        self
1207    }
1208
1209    fn matches(&self, other: &dyn DnsRecordExt) -> bool {
1210        if let Some(other_hinfo) = other.any().downcast_ref::<Self>() {
1211            return self.cpu == other_hinfo.cpu
1212                && self.os == other_hinfo.os
1213                && self.record.entry == other_hinfo.record.entry;
1214        }
1215        false
1216    }
1217
1218    fn rrdata_match(&self, other: &dyn DnsRecordExt) -> bool {
1219        if let Some(other_hinfo) = other.any().downcast_ref::<Self>() {
1220            return self.cpu == other_hinfo.cpu && self.os == other_hinfo.os;
1221        }
1222        false
1223    }
1224
1225    fn compare_rdata(&self, other: &dyn DnsRecordExt) -> cmp::Ordering {
1226        if let Some(other_hinfo) = other.any().downcast_ref::<Self>() {
1227            match self.cpu.cmp(&other_hinfo.cpu) {
1228                cmp::Ordering::Equal => self.os.cmp(&other_hinfo.os),
1229                ordering => ordering,
1230            }
1231        } else {
1232            cmp::Ordering::Greater
1233        }
1234    }
1235
1236    fn rdata_print(&self) -> String {
1237        format!("cpu: {}, os: {}", self.cpu, self.os)
1238    }
1239
1240    fn clone_box(&self) -> DnsRecordBox {
1241        Box::new(self.clone())
1242    }
1243
1244    fn boxed(self) -> DnsRecordBox {
1245        Box::new(self)
1246    }
1247}
1248
1249/// Resource Record for negative responses
1250///
1251/// [RFC4034 section 4.1](https://datatracker.ietf.org/doc/html/rfc4034#section-4.1)
1252/// and
1253/// [RFC6762 section 6.1](https://datatracker.ietf.org/doc/html/rfc6762#section-6.1)
1254#[derive(Debug, Clone)]
1255pub struct DnsNSec {
1256    record: DnsRecord,
1257    next_domain: String,
1258    type_bitmap: Vec<u8>,
1259}
1260
1261impl DnsNSec {
1262    pub fn new(
1263        name: &str,
1264        class: u16,
1265        ttl: u32,
1266        next_domain: String,
1267        type_bitmap: Vec<u8>,
1268    ) -> Self {
1269        let record = DnsRecord::new(name, RRType::NSEC, class, ttl);
1270        Self {
1271            record,
1272            next_domain,
1273            type_bitmap,
1274        }
1275    }
1276
1277    /// Returns the types marked by `type_bitmap`
1278    pub fn _types(&self) -> Vec<u16> {
1279        // From RFC 4034: 4.1.2 The Type Bit Maps Field
1280        // https://datatracker.ietf.org/doc/html/rfc4034#section-4.1.2
1281        //
1282        // Each bitmap encodes the low-order 8 bits of RR types within the
1283        // window block, in network bit order.  The first bit is bit 0.  For
1284        // window block 0, bit 1 corresponds to RR type 1 (A), bit 2 corresponds
1285        // to RR type 2 (NS), and so forth.
1286
1287        let mut bit_num = 0;
1288        let mut results = Vec::new();
1289
1290        for byte in self.type_bitmap.iter() {
1291            let mut bit_mask: u8 = 0x80; // for bit 0 in network bit order
1292
1293            // check every bit in this byte, one by one.
1294            for _ in 0..8 {
1295                if (byte & bit_mask) != 0 {
1296                    results.push(bit_num);
1297                }
1298                bit_num += 1;
1299                bit_mask >>= 1; // mask for the next bit
1300            }
1301        }
1302        results
1303    }
1304}
1305
1306impl DnsRecordExt for DnsNSec {
1307    fn get_record(&self) -> &DnsRecord {
1308        &self.record
1309    }
1310
1311    fn get_record_mut(&mut self) -> &mut DnsRecord {
1312        &mut self.record
1313    }
1314
1315    fn write(&self, packet: &mut DnsOutPacket) {
1316        packet.write_bytes(self.next_domain.as_bytes());
1317        packet.write_bytes(&self.type_bitmap);
1318    }
1319
1320    fn any(&self) -> &dyn Any {
1321        self
1322    }
1323
1324    fn matches(&self, other: &dyn DnsRecordExt) -> bool {
1325        if let Some(other_record) = other.any().downcast_ref::<Self>() {
1326            return self.next_domain == other_record.next_domain
1327                && self.type_bitmap == other_record.type_bitmap
1328                && self.record.entry == other_record.record.entry;
1329        }
1330        false
1331    }
1332
1333    fn rrdata_match(&self, other: &dyn DnsRecordExt) -> bool {
1334        if let Some(other_record) = other.any().downcast_ref::<Self>() {
1335            return self.next_domain == other_record.next_domain
1336                && self.type_bitmap == other_record.type_bitmap;
1337        }
1338        false
1339    }
1340
1341    fn compare_rdata(&self, other: &dyn DnsRecordExt) -> cmp::Ordering {
1342        if let Some(other_nsec) = other.any().downcast_ref::<Self>() {
1343            match self.next_domain.cmp(&other_nsec.next_domain) {
1344                cmp::Ordering::Equal => self.type_bitmap.cmp(&other_nsec.type_bitmap),
1345                ordering => ordering,
1346            }
1347        } else {
1348            cmp::Ordering::Greater
1349        }
1350    }
1351
1352    fn rdata_print(&self) -> String {
1353        format!(
1354            "next_domain: {}, type_bitmap len: {}",
1355            self.next_domain,
1356            self.type_bitmap.len()
1357        )
1358    }
1359
1360    fn clone_box(&self) -> DnsRecordBox {
1361        Box::new(self.clone())
1362    }
1363
1364    fn boxed(self) -> DnsRecordBox {
1365        Box::new(self)
1366    }
1367}
1368
1369#[derive(PartialEq)]
1370enum PacketState {
1371    Init = 0,
1372    Finished = 1,
1373}
1374
1375/// A single packet for outgoing DNS message.
1376pub struct DnsOutPacket {
1377    /// All bytes in `data` concatenated is the actual packet on the wire.
1378    data: Vec<Vec<u8>>,
1379
1380    /// Current logical size of the packet. It starts with the size of the mandatory header.
1381    size: usize,
1382
1383    /// An internal state, not defined by DNS.
1384    state: PacketState,
1385
1386    /// k: name, v: offset
1387    names: HashMap<String, u16>,
1388}
1389
1390impl DnsOutPacket {
1391    fn new() -> Self {
1392        Self {
1393            data: Vec::new(),
1394            size: MSG_HEADER_LEN, // Header is mandatory.
1395            state: PacketState::Init,
1396            names: HashMap::new(),
1397        }
1398    }
1399
1400    pub fn size(&self) -> usize {
1401        self.size
1402    }
1403
1404    pub fn to_bytes(&self) -> Vec<u8> {
1405        self.data.concat()
1406    }
1407
1408    fn write_question(&mut self, question: &DnsQuestion) {
1409        self.write_name(&question.entry.name);
1410        self.write_short(question.entry.ty as u16);
1411        self.write_short(question.entry.class);
1412    }
1413
1414    /// Writes a record (answer, authoritative answer, additional)
1415    /// Returns false if the packet exceeds the max size with this record, nothing is written to the packet.
1416    /// otherwise returns true.
1417    fn write_record(&mut self, record_ext: &dyn DnsRecordExt, now: u64) -> bool {
1418        let start_data_length = self.data.len();
1419        let start_size = self.size;
1420
1421        let record = record_ext.get_record();
1422        self.write_name(record.get_name());
1423        self.write_short(record.entry.ty as u16);
1424        if record.entry.cache_flush {
1425            // check "multicast"
1426            self.write_short(record.entry.class | CLASS_CACHE_FLUSH);
1427        } else {
1428            self.write_short(record.entry.class);
1429        }
1430
1431        if now == 0 {
1432            self.write_u32(record.ttl);
1433        } else {
1434            self.write_u32(record.get_remaining_ttl(now));
1435        }
1436
1437        let index = self.data.len();
1438
1439        // Adjust size for the short we will write before this record
1440        self.size += 2;
1441        record_ext.write(self);
1442        self.size -= 2;
1443
1444        let length: usize = self.data[index..].iter().map(|x| x.len()).sum();
1445        self.insert_short(index, length as u16);
1446
1447        if self.size > MAX_MSG_ABSOLUTE {
1448            self.data.truncate(start_data_length);
1449            self.size = start_size;
1450            self.state = PacketState::Finished;
1451            return false;
1452        }
1453
1454        true
1455    }
1456
1457    pub(crate) fn insert_short(&mut self, index: usize, value: u16) {
1458        self.data.insert(index, value.to_be_bytes().to_vec());
1459        self.size += 2;
1460    }
1461
1462    // Write name to packet
1463    //
1464    // [RFC1035]
1465    // 4.1.4. Message compression
1466    //
1467    // In order to reduce the size of messages, the domain system utilizes a
1468    // compression scheme which eliminates the repetition of domain names in a
1469    // message.  In this scheme, an entire domain name or a list of labels at
1470    // the end of a domain name is replaced with a pointer to a prior occurrence
1471    // of the same name.
1472    // The pointer takes the form of a two octet sequence:
1473    //     +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
1474    //     | 1  1|                OFFSET                   |
1475    //     +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
1476    // The first two bits are ones.  This allows a pointer to be distinguished
1477    // from a label, since the label must begin with two zero bits because
1478    // labels are restricted to 63 octets or less.  (The 10 and 01 combinations
1479    // are reserved for future use.)  The OFFSET field specifies an offset from
1480    // the start of the message (i.e., the first octet of the ID field in the
1481    // domain header).  A zero offset specifies the first byte of the ID field,
1482    // etc.
1483    fn write_name(&mut self, name: &str) {
1484        // ignore the ending "." if exists
1485        let end = name.len();
1486        let end = if end > 0 && &name[end - 1..] == "." {
1487            end - 1
1488        } else {
1489            end
1490        };
1491
1492        let mut here = 0;
1493        while here < end {
1494            const POINTER_MASK: u16 = 0xC000;
1495            let remaining = &name[here..end];
1496
1497            // Check if 'remaining' already appeared in this message
1498            match self.names.get(remaining) {
1499                Some(offset) => {
1500                    let pointer = *offset | POINTER_MASK;
1501                    self.write_short(pointer);
1502                    // println!(
1503                    //     "written pointer {} ({}) for {}",
1504                    //     pointer,
1505                    //     pointer ^ POINTER_MASK,
1506                    //     remaining
1507                    // );
1508                    break;
1509                }
1510                None => {
1511                    // Remember the remaining parts so we can point to it
1512                    self.names.insert(remaining.to_string(), self.size as u16);
1513                    // println!("set offset {} for {}", self.size, remaining);
1514
1515                    // Find the current label to write into the packet
1516                    let stop = remaining.find('.').map_or(end, |i| here + i);
1517                    let label = &name[here..stop];
1518                    self.write_utf8(label);
1519
1520                    here = stop + 1; // move past the current label
1521                }
1522            }
1523
1524            if here >= end {
1525                self.write_byte(0); // name ends with 0 if not using a pointer
1526            }
1527        }
1528    }
1529
1530    fn write_utf8(&mut self, utf: &str) {
1531        assert!(utf.len() < 64);
1532        self.write_byte(utf.len() as u8);
1533        self.write_bytes(utf.as_bytes());
1534    }
1535
1536    fn write_bytes(&mut self, s: &[u8]) {
1537        self.data.push(s.to_vec());
1538        self.size += s.len();
1539    }
1540
1541    fn write_u32(&mut self, int: u32) {
1542        self.data.push(int.to_be_bytes().to_vec());
1543        self.size += 4;
1544    }
1545
1546    fn write_short(&mut self, short: u16) {
1547        self.data.push(short.to_be_bytes().to_vec());
1548        self.size += 2;
1549    }
1550
1551    fn write_byte(&mut self, byte: u8) {
1552        self.data.push(vec![byte]);
1553        self.size += 1;
1554    }
1555
1556    /// Writes the header fields and finish the packet.
1557    /// This function should be only called when finishing a packet.
1558    ///
1559    /// The header format is based on RFC 1035 section 4.1.1:
1560    /// https://datatracker.ietf.org/doc/html/rfc1035#section-4.1.1
1561    //
1562    //                                  1  1  1  1  1  1
1563    //    0  1  2  3  4  5  6  7  8  9  0  1  2  3  4  5
1564    //    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
1565    //    |                      ID                       |
1566    //    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
1567    //    |QR|   Opcode  |AA|TC|RD|RA|   Z    |   RCODE   |
1568    //    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
1569    //    |                    QDCOUNT                    |
1570    //    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
1571    //    |                    ANCOUNT                    |
1572    //    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
1573    //    |                    NSCOUNT                    |
1574    //    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
1575    //    |                    ARCOUNT                    |
1576    //    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
1577    //
1578    fn write_header(
1579        &mut self,
1580        id: u16,
1581        flags: u16,
1582        q_count: u16,
1583        a_count: u16,
1584        auth_count: u16,
1585        addi_count: u16,
1586    ) {
1587        self.insert_short(0, addi_count);
1588        self.insert_short(0, auth_count);
1589        self.insert_short(0, a_count);
1590        self.insert_short(0, q_count);
1591        self.insert_short(0, flags);
1592        self.insert_short(0, id);
1593
1594        // Adjust the size as it was already initialized to include the header.
1595        self.size -= MSG_HEADER_LEN;
1596
1597        self.state = PacketState::Finished;
1598    }
1599}
1600
1601/// Representation of one outgoing DNS message that could be sent in one or more packet(s).
1602pub struct DnsOutgoing {
1603    flags: u16,
1604    id: u16,
1605    multicast: bool,
1606    questions: Vec<DnsQuestion>,
1607    answers: Vec<(DnsRecordBox, u64)>,
1608    authorities: Vec<DnsRecordBox>,
1609    additionals: Vec<DnsRecordBox>,
1610    known_answer_count: i64, // for internal maintenance only
1611}
1612
1613impl DnsOutgoing {
1614    pub fn new(flags: u16) -> Self {
1615        Self {
1616            flags,
1617            id: 0,
1618            multicast: true,
1619            questions: Vec::new(),
1620            answers: Vec::new(),
1621            authorities: Vec::new(),
1622            additionals: Vec::new(),
1623            known_answer_count: 0,
1624        }
1625    }
1626
1627    pub fn questions(&self) -> &[DnsQuestion] {
1628        &self.questions
1629    }
1630
1631    /// For testing purposes only.
1632    pub(crate) fn _answers(&self) -> &[(DnsRecordBox, u64)] {
1633        &self.answers
1634    }
1635
1636    pub fn answers_count(&self) -> usize {
1637        self.answers.len()
1638    }
1639
1640    pub fn authorities(&self) -> &[DnsRecordBox] {
1641        &self.authorities
1642    }
1643
1644    pub fn additionals(&self) -> &[DnsRecordBox] {
1645        &self.additionals
1646    }
1647
1648    pub fn known_answer_count(&self) -> i64 {
1649        self.known_answer_count
1650    }
1651
1652    pub fn set_id(&mut self, id: u16) {
1653        self.id = id;
1654    }
1655
1656    pub const fn is_query(&self) -> bool {
1657        (self.flags & FLAGS_QR_MASK) == FLAGS_QR_QUERY
1658    }
1659
1660    const fn is_response(&self) -> bool {
1661        (self.flags & FLAGS_QR_MASK) == FLAGS_QR_RESPONSE
1662    }
1663
1664    // Adds an additional answer
1665
1666    // From: RFC 6763, DNS-Based Service Discovery, February 2013
1667
1668    // 12.  DNS Additional Record Generation
1669
1670    //    DNS has an efficiency feature whereby a DNS server may place
1671    //    additional records in the additional section of the DNS message.
1672    //    These additional records are records that the client did not
1673    //    explicitly request, but the server has reasonable grounds to expect
1674    //    that the client might request them shortly, so including them can
1675    //    save the client from having to issue additional queries.
1676
1677    //    This section recommends which additional records SHOULD be generated
1678    //    to improve network efficiency, for both Unicast and Multicast DNS-SD
1679    //    responses.
1680
1681    // 12.1.  PTR Records
1682
1683    //    When including a DNS-SD Service Instance Enumeration or Selective
1684    //    Instance Enumeration (subtype) PTR record in a response packet, the
1685    //    server/responder SHOULD include the following additional records:
1686
1687    //    o  The SRV record(s) named in the PTR rdata.
1688    //    o  The TXT record(s) named in the PTR rdata.
1689    //    o  All address records (type "A" and "AAAA") named in the SRV rdata.
1690
1691    // 12.2.  SRV Records
1692
1693    //    When including an SRV record in a response packet, the
1694    //    server/responder SHOULD include the following additional records:
1695
1696    //    o  All address records (type "A" and "AAAA") named in the SRV rdata.
1697    pub fn add_additional_answer(&mut self, answer: impl DnsRecordExt + 'static) {
1698        trace!("add_additional_answer: {:?}", &answer);
1699        self.additionals.push(answer.boxed());
1700    }
1701
1702    /// A workaround as Rust doesn't allow us to pass DnsRecordBox in as `impl DnsRecordExt`
1703    pub fn add_answer_box(&mut self, answer_box: DnsRecordBox) {
1704        self.answers.push((answer_box, 0));
1705    }
1706
1707    pub fn add_authority(&mut self, record: DnsRecordBox) {
1708        self.authorities.push(record);
1709    }
1710
1711    /// Returns true if `answer` is added to the outgoing msg.
1712    /// Returns false if `answer` was not added as it expired or suppressed by the incoming `msg`.
1713    pub fn add_answer(
1714        &mut self,
1715        msg: &DnsIncoming,
1716        answer: impl DnsRecordExt + Send + 'static,
1717    ) -> bool {
1718        trace!("Check for add_answer");
1719        if answer.suppressed_by(msg) {
1720            trace!("my answer is suppressed by incoming msg");
1721            self.known_answer_count += 1;
1722            return false;
1723        }
1724
1725        self.add_answer_at_time(answer, 0)
1726    }
1727
1728    /// Returns true if `answer` is added to the outgoing msg.
1729    /// Returns false if the answer is expired `now` hence not added.
1730    /// If `now` is 0, do not check if the answer expires.
1731    pub fn add_answer_at_time(
1732        &mut self,
1733        answer: impl DnsRecordExt + Send + 'static,
1734        now: u64,
1735    ) -> bool {
1736        if now == 0 || !answer.get_record().is_expired(now) {
1737            trace!("add_answer push: {:?}", &answer);
1738            self.answers.push((answer.boxed(), now));
1739            return true;
1740        }
1741        false
1742    }
1743
1744    pub fn add_question(&mut self, name: &str, qtype: RRType) {
1745        let q = DnsQuestion {
1746            entry: DnsEntry::new(name.to_string(), qtype, CLASS_IN),
1747        };
1748        self.questions.push(q);
1749    }
1750
1751    /// Returns a list of actual DNS packet data to be sent on the wire.
1752    pub fn to_data_on_wire(&self) -> Vec<Vec<u8>> {
1753        let packet_list = self.to_packets();
1754        packet_list.iter().map(|p| p.data.concat()).collect()
1755    }
1756
1757    /// Encode self into one or more packets.
1758    pub fn to_packets(&self) -> Vec<DnsOutPacket> {
1759        let mut packet_list = Vec::new();
1760        let mut packet = DnsOutPacket::new();
1761
1762        let mut question_count = self.questions.len() as u16;
1763        let mut answer_count = 0;
1764        let mut auth_count = 0;
1765        let mut addi_count = 0;
1766        let id = if self.multicast { 0 } else { self.id };
1767
1768        for question in self.questions.iter() {
1769            packet.write_question(question);
1770        }
1771
1772        for (answer, time) in self.answers.iter() {
1773            if packet.write_record(answer.as_ref(), *time) {
1774                answer_count += 1;
1775            }
1776        }
1777
1778        for auth in self.authorities.iter() {
1779            auth_count += u16::from(packet.write_record(auth.as_ref(), 0));
1780        }
1781
1782        for addi in self.additionals.iter() {
1783            if packet.write_record(addi.as_ref(), 0) {
1784                addi_count += 1;
1785                continue;
1786            }
1787
1788            // No more processing for response packets.
1789            if self.is_response() {
1790                break;
1791            }
1792
1793            // For query, the current packet exceeds its max size due to known answers,
1794            // need to truncate.
1795
1796            // finish the current packet first.
1797            packet.write_header(
1798                id,
1799                self.flags | FLAGS_TC,
1800                question_count,
1801                answer_count,
1802                auth_count,
1803                addi_count,
1804            );
1805
1806            packet_list.push(packet);
1807
1808            // create a new packet and reset counts.
1809            packet = DnsOutPacket::new();
1810            packet.write_record(addi.as_ref(), 0);
1811
1812            question_count = 0;
1813            answer_count = 0;
1814            auth_count = 0;
1815            addi_count = 1;
1816        }
1817
1818        packet.write_header(
1819            id,
1820            self.flags,
1821            question_count,
1822            answer_count,
1823            auth_count,
1824            addi_count,
1825        );
1826
1827        packet_list.push(packet);
1828        packet_list
1829    }
1830}
1831
1832/// An incoming DNS message. It could be a query or a response.
1833#[derive(Debug)]
1834pub struct DnsIncoming {
1835    offset: usize,
1836    data: Vec<u8>,
1837    questions: Vec<DnsQuestion>,
1838    answers: Vec<DnsRecordBox>,
1839    authorities: Vec<DnsRecordBox>,
1840    additional: Vec<DnsRecordBox>,
1841    id: u16,
1842    flags: u16,
1843    num_questions: u16,
1844    num_answers: u16,
1845    num_authorities: u16,
1846    num_additionals: u16,
1847    interface_id: InterfaceId,
1848}
1849
1850impl DnsIncoming {
1851    pub fn new(data: Vec<u8>, interface_id: InterfaceId) -> Result<Self> {
1852        let mut incoming = Self {
1853            offset: 0,
1854            data,
1855            questions: Vec::new(),
1856            answers: Vec::new(),
1857            authorities: Vec::new(),
1858            additional: Vec::new(),
1859            id: 0,
1860            flags: 0,
1861            num_questions: 0,
1862            num_answers: 0,
1863            num_authorities: 0,
1864            num_additionals: 0,
1865            interface_id,
1866        };
1867
1868        /*
1869        RFC 1035 section 4.1: https://datatracker.ietf.org/doc/html/rfc1035#section-4.1
1870        ...
1871        All communications inside of the domain protocol are carried in a single
1872        format called a message.  The top level format of message is divided
1873        into 5 sections (some of which are empty in certain cases) shown below:
1874
1875            +---------------------+
1876            |        Header       |
1877            +---------------------+
1878            |       Question      | the question for the name server
1879            +---------------------+
1880            |        Answer       | RRs answering the question
1881            +---------------------+
1882            |      Authority      | RRs pointing toward an authority
1883            +---------------------+
1884            |      Additional     | RRs holding additional information
1885            +---------------------+
1886         */
1887        incoming.read_header()?;
1888        incoming.read_questions()?;
1889        incoming.read_answers()?;
1890        incoming.read_authorities()?;
1891        incoming.read_additional()?;
1892
1893        Ok(incoming)
1894    }
1895
1896    pub fn id(&self) -> u16 {
1897        self.id
1898    }
1899
1900    pub fn questions(&self) -> &[DnsQuestion] {
1901        &self.questions
1902    }
1903
1904    pub fn answers(&self) -> &[DnsRecordBox] {
1905        &self.answers
1906    }
1907
1908    pub fn authorities(&self) -> &[DnsRecordBox] {
1909        &self.authorities
1910    }
1911
1912    pub fn additionals(&self) -> &[DnsRecordBox] {
1913        &self.additional
1914    }
1915
1916    pub fn answers_mut(&mut self) -> &mut Vec<DnsRecordBox> {
1917        &mut self.answers
1918    }
1919
1920    pub fn authorities_mut(&mut self) -> &mut Vec<DnsRecordBox> {
1921        &mut self.authorities
1922    }
1923
1924    pub fn additionals_mut(&mut self) -> &mut Vec<DnsRecordBox> {
1925        &mut self.additional
1926    }
1927
1928    pub fn all_records(self) -> impl Iterator<Item = DnsRecordBox> {
1929        self.answers
1930            .into_iter()
1931            .chain(self.authorities)
1932            .chain(self.additional)
1933    }
1934
1935    pub fn num_additionals(&self) -> u16 {
1936        self.num_additionals
1937    }
1938
1939    pub fn num_authorities(&self) -> u16 {
1940        self.num_authorities
1941    }
1942
1943    pub fn num_questions(&self) -> u16 {
1944        self.num_questions
1945    }
1946
1947    pub const fn is_query(&self) -> bool {
1948        (self.flags & FLAGS_QR_MASK) == FLAGS_QR_QUERY
1949    }
1950
1951    pub const fn is_response(&self) -> bool {
1952        (self.flags & FLAGS_QR_MASK) == FLAGS_QR_RESPONSE
1953    }
1954
1955    fn read_header(&mut self) -> Result<()> {
1956        if self.data.len() < MSG_HEADER_LEN {
1957            return Err(e_fmt!(
1958                "DNS incoming: header is too short: {} bytes",
1959                self.data.len()
1960            ));
1961        }
1962
1963        let data = &self.data[0..];
1964        self.id = u16_from_be_slice(&data[..2]);
1965        self.flags = u16_from_be_slice(&data[2..4]);
1966        self.num_questions = u16_from_be_slice(&data[4..6]);
1967        self.num_answers = u16_from_be_slice(&data[6..8]);
1968        self.num_authorities = u16_from_be_slice(&data[8..10]);
1969        self.num_additionals = u16_from_be_slice(&data[10..12]);
1970
1971        self.offset = MSG_HEADER_LEN;
1972
1973        trace!(
1974            "read_header: id {}, {} questions {} answers {} authorities {} additionals",
1975            self.id,
1976            self.num_questions,
1977            self.num_answers,
1978            self.num_authorities,
1979            self.num_additionals
1980        );
1981        Ok(())
1982    }
1983
1984    fn read_questions(&mut self) -> Result<()> {
1985        trace!("read_questions: {}", &self.num_questions);
1986        for i in 0..self.num_questions {
1987            let name = self.read_name()?;
1988
1989            let data = &self.data[self.offset..];
1990            if data.len() < 4 {
1991                return Err(Error::Msg(format!(
1992                    "DNS incoming: question idx {} too short: {}",
1993                    i,
1994                    data.len()
1995                )));
1996            }
1997            let ty = u16_from_be_slice(&data[..2]);
1998            let class = u16_from_be_slice(&data[2..4]);
1999            self.offset += 4;
2000
2001            let Some(rr_type) = RRType::from_u16(ty) else {
2002                return Err(Error::Msg(format!(
2003                    "DNS incoming: question idx {i} qtype unknown: {ty}",
2004                )));
2005            };
2006
2007            self.questions.push(DnsQuestion {
2008                entry: DnsEntry::new(name, rr_type, class),
2009            });
2010        }
2011        Ok(())
2012    }
2013
2014    fn read_answers(&mut self) -> Result<()> {
2015        self.answers = self.read_rr_records(self.num_answers)?;
2016        Ok(())
2017    }
2018
2019    fn read_authorities(&mut self) -> Result<()> {
2020        self.authorities = self.read_rr_records(self.num_authorities)?;
2021        Ok(())
2022    }
2023
2024    fn read_additional(&mut self) -> Result<()> {
2025        self.additional = self.read_rr_records(self.num_additionals)?;
2026        Ok(())
2027    }
2028
2029    /// Decodes a sequence of RR records (in answers, authorities and additionals).
2030    fn read_rr_records(&mut self, count: u16) -> Result<Vec<DnsRecordBox>> {
2031        trace!("read_rr_records: {}", count);
2032        let mut rr_records = Vec::new();
2033
2034        // RFC 1035: https://datatracker.ietf.org/doc/html/rfc1035#section-3.2.1
2035        //
2036        // All RRs have the same top level format shown below:
2037        //                               1  1  1  1  1  1
2038        // 0  1  2  3  4  5  6  7  8  9  0  1  2  3  4  5
2039        // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
2040        // |                                               |
2041        // /                                               /
2042        // /                      NAME                     /
2043        // |                                               |
2044        // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
2045        // |                      TYPE                     |
2046        // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
2047        // |                     CLASS                     |
2048        // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
2049        // |                      TTL                      |
2050        // |                                               |
2051        // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
2052        // |                   RDLENGTH                    |
2053        // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--|
2054        // /                     RDATA                     /
2055        // /                                               /
2056        // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
2057
2058        // Muse have at least TYPE, CLASS, TTL, RDLENGTH fields: 10 bytes.
2059        const RR_HEADER_REMAIN: usize = 10;
2060
2061        for _ in 0..count {
2062            let name = self.read_name()?;
2063            let slice = &self.data[self.offset..];
2064
2065            if slice.len() < RR_HEADER_REMAIN {
2066                return Err(Error::Msg(format!(
2067                    "read_others: RR '{}' is too short after name: {} bytes",
2068                    &name,
2069                    slice.len()
2070                )));
2071            }
2072
2073            let ty = u16_from_be_slice(&slice[..2]);
2074            let class = u16_from_be_slice(&slice[2..4]);
2075            let mut ttl = u32_from_be_slice(&slice[4..8]);
2076            if ttl == 0 && self.is_response() {
2077                // RFC 6762 section 10.1:
2078                // "...Queriers receiving a Multicast DNS response with a TTL of zero SHOULD
2079                // NOT immediately delete the record from the cache, but instead record
2080                // a TTL of 1 and then delete the record one second later."
2081                // See https://datatracker.ietf.org/doc/html/rfc6762#section-10.1
2082
2083                ttl = 1;
2084            }
2085            let rdata_len = u16_from_be_slice(&slice[8..10]) as usize;
2086            self.offset += RR_HEADER_REMAIN;
2087            let next_offset = self.offset + rdata_len;
2088
2089            // Sanity check for RDATA length.
2090            if next_offset > self.data.len() {
2091                return Err(Error::Msg(format!(
2092                    "RR {name} RDATA length {rdata_len} is invalid: remain data len: {}",
2093                    self.data.len() - self.offset
2094                )));
2095            }
2096
2097            // decode RDATA based on the record type.
2098            let rec: Option<DnsRecordBox> = match RRType::from_u16(ty) {
2099                None => None,
2100
2101                Some(rr_type) => match rr_type {
2102                    RRType::CNAME | RRType::PTR => {
2103                        Some(DnsPointer::new(&name, rr_type, class, ttl, self.read_name()?).boxed())
2104                    }
2105                    RRType::TXT => {
2106                        Some(DnsTxt::new(&name, class, ttl, self.read_vec(rdata_len)?).boxed())
2107                    }
2108                    RRType::SRV => Some(
2109                        DnsSrv::new(
2110                            &name,
2111                            class,
2112                            ttl,
2113                            self.read_u16()?,
2114                            self.read_u16()?,
2115                            self.read_u16()?,
2116                            self.read_name()?,
2117                        )
2118                        .boxed(),
2119                    ),
2120                    RRType::HINFO => Some(
2121                        DnsHostInfo::new(
2122                            &name,
2123                            rr_type,
2124                            class,
2125                            ttl,
2126                            self.read_char_string()?,
2127                            self.read_char_string()?,
2128                        )
2129                        .boxed(),
2130                    ),
2131                    RRType::A => Some(
2132                        DnsAddress::new(
2133                            &name,
2134                            rr_type,
2135                            class,
2136                            ttl,
2137                            self.read_ipv4()?.into(),
2138                            self.interface_id.clone(),
2139                        )
2140                        .boxed(),
2141                    ),
2142                    RRType::AAAA => Some(
2143                        DnsAddress::new(
2144                            &name,
2145                            rr_type,
2146                            class,
2147                            ttl,
2148                            self.read_ipv6()?.into(),
2149                            self.interface_id.clone(),
2150                        )
2151                        .boxed(),
2152                    ),
2153                    RRType::NSEC => Some(
2154                        DnsNSec::new(
2155                            &name,
2156                            class,
2157                            ttl,
2158                            self.read_name()?,
2159                            self.read_type_bitmap()?,
2160                        )
2161                        .boxed(),
2162                    ),
2163                    _ => None,
2164                },
2165            };
2166
2167            if let Some(record) = rec {
2168                trace!("read_rr_records: {:?}", &record);
2169                rr_records.push(record);
2170            } else {
2171                trace!("Unsupported DNS record type: {} name: {}", ty, &name);
2172                self.offset += rdata_len;
2173            }
2174
2175            // sanity check.
2176            if self.offset != next_offset {
2177                return Err(Error::Msg(format!(
2178                    "read_rr_records: decode offset error for RData type {} offset: {} expected offset: {}",
2179                    ty, self.offset, next_offset,
2180                )));
2181            }
2182        }
2183
2184        Ok(rr_records)
2185    }
2186
2187    fn read_char_string(&mut self) -> Result<String> {
2188        let length = self.data[self.offset];
2189        self.offset += 1;
2190        self.read_string(length as usize)
2191    }
2192
2193    fn read_u16(&mut self) -> Result<u16> {
2194        let slice = &self.data[self.offset..];
2195        if slice.len() < U16_SIZE {
2196            return Err(Error::Msg(format!(
2197                "read_u16: slice len is only {}",
2198                slice.len()
2199            )));
2200        }
2201        let num = u16_from_be_slice(&slice[..U16_SIZE]);
2202        self.offset += U16_SIZE;
2203        Ok(num)
2204    }
2205
2206    /// Reads the "Type Bit Map" block for a DNS NSEC record.
2207    fn read_type_bitmap(&mut self) -> Result<Vec<u8>> {
2208        // From RFC 6762: 6.1.  Negative Responses
2209        // https://datatracker.ietf.org/doc/html/rfc6762#section-6.1
2210        //   o The Type Bit Map block number is 0.
2211        //   o The Type Bit Map block length byte is a value in the range 1-32.
2212        //   o The Type Bit Map data is 1-32 bytes, as indicated by length
2213        //     byte.
2214
2215        // Sanity check: at least 2 bytes to read.
2216        if self.data.len() < self.offset + 2 {
2217            return Err(Error::Msg(format!(
2218                "DnsIncoming is too short: {} at NSEC Type Bit Map offset {}",
2219                self.data.len(),
2220                self.offset
2221            )));
2222        }
2223
2224        let block_num = self.data[self.offset];
2225        self.offset += 1;
2226        if block_num != 0 {
2227            return Err(Error::Msg(format!(
2228                "NSEC block number is not 0: {block_num}"
2229            )));
2230        }
2231
2232        let block_len = self.data[self.offset] as usize;
2233        if !(1..=32).contains(&block_len) {
2234            return Err(Error::Msg(format!(
2235                "NSEC block length must be in the range 1-32: {block_len}"
2236            )));
2237        }
2238        self.offset += 1;
2239
2240        let end = self.offset + block_len;
2241        if end > self.data.len() {
2242            return Err(Error::Msg(format!(
2243                "NSEC block overflow: {} over RData len {}",
2244                end,
2245                self.data.len()
2246            )));
2247        }
2248        let bitmap = self.data[self.offset..end].to_vec();
2249        self.offset += block_len;
2250
2251        Ok(bitmap)
2252    }
2253
2254    fn read_vec(&mut self, length: usize) -> Result<Vec<u8>> {
2255        if self.data.len() < self.offset + length {
2256            return Err(e_fmt!(
2257                "DNS Incoming: not enough data to read a chunk of data"
2258            ));
2259        }
2260
2261        let v = self.data[self.offset..self.offset + length].to_vec();
2262        self.offset += length;
2263        Ok(v)
2264    }
2265
2266    fn read_ipv4(&mut self) -> Result<Ipv4Addr> {
2267        if self.data.len() < self.offset + 4 {
2268            return Err(e_fmt!("DNS Incoming: not enough data to read an IPV4"));
2269        }
2270
2271        let bytes: [u8; 4] = self.data[self.offset..self.offset + 4]
2272            .try_into()
2273            .map_err(|_| e_fmt!("DNS incoming: Not enough bytes for reading an IPV4"))?;
2274        self.offset += bytes.len();
2275        Ok(Ipv4Addr::from(bytes))
2276    }
2277
2278    fn read_ipv6(&mut self) -> Result<Ipv6Addr> {
2279        if self.data.len() < self.offset + 16 {
2280            return Err(e_fmt!("DNS Incoming: not enough data to read an IPV6"));
2281        }
2282
2283        let bytes: [u8; 16] = self.data[self.offset..self.offset + 16]
2284            .try_into()
2285            .map_err(|_| e_fmt!("DNS incoming: Not enough bytes for reading an IPV6"))?;
2286        self.offset += bytes.len();
2287        Ok(Ipv6Addr::from(bytes))
2288    }
2289
2290    fn read_string(&mut self, length: usize) -> Result<String> {
2291        if self.data.len() < self.offset + length {
2292            return Err(e_fmt!("DNS Incoming: not enough data to read a string"));
2293        }
2294
2295        let s = str::from_utf8(&self.data[self.offset..self.offset + length])
2296            .map_err(|e| Error::Msg(e.to_string()))?;
2297        self.offset += length;
2298        Ok(s.to_string())
2299    }
2300
2301    /// Reads a domain name at the current location of `self.data`.
2302    ///
2303    /// See https://datatracker.ietf.org/doc/html/rfc1035#section-3.1 for
2304    /// domain name encoding.
2305    fn read_name(&mut self) -> Result<String> {
2306        let data = &self.data[..];
2307        let start_offset = self.offset;
2308        let mut offset = start_offset;
2309        let mut name = "".to_string();
2310        let mut at_end = false;
2311
2312        // From RFC1035:
2313        // "...Domain names in messages are expressed in terms of a sequence of labels.
2314        // Each label is represented as a one octet length field followed by that
2315        // number of octets."
2316        //
2317        // "...The compression scheme allows a domain name in a message to be
2318        // represented as either:
2319        // - a sequence of labels ending in a zero octet
2320        // - a pointer
2321        // - a sequence of labels ending with a pointer"
2322        loop {
2323            if offset >= data.len() {
2324                return Err(Error::Msg(format!(
2325                    "read_name: offset: {} data len {}. DnsIncoming: {:?}",
2326                    offset,
2327                    data.len(),
2328                    self
2329                )));
2330            }
2331            let length = data[offset];
2332
2333            // From RFC1035:
2334            // "...Since every domain name ends with the null label of
2335            // the root, a domain name is terminated by a length byte of zero."
2336            if length == 0 {
2337                if !at_end {
2338                    self.offset = offset + 1;
2339                }
2340                break; // The end of the name
2341            }
2342
2343            // Check the first 2 bits for possible "Message compression".
2344            match length & 0xC0 {
2345                0x00 => {
2346                    // regular utf8 string with length
2347                    offset += 1;
2348                    let ending = offset + length as usize;
2349
2350                    // Never read beyond the whole data length.
2351                    if ending > data.len() {
2352                        return Err(Error::Msg(format!(
2353                            "read_name: ending {} exceeds data length {}",
2354                            ending,
2355                            data.len()
2356                        )));
2357                    }
2358
2359                    name += str::from_utf8(&data[offset..ending])
2360                        .map_err(|e| Error::Msg(format!("read_name: from_utf8: {e}")))?;
2361                    name += ".";
2362                    offset += length as usize;
2363                }
2364                0xC0 => {
2365                    // Message compression.
2366                    // See https://datatracker.ietf.org/doc/html/rfc1035#section-4.1.4
2367                    let slice = &data[offset..];
2368                    if slice.len() < U16_SIZE {
2369                        return Err(Error::Msg(format!(
2370                            "read_name: u16 slice len is only {}",
2371                            slice.len()
2372                        )));
2373                    }
2374                    let pointer = (u16_from_be_slice(slice) ^ 0xC000) as usize;
2375                    if pointer >= start_offset {
2376                        // Error: could trigger an infinite loop.
2377                        return Err(Error::Msg(format!(
2378                            "Invalid name compression: pointer {} must be less than the start offset {}",
2379                            &pointer, &start_offset
2380                        )));
2381                    }
2382
2383                    // A pointer marks the end of a domain name.
2384                    if !at_end {
2385                        self.offset = offset + U16_SIZE;
2386                        at_end = true;
2387                    }
2388                    offset = pointer;
2389                }
2390                _ => {
2391                    return Err(Error::Msg(format!(
2392                        "Bad name with invalid length: 0x{:x} offset {}, data (so far): {:x?}",
2393                        length,
2394                        offset,
2395                        &data[..offset]
2396                    )));
2397                }
2398            };
2399        }
2400
2401        Ok(name)
2402    }
2403}
2404
2405/// Returns UNIX time in millis
2406fn current_time_millis() -> u64 {
2407    SystemTime::now()
2408        .duration_since(SystemTime::UNIX_EPOCH)
2409        .expect("failed to get current UNIX time")
2410        .as_millis() as u64
2411}
2412
2413const fn u16_from_be_slice(bytes: &[u8]) -> u16 {
2414    let u8_array: [u8; 2] = [bytes[0], bytes[1]];
2415    u16::from_be_bytes(u8_array)
2416}
2417
2418const fn u32_from_be_slice(s: &[u8]) -> u32 {
2419    let u8_array: [u8; 4] = [s[0], s[1], s[2], s[3]];
2420    u32::from_be_bytes(u8_array)
2421}
2422
2423/// Returns the UNIX time in millis at which this record will have expired
2424/// by a certain percentage.
2425const fn get_expiration_time(created: u64, ttl: u32, percent: u32) -> u64 {
2426    // 'created' is in millis, 'ttl' is in seconds, hence:
2427    // ttl * 1000 * (percent / 100) => ttl * percent * 10
2428    created + (ttl * percent * 10) as u64
2429}