dns_protocol_patch/
lib.rs

1//! An implementation of the DNS protocol, [sans I/O].
2//!
3//! [sans I/O]: https://sans-io.readthedocs.io/
4//!
5//! This crate implements the Domain System Protocol, used for namespace lookups among
6//! other things. It is intended to be used in conjunction with a transport layer, such
7//! as UDP, TCP or HTTPS The goal of this crate is to provide a runtime and protocol-agnostic
8//! implementation of the DNS protocol, so that it can be used in a variety of contexts.
9//!
10//! This crate is not only `no_std`, it does not use an allocator as well. This means that it
11//! can be used on embedded systems that do not have allocators, and that it does no allocation
12//! of its own. However, this comes with a catch: the user is expected to provide their own
13//! buffers for the various operations. This is done to avoid the need for an allocator, and
14//! to allow the user to control the memory usage of the library.
15//!
16//! This crate is also `#![forbid(unsafe_code)]`, and is intended to remain so.
17//!
18//! # Example
19//!
20//! ```no_run
21//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
22//! use dns_protocol::{Message, Question, ResourceRecord, ResourceType, Flags};
23//! use std::net::UdpSocket;
24//!
25//! // Allocate a buffer for the message.
26//! let mut buf = vec![0; 1024];
27//!
28//! // Create a message. This is a query for the A record of example.com.
29//! let mut questions = [
30//!     Question::new(
31//!         "example.com",
32//!         ResourceType::A,
33//!         0,
34//!     )
35//! ];
36//! let mut answers = [ResourceRecord::default()];
37//! let message = Message::new(0x42, Flags::default(), &mut questions, &mut answers, &mut [], &mut []);
38//!
39//! // Serialize the message into the buffer
40//! assert!(message.space_needed() <= buf.len());
41//! let len = message.write(&mut buf)?;
42//!
43//! // Write the buffer to the socket.
44//! let socket = UdpSocket::bind("localhost:0")?;
45//! socket.send_to(&buf[..len], "1.2.3.4:53")?;
46//!
47//! // Read new data from the socket.
48//! let data_len = socket.recv(&mut buf)?;
49//!
50//! // Parse the data as a message.
51//! let message = Message::read(
52//!     &buf[..data_len],
53//!     &mut questions,
54//!     &mut answers,
55//!     &mut [],
56//!     &mut [],
57//! )?;
58//!
59//! // Read the answer from the message.
60//! let answer = message.answers()[0];
61//! println!("Answer Data: {:?}", answer.data());
62//! # Ok(())
63//! # }
64//! ```
65//!
66//! # Features
67//!
68//! - `std` (enabled by default) - Enables the `std` library for use in `Error` types.
69//!   Disable this feature to use on `no_std` targets.
70//! ```
71
72#![forbid(
73    unsafe_code,
74    missing_docs,
75    missing_debug_implementations,
76    rust_2018_idioms,
77    future_incompatible
78)]
79#![no_std]
80
81#[cfg(feature = "std")]
82extern crate std;
83
84use core::convert::{TryFrom, TryInto};
85use core::fmt;
86use core::iter;
87use core::mem;
88use core::num::NonZeroUsize;
89use core::str;
90
91use core::error::Error as StdError;
92
93mod ser;
94pub use ser::{Cursor, Deserialize, Label, LabelSegment, Serialize};
95
96/// Macro to implement `Serialize` for a struct.
97macro_rules! serialize {
98    (
99        $(#[$outer:meta])*
100        pub struct $name:ident $(<$lt: lifetime>)? {
101            $(
102                $(#[$inner:meta])*
103                $vis: vis $field:ident: $ty:ty,
104            )*
105        }
106    ) => {
107        $(#[$outer])*
108        pub struct $name $(<$lt>)? {
109            $(
110                $(#[$inner])*
111                $vis $field: $ty,
112            )*
113        }
114
115        impl<'a> Serialize<'a> for $name $(<$lt>)? {
116            fn serialized_len(&self) -> usize {
117                let mut len = 0;
118                $(
119                    len += self.$field.serialized_len();
120                )*
121                len
122            }
123
124            fn serialize(&self, cursor: &mut [u8]) -> Result<usize, Error> {
125                let mut index = 0;
126                $(
127                    index += self.$field.serialize(&mut cursor[index..])?;
128                )*
129                Ok(index)
130            }
131        }
132
133        impl<'a> Deserialize<'a> for $name $(<$lt>)? {
134            fn deserialize(&mut self, mut cursor: Cursor<'a>) -> Result<Cursor<'a>, Error> {
135                $(
136                    cursor = self.$field.deserialize(cursor)?;
137                )*
138                Ok(cursor)
139            }
140        }
141    };
142}
143
144/// An enum with a bevy of given variants.
145macro_rules! num_enum {
146    (
147        $(#[$outer:meta])*
148        pub enum $name:ident {
149            $(
150                $(#[$inner:meta])*
151                $variant:ident = $value:expr,
152            )*
153        }
154    ) => {
155        $(#[$outer])*
156        #[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
157        #[repr(u16)]
158        // New codes may be added in the future.
159        #[non_exhaustive]
160        pub enum $name {
161            $(
162                $(#[$inner])*
163                $variant = $value,
164            )*
165        }
166
167        impl TryFrom<u16> for $name {
168            type Error = InvalidCode;
169
170            fn try_from(value: u16) -> Result<Self, Self::Error> {
171                match value {
172                    $(
173                        $value => Ok($name::$variant),
174                    )*
175                    _ => Err(InvalidCode(value)),
176                }
177            }
178        }
179
180        impl From<$name> for u16 {
181            fn from(value: $name) -> Self {
182                value as u16
183            }
184        }
185
186        impl<'a> Serialize<'a> for $name {
187            fn serialized_len(&self) -> usize {
188                mem::size_of::<u16>()
189            }
190
191            fn serialize(&self, cursor: &mut [u8]) -> Result<usize, Error> {
192                let value: u16 = (*self).into();
193                value.serialize(cursor)
194            }
195        }
196
197        impl<'a> Deserialize<'a> for $name {
198            fn deserialize(&mut self, cursor: Cursor<'a>) -> Result<Cursor<'a>, Error> {
199                let mut value = 0;
200                let cursor = value.deserialize(cursor)?;
201                *self = value.try_into()?;
202                Ok(cursor)
203            }
204        }
205    };
206}
207
208/// The buffer type when dealing with DNS messages.
209#[derive(Debug, Clone, Copy, PartialEq, Eq, derive_more::IsVariant, derive_more::Display)]
210pub enum BufferType {
211    /// The question buffer.
212    #[display("question")]
213    Question,
214    /// The answer buffer.
215    #[display("answer")]
216    Answer,
217    /// The authority buffer.
218    #[display("authority")]
219    Authority,
220    /// The additional buffer.
221    #[display("additional")]
222    Additional,
223}
224
225/// An error that may occur while using the DNS protocol.
226#[derive(Debug)]
227#[non_exhaustive]
228pub enum Error {
229    /// We are trying to write to a buffer, but the buffer doesn't have enough space.
230    ///
231    /// This error can be fixed by increasing the size of the buffer.
232    NotEnoughWriteSpace {
233        /// The number of entries we tried to write.
234        tried_to_write: NonZeroUsize,
235
236        /// The number of entries that were available in the buffer.
237        available: usize,
238
239        /// The type of the buffer that we tried to write to.
240        buffer_type: BufferType,
241    },
242
243    /// We attempted to read from a buffer, but we ran out of room before we could read the entire
244    /// value.
245    ///
246    /// This error can be fixed by reading more bytes.
247    NotEnoughReadBytes {
248        /// The number of bytes we tried to read.
249        tried_to_read: NonZeroUsize,
250
251        /// The number of bytes that were available in the buffer.
252        available: usize,
253    },
254
255    /// We tried to parse this value, but it was invalid.
256    Parse {
257        /// The name of the value we tried to parse.
258        name: &'static str,
259    },
260
261    /// We tried to serialize a string longer than 256 bytes.
262    NameTooLong(usize),
263
264    /// We could not create a valid UTF-8 string from the bytes we read.
265    InvalidUtf8(simdutf8::compat::Utf8Error),
266
267    /// We could not convert a raw number to a code.
268    InvalidCode(InvalidCode),
269
270    /// We do not support this many URL segments.
271    TooManyUrlSegments(usize),
272}
273
274impl fmt::Display for Error {
275    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
276        match self {
277            Error::NotEnoughWriteSpace {
278                tried_to_write,
279                available,
280                buffer_type,
281            } => {
282                write!(
283                    f,
284                    "not enough write space: tried to write {} entries to {} buffer, but only {} were available",
285                    tried_to_write, buffer_type, available
286                )
287            }
288            Error::NotEnoughReadBytes {
289                tried_to_read,
290                available,
291            } => {
292                write!(
293                    f,
294                    "not enough read bytes: tried to read {} bytes, but only {} were available",
295                    tried_to_read, available
296                )
297            }
298            Error::Parse { name } => {
299                write!(f, "parse error: could not parse a {}", name)
300            }
301            Error::NameTooLong(len) => {
302                write!(f, "name too long: name was {} bytes long", len)
303            }
304            Error::InvalidUtf8(err) => {
305                write!(f, "invalid UTF-8: {}", err)
306            }
307            Error::TooManyUrlSegments(segments) => {
308                write!(f, "too many URL segments: {} segments", segments)
309            }
310            Error::InvalidCode(err) => {
311                write!(f, "{}", err)
312            }
313        }
314    }
315}
316
317impl StdError for Error {
318    fn source(&self) -> Option<&(dyn StdError + 'static)> {
319        match self {
320            Error::InvalidCode(err) => Some(err),
321            _ => None,
322        }
323    }
324}
325
326impl From<simdutf8::compat::Utf8Error> for Error {
327    fn from(err: simdutf8::compat::Utf8Error) -> Self {
328        Error::InvalidUtf8(err)
329    }
330}
331
332impl From<InvalidCode> for Error {
333    fn from(err: InvalidCode) -> Self {
334        Error::InvalidCode(err)
335    }
336}
337
338/// The message for a DNS query.
339#[derive(Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
340pub struct Message<'arrays, 'innards> {
341    /// The header of the message.
342    header: Header,
343
344    /// The questions in the message.
345    ///
346    /// **Invariant:** This is always greater than or equal to `header.question_count`.
347    questions: &'arrays mut [Question<'innards>],
348
349    /// The answers in the message.
350    ///
351    /// **Invariant:** This is always greater than or equal to `header.answer_count`.
352    answers: &'arrays mut [ResourceRecord<'innards>],
353
354    /// The authorities in the message.
355    ///
356    /// **Invariant:** This is always greater than or equal to `header.authority_count`.
357    authorities: &'arrays mut [ResourceRecord<'innards>],
358
359    /// The additional records in the message.
360    ///
361    /// **Invariant:** This is always greater than or equal to `header.additional_count`.
362    additional: &'arrays mut [ResourceRecord<'innards>],
363}
364
365impl fmt::Debug for Message<'_, '_> {
366    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
367        f.debug_struct("Message")
368            .field("header", &self.header)
369            .field("questions", &self.questions())
370            .field("answers", &self.answers())
371            .field("authorities", &self.authorities())
372            .field("additional", &self.additional())
373            .finish()
374    }
375}
376
377impl<'arrays, 'innards> Message<'arrays, 'innards> {
378    /// Create a new message from a set of buffers for each section.
379    ///
380    /// # Panics
381    ///
382    /// This function panics if the number of questions, answers, authorities, or additional records
383    /// is greater than `u16::MAX`.
384    pub fn new(
385        id: u16,
386        flags: Flags,
387        questions: &'arrays mut [Question<'innards>],
388        answers: &'arrays mut [ResourceRecord<'innards>],
389        authorities: &'arrays mut [ResourceRecord<'innards>],
390        additional: &'arrays mut [ResourceRecord<'innards>],
391    ) -> Self {
392        Self {
393            header: Header {
394                id,
395                flags,
396                question_count: questions.len().try_into().unwrap(),
397                answer_count: answers.len().try_into().unwrap(),
398                authority_count: authorities.len().try_into().unwrap(),
399                additional_count: additional.len().try_into().unwrap(),
400            },
401            questions,
402            answers,
403            authorities,
404            additional,
405        }
406    }
407
408    /// Get the ID of this message.
409    pub fn id(&self) -> u16 {
410        self.header.id
411    }
412
413    /// Get a mutable reference to the ID of this message.
414    pub fn id_mut(&mut self) -> &mut u16 {
415        &mut self.header.id
416    }
417
418    /// Get the header of this message.
419    pub fn header(&self) -> Header {
420        self.header
421    }
422
423    /// Get the flags for this message.
424    pub fn flags(&self) -> Flags {
425        self.header.flags
426    }
427
428    /// Get a mutable reference to the flags for this message.
429    pub fn flags_mut(&mut self) -> &mut Flags {
430        &mut self.header.flags
431    }
432
433    /// Get the questions in this message.
434    pub fn questions(&self) -> &[Question<'innards>] {
435        &self.questions[..self.header.question_count as usize]
436    }
437
438    /// Get a mutable reference to the questions in this message.
439    pub fn questions_mut(&mut self) -> &mut [Question<'innards>] {
440        &mut self.questions[..self.header.question_count as usize]
441    }
442
443    /// Get the answers in this message.
444    pub fn answers(&self) -> &[ResourceRecord<'innards>] {
445        &self.answers[..self.header.answer_count as usize]
446    }
447
448    /// Get a mutable reference to the answers in this message.
449    pub fn answers_mut(&mut self) -> &mut [ResourceRecord<'innards>] {
450        &mut self.answers[..self.header.answer_count as usize]
451    }
452
453    /// Get the authorities in this message.
454    pub fn authorities(&self) -> &[ResourceRecord<'innards>] {
455        &self.authorities[..self.header.authority_count as usize]
456    }
457
458    /// Get a mutable reference to the authorities in this message.
459    pub fn authorities_mut(&mut self) -> &mut [ResourceRecord<'innards>] {
460        &mut self.authorities[..self.header.authority_count as usize]
461    }
462
463    /// Get the additional records in this message.
464    pub fn additional(&self) -> &[ResourceRecord<'innards>] {
465        &self.additional[..self.header.additional_count as usize]
466    }
467
468    /// Get a mutable reference to the additional records in this message.
469    pub fn additional_mut(&mut self) -> &mut [ResourceRecord<'innards>] {
470        &mut self.additional[..self.header.additional_count as usize]
471    }
472
473    /// Get the buffer space needed to serialize this message.
474    pub fn space_needed(&self) -> usize {
475        self.serialized_len()
476    }
477
478    /// Write this message to a buffer.
479    ///
480    /// Returns the number of bytes written.
481    ///
482    /// # Errors
483    ///
484    /// This function may raise [`Error::NameTooLong`] if a `Label` is too long to be serialized.
485    ///
486    /// # Panics
487    ///
488    /// This function panics if the buffer is not large enough to hold the serialized message. This
489    /// panic can be avoided by ensuring the buffer contains at least [`space_needed`] bytes.
490    pub fn write(&self, buffer: &mut [u8]) -> Result<usize, Error> {
491        self.serialize(buffer)
492    }
493
494    /// Read a message from a buffer.
495    ///
496    /// # Errors
497    ///
498    /// This function may raise one of the following errors:
499    ///
500    /// - [`Error::NotEnoughReadBytes`] if the buffer is not large enough to hold the entire structure.
501    ///   You may need to read more data before calling this function again.
502    /// - [`Error::NotEnoughWriteSpace`] if the buffers provided are not large enough to hold the
503    ///   entire structure. You may need to allocate larger buffers before calling this function.
504    /// - [`Error::InvalidUtf8`] if a domain name contains invalid UTF-8.
505    /// - [`Error::NameTooLong`] if a domain name is too long to be deserialized.
506    /// - [`Error::InvalidCode`] if a domain name contains an invalid label code.
507    pub fn read(
508        buffer: &'innards [u8],
509        questions: &'arrays mut [Question<'innards>],
510        answers: &'arrays mut [ResourceRecord<'innards>],
511        authorities: &'arrays mut [ResourceRecord<'innards>],
512        additional: &'arrays mut [ResourceRecord<'innards>],
513    ) -> Result<Message<'arrays, 'innards>, Error> {
514        // Create a message and cursor, then deserialize the message from the cursor.
515        let mut message = Message::new(
516            0,
517            Flags::default(),
518            questions,
519            answers,
520            authorities,
521            additional,
522        );
523        let cursor = Cursor::new(buffer);
524
525        message.deserialize(cursor)?;
526
527        Ok(message)
528    }
529}
530
531impl<'innards> Serialize<'innards> for Message<'_, 'innards> {
532    fn serialized_len(&self) -> usize {
533        iter::once(self.header.serialized_len())
534            .chain(self.questions().iter().map(Serialize::serialized_len))
535            .chain(self.answers().iter().map(Serialize::serialized_len))
536            .chain(self.authorities().iter().map(Serialize::serialized_len))
537            .chain(self.additional().iter().map(Serialize::serialized_len))
538            .fold(0, |a, b| a.saturating_add(b))
539    }
540
541    fn serialize(&self, bytes: &mut [u8]) -> Result<usize, Error> {
542        let mut offset = 0;
543        offset += self.header.serialize(&mut bytes[offset..])?;
544        for question in self.questions.iter() {
545            offset += question.serialize(&mut bytes[offset..])?;
546        }
547        for answer in self.answers.iter() {
548            offset += answer.serialize(&mut bytes[offset..])?;
549        }
550        for authority in self.authorities.iter() {
551            offset += authority.serialize(&mut bytes[offset..])?;
552        }
553        for additional in self.additional.iter() {
554            offset += additional.serialize(&mut bytes[offset..])?;
555        }
556        Ok(offset)
557    }
558}
559
560impl<'innards> Deserialize<'innards> for Message<'_, 'innards> {
561    fn deserialize(&mut self, cursor: Cursor<'innards>) -> Result<Cursor<'innards>, Error> {
562        /// Read a set of `T`, bounded by `count`.
563        fn try_read_set<'a, T: Deserialize<'a>>(
564            mut cursor: Cursor<'a>,
565            count: usize,
566            items: &mut [T],
567            ty: BufferType,
568        ) -> Result<Cursor<'a>, Error> {
569            let len = items.len();
570
571            if count == 0 {
572                return Ok(cursor);
573            }
574
575            for i in 0..count {
576                cursor = items
577                    .get_mut(i)
578                    .ok_or_else(|| Error::NotEnoughWriteSpace {
579                        tried_to_write: NonZeroUsize::new(count).unwrap(),
580                        available: len,
581                        buffer_type: ty,
582                    })?
583                    .deserialize(cursor)?;
584            }
585
586            Ok(cursor)
587        }
588
589        let cursor = self.header.deserialize(cursor)?;
590
591        // If we're truncated, we can't read the rest of the message.
592        if self.header.flags.truncated() {
593            self.header.clear();
594            return Ok(cursor);
595        }
596
597        let cursor = try_read_set(
598            cursor,
599            self.header.question_count as usize,
600            self.questions,
601            BufferType::Question,
602        )?;
603        let cursor = try_read_set(
604            cursor,
605            self.header.answer_count as usize,
606            self.answers,
607            BufferType::Answer,
608        )?;
609        let cursor = try_read_set(
610            cursor,
611            self.header.authority_count as usize,
612            self.authorities,
613            BufferType::Authority,
614        )?;
615        let cursor = try_read_set(
616            cursor,
617            self.header.additional_count as usize,
618            self.additional,
619            BufferType::Additional,
620        )?;
621
622        Ok(cursor)
623    }
624}
625
626serialize! {
627    /// The header for a DNS query.
628    #[derive(Debug, Default, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
629    pub struct Header {
630        /// The ID of this query.
631        id: u16,
632
633        /// The flags associated with this query.
634        flags: Flags,
635
636        /// The number of questions in this query.
637        question_count: u16,
638
639        /// The number of answers in this query.
640        answer_count: u16,
641
642        /// The number of authorities in this query.
643        authority_count: u16,
644
645        /// The number of additional records in this query.
646        additional_count: u16,
647    }
648}
649
650impl Header {
651    fn clear(&mut self) {
652        self.question_count = 0;
653        self.answer_count = 0;
654        self.authority_count = 0;
655        self.additional_count = 0;
656    }
657}
658
659serialize! {
660    /// The question in a DNS query.
661    #[derive(Debug, Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
662    pub struct Question<'a> {
663        /// The name of the question.
664        name: Label<'a>,
665
666        /// The type of the question.
667        ty: ResourceType,
668
669        /// The class of the question.
670        class: u16,
671    }
672}
673
674impl<'a> Question<'a> {
675    /// Create a new question.
676    pub fn new(label: impl Into<Label<'a>>, ty: ResourceType, class: u16) -> Self {
677        Self {
678            name: label.into(),
679            ty,
680            class,
681        }
682    }
683
684    /// Get the name of the question.
685    pub fn name(&self) -> Label<'a> {
686        self.name
687    }
688
689    /// Get the type of the question.
690    pub fn ty(&self) -> ResourceType {
691        self.ty
692    }
693
694    /// Get the class of the question.
695    pub fn class(&self) -> u16 {
696        self.class
697    }
698}
699
700serialize! {
701    /// A resource record in a DNS query.
702    #[derive(Debug, Default, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
703    pub struct ResourceRecord<'a> {
704        /// The name of the resource record.
705        name: Label<'a>,
706
707        /// The type of the resource record.
708        ty: ResourceType,
709
710        /// The class of the resource record.
711        class: u16,
712
713        /// The time-to-live of the resource record.
714        ttl: u32,
715
716        /// The data of the resource record.
717        data: ResourceData<'a>,
718    }
719}
720
721impl<'a> ResourceRecord<'a> {
722    /// Create a new `ResourceRecord`.
723    pub fn new(
724        name: impl Into<Label<'a>>,
725        ty: ResourceType,
726        class: u16,
727        ttl: u32,
728        data: &'a [u8],
729    ) -> Self {
730        Self {
731            name: name.into(),
732            ty,
733            class,
734            ttl,
735            data: data.into(),
736        }
737    }
738
739    /// Get the name of the resource record.
740    pub fn name(&self) -> Label<'a> {
741        self.name
742    }
743
744    /// Get the type of the resource record.
745    pub fn ty(&self) -> ResourceType {
746        self.ty
747    }
748
749    /// Get the class of the resource record.
750    pub fn class(&self) -> u16 {
751        self.class
752    }
753
754    /// Get the time-to-live of the resource record.
755    pub fn ttl(&self) -> u32 {
756        self.ttl
757    }
758
759    /// Get the data of the resource record.
760    pub fn data(&self) -> &'a [u8] {
761        self.data.0
762    }
763}
764
765/// The resource stored in a resource record.
766#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
767pub(crate) struct ResourceData<'a>(&'a [u8]);
768
769impl<'a> From<&'a [u8]> for ResourceData<'a> {
770    fn from(data: &'a [u8]) -> Self {
771        ResourceData(data)
772    }
773}
774
775impl<'a> Serialize<'a> for ResourceData<'a> {
776    fn serialized_len(&self) -> usize {
777        2 + self.0.len()
778    }
779
780    fn serialize(&self, bytes: &mut [u8]) -> Result<usize, Error> {
781        let len = self.serialized_len();
782        if bytes.len() < len {
783            panic!("not enough bytes to serialize resource data");
784        }
785
786        // Write the length as a big-endian u16
787        let [b1, b2] = (self.0.len() as u16).to_be_bytes();
788        bytes[0] = b1;
789        bytes[1] = b2;
790
791        // Write the data
792        bytes[2..len].copy_from_slice(self.0);
793
794        Ok(len)
795    }
796}
797
798impl<'a> Deserialize<'a> for ResourceData<'a> {
799    fn deserialize(&mut self, cursor: Cursor<'a>) -> Result<Cursor<'a>, Error> {
800        // Deserialize a u16 for the length
801        let mut len = 0u16;
802        let cursor = len.deserialize(cursor)?;
803
804        if len == 0 {
805            self.0 = &[];
806            return Ok(cursor);
807        }
808
809        // Read in the data
810        if cursor.len() < len as usize {
811            return Err(Error::NotEnoughReadBytes {
812                tried_to_read: NonZeroUsize::new(len as usize).unwrap(),
813                available: cursor.len(),
814            });
815        }
816
817        self.0 = &cursor.remaining()[..len as usize];
818        cursor.advance(len as usize)
819    }
820}
821
822/// The flags associated with a DNS message.
823#[derive(Default, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
824#[repr(transparent)]
825pub struct Flags(u16);
826
827impl Flags {
828    // Values used to manipulate the inside.
829    const RAW_QR: u16 = 1 << 15;
830    const RAW_OPCODE_SHIFT: u16 = 11;
831    const RAW_OPCODE_MASK: u16 = 0b1111;
832    const RAW_AA: u16 = 1 << 10;
833    const RAW_TC: u16 = 1 << 9;
834    const RAW_RD: u16 = 1 << 8;
835    const RAW_RA: u16 = 1 << 7;
836    const RAW_RCODE_SHIFT: u16 = 0;
837    const RAW_RCODE_MASK: u16 = 0b1111;
838
839    /// Create a new, empty set of flags.
840    pub const fn new() -> Self {
841        Self(0)
842    }
843
844    /// Use the standard set of flags for a DNS query.
845    ///
846    /// This is identical to `new()` but uses recursive querying.
847    pub const fn standard_query() -> Self {
848        Self(0x0100)
849    }
850
851    /// Get the query/response flag.
852    pub fn qr(&self) -> MessageType {
853        if self.0 & Self::RAW_QR != 0 {
854            MessageType::Reply
855        } else {
856            MessageType::Query
857        }
858    }
859
860    /// Set the message's query/response flag.
861    pub fn set_qr(&mut self, qr: MessageType) -> &mut Self {
862        if qr == MessageType::Reply {
863            self.0 |= Self::RAW_QR;
864        } else {
865            self.0 &= !Self::RAW_QR;
866        }
867
868        self
869    }
870
871    /// Get the opcode.
872    pub fn opcode(&self) -> Opcode {
873        let raw = (self.0 >> Self::RAW_OPCODE_SHIFT) & Self::RAW_OPCODE_MASK;
874        raw.try_into()
875            .unwrap_or_else(|_| panic!("invalid opcode: {}", raw))
876    }
877
878    /// Set the opcode.
879    pub fn set_opcode(&mut self, opcode: Opcode) {
880        self.0 |= (opcode as u16) << Self::RAW_OPCODE_SHIFT;
881    }
882
883    /// Get whether this message is authoritative.
884    pub fn authoritative(&self) -> bool {
885        self.0 & Self::RAW_AA != 0
886    }
887
888    /// Set whether this message is authoritative.
889    pub fn set_authoritative(&mut self, authoritative: bool) -> &mut Self {
890        if authoritative {
891            self.0 |= Self::RAW_AA;
892        } else {
893            self.0 &= !Self::RAW_AA;
894        }
895
896        self
897    }
898
899    /// Get whether this message is truncated.
900    pub fn truncated(&self) -> bool {
901        self.0 & Self::RAW_TC != 0
902    }
903
904    /// Set whether this message is truncated.
905    pub fn set_truncated(&mut self, truncated: bool) -> &mut Self {
906        if truncated {
907            self.0 |= Self::RAW_TC;
908        } else {
909            self.0 &= !Self::RAW_TC;
910        }
911
912        self
913    }
914
915    /// Get whether this message is recursive.
916    pub fn recursive(&self) -> bool {
917        self.0 & Self::RAW_RD != 0
918    }
919
920    /// Set whether this message is recursive.
921    pub fn set_recursive(&mut self, recursive: bool) -> &mut Self {
922        if recursive {
923            self.0 |= Self::RAW_RD;
924        } else {
925            self.0 &= !Self::RAW_RD;
926        }
927
928        self
929    }
930
931    /// Get whether recursion is available for this message.
932    pub fn recursion_available(&self) -> bool {
933        self.0 & Self::RAW_RA != 0
934    }
935
936    /// Set whether recursion is available for this message.
937    pub fn set_recursion_available(&mut self, recursion_available: bool) -> &mut Self {
938        if recursion_available {
939            self.0 |= Self::RAW_RA;
940        } else {
941            self.0 &= !Self::RAW_RA;
942        }
943
944        self
945    }
946
947    /// Get the response code.
948    pub fn response_code(&self) -> ResponseCode {
949        let raw = (self.0 >> Self::RAW_RCODE_SHIFT) & Self::RAW_RCODE_MASK;
950        raw.try_into()
951            .unwrap_or_else(|_| panic!("invalid response code: {}", raw))
952    }
953
954    /// Set the response code.
955    pub fn set_response_code(&mut self, response_code: ResponseCode) -> &mut Self {
956        self.0 |= (response_code as u16) << Self::RAW_RCODE_SHIFT;
957        self
958    }
959
960    /// Get the raw value of these flags.
961    pub fn raw(self) -> u16 {
962        self.0
963    }
964}
965
966impl fmt::Debug for Flags {
967    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
968        let mut list = f.debug_list();
969
970        list.entry(&self.qr());
971        list.entry(&self.opcode());
972
973        if self.authoritative() {
974            list.entry(&"authoritative");
975        }
976
977        if self.truncated() {
978            list.entry(&"truncated");
979        }
980
981        if self.recursive() {
982            list.entry(&"recursive");
983        }
984
985        if self.recursion_available() {
986            list.entry(&"recursion available");
987        }
988
989        list.entry(&self.response_code());
990
991        list.finish()
992    }
993}
994
995impl Serialize<'_> for Flags {
996    fn serialized_len(&self) -> usize {
997        2
998    }
999
1000    fn serialize(&self, buf: &mut [u8]) -> Result<usize, Error> {
1001        self.0.serialize(buf)
1002    }
1003}
1004
1005impl<'a> Deserialize<'a> for Flags {
1006    fn deserialize(&mut self, bytes: Cursor<'a>) -> Result<Cursor<'a>, Error> {
1007        u16::deserialize(&mut self.0, bytes)
1008    }
1009}
1010
1011/// Whether a message is a query or a reply.
1012#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
1013pub enum MessageType {
1014    /// The message is a query.
1015    Query,
1016
1017    /// The message is a reply.
1018    Reply,
1019}
1020
1021num_enum! {
1022    /// The operation code for the query.
1023    pub enum Opcode {
1024        /// A standard query.
1025        Query = 0,
1026
1027        /// A reverse query.
1028        IQuery = 1,
1029
1030        /// A server status request.
1031        Status = 2,
1032
1033        /// A notification of zone change.
1034        Notify = 4,
1035
1036        /// A dynamic update.
1037        Update = 5,
1038
1039        /// DSO query.
1040        Dso = 6,
1041    }
1042}
1043
1044num_enum! {
1045    /// The response code for a query.
1046    pub enum ResponseCode {
1047        /// There was no error in the query.
1048        NoError = 0,
1049
1050        /// The query was malformed.
1051        FormatError = 1,
1052
1053        /// The server failed to fulfill the query.
1054        ServerFailure = 2,
1055
1056        /// The name does not exist.
1057        NameError = 3,
1058
1059        /// The query is not implemented.
1060        NotImplemented = 4,
1061
1062        /// The query is refused.
1063        Refused = 5,
1064
1065        /// The name exists, but the query type is not supported.
1066        YxDomain = 6,
1067
1068        /// The name does not exist, but the query type is supported.
1069        YxRrSet = 7,
1070
1071        /// The name exists, but the query type is not supported.
1072        NxRrSet = 8,
1073
1074        /// The server is not authoritative for the zone.
1075        NotAuth = 9,
1076
1077        /// The name does not exist in the zone.
1078        NotZone = 10,
1079
1080        /// The DSO-TYPE is not supported.
1081        DsoTypeNi = 11,
1082
1083        /// Bad OPT version.
1084        BadVers = 16,
1085
1086        /// Key not recognized.
1087        BadKey = 17,
1088
1089        /// Signature out of time window.
1090        BadTime = 18,
1091
1092        /// Bad TKEY mode.
1093        BadMode = 19,
1094
1095        /// Duplicate key name.
1096        BadName = 20,
1097
1098        /// Algorithm not supported.
1099        BadAlg = 21,
1100
1101        /// Bad truncation.
1102        BadTrunc = 22,
1103
1104        /// Bad/missing server cookie.
1105        BadCookie = 23,
1106    }
1107}
1108
1109num_enum! {
1110    /// The resource types that a question can ask for.
1111    pub enum ResourceType {
1112        /// Get the host's IPv4 address.
1113        A = 1,
1114
1115        /// Get the authoritative name servers for a domain.
1116        NS = 2,
1117
1118        /// Get the mail server for a domain.
1119        MD = 3,
1120
1121        /// Get the mail forwarder for a domain.
1122        MF = 4,
1123
1124        /// Get the canonical name for a domain.
1125        CName = 5,
1126
1127        /// Get the start of authority record for a domain.
1128        Soa = 6,
1129
1130        /// Get the mailbox for a domain.
1131        MB = 7,
1132
1133        /// Get the mail group member for a domain.
1134        MG = 8,
1135
1136        /// Get the mail rename domain for a domain.
1137        MR = 9,
1138
1139        /// Get the null record for a domain.
1140        Null = 10,
1141
1142        /// Get the well known services for a domain.
1143        Wks = 11,
1144
1145        /// Get the domain pointer for a domain.
1146        Ptr = 12,
1147
1148        /// Get the host information for a domain.
1149        HInfo = 13,
1150
1151        /// Get the mailbox or mail list information for a domain.
1152        MInfo = 14,
1153
1154        /// Get the mail exchange for a domain.
1155        MX = 15,
1156
1157        /// Get the text for a domain.
1158        Txt = 16,
1159
1160        /// Get the responsible person for a domain.
1161        RP = 17,
1162
1163        /// Get the AFS database location for a domain.
1164        AfsDb = 18,
1165
1166        /// Get the X.25 address for a domain.
1167        X25 = 19,
1168
1169        /// Get the ISDN address for a domain.
1170        Isdn = 20,
1171
1172        /// Get the router for a domain.
1173        Rt = 21,
1174
1175        /// Get the NSAP address for a domain.
1176        NSap = 22,
1177
1178        /// Get the reverse NSAP address for a domain.
1179        NSapPtr = 23,
1180
1181        /// Get the security signature for a domain.
1182        Sig = 24,
1183
1184        /// Get the key for a domain.
1185        Key = 25,
1186
1187        /// Get the X.400 mail mapping for a domain.
1188        Px = 26,
1189
1190        /// Get the geographical location for a domain.
1191        GPos = 27,
1192
1193        /// Get the IPv6 address for a domain.
1194        AAAA = 28,
1195
1196        /// Get the location for a domain.
1197        Loc = 29,
1198
1199        /// Get the next domain name in a zone.
1200        Nxt = 30,
1201
1202        /// Get the endpoint identifier for a domain.
1203        EId = 31,
1204
1205        /// Get the Nimrod locator for a domain.
1206        NimLoc = 32,
1207
1208        /// Get the server selection for a domain.
1209        Srv = 33,
1210
1211        /// Get the ATM address for a domain.
1212        AtmA = 34,
1213
1214        /// Get the naming authority pointer for a domain.
1215        NAPtr = 35,
1216
1217        /// Get the key exchange for a domain.
1218        Kx = 36,
1219
1220        /// Get the certificate for a domain.
1221        Cert = 37,
1222
1223        /// Get the IPv6 address for a domain.
1224        ///
1225        /// This is obsolete; use `AAAA` instead.
1226        A6 = 38,
1227
1228        /// Get the DNAME for a domain.
1229        DName = 39,
1230
1231        /// Get the sink for a domain.
1232        Sink = 40,
1233
1234        /// Get the OPT for a domain.
1235        Opt = 41,
1236
1237        /// Get the address prefix list for a domain.
1238        ApL = 42,
1239
1240        /// Get the delegation signer for a domain.
1241        DS = 43,
1242
1243        /// Get the SSH key fingerprint for a domain.
1244        SshFp = 44,
1245
1246        /// Get the IPSEC key for a domain.
1247        IpSecKey = 45,
1248
1249        /// Get the resource record signature for a domain.
1250        RRSig = 46,
1251
1252        /// Get the next secure record for a domain.
1253        NSEC = 47,
1254
1255        /// Get the DNSKEY for a domain.
1256        DNSKey = 48,
1257
1258        /// Get the DHCID for a domain.
1259        DHCID = 49,
1260
1261        /// Get the NSEC3 for a domain.
1262        NSEC3 = 50,
1263
1264        /// Get the NSEC3 parameters for a domain.
1265        NSEC3Param = 51,
1266
1267        /// Get the TLSA for a domain.
1268        TLSA = 52,
1269
1270        /// Get the S/MIME certificate association for a domain.
1271        SMimeA = 53,
1272
1273        /// Get the host information for a domain.
1274        HIP = 55,
1275
1276        /// Get the NINFO for a domain.
1277        NInfo = 56,
1278
1279        /// Get the RKEY for a domain.
1280        RKey = 57,
1281
1282        /// Get the trust anchor link for a domain.
1283        TALink = 58,
1284
1285        /// Get the child DS for a domain.
1286        CDS = 59,
1287
1288        /// Get the DNSKEY for a domain.
1289        CDNSKey = 60,
1290
1291        /// Get the OpenPGP key for a domain.
1292        OpenPGPKey = 61,
1293
1294        /// Get the Child-to-Parent Synchronization for a domain.
1295        CSync = 62,
1296
1297        /// Get the Zone Data Message for a domain.
1298        ZoneMD = 63,
1299
1300        /// Get the General Purpose Service Binding for a domain.
1301        Svcb = 64,
1302
1303        /// Get the HTTP Service Binding for a domain.
1304        Https = 65,
1305
1306        /// Get the Sender Policy Framework for a domain.
1307        Spf = 99,
1308
1309        /// Get the UINFO for a domain.
1310        UInfo = 100,
1311
1312        /// Get the UID for a domain.
1313        UID = 101,
1314
1315        /// Get the GID for a domain.
1316        GID = 102,
1317
1318        /// Get the UNSPEC for a domain.
1319        Unspec = 103,
1320
1321        /// Get the NID for a domain.
1322        NID = 104,
1323
1324        /// Get the L32 for a domain.
1325        L32 = 105,
1326
1327        /// Get the L64 for a domain.
1328        L64 = 106,
1329
1330        /// Get the LP for a domain.
1331        LP = 107,
1332
1333        /// Get the EUI48 for a domain.
1334        EUI48 = 108,
1335
1336        /// Get the EUI64 for a domain.
1337        EUI64 = 109,
1338
1339        /// Get the transaction key for a domain.
1340        TKey = 249,
1341
1342        /// Get the transaction signature for a domain.
1343        TSig = 250,
1344
1345        /// Get the incremental transfer for a domain.
1346        Ixfr = 251,
1347
1348        /// Get the transfer of an entire zone for a domain.
1349        Axfr = 252,
1350
1351        /// Get the mailbox-related records for a domain.
1352        MailB = 253,
1353
1354        /// Get the mail agent RRs for a domain.
1355        ///
1356        /// This is obsolete; use `MX` instead.
1357        MailA = 254,
1358
1359        /// Get the wildcard match for a domain.
1360        Wildcard = 255,
1361
1362        /// Get the URI for a domain.
1363        Uri = 256,
1364
1365        /// Get the certification authority authorization for a domain.
1366        Caa = 257,
1367
1368        /// Get the application visibility and control for a domain.
1369        Avc = 258,
1370
1371        /// Get the digital object architecture for a domain.
1372        Doa = 259,
1373
1374        /// Get the automatic network discovery for a domain.
1375        Amtrelay = 260,
1376
1377        /// Get the DNSSEC trust authorities for a domain.
1378        TA = 32768,
1379
1380        /// Get the DNSSEC lookaside validation for a domain.
1381        DLV = 32769,
1382    }
1383}
1384
1385impl Default for ResourceType {
1386    fn default() -> Self {
1387        Self::A
1388    }
1389}
1390
1391/// The given value is not a valid code.
1392#[derive(Debug, Clone)]
1393pub struct InvalidCode(u16);
1394
1395impl InvalidCode {
1396    /// Get the invalid code.
1397    pub fn code(&self) -> u16 {
1398        self.0
1399    }
1400}
1401
1402impl fmt::Display for InvalidCode {
1403    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1404        write!(f, "invalid code: {}", self.0)
1405    }
1406}
1407
1408impl StdError for InvalidCode {}
1409
1410#[cfg(test)]
1411mod tests {
1412    use super::*;
1413
1414    #[test]
1415    fn resource_data_serialization() {
1416        let mut buf = [0u8; 7];
1417        let record = ResourceData(&[0x1f, 0xfe, 0x02, 0x24, 0x75]);
1418        let len = record
1419            .serialize(&mut buf)
1420            .expect("serialized into provided buffer");
1421        assert_eq!(len, 7);
1422        assert_eq!(buf, [0x00, 0x05, 0x1f, 0xfe, 0x02, 0x24, 0x75]);
1423    }
1424}