dns_protocol_patch/lib.rs
1//! An implementation of the DNS protocol, [sans I/O].
2//!
3//! [sans I/O]: https://sans-io.readthedocs.io/
4//!
5//! This crate implements the Domain System Protocol, used for namespace lookups among
6//! other things. It is intended to be used in conjunction with a transport layer, such
7//! as UDP, TCP or HTTPS The goal of this crate is to provide a runtime and protocol-agnostic
8//! implementation of the DNS protocol, so that it can be used in a variety of contexts.
9//!
10//! This crate is not only `no_std`, it does not use an allocator as well. This means that it
11//! can be used on embedded systems that do not have allocators, and that it does no allocation
12//! of its own. However, this comes with a catch: the user is expected to provide their own
13//! buffers for the various operations. This is done to avoid the need for an allocator, and
14//! to allow the user to control the memory usage of the library.
15//!
16//! This crate is also `#![forbid(unsafe_code)]`, and is intended to remain so.
17//!
18//! # Example
19//!
20//! ```no_run
21//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
22//! use dns_protocol::{Message, Question, ResourceRecord, ResourceType, Flags};
23//! use std::net::UdpSocket;
24//!
25//! // Allocate a buffer for the message.
26//! let mut buf = vec![0; 1024];
27//!
28//! // Create a message. This is a query for the A record of example.com.
29//! let mut questions = [
30//! Question::new(
31//! "example.com",
32//! ResourceType::A,
33//! 0,
34//! )
35//! ];
36//! let mut answers = [ResourceRecord::default()];
37//! let message = Message::new(0x42, Flags::default(), &mut questions, &mut answers, &mut [], &mut []);
38//!
39//! // Serialize the message into the buffer
40//! assert!(message.space_needed() <= buf.len());
41//! let len = message.write(&mut buf)?;
42//!
43//! // Write the buffer to the socket.
44//! let socket = UdpSocket::bind("localhost:0")?;
45//! socket.send_to(&buf[..len], "1.2.3.4:53")?;
46//!
47//! // Read new data from the socket.
48//! let data_len = socket.recv(&mut buf)?;
49//!
50//! // Parse the data as a message.
51//! let message = Message::read(
52//! &buf[..data_len],
53//! &mut questions,
54//! &mut answers,
55//! &mut [],
56//! &mut [],
57//! )?;
58//!
59//! // Read the answer from the message.
60//! let answer = message.answers()[0];
61//! println!("Answer Data: {:?}", answer.data());
62//! # Ok(())
63//! # }
64//! ```
65//!
66//! # Features
67//!
68//! - `std` (enabled by default) - Enables the `std` library for use in `Error` types.
69//! Disable this feature to use on `no_std` targets.
70//! ```
71
72#![forbid(
73 unsafe_code,
74 missing_docs,
75 missing_debug_implementations,
76 rust_2018_idioms,
77 future_incompatible
78)]
79#![no_std]
80
81#[cfg(feature = "std")]
82extern crate std;
83
84use core::convert::{TryFrom, TryInto};
85use core::fmt;
86use core::iter;
87use core::mem;
88use core::num::NonZeroUsize;
89use core::str;
90
91use core::error::Error as StdError;
92
93mod ser;
94pub use ser::{Cursor, Deserialize, Label, LabelSegment, Serialize};
95
96/// Macro to implement `Serialize` for a struct.
97macro_rules! serialize {
98 (
99 $(#[$outer:meta])*
100 pub struct $name:ident $(<$lt: lifetime>)? {
101 $(
102 $(#[$inner:meta])*
103 $vis: vis $field:ident: $ty:ty,
104 )*
105 }
106 ) => {
107 $(#[$outer])*
108 pub struct $name $(<$lt>)? {
109 $(
110 $(#[$inner])*
111 $vis $field: $ty,
112 )*
113 }
114
115 impl<'a> Serialize<'a> for $name $(<$lt>)? {
116 fn serialized_len(&self) -> usize {
117 let mut len = 0;
118 $(
119 len += self.$field.serialized_len();
120 )*
121 len
122 }
123
124 fn serialize(&self, cursor: &mut [u8]) -> Result<usize, Error> {
125 let mut index = 0;
126 $(
127 index += self.$field.serialize(&mut cursor[index..])?;
128 )*
129 Ok(index)
130 }
131 }
132
133 impl<'a> Deserialize<'a> for $name $(<$lt>)? {
134 fn deserialize(&mut self, mut cursor: Cursor<'a>) -> Result<Cursor<'a>, Error> {
135 $(
136 cursor = self.$field.deserialize(cursor)?;
137 )*
138 Ok(cursor)
139 }
140 }
141 };
142}
143
144/// An enum with a bevy of given variants.
145macro_rules! num_enum {
146 (
147 $(#[$outer:meta])*
148 pub enum $name:ident {
149 $(
150 $(#[$inner:meta])*
151 $variant:ident = $value:expr,
152 )*
153 }
154 ) => {
155 $(#[$outer])*
156 #[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
157 #[repr(u16)]
158 // New codes may be added in the future.
159 #[non_exhaustive]
160 pub enum $name {
161 $(
162 $(#[$inner])*
163 $variant = $value,
164 )*
165 }
166
167 impl TryFrom<u16> for $name {
168 type Error = InvalidCode;
169
170 fn try_from(value: u16) -> Result<Self, Self::Error> {
171 match value {
172 $(
173 $value => Ok($name::$variant),
174 )*
175 _ => Err(InvalidCode(value)),
176 }
177 }
178 }
179
180 impl From<$name> for u16 {
181 fn from(value: $name) -> Self {
182 value as u16
183 }
184 }
185
186 impl<'a> Serialize<'a> for $name {
187 fn serialized_len(&self) -> usize {
188 mem::size_of::<u16>()
189 }
190
191 fn serialize(&self, cursor: &mut [u8]) -> Result<usize, Error> {
192 let value: u16 = (*self).into();
193 value.serialize(cursor)
194 }
195 }
196
197 impl<'a> Deserialize<'a> for $name {
198 fn deserialize(&mut self, cursor: Cursor<'a>) -> Result<Cursor<'a>, Error> {
199 let mut value = 0;
200 let cursor = value.deserialize(cursor)?;
201 *self = value.try_into()?;
202 Ok(cursor)
203 }
204 }
205 };
206}
207
208/// The buffer type when dealing with DNS messages.
209#[derive(Debug, Clone, Copy, PartialEq, Eq, derive_more::IsVariant, derive_more::Display)]
210pub enum BufferType {
211 /// The question buffer.
212 #[display("question")]
213 Question,
214 /// The answer buffer.
215 #[display("answer")]
216 Answer,
217 /// The authority buffer.
218 #[display("authority")]
219 Authority,
220 /// The additional buffer.
221 #[display("additional")]
222 Additional,
223}
224
225/// An error that may occur while using the DNS protocol.
226#[derive(Debug)]
227#[non_exhaustive]
228pub enum Error {
229 /// We are trying to write to a buffer, but the buffer doesn't have enough space.
230 ///
231 /// This error can be fixed by increasing the size of the buffer.
232 NotEnoughWriteSpace {
233 /// The number of entries we tried to write.
234 tried_to_write: NonZeroUsize,
235
236 /// The number of entries that were available in the buffer.
237 available: usize,
238
239 /// The type of the buffer that we tried to write to.
240 buffer_type: BufferType,
241 },
242
243 /// We attempted to read from a buffer, but we ran out of room before we could read the entire
244 /// value.
245 ///
246 /// This error can be fixed by reading more bytes.
247 NotEnoughReadBytes {
248 /// The number of bytes we tried to read.
249 tried_to_read: NonZeroUsize,
250
251 /// The number of bytes that were available in the buffer.
252 available: usize,
253 },
254
255 /// We tried to parse this value, but it was invalid.
256 Parse {
257 /// The name of the value we tried to parse.
258 name: &'static str,
259 },
260
261 /// We tried to serialize a string longer than 256 bytes.
262 NameTooLong(usize),
263
264 /// We could not create a valid UTF-8 string from the bytes we read.
265 InvalidUtf8(simdutf8::compat::Utf8Error),
266
267 /// We could not convert a raw number to a code.
268 InvalidCode(InvalidCode),
269
270 /// We do not support this many URL segments.
271 TooManyUrlSegments(usize),
272}
273
274impl fmt::Display for Error {
275 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
276 match self {
277 Error::NotEnoughWriteSpace {
278 tried_to_write,
279 available,
280 buffer_type,
281 } => {
282 write!(
283 f,
284 "not enough write space: tried to write {} entries to {} buffer, but only {} were available",
285 tried_to_write, buffer_type, available
286 )
287 }
288 Error::NotEnoughReadBytes {
289 tried_to_read,
290 available,
291 } => {
292 write!(
293 f,
294 "not enough read bytes: tried to read {} bytes, but only {} were available",
295 tried_to_read, available
296 )
297 }
298 Error::Parse { name } => {
299 write!(f, "parse error: could not parse a {}", name)
300 }
301 Error::NameTooLong(len) => {
302 write!(f, "name too long: name was {} bytes long", len)
303 }
304 Error::InvalidUtf8(err) => {
305 write!(f, "invalid UTF-8: {}", err)
306 }
307 Error::TooManyUrlSegments(segments) => {
308 write!(f, "too many URL segments: {} segments", segments)
309 }
310 Error::InvalidCode(err) => {
311 write!(f, "{}", err)
312 }
313 }
314 }
315}
316
317impl StdError for Error {
318 fn source(&self) -> Option<&(dyn StdError + 'static)> {
319 match self {
320 Error::InvalidCode(err) => Some(err),
321 _ => None,
322 }
323 }
324}
325
326impl From<simdutf8::compat::Utf8Error> for Error {
327 fn from(err: simdutf8::compat::Utf8Error) -> Self {
328 Error::InvalidUtf8(err)
329 }
330}
331
332impl From<InvalidCode> for Error {
333 fn from(err: InvalidCode) -> Self {
334 Error::InvalidCode(err)
335 }
336}
337
338/// The message for a DNS query.
339#[derive(Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
340pub struct Message<'arrays, 'innards> {
341 /// The header of the message.
342 header: Header,
343
344 /// The questions in the message.
345 ///
346 /// **Invariant:** This is always greater than or equal to `header.question_count`.
347 questions: &'arrays mut [Question<'innards>],
348
349 /// The answers in the message.
350 ///
351 /// **Invariant:** This is always greater than or equal to `header.answer_count`.
352 answers: &'arrays mut [ResourceRecord<'innards>],
353
354 /// The authorities in the message.
355 ///
356 /// **Invariant:** This is always greater than or equal to `header.authority_count`.
357 authorities: &'arrays mut [ResourceRecord<'innards>],
358
359 /// The additional records in the message.
360 ///
361 /// **Invariant:** This is always greater than or equal to `header.additional_count`.
362 additional: &'arrays mut [ResourceRecord<'innards>],
363}
364
365impl fmt::Debug for Message<'_, '_> {
366 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
367 f.debug_struct("Message")
368 .field("header", &self.header)
369 .field("questions", &self.questions())
370 .field("answers", &self.answers())
371 .field("authorities", &self.authorities())
372 .field("additional", &self.additional())
373 .finish()
374 }
375}
376
377impl<'arrays, 'innards> Message<'arrays, 'innards> {
378 /// Create a new message from a set of buffers for each section.
379 ///
380 /// # Panics
381 ///
382 /// This function panics if the number of questions, answers, authorities, or additional records
383 /// is greater than `u16::MAX`.
384 pub fn new(
385 id: u16,
386 flags: Flags,
387 questions: &'arrays mut [Question<'innards>],
388 answers: &'arrays mut [ResourceRecord<'innards>],
389 authorities: &'arrays mut [ResourceRecord<'innards>],
390 additional: &'arrays mut [ResourceRecord<'innards>],
391 ) -> Self {
392 Self {
393 header: Header {
394 id,
395 flags,
396 question_count: questions.len().try_into().unwrap(),
397 answer_count: answers.len().try_into().unwrap(),
398 authority_count: authorities.len().try_into().unwrap(),
399 additional_count: additional.len().try_into().unwrap(),
400 },
401 questions,
402 answers,
403 authorities,
404 additional,
405 }
406 }
407
408 /// Get the ID of this message.
409 pub fn id(&self) -> u16 {
410 self.header.id
411 }
412
413 /// Get a mutable reference to the ID of this message.
414 pub fn id_mut(&mut self) -> &mut u16 {
415 &mut self.header.id
416 }
417
418 /// Get the header of this message.
419 pub fn header(&self) -> Header {
420 self.header
421 }
422
423 /// Get the flags for this message.
424 pub fn flags(&self) -> Flags {
425 self.header.flags
426 }
427
428 /// Get a mutable reference to the flags for this message.
429 pub fn flags_mut(&mut self) -> &mut Flags {
430 &mut self.header.flags
431 }
432
433 /// Get the questions in this message.
434 pub fn questions(&self) -> &[Question<'innards>] {
435 &self.questions[..self.header.question_count as usize]
436 }
437
438 /// Get a mutable reference to the questions in this message.
439 pub fn questions_mut(&mut self) -> &mut [Question<'innards>] {
440 &mut self.questions[..self.header.question_count as usize]
441 }
442
443 /// Get the answers in this message.
444 pub fn answers(&self) -> &[ResourceRecord<'innards>] {
445 &self.answers[..self.header.answer_count as usize]
446 }
447
448 /// Get a mutable reference to the answers in this message.
449 pub fn answers_mut(&mut self) -> &mut [ResourceRecord<'innards>] {
450 &mut self.answers[..self.header.answer_count as usize]
451 }
452
453 /// Get the authorities in this message.
454 pub fn authorities(&self) -> &[ResourceRecord<'innards>] {
455 &self.authorities[..self.header.authority_count as usize]
456 }
457
458 /// Get a mutable reference to the authorities in this message.
459 pub fn authorities_mut(&mut self) -> &mut [ResourceRecord<'innards>] {
460 &mut self.authorities[..self.header.authority_count as usize]
461 }
462
463 /// Get the additional records in this message.
464 pub fn additional(&self) -> &[ResourceRecord<'innards>] {
465 &self.additional[..self.header.additional_count as usize]
466 }
467
468 /// Get a mutable reference to the additional records in this message.
469 pub fn additional_mut(&mut self) -> &mut [ResourceRecord<'innards>] {
470 &mut self.additional[..self.header.additional_count as usize]
471 }
472
473 /// Get the buffer space needed to serialize this message.
474 pub fn space_needed(&self) -> usize {
475 self.serialized_len()
476 }
477
478 /// Write this message to a buffer.
479 ///
480 /// Returns the number of bytes written.
481 ///
482 /// # Errors
483 ///
484 /// This function may raise [`Error::NameTooLong`] if a `Label` is too long to be serialized.
485 ///
486 /// # Panics
487 ///
488 /// This function panics if the buffer is not large enough to hold the serialized message. This
489 /// panic can be avoided by ensuring the buffer contains at least [`space_needed`] bytes.
490 pub fn write(&self, buffer: &mut [u8]) -> Result<usize, Error> {
491 self.serialize(buffer)
492 }
493
494 /// Read a message from a buffer.
495 ///
496 /// # Errors
497 ///
498 /// This function may raise one of the following errors:
499 ///
500 /// - [`Error::NotEnoughReadBytes`] if the buffer is not large enough to hold the entire structure.
501 /// You may need to read more data before calling this function again.
502 /// - [`Error::NotEnoughWriteSpace`] if the buffers provided are not large enough to hold the
503 /// entire structure. You may need to allocate larger buffers before calling this function.
504 /// - [`Error::InvalidUtf8`] if a domain name contains invalid UTF-8.
505 /// - [`Error::NameTooLong`] if a domain name is too long to be deserialized.
506 /// - [`Error::InvalidCode`] if a domain name contains an invalid label code.
507 pub fn read(
508 buffer: &'innards [u8],
509 questions: &'arrays mut [Question<'innards>],
510 answers: &'arrays mut [ResourceRecord<'innards>],
511 authorities: &'arrays mut [ResourceRecord<'innards>],
512 additional: &'arrays mut [ResourceRecord<'innards>],
513 ) -> Result<Message<'arrays, 'innards>, Error> {
514 // Create a message and cursor, then deserialize the message from the cursor.
515 let mut message = Message::new(
516 0,
517 Flags::default(),
518 questions,
519 answers,
520 authorities,
521 additional,
522 );
523 let cursor = Cursor::new(buffer);
524
525 message.deserialize(cursor)?;
526
527 Ok(message)
528 }
529}
530
531impl<'innards> Serialize<'innards> for Message<'_, 'innards> {
532 fn serialized_len(&self) -> usize {
533 iter::once(self.header.serialized_len())
534 .chain(self.questions().iter().map(Serialize::serialized_len))
535 .chain(self.answers().iter().map(Serialize::serialized_len))
536 .chain(self.authorities().iter().map(Serialize::serialized_len))
537 .chain(self.additional().iter().map(Serialize::serialized_len))
538 .fold(0, |a, b| a.saturating_add(b))
539 }
540
541 fn serialize(&self, bytes: &mut [u8]) -> Result<usize, Error> {
542 let mut offset = 0;
543 offset += self.header.serialize(&mut bytes[offset..])?;
544 for question in self.questions.iter() {
545 offset += question.serialize(&mut bytes[offset..])?;
546 }
547 for answer in self.answers.iter() {
548 offset += answer.serialize(&mut bytes[offset..])?;
549 }
550 for authority in self.authorities.iter() {
551 offset += authority.serialize(&mut bytes[offset..])?;
552 }
553 for additional in self.additional.iter() {
554 offset += additional.serialize(&mut bytes[offset..])?;
555 }
556 Ok(offset)
557 }
558}
559
560impl<'innards> Deserialize<'innards> for Message<'_, 'innards> {
561 fn deserialize(&mut self, cursor: Cursor<'innards>) -> Result<Cursor<'innards>, Error> {
562 /// Read a set of `T`, bounded by `count`.
563 fn try_read_set<'a, T: Deserialize<'a>>(
564 mut cursor: Cursor<'a>,
565 count: usize,
566 items: &mut [T],
567 ty: BufferType,
568 ) -> Result<Cursor<'a>, Error> {
569 let len = items.len();
570
571 if count == 0 {
572 return Ok(cursor);
573 }
574
575 for i in 0..count {
576 cursor = items
577 .get_mut(i)
578 .ok_or_else(|| Error::NotEnoughWriteSpace {
579 tried_to_write: NonZeroUsize::new(count).unwrap(),
580 available: len,
581 buffer_type: ty,
582 })?
583 .deserialize(cursor)?;
584 }
585
586 Ok(cursor)
587 }
588
589 let cursor = self.header.deserialize(cursor)?;
590
591 // If we're truncated, we can't read the rest of the message.
592 if self.header.flags.truncated() {
593 self.header.clear();
594 return Ok(cursor);
595 }
596
597 let cursor = try_read_set(
598 cursor,
599 self.header.question_count as usize,
600 self.questions,
601 BufferType::Question,
602 )?;
603 let cursor = try_read_set(
604 cursor,
605 self.header.answer_count as usize,
606 self.answers,
607 BufferType::Answer,
608 )?;
609 let cursor = try_read_set(
610 cursor,
611 self.header.authority_count as usize,
612 self.authorities,
613 BufferType::Authority,
614 )?;
615 let cursor = try_read_set(
616 cursor,
617 self.header.additional_count as usize,
618 self.additional,
619 BufferType::Additional,
620 )?;
621
622 Ok(cursor)
623 }
624}
625
626serialize! {
627 /// The header for a DNS query.
628 #[derive(Debug, Default, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
629 pub struct Header {
630 /// The ID of this query.
631 id: u16,
632
633 /// The flags associated with this query.
634 flags: Flags,
635
636 /// The number of questions in this query.
637 question_count: u16,
638
639 /// The number of answers in this query.
640 answer_count: u16,
641
642 /// The number of authorities in this query.
643 authority_count: u16,
644
645 /// The number of additional records in this query.
646 additional_count: u16,
647 }
648}
649
650impl Header {
651 fn clear(&mut self) {
652 self.question_count = 0;
653 self.answer_count = 0;
654 self.authority_count = 0;
655 self.additional_count = 0;
656 }
657}
658
659serialize! {
660 /// The question in a DNS query.
661 #[derive(Debug, Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
662 pub struct Question<'a> {
663 /// The name of the question.
664 name: Label<'a>,
665
666 /// The type of the question.
667 ty: ResourceType,
668
669 /// The class of the question.
670 class: u16,
671 }
672}
673
674impl<'a> Question<'a> {
675 /// Create a new question.
676 pub fn new(label: impl Into<Label<'a>>, ty: ResourceType, class: u16) -> Self {
677 Self {
678 name: label.into(),
679 ty,
680 class,
681 }
682 }
683
684 /// Get the name of the question.
685 pub fn name(&self) -> Label<'a> {
686 self.name
687 }
688
689 /// Get the type of the question.
690 pub fn ty(&self) -> ResourceType {
691 self.ty
692 }
693
694 /// Get the class of the question.
695 pub fn class(&self) -> u16 {
696 self.class
697 }
698}
699
700serialize! {
701 /// A resource record in a DNS query.
702 #[derive(Debug, Default, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
703 pub struct ResourceRecord<'a> {
704 /// The name of the resource record.
705 name: Label<'a>,
706
707 /// The type of the resource record.
708 ty: ResourceType,
709
710 /// The class of the resource record.
711 class: u16,
712
713 /// The time-to-live of the resource record.
714 ttl: u32,
715
716 /// The data of the resource record.
717 data: ResourceData<'a>,
718 }
719}
720
721impl<'a> ResourceRecord<'a> {
722 /// Create a new `ResourceRecord`.
723 pub fn new(
724 name: impl Into<Label<'a>>,
725 ty: ResourceType,
726 class: u16,
727 ttl: u32,
728 data: &'a [u8],
729 ) -> Self {
730 Self {
731 name: name.into(),
732 ty,
733 class,
734 ttl,
735 data: data.into(),
736 }
737 }
738
739 /// Get the name of the resource record.
740 pub fn name(&self) -> Label<'a> {
741 self.name
742 }
743
744 /// Get the type of the resource record.
745 pub fn ty(&self) -> ResourceType {
746 self.ty
747 }
748
749 /// Get the class of the resource record.
750 pub fn class(&self) -> u16 {
751 self.class
752 }
753
754 /// Get the time-to-live of the resource record.
755 pub fn ttl(&self) -> u32 {
756 self.ttl
757 }
758
759 /// Get the data of the resource record.
760 pub fn data(&self) -> &'a [u8] {
761 self.data.0
762 }
763}
764
765/// The resource stored in a resource record.
766#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
767pub(crate) struct ResourceData<'a>(&'a [u8]);
768
769impl<'a> From<&'a [u8]> for ResourceData<'a> {
770 fn from(data: &'a [u8]) -> Self {
771 ResourceData(data)
772 }
773}
774
775impl<'a> Serialize<'a> for ResourceData<'a> {
776 fn serialized_len(&self) -> usize {
777 2 + self.0.len()
778 }
779
780 fn serialize(&self, bytes: &mut [u8]) -> Result<usize, Error> {
781 let len = self.serialized_len();
782 if bytes.len() < len {
783 panic!("not enough bytes to serialize resource data");
784 }
785
786 // Write the length as a big-endian u16
787 let [b1, b2] = (self.0.len() as u16).to_be_bytes();
788 bytes[0] = b1;
789 bytes[1] = b2;
790
791 // Write the data
792 bytes[2..len].copy_from_slice(self.0);
793
794 Ok(len)
795 }
796}
797
798impl<'a> Deserialize<'a> for ResourceData<'a> {
799 fn deserialize(&mut self, cursor: Cursor<'a>) -> Result<Cursor<'a>, Error> {
800 // Deserialize a u16 for the length
801 let mut len = 0u16;
802 let cursor = len.deserialize(cursor)?;
803
804 if len == 0 {
805 self.0 = &[];
806 return Ok(cursor);
807 }
808
809 // Read in the data
810 if cursor.len() < len as usize {
811 return Err(Error::NotEnoughReadBytes {
812 tried_to_read: NonZeroUsize::new(len as usize).unwrap(),
813 available: cursor.len(),
814 });
815 }
816
817 self.0 = &cursor.remaining()[..len as usize];
818 cursor.advance(len as usize)
819 }
820}
821
822/// The flags associated with a DNS message.
823#[derive(Default, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
824#[repr(transparent)]
825pub struct Flags(u16);
826
827impl Flags {
828 // Values used to manipulate the inside.
829 const RAW_QR: u16 = 1 << 15;
830 const RAW_OPCODE_SHIFT: u16 = 11;
831 const RAW_OPCODE_MASK: u16 = 0b1111;
832 const RAW_AA: u16 = 1 << 10;
833 const RAW_TC: u16 = 1 << 9;
834 const RAW_RD: u16 = 1 << 8;
835 const RAW_RA: u16 = 1 << 7;
836 const RAW_RCODE_SHIFT: u16 = 0;
837 const RAW_RCODE_MASK: u16 = 0b1111;
838
839 /// Create a new, empty set of flags.
840 pub const fn new() -> Self {
841 Self(0)
842 }
843
844 /// Use the standard set of flags for a DNS query.
845 ///
846 /// This is identical to `new()` but uses recursive querying.
847 pub const fn standard_query() -> Self {
848 Self(0x0100)
849 }
850
851 /// Get the query/response flag.
852 pub fn qr(&self) -> MessageType {
853 if self.0 & Self::RAW_QR != 0 {
854 MessageType::Reply
855 } else {
856 MessageType::Query
857 }
858 }
859
860 /// Set the message's query/response flag.
861 pub fn set_qr(&mut self, qr: MessageType) -> &mut Self {
862 if qr == MessageType::Reply {
863 self.0 |= Self::RAW_QR;
864 } else {
865 self.0 &= !Self::RAW_QR;
866 }
867
868 self
869 }
870
871 /// Get the opcode.
872 pub fn opcode(&self) -> Opcode {
873 let raw = (self.0 >> Self::RAW_OPCODE_SHIFT) & Self::RAW_OPCODE_MASK;
874 raw.try_into()
875 .unwrap_or_else(|_| panic!("invalid opcode: {}", raw))
876 }
877
878 /// Set the opcode.
879 pub fn set_opcode(&mut self, opcode: Opcode) {
880 self.0 |= (opcode as u16) << Self::RAW_OPCODE_SHIFT;
881 }
882
883 /// Get whether this message is authoritative.
884 pub fn authoritative(&self) -> bool {
885 self.0 & Self::RAW_AA != 0
886 }
887
888 /// Set whether this message is authoritative.
889 pub fn set_authoritative(&mut self, authoritative: bool) -> &mut Self {
890 if authoritative {
891 self.0 |= Self::RAW_AA;
892 } else {
893 self.0 &= !Self::RAW_AA;
894 }
895
896 self
897 }
898
899 /// Get whether this message is truncated.
900 pub fn truncated(&self) -> bool {
901 self.0 & Self::RAW_TC != 0
902 }
903
904 /// Set whether this message is truncated.
905 pub fn set_truncated(&mut self, truncated: bool) -> &mut Self {
906 if truncated {
907 self.0 |= Self::RAW_TC;
908 } else {
909 self.0 &= !Self::RAW_TC;
910 }
911
912 self
913 }
914
915 /// Get whether this message is recursive.
916 pub fn recursive(&self) -> bool {
917 self.0 & Self::RAW_RD != 0
918 }
919
920 /// Set whether this message is recursive.
921 pub fn set_recursive(&mut self, recursive: bool) -> &mut Self {
922 if recursive {
923 self.0 |= Self::RAW_RD;
924 } else {
925 self.0 &= !Self::RAW_RD;
926 }
927
928 self
929 }
930
931 /// Get whether recursion is available for this message.
932 pub fn recursion_available(&self) -> bool {
933 self.0 & Self::RAW_RA != 0
934 }
935
936 /// Set whether recursion is available for this message.
937 pub fn set_recursion_available(&mut self, recursion_available: bool) -> &mut Self {
938 if recursion_available {
939 self.0 |= Self::RAW_RA;
940 } else {
941 self.0 &= !Self::RAW_RA;
942 }
943
944 self
945 }
946
947 /// Get the response code.
948 pub fn response_code(&self) -> ResponseCode {
949 let raw = (self.0 >> Self::RAW_RCODE_SHIFT) & Self::RAW_RCODE_MASK;
950 raw.try_into()
951 .unwrap_or_else(|_| panic!("invalid response code: {}", raw))
952 }
953
954 /// Set the response code.
955 pub fn set_response_code(&mut self, response_code: ResponseCode) -> &mut Self {
956 self.0 |= (response_code as u16) << Self::RAW_RCODE_SHIFT;
957 self
958 }
959
960 /// Get the raw value of these flags.
961 pub fn raw(self) -> u16 {
962 self.0
963 }
964}
965
966impl fmt::Debug for Flags {
967 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
968 let mut list = f.debug_list();
969
970 list.entry(&self.qr());
971 list.entry(&self.opcode());
972
973 if self.authoritative() {
974 list.entry(&"authoritative");
975 }
976
977 if self.truncated() {
978 list.entry(&"truncated");
979 }
980
981 if self.recursive() {
982 list.entry(&"recursive");
983 }
984
985 if self.recursion_available() {
986 list.entry(&"recursion available");
987 }
988
989 list.entry(&self.response_code());
990
991 list.finish()
992 }
993}
994
995impl Serialize<'_> for Flags {
996 fn serialized_len(&self) -> usize {
997 2
998 }
999
1000 fn serialize(&self, buf: &mut [u8]) -> Result<usize, Error> {
1001 self.0.serialize(buf)
1002 }
1003}
1004
1005impl<'a> Deserialize<'a> for Flags {
1006 fn deserialize(&mut self, bytes: Cursor<'a>) -> Result<Cursor<'a>, Error> {
1007 u16::deserialize(&mut self.0, bytes)
1008 }
1009}
1010
1011/// Whether a message is a query or a reply.
1012#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
1013pub enum MessageType {
1014 /// The message is a query.
1015 Query,
1016
1017 /// The message is a reply.
1018 Reply,
1019}
1020
1021num_enum! {
1022 /// The operation code for the query.
1023 pub enum Opcode {
1024 /// A standard query.
1025 Query = 0,
1026
1027 /// A reverse query.
1028 IQuery = 1,
1029
1030 /// A server status request.
1031 Status = 2,
1032
1033 /// A notification of zone change.
1034 Notify = 4,
1035
1036 /// A dynamic update.
1037 Update = 5,
1038
1039 /// DSO query.
1040 Dso = 6,
1041 }
1042}
1043
1044num_enum! {
1045 /// The response code for a query.
1046 pub enum ResponseCode {
1047 /// There was no error in the query.
1048 NoError = 0,
1049
1050 /// The query was malformed.
1051 FormatError = 1,
1052
1053 /// The server failed to fulfill the query.
1054 ServerFailure = 2,
1055
1056 /// The name does not exist.
1057 NameError = 3,
1058
1059 /// The query is not implemented.
1060 NotImplemented = 4,
1061
1062 /// The query is refused.
1063 Refused = 5,
1064
1065 /// The name exists, but the query type is not supported.
1066 YxDomain = 6,
1067
1068 /// The name does not exist, but the query type is supported.
1069 YxRrSet = 7,
1070
1071 /// The name exists, but the query type is not supported.
1072 NxRrSet = 8,
1073
1074 /// The server is not authoritative for the zone.
1075 NotAuth = 9,
1076
1077 /// The name does not exist in the zone.
1078 NotZone = 10,
1079
1080 /// The DSO-TYPE is not supported.
1081 DsoTypeNi = 11,
1082
1083 /// Bad OPT version.
1084 BadVers = 16,
1085
1086 /// Key not recognized.
1087 BadKey = 17,
1088
1089 /// Signature out of time window.
1090 BadTime = 18,
1091
1092 /// Bad TKEY mode.
1093 BadMode = 19,
1094
1095 /// Duplicate key name.
1096 BadName = 20,
1097
1098 /// Algorithm not supported.
1099 BadAlg = 21,
1100
1101 /// Bad truncation.
1102 BadTrunc = 22,
1103
1104 /// Bad/missing server cookie.
1105 BadCookie = 23,
1106 }
1107}
1108
1109num_enum! {
1110 /// The resource types that a question can ask for.
1111 pub enum ResourceType {
1112 /// Get the host's IPv4 address.
1113 A = 1,
1114
1115 /// Get the authoritative name servers for a domain.
1116 NS = 2,
1117
1118 /// Get the mail server for a domain.
1119 MD = 3,
1120
1121 /// Get the mail forwarder for a domain.
1122 MF = 4,
1123
1124 /// Get the canonical name for a domain.
1125 CName = 5,
1126
1127 /// Get the start of authority record for a domain.
1128 Soa = 6,
1129
1130 /// Get the mailbox for a domain.
1131 MB = 7,
1132
1133 /// Get the mail group member for a domain.
1134 MG = 8,
1135
1136 /// Get the mail rename domain for a domain.
1137 MR = 9,
1138
1139 /// Get the null record for a domain.
1140 Null = 10,
1141
1142 /// Get the well known services for a domain.
1143 Wks = 11,
1144
1145 /// Get the domain pointer for a domain.
1146 Ptr = 12,
1147
1148 /// Get the host information for a domain.
1149 HInfo = 13,
1150
1151 /// Get the mailbox or mail list information for a domain.
1152 MInfo = 14,
1153
1154 /// Get the mail exchange for a domain.
1155 MX = 15,
1156
1157 /// Get the text for a domain.
1158 Txt = 16,
1159
1160 /// Get the responsible person for a domain.
1161 RP = 17,
1162
1163 /// Get the AFS database location for a domain.
1164 AfsDb = 18,
1165
1166 /// Get the X.25 address for a domain.
1167 X25 = 19,
1168
1169 /// Get the ISDN address for a domain.
1170 Isdn = 20,
1171
1172 /// Get the router for a domain.
1173 Rt = 21,
1174
1175 /// Get the NSAP address for a domain.
1176 NSap = 22,
1177
1178 /// Get the reverse NSAP address for a domain.
1179 NSapPtr = 23,
1180
1181 /// Get the security signature for a domain.
1182 Sig = 24,
1183
1184 /// Get the key for a domain.
1185 Key = 25,
1186
1187 /// Get the X.400 mail mapping for a domain.
1188 Px = 26,
1189
1190 /// Get the geographical location for a domain.
1191 GPos = 27,
1192
1193 /// Get the IPv6 address for a domain.
1194 AAAA = 28,
1195
1196 /// Get the location for a domain.
1197 Loc = 29,
1198
1199 /// Get the next domain name in a zone.
1200 Nxt = 30,
1201
1202 /// Get the endpoint identifier for a domain.
1203 EId = 31,
1204
1205 /// Get the Nimrod locator for a domain.
1206 NimLoc = 32,
1207
1208 /// Get the server selection for a domain.
1209 Srv = 33,
1210
1211 /// Get the ATM address for a domain.
1212 AtmA = 34,
1213
1214 /// Get the naming authority pointer for a domain.
1215 NAPtr = 35,
1216
1217 /// Get the key exchange for a domain.
1218 Kx = 36,
1219
1220 /// Get the certificate for a domain.
1221 Cert = 37,
1222
1223 /// Get the IPv6 address for a domain.
1224 ///
1225 /// This is obsolete; use `AAAA` instead.
1226 A6 = 38,
1227
1228 /// Get the DNAME for a domain.
1229 DName = 39,
1230
1231 /// Get the sink for a domain.
1232 Sink = 40,
1233
1234 /// Get the OPT for a domain.
1235 Opt = 41,
1236
1237 /// Get the address prefix list for a domain.
1238 ApL = 42,
1239
1240 /// Get the delegation signer for a domain.
1241 DS = 43,
1242
1243 /// Get the SSH key fingerprint for a domain.
1244 SshFp = 44,
1245
1246 /// Get the IPSEC key for a domain.
1247 IpSecKey = 45,
1248
1249 /// Get the resource record signature for a domain.
1250 RRSig = 46,
1251
1252 /// Get the next secure record for a domain.
1253 NSEC = 47,
1254
1255 /// Get the DNSKEY for a domain.
1256 DNSKey = 48,
1257
1258 /// Get the DHCID for a domain.
1259 DHCID = 49,
1260
1261 /// Get the NSEC3 for a domain.
1262 NSEC3 = 50,
1263
1264 /// Get the NSEC3 parameters for a domain.
1265 NSEC3Param = 51,
1266
1267 /// Get the TLSA for a domain.
1268 TLSA = 52,
1269
1270 /// Get the S/MIME certificate association for a domain.
1271 SMimeA = 53,
1272
1273 /// Get the host information for a domain.
1274 HIP = 55,
1275
1276 /// Get the NINFO for a domain.
1277 NInfo = 56,
1278
1279 /// Get the RKEY for a domain.
1280 RKey = 57,
1281
1282 /// Get the trust anchor link for a domain.
1283 TALink = 58,
1284
1285 /// Get the child DS for a domain.
1286 CDS = 59,
1287
1288 /// Get the DNSKEY for a domain.
1289 CDNSKey = 60,
1290
1291 /// Get the OpenPGP key for a domain.
1292 OpenPGPKey = 61,
1293
1294 /// Get the Child-to-Parent Synchronization for a domain.
1295 CSync = 62,
1296
1297 /// Get the Zone Data Message for a domain.
1298 ZoneMD = 63,
1299
1300 /// Get the General Purpose Service Binding for a domain.
1301 Svcb = 64,
1302
1303 /// Get the HTTP Service Binding for a domain.
1304 Https = 65,
1305
1306 /// Get the Sender Policy Framework for a domain.
1307 Spf = 99,
1308
1309 /// Get the UINFO for a domain.
1310 UInfo = 100,
1311
1312 /// Get the UID for a domain.
1313 UID = 101,
1314
1315 /// Get the GID for a domain.
1316 GID = 102,
1317
1318 /// Get the UNSPEC for a domain.
1319 Unspec = 103,
1320
1321 /// Get the NID for a domain.
1322 NID = 104,
1323
1324 /// Get the L32 for a domain.
1325 L32 = 105,
1326
1327 /// Get the L64 for a domain.
1328 L64 = 106,
1329
1330 /// Get the LP for a domain.
1331 LP = 107,
1332
1333 /// Get the EUI48 for a domain.
1334 EUI48 = 108,
1335
1336 /// Get the EUI64 for a domain.
1337 EUI64 = 109,
1338
1339 /// Get the transaction key for a domain.
1340 TKey = 249,
1341
1342 /// Get the transaction signature for a domain.
1343 TSig = 250,
1344
1345 /// Get the incremental transfer for a domain.
1346 Ixfr = 251,
1347
1348 /// Get the transfer of an entire zone for a domain.
1349 Axfr = 252,
1350
1351 /// Get the mailbox-related records for a domain.
1352 MailB = 253,
1353
1354 /// Get the mail agent RRs for a domain.
1355 ///
1356 /// This is obsolete; use `MX` instead.
1357 MailA = 254,
1358
1359 /// Get the wildcard match for a domain.
1360 Wildcard = 255,
1361
1362 /// Get the URI for a domain.
1363 Uri = 256,
1364
1365 /// Get the certification authority authorization for a domain.
1366 Caa = 257,
1367
1368 /// Get the application visibility and control for a domain.
1369 Avc = 258,
1370
1371 /// Get the digital object architecture for a domain.
1372 Doa = 259,
1373
1374 /// Get the automatic network discovery for a domain.
1375 Amtrelay = 260,
1376
1377 /// Get the DNSSEC trust authorities for a domain.
1378 TA = 32768,
1379
1380 /// Get the DNSSEC lookaside validation for a domain.
1381 DLV = 32769,
1382 }
1383}
1384
1385impl Default for ResourceType {
1386 fn default() -> Self {
1387 Self::A
1388 }
1389}
1390
1391/// The given value is not a valid code.
1392#[derive(Debug, Clone)]
1393pub struct InvalidCode(u16);
1394
1395impl InvalidCode {
1396 /// Get the invalid code.
1397 pub fn code(&self) -> u16 {
1398 self.0
1399 }
1400}
1401
1402impl fmt::Display for InvalidCode {
1403 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1404 write!(f, "invalid code: {}", self.0)
1405 }
1406}
1407
1408impl StdError for InvalidCode {}
1409
1410#[cfg(test)]
1411mod tests {
1412 use super::*;
1413
1414 #[test]
1415 fn resource_data_serialization() {
1416 let mut buf = [0u8; 7];
1417 let record = ResourceData(&[0x1f, 0xfe, 0x02, 0x24, 0x75]);
1418 let len = record
1419 .serialize(&mut buf)
1420 .expect("serialized into provided buffer");
1421 assert_eq!(len, 7);
1422 assert_eq!(buf, [0x00, 0x05, 0x1f, 0xfe, 0x02, 0x24, 0x75]);
1423 }
1424}