dns_protocol/
lib.rs

1//! An implementation of the DNS protocol, [sans I/O].
2//!
3//! [sans I/O]: https://sans-io.readthedocs.io/
4//!
5//! This crate implements the Domain System Protocol, used for namespace lookups among
6//! other things. It is intended to be used in conjunction with a transport layer, such
7//! as UDP, TCP or HTTPS The goal of this crate is to provide a runtime and protocol-agnostic
8//! implementation of the DNS protocol, so that it can be used in a variety of contexts.
9//!
10//! This crate is not only `no_std`, it does not use an allocator as well. This means that it
11//! can be used on embedded systems that do not have allocators, and that it does no allocation
12//! of its own. However, this comes with a catch: the user is expected to provide their own
13//! buffers for the various operations. This is done to avoid the need for an allocator, and
14//! to allow the user to control the memory usage of the library.
15//!
16//! This crate is also `#![forbid(unsafe_code)]`, and is intended to remain so.
17//!
18//! # Example
19//!
20//! ```no_run
21//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
22//! use dns_protocol::{Message, Question, ResourceRecord, ResourceType, Flags};
23//! use std::net::UdpSocket;
24//!
25//! // Allocate a buffer for the message.
26//! let mut buf = vec![0; 1024];
27//!
28//! // Create a message. This is a query for the A record of example.com.
29//! let mut questions = [
30//!     Question::new(
31//!         "example.com",
32//!         ResourceType::A,
33//!         0,
34//!     )
35//! ];
36//! let mut answers = [ResourceRecord::default()];
37//! let message = Message::new(0x42, Flags::default(), &mut questions, &mut answers, &mut [], &mut []);
38//!
39//! // Serialize the message into the buffer
40//! assert!(message.space_needed() <= buf.len());
41//! let len = message.write(&mut buf)?;
42//!
43//! // Write the buffer to the socket.
44//! let socket = UdpSocket::bind("localhost:0")?;
45//! socket.send_to(&buf[..len], "1.2.3.4:53")?;
46//!
47//! // Read new data from the socket.
48//! let data_len = socket.recv(&mut buf)?;
49//!
50//! // Parse the data as a message.
51//! let message = Message::read(
52//!     &buf[..data_len],
53//!     &mut questions,
54//!     &mut answers,
55//!     &mut [],
56//!     &mut [],
57//! )?;
58//!
59//! // Read the answer from the message.
60//! let answer = message.answers()[0];
61//! println!("Answer Data: {:?}", answer.data());
62//! # Ok(())
63//! # }
64//! ```
65//!
66//! # Features
67//!
68//! - `std` (enabled by default) - Enables the `std` library for use in `Error` types.
69//!   Disable this feature to use on `no_std` targets.
70//! ```
71
72#![forbid(
73    unsafe_code,
74    missing_docs,
75    missing_debug_implementations,
76    rust_2018_idioms,
77    future_incompatible
78)]
79#![no_std]
80
81#[cfg(feature = "std")]
82extern crate std;
83
84use core::convert::{TryFrom, TryInto};
85use core::fmt;
86use core::iter;
87use core::mem;
88use core::num::NonZeroUsize;
89use core::str;
90
91#[cfg(feature = "std")]
92use std::error::Error as StdError;
93
94mod ser;
95use ser::{Cursor, Serialize};
96pub use ser::{Label, LabelSegment};
97
98/// Macro to implement `Serialize` for a struct.
99macro_rules! serialize {
100    (
101        $(#[$outer:meta])*
102        pub struct $name:ident $(<$lt: lifetime>)? {
103            $(
104                $(#[$inner:meta])*
105                $vis: vis $field:ident: $ty:ty,
106            )*
107        }
108    ) => {
109        $(#[$outer])*
110        pub struct $name $(<$lt>)? {
111            $(
112                $(#[$inner])*
113                $vis $field: $ty,
114            )*
115        }
116
117        impl<'a> Serialize<'a> for $name $(<$lt>)? {
118            fn serialized_len(&self) -> usize {
119                let mut len = 0;
120                $(
121                    len += self.$field.serialized_len();
122                )*
123                len
124            }
125
126            fn serialize(&self, cursor: &mut [u8]) -> Result<usize, Error> {
127                let mut index = 0;
128                $(
129                    index += self.$field.serialize(&mut cursor[index..])?;
130                )*
131                Ok(index)
132            }
133
134            fn deserialize(&mut self, mut cursor: Cursor<'a>) -> Result<Cursor<'a>, Error> {
135                $(
136                    cursor = self.$field.deserialize(cursor)?;
137                )*
138                Ok(cursor)
139            }
140        }
141    };
142}
143
144/// An enum with a bevy of given variants.
145macro_rules! num_enum {
146    (
147        $(#[$outer:meta])*
148        pub enum $name:ident {
149            $(
150                $(#[$inner:meta])*
151                $variant:ident = $value:expr,
152            )*
153        }
154    ) => {
155        $(#[$outer])*
156        #[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
157        #[repr(u16)]
158        // New codes may be added in the future.
159        #[non_exhaustive]
160        pub enum $name {
161            $(
162                $(#[$inner])*
163                $variant = $value,
164            )*
165        }
166
167        impl TryFrom<u16> for $name {
168            type Error = InvalidCode;
169
170            fn try_from(value: u16) -> Result<Self, Self::Error> {
171                match value {
172                    $(
173                        $value => Ok($name::$variant),
174                    )*
175                    _ => Err(InvalidCode(value)),
176                }
177            }
178        }
179
180        impl From<$name> for u16 {
181            fn from(value: $name) -> Self {
182                value as u16
183            }
184        }
185
186        impl<'a> Serialize<'a> for $name {
187            fn serialized_len(&self) -> usize {
188                mem::size_of::<u16>()
189            }
190
191            fn serialize(&self, cursor: &mut [u8]) -> Result<usize, Error> {
192                let value: u16 = (*self).into();
193                value.serialize(cursor)
194            }
195
196            fn deserialize(&mut self, cursor: Cursor<'a>) -> Result<Cursor<'a>, Error> {
197                let mut value = 0;
198                let cursor = value.deserialize(cursor)?;
199                *self = value.try_into()?;
200                Ok(cursor)
201            }
202        }
203    };
204}
205
206/// An error that may occur while using the DNS protocol.
207#[derive(Debug)]
208#[non_exhaustive]
209pub enum Error {
210    /// We are trying to write to a buffer, but the buffer doesn't have enough space.
211    ///
212    /// This error can be fixed by increasing the size of the buffer.
213    NotEnoughWriteSpace {
214        /// The number of entries we tried to write.
215        tried_to_write: NonZeroUsize,
216
217        /// The number of entries that were available in the buffer.
218        available: usize,
219
220        /// The type of the buffer that we tried to write to.
221        buffer_type: &'static str,
222    },
223
224    /// We attempted to read from a buffer, but we ran out of room before we could read the entire
225    /// value.
226    ///
227    /// This error can be fixed by reading more bytes.
228    NotEnoughReadBytes {
229        /// The number of bytes we tried to read.
230        tried_to_read: NonZeroUsize,
231
232        /// The number of bytes that were available in the buffer.
233        available: usize,
234    },
235
236    /// We tried to parse this value, but it was invalid.
237    Parse {
238        /// The name of the value we tried to parse.
239        name: &'static str,
240    },
241
242    /// We tried to serialize a string longer than 256 bytes.
243    NameTooLong(usize),
244
245    /// We could not create a valid UTF-8 string from the bytes we read.
246    InvalidUtf8(str::Utf8Error),
247
248    /// We could not convert a raw number to a code.
249    InvalidCode(InvalidCode),
250
251    /// We do not support this many URL segments.
252    TooManyUrlSegments(usize),
253}
254
255impl fmt::Display for Error {
256    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
257        match self {
258            Error::NotEnoughWriteSpace {
259                tried_to_write,
260                available,
261                buffer_type,
262            } => {
263                write!(
264                    f,
265                    "not enough write space: tried to write {} entries to {} buffer, but only {} were available",
266                    tried_to_write, buffer_type, available
267                )
268            }
269            Error::NotEnoughReadBytes {
270                tried_to_read,
271                available,
272            } => {
273                write!(
274                    f,
275                    "not enough read bytes: tried to read {} bytes, but only {} were available",
276                    tried_to_read, available
277                )
278            }
279            Error::Parse { name } => {
280                write!(f, "parse error: could not parse a {}", name)
281            }
282            Error::NameTooLong(len) => {
283                write!(f, "name too long: name was {} bytes long", len)
284            }
285            Error::InvalidUtf8(err) => {
286                write!(f, "invalid UTF-8: {}", err)
287            }
288            Error::TooManyUrlSegments(segments) => {
289                write!(f, "too many URL segments: {} segments", segments)
290            }
291            Error::InvalidCode(err) => {
292                write!(f, "{}", err)
293            }
294        }
295    }
296}
297
298#[cfg(feature = "std")]
299impl StdError for Error {
300    fn source(&self) -> Option<&(dyn StdError + 'static)> {
301        match self {
302            Error::InvalidUtf8(err) => Some(err),
303            Error::InvalidCode(err) => Some(err),
304            _ => None,
305        }
306    }
307}
308
309impl From<str::Utf8Error> for Error {
310    fn from(err: str::Utf8Error) -> Self {
311        Error::InvalidUtf8(err)
312    }
313}
314
315impl From<InvalidCode> for Error {
316    fn from(err: InvalidCode) -> Self {
317        Error::InvalidCode(err)
318    }
319}
320
321/// The message for a DNS query.
322#[derive(Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
323pub struct Message<'arrays, 'innards> {
324    /// The header of the message.
325    header: Header,
326
327    /// The questions in the message.
328    ///
329    /// **Invariant:** This is always greater than or equal to `header.question_count`.
330    questions: &'arrays mut [Question<'innards>],
331
332    /// The answers in the message.
333    ///
334    /// **Invariant:** This is always greater than or equal to `header.answer_count`.
335    answers: &'arrays mut [ResourceRecord<'innards>],
336
337    /// The authorities in the message.
338    ///
339    /// **Invariant:** This is always greater than or equal to `header.authority_count`.
340    authorities: &'arrays mut [ResourceRecord<'innards>],
341
342    /// The additional records in the message.
343    ///
344    /// **Invariant:** This is always greater than or equal to `header.additional_count`.
345    additional: &'arrays mut [ResourceRecord<'innards>],
346}
347
348impl fmt::Debug for Message<'_, '_> {
349    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
350        f.debug_struct("Message")
351            .field("header", &self.header)
352            .field("questions", &self.questions())
353            .field("answers", &self.answers())
354            .field("authorities", &self.authorities())
355            .field("additional", &self.additional())
356            .finish()
357    }
358}
359
360impl<'arrays, 'innards> Message<'arrays, 'innards> {
361    /// Create a new message from a set of buffers for each section.
362    ///
363    /// # Panics
364    ///
365    /// This function panics if the number of questions, answers, authorities, or additional records
366    /// is greater than `u16::MAX`.
367    pub fn new(
368        id: u16,
369        flags: Flags,
370        questions: &'arrays mut [Question<'innards>],
371        answers: &'arrays mut [ResourceRecord<'innards>],
372        authorities: &'arrays mut [ResourceRecord<'innards>],
373        additional: &'arrays mut [ResourceRecord<'innards>],
374    ) -> Self {
375        Self {
376            header: Header {
377                id,
378                flags,
379                question_count: questions.len().try_into().unwrap(),
380                answer_count: answers.len().try_into().unwrap(),
381                authority_count: authorities.len().try_into().unwrap(),
382                additional_count: additional.len().try_into().unwrap(),
383            },
384            questions,
385            answers,
386            authorities,
387            additional,
388        }
389    }
390
391    /// Get the ID of this message.
392    pub fn id(&self) -> u16 {
393        self.header.id
394    }
395
396    /// Get a mutable reference to the ID of this message.
397    pub fn id_mut(&mut self) -> &mut u16 {
398        &mut self.header.id
399    }
400
401    /// Get the header of this message.
402    pub fn header(&self) -> Header {
403        self.header
404    }
405
406    /// Get the flags for this message.
407    pub fn flags(&self) -> Flags {
408        self.header.flags
409    }
410
411    /// Get a mutable reference to the flags for this message.
412    pub fn flags_mut(&mut self) -> &mut Flags {
413        &mut self.header.flags
414    }
415
416    /// Get the questions in this message.
417    pub fn questions(&self) -> &[Question<'innards>] {
418        &self.questions[..self.header.question_count as usize]
419    }
420
421    /// Get a mutable reference to the questions in this message.
422    pub fn questions_mut(&mut self) -> &mut [Question<'innards>] {
423        &mut self.questions[..self.header.question_count as usize]
424    }
425
426    /// Get the answers in this message.
427    pub fn answers(&self) -> &[ResourceRecord<'innards>] {
428        &self.answers[..self.header.answer_count as usize]
429    }
430
431    /// Get a mutable reference to the answers in this message.
432    pub fn answers_mut(&mut self) -> &mut [ResourceRecord<'innards>] {
433        &mut self.answers[..self.header.answer_count as usize]
434    }
435
436    /// Get the authorities in this message.
437    pub fn authorities(&self) -> &[ResourceRecord<'innards>] {
438        &self.authorities[..self.header.authority_count as usize]
439    }
440
441    /// Get a mutable reference to the authorities in this message.
442    pub fn authorities_mut(&mut self) -> &mut [ResourceRecord<'innards>] {
443        &mut self.authorities[..self.header.authority_count as usize]
444    }
445
446    /// Get the additional records in this message.
447    pub fn additional(&self) -> &[ResourceRecord<'innards>] {
448        &self.additional[..self.header.additional_count as usize]
449    }
450
451    /// Get a mutable reference to the additional records in this message.
452    pub fn additional_mut(&mut self) -> &mut [ResourceRecord<'innards>] {
453        &mut self.additional[..self.header.additional_count as usize]
454    }
455
456    /// Get the buffer space needed to serialize this message.
457    pub fn space_needed(&self) -> usize {
458        self.serialized_len()
459    }
460
461    /// Write this message to a buffer.
462    ///
463    /// Returns the number of bytes written.
464    ///
465    /// # Errors
466    ///
467    /// This function may raise [`Error::NameTooLong`] if a `Label` is too long to be serialized.
468    ///
469    /// # Panics
470    ///
471    /// This function panics if the buffer is not large enough to hold the serialized message. This
472    /// panic can be avoided by ensuring the buffer contains at least [`space_needed`] bytes.
473    pub fn write(&self, buffer: &mut [u8]) -> Result<usize, Error> {
474        self.serialize(buffer)
475    }
476
477    /// Read a message from a buffer.
478    ///
479    /// # Errors
480    ///
481    /// This function may raise one of the following errors:
482    ///
483    /// - [`Error::NotEnoughReadBytes`] if the buffer is not large enough to hold the entire structure.
484    ///   You may need to read more data before calling this function again.
485    /// - [`Error::NotEnoughWriteSpace`] if the buffers provided are not large enough to hold the
486    ///   entire structure. You may need to allocate larger buffers before calling this function.
487    /// - [`Error::InvalidUtf8`] if a domain name contains invalid UTF-8.
488    /// - [`Error::NameTooLong`] if a domain name is too long to be deserialized.
489    /// - [`Error::InvalidCode`] if a domain name contains an invalid label code.
490    pub fn read(
491        buffer: &'innards [u8],
492        questions: &'arrays mut [Question<'innards>],
493        answers: &'arrays mut [ResourceRecord<'innards>],
494        authorities: &'arrays mut [ResourceRecord<'innards>],
495        additional: &'arrays mut [ResourceRecord<'innards>],
496    ) -> Result<Message<'arrays, 'innards>, Error> {
497        // Create a message and cursor, then deserialize the message from the cursor.
498        let mut message = Message::new(
499            0,
500            Flags::default(),
501            questions,
502            answers,
503            authorities,
504            additional,
505        );
506        let cursor = Cursor::new(buffer);
507
508        message.deserialize(cursor)?;
509
510        Ok(message)
511    }
512}
513
514impl<'arrays, 'innards> Serialize<'innards> for Message<'arrays, 'innards> {
515    fn serialized_len(&self) -> usize {
516        iter::once(self.header.serialized_len())
517            .chain(self.questions().iter().map(Serialize::serialized_len))
518            .chain(self.answers().iter().map(Serialize::serialized_len))
519            .chain(self.authorities().iter().map(Serialize::serialized_len))
520            .chain(self.additional().iter().map(Serialize::serialized_len))
521            .fold(0, |a, b| a.saturating_add(b))
522    }
523
524    fn serialize(&self, bytes: &mut [u8]) -> Result<usize, Error> {
525        let mut offset = 0;
526        offset += self.header.serialize(&mut bytes[offset..])?;
527        for question in self.questions.iter() {
528            offset += question.serialize(&mut bytes[offset..])?;
529        }
530        for answer in self.answers.iter() {
531            offset += answer.serialize(&mut bytes[offset..])?;
532        }
533        for authority in self.authorities.iter() {
534            offset += authority.serialize(&mut bytes[offset..])?;
535        }
536        for additional in self.additional.iter() {
537            offset += additional.serialize(&mut bytes[offset..])?;
538        }
539        Ok(offset)
540    }
541
542    fn deserialize(&mut self, cursor: Cursor<'innards>) -> Result<Cursor<'innards>, Error> {
543        /// Read a set of `T`, bounded by `count`.
544        fn try_read_set<'a, T: Serialize<'a>>(
545            mut cursor: Cursor<'a>,
546            count: usize,
547            items: &mut [T],
548            name: &'static str,
549        ) -> Result<Cursor<'a>, Error> {
550            let len = items.len();
551
552            if count == 0 {
553                return Ok(cursor);
554            }
555
556            for i in 0..count {
557                cursor = items
558                    .get_mut(i)
559                    .ok_or_else(|| Error::NotEnoughWriteSpace {
560                        tried_to_write: NonZeroUsize::new(count).unwrap(),
561                        available: len,
562                        buffer_type: name,
563                    })?
564                    .deserialize(cursor)?;
565            }
566
567            Ok(cursor)
568        }
569
570        let cursor = self.header.deserialize(cursor)?;
571
572        // If we're truncated, we can't read the rest of the message.
573        if self.header.flags.truncated() {
574            self.header.clear();
575            return Ok(cursor);
576        }
577
578        let cursor = try_read_set(
579            cursor,
580            self.header.question_count as usize,
581            self.questions,
582            "Question",
583        )?;
584        let cursor = try_read_set(
585            cursor,
586            self.header.answer_count as usize,
587            self.answers,
588            "Answer",
589        )?;
590        let cursor = try_read_set(
591            cursor,
592            self.header.authority_count as usize,
593            self.authorities,
594            "Authority",
595        )?;
596        let cursor = try_read_set(
597            cursor,
598            self.header.additional_count as usize,
599            self.additional,
600            "Additional",
601        )?;
602
603        Ok(cursor)
604    }
605}
606
607serialize! {
608    /// The header for a DNS query.
609    #[derive(Debug, Default, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
610    pub struct Header {
611        /// The ID of this query.
612        id: u16,
613
614        /// The flags associated with this query.
615        flags: Flags,
616
617        /// The number of questions in this query.
618        question_count: u16,
619
620        /// The number of answers in this query.
621        answer_count: u16,
622
623        /// The number of authorities in this query.
624        authority_count: u16,
625
626        /// The number of additional records in this query.
627        additional_count: u16,
628    }
629}
630
631impl Header {
632    fn clear(&mut self) {
633        self.question_count = 0;
634        self.answer_count = 0;
635        self.authority_count = 0;
636        self.additional_count = 0;
637    }
638}
639
640serialize! {
641    /// The question in a DNS query.
642    #[derive(Debug, Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
643    pub struct Question<'a> {
644        /// The name of the question.
645        name: Label<'a>,
646
647        /// The type of the question.
648        ty: ResourceType,
649
650        /// The class of the question.
651        class: u16,
652    }
653}
654
655impl<'a> Question<'a> {
656    /// Create a new question.
657    pub fn new(label: impl Into<Label<'a>>, ty: ResourceType, class: u16) -> Self {
658        Self {
659            name: label.into(),
660            ty,
661            class,
662        }
663    }
664
665    /// Get the name of the question.
666    pub fn name(&self) -> Label<'a> {
667        self.name
668    }
669
670    /// Get the type of the question.
671    pub fn ty(&self) -> ResourceType {
672        self.ty
673    }
674
675    /// Get the class of the question.
676    pub fn class(&self) -> u16 {
677        self.class
678    }
679}
680
681serialize! {
682    /// A resource record in a DNS query.
683    #[derive(Debug, Default, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
684    pub struct ResourceRecord<'a> {
685        /// The name of the resource record.
686        name: Label<'a>,
687
688        /// The type of the resource record.
689        ty: ResourceType,
690
691        /// The class of the resource record.
692        class: u16,
693
694        /// The time-to-live of the resource record.
695        ttl: u32,
696
697        /// The data of the resource record.
698        data: ResourceData<'a>,
699    }
700}
701
702impl<'a> ResourceRecord<'a> {
703    /// Create a new `ResourceRecord`.
704    pub fn new(
705        name: impl Into<Label<'a>>,
706        ty: ResourceType,
707        class: u16,
708        ttl: u32,
709        data: &'a [u8],
710    ) -> Self {
711        Self {
712            name: name.into(),
713            ty,
714            class,
715            ttl,
716            data: data.into(),
717        }
718    }
719
720    /// Get the name of the resource record.
721    pub fn name(&self) -> Label<'a> {
722        self.name
723    }
724
725    /// Get the type of the resource record.
726    pub fn ty(&self) -> ResourceType {
727        self.ty
728    }
729
730    /// Get the class of the resource record.
731    pub fn class(&self) -> u16 {
732        self.class
733    }
734
735    /// Get the time-to-live of the resource record.
736    pub fn ttl(&self) -> u32 {
737        self.ttl
738    }
739
740    /// Get the data of the resource record.
741    pub fn data(&self) -> &'a [u8] {
742        self.data.0
743    }
744}
745
746/// The resource stored in a resource record.
747#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
748pub(crate) struct ResourceData<'a>(&'a [u8]);
749
750impl<'a> From<&'a [u8]> for ResourceData<'a> {
751    fn from(data: &'a [u8]) -> Self {
752        ResourceData(data)
753    }
754}
755
756impl<'a> Serialize<'a> for ResourceData<'a> {
757    fn serialized_len(&self) -> usize {
758        2 + self.0.len()
759    }
760
761    fn serialize(&self, bytes: &mut [u8]) -> Result<usize, Error> {
762        let len = self.serialized_len();
763        if bytes.len() < len {
764            panic!("not enough bytes to serialize resource data");
765        }
766
767        // Write the length as a big-endian u16
768        let [b1, b2] = (self.0.len() as u16).to_be_bytes();
769        bytes[0] = b1;
770        bytes[1] = b2;
771
772        // Write the data
773        bytes[2..len].copy_from_slice(self.0);
774
775        Ok(len)
776    }
777
778    fn deserialize(&mut self, cursor: Cursor<'a>) -> Result<Cursor<'a>, Error> {
779        // Deserialize a u16 for the length
780        let mut len = 0u16;
781        let cursor = len.deserialize(cursor)?;
782
783        if len == 0 {
784            self.0 = &[];
785            return Ok(cursor);
786        }
787
788        // Read in the data
789        if cursor.len() < len as usize {
790            return Err(Error::NotEnoughReadBytes {
791                tried_to_read: NonZeroUsize::new(len as usize).unwrap(),
792                available: cursor.len(),
793            });
794        }
795
796        self.0 = &cursor.remaining()[..len as usize];
797        cursor.advance(len as usize)
798    }
799}
800
801/// The flags associated with a DNS message.
802#[derive(Default, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
803#[repr(transparent)]
804pub struct Flags(u16);
805
806impl Flags {
807    // Values used to manipulate the inside.
808    const RAW_QR: u16 = 1 << 15;
809    const RAW_OPCODE_SHIFT: u16 = 11;
810    const RAW_OPCODE_MASK: u16 = 0b1111;
811    const RAW_AA: u16 = 1 << 10;
812    const RAW_TC: u16 = 1 << 9;
813    const RAW_RD: u16 = 1 << 8;
814    const RAW_RA: u16 = 1 << 7;
815    const RAW_RCODE_SHIFT: u16 = 0;
816    const RAW_RCODE_MASK: u16 = 0b1111;
817
818    /// Create a new, empty set of flags.
819    pub const fn new() -> Self {
820        Self(0)
821    }
822
823    /// Use the standard set of flags for a DNS query.
824    ///
825    /// This is identical to `new()` but uses recursive querying.
826    pub const fn standard_query() -> Self {
827        Self(0x0100)
828    }
829
830    /// Get the query/response flag.
831    pub fn qr(&self) -> MessageType {
832        if self.0 & Self::RAW_QR != 0 {
833            MessageType::Reply
834        } else {
835            MessageType::Query
836        }
837    }
838
839    /// Set the message's query/response flag.
840    pub fn set_qr(&mut self, qr: MessageType) -> &mut Self {
841        if qr == MessageType::Reply {
842            self.0 |= Self::RAW_QR;
843        } else {
844            self.0 &= !Self::RAW_QR;
845        }
846
847        self
848    }
849
850    /// Get the opcode.
851    pub fn opcode(&self) -> Opcode {
852        let raw = (self.0 >> Self::RAW_OPCODE_SHIFT) & Self::RAW_OPCODE_MASK;
853        raw.try_into()
854            .unwrap_or_else(|_| panic!("invalid opcode: {}", raw))
855    }
856
857    /// Set the opcode.
858    pub fn set_opcode(&mut self, opcode: Opcode) {
859        self.0 |= (opcode as u16) << Self::RAW_OPCODE_SHIFT;
860    }
861
862    /// Get whether this message is authoritative.
863    pub fn authoritative(&self) -> bool {
864        self.0 & Self::RAW_AA != 0
865    }
866
867    /// Set whether this message is authoritative.
868    pub fn set_authoritative(&mut self, authoritative: bool) -> &mut Self {
869        if authoritative {
870            self.0 |= Self::RAW_AA;
871        } else {
872            self.0 &= !Self::RAW_AA;
873        }
874
875        self
876    }
877
878    /// Get whether this message is truncated.
879    pub fn truncated(&self) -> bool {
880        self.0 & Self::RAW_TC != 0
881    }
882
883    /// Set whether this message is truncated.
884    pub fn set_truncated(&mut self, truncated: bool) -> &mut Self {
885        if truncated {
886            self.0 |= Self::RAW_TC;
887        } else {
888            self.0 &= !Self::RAW_TC;
889        }
890
891        self
892    }
893
894    /// Get whether this message is recursive.
895    pub fn recursive(&self) -> bool {
896        self.0 & Self::RAW_RD != 0
897    }
898
899    /// Set whether this message is recursive.
900    pub fn set_recursive(&mut self, recursive: bool) -> &mut Self {
901        if recursive {
902            self.0 |= Self::RAW_RD;
903        } else {
904            self.0 &= !Self::RAW_RD;
905        }
906
907        self
908    }
909
910    /// Get whether recursion is available for this message.
911    pub fn recursion_available(&self) -> bool {
912        self.0 & Self::RAW_RA != 0
913    }
914
915    /// Set whether recursion is available for this message.
916    pub fn set_recursion_available(&mut self, recursion_available: bool) -> &mut Self {
917        if recursion_available {
918            self.0 |= Self::RAW_RA;
919        } else {
920            self.0 &= !Self::RAW_RA;
921        }
922
923        self
924    }
925
926    /// Get the response code.
927    pub fn response_code(&self) -> ResponseCode {
928        let raw = (self.0 >> Self::RAW_RCODE_SHIFT) & Self::RAW_RCODE_MASK;
929        raw.try_into()
930            .unwrap_or_else(|_| panic!("invalid response code: {}", raw))
931    }
932
933    /// Set the response code.
934    pub fn set_response_code(&mut self, response_code: ResponseCode) -> &mut Self {
935        self.0 |= (response_code as u16) << Self::RAW_RCODE_SHIFT;
936        self
937    }
938
939    /// Get the raw value of these flags.
940    pub fn raw(self) -> u16 {
941        self.0
942    }
943}
944
945impl fmt::Debug for Flags {
946    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
947        let mut list = f.debug_list();
948
949        list.entry(&self.qr());
950        list.entry(&self.opcode());
951
952        if self.authoritative() {
953            list.entry(&"authoritative");
954        }
955
956        if self.truncated() {
957            list.entry(&"truncated");
958        }
959
960        if self.recursive() {
961            list.entry(&"recursive");
962        }
963
964        if self.recursion_available() {
965            list.entry(&"recursion available");
966        }
967
968        list.entry(&self.response_code());
969
970        list.finish()
971    }
972}
973
974impl<'a> Serialize<'a> for Flags {
975    fn serialized_len(&self) -> usize {
976        2
977    }
978
979    fn serialize(&self, buf: &mut [u8]) -> Result<usize, Error> {
980        self.0.serialize(buf)
981    }
982
983    fn deserialize(&mut self, bytes: Cursor<'a>) -> Result<Cursor<'a>, Error> {
984        u16::deserialize(&mut self.0, bytes)
985    }
986}
987
988/// Whether a message is a query or a reply.
989#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
990pub enum MessageType {
991    /// The message is a query.
992    Query,
993
994    /// The message is a reply.
995    Reply,
996}
997
998num_enum! {
999    /// The operation code for the query.
1000    pub enum Opcode {
1001        /// A standard query.
1002        Query = 0,
1003
1004        /// A reverse query.
1005        IQuery = 1,
1006
1007        /// A server status request.
1008        Status = 2,
1009
1010        /// A notification of zone change.
1011        Notify = 4,
1012
1013        /// A dynamic update.
1014        Update = 5,
1015
1016        /// DSO query.
1017        Dso = 6,
1018    }
1019}
1020
1021num_enum! {
1022    /// The response code for a query.
1023    pub enum ResponseCode {
1024        /// There was no error in the query.
1025        NoError = 0,
1026
1027        /// The query was malformed.
1028        FormatError = 1,
1029
1030        /// The server failed to fulfill the query.
1031        ServerFailure = 2,
1032
1033        /// The name does not exist.
1034        NameError = 3,
1035
1036        /// The query is not implemented.
1037        NotImplemented = 4,
1038
1039        /// The query is refused.
1040        Refused = 5,
1041
1042        /// The name exists, but the query type is not supported.
1043        YxDomain = 6,
1044
1045        /// The name does not exist, but the query type is supported.
1046        YxRrSet = 7,
1047
1048        /// The name exists, but the query type is not supported.
1049        NxRrSet = 8,
1050
1051        /// The server is not authoritative for the zone.
1052        NotAuth = 9,
1053
1054        /// The name does not exist in the zone.
1055        NotZone = 10,
1056
1057        /// The DSO-TYPE is not supported.
1058        DsoTypeNi = 11,
1059
1060        /// Bad OPT version.
1061        BadVers = 16,
1062
1063        /// Key not recognized.
1064        BadKey = 17,
1065
1066        /// Signature out of time window.
1067        BadTime = 18,
1068
1069        /// Bad TKEY mode.
1070        BadMode = 19,
1071
1072        /// Duplicate key name.
1073        BadName = 20,
1074
1075        /// Algorithm not supported.
1076        BadAlg = 21,
1077
1078        /// Bad truncation.
1079        BadTrunc = 22,
1080
1081        /// Bad/missing server cookie.
1082        BadCookie = 23,
1083    }
1084}
1085
1086num_enum! {
1087    /// The resource types that a question can ask for.
1088    pub enum ResourceType {
1089        /// Get the host's IPv4 address.
1090        A = 1,
1091
1092        /// Get the authoritative name servers for a domain.
1093        NS = 2,
1094
1095        /// Get the mail server for a domain.
1096        MD = 3,
1097
1098        /// Get the mail forwarder for a domain.
1099        MF = 4,
1100
1101        /// Get the canonical name for a domain.
1102        CName = 5,
1103
1104        /// Get the start of authority record for a domain.
1105        Soa = 6,
1106
1107        /// Get the mailbox for a domain.
1108        MB = 7,
1109
1110        /// Get the mail group member for a domain.
1111        MG = 8,
1112
1113        /// Get the mail rename domain for a domain.
1114        MR = 9,
1115
1116        /// Get the null record for a domain.
1117        Null = 10,
1118
1119        /// Get the well known services for a domain.
1120        Wks = 11,
1121
1122        /// Get the domain pointer for a domain.
1123        Ptr = 12,
1124
1125        /// Get the host information for a domain.
1126        HInfo = 13,
1127
1128        /// Get the mailbox or mail list information for a domain.
1129        MInfo = 14,
1130
1131        /// Get the mail exchange for a domain.
1132        MX = 15,
1133
1134        /// Get the text for a domain.
1135        Txt = 16,
1136
1137        /// Get the responsible person for a domain.
1138        RP = 17,
1139
1140        /// Get the AFS database location for a domain.
1141        AfsDb = 18,
1142
1143        /// Get the X.25 address for a domain.
1144        X25 = 19,
1145
1146        /// Get the ISDN address for a domain.
1147        Isdn = 20,
1148
1149        /// Get the router for a domain.
1150        Rt = 21,
1151
1152        /// Get the NSAP address for a domain.
1153        NSap = 22,
1154
1155        /// Get the reverse NSAP address for a domain.
1156        NSapPtr = 23,
1157
1158        /// Get the security signature for a domain.
1159        Sig = 24,
1160
1161        /// Get the key for a domain.
1162        Key = 25,
1163
1164        /// Get the X.400 mail mapping for a domain.
1165        Px = 26,
1166
1167        /// Get the geographical location for a domain.
1168        GPos = 27,
1169
1170        /// Get the IPv6 address for a domain.
1171        AAAA = 28,
1172
1173        /// Get the location for a domain.
1174        Loc = 29,
1175
1176        /// Get the next domain name in a zone.
1177        Nxt = 30,
1178
1179        /// Get the endpoint identifier for a domain.
1180        EId = 31,
1181
1182        /// Get the Nimrod locator for a domain.
1183        NimLoc = 32,
1184
1185        /// Get the server selection for a domain.
1186        Srv = 33,
1187
1188        /// Get the ATM address for a domain.
1189        AtmA = 34,
1190
1191        /// Get the naming authority pointer for a domain.
1192        NAPtr = 35,
1193
1194        /// Get the key exchange for a domain.
1195        Kx = 36,
1196
1197        /// Get the certificate for a domain.
1198        Cert = 37,
1199
1200        /// Get the IPv6 address for a domain.
1201        ///
1202        /// This is obsolete; use `AAAA` instead.
1203        A6 = 38,
1204
1205        /// Get the DNAME for a domain.
1206        DName = 39,
1207
1208        /// Get the sink for a domain.
1209        Sink = 40,
1210
1211        /// Get the OPT for a domain.
1212        Opt = 41,
1213
1214        /// Get the address prefix list for a domain.
1215        ApL = 42,
1216
1217        /// Get the delegation signer for a domain.
1218        DS = 43,
1219
1220        /// Get the SSH key fingerprint for a domain.
1221        SshFp = 44,
1222
1223        /// Get the IPSEC key for a domain.
1224        IpSecKey = 45,
1225
1226        /// Get the resource record signature for a domain.
1227        RRSig = 46,
1228
1229        /// Get the next secure record for a domain.
1230        NSEC = 47,
1231
1232        /// Get the DNSKEY for a domain.
1233        DNSKey = 48,
1234
1235        /// Get the DHCID for a domain.
1236        DHCID = 49,
1237
1238        /// Get the NSEC3 for a domain.
1239        NSEC3 = 50,
1240
1241        /// Get the NSEC3 parameters for a domain.
1242        NSEC3Param = 51,
1243
1244        /// Get the TLSA for a domain.
1245        TLSA = 52,
1246
1247        /// Get the S/MIME certificate association for a domain.
1248        SMimeA = 53,
1249
1250        /// Get the host information for a domain.
1251        HIP = 55,
1252
1253        /// Get the NINFO for a domain.
1254        NInfo = 56,
1255
1256        /// Get the RKEY for a domain.
1257        RKey = 57,
1258
1259        /// Get the trust anchor link for a domain.
1260        TALink = 58,
1261
1262        /// Get the child DS for a domain.
1263        CDS = 59,
1264
1265        /// Get the DNSKEY for a domain.
1266        CDNSKey = 60,
1267
1268        /// Get the OpenPGP key for a domain.
1269        OpenPGPKey = 61,
1270
1271        /// Get the Child-to-Parent Synchronization for a domain.
1272        CSync = 62,
1273
1274        /// Get the Zone Data Message for a domain.
1275        ZoneMD = 63,
1276
1277        /// Get the General Purpose Service Binding for a domain.
1278        Svcb = 64,
1279
1280        /// Get the HTTP Service Binding for a domain.
1281        Https = 65,
1282
1283        /// Get the Sender Policy Framework for a domain.
1284        Spf = 99,
1285
1286        /// Get the UINFO for a domain.
1287        UInfo = 100,
1288
1289        /// Get the UID for a domain.
1290        UID = 101,
1291
1292        /// Get the GID for a domain.
1293        GID = 102,
1294
1295        /// Get the UNSPEC for a domain.
1296        Unspec = 103,
1297
1298        /// Get the NID for a domain.
1299        NID = 104,
1300
1301        /// Get the L32 for a domain.
1302        L32 = 105,
1303
1304        /// Get the L64 for a domain.
1305        L64 = 106,
1306
1307        /// Get the LP for a domain.
1308        LP = 107,
1309
1310        /// Get the EUI48 for a domain.
1311        EUI48 = 108,
1312
1313        /// Get the EUI64 for a domain.
1314        EUI64 = 109,
1315
1316        /// Get the transaction key for a domain.
1317        TKey = 249,
1318
1319        /// Get the transaction signature for a domain.
1320        TSig = 250,
1321
1322        /// Get the incremental transfer for a domain.
1323        Ixfr = 251,
1324
1325        /// Get the transfer of an entire zone for a domain.
1326        Axfr = 252,
1327
1328        /// Get the mailbox-related records for a domain.
1329        MailB = 253,
1330
1331        /// Get the mail agent RRs for a domain.
1332        ///
1333        /// This is obsolete; use `MX` instead.
1334        MailA = 254,
1335
1336        /// Get the wildcard match for a domain.
1337        Wildcard = 255,
1338
1339        /// Get the URI for a domain.
1340        Uri = 256,
1341
1342        /// Get the certification authority authorization for a domain.
1343        Caa = 257,
1344
1345        /// Get the application visibility and control for a domain.
1346        Avc = 258,
1347
1348        /// Get the digital object architecture for a domain.
1349        Doa = 259,
1350
1351        /// Get the automatic network discovery for a domain.
1352        Amtrelay = 260,
1353
1354        /// Get the DNSSEC trust authorities for a domain.
1355        TA = 32768,
1356
1357        /// Get the DNSSEC lookaside validation for a domain.
1358        DLV = 32769,
1359    }
1360}
1361
1362impl Default for ResourceType {
1363    fn default() -> Self {
1364        Self::A
1365    }
1366}
1367
1368/// The given value is not a valid code.
1369#[derive(Debug, Clone)]
1370pub struct InvalidCode(u16);
1371
1372impl InvalidCode {
1373    /// Get the invalid code.
1374    pub fn code(&self) -> u16 {
1375        self.0
1376    }
1377}
1378
1379impl fmt::Display for InvalidCode {
1380    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1381        write!(f, "invalid code: {}", self.0)
1382    }
1383}
1384
1385#[cfg(feature = "std")]
1386impl StdError for InvalidCode {}
1387
1388#[cfg(test)]
1389mod tests {
1390    use super::*;
1391
1392    #[test]
1393    fn resource_data_serialization() {
1394        let mut buf = [0u8; 7];
1395        let record = ResourceData(&[0x1f, 0xfe, 0x02, 0x24, 0x75]);
1396        let len = record
1397            .serialize(&mut buf)
1398            .expect("serialized into provided buffer");
1399        assert_eq!(len, 7);
1400        assert_eq!(buf, [0x00, 0x05, 0x1f, 0xfe, 0x02, 0x24, 0x75]);
1401    }
1402}