dns_message_parser/decode/
domain_name.rs

1use crate::{
2    decode::Decoder, domain_name::DOMAIN_NAME_MAX_RECURSION, DecodeError, DecodeResult, DomainName,
3};
4use std::{collections::HashSet, str::from_utf8};
5
6const COMPRESSION_BITS: u8 = 0b1100_0000;
7const COMPRESSION_BITS_REV: u8 = 0b0011_1111;
8
9#[inline]
10const fn is_compressed(length: u8) -> bool {
11    (length & COMPRESSION_BITS) == COMPRESSION_BITS
12}
13
14#[inline]
15const fn get_offset(length_1: u8, length_2: u8) -> u16 {
16    (((length_1 & COMPRESSION_BITS_REV) as u16) << 8) | length_2 as u16
17}
18
19enum DomainNameLength {
20    Compressed(u16),
21    Label(u8),
22}
23
24impl<'a, 'b: 'a> Decoder<'a, 'b> {
25    fn domain_name_length(&mut self) -> DecodeResult<DomainNameLength> {
26        let length = self.u8()?;
27        if is_compressed(length) {
28            let offset = self.u8()?;
29            Ok(DomainNameLength::Compressed(get_offset(length, offset)))
30        } else {
31            Ok(DomainNameLength::Label(length))
32        }
33    }
34
35    pub(super) fn domain_name(&mut self) -> DecodeResult<DomainName> {
36        let mut domain_name = DomainName::default();
37
38        loop {
39            match self.domain_name_length()? {
40                DomainNameLength::Compressed(offset) => {
41                    self.domain_name_recursion(&mut domain_name, offset)?;
42                    return Ok(domain_name);
43                }
44                DomainNameLength::Label(0) => return Ok(domain_name),
45                DomainNameLength::Label(length) => {
46                    self.domain_name_label(&mut domain_name, length)?
47                }
48            }
49        }
50    }
51
52    fn domain_name_label(&mut self, domain_name: &mut DomainName, length: u8) -> DecodeResult<()> {
53        let buffer = self.read(length as usize)?;
54        let label = from_utf8(buffer.as_ref())?;
55        let label = label.parse()?;
56        domain_name.append_label(label)?;
57        Ok(())
58    }
59
60    fn domain_name_recursion(&self, domain_name: &mut DomainName, offset: u16) -> DecodeResult<()> {
61        let mut decoder = self.new_main_offset(offset);
62        let mut recursions = HashSet::new();
63
64        loop {
65            match decoder.domain_name_length()? {
66                DomainNameLength::Compressed(offset) => {
67                    if recursions.insert(offset) {
68                        let recursions_len = recursions.len();
69                        if recursions_len > DOMAIN_NAME_MAX_RECURSION {
70                            return Err(DecodeError::MaxRecursion(recursions_len));
71                        }
72                    } else {
73                        return Err(DecodeError::EndlessRecursion(offset));
74                    }
75
76                    decoder.offset = offset as usize;
77                }
78                DomainNameLength::Label(0) => return Ok(()),
79                DomainNameLength::Label(length) => {
80                    decoder.domain_name_label(domain_name, length)?;
81                }
82            }
83        }
84    }
85}
86
87impl_decode!(DomainName, domain_name);