1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
use crate::domain_name::DOMAIN_NAME_MAX_RECURSION; use crate::encode::Encoder; use crate::{DomainName, EncodeError, EncodeResult}; use std::collections::HashMap; const MAX_OFFSET: u16 = 0b0011_1111_1111_1111; const COMPRESSION_BITS: u16 = 0b1100_0000_0000_0000; impl Encoder { #[inline] fn compress(&mut self, domain_name_str: &str) -> EncodeResult<Option<usize>> { if let Some((index, recursion)) = self.domain_name_index.get(domain_name_str) { let index = *index; if MAX_OFFSET < index { return Err(EncodeError::CompressionError(index)); } let recursion = *recursion; if recursion > DOMAIN_NAME_MAX_RECURSION { return Ok(None); } let index = COMPRESSION_BITS | index; self.u16(index); Ok(Some(recursion)) } else { Ok(None) } } #[inline] fn label(&mut self, label: &str) -> EncodeResult<u16> { let index = self.get_offset()?; self.string(label)?; Ok(index) } #[inline] fn merge_domain_name_index( &mut self, domain_name_index: HashMap<String, u16>, recursion: usize, ) -> EncodeResult<()> { if recursion > DOMAIN_NAME_MAX_RECURSION { return Err(EncodeError::MaxRecursionError(recursion)); } for (domain_name_str, index) in domain_name_index { self.domain_name_index .insert(domain_name_str, (index, recursion)); } Ok(()) } pub(super) fn domain_name(&mut self, domain_name: &DomainName) -> EncodeResult<()> { let mut domain_name_index = HashMap::new(); for (label, domain_name_str) in domain_name.iter() { if let Some(recursion) = self.compress(domain_name_str)? { self.merge_domain_name_index(domain_name_index, recursion + 1)?; return Ok(()); } let index = self.label(label)?; if index <= MAX_OFFSET { domain_name_index.insert(domain_name_str.to_string(), index); } } self.string("")?; self.merge_domain_name_index(domain_name_index, 0)?; Ok(()) } } impl DomainName { fn iter(&self) -> DomainNameIter { DomainNameIter { domain_name_str: self.as_str(), } } fn as_str(&self) -> &str { &self.0 } } struct DomainNameIter<'a> { domain_name_str: &'a str, } impl<'a> Iterator for DomainNameIter<'a> { type Item = (&'a str, &'a str); fn next(&mut self) -> Option<Self::Item> { let index = self.domain_name_str.find('.')?; let (label, remain) = self.domain_name_str.split_at(index); let remain = remain.get(1..)?; let domain_name_str = self.domain_name_str; self.domain_name_str = remain; if label.is_empty() { None } else { Some((label, domain_name_str)) } } } impl_encode!(DomainName, domain_name);