flex_dns/
name.rs

1use core::fmt::{Debug, Display, Formatter};
2use core::hash::Hash;
3use crate::{Buffer, DnsError, DnsMessage, DnsMessageError, MutBuffer};
4use crate::parse::ParseBytes;
5use crate::write::WriteBytes;
6
7const MAX_DOMAIN_NAME_DEPTH: usize = 128;
8const MAX_DOMAIN_NAME_LABEL_LENGTH: usize = 63;
9
10/// A DNS name.
11#[derive(Clone, Copy)]
12pub struct DnsName<'a> {
13    bytes: &'a [u8],
14    offset: usize,
15}
16
17/// Create a new [`DnsName`] from a byte slice. The domain name must
18/// be dot separated. The constructor will check convert the domain name
19/// to DNS wire format and check if it is valid.
20/// This macro accepts an expression as input, which is evaluated at compile
21/// time. If you want to create a [`DnsName`] from a byte slice of unknown
22/// length, use the [`DnsName::new`] function instead.
23///
24/// # Example
25/// ```
26/// use flex_dns::dns_name;
27/// use flex_dns::name::DnsName;
28///
29/// const NAME: DnsName = dns_name!(b"example.com");
30/// ```
31#[macro_export]
32macro_rules! dns_name {
33    ($value:expr $(,)?) => {
34        {
35            const NAME: [u8; $value.len() + 2] = {
36                let mut result = [0; $value.len() + 2];
37                let mut label_start = 0;
38                let mut index = 0;
39                loop {
40                    if index == $value.len() {
41                        if index - label_start > u8::MAX as usize {
42                            panic!("Label too long, maximum length is 255.");
43                        }
44
45                        result[label_start] = (index - label_start) as u8;
46
47                        break;
48                    }
49
50                    let byte = $value[index];
51                    if byte == b'.' {
52                        if index - label_start > u8::MAX as usize {
53                            panic!("Label too long, maximum length is 255.");
54                        }
55
56                        result[label_start] = (index - label_start) as u8;
57                        label_start = index + 1;
58                    } else {
59                        result[index + 1] = byte;
60                    }
61
62                    index += 1;
63                }
64
65                result
66            };
67            unsafe { flex_dns::name::DnsName::new_unchecked(&NAME) }
68        }
69    };
70}
71
72impl<'a> DnsName<'a> {
73    /// Create a new [`DnsName`] from a byte slice. The bytes must be in DNS
74    /// wire format. The constructor will check if the name is valid.
75    #[inline(always)]
76    pub fn new(bytes: &'a [u8]) -> Result<Self, DnsMessageError> {
77        for part in (NameIterator {
78            bytes,
79            offset: 0,
80            depth: 0,
81        }) {
82            part?;
83        }
84
85        Ok(Self { bytes, offset: 0 })
86    }
87
88    /// Create a new [`DnsName`] from a byte slice. The bytes must be in DNS
89    /// wire format. The constructor will not check if the name is valid, hence
90    /// the `unsafe`. Using this function is unsafe cause it can lead to an
91    /// invalid DNS message.
92    #[inline(always)]
93    pub const unsafe fn new_unchecked(bytes: &'a [u8]) -> Self {
94        Self { bytes, offset: 0 }
95    }
96
97    /// Return an iterator over the parts of the name.
98    #[inline(always)]
99    pub fn iter(&self) -> NameIterator<'a> {
100        NameIterator {
101            bytes: self.bytes,
102            offset: self.offset,
103            depth: 0,
104        }
105    }
106
107    pub(crate) fn split_first(&self) -> Result<(&'a [u8], Option<Self>), DnsMessageError> {
108        let mut iter = self.iter();
109        let first = iter.next().unwrap()?;
110        if let Some(next) = iter.next() {
111            let next = next?;
112
113            // Calculate offset from address of the second pointer.
114            let offset = next.as_ptr() as usize - self.bytes.as_ptr() as usize - 1;
115
116            Ok((first, Some(Self {
117                bytes: self.bytes,
118                offset,
119            })))
120        } else {
121            Ok((first, None))
122        }
123    }
124}
125
126impl<'a> ParseBytes<'a> for DnsName<'a> {
127    #[inline]
128    fn parse_bytes(bytes: &'a [u8], i: &mut usize) -> Result<Self, DnsMessageError> {
129        const MAX_LENGTH: usize = 255;
130        let mut j = *i;
131
132        loop {
133            if j - *i >= MAX_LENGTH {
134                return Err(DnsMessageError::DnsError(DnsError::NameTooLong));
135            }
136
137            match LabelType::from_bytes(bytes, &mut j)? {
138                LabelType::Pointer(_) => {
139                    break;
140                }
141                LabelType::Part(len) => {
142                    j += len as usize;
143
144                    if len == 0 {
145                        break;
146                    }
147
148                    if len > MAX_DOMAIN_NAME_LABEL_LENGTH as u8 {
149                        return Err(DnsMessageError::DnsError(DnsError::LabelTooLong));
150                    }
151                }
152            }
153        }
154
155        let offset = *i;
156        *i = j;
157
158        Ok(Self { bytes, offset })
159    }
160}
161
162impl<'a> WriteBytes for DnsName<'a> {
163    #[inline(always)]
164    fn write<
165        const PTR_STORAGE: usize,
166        const DNS_SECTION: usize,
167        B: MutBuffer + Buffer,
168    >(&self, message: &mut DnsMessage<PTR_STORAGE, DNS_SECTION, B>) -> Result<usize, DnsMessageError> {
169        message.write_name(*self)
170    }
171}
172
173/// An iterator over the parts of a [`DnsName`]. By default, this iterator is
174/// not cycle safe, meaning that it will not detect cycles in the name. If there
175/// is a cycle, the iterator will loop till the maximum depth is reached (128).
176pub struct NameIterator<'a> {
177    bytes: &'a [u8],
178    offset: usize,
179    depth: usize,
180}
181
182impl<'a> NameIterator<'a> {
183    /// Return a cycle safe version of this iterator. If there is a cycle in the
184    /// name, the iterator will return an error. The cycle safe detection uses
185    /// O(n^2) comparisons, where n is the number of parts in the name.
186    pub fn cycle_safe(self) -> CycleSafeNameIterator<'a> {
187        CycleSafeNameIterator {
188            iter: self,
189            depth: [0; MAX_DOMAIN_NAME_DEPTH],
190        }
191    }
192}
193
194impl<'a> Iterator for NameIterator<'a> {
195    type Item = Result<&'a [u8], DnsMessageError>;
196
197    fn next(&mut self) -> Option<Self::Item> {
198        let mut i = self.offset;
199        loop {
200            self.depth += 1;
201            if self.depth > MAX_DOMAIN_NAME_DEPTH {
202                return Some(Err(DnsMessageError::DnsError(DnsError::NameTooLong)));
203            }
204
205            match LabelType::from_bytes(self.bytes, &mut i).unwrap() {
206                LabelType::Pointer(ptr) => {
207                    if ptr < self.offset as u16 {
208                        // The pointer points to an earlier part of the message.
209                        i = ptr as usize;
210
211                        continue;
212                    } else {
213                        // The pointer points into the future.
214                        return Some(Err(DnsMessageError::DnsError(DnsError::PointerIntoTheFuture)));
215                    }
216                }
217                LabelType::Part(len) => {
218                    if len == 0 {
219                        // We've reached the end of the name.
220                        return None;
221                    }
222
223                    if len > MAX_DOMAIN_NAME_LABEL_LENGTH as u8 {
224                        return Some(Err(DnsMessageError::DnsError(DnsError::LabelTooLong)));
225                    }
226
227                    if self.bytes.len() < i + len as usize {
228                        // The name is longer than the buffer.
229                        return Some(Err(DnsMessageError::DnsError(DnsError::MessageTooShort)));
230                    }
231
232                    let part = &self.bytes[i..i + len as usize];
233                    self.offset = i + len as usize;
234
235                    return Some(Ok(part))
236                }
237            }
238        }
239    }
240}
241
242/// A cycle safe version of [`NameIterator`]. If there is a cycle in the name,
243/// the iterator will return an error.
244pub struct CycleSafeNameIterator<'a> {
245    iter: NameIterator<'a>,
246    depth: [usize; MAX_DOMAIN_NAME_DEPTH],
247}
248
249impl<'a> Iterator for CycleSafeNameIterator<'a> {
250    type Item = Result<&'a [u8], DnsMessageError>;
251
252    fn next(&mut self) -> Option<Self::Item> {
253        let next = self.iter.next();
254
255        if let Some(Ok(part)) = next {
256            let part = part.as_ptr() as usize;
257
258            for &known_part in &self.depth[..self.iter.depth - 1] {
259                if known_part == part {
260                    return Some(Err(DnsMessageError::DnsError(DnsError::PointerCycle)));
261                }
262            }
263
264            self.depth[self.iter.depth - 1] = part;
265        }
266
267        next
268    }
269}
270
271impl PartialEq<DnsName<'_>> for DnsName<'_> {
272    fn eq(&self, other: &DnsName<'_>) -> bool {
273        for (a, b) in self.iter().zip(other.iter()) {
274            match (a, b) {
275                (Ok(a), Ok(b)) => {
276                    if a != b {
277                        return false;
278                    }
279                }
280                _ => {
281                    return false;
282                }
283            }
284        }
285
286        true
287    }
288}
289
290impl Hash for DnsName<'_> {
291    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
292        for part in self.iter() {
293            if let Err(_) = part {
294                // If the name is invalid, we cannot hash it.
295                return;
296            }
297
298            let part = part.unwrap();
299            state.write_u8(part.len() as u8);
300            state.write(part);
301        }
302    }
303}
304
305impl Display for DnsName<'_> {
306    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
307        let mut first = true;
308        for part in self.iter() {
309            if first {
310                first = false;
311            } else {
312                f.write_str(".")?;
313            }
314
315            let part = part.map_err(|_| core::fmt::Error)?;
316            f.write_str(core::str::from_utf8(part).unwrap())?;
317        }
318
319        Ok(())
320    }
321}
322
323impl Debug for DnsName<'_> {
324    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
325        f.write_str("DnsName(")?;
326        Display::fmt(self, f)?;
327        f.write_str(")")?;
328
329        Ok(())
330    }
331}
332
333#[derive(PartialEq)]
334enum LabelType {
335    Pointer(u16),
336    Part(u8),
337}
338
339impl LabelType {
340    fn from_bytes(bytes: &[u8], i: &mut usize) -> Result<Self, DnsMessageError> {
341        const PTR_MASK: u8 = 0b11000000;
342        const LEN_MASK: u8 = !PTR_MASK;
343
344        let c = u8::parse_bytes(bytes, i)?;
345
346        if c & PTR_MASK == PTR_MASK {
347            let c = c & LEN_MASK;
348            let pointer = u16::from_be_bytes([c, u8::parse_bytes(bytes, i)?]);
349            if pointer >= *i as u16 {
350                // Cannot point to the future.
351                return Err(DnsMessageError::DnsError(DnsError::PointerIntoTheFuture));
352            }
353
354            Ok(Self::Pointer(pointer))
355        } else {
356            let len = c & LEN_MASK;
357
358            Ok(Self::Part(len))
359        }
360    }
361}