dns_protocol_patch/
ser.rs

1//! A serializer/deserializer compatible with the DNS wire format.
2//!
3//! Various notes about the wire format:
4//!
5//! - All integers are in big endian format.
6//! - All strings are length-prefixed.
7//! - Names consist of a series of labels, each prefixed with a length byte.
8//! - Names end with a zero byte.
9//! - Names are compressed by using a pointer to a previous name.
10//! - This means that we need access to the entire buffer.
11
12use super::Error;
13
14use core::convert::TryInto;
15use core::fmt;
16use core::hash;
17use core::iter;
18use core::mem;
19use core::num::NonZeroUsize;
20
21use memchr::Memchr;
22
23/// An object that is able to be serialized to a series of bytes.
24pub trait Serialize<'a> {
25    /// The number of bytes needed to serialize this object.
26    fn serialized_len(&self) -> usize;
27
28    /// Serialize this object into a series of bytes.
29    fn serialize(&self, bytes: &mut [u8]) -> Result<usize, Error>;
30}
31
32/// An object that is able to be deserialized from a series of bytes.
33pub trait Deserialize<'a> {
34    /// Deserialize this object from a series of bytes.
35    fn deserialize(&mut self, cursor: Cursor<'a>) -> Result<Cursor<'a>, Error>;
36}
37
38/// A cursor into a series of bytes.
39///
40/// This keeps track of how many bytes that we have already consumed, while also providing the original buffer.
41#[derive(Debug, Copy, Clone)]
42pub struct Cursor<'a> {
43    /// The bytes being read.
44    bytes: &'a [u8],
45
46    /// The index into the bytes that we've read so far.
47    cursor: usize,
48}
49
50impl<'a> Cursor<'a> {
51    /// Create a new cursor from a series of bytes.
52    pub fn new(bytes: &'a [u8]) -> Self {
53        Self { bytes, cursor: 0 }
54    }
55
56    /// Get the original bytes that this cursor was created from.
57    pub fn original(&self) -> &'a [u8] {
58        self.bytes
59    }
60
61    /// Get the slice of remaining bytes.
62    pub fn remaining(&self) -> &'a [u8] {
63        &self.bytes[self.cursor..]
64    }
65
66    /// Get the length of the slice of remaining bytes.
67    #[allow(clippy::len_without_is_empty)]
68    pub fn len(&self) -> usize {
69        self.bytes.len() - self.cursor
70    }
71
72    /// Get a new cursor at the given absolute position.
73    pub fn at(&self, pos: usize) -> Self {
74        Self {
75            bytes: self.bytes,
76            cursor: pos,
77        }
78    }
79
80    /// Advance the cursor by the given number of bytes.
81    pub fn advance(mut self, n: usize) -> Result<Self, Error> {
82        if n == 0 {
83            return Ok(self);
84        }
85
86        if self.cursor + n > self.bytes.len() {
87            return Err(Error::NotEnoughReadBytes {
88                tried_to_read: NonZeroUsize::new(self.cursor.saturating_add(n)).unwrap(),
89                available: self.bytes.len(),
90            });
91        }
92
93        self.cursor += n;
94        Ok(self)
95    }
96
97    /// Error for when a read of `n` bytes failed.
98    fn read_error(&self, n: usize) -> Error {
99        Error::NotEnoughReadBytes {
100            tried_to_read: NonZeroUsize::new(self.cursor.saturating_add(n)).unwrap(),
101            available: self.bytes.len(),
102        }
103    }
104}
105
106impl Serialize<'_> for () {
107    fn serialized_len(&self) -> usize {
108        0
109    }
110
111    fn serialize(&self, _bytes: &mut [u8]) -> Result<usize, Error> {
112        Ok(0)
113    }
114}
115
116impl<'a> Deserialize<'a> for () {
117    fn deserialize(&mut self, bytes: Cursor<'a>) -> Result<Cursor<'a>, Error> {
118        Ok(bytes)
119    }
120}
121
122/// A DNS name.
123#[derive(Clone, Copy)]
124pub struct Label<'a> {
125    repr: Repr<'a>,
126}
127
128/// Internal representation of a DNS name.
129#[derive(Clone, Copy)]
130enum Repr<'a> {
131    /// The name is represented by its range of bytes in the initial buffer.
132    ///
133    /// It is lazily parsed into labels when needed.
134    Bytes {
135        /// The original buffer, in totality, that this name was parsed from.
136        original: &'a [u8],
137
138        /// The starting position of this name in the original buffer.
139        start: usize,
140
141        /// The ending position of this name in the original buffer.
142        end: usize,
143    },
144
145    /// The label will be the parsed version of this string.
146    String {
147        /// The string representation of the label.
148        string: &'a str,
149    },
150}
151
152impl Default for Label<'_> {
153    fn default() -> Self {
154        // An empty label.
155        Self {
156            repr: Repr::Bytes {
157                original: &[0],
158                start: 0,
159                end: 1,
160            },
161        }
162    }
163}
164
165impl<'a> PartialEq<Label<'a>> for Label<'_> {
166    fn eq(&self, other: &Label<'a>) -> bool {
167        self.segments().eq(other.segments())
168    }
169}
170
171impl Eq for Label<'_> {}
172
173impl<'a> PartialOrd<Label<'a>> for Label<'_> {
174    fn partial_cmp(&self, other: &Label<'a>) -> Option<core::cmp::Ordering> {
175        self.segments().partial_cmp(other.segments())
176    }
177}
178
179impl Ord for Label<'_> {
180    fn cmp(&self, other: &Self) -> core::cmp::Ordering {
181        self.segments().cmp(other.segments())
182    }
183}
184
185impl hash::Hash for Label<'_> {
186    fn hash<H: hash::Hasher>(&self, state: &mut H) {
187        for segment in self.segments() {
188            segment.hash(state);
189        }
190    }
191}
192
193impl<'a> Label<'a> {
194    /// Get an iterator over the label segments in this name.
195    pub fn segments(&self) -> impl Iterator<Item = LabelSegment<'a>> {
196        match self.repr {
197            Repr::Bytes {
198                original, start, ..
199            } => Either::A(parse_bytes(original, start)),
200            Repr::String { string } => Either::B(parse_string(string)),
201        }
202    }
203
204    /// Get an iterator over the strings making up this name.
205    pub fn names(&self) -> impl Iterator<Item = Result<&'a str, &'a [u8]>> {
206        match self.repr {
207            Repr::String { string } => {
208                // Guaranteed to have no pointers.
209                Either::A(parse_string(string).filter_map(|seg| seg.as_str().map(Ok)))
210            }
211            Repr::Bytes {
212                original, start, ..
213            } => {
214                // We may have to deal with pointers. Parse it manually.
215                let mut cursor = Cursor {
216                    bytes: original,
217                    cursor: start,
218                };
219
220                Either::B(iter::from_fn(move || {
221                    loop {
222                        let mut ls: LabelSegment<'_> = LabelSegment::Empty;
223                        cursor = ls.deserialize(cursor).ok()?;
224
225                        // TODO: Handle the case where utf-8 errors out.
226                        match ls {
227                            LabelSegment::Empty => return None,
228                            LabelSegment::Pointer(pos) => {
229                                // Change to another location.
230                                cursor = cursor.at(pos.into());
231                            }
232                            LabelSegment::String(label) => return Some(Ok(label)),
233                        }
234                    }
235                }))
236            }
237        }
238    }
239}
240
241impl<'a> From<&'a str> for Label<'a> {
242    fn from(string: &'a str) -> Self {
243        Self {
244            repr: Repr::String { string },
245        }
246    }
247}
248
249impl fmt::Debug for Label<'_> {
250    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
251        struct LabelFmt<'a>(&'a Label<'a>);
252
253        impl fmt::Debug for LabelFmt<'_> {
254            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
255                fmt::Display::fmt(self.0, f)
256            }
257        }
258
259        f.debug_tuple("Label").field(&LabelFmt(self)).finish()
260    }
261}
262
263impl fmt::Display for Label<'_> {
264    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
265        self.names().enumerate().try_for_each(|(i, name)| {
266            if i > 0 {
267                f.write_str(".")?;
268            }
269
270            match name {
271                Ok(name) => f.write_str(name),
272                Err(_) => f.write_str("???"),
273            }
274        })
275    }
276}
277
278impl<'a> Serialize<'a> for Label<'a> {
279    fn serialized_len(&self) -> usize {
280        if let Repr::Bytes { start, end, .. } = self.repr {
281            return end - start;
282        }
283
284        self.segments()
285            .map(|item| item.serialized_len())
286            .fold(0, |a, b| a.saturating_add(b))
287    }
288
289    fn serialize(&self, bytes: &mut [u8]) -> Result<usize, Error> {
290        // Fast path: just copy our bytes.
291        if let Repr::Bytes {
292            original,
293            start,
294            end,
295        } = self.repr
296        {
297            bytes[..end - start].copy_from_slice(&original[start..end]);
298            return Ok(end - start);
299        }
300
301        self.segments().try_fold(0, |mut offset, item| {
302            let len = item.serialize(&mut bytes[offset..])?;
303            offset += len;
304            Ok(offset)
305        })
306    }
307}
308
309impl<'a> Deserialize<'a> for Label<'a> {
310    fn deserialize(&mut self, cursor: Cursor<'a>) -> Result<Cursor<'a>, Error> {
311        let original = cursor.original();
312        let start = cursor.cursor;
313
314        // Figure out where the end is.
315        let mut end = start;
316        loop {
317            let len_char = match original.get(end) {
318                Some(0) => {
319                    end += 1;
320                    break;
321                }
322                Some(ptr) if ptr & PTR_MASK != 0 => {
323                    // Pointer goes for one more byte and then ends.
324                    end += 2;
325                    break;
326                }
327                Some(len_char) => *len_char,
328                None => {
329                    return Err(Error::NotEnoughReadBytes {
330                        tried_to_read: NonZeroUsize::new(cursor.cursor + end).unwrap(),
331                        available: original.len(),
332                    })
333                }
334            };
335
336            let len = len_char as usize;
337            end += len + 1;
338        }
339
340        self.repr = Repr::Bytes {
341            original,
342            start,
343            end,
344        };
345        cursor.advance(end - start)
346    }
347}
348
349/// Parse a set of bytes as a DNS name.
350fn parse_bytes(bytes: &[u8], position: usize) -> impl Iterator<Item = LabelSegment<'_>> + '_ {
351    let mut cursor = Cursor {
352        bytes,
353        cursor: position,
354    };
355    let mut keep_going = true;
356
357    iter::from_fn(move || {
358        if !keep_going {
359            return None;
360        }
361
362        let mut segment = LabelSegment::Empty;
363        cursor = segment.deserialize(cursor).ok()?;
364
365        match segment {
366            LabelSegment::String(_) => {}
367            _ => {
368                keep_going = false;
369            }
370        }
371
372        Some(segment)
373    })
374}
375
376/// Parse a string as a DNS name.
377fn parse_string(str: &str) -> impl Iterator<Item = LabelSegment<'_>> + '_ {
378    let dot = Memchr::new(b'.', str.as_bytes());
379    let mut last_index = 0;
380
381    // Add one extra entry to `dot` that's at the end of the string.
382    let dot = dot.chain(Some(str.len()));
383
384    dot.filter_map(move |index| {
385        let item = &str[last_index..index];
386        last_index = index.saturating_add(1);
387
388        if item.is_empty() {
389            None
390        } else {
391            Some(LabelSegment::String(item))
392        }
393    })
394    .chain(Some(LabelSegment::Empty))
395}
396
397/// A DNS-compatible label segment.
398#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
399pub enum LabelSegment<'a> {
400    /// The empty terminator.
401    Empty,
402
403    /// A string label.
404    String(&'a str),
405
406    /// A pointer to a previous name.
407    Pointer(u16),
408}
409
410const MAX_STR_LEN: usize = !PTR_MASK as usize;
411const PTR_MASK: u8 = 0b1100_0000;
412
413impl<'a> LabelSegment<'a> {
414    fn as_str(&self) -> Option<&'a str> {
415        match self {
416            Self::String(s) => Some(s),
417            _ => None,
418        }
419    }
420}
421
422impl Default for LabelSegment<'_> {
423    fn default() -> Self {
424        Self::Empty
425    }
426}
427
428impl<'a> Serialize<'a> for LabelSegment<'a> {
429    fn serialized_len(&self) -> usize {
430        match self {
431            Self::Empty => 1,
432            Self::Pointer(_) => 2,
433            Self::String(s) => 1 + s.len(),
434        }
435    }
436
437    fn serialize(&self, bytes: &mut [u8]) -> Result<usize, Error> {
438        match self {
439            Self::Empty => {
440                // An empty label segment is just a zero byte.
441                bytes[0] = 0;
442                Ok(1)
443            }
444            Self::Pointer(ptr) => {
445                // Apply the pointer mask to the first byte.
446                let [mut b1, b2] = ptr.to_be_bytes();
447                b1 |= PTR_MASK;
448                bytes[0] = b1;
449                bytes[1] = b2;
450                Ok(2)
451            }
452            Self::String(s) => {
453                // First, serialize the length byte.
454                let len = s.len();
455
456                if len > MAX_STR_LEN {
457                    return Err(Error::NameTooLong(len));
458                }
459
460                if len > bytes.len() {
461                    panic!("not enough bytes to serialize string");
462                }
463
464                bytes[0] = len as u8;
465                bytes[1..=len].copy_from_slice(s.as_bytes());
466                Ok(len + 1)
467            }
468        }
469    }
470}
471
472impl<'a> Deserialize<'a> for LabelSegment<'a> {
473    fn deserialize(&mut self, cursor: Cursor<'a>) -> Result<Cursor<'a>, Error> {
474        // The type is determined by the first byte.
475        let b1 = *cursor
476            .remaining()
477            .first()
478            .ok_or_else(|| cursor.read_error(1))?;
479
480        if b1 == 0 {
481            // An empty label segment is just a zero byte.
482            *self = Self::Empty;
483            cursor.advance(1)
484        } else if b1 & PTR_MASK == PTR_MASK {
485            // A pointer is a 2-byte value with the pointer mask applied to the first byte.
486            let [b1, b2]: [u8; 2] = cursor.remaining()[..2]
487                .try_into()
488                .map_err(|_| cursor.read_error(2))?;
489            let ptr = u16::from_be_bytes([b1 & !PTR_MASK, b2]);
490            *self = Self::Pointer(ptr);
491            cursor.advance(2)
492        } else {
493            // A string label is a length byte followed by the string.
494            let len = b1 as usize;
495
496            if len > MAX_STR_LEN {
497                return Err(Error::NameTooLong(len));
498            }
499
500            // Parse the string's bytes.
501            let bytes = cursor.remaining()[1..=len]
502                .try_into()
503                .map_err(|_| cursor.read_error(len + 1))?;
504
505            // Parse as UTF8
506            let s = simdutf8::compat::from_utf8(bytes)?;
507            *self = Self::String(s);
508            cursor.advance(len + 1)
509        }
510    }
511}
512
513macro_rules! serialize_num {
514    ($($num_ty: ident),*) => {
515        $(
516            impl<'a> Serialize<'a> for $num_ty {
517                fn serialized_len(&self) -> usize {
518                    mem::size_of::<$num_ty>()
519                }
520
521                fn serialize(&self, bytes: &mut [u8]) -> Result<usize, Error> {
522                    if bytes.len() < mem::size_of::<$num_ty>() {
523                        panic!("Not enough space to serialize a {}", stringify!($num_ty));
524                    }
525
526                    let value = (*self).to_be_bytes();
527                    bytes[..mem::size_of::<$num_ty>()].copy_from_slice(&value);
528
529                    Ok(mem::size_of::<$num_ty>())
530                }
531            }
532
533            impl<'a> Deserialize<'a> for $num_ty {
534                fn deserialize(&mut self, bytes: Cursor<'a>) -> Result<Cursor<'a>, Error> {
535                    if bytes.len() < mem::size_of::<$num_ty>() {
536                        return Err(bytes.read_error(mem::size_of::<$num_ty>()));
537                    }
538
539                    let mut value = [0; mem::size_of::<$num_ty>()];
540                    value.copy_from_slice(&bytes.remaining()[..mem::size_of::<$num_ty>()]);
541                    *self = $num_ty::from_be_bytes(value);
542
543                    bytes.advance(mem::size_of::<$num_ty>())
544                }
545            }
546        )*
547    }
548}
549
550serialize_num! {
551    u8, u16, u32, u64,
552    i8, i16, i32, i64
553}
554
555/// One iterator or another.
556enum Either<A, B> {
557    A(A),
558    B(B),
559}
560
561impl<A: Iterator, Other: Iterator<Item = A::Item>> Iterator for Either<A, Other> {
562    type Item = A::Item;
563
564    fn next(&mut self) -> Option<Self::Item> {
565        match self {
566            Either::A(a) => a.next(),
567            Either::B(b) => b.next(),
568        }
569    }
570
571    fn size_hint(&self) -> (usize, Option<usize>) {
572        match self {
573            Either::A(a) => a.size_hint(),
574            Either::B(b) => b.size_hint(),
575        }
576    }
577
578    fn fold<B, F>(self, init: B, f: F) -> B
579    where
580        Self: Sized,
581        F: FnMut(B, Self::Item) -> B,
582    {
583        match self {
584            Either::A(a) => a.fold(init, f),
585            Either::B(b) => b.fold(init, f),
586        }
587    }
588}