dns_message_parser/encode/
domain_name.rs

1use crate::domain_name::DOMAIN_NAME_MAX_RECURSION;
2use crate::encode::Encoder;
3use crate::label::Label;
4use crate::{DomainName, EncodeError, EncodeResult};
5use std::collections::HashMap;
6
7const MAX_OFFSET: u16 = 0b0011_1111_1111_1111;
8const COMPRESSION_BITS: u16 = 0b1100_0000_0000_0000;
9
10impl Encoder {
11    #[inline]
12    fn compress(&mut self, domain_name: &DomainName) -> EncodeResult<Option<usize>> {
13        if let Some((index, recursion)) = self.domain_name_index.get(domain_name) {
14            let index = *index;
15            if MAX_OFFSET < index {
16                return Err(EncodeError::Compression(index));
17            }
18
19            let recursion = *recursion;
20            if recursion > DOMAIN_NAME_MAX_RECURSION {
21                return Ok(None);
22            }
23
24            let index = COMPRESSION_BITS | index;
25            self.u16(index);
26
27            Ok(Some(recursion))
28        } else {
29            Ok(None)
30        }
31    }
32
33    #[inline]
34    fn label(&mut self, label: &Label) -> EncodeResult<u16> {
35        let index = self.get_offset()?;
36        self.string_with_len(label.as_ref())?;
37        Ok(index)
38    }
39
40    #[inline]
41    fn merge_domain_name_index(
42        &mut self,
43        domain_name_index: HashMap<DomainName, u16>,
44        recursion: usize,
45    ) -> EncodeResult<()> {
46        if recursion > DOMAIN_NAME_MAX_RECURSION {
47            return Err(EncodeError::MaxRecursion(recursion));
48        }
49
50        for (domain_name_str, index) in domain_name_index {
51            self.domain_name_index
52                .insert(domain_name_str, (index, recursion));
53        }
54        Ok(())
55    }
56
57    pub(super) fn domain_name(&mut self, domain_name: &DomainName) -> EncodeResult<()> {
58        let mut domain_name_index = HashMap::new();
59        for (label, domain_name) in domain_name.iter() {
60            if let Some(recursion) = self.compress(&domain_name)? {
61                self.merge_domain_name_index(domain_name_index, recursion + 1)?;
62                return Ok(());
63            }
64
65            let index = self.label(&label)?;
66            if index <= MAX_OFFSET {
67                domain_name_index.insert(domain_name, index);
68            }
69        }
70        self.string_with_len("")?;
71        self.merge_domain_name_index(domain_name_index, 0)?;
72        Ok(())
73    }
74}
75
76impl DomainName {
77    fn iter(&self) -> DomainNameIter<'_> {
78        DomainNameIter { labels: &self.0 }
79    }
80}
81
82struct DomainNameIter<'a> {
83    labels: &'a [Label],
84}
85
86impl<'a> Iterator for DomainNameIter<'a> {
87    type Item = (Label, DomainName);
88
89    fn next(&mut self) -> Option<Self::Item> {
90        if self.labels.is_empty() {
91            return None;
92        }
93
94        let label = self.labels[0].clone();
95        let domain_name = DomainName(self.labels.to_vec());
96        self.labels = &self.labels[1..];
97        Some((label, domain_name))
98    }
99}
100impl_encode!(DomainName, domain_name);