1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
use crate::decode::Decoder;
use crate::domain_name::DOMAIN_NAME_MAX_RECURSION;
use crate::{DecodeError, DecodeResult, DomainName};
use std::str::from_utf8;

const COMPRESSION_BITS: u8 = 0b1100_0000;
const COMPRESSION_BITS_REV: u8 = 0b0011_1111;

#[inline]
const fn is_compressed(length: u8) -> bool {
    (length & COMPRESSION_BITS) == COMPRESSION_BITS
}

#[inline]
const fn get_offset(length_1: u8, length_2: u8) -> usize {
    (((length_1 & COMPRESSION_BITS_REV) as usize) << 8) | length_2 as usize
}

impl<'a, 'b: 'a> Decoder<'a, 'b> {
    pub(super) fn domain_name(&mut self) -> DecodeResult<DomainName> {
        let mut domain_name = DomainName::default();
        self.domain_name_recursion(&mut domain_name, 0)?;
        Ok(domain_name)
    }

    fn domain_name_recursion(
        &mut self,
        domain_name: &mut DomainName,
        recursion: usize,
    ) -> DecodeResult<()> {
        if recursion > DOMAIN_NAME_MAX_RECURSION {
            return Err(DecodeError::MaxRecursion(recursion));
        }

        let mut length = self.u8()?;
        while length != 0 {
            if is_compressed(length) {
                let buffer = self.u8()?;
                let offset = get_offset(length, buffer);
                let mut decoder = self.new_main_offset(offset);
                return decoder.domain_name_recursion(domain_name, recursion + 1);
            } else {
                let buffer = self.read(length as usize)?;
                let label = from_utf8(buffer.as_ref())?;
                domain_name.append_label(label)?;
                length = self.u8()?;
            }
        }

        Ok(())
    }
}

impl_decode!(DomainName, domain_name);