mdns_sd/
dns_parser.rs

1//! DNS parsing utility.
2//!
3//! [DnsIncoming] is the logic representation of an incoming DNS packet.
4//! [DnsOutgoing] is the logic representation of an outgoing DNS message of one or more packets.
5//! [DnsOutPacket] is the encoded one packet for [DnsOutgoing].
6
7#[cfg(feature = "logging")]
8use crate::log::trace;
9
10use crate::error::{e_fmt, Error, Result};
11
12use std::{
13    any::Any,
14    cmp,
15    collections::HashMap,
16    convert::TryInto,
17    fmt,
18    net::{IpAddr, Ipv4Addr, Ipv6Addr},
19    str,
20    time::SystemTime,
21};
22
23/// DNS resource record types, stored as `u16`. Can do `as u16` when needed.
24///
25/// See [RFC 1035 section 3.2.2](https://datatracker.ietf.org/doc/html/rfc1035#section-3.2.2)
26#[derive(Debug, PartialEq, Eq, Clone, Copy, PartialOrd, Ord)]
27#[non_exhaustive]
28#[repr(u16)]
29pub enum RRType {
30    /// DNS record type for IPv4 address
31    A = 1,
32
33    /// DNS record type for Canonical Name
34    CNAME = 5,
35
36    /// DNS record type for Pointer
37    PTR = 12,
38
39    /// DNS record type for Host Info
40    HINFO = 13,
41
42    /// DNS record type for Text (properties)
43    TXT = 16,
44
45    /// DNS record type for IPv6 address
46    AAAA = 28,
47
48    /// DNS record type for Service
49    SRV = 33,
50
51    /// DNS record type for Negative Responses
52    NSEC = 47,
53
54    /// DNS record type for any records (wildcard)
55    ANY = 255,
56}
57
58impl RRType {
59    /// Converts `u16` into `RRType` if possible.
60    pub const fn from_u16(value: u16) -> Option<RRType> {
61        match value {
62            1 => Some(RRType::A),
63            5 => Some(RRType::CNAME),
64            12 => Some(RRType::PTR),
65            13 => Some(RRType::HINFO),
66            16 => Some(RRType::TXT),
67            28 => Some(RRType::AAAA),
68            33 => Some(RRType::SRV),
69            47 => Some(RRType::NSEC),
70            255 => Some(RRType::ANY),
71            _ => None,
72        }
73    }
74}
75
76impl fmt::Display for RRType {
77    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
78        match self {
79            RRType::A => write!(f, "TYPE_A"),
80            RRType::CNAME => write!(f, "TYPE_CNAME"),
81            RRType::PTR => write!(f, "TYPE_PTR"),
82            RRType::HINFO => write!(f, "TYPE_HINFO"),
83            RRType::TXT => write!(f, "TYPE_TXT"),
84            RRType::AAAA => write!(f, "TYPE_AAAA"),
85            RRType::SRV => write!(f, "TYPE_SRV"),
86            RRType::NSEC => write!(f, "TYPE_NSEC"),
87            RRType::ANY => write!(f, "TYPE_ANY"),
88        }
89    }
90}
91
92/// The class value for the Internet.
93pub const CLASS_IN: u16 = 1;
94pub const CLASS_MASK: u16 = 0x7FFF;
95
96/// Cache-flush bit: the most significant bit of the rrclass field of the resource record.  
97pub const CLASS_CACHE_FLUSH: u16 = 0x8000;
98
99/// Max size of UDP datagram payload.
100///
101/// It is calculated as: 9000 bytes - IP header 20 bytes - UDP header 8 bytes.
102/// Reference: [RFC6762 section 17](https://datatracker.ietf.org/doc/html/rfc6762#section-17)
103pub const MAX_MSG_ABSOLUTE: usize = 8972;
104
105const MSG_HEADER_LEN: usize = 12;
106
107// Definitions for DNS message header "flags" field
108//
109// The "flags" field is 16-bit long, in this format:
110// (RFC 1035 section 4.1.1)
111//
112//   0  1  2  3  4  5  6  7  8  9  0  1  2  3  4  5
113// |QR|   Opcode  |AA|TC|RD|RA|   Z    |   RCODE   |
114//
115pub const FLAGS_QR_MASK: u16 = 0x8000; // mask for query/response bit
116
117/// Flag bit to indicate a query
118pub const FLAGS_QR_QUERY: u16 = 0x0000;
119
120/// Flag bit to indicate a response
121pub const FLAGS_QR_RESPONSE: u16 = 0x8000;
122
123/// Flag bit for Authoritative Answer
124pub const FLAGS_AA: u16 = 0x0400;
125
126/// mask for TC(Truncated) bit
127///
128/// 2024-08-10: currently this flag is only supported on the querier side,
129///             not supported on the responder side. I.e. the responder only
130///             handles the first packet and ignore this bit. Since the
131///             additional packets have 0 questions, the processing of them
132///             is no-op.
133///             In practice, this means the responder supports Known-Answer
134///             only with single packet, not multi-packet. The querier supports
135///             both single packet and multi-packet.
136pub const FLAGS_TC: u16 = 0x0200;
137
138/// A convenience type alias for DNS record trait objects.
139pub type DnsRecordBox = Box<dyn DnsRecordExt>;
140
141impl Clone for DnsRecordBox {
142    fn clone(&self) -> Self {
143        self.clone_box()
144    }
145}
146
147const U16_SIZE: usize = 2;
148
149/// Returns `RRType` for a given IP address.
150#[inline]
151pub const fn ip_address_rr_type(address: &IpAddr) -> RRType {
152    match address {
153        IpAddr::V4(_) => RRType::A,
154        IpAddr::V6(_) => RRType::AAAA,
155    }
156}
157
158#[derive(Eq, PartialEq, Debug, Clone)]
159pub struct DnsEntry {
160    pub(crate) name: String, // always lower case.
161    pub(crate) ty: RRType,
162    class: u16,
163    cache_flush: bool,
164}
165
166impl DnsEntry {
167    const fn new(name: String, ty: RRType, class: u16) -> Self {
168        Self {
169            name,
170            ty,
171            class: class & CLASS_MASK,
172            cache_flush: (class & CLASS_CACHE_FLUSH) != 0,
173        }
174    }
175}
176
177/// Common methods for all DNS entries:  questions and resource records.
178pub trait DnsEntryExt: fmt::Debug {
179    fn entry_name(&self) -> &str;
180
181    fn entry_type(&self) -> RRType;
182}
183
184/// A DNS question entry
185#[derive(Debug)]
186pub struct DnsQuestion {
187    pub(crate) entry: DnsEntry,
188}
189
190impl DnsEntryExt for DnsQuestion {
191    fn entry_name(&self) -> &str {
192        &self.entry.name
193    }
194
195    fn entry_type(&self) -> RRType {
196        self.entry.ty
197    }
198}
199
200/// A DNS Resource Record - like a DNS entry, but has a TTL.
201/// RFC: https://www.rfc-editor.org/rfc/rfc1035#section-3.2.1
202///      https://www.rfc-editor.org/rfc/rfc1035#section-4.1.3
203#[derive(Debug, Clone)]
204pub struct DnsRecord {
205    pub(crate) entry: DnsEntry,
206    ttl: u32,     // in seconds, 0 means this record should not be cached
207    created: u64, // UNIX time in millis
208    expires: u64, // expires at this UNIX time in millis
209
210    /// Support re-query an instance before its PTR record expires.
211    /// See https://datatracker.ietf.org/doc/html/rfc6762#section-5.2
212    refresh: u64, // UNIX time in millis
213
214    /// If conflict resolution decides to change the name, this is the new one.
215    new_name: Option<String>,
216}
217
218impl DnsRecord {
219    fn new(name: &str, ty: RRType, class: u16, ttl: u32) -> Self {
220        let created = current_time_millis();
221
222        // From RFC 6762 section 5.2:
223        // "... The querier should plan to issue a query at 80% of the record
224        // lifetime, and then if no answer is received, at 85%, 90%, and 95%."
225        let refresh = get_expiration_time(created, ttl, 80);
226
227        let expires = get_expiration_time(created, ttl, 100);
228
229        Self {
230            entry: DnsEntry::new(name.to_string(), ty, class),
231            ttl,
232            created,
233            expires,
234            refresh,
235            new_name: None,
236        }
237    }
238
239    pub const fn get_ttl(&self) -> u32 {
240        self.ttl
241    }
242
243    pub const fn get_expire_time(&self) -> u64 {
244        self.expires
245    }
246
247    pub const fn get_refresh_time(&self) -> u64 {
248        self.refresh
249    }
250
251    pub const fn is_expired(&self, now: u64) -> bool {
252        now >= self.expires
253    }
254
255    pub const fn refresh_due(&self, now: u64) -> bool {
256        now >= self.refresh
257    }
258
259    /// Returns whether `now` (in millis) has passed half of TTL.
260    pub fn halflife_passed(&self, now: u64) -> bool {
261        let halflife = get_expiration_time(self.created, self.ttl, 50);
262        now > halflife
263    }
264
265    pub fn is_unique(&self) -> bool {
266        self.entry.cache_flush
267    }
268
269    /// Updates the refresh time to be the same as the expire time so that
270    /// this record will not refresh again and will just expire.
271    pub fn refresh_no_more(&mut self) {
272        self.refresh = get_expiration_time(self.created, self.ttl, 100);
273    }
274
275    /// Returns if this record is due for refresh. If yes, `refresh` time is updated.
276    pub fn refresh_maybe(&mut self, now: u64) -> bool {
277        if self.is_expired(now) || !self.refresh_due(now) {
278            return false;
279        }
280
281        trace!(
282            "{} qtype {} is due to refresh",
283            &self.entry.name,
284            self.entry.ty
285        );
286
287        // From RFC 6762 section 5.2:
288        // "... The querier should plan to issue a query at 80% of the record
289        // lifetime, and then if no answer is received, at 85%, 90%, and 95%."
290        //
291        // If the answer is received in time, 'refresh' will be reset outside
292        // this function, back to 80% of the new TTL.
293        if self.refresh == get_expiration_time(self.created, self.ttl, 80) {
294            self.refresh = get_expiration_time(self.created, self.ttl, 85);
295        } else if self.refresh == get_expiration_time(self.created, self.ttl, 85) {
296            self.refresh = get_expiration_time(self.created, self.ttl, 90);
297        } else if self.refresh == get_expiration_time(self.created, self.ttl, 90) {
298            self.refresh = get_expiration_time(self.created, self.ttl, 95);
299        } else {
300            self.refresh_no_more();
301        }
302
303        true
304    }
305
306    /// Returns the remaining TTL in seconds
307    fn get_remaining_ttl(&self, now: u64) -> u32 {
308        let remaining_millis = get_expiration_time(self.created, self.ttl, 100) - now;
309        cmp::max(0, remaining_millis / 1000) as u32
310    }
311
312    /// Return the absolute time for this record being created
313    pub const fn get_created(&self) -> u64 {
314        self.created
315    }
316
317    /// Set the absolute expiration time in millis
318    fn set_expire(&mut self, expire_at: u64) {
319        self.expires = expire_at;
320    }
321
322    fn reset_ttl(&mut self, other: &Self) {
323        self.ttl = other.ttl;
324        self.created = other.created;
325        self.refresh = get_expiration_time(self.created, self.ttl, 80);
326        self.expires = get_expiration_time(self.created, self.ttl, 100);
327    }
328
329    /// Modify TTL to reflect the remaining life time from `now`.
330    pub fn update_ttl(&mut self, now: u64) {
331        if now > self.created {
332            let elapsed = now - self.created;
333            self.ttl -= (elapsed / 1000) as u32;
334        }
335    }
336
337    pub fn set_new_name(&mut self, new_name: String) {
338        if new_name == self.entry.name {
339            self.new_name = None;
340        } else {
341            self.new_name = Some(new_name);
342        }
343    }
344
345    pub fn get_new_name(&self) -> Option<&str> {
346        self.new_name.as_deref()
347    }
348
349    /// Return the new name if exists, otherwise the regular name in DnsEntry.
350    pub(crate) fn get_name(&self) -> &str {
351        self.new_name.as_deref().unwrap_or(&self.entry.name)
352    }
353
354    pub fn get_original_name(&self) -> &str {
355        &self.entry.name
356    }
357}
358
359impl PartialEq for DnsRecord {
360    fn eq(&self, other: &Self) -> bool {
361        self.entry == other.entry
362    }
363}
364
365/// Common methods for DNS resource records.
366pub trait DnsRecordExt: fmt::Debug {
367    fn get_record(&self) -> &DnsRecord;
368    fn get_record_mut(&mut self) -> &mut DnsRecord;
369    fn write(&self, packet: &mut DnsOutPacket);
370    fn any(&self) -> &dyn Any;
371
372    /// Returns whether `other` record is considered the same except TTL.
373    fn matches(&self, other: &dyn DnsRecordExt) -> bool;
374
375    /// Returns whether `other` record has the same rdata.
376    fn rrdata_match(&self, other: &dyn DnsRecordExt) -> bool;
377
378    /// Returns the result based on a byte-level comparison of `rdata`.
379    /// If `other` is not valid, returns `Greater`.
380    fn compare_rdata(&self, other: &dyn DnsRecordExt) -> cmp::Ordering;
381
382    /// Returns the result based on "lexicographically later" defined below.
383    fn compare(&self, other: &dyn DnsRecordExt) -> cmp::Ordering {
384        /*
385        RFC 6762: https://datatracker.ietf.org/doc/html/rfc6762#section-8.2
386
387        ... The determination of "lexicographically later" is performed by first
388        comparing the record class (excluding the cache-flush bit described
389        in Section 10.2), then the record type, then raw comparison of the
390        binary content of the rdata without regard for meaning or structure.
391        If the record classes differ, then the numerically greater class is
392        considered "lexicographically later".  Otherwise, if the record types
393        differ, then the numerically greater type is considered
394        "lexicographically later".  If the rrtype and rrclass both match,
395        then the rdata is compared. ...
396        */
397        match self.get_class().cmp(&other.get_class()) {
398            cmp::Ordering::Equal => match self.get_type().cmp(&other.get_type()) {
399                cmp::Ordering::Equal => self.compare_rdata(other),
400                not_equal => not_equal,
401            },
402            not_equal => not_equal,
403        }
404    }
405
406    /// Returns a human-readable string of rdata.
407    fn rdata_print(&self) -> String;
408
409    /// Returns the class only, excluding class_flush / unique bit.
410    fn get_class(&self) -> u16 {
411        self.get_record().entry.class
412    }
413
414    fn get_cache_flush(&self) -> bool {
415        self.get_record().entry.cache_flush
416    }
417
418    /// Return the new name if exists, otherwise the regular name in DnsEntry.
419    fn get_name(&self) -> &str {
420        self.get_record().get_name()
421    }
422
423    fn get_type(&self) -> RRType {
424        self.get_record().entry.ty
425    }
426
427    /// Resets TTL using `other` record.
428    /// `self.refresh` and `self.expires` are also reset.
429    fn reset_ttl(&mut self, other: &dyn DnsRecordExt) {
430        self.get_record_mut().reset_ttl(other.get_record());
431    }
432
433    fn get_created(&self) -> u64 {
434        self.get_record().get_created()
435    }
436
437    fn get_expire(&self) -> u64 {
438        self.get_record().get_expire_time()
439    }
440
441    fn set_expire(&mut self, expire_at: u64) {
442        self.get_record_mut().set_expire(expire_at);
443    }
444
445    /// Set expire as `expire_at` if it is sooner than the current `expire`.
446    fn set_expire_sooner(&mut self, expire_at: u64) {
447        if expire_at < self.get_expire() {
448            self.get_record_mut().set_expire(expire_at);
449        }
450    }
451
452    /// Given `now`, if the record is due to refresh, this method updates the refresh time
453    /// and returns the new refresh time. Otherwise, returns None.
454    fn updated_refresh_time(&mut self, now: u64) -> Option<u64> {
455        if self.get_record_mut().refresh_maybe(now) {
456            Some(self.get_record().get_refresh_time())
457        } else {
458            None
459        }
460    }
461
462    /// Returns true if another record has matched content,
463    /// and if its TTL is at least half of this record's.
464    fn suppressed_by_answer(&self, other: &dyn DnsRecordExt) -> bool {
465        self.matches(other) && (other.get_record().ttl > self.get_record().ttl / 2)
466    }
467
468    /// Required by RFC 6762 Section 7.1: Known-Answer Suppression.
469    fn suppressed_by(&self, msg: &DnsIncoming) -> bool {
470        for answer in msg.answers.iter() {
471            if self.suppressed_by_answer(answer.as_ref()) {
472                return true;
473            }
474        }
475        false
476    }
477
478    fn clone_box(&self) -> DnsRecordBox;
479}
480
481/// Resource Record for IPv4 address or IPv6 address.
482#[derive(Debug, Clone)]
483pub struct DnsAddress {
484    pub(crate) record: DnsRecord,
485    address: IpAddr,
486}
487
488impl DnsAddress {
489    pub fn new(name: &str, ty: RRType, class: u16, ttl: u32, address: IpAddr) -> Self {
490        let record = DnsRecord::new(name, ty, class, ttl);
491        Self { record, address }
492    }
493
494    pub fn address(&self) -> IpAddr {
495        self.address
496    }
497}
498
499impl DnsRecordExt for DnsAddress {
500    fn get_record(&self) -> &DnsRecord {
501        &self.record
502    }
503
504    fn get_record_mut(&mut self) -> &mut DnsRecord {
505        &mut self.record
506    }
507
508    fn write(&self, packet: &mut DnsOutPacket) {
509        match self.address {
510            IpAddr::V4(addr) => packet.write_bytes(addr.octets().as_ref()),
511            IpAddr::V6(addr) => packet.write_bytes(addr.octets().as_ref()),
512        };
513    }
514
515    fn any(&self) -> &dyn Any {
516        self
517    }
518
519    fn matches(&self, other: &dyn DnsRecordExt) -> bool {
520        if let Some(other_a) = other.any().downcast_ref::<Self>() {
521            return self.address == other_a.address && self.record.entry == other_a.record.entry;
522        }
523        false
524    }
525
526    fn rrdata_match(&self, other: &dyn DnsRecordExt) -> bool {
527        if let Some(other_a) = other.any().downcast_ref::<Self>() {
528            return self.address == other_a.address;
529        }
530        false
531    }
532
533    fn compare_rdata(&self, other: &dyn DnsRecordExt) -> cmp::Ordering {
534        if let Some(other_a) = other.any().downcast_ref::<Self>() {
535            self.address.cmp(&other_a.address)
536        } else {
537            cmp::Ordering::Greater
538        }
539    }
540
541    fn rdata_print(&self) -> String {
542        format!("{}", self.address)
543    }
544
545    fn clone_box(&self) -> DnsRecordBox {
546        Box::new(self.clone())
547    }
548}
549
550/// Resource Record for a DNS pointer
551#[derive(Debug, Clone)]
552pub struct DnsPointer {
553    record: DnsRecord,
554    alias: String, // the full name of Service Instance
555}
556
557impl DnsPointer {
558    pub fn new(name: &str, ty: RRType, class: u16, ttl: u32, alias: String) -> Self {
559        let record = DnsRecord::new(name, ty, class, ttl);
560        Self { record, alias }
561    }
562
563    pub fn alias(&self) -> &str {
564        &self.alias
565    }
566}
567
568impl DnsRecordExt for DnsPointer {
569    fn get_record(&self) -> &DnsRecord {
570        &self.record
571    }
572
573    fn get_record_mut(&mut self) -> &mut DnsRecord {
574        &mut self.record
575    }
576
577    fn write(&self, packet: &mut DnsOutPacket) {
578        packet.write_name(&self.alias);
579    }
580
581    fn any(&self) -> &dyn Any {
582        self
583    }
584
585    fn matches(&self, other: &dyn DnsRecordExt) -> bool {
586        if let Some(other_ptr) = other.any().downcast_ref::<Self>() {
587            return self.alias == other_ptr.alias && self.record.entry == other_ptr.record.entry;
588        }
589        false
590    }
591
592    fn rrdata_match(&self, other: &dyn DnsRecordExt) -> bool {
593        if let Some(other_ptr) = other.any().downcast_ref::<Self>() {
594            return self.alias == other_ptr.alias;
595        }
596        false
597    }
598
599    fn compare_rdata(&self, other: &dyn DnsRecordExt) -> cmp::Ordering {
600        if let Some(other_ptr) = other.any().downcast_ref::<Self>() {
601            self.alias.cmp(&other_ptr.alias)
602        } else {
603            cmp::Ordering::Greater
604        }
605    }
606
607    fn rdata_print(&self) -> String {
608        self.alias.clone()
609    }
610
611    fn clone_box(&self) -> DnsRecordBox {
612        Box::new(self.clone())
613    }
614}
615
616/// Resource Record for a DNS service.
617#[derive(Debug, Clone)]
618pub struct DnsSrv {
619    pub(crate) record: DnsRecord,
620    pub(crate) priority: u16, // lower number means higher priority. Should be 0 in common cases.
621    pub(crate) weight: u16,   // Should be 0 in common cases
622    host: String,
623    port: u16,
624}
625
626impl DnsSrv {
627    pub fn new(
628        name: &str,
629        class: u16,
630        ttl: u32,
631        priority: u16,
632        weight: u16,
633        port: u16,
634        host: String,
635    ) -> Self {
636        let record = DnsRecord::new(name, RRType::SRV, class, ttl);
637        Self {
638            record,
639            priority,
640            weight,
641            host,
642            port,
643        }
644    }
645
646    pub fn host(&self) -> &str {
647        &self.host
648    }
649
650    pub fn port(&self) -> u16 {
651        self.port
652    }
653
654    pub fn set_host(&mut self, host: String) {
655        self.host = host;
656    }
657}
658
659impl DnsRecordExt for DnsSrv {
660    fn get_record(&self) -> &DnsRecord {
661        &self.record
662    }
663
664    fn get_record_mut(&mut self) -> &mut DnsRecord {
665        &mut self.record
666    }
667
668    fn write(&self, packet: &mut DnsOutPacket) {
669        packet.write_short(self.priority);
670        packet.write_short(self.weight);
671        packet.write_short(self.port);
672        packet.write_name(&self.host);
673    }
674
675    fn any(&self) -> &dyn Any {
676        self
677    }
678
679    fn matches(&self, other: &dyn DnsRecordExt) -> bool {
680        if let Some(other_svc) = other.any().downcast_ref::<Self>() {
681            return self.host == other_svc.host
682                && self.port == other_svc.port
683                && self.weight == other_svc.weight
684                && self.priority == other_svc.priority
685                && self.record.entry == other_svc.record.entry;
686        }
687        false
688    }
689
690    fn rrdata_match(&self, other: &dyn DnsRecordExt) -> bool {
691        if let Some(other_srv) = other.any().downcast_ref::<Self>() {
692            return self.host == other_srv.host
693                && self.port == other_srv.port
694                && self.weight == other_srv.weight
695                && self.priority == other_srv.priority;
696        }
697        false
698    }
699
700    fn compare_rdata(&self, other: &dyn DnsRecordExt) -> cmp::Ordering {
701        let Some(other_srv) = other.any().downcast_ref::<Self>() else {
702            return cmp::Ordering::Greater;
703        };
704
705        // 1. compare `priority`
706        match self
707            .priority
708            .to_be_bytes()
709            .cmp(&other_srv.priority.to_be_bytes())
710        {
711            cmp::Ordering::Equal => {
712                // 2. compare `weight`
713                match self
714                    .weight
715                    .to_be_bytes()
716                    .cmp(&other_srv.weight.to_be_bytes())
717                {
718                    cmp::Ordering::Equal => {
719                        // 3. compare `port`.
720                        match self.port.to_be_bytes().cmp(&other_srv.port.to_be_bytes()) {
721                            cmp::Ordering::Equal => self.host.cmp(&other_srv.host),
722                            not_equal => not_equal,
723                        }
724                    }
725                    not_equal => not_equal,
726                }
727            }
728            not_equal => not_equal,
729        }
730    }
731
732    fn rdata_print(&self) -> String {
733        format!(
734            "priority: {}, weight: {}, port: {}, host: {}",
735            self.priority, self.weight, self.port, self.host
736        )
737    }
738
739    fn clone_box(&self) -> DnsRecordBox {
740        Box::new(self.clone())
741    }
742}
743
744/// Resource Record for a DNS TXT record.
745///
746/// From [RFC 6763 section 6]:
747///
748/// The format of each constituent string within the DNS TXT record is a
749/// single length byte, followed by 0-255 bytes of text data.
750///
751/// DNS-SD uses DNS TXT records to store arbitrary key/value pairs
752///    conveying additional information about the named service.  Each
753///    key/value pair is encoded as its own constituent string within the
754///    DNS TXT record, in the form "key=value" (without the quotation
755///    marks).  Everything up to the first '=' character is the key (Section
756///    6.4).  Everything after the first '=' character to the end of the
757///    string (including subsequent '=' characters, if any) is the value
758#[derive(Clone)]
759pub struct DnsTxt {
760    pub(crate) record: DnsRecord,
761    text: Vec<u8>,
762}
763
764impl DnsTxt {
765    pub fn new(name: &str, class: u16, ttl: u32, text: Vec<u8>) -> Self {
766        let record = DnsRecord::new(name, RRType::TXT, class, ttl);
767        Self { record, text }
768    }
769
770    pub fn text(&self) -> &[u8] {
771        &self.text
772    }
773}
774
775impl DnsRecordExt for DnsTxt {
776    fn get_record(&self) -> &DnsRecord {
777        &self.record
778    }
779
780    fn get_record_mut(&mut self) -> &mut DnsRecord {
781        &mut self.record
782    }
783
784    fn write(&self, packet: &mut DnsOutPacket) {
785        packet.write_bytes(&self.text);
786    }
787
788    fn any(&self) -> &dyn Any {
789        self
790    }
791
792    fn matches(&self, other: &dyn DnsRecordExt) -> bool {
793        if let Some(other_txt) = other.any().downcast_ref::<Self>() {
794            return self.text == other_txt.text && self.record.entry == other_txt.record.entry;
795        }
796        false
797    }
798
799    fn rrdata_match(&self, other: &dyn DnsRecordExt) -> bool {
800        if let Some(other_txt) = other.any().downcast_ref::<Self>() {
801            return self.text == other_txt.text;
802        }
803        false
804    }
805
806    fn compare_rdata(&self, other: &dyn DnsRecordExt) -> cmp::Ordering {
807        if let Some(other_txt) = other.any().downcast_ref::<Self>() {
808            self.text.cmp(&other_txt.text)
809        } else {
810            cmp::Ordering::Greater
811        }
812    }
813
814    fn rdata_print(&self) -> String {
815        format!("{:?}", decode_txt(&self.text))
816    }
817
818    fn clone_box(&self) -> DnsRecordBox {
819        Box::new(self.clone())
820    }
821}
822
823impl fmt::Debug for DnsTxt {
824    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
825        let properties = decode_txt(&self.text);
826        write!(
827            f,
828            "DnsTxt {{ record: {:?}, text: {:?} }}",
829            self.record, properties
830        )
831    }
832}
833
834// Convert from DNS TXT record content to key/value pairs
835fn decode_txt(txt: &[u8]) -> Vec<TxtProperty> {
836    let mut properties = Vec::new();
837    let mut offset = 0;
838    while offset < txt.len() {
839        let length = txt[offset] as usize;
840        if length == 0 {
841            break; // reached the end
842        }
843        offset += 1; // move over the length byte
844
845        let offset_end = offset + length;
846        if offset_end > txt.len() {
847            trace!("ERROR: DNS TXT: size given for property is out of range. (offset={}, length={}, offset_end={}, record length={})", offset, length, offset_end, txt.len());
848            break; // Skipping the rest of the record content, as the size for this property would already be out of range.
849        }
850        let kv_bytes = &txt[offset..offset_end];
851
852        // split key and val using the first `=`
853        let (k, v) = kv_bytes.iter().position(|&x| x == b'=').map_or_else(
854            || (kv_bytes.to_vec(), None),
855            |idx| (kv_bytes[..idx].to_vec(), Some(kv_bytes[idx + 1..].to_vec())),
856        );
857
858        // Make sure the key can be stored in UTF-8.
859        match String::from_utf8(k) {
860            Ok(k_string) => {
861                properties.push(TxtProperty {
862                    key: k_string,
863                    val: v,
864                });
865            }
866            Err(e) => trace!("ERROR: convert to String from key: {}", e),
867        }
868
869        offset += length;
870    }
871
872    properties
873}
874
875/// Represents a property in a TXT record.
876#[derive(Clone, PartialEq, Eq)]
877pub struct TxtProperty {
878    /// The name of the property. The original cases are kept.
879    key: String,
880
881    /// RFC 6763 says values are bytes, not necessarily UTF-8.
882    /// It is also possible that there is no value, in which case
883    /// the key is a boolean key.
884    val: Option<Vec<u8>>,
885}
886
887impl TxtProperty {
888    /// Returns the value of a property as str.
889    pub fn val_str(&self) -> &str {
890        self.val
891            .as_ref()
892            .map_or("", |v| std::str::from_utf8(&v[..]).unwrap_or_default())
893    }
894}
895
896/// Supports constructing from a tuple.
897impl<K, V> From<&(K, V)> for TxtProperty
898where
899    K: ToString,
900    V: ToString,
901{
902    fn from(prop: &(K, V)) -> Self {
903        Self {
904            key: prop.0.to_string(),
905            val: Some(prop.1.to_string().into_bytes()),
906        }
907    }
908}
909
910impl<K, V> From<(K, V)> for TxtProperty
911where
912    K: ToString,
913    V: AsRef<[u8]>,
914{
915    fn from(prop: (K, V)) -> Self {
916        Self {
917            key: prop.0.to_string(),
918            val: Some(prop.1.as_ref().into()),
919        }
920    }
921}
922
923/// Support a property that has no value.
924impl From<&str> for TxtProperty {
925    fn from(key: &str) -> Self {
926        Self {
927            key: key.to_string(),
928            val: None,
929        }
930    }
931}
932
933impl fmt::Display for TxtProperty {
934    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
935        write!(f, "{}={}", self.key, self.val_str())
936    }
937}
938
939/// Mimic the default debug output for a struct, with a twist:
940/// - If self.var is UTF-8, will output it as a string in double quotes.
941/// - If self.var is not UTF-8, will output its bytes as in hex.
942impl fmt::Debug for TxtProperty {
943    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
944        let val_string = self.val.as_ref().map_or_else(
945            || "None".to_string(),
946            |v| {
947                std::str::from_utf8(&v[..]).map_or_else(
948                    |_| format!("Some({})", u8_slice_to_hex(&v[..])),
949                    |s| format!("Some(\"{}\")", s),
950                )
951            },
952        );
953
954        write!(
955            f,
956            "TxtProperty {{key: \"{}\", val: {}}}",
957            &self.key, &val_string,
958        )
959    }
960}
961
962const HEX_TABLE: [char; 16] = [
963    '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f',
964];
965
966/// Create a hex string from `slice`, with a "0x" prefix.
967///
968/// For example, [1u8, 2u8] -> "0x0102"
969fn u8_slice_to_hex(slice: &[u8]) -> String {
970    let mut hex = String::with_capacity(slice.len() * 2 + 2);
971    hex.push_str("0x");
972    for b in slice {
973        hex.push(HEX_TABLE[(b >> 4) as usize]);
974        hex.push(HEX_TABLE[(b & 0x0F) as usize]);
975    }
976    hex
977}
978
979/// A DNS host information record
980#[derive(Debug, Clone)]
981struct DnsHostInfo {
982    record: DnsRecord,
983    cpu: String,
984    os: String,
985}
986
987impl DnsHostInfo {
988    fn new(name: &str, ty: RRType, class: u16, ttl: u32, cpu: String, os: String) -> Self {
989        let record = DnsRecord::new(name, ty, class, ttl);
990        Self { record, cpu, os }
991    }
992}
993
994impl DnsRecordExt for DnsHostInfo {
995    fn get_record(&self) -> &DnsRecord {
996        &self.record
997    }
998
999    fn get_record_mut(&mut self) -> &mut DnsRecord {
1000        &mut self.record
1001    }
1002
1003    fn write(&self, packet: &mut DnsOutPacket) {
1004        println!("writing HInfo: cpu {} os {}", &self.cpu, &self.os);
1005        packet.write_bytes(self.cpu.as_bytes());
1006        packet.write_bytes(self.os.as_bytes());
1007    }
1008
1009    fn any(&self) -> &dyn Any {
1010        self
1011    }
1012
1013    fn matches(&self, other: &dyn DnsRecordExt) -> bool {
1014        if let Some(other_hinfo) = other.any().downcast_ref::<Self>() {
1015            return self.cpu == other_hinfo.cpu
1016                && self.os == other_hinfo.os
1017                && self.record.entry == other_hinfo.record.entry;
1018        }
1019        false
1020    }
1021
1022    fn rrdata_match(&self, other: &dyn DnsRecordExt) -> bool {
1023        if let Some(other_hinfo) = other.any().downcast_ref::<Self>() {
1024            return self.cpu == other_hinfo.cpu && self.os == other_hinfo.os;
1025        }
1026        false
1027    }
1028
1029    fn compare_rdata(&self, other: &dyn DnsRecordExt) -> cmp::Ordering {
1030        if let Some(other_hinfo) = other.any().downcast_ref::<Self>() {
1031            match self.cpu.cmp(&other_hinfo.cpu) {
1032                cmp::Ordering::Equal => self.os.cmp(&other_hinfo.os),
1033                ordering => ordering,
1034            }
1035        } else {
1036            cmp::Ordering::Greater
1037        }
1038    }
1039
1040    fn rdata_print(&self) -> String {
1041        format!("cpu: {}, os: {}", self.cpu, self.os)
1042    }
1043
1044    fn clone_box(&self) -> DnsRecordBox {
1045        Box::new(self.clone())
1046    }
1047}
1048
1049/// Resource Record for negative responses
1050///
1051/// [RFC4034 section 4.1](https://datatracker.ietf.org/doc/html/rfc4034#section-4.1)
1052/// and
1053/// [RFC6762 section 6.1](https://datatracker.ietf.org/doc/html/rfc6762#section-6.1)
1054#[derive(Debug, Clone)]
1055pub struct DnsNSec {
1056    record: DnsRecord,
1057    next_domain: String,
1058    type_bitmap: Vec<u8>,
1059}
1060
1061impl DnsNSec {
1062    pub fn new(
1063        name: &str,
1064        class: u16,
1065        ttl: u32,
1066        next_domain: String,
1067        type_bitmap: Vec<u8>,
1068    ) -> Self {
1069        let record = DnsRecord::new(name, RRType::NSEC, class, ttl);
1070        Self {
1071            record,
1072            next_domain,
1073            type_bitmap,
1074        }
1075    }
1076
1077    /// Returns the types marked by `type_bitmap`
1078    pub fn _types(&self) -> Vec<u16> {
1079        // From RFC 4034: 4.1.2 The Type Bit Maps Field
1080        // https://datatracker.ietf.org/doc/html/rfc4034#section-4.1.2
1081        //
1082        // Each bitmap encodes the low-order 8 bits of RR types within the
1083        // window block, in network bit order.  The first bit is bit 0.  For
1084        // window block 0, bit 1 corresponds to RR type 1 (A), bit 2 corresponds
1085        // to RR type 2 (NS), and so forth.
1086
1087        let mut bit_num = 0;
1088        let mut results = Vec::new();
1089
1090        for byte in self.type_bitmap.iter() {
1091            let mut bit_mask: u8 = 0x80; // for bit 0 in network bit order
1092
1093            // check every bit in this byte, one by one.
1094            for _ in 0..8 {
1095                if (byte & bit_mask) != 0 {
1096                    results.push(bit_num);
1097                }
1098                bit_num += 1;
1099                bit_mask >>= 1; // mask for the next bit
1100            }
1101        }
1102        results
1103    }
1104}
1105
1106impl DnsRecordExt for DnsNSec {
1107    fn get_record(&self) -> &DnsRecord {
1108        &self.record
1109    }
1110
1111    fn get_record_mut(&mut self) -> &mut DnsRecord {
1112        &mut self.record
1113    }
1114
1115    fn write(&self, packet: &mut DnsOutPacket) {
1116        packet.write_bytes(self.next_domain.as_bytes());
1117        packet.write_bytes(&self.type_bitmap);
1118    }
1119
1120    fn any(&self) -> &dyn Any {
1121        self
1122    }
1123
1124    fn matches(&self, other: &dyn DnsRecordExt) -> bool {
1125        if let Some(other_record) = other.any().downcast_ref::<Self>() {
1126            return self.next_domain == other_record.next_domain
1127                && self.type_bitmap == other_record.type_bitmap
1128                && self.record.entry == other_record.record.entry;
1129        }
1130        false
1131    }
1132
1133    fn rrdata_match(&self, other: &dyn DnsRecordExt) -> bool {
1134        if let Some(other_record) = other.any().downcast_ref::<Self>() {
1135            return self.next_domain == other_record.next_domain
1136                && self.type_bitmap == other_record.type_bitmap;
1137        }
1138        false
1139    }
1140
1141    fn compare_rdata(&self, other: &dyn DnsRecordExt) -> cmp::Ordering {
1142        if let Some(other_nsec) = other.any().downcast_ref::<Self>() {
1143            match self.next_domain.cmp(&other_nsec.next_domain) {
1144                cmp::Ordering::Equal => self.type_bitmap.cmp(&other_nsec.type_bitmap),
1145                ordering => ordering,
1146            }
1147        } else {
1148            cmp::Ordering::Greater
1149        }
1150    }
1151
1152    fn rdata_print(&self) -> String {
1153        format!(
1154            "next_domain: {}, type_bitmap len: {}",
1155            self.next_domain,
1156            self.type_bitmap.len()
1157        )
1158    }
1159
1160    fn clone_box(&self) -> DnsRecordBox {
1161        Box::new(self.clone())
1162    }
1163}
1164
1165#[derive(PartialEq)]
1166enum PacketState {
1167    Init = 0,
1168    Finished = 1,
1169}
1170
1171/// A single packet for outgoing DNS message.
1172pub struct DnsOutPacket {
1173    /// All bytes in `data` concatenated is the actual packet on the wire.
1174    data: Vec<Vec<u8>>,
1175
1176    /// Current logical size of the packet. It starts with the size of the mandatory header.
1177    size: usize,
1178
1179    /// An internal state, not defined by DNS.
1180    state: PacketState,
1181
1182    /// k: name, v: offset
1183    names: HashMap<String, u16>,
1184}
1185
1186impl DnsOutPacket {
1187    fn new() -> Self {
1188        Self {
1189            data: Vec::new(),
1190            size: MSG_HEADER_LEN, // Header is mandatory.
1191            state: PacketState::Init,
1192            names: HashMap::new(),
1193        }
1194    }
1195
1196    pub fn size(&self) -> usize {
1197        self.size
1198    }
1199
1200    pub fn to_bytes(&self) -> Vec<u8> {
1201        self.data.concat()
1202    }
1203
1204    fn write_question(&mut self, question: &DnsQuestion) {
1205        self.write_name(&question.entry.name);
1206        self.write_short(question.entry.ty as u16);
1207        self.write_short(question.entry.class);
1208    }
1209
1210    /// Writes a record (answer, authoritative answer, additional)
1211    /// Returns false if the packet exceeds the max size with this record, nothing is written to the packet.
1212    /// otherwise returns true.
1213    fn write_record(&mut self, record_ext: &dyn DnsRecordExt, now: u64) -> bool {
1214        let start_data_length = self.data.len();
1215        let start_size = self.size;
1216
1217        let record = record_ext.get_record();
1218        self.write_name(record.get_name());
1219        self.write_short(record.entry.ty as u16);
1220        if record.entry.cache_flush {
1221            // check "multicast"
1222            self.write_short(record.entry.class | CLASS_CACHE_FLUSH);
1223        } else {
1224            self.write_short(record.entry.class);
1225        }
1226
1227        if now == 0 {
1228            self.write_u32(record.ttl);
1229        } else {
1230            self.write_u32(record.get_remaining_ttl(now));
1231        }
1232
1233        let index = self.data.len();
1234
1235        // Adjust size for the short we will write before this record
1236        self.size += 2;
1237        record_ext.write(self);
1238        self.size -= 2;
1239
1240        let length: usize = self.data[index..].iter().map(|x| x.len()).sum();
1241        self.insert_short(index, length as u16);
1242
1243        if self.size > MAX_MSG_ABSOLUTE {
1244            self.data.truncate(start_data_length);
1245            self.size = start_size;
1246            self.state = PacketState::Finished;
1247            return false;
1248        }
1249
1250        true
1251    }
1252
1253    pub(crate) fn insert_short(&mut self, index: usize, value: u16) {
1254        self.data.insert(index, value.to_be_bytes().to_vec());
1255        self.size += 2;
1256    }
1257
1258    // Write name to packet
1259    //
1260    // [RFC1035]
1261    // 4.1.4. Message compression
1262    //
1263    // In order to reduce the size of messages, the domain system utilizes a
1264    // compression scheme which eliminates the repetition of domain names in a
1265    // message.  In this scheme, an entire domain name or a list of labels at
1266    // the end of a domain name is replaced with a pointer to a prior occurrence
1267    // of the same name.
1268    // The pointer takes the form of a two octet sequence:
1269    //     +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
1270    //     | 1  1|                OFFSET                   |
1271    //     +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
1272    // The first two bits are ones.  This allows a pointer to be distinguished
1273    // from a label, since the label must begin with two zero bits because
1274    // labels are restricted to 63 octets or less.  (The 10 and 01 combinations
1275    // are reserved for future use.)  The OFFSET field specifies an offset from
1276    // the start of the message (i.e., the first octet of the ID field in the
1277    // domain header).  A zero offset specifies the first byte of the ID field,
1278    // etc.
1279    fn write_name(&mut self, name: &str) {
1280        // ignore the ending "." if exists
1281        let end = name.len();
1282        let end = if end > 0 && &name[end - 1..] == "." {
1283            end - 1
1284        } else {
1285            end
1286        };
1287
1288        let mut here = 0;
1289        while here < end {
1290            const POINTER_MASK: u16 = 0xC000;
1291            let remaining = &name[here..end];
1292
1293            // Check if 'remaining' already appeared in this message
1294            match self.names.get(remaining) {
1295                Some(offset) => {
1296                    let pointer = *offset | POINTER_MASK;
1297                    self.write_short(pointer);
1298                    // println!(
1299                    //     "written pointer {} ({}) for {}",
1300                    //     pointer,
1301                    //     pointer ^ POINTER_MASK,
1302                    //     remaining
1303                    // );
1304                    break;
1305                }
1306                None => {
1307                    // Remember the remaining parts so we can point to it
1308                    self.names.insert(remaining.to_string(), self.size as u16);
1309                    // println!("set offset {} for {}", self.size, remaining);
1310
1311                    // Find the current label to write into the packet
1312                    let stop = remaining.find('.').map_or(end, |i| here + i);
1313                    let label = &name[here..stop];
1314                    self.write_utf8(label);
1315
1316                    here = stop + 1; // move past the current label
1317                }
1318            }
1319
1320            if here >= end {
1321                self.write_byte(0); // name ends with 0 if not using a pointer
1322            }
1323        }
1324    }
1325
1326    fn write_utf8(&mut self, utf: &str) {
1327        assert!(utf.len() < 64);
1328        self.write_byte(utf.len() as u8);
1329        self.write_bytes(utf.as_bytes());
1330    }
1331
1332    fn write_bytes(&mut self, s: &[u8]) {
1333        self.data.push(s.to_vec());
1334        self.size += s.len();
1335    }
1336
1337    fn write_u32(&mut self, int: u32) {
1338        self.data.push(int.to_be_bytes().to_vec());
1339        self.size += 4;
1340    }
1341
1342    fn write_short(&mut self, short: u16) {
1343        self.data.push(short.to_be_bytes().to_vec());
1344        self.size += 2;
1345    }
1346
1347    fn write_byte(&mut self, byte: u8) {
1348        self.data.push(vec![byte]);
1349        self.size += 1;
1350    }
1351
1352    /// Writes the header fields and finish the packet.
1353    /// This function should be only called when finishing a packet.
1354    ///
1355    /// The header format is based on RFC 1035 section 4.1.1:
1356    /// https://datatracker.ietf.org/doc/html/rfc1035#section-4.1.1
1357    //
1358    //                                  1  1  1  1  1  1
1359    //    0  1  2  3  4  5  6  7  8  9  0  1  2  3  4  5
1360    //    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
1361    //    |                      ID                       |
1362    //    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
1363    //    |QR|   Opcode  |AA|TC|RD|RA|   Z    |   RCODE   |
1364    //    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
1365    //    |                    QDCOUNT                    |
1366    //    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
1367    //    |                    ANCOUNT                    |
1368    //    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
1369    //    |                    NSCOUNT                    |
1370    //    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
1371    //    |                    ARCOUNT                    |
1372    //    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
1373    //
1374    fn write_header(
1375        &mut self,
1376        id: u16,
1377        flags: u16,
1378        q_count: u16,
1379        a_count: u16,
1380        auth_count: u16,
1381        addi_count: u16,
1382    ) {
1383        self.insert_short(0, addi_count);
1384        self.insert_short(0, auth_count);
1385        self.insert_short(0, a_count);
1386        self.insert_short(0, q_count);
1387        self.insert_short(0, flags);
1388        self.insert_short(0, id);
1389
1390        // Adjust the size as it was already initialized to include the header.
1391        self.size -= MSG_HEADER_LEN;
1392
1393        self.state = PacketState::Finished;
1394    }
1395}
1396
1397/// Representation of one outgoing DNS message that could be sent in one or more packet(s).
1398pub struct DnsOutgoing {
1399    flags: u16,
1400    id: u16,
1401    multicast: bool,
1402    questions: Vec<DnsQuestion>,
1403    answers: Vec<(DnsRecordBox, u64)>,
1404    authorities: Vec<DnsRecordBox>,
1405    additionals: Vec<DnsRecordBox>,
1406    known_answer_count: i64, // for internal maintenance only
1407}
1408
1409impl DnsOutgoing {
1410    pub fn new(flags: u16) -> Self {
1411        Self {
1412            flags,
1413            id: 0,
1414            multicast: true,
1415            questions: Vec::new(),
1416            answers: Vec::new(),
1417            authorities: Vec::new(),
1418            additionals: Vec::new(),
1419            known_answer_count: 0,
1420        }
1421    }
1422
1423    pub fn questions(&self) -> &[DnsQuestion] {
1424        &self.questions
1425    }
1426
1427    pub fn answers_count(&self) -> usize {
1428        self.answers.len()
1429    }
1430
1431    pub fn authorities(&self) -> &[DnsRecordBox] {
1432        &self.authorities
1433    }
1434
1435    pub fn additionals(&self) -> &[DnsRecordBox] {
1436        &self.additionals
1437    }
1438
1439    pub fn known_answer_count(&self) -> i64 {
1440        self.known_answer_count
1441    }
1442
1443    pub fn set_id(&mut self, id: u16) {
1444        self.id = id;
1445    }
1446
1447    pub const fn is_query(&self) -> bool {
1448        (self.flags & FLAGS_QR_MASK) == FLAGS_QR_QUERY
1449    }
1450
1451    const fn is_response(&self) -> bool {
1452        (self.flags & FLAGS_QR_MASK) == FLAGS_QR_RESPONSE
1453    }
1454
1455    // Adds an additional answer
1456
1457    // From: RFC 6763, DNS-Based Service Discovery, February 2013
1458
1459    // 12.  DNS Additional Record Generation
1460
1461    //    DNS has an efficiency feature whereby a DNS server may place
1462    //    additional records in the additional section of the DNS message.
1463    //    These additional records are records that the client did not
1464    //    explicitly request, but the server has reasonable grounds to expect
1465    //    that the client might request them shortly, so including them can
1466    //    save the client from having to issue additional queries.
1467
1468    //    This section recommends which additional records SHOULD be generated
1469    //    to improve network efficiency, for both Unicast and Multicast DNS-SD
1470    //    responses.
1471
1472    // 12.1.  PTR Records
1473
1474    //    When including a DNS-SD Service Instance Enumeration or Selective
1475    //    Instance Enumeration (subtype) PTR record in a response packet, the
1476    //    server/responder SHOULD include the following additional records:
1477
1478    //    o  The SRV record(s) named in the PTR rdata.
1479    //    o  The TXT record(s) named in the PTR rdata.
1480    //    o  All address records (type "A" and "AAAA") named in the SRV rdata.
1481
1482    // 12.2.  SRV Records
1483
1484    //    When including an SRV record in a response packet, the
1485    //    server/responder SHOULD include the following additional records:
1486
1487    //    o  All address records (type "A" and "AAAA") named in the SRV rdata.
1488    pub fn add_additional_answer(&mut self, answer: impl DnsRecordExt + 'static) {
1489        trace!("add_additional_answer: {:?}", &answer);
1490        self.additionals.push(Box::new(answer));
1491    }
1492
1493    /// A workaround as Rust doesn't allow us to pass DnsRecordBox in as `impl DnsRecordExt`
1494    pub fn add_answer_box(&mut self, answer_box: DnsRecordBox) {
1495        self.answers.push((answer_box, 0));
1496    }
1497
1498    pub fn add_authority(&mut self, record: DnsRecordBox) {
1499        self.authorities.push(record);
1500    }
1501
1502    /// Returns true if `answer` is added to the outgoing msg.
1503    /// Returns false if `answer` was not added as it expired or suppressed by the incoming `msg`.
1504    pub fn add_answer(
1505        &mut self,
1506        msg: &DnsIncoming,
1507        answer: impl DnsRecordExt + Send + 'static,
1508    ) -> bool {
1509        trace!("Check for add_answer");
1510        if answer.suppressed_by(msg) {
1511            trace!("my answer is suppressed by incoming msg");
1512            self.known_answer_count += 1;
1513            return false;
1514        }
1515
1516        self.add_answer_at_time(answer, 0)
1517    }
1518
1519    /// Returns true if `answer` is added to the outgoing msg.
1520    /// Returns false if the answer is expired `now` hence not added.
1521    /// If `now` is 0, do not check if the answer expires.
1522    pub fn add_answer_at_time(
1523        &mut self,
1524        answer: impl DnsRecordExt + Send + 'static,
1525        now: u64,
1526    ) -> bool {
1527        if now == 0 || !answer.get_record().is_expired(now) {
1528            trace!("add_answer push: {:?}", &answer);
1529            self.answers.push((Box::new(answer), now));
1530            return true;
1531        }
1532        false
1533    }
1534
1535    pub fn add_question(&mut self, name: &str, qtype: RRType) {
1536        let q = DnsQuestion {
1537            entry: DnsEntry::new(name.to_string(), qtype, CLASS_IN),
1538        };
1539        self.questions.push(q);
1540    }
1541
1542    /// Returns a list of actual DNS packet data to be sent on the wire.
1543    pub fn to_data_on_wire(&self) -> Vec<Vec<u8>> {
1544        let packet_list = self.to_packets();
1545        packet_list.iter().map(|p| p.data.concat()).collect()
1546    }
1547
1548    /// Encode self into one or more packets.
1549    pub fn to_packets(&self) -> Vec<DnsOutPacket> {
1550        let mut packet_list = Vec::new();
1551        let mut packet = DnsOutPacket::new();
1552
1553        let mut question_count = self.questions.len() as u16;
1554        let mut answer_count = 0;
1555        let mut auth_count = 0;
1556        let mut addi_count = 0;
1557        let id = if self.multicast { 0 } else { self.id };
1558
1559        for question in self.questions.iter() {
1560            packet.write_question(question);
1561        }
1562
1563        for (answer, time) in self.answers.iter() {
1564            if packet.write_record(answer.as_ref(), *time) {
1565                answer_count += 1;
1566            }
1567        }
1568
1569        for auth in self.authorities.iter() {
1570            auth_count += u16::from(packet.write_record(auth.as_ref(), 0));
1571        }
1572
1573        for addi in self.additionals.iter() {
1574            if packet.write_record(addi.as_ref(), 0) {
1575                addi_count += 1;
1576                continue;
1577            }
1578
1579            // No more processing for response packets.
1580            if self.is_response() {
1581                break;
1582            }
1583
1584            // For query, the current packet exceeds its max size due to known answers,
1585            // need to truncate.
1586
1587            // finish the current packet first.
1588            packet.write_header(
1589                id,
1590                self.flags | FLAGS_TC,
1591                question_count,
1592                answer_count,
1593                auth_count,
1594                addi_count,
1595            );
1596
1597            packet_list.push(packet);
1598
1599            // create a new packet and reset counts.
1600            packet = DnsOutPacket::new();
1601            packet.write_record(addi.as_ref(), 0);
1602
1603            question_count = 0;
1604            answer_count = 0;
1605            auth_count = 0;
1606            addi_count = 1;
1607        }
1608
1609        packet.write_header(
1610            id,
1611            self.flags,
1612            question_count,
1613            answer_count,
1614            auth_count,
1615            addi_count,
1616        );
1617
1618        packet_list.push(packet);
1619        packet_list
1620    }
1621}
1622
1623/// An incoming DNS message. It could be a query or a response.
1624#[derive(Debug)]
1625pub struct DnsIncoming {
1626    offset: usize,
1627    data: Vec<u8>,
1628    questions: Vec<DnsQuestion>,
1629    answers: Vec<DnsRecordBox>,
1630    authorities: Vec<DnsRecordBox>,
1631    additional: Vec<DnsRecordBox>,
1632    id: u16,
1633    flags: u16,
1634    num_questions: u16,
1635    num_answers: u16,
1636    num_authorities: u16,
1637    num_additionals: u16,
1638}
1639
1640impl DnsIncoming {
1641    pub fn new(data: Vec<u8>) -> Result<Self> {
1642        let mut incoming = Self {
1643            offset: 0,
1644            data,
1645            questions: Vec::new(),
1646            answers: Vec::new(),
1647            authorities: Vec::new(),
1648            additional: Vec::new(),
1649            id: 0,
1650            flags: 0,
1651            num_questions: 0,
1652            num_answers: 0,
1653            num_authorities: 0,
1654            num_additionals: 0,
1655        };
1656
1657        /*
1658        RFC 1035 section 4.1: https://datatracker.ietf.org/doc/html/rfc1035#section-4.1
1659        ...
1660        All communications inside of the domain protocol are carried in a single
1661        format called a message.  The top level format of message is divided
1662        into 5 sections (some of which are empty in certain cases) shown below:
1663
1664            +---------------------+
1665            |        Header       |
1666            +---------------------+
1667            |       Question      | the question for the name server
1668            +---------------------+
1669            |        Answer       | RRs answering the question
1670            +---------------------+
1671            |      Authority      | RRs pointing toward an authority
1672            +---------------------+
1673            |      Additional     | RRs holding additional information
1674            +---------------------+
1675         */
1676        incoming.read_header()?;
1677        incoming.read_questions()?;
1678        incoming.read_answers()?;
1679        incoming.read_authorities()?;
1680        incoming.read_additional()?;
1681
1682        Ok(incoming)
1683    }
1684
1685    pub fn id(&self) -> u16 {
1686        self.id
1687    }
1688
1689    pub fn questions(&self) -> &[DnsQuestion] {
1690        &self.questions
1691    }
1692
1693    pub fn answers(&self) -> &[DnsRecordBox] {
1694        &self.answers
1695    }
1696
1697    pub fn authorities(&self) -> &[DnsRecordBox] {
1698        &self.authorities
1699    }
1700
1701    pub fn additionals(&self) -> &[DnsRecordBox] {
1702        &self.additional
1703    }
1704
1705    pub fn answers_mut(&mut self) -> &mut Vec<DnsRecordBox> {
1706        &mut self.answers
1707    }
1708
1709    pub fn authorities_mut(&mut self) -> &mut Vec<DnsRecordBox> {
1710        &mut self.authorities
1711    }
1712
1713    pub fn additionals_mut(&mut self) -> &mut Vec<DnsRecordBox> {
1714        &mut self.additional
1715    }
1716
1717    pub fn all_records(self) -> impl Iterator<Item = DnsRecordBox> {
1718        self.answers
1719            .into_iter()
1720            .chain(self.authorities)
1721            .chain(self.additional)
1722    }
1723
1724    pub fn num_additionals(&self) -> u16 {
1725        self.num_additionals
1726    }
1727
1728    pub fn num_authorities(&self) -> u16 {
1729        self.num_authorities
1730    }
1731
1732    pub fn num_questions(&self) -> u16 {
1733        self.num_questions
1734    }
1735
1736    pub const fn is_query(&self) -> bool {
1737        (self.flags & FLAGS_QR_MASK) == FLAGS_QR_QUERY
1738    }
1739
1740    pub const fn is_response(&self) -> bool {
1741        (self.flags & FLAGS_QR_MASK) == FLAGS_QR_RESPONSE
1742    }
1743
1744    fn read_header(&mut self) -> Result<()> {
1745        if self.data.len() < MSG_HEADER_LEN {
1746            return Err(e_fmt!(
1747                "DNS incoming: header is too short: {} bytes",
1748                self.data.len()
1749            ));
1750        }
1751
1752        let data = &self.data[0..];
1753        self.id = u16_from_be_slice(&data[..2]);
1754        self.flags = u16_from_be_slice(&data[2..4]);
1755        self.num_questions = u16_from_be_slice(&data[4..6]);
1756        self.num_answers = u16_from_be_slice(&data[6..8]);
1757        self.num_authorities = u16_from_be_slice(&data[8..10]);
1758        self.num_additionals = u16_from_be_slice(&data[10..12]);
1759
1760        self.offset = MSG_HEADER_LEN;
1761
1762        trace!(
1763            "read_header: id {}, {} questions {} answers {} authorities {} additionals",
1764            self.id,
1765            self.num_questions,
1766            self.num_answers,
1767            self.num_authorities,
1768            self.num_additionals
1769        );
1770        Ok(())
1771    }
1772
1773    fn read_questions(&mut self) -> Result<()> {
1774        trace!("read_questions: {}", &self.num_questions);
1775        for i in 0..self.num_questions {
1776            let name = self.read_name()?;
1777
1778            let data = &self.data[self.offset..];
1779            if data.len() < 4 {
1780                return Err(Error::Msg(format!(
1781                    "DNS incoming: question idx {} too short: {}",
1782                    i,
1783                    data.len()
1784                )));
1785            }
1786            let ty = u16_from_be_slice(&data[..2]);
1787            let class = u16_from_be_slice(&data[2..4]);
1788            self.offset += 4;
1789
1790            let Some(rr_type) = RRType::from_u16(ty) else {
1791                return Err(Error::Msg(format!(
1792                    "DNS incoming: question idx {} qtype unknown: {}",
1793                    i, ty
1794                )));
1795            };
1796
1797            self.questions.push(DnsQuestion {
1798                entry: DnsEntry::new(name, rr_type, class),
1799            });
1800        }
1801        Ok(())
1802    }
1803
1804    fn read_answers(&mut self) -> Result<()> {
1805        self.answers = self.read_rr_records(self.num_answers)?;
1806        Ok(())
1807    }
1808
1809    fn read_authorities(&mut self) -> Result<()> {
1810        self.authorities = self.read_rr_records(self.num_authorities)?;
1811        Ok(())
1812    }
1813
1814    fn read_additional(&mut self) -> Result<()> {
1815        self.additional = self.read_rr_records(self.num_additionals)?;
1816        Ok(())
1817    }
1818
1819    /// Decodes a sequence of RR records (in answers, authorities and additionals).
1820    fn read_rr_records(&mut self, count: u16) -> Result<Vec<DnsRecordBox>> {
1821        trace!("read_rr_records: {}", count);
1822        let mut rr_records = Vec::new();
1823
1824        // RFC 1035: https://datatracker.ietf.org/doc/html/rfc1035#section-3.2.1
1825        //
1826        // All RRs have the same top level format shown below:
1827        //                               1  1  1  1  1  1
1828        // 0  1  2  3  4  5  6  7  8  9  0  1  2  3  4  5
1829        // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
1830        // |                                               |
1831        // /                                               /
1832        // /                      NAME                     /
1833        // |                                               |
1834        // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
1835        // |                      TYPE                     |
1836        // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
1837        // |                     CLASS                     |
1838        // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
1839        // |                      TTL                      |
1840        // |                                               |
1841        // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
1842        // |                   RDLENGTH                    |
1843        // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--|
1844        // /                     RDATA                     /
1845        // /                                               /
1846        // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
1847
1848        // Muse have at least TYPE, CLASS, TTL, RDLENGTH fields: 10 bytes.
1849        const RR_HEADER_REMAIN: usize = 10;
1850
1851        for _ in 0..count {
1852            let name = self.read_name()?;
1853            let slice = &self.data[self.offset..];
1854
1855            if slice.len() < RR_HEADER_REMAIN {
1856                return Err(Error::Msg(format!(
1857                    "read_others: RR '{}' is too short after name: {} bytes",
1858                    &name,
1859                    slice.len()
1860                )));
1861            }
1862
1863            let ty = u16_from_be_slice(&slice[..2]);
1864            let class = u16_from_be_slice(&slice[2..4]);
1865            let mut ttl = u32_from_be_slice(&slice[4..8]);
1866            if ttl == 0 && self.is_response() {
1867                // RFC 6762 section 10.1:
1868                // "...Queriers receiving a Multicast DNS response with a TTL of zero SHOULD
1869                // NOT immediately delete the record from the cache, but instead record
1870                // a TTL of 1 and then delete the record one second later."
1871                // See https://datatracker.ietf.org/doc/html/rfc6762#section-10.1
1872
1873                ttl = 1;
1874            }
1875            let rdata_len = u16_from_be_slice(&slice[8..10]) as usize;
1876            self.offset += RR_HEADER_REMAIN;
1877            let next_offset = self.offset + rdata_len;
1878
1879            // Sanity check for RDATA length.
1880            if next_offset > self.data.len() {
1881                return Err(Error::Msg(format!(
1882                    "RR {name} RDATA length {rdata_len} is invalid: remain data len: {}",
1883                    self.data.len() - self.offset
1884                )));
1885            }
1886
1887            // decode RDATA based on the record type.
1888            let rec: Option<DnsRecordBox> = match RRType::from_u16(ty) {
1889                None => None,
1890
1891                Some(rr_type) => match rr_type {
1892                    RRType::CNAME | RRType::PTR => Some(Box::new(DnsPointer::new(
1893                        &name,
1894                        rr_type,
1895                        class,
1896                        ttl,
1897                        self.read_name()?,
1898                    ))),
1899                    RRType::TXT => Some(Box::new(DnsTxt::new(
1900                        &name,
1901                        class,
1902                        ttl,
1903                        self.read_vec(rdata_len),
1904                    ))),
1905                    RRType::SRV => Some(Box::new(DnsSrv::new(
1906                        &name,
1907                        class,
1908                        ttl,
1909                        self.read_u16()?,
1910                        self.read_u16()?,
1911                        self.read_u16()?,
1912                        self.read_name()?,
1913                    ))),
1914                    RRType::HINFO => Some(Box::new(DnsHostInfo::new(
1915                        &name,
1916                        rr_type,
1917                        class,
1918                        ttl,
1919                        self.read_char_string(),
1920                        self.read_char_string(),
1921                    ))),
1922                    RRType::A => Some(Box::new(DnsAddress::new(
1923                        &name,
1924                        rr_type,
1925                        class,
1926                        ttl,
1927                        self.read_ipv4().into(),
1928                    ))),
1929                    RRType::AAAA => Some(Box::new(DnsAddress::new(
1930                        &name,
1931                        rr_type,
1932                        class,
1933                        ttl,
1934                        self.read_ipv6().into(),
1935                    ))),
1936                    RRType::NSEC => Some(Box::new(DnsNSec::new(
1937                        &name,
1938                        class,
1939                        ttl,
1940                        self.read_name()?,
1941                        self.read_type_bitmap()?,
1942                    ))),
1943                    _ => None,
1944                },
1945            };
1946
1947            if let Some(record) = rec {
1948                trace!("read_rr_records: {:?}", &record);
1949                rr_records.push(record);
1950            } else {
1951                trace!("Unsupported DNS record type: {} name: {}", ty, &name);
1952                self.offset += rdata_len;
1953            }
1954
1955            // sanity check.
1956            if self.offset != next_offset {
1957                return Err(Error::Msg(format!(
1958                    "read_rr_records: decode offset error for RData type {} offset: {} expected offset: {}",
1959                    ty, self.offset, next_offset,
1960                )));
1961            }
1962        }
1963
1964        Ok(rr_records)
1965    }
1966
1967    fn read_char_string(&mut self) -> String {
1968        let length = self.data[self.offset];
1969        self.offset += 1;
1970        self.read_string(length as usize)
1971    }
1972
1973    fn read_u16(&mut self) -> Result<u16> {
1974        let slice = &self.data[self.offset..];
1975        if slice.len() < U16_SIZE {
1976            return Err(Error::Msg(format!(
1977                "read_u16: slice len is only {}",
1978                slice.len()
1979            )));
1980        }
1981        let num = u16_from_be_slice(&slice[..U16_SIZE]);
1982        self.offset += U16_SIZE;
1983        Ok(num)
1984    }
1985
1986    /// Reads the "Type Bit Map" block for a DNS NSEC record.
1987    fn read_type_bitmap(&mut self) -> Result<Vec<u8>> {
1988        // From RFC 6762: 6.1.  Negative Responses
1989        // https://datatracker.ietf.org/doc/html/rfc6762#section-6.1
1990        //   o The Type Bit Map block number is 0.
1991        //   o The Type Bit Map block length byte is a value in the range 1-32.
1992        //   o The Type Bit Map data is 1-32 bytes, as indicated by length
1993        //     byte.
1994
1995        // Sanity check: at least 2 bytes to read.
1996        if self.data.len() < self.offset + 2 {
1997            return Err(Error::Msg(format!(
1998                "DnsIncoming is too short: {} at NSEC Type Bit Map offset {}",
1999                self.data.len(),
2000                self.offset
2001            )));
2002        }
2003
2004        let block_num = self.data[self.offset];
2005        self.offset += 1;
2006        if block_num != 0 {
2007            return Err(Error::Msg(format!(
2008                "NSEC block number is not 0: {}",
2009                block_num
2010            )));
2011        }
2012
2013        let block_len = self.data[self.offset] as usize;
2014        if !(1..=32).contains(&block_len) {
2015            return Err(Error::Msg(format!(
2016                "NSEC block length must be in the range 1-32: {}",
2017                block_len
2018            )));
2019        }
2020        self.offset += 1;
2021
2022        let end = self.offset + block_len;
2023        if end > self.data.len() {
2024            return Err(Error::Msg(format!(
2025                "NSEC block overflow: {} over RData len {}",
2026                end,
2027                self.data.len()
2028            )));
2029        }
2030        let bitmap = self.data[self.offset..end].to_vec();
2031        self.offset += block_len;
2032
2033        Ok(bitmap)
2034    }
2035
2036    fn read_vec(&mut self, length: usize) -> Vec<u8> {
2037        let v = self.data[self.offset..self.offset + length].to_vec();
2038        self.offset += length;
2039        v
2040    }
2041
2042    fn read_ipv4(&mut self) -> Ipv4Addr {
2043        let bytes: [u8; 4] = (&self.data)[self.offset..self.offset + 4]
2044            .try_into()
2045            .unwrap();
2046        self.offset += bytes.len();
2047        Ipv4Addr::from(bytes)
2048    }
2049
2050    fn read_ipv6(&mut self) -> Ipv6Addr {
2051        let bytes: [u8; 16] = (&self.data)[self.offset..self.offset + 16]
2052            .try_into()
2053            .unwrap();
2054        self.offset += bytes.len();
2055        Ipv6Addr::from(bytes)
2056    }
2057
2058    fn read_string(&mut self, length: usize) -> String {
2059        let s = str::from_utf8(&self.data[self.offset..self.offset + length]).unwrap();
2060        self.offset += length;
2061        s.to_string()
2062    }
2063
2064    /// Reads a domain name at the current location of `self.data`.
2065    ///
2066    /// See https://datatracker.ietf.org/doc/html/rfc1035#section-3.1 for
2067    /// domain name encoding.
2068    fn read_name(&mut self) -> Result<String> {
2069        let data = &self.data[..];
2070        let start_offset = self.offset;
2071        let mut offset = start_offset;
2072        let mut name = "".to_string();
2073        let mut at_end = false;
2074
2075        // From RFC1035:
2076        // "...Domain names in messages are expressed in terms of a sequence of labels.
2077        // Each label is represented as a one octet length field followed by that
2078        // number of octets."
2079        //
2080        // "...The compression scheme allows a domain name in a message to be
2081        // represented as either:
2082        // - a sequence of labels ending in a zero octet
2083        // - a pointer
2084        // - a sequence of labels ending with a pointer"
2085        loop {
2086            if offset >= data.len() {
2087                return Err(Error::Msg(format!(
2088                    "read_name: offset: {} data len {}. DnsIncoming: {:?}",
2089                    offset,
2090                    data.len(),
2091                    self
2092                )));
2093            }
2094            let length = data[offset];
2095
2096            // From RFC1035:
2097            // "...Since every domain name ends with the null label of
2098            // the root, a domain name is terminated by a length byte of zero."
2099            if length == 0 {
2100                if !at_end {
2101                    self.offset = offset + 1;
2102                }
2103                break; // The end of the name
2104            }
2105
2106            // Check the first 2 bits for possible "Message compression".
2107            match length & 0xC0 {
2108                0x00 => {
2109                    // regular utf8 string with length
2110                    offset += 1;
2111                    let ending = offset + length as usize;
2112
2113                    // Never read beyond the whole data length.
2114                    if ending > data.len() {
2115                        return Err(Error::Msg(format!(
2116                            "read_name: ending {} exceeds data length {}",
2117                            ending,
2118                            data.len()
2119                        )));
2120                    }
2121
2122                    name += str::from_utf8(&data[offset..ending])
2123                        .map_err(|e| Error::Msg(format!("read_name: from_utf8: {}", e)))?;
2124                    name += ".";
2125                    offset += length as usize;
2126                }
2127                0xC0 => {
2128                    // Message compression.
2129                    // See https://datatracker.ietf.org/doc/html/rfc1035#section-4.1.4
2130                    let slice = &data[offset..];
2131                    if slice.len() < U16_SIZE {
2132                        return Err(Error::Msg(format!(
2133                            "read_name: u16 slice len is only {}",
2134                            slice.len()
2135                        )));
2136                    }
2137                    let pointer = (u16_from_be_slice(slice) ^ 0xC000) as usize;
2138                    if pointer >= start_offset {
2139                        // Error: could trigger an infinite loop.
2140                        return Err(Error::Msg(format!(
2141                            "Invalid name compression: pointer {} must be less than the start offset {}",
2142                            &pointer, &start_offset
2143                        )));
2144                    }
2145
2146                    // A pointer marks the end of a domain name.
2147                    if !at_end {
2148                        self.offset = offset + U16_SIZE;
2149                        at_end = true;
2150                    }
2151                    offset = pointer;
2152                }
2153                _ => {
2154                    return Err(Error::Msg(format!(
2155                        "Bad name with invalid length: 0x{:x} offset {}, data (so far): {:x?}",
2156                        length,
2157                        offset,
2158                        &data[..offset]
2159                    )));
2160                }
2161            };
2162        }
2163
2164        Ok(name)
2165    }
2166}
2167
2168/// Returns UNIX time in millis
2169fn current_time_millis() -> u64 {
2170    SystemTime::now()
2171        .duration_since(SystemTime::UNIX_EPOCH)
2172        .expect("failed to get current UNIX time")
2173        .as_millis() as u64
2174}
2175
2176const fn u16_from_be_slice(bytes: &[u8]) -> u16 {
2177    let u8_array: [u8; 2] = [bytes[0], bytes[1]];
2178    u16::from_be_bytes(u8_array)
2179}
2180
2181const fn u32_from_be_slice(s: &[u8]) -> u32 {
2182    let u8_array: [u8; 4] = [s[0], s[1], s[2], s[3]];
2183    u32::from_be_bytes(u8_array)
2184}
2185
2186/// Returns the UNIX time in millis at which this record will have expired
2187/// by a certain percentage.
2188const fn get_expiration_time(created: u64, ttl: u32, percent: u32) -> u64 {
2189    // 'created' is in millis, 'ttl' is in seconds, hence:
2190    // ttl * 1000 * (percent / 100) => ttl * percent * 10
2191    created + (ttl * percent * 10) as u64
2192}