dns_message_parser/decode/
domain_name.rs1use crate::{
2 decode::Decoder, domain_name::DOMAIN_NAME_MAX_RECURSION, DecodeError, DecodeResult, DomainName,
3};
4use std::{collections::HashSet, str::from_utf8};
5
6const COMPRESSION_BITS: u8 = 0b1100_0000;
7const COMPRESSION_BITS_REV: u8 = 0b0011_1111;
8
9#[inline]
10const fn is_compressed(length: u8) -> bool {
11 (length & COMPRESSION_BITS) == COMPRESSION_BITS
12}
13
14#[inline]
15const fn get_offset(length_1: u8, length_2: u8) -> u16 {
16 (((length_1 & COMPRESSION_BITS_REV) as u16) << 8) | length_2 as u16
17}
18
19enum DomainNameLength {
20 Compressed(u16),
21 Label(u8),
22}
23
24impl<'a, 'b: 'a> Decoder<'a, 'b> {
25 fn domain_name_length(&mut self) -> DecodeResult<DomainNameLength> {
26 let length = self.u8()?;
27 if is_compressed(length) {
28 let offset = self.u8()?;
29 Ok(DomainNameLength::Compressed(get_offset(length, offset)))
30 } else {
31 Ok(DomainNameLength::Label(length))
32 }
33 }
34
35 pub(super) fn domain_name(&mut self) -> DecodeResult<DomainName> {
36 let mut domain_name = DomainName::default();
37
38 loop {
39 match self.domain_name_length()? {
40 DomainNameLength::Compressed(offset) => {
41 self.domain_name_recursion(&mut domain_name, offset)?;
42 return Ok(domain_name);
43 }
44 DomainNameLength::Label(0) => return Ok(domain_name),
45 DomainNameLength::Label(length) => {
46 self.domain_name_label(&mut domain_name, length)?
47 }
48 }
49 }
50 }
51
52 fn domain_name_label(&mut self, domain_name: &mut DomainName, length: u8) -> DecodeResult<()> {
53 let buffer = self.read(length as usize)?;
54 let label = from_utf8(buffer.as_ref())?;
55 let label = label.parse()?;
56 domain_name.append_label(label)?;
57 Ok(())
58 }
59
60 fn domain_name_recursion(&self, domain_name: &mut DomainName, offset: u16) -> DecodeResult<()> {
61 let mut decoder = self.new_main_offset(offset);
62 let mut recursions = HashSet::new();
63
64 loop {
65 match decoder.domain_name_length()? {
66 DomainNameLength::Compressed(offset) => {
67 if recursions.insert(offset) {
68 let recursions_len = recursions.len();
69 if recursions_len > DOMAIN_NAME_MAX_RECURSION {
70 return Err(DecodeError::MaxRecursion(recursions_len));
71 }
72 } else {
73 return Err(DecodeError::EndlessRecursion(offset));
74 }
75
76 decoder.offset = offset as usize;
77 }
78 DomainNameLength::Label(0) => return Ok(()),
79 DomainNameLength::Label(length) => {
80 decoder.domain_name_label(domain_name, length)?;
81 }
82 }
83 }
84 }
85}
86
87impl_decode!(DomainName, domain_name);