1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
use crate::decode::Decoder; use crate::domain_name::DOMAIN_NAME_MAX_RECURSION; use crate::{DecodeError, DecodeResult, DomainName}; use std::str::from_utf8; const COMPRESSION_BITS: u8 = 0b1100_0000; const COMPRESSION_BITS_REV: u8 = 0b0011_1111; #[inline] fn is_compressed(length: u8) -> bool { (length & COMPRESSION_BITS) == COMPRESSION_BITS } #[inline] fn get_offset(length_1: u8, length_2: u8) -> usize { (((length_1 & COMPRESSION_BITS_REV) as usize) << 8) | length_2 as usize } impl<'a, 'b: 'a> Decoder<'a, 'b> { pub(super) fn domain_name(&mut self) -> DecodeResult<DomainName> { let mut domain_name = DomainName::default(); self.domain_name_recursion(&mut domain_name, 0)?; Ok(domain_name) } fn domain_name_recursion( &mut self, domain_name: &mut DomainName, recursion: usize, ) -> DecodeResult<()> { if recursion > DOMAIN_NAME_MAX_RECURSION { return Err(DecodeError::MaxRecursionError(recursion)); } let mut length = self.u8()?; while length != 0 { if is_compressed(length) { let buffer = self.u8()?; let offset = get_offset(length, buffer); let mut decoder = self.new_main_offset(offset); return decoder.domain_name_recursion(domain_name, recursion + 1); } else { let buffer = self.read(length as usize)?; let label = from_utf8(buffer.as_ref())?; domain_name.append_label(label)?; length = self.u8()?; } } Ok(()) } } impl_decode!(DomainName, domain_name);