dns_protocol/
ser.rs

1//! A serializer/deserializer compatible with the DNS wire format.
2//!
3//! Various notes about the wire format:
4//!
5//! - All integers are in big endian format.
6//! - All strings are length-prefixed.
7//! - Names consist of a series of labels, each prefixed with a length byte.
8//! - Names end with a zero byte.
9//! - Names are compressed by using a pointer to a previous name.
10//! - This means that we need access to the entire buffer.
11
12use super::Error;
13
14use core::convert::TryInto;
15use core::fmt;
16use core::hash;
17use core::iter;
18use core::mem;
19use core::num::NonZeroUsize;
20
21use memchr::Memchr;
22
23/// An object that is able to be serialized to or deserialized from a series of bytes.
24pub(crate) trait Serialize<'a> {
25    /// The number of bytes needed to serialize this object.
26    fn serialized_len(&self) -> usize;
27
28    /// Serialize this object into a series of bytes.
29    fn serialize(&self, bytes: &mut [u8]) -> Result<usize, Error>;
30
31    /// Deserialize this object from a series of bytes.
32    fn deserialize(&mut self, cursor: Cursor<'a>) -> Result<Cursor<'a>, Error>;
33}
34
35/// A cursor into a series of bytes.
36///
37/// This keeps track of how many bytes that we have already consumed, while also providing the original buffer.
38#[derive(Debug, Copy, Clone)]
39pub(crate) struct Cursor<'a> {
40    /// The bytes being read.
41    bytes: &'a [u8],
42
43    /// The index into the bytes that we've read so far.
44    cursor: usize,
45}
46
47impl<'a> Cursor<'a> {
48    /// Create a new cursor from a series of bytes.
49    pub(crate) fn new(bytes: &'a [u8]) -> Self {
50        Self { bytes, cursor: 0 }
51    }
52
53    /// Get the original bytes that this cursor was created from.
54    pub(crate) fn original(&self) -> &'a [u8] {
55        self.bytes
56    }
57
58    /// Get the slice of remaining bytes.
59    pub(crate) fn remaining(&self) -> &'a [u8] {
60        &self.bytes[self.cursor..]
61    }
62
63    /// Get the length of the slice of remaining bytes.
64    pub(crate) fn len(&self) -> usize {
65        self.bytes.len() - self.cursor
66    }
67
68    /// Get a new cursor at the given absolute position.
69    pub(crate) fn at(&self, pos: usize) -> Self {
70        Self {
71            bytes: self.bytes,
72            cursor: pos,
73        }
74    }
75
76    /// Advance the cursor by the given number of bytes.
77    pub(crate) fn advance(mut self, n: usize) -> Result<Self, Error> {
78        if n == 0 {
79            return Ok(self);
80        }
81
82        if self.cursor + n > self.bytes.len() {
83            return Err(Error::NotEnoughReadBytes {
84                tried_to_read: NonZeroUsize::new(self.cursor.saturating_add(n)).unwrap(),
85                available: self.bytes.len(),
86            });
87        }
88
89        self.cursor += n;
90        Ok(self)
91    }
92
93    /// Error for when a read of `n` bytes failed.
94    fn read_error(&self, n: usize) -> Error {
95        Error::NotEnoughReadBytes {
96            tried_to_read: NonZeroUsize::new(self.cursor.saturating_add(n)).unwrap(),
97            available: self.bytes.len(),
98        }
99    }
100}
101
102impl<'a> Serialize<'a> for () {
103    fn serialized_len(&self) -> usize {
104        0
105    }
106
107    fn serialize(&self, _bytes: &mut [u8]) -> Result<usize, Error> {
108        Ok(0)
109    }
110
111    fn deserialize(&mut self, bytes: Cursor<'a>) -> Result<Cursor<'a>, Error> {
112        Ok(bytes)
113    }
114}
115
116/// A DNS name.
117#[derive(Clone, Copy)]
118pub struct Label<'a> {
119    repr: Repr<'a>,
120}
121
122/// Internal representation of a DNS name.
123#[derive(Clone, Copy)]
124enum Repr<'a> {
125    /// The name is represented by its range of bytes in the initial buffer.
126    ///
127    /// It is lazily parsed into labels when needed.
128    Bytes {
129        /// The original buffer, in totality, that this name was parsed from.
130        original: &'a [u8],
131
132        /// The starting position of this name in the original buffer.
133        start: usize,
134
135        /// The ending position of this name in the original buffer.
136        end: usize,
137    },
138
139    /// The label will be the parsed version of this string.
140    String {
141        /// The string representation of the label.
142        string: &'a str,
143    },
144}
145
146impl Default for Label<'_> {
147    fn default() -> Self {
148        // An empty label.
149        Self {
150            repr: Repr::Bytes {
151                original: &[0],
152                start: 0,
153                end: 1,
154            },
155        }
156    }
157}
158
159impl<'a, 'b> PartialEq<Label<'a>> for Label<'b> {
160    fn eq(&self, other: &Label<'a>) -> bool {
161        self.segments().eq(other.segments())
162    }
163}
164
165impl Eq for Label<'_> {}
166
167impl<'a, 'b> PartialOrd<Label<'a>> for Label<'b> {
168    fn partial_cmp(&self, other: &Label<'a>) -> Option<core::cmp::Ordering> {
169        self.segments().partial_cmp(other.segments())
170    }
171}
172
173impl Ord for Label<'_> {
174    fn cmp(&self, other: &Self) -> core::cmp::Ordering {
175        self.segments().cmp(other.segments())
176    }
177}
178
179impl hash::Hash for Label<'_> {
180    fn hash<H: hash::Hasher>(&self, state: &mut H) {
181        for segment in self.segments() {
182            segment.hash(state);
183        }
184    }
185}
186
187impl<'a> Label<'a> {
188    /// Get an iterator over the label segments in this name.
189    pub fn segments(&self) -> impl Iterator<Item = LabelSegment<'a>> {
190        match self.repr {
191            Repr::Bytes {
192                original, start, ..
193            } => Either::A(parse_bytes(original, start)),
194            Repr::String { string } => Either::B(parse_string(string)),
195        }
196    }
197
198    /// Get an iterator over the strings making up this name.
199    pub fn names(&self) -> impl Iterator<Item = Result<&'a str, &'a [u8]>> {
200        match self.repr {
201            Repr::String { string } => {
202                // Guaranteed to have no pointers.
203                Either::A(parse_string(string).filter_map(|seg| seg.as_str().map(Ok)))
204            }
205            Repr::Bytes {
206                original, start, ..
207            } => {
208                // We may have to deal with pointers. Parse it manually.
209                let mut cursor = Cursor {
210                    bytes: original,
211                    cursor: start,
212                };
213
214                Either::B(iter::from_fn(move || {
215                    loop {
216                        let mut ls: LabelSegment<'_> = LabelSegment::Empty;
217                        cursor = ls.deserialize(cursor).ok()?;
218
219                        // TODO: Handle the case where utf-8 errors out.
220                        match ls {
221                            LabelSegment::Empty => return None,
222                            LabelSegment::Pointer(pos) => {
223                                // Change to another location.
224                                cursor = cursor.at(pos.into());
225                            }
226                            LabelSegment::String(label) => return Some(Ok(label)),
227                        }
228                    }
229                }))
230            }
231        }
232    }
233}
234
235impl<'a> From<&'a str> for Label<'a> {
236    fn from(string: &'a str) -> Self {
237        Self {
238            repr: Repr::String { string },
239        }
240    }
241}
242
243impl fmt::Debug for Label<'_> {
244    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
245        struct LabelFmt<'a>(&'a Label<'a>);
246
247        impl fmt::Debug for LabelFmt<'_> {
248            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
249                fmt::Display::fmt(self.0, f)
250            }
251        }
252
253        f.debug_tuple("Label").field(&LabelFmt(self)).finish()
254    }
255}
256
257impl fmt::Display for Label<'_> {
258    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
259        self.names().enumerate().try_for_each(|(i, name)| {
260            if i > 0 {
261                f.write_str(".")?;
262            }
263
264            match name {
265                Ok(name) => f.write_str(name),
266                Err(_) => f.write_str("???"),
267            }
268        })
269    }
270}
271
272impl<'a> Serialize<'a> for Label<'a> {
273    fn serialized_len(&self) -> usize {
274        if let Repr::Bytes { start, end, .. } = self.repr {
275            return end - start;
276        }
277
278        self.segments()
279            .map(|item| item.serialized_len())
280            .fold(0, |a, b| a.saturating_add(b))
281    }
282
283    fn serialize(&self, bytes: &mut [u8]) -> Result<usize, Error> {
284        // Fast path: just copy our bytes.
285        if let Repr::Bytes {
286            original,
287            start,
288            end,
289        } = self.repr
290        {
291            bytes[..end - start].copy_from_slice(&original[start..end]);
292            return Ok(end - start);
293        }
294
295        self.segments().try_fold(0, |mut offset, item| {
296            let len = item.serialize(&mut bytes[offset..])?;
297            offset += len;
298            Ok(offset)
299        })
300    }
301
302    fn deserialize(&mut self, cursor: Cursor<'a>) -> Result<Cursor<'a>, Error> {
303        let original = cursor.original();
304        let start = cursor.cursor;
305
306        // Figure out where the end is.
307        let mut end = start;
308        loop {
309            let len_char = match original.get(end) {
310                Some(0) => {
311                    end += 1;
312                    break;
313                }
314                Some(ptr) if ptr & PTR_MASK != 0 => {
315                    // Pointer goes for one more byte and then ends.
316                    end += 2;
317                    break;
318                }
319                Some(len_char) => *len_char,
320                None => {
321                    return Err(Error::NotEnoughReadBytes {
322                        tried_to_read: NonZeroUsize::new(cursor.cursor + end).unwrap(),
323                        available: original.len(),
324                    })
325                }
326            };
327
328            let len = len_char as usize;
329            end += len + 1;
330        }
331
332        self.repr = Repr::Bytes {
333            original,
334            start,
335            end,
336        };
337        cursor.advance(end - start)
338    }
339}
340
341/// Parse a set of bytes as a DNS name.
342fn parse_bytes(bytes: &[u8], position: usize) -> impl Iterator<Item = LabelSegment<'_>> + '_ {
343    let mut cursor = Cursor {
344        bytes,
345        cursor: position,
346    };
347    let mut keep_going = true;
348
349    iter::from_fn(move || {
350        if !keep_going {
351            return None;
352        }
353
354        let mut segment = LabelSegment::Empty;
355        cursor = segment.deserialize(cursor).ok()?;
356
357        match segment {
358            LabelSegment::String(_) => {}
359            _ => {
360                keep_going = false;
361            }
362        }
363
364        Some(segment)
365    })
366}
367
368/// Parse a string as a DNS name.
369fn parse_string(str: &str) -> impl Iterator<Item = LabelSegment<'_>> + '_ {
370    let dot = Memchr::new(b'.', str.as_bytes());
371    let mut last_index = 0;
372
373    // Add one extra entry to `dot` that's at the end of the string.
374    let dot = dot.chain(Some(str.len()));
375
376    dot.filter_map(move |index| {
377        let item = &str[last_index..index];
378        last_index = index.saturating_add(1);
379
380        if item.is_empty() {
381            None
382        } else {
383            Some(LabelSegment::String(item))
384        }
385    })
386    .chain(Some(LabelSegment::Empty))
387}
388
389/// A DNS-compatible label segment.
390#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
391pub enum LabelSegment<'a> {
392    /// The empty terminator.
393    Empty,
394
395    /// A string label.
396    String(&'a str),
397
398    /// A pointer to a previous name.
399    Pointer(u16),
400}
401
402const MAX_STR_LEN: usize = !PTR_MASK as usize;
403const PTR_MASK: u8 = 0b1100_0000;
404
405impl<'a> LabelSegment<'a> {
406    fn as_str(&self) -> Option<&'a str> {
407        match self {
408            Self::String(s) => Some(s),
409            _ => None,
410        }
411    }
412}
413
414impl Default for LabelSegment<'_> {
415    fn default() -> Self {
416        Self::Empty
417    }
418}
419
420impl<'a> Serialize<'a> for LabelSegment<'a> {
421    fn serialized_len(&self) -> usize {
422        match self {
423            Self::Empty => 1,
424            Self::Pointer(_) => 2,
425            Self::String(s) => 1 + s.len(),
426        }
427    }
428
429    fn serialize(&self, bytes: &mut [u8]) -> Result<usize, Error> {
430        match self {
431            Self::Empty => {
432                // An empty label segment is just a zero byte.
433                bytes[0] = 0;
434                Ok(1)
435            }
436            Self::Pointer(ptr) => {
437                // Apply the pointer mask to the first byte.
438                let [mut b1, b2] = ptr.to_be_bytes();
439                b1 |= PTR_MASK;
440                bytes[0] = b1;
441                bytes[1] = b2;
442                Ok(2)
443            }
444            Self::String(s) => {
445                // First, serialize the length byte.
446                let len = s.len();
447
448                if len > MAX_STR_LEN {
449                    return Err(Error::NameTooLong(len));
450                }
451
452                if len > bytes.len() {
453                    panic!("not enough bytes to serialize string");
454                }
455
456                bytes[0] = len as u8;
457                bytes[1..=len].copy_from_slice(s.as_bytes());
458                Ok(len + 1)
459            }
460        }
461    }
462
463    fn deserialize(&mut self, cursor: Cursor<'a>) -> Result<Cursor<'a>, Error> {
464        // The type is determined by the first byte.
465        let b1 = *cursor
466            .remaining()
467            .first()
468            .ok_or_else(|| cursor.read_error(1))?;
469
470        if b1 == 0 {
471            // An empty label segment is just a zero byte.
472            *self = Self::Empty;
473            cursor.advance(1)
474        } else if b1 & PTR_MASK == PTR_MASK {
475            // A pointer is a 2-byte value with the pointer mask applied to the first byte.
476            let [b1, b2]: [u8; 2] = cursor.remaining()[..2]
477                .try_into()
478                .map_err(|_| cursor.read_error(2))?;
479            let ptr = u16::from_be_bytes([b1 & !PTR_MASK, b2]);
480            *self = Self::Pointer(ptr);
481            cursor.advance(2)
482        } else {
483            // A string label is a length byte followed by the string.
484            let len = b1 as usize;
485
486            if len > MAX_STR_LEN {
487                return Err(Error::NameTooLong(len));
488            }
489
490            // Parse the string's bytes.
491            let bytes = cursor.remaining()[1..=len]
492                .try_into()
493                .map_err(|_| cursor.read_error(len + 1))?;
494
495            // Parse as UTF8
496            let s = core::str::from_utf8(bytes)?;
497            *self = Self::String(s);
498            cursor.advance(len + 1)
499        }
500    }
501}
502
503macro_rules! serialize_num {
504    ($($num_ty: ident),*) => {
505        $(
506            impl<'a> Serialize<'a> for $num_ty {
507                fn serialized_len(&self) -> usize {
508                    mem::size_of::<$num_ty>()
509                }
510
511                fn serialize(&self, bytes: &mut [u8]) -> Result<usize, Error> {
512                    if bytes.len() < mem::size_of::<$num_ty>() {
513                        panic!("Not enough space to serialize a {}", stringify!($num_ty));
514                    }
515
516                    let value = (*self).to_be_bytes();
517                    bytes[..mem::size_of::<$num_ty>()].copy_from_slice(&value);
518
519                    Ok(mem::size_of::<$num_ty>())
520                }
521
522                fn deserialize(&mut self, bytes: Cursor<'a>) -> Result<Cursor<'a>, Error> {
523                    if bytes.len() < mem::size_of::<$num_ty>() {
524                        return Err(bytes.read_error(mem::size_of::<$num_ty>()));
525                    }
526
527                    let mut value = [0; mem::size_of::<$num_ty>()];
528                    value.copy_from_slice(&bytes.remaining()[..mem::size_of::<$num_ty>()]);
529                    *self = $num_ty::from_be_bytes(value);
530
531                    bytes.advance(mem::size_of::<$num_ty>())
532                }
533            }
534        )*
535    }
536}
537
538serialize_num! {
539    u8, u16, u32, u64,
540    i8, i16, i32, i64
541}
542
543/// One iterator or another.
544enum Either<A, B> {
545    A(A),
546    B(B),
547}
548
549impl<A: Iterator, Other: Iterator<Item = A::Item>> Iterator for Either<A, Other> {
550    type Item = A::Item;
551
552    fn next(&mut self) -> Option<Self::Item> {
553        match self {
554            Either::A(a) => a.next(),
555            Either::B(b) => b.next(),
556        }
557    }
558
559    fn size_hint(&self) -> (usize, Option<usize>) {
560        match self {
561            Either::A(a) => a.size_hint(),
562            Either::B(b) => b.size_hint(),
563        }
564    }
565
566    fn fold<B, F>(self, init: B, f: F) -> B
567    where
568        Self: Sized,
569        F: FnMut(B, Self::Item) -> B,
570    {
571        match self {
572            Either::A(a) => a.fold(init, f),
573            Either::B(b) => b.fold(init, f),
574        }
575    }
576}