1use core::fmt::{Debug, Display, Formatter};
2use core::hash::Hash;
3use crate::{Buffer, DnsError, DnsMessage, DnsMessageError, MutBuffer};
4use crate::parse::ParseBytes;
5use crate::write::WriteBytes;
6
7const MAX_DOMAIN_NAME_DEPTH: usize = 128;
8const MAX_DOMAIN_NAME_LABEL_LENGTH: usize = 63;
9
10#[derive(Clone, Copy)]
12pub struct DnsName<'a> {
13 bytes: &'a [u8],
14 offset: usize,
15}
16
17#[macro_export]
32macro_rules! dns_name {
33 ($value:expr $(,)?) => {
34 {
35 const NAME: [u8; $value.len() + 2] = {
36 let mut result = [0; $value.len() + 2];
37 let mut label_start = 0;
38 let mut index = 0;
39 loop {
40 if index == $value.len() {
41 if index - label_start > u8::MAX as usize {
42 panic!("Label too long, maximum length is 255.");
43 }
44
45 result[label_start] = (index - label_start) as u8;
46
47 break;
48 }
49
50 let byte = $value[index];
51 if byte == b'.' {
52 if index - label_start > u8::MAX as usize {
53 panic!("Label too long, maximum length is 255.");
54 }
55
56 result[label_start] = (index - label_start) as u8;
57 label_start = index + 1;
58 } else {
59 result[index + 1] = byte;
60 }
61
62 index += 1;
63 }
64
65 result
66 };
67 unsafe { flex_dns::name::DnsName::new_unchecked(&NAME) }
68 }
69 };
70}
71
72impl<'a> DnsName<'a> {
73 #[inline(always)]
76 pub fn new(bytes: &'a [u8]) -> Result<Self, DnsMessageError> {
77 for part in (NameIterator {
78 bytes,
79 offset: 0,
80 depth: 0,
81 }) {
82 part?;
83 }
84
85 Ok(Self { bytes, offset: 0 })
86 }
87
88 #[inline(always)]
93 pub const unsafe fn new_unchecked(bytes: &'a [u8]) -> Self {
94 Self { bytes, offset: 0 }
95 }
96
97 #[inline(always)]
99 pub fn iter(&self) -> NameIterator<'a> {
100 NameIterator {
101 bytes: self.bytes,
102 offset: self.offset,
103 depth: 0,
104 }
105 }
106
107 pub(crate) fn split_first(&self) -> Result<(&'a [u8], Option<Self>), DnsMessageError> {
108 let mut iter = self.iter();
109 let first = iter.next().unwrap()?;
110 if let Some(next) = iter.next() {
111 let next = next?;
112
113 let offset = next.as_ptr() as usize - self.bytes.as_ptr() as usize - 1;
115
116 Ok((first, Some(Self {
117 bytes: self.bytes,
118 offset,
119 })))
120 } else {
121 Ok((first, None))
122 }
123 }
124}
125
126impl<'a> ParseBytes<'a> for DnsName<'a> {
127 #[inline]
128 fn parse_bytes(bytes: &'a [u8], i: &mut usize) -> Result<Self, DnsMessageError> {
129 const MAX_LENGTH: usize = 255;
130 let mut j = *i;
131
132 loop {
133 if j - *i >= MAX_LENGTH {
134 return Err(DnsMessageError::DnsError(DnsError::NameTooLong));
135 }
136
137 match LabelType::from_bytes(bytes, &mut j)? {
138 LabelType::Pointer(_) => {
139 break;
140 }
141 LabelType::Part(len) => {
142 j += len as usize;
143
144 if len == 0 {
145 break;
146 }
147
148 if len > MAX_DOMAIN_NAME_LABEL_LENGTH as u8 {
149 return Err(DnsMessageError::DnsError(DnsError::LabelTooLong));
150 }
151 }
152 }
153 }
154
155 let offset = *i;
156 *i = j;
157
158 Ok(Self { bytes, offset })
159 }
160}
161
162impl<'a> WriteBytes for DnsName<'a> {
163 #[inline(always)]
164 fn write<
165 const PTR_STORAGE: usize,
166 const DNS_SECTION: usize,
167 B: MutBuffer + Buffer,
168 >(&self, message: &mut DnsMessage<PTR_STORAGE, DNS_SECTION, B>) -> Result<usize, DnsMessageError> {
169 message.write_name(*self)
170 }
171}
172
173pub struct NameIterator<'a> {
177 bytes: &'a [u8],
178 offset: usize,
179 depth: usize,
180}
181
182impl<'a> NameIterator<'a> {
183 pub fn cycle_safe(self) -> CycleSafeNameIterator<'a> {
187 CycleSafeNameIterator {
188 iter: self,
189 depth: [0; MAX_DOMAIN_NAME_DEPTH],
190 }
191 }
192}
193
194impl<'a> Iterator for NameIterator<'a> {
195 type Item = Result<&'a [u8], DnsMessageError>;
196
197 fn next(&mut self) -> Option<Self::Item> {
198 let mut i = self.offset;
199 loop {
200 self.depth += 1;
201 if self.depth > MAX_DOMAIN_NAME_DEPTH {
202 return Some(Err(DnsMessageError::DnsError(DnsError::NameTooLong)));
203 }
204
205 match LabelType::from_bytes(self.bytes, &mut i).unwrap() {
206 LabelType::Pointer(ptr) => {
207 if ptr < self.offset as u16 {
208 i = ptr as usize;
210
211 continue;
212 } else {
213 return Some(Err(DnsMessageError::DnsError(DnsError::PointerIntoTheFuture)));
215 }
216 }
217 LabelType::Part(len) => {
218 if len == 0 {
219 return None;
221 }
222
223 if len > MAX_DOMAIN_NAME_LABEL_LENGTH as u8 {
224 return Some(Err(DnsMessageError::DnsError(DnsError::LabelTooLong)));
225 }
226
227 if self.bytes.len() < i + len as usize {
228 return Some(Err(DnsMessageError::DnsError(DnsError::MessageTooShort)));
230 }
231
232 let part = &self.bytes[i..i + len as usize];
233 self.offset = i + len as usize;
234
235 return Some(Ok(part))
236 }
237 }
238 }
239 }
240}
241
242pub struct CycleSafeNameIterator<'a> {
245 iter: NameIterator<'a>,
246 depth: [usize; MAX_DOMAIN_NAME_DEPTH],
247}
248
249impl<'a> Iterator for CycleSafeNameIterator<'a> {
250 type Item = Result<&'a [u8], DnsMessageError>;
251
252 fn next(&mut self) -> Option<Self::Item> {
253 let next = self.iter.next();
254
255 if let Some(Ok(part)) = next {
256 let part = part.as_ptr() as usize;
257
258 for &known_part in &self.depth[..self.iter.depth - 1] {
259 if known_part == part {
260 return Some(Err(DnsMessageError::DnsError(DnsError::PointerCycle)));
261 }
262 }
263
264 self.depth[self.iter.depth - 1] = part;
265 }
266
267 next
268 }
269}
270
271impl PartialEq<DnsName<'_>> for DnsName<'_> {
272 fn eq(&self, other: &DnsName<'_>) -> bool {
273 for (a, b) in self.iter().zip(other.iter()) {
274 match (a, b) {
275 (Ok(a), Ok(b)) => {
276 if a != b {
277 return false;
278 }
279 }
280 _ => {
281 return false;
282 }
283 }
284 }
285
286 true
287 }
288}
289
290impl Hash for DnsName<'_> {
291 fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
292 for part in self.iter() {
293 if let Err(_) = part {
294 return;
296 }
297
298 let part = part.unwrap();
299 state.write_u8(part.len() as u8);
300 state.write(part);
301 }
302 }
303}
304
305impl Display for DnsName<'_> {
306 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
307 let mut first = true;
308 for part in self.iter() {
309 if first {
310 first = false;
311 } else {
312 f.write_str(".")?;
313 }
314
315 let part = part.map_err(|_| core::fmt::Error)?;
316 f.write_str(core::str::from_utf8(part).unwrap())?;
317 }
318
319 Ok(())
320 }
321}
322
323impl Debug for DnsName<'_> {
324 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
325 f.write_str("DnsName(")?;
326 Display::fmt(self, f)?;
327 f.write_str(")")?;
328
329 Ok(())
330 }
331}
332
333#[derive(PartialEq)]
334enum LabelType {
335 Pointer(u16),
336 Part(u8),
337}
338
339impl LabelType {
340 fn from_bytes(bytes: &[u8], i: &mut usize) -> Result<Self, DnsMessageError> {
341 const PTR_MASK: u8 = 0b11000000;
342 const LEN_MASK: u8 = !PTR_MASK;
343
344 let c = u8::parse_bytes(bytes, i)?;
345
346 if c & PTR_MASK == PTR_MASK {
347 let c = c & LEN_MASK;
348 let pointer = u16::from_be_bytes([c, u8::parse_bytes(bytes, i)?]);
349 if pointer >= *i as u16 {
350 return Err(DnsMessageError::DnsError(DnsError::PointerIntoTheFuture));
352 }
353
354 Ok(Self::Pointer(pointer))
355 } else {
356 let len = c & LEN_MASK;
357
358 Ok(Self::Part(len))
359 }
360 }
361}