Skip to main content

mdns_proto/wire/
name.rs

1//! Compression-aware zero-copy DNS name reference.
2//!
3//! [`NameRef<'a>`] borrows into the full message buffer and iterates labels
4//! via [`NameLabels<'a>`]. Pointer chains are bounded by
5//! [`crate::constants::MAX_POINTER_HOPS`] and forward pointers (RFC 1035
6//! §4.1.4 prohibits them) are rejected.
7
8use crate::{
9  constants::{MAX_LABEL_BYTES, MAX_NAME_BYTES, MAX_POINTER_HOPS},
10  error::{BufferTooShortDetail, ParseError, PointerForwardDetail},
11  name::LabelTooLongDetail,
12};
13
14const POINTER_TAG_MASK: u8 = 0b1100_0000;
15const POINTER_TAG: u8 = 0b1100_0000;
16const POINTER_OFFSET_HI_MASK: u8 = 0b0011_1111;
17
18/// Zero-copy DNS name reference into a message buffer.
19#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
20pub struct NameRef<'a> {
21  message: &'a [u8],
22  start: u16,
23}
24
25impl<'a> NameRef<'a> {
26  /// Parses a name starting at `offset` in `message`. Returns the `NameRef`
27  /// and the number of bytes consumed at the *inline* position (i.e. how
28  /// far to advance the parser past this name in the original section —
29  /// not how long the fully-resolved name is).
30  pub fn try_parse(message: &'a [u8], offset: usize) -> Result<(Self, usize), ParseError> {
31    let mut cursor = offset;
32    loop {
33      let kind_byte = match message.get(cursor).copied() {
34        Some(b) => b,
35        None => {
36          return Err(ParseError::BufferTooShort(BufferTooShortDetail::new(
37            1, cursor, 0,
38          )));
39        }
40      };
41
42      if kind_byte == 0 {
43        // Root label terminator.
44        let consumed = cursor.saturating_sub(offset).saturating_add(1);
45        let start = u16::try_from(offset)?;
46        return Ok((Self { message, start }, consumed));
47      }
48
49      match kind_byte & POINTER_TAG_MASK {
50        0 => {
51          // Length-prefixed label.
52          let len = kind_byte as usize;
53          if kind_byte > MAX_LABEL_BYTES {
54            return Err(ParseError::LabelTooLong(LabelTooLongDetail::new(len)));
55          }
56          let after_len = cursor.saturating_add(1);
57          let label_end = after_len.saturating_add(len);
58          if message.len() < label_end {
59            return Err(ParseError::BufferTooShort(BufferTooShortDetail::new(
60              len,
61              after_len,
62              message.len().saturating_sub(after_len),
63            )));
64          }
65          cursor = label_end;
66        }
67        POINTER_TAG => {
68          // Two-byte pointer. Consume them and return.
69          let after_first = cursor.saturating_add(1);
70          let _second = match message.get(after_first).copied() {
71            Some(b) => b,
72            None => {
73              return Err(ParseError::BufferTooShort(BufferTooShortDetail::new(
74                1,
75                after_first,
76                0,
77              )));
78            }
79          };
80          let consumed = after_first.saturating_add(1).saturating_sub(offset);
81          let start = u16::try_from(offset)?;
82          return Ok((Self { message, start }, consumed));
83        }
84        _ => {
85          return Err(ParseError::InvalidLabelKind(kind_byte));
86        }
87      }
88    }
89  }
90
91  /// Returns an iterator over the labels of this name, following pointers
92  /// with a bounded hop count.
93  pub fn labels(&self) -> NameLabels<'a> {
94    NameLabels {
95      message: self.message,
96      cursor: self.start,
97      hops_remaining: MAX_POINTER_HOPS,
98      wire_octets: 1, // the eventual root terminator
99      done: false,
100    }
101  }
102
103  /// Returns the byte offset at which this name begins in the message.
104  #[inline(always)]
105  pub const fn start(&self) -> u16 {
106    self.start
107  }
108
109  cfg_heap! {
110  /// Writes this name's labels in DNS wire form (`len`-octet + label bytes …,
111  /// terminating root `0x00`), following any compression pointers so the
112  /// result is SELF-CONTAINED — it no longer references the source message.
113  ///
114  /// When `fold_case` is true, ASCII label bytes are lowercased — the canonical
115  /// case-insensitive form (RFC 6762 §16) used for record IDENTITY (cache
116  /// dedup / TTL=0 removal / cache-flush). When false, case is preserved so a
117  /// caller can surface the name for display.
118  ///
119  /// Enforces the RFC 1035 §3.1 limit: the *fully encoded* name (each length
120  /// octet + its label bytes + the root terminator) must be ≤
121  /// [`MAX_NAME_BYTES`]. The [`NameLabels`] iterator now bounds the EXPANDED
122  /// wire length too, so an over-long name yields
123  /// [`ParseError::NameTooLong`] from `label?` below; the redundant tally here
124  /// is kept as defense-in-depth. Also errors if the label iterator is
125  /// malformed (pointer cycle, forward pointer, or truncation).
126    pub(crate) fn write_wire(
127      &self,
128      out: &mut std::vec::Vec<u8>,
129      fold_case: bool,
130    ) -> Result<(), crate::error::ParseError> {
131      let mut encoded = 1usize; // the trailing root terminator
132      for label in self.labels() {
133        let label = label?;
134        if label.is_empty() {
135          break; // root; terminator appended below
136        }
137        let len = label.len().min(usize::from(MAX_LABEL_BYTES));
138        encoded = encoded.saturating_add(1usize.saturating_add(len));
139        if encoded > MAX_NAME_BYTES {
140          return Err(crate::error::ParseError::NameTooLong(encoded));
141        }
142        #[allow(clippy::cast_possible_truncation)]
143        out.push(len as u8);
144        if fold_case {
145          out.extend(
146            label
147              .iter()
148              .take(usize::from(MAX_LABEL_BYTES))
149              .map(u8::to_ascii_lowercase),
150          );
151        } else {
152          out.extend(label.iter().take(usize::from(MAX_LABEL_BYTES)).copied());
153        }
154      }
155      out.push(0);
156      Ok(())
157    }
158  }
159
160  /// Compares two names case-insensitively (mDNS rule, RFC 6762 §16).
161  pub fn equals_ignoring_case(&self, other: &NameRef<'_>) -> bool {
162    let mut a = self.labels();
163    let mut b = other.labels();
164    loop {
165      match (a.next(), b.next()) {
166        (None, None) => return true,
167        (Some(Ok(la)), Some(Ok(lb))) => {
168          if la.len() != lb.len() {
169            return false;
170          }
171          let mut i = 0usize;
172          while i < la.len() {
173            let ca = match la.get(i) {
174              Some(c) => *c,
175              None => return false,
176            };
177            let cb = match lb.get(i) {
178              Some(c) => *c,
179              None => return false,
180            };
181            if !ca.eq_ignore_ascii_case(&cb) {
182              return false;
183            }
184            i = i.saturating_add(1);
185          }
186        }
187        _ => return false,
188      }
189    }
190  }
191}
192
193/// Iterator over labels of a [`NameRef`], yielding `Result<&[u8], ParseError>`.
194pub struct NameLabels<'a> {
195  message: &'a [u8],
196  cursor: u16,
197  hops_remaining: u8,
198  /// accumulated EXPANDED wire length — the root terminator (seeded
199  /// at 1) plus, per label, one length octet + the label payload. The RFC 1035
200  /// §3.1 255-octet cap is on this expanded form, so counting full wire octets
201  /// (not just label payload) makes EVERY consumer of `labels()` — including
202  /// service canonicalization of peer names — reject an over-long name.
203  wire_octets: usize,
204  done: bool,
205}
206
207impl<'a> Iterator for NameLabels<'a> {
208  type Item = Result<&'a [u8], ParseError>;
209
210  #[allow(clippy::arithmetic_side_effects)]
211  fn next(&mut self) -> Option<Self::Item> {
212    if self.done {
213      return None;
214    }
215    loop {
216      let cursor_usize = self.cursor as usize;
217      let kind_byte = match self.message.get(cursor_usize).copied() {
218        Some(b) => b,
219        None => {
220          self.done = true;
221          return Some(Err(ParseError::BufferTooShort(BufferTooShortDetail::new(
222            1,
223            cursor_usize,
224            0,
225          ))));
226        }
227      };
228
229      if kind_byte == 0 {
230        self.done = true;
231        return None;
232      }
233
234      match kind_byte & POINTER_TAG_MASK {
235        0 => {
236          let len = kind_byte as usize;
237          if kind_byte > MAX_LABEL_BYTES {
238            self.done = true;
239            return Some(Err(ParseError::LabelTooLong(LabelTooLongDetail::new(len))));
240          }
241          let after_len = cursor_usize.saturating_add(1);
242          let label_end = after_len.saturating_add(len);
243          let label = match self.message.get(after_len..label_end) {
244            Some(s) => s,
245            None => {
246              self.done = true;
247              return Some(Err(ParseError::BufferTooShort(BufferTooShortDetail::new(
248                len,
249                after_len,
250                self.message.len().saturating_sub(after_len),
251              ))));
252            }
253          };
254          // Count the FULL wire length: this label's length octet + payload,
255          // on top of the root octet seeded at construction.
256          self.wire_octets = self.wire_octets.saturating_add(1usize.saturating_add(len));
257          if self.wire_octets > MAX_NAME_BYTES {
258            self.done = true;
259            return Some(Err(ParseError::NameTooLong(self.wire_octets)));
260          }
261          let next_cursor_usize = label_end;
262          self.cursor = match u16::try_from(next_cursor_usize) {
263            Ok(v) => v,
264            Err(e) => {
265              self.done = true;
266              return Some(Err(ParseError::IntegerConversion(e)));
267            }
268          };
269          return Some(Ok(label));
270        }
271        POINTER_TAG => {
272          if self.hops_remaining == 0 {
273            self.done = true;
274            return Some(Err(ParseError::PointerChainTooLong(MAX_POINTER_HOPS)));
275          }
276          self.hops_remaining = self.hops_remaining.saturating_sub(1);
277
278          let after_first = cursor_usize.saturating_add(1);
279          let second = match self.message.get(after_first).copied() {
280            Some(b) => b,
281            None => {
282              self.done = true;
283              return Some(Err(ParseError::BufferTooShort(BufferTooShortDetail::new(
284                1,
285                after_first,
286                0,
287              ))));
288            }
289          };
290          let hi = (kind_byte & POINTER_OFFSET_HI_MASK) as u16;
291          let target = (hi << 8) | (second as u16);
292          if target >= self.cursor {
293            self.done = true;
294            return Some(Err(ParseError::PointerForward(PointerForwardDetail::new(
295              target,
296              self.cursor,
297            ))));
298          }
299          self.cursor = target;
300          // Loop: resolve the new cursor as a fresh label.
301        }
302        _ => {
303          self.done = true;
304          return Some(Err(ParseError::InvalidLabelKind(kind_byte)));
305        }
306      }
307    }
308  }
309}
310
311#[cfg(test)]
312#[cfg(any(feature = "alloc", feature = "std"))]
313#[allow(clippy::unwrap_used)]
314mod tests;