1use super::Error;
13
14use core::convert::TryInto;
15use core::fmt;
16use core::hash;
17use core::iter;
18use core::mem;
19use core::num::NonZeroUsize;
20
21use memchr::Memchr;
22
23pub(crate) trait Serialize<'a> {
25 fn serialized_len(&self) -> usize;
27
28 fn serialize(&self, bytes: &mut [u8]) -> Result<usize, Error>;
30
31 fn deserialize(&mut self, cursor: Cursor<'a>) -> Result<Cursor<'a>, Error>;
33}
34
35#[derive(Debug, Copy, Clone)]
39pub(crate) struct Cursor<'a> {
40 bytes: &'a [u8],
42
43 cursor: usize,
45}
46
47impl<'a> Cursor<'a> {
48 pub(crate) fn new(bytes: &'a [u8]) -> Self {
50 Self { bytes, cursor: 0 }
51 }
52
53 pub(crate) fn original(&self) -> &'a [u8] {
55 self.bytes
56 }
57
58 pub(crate) fn remaining(&self) -> &'a [u8] {
60 &self.bytes[self.cursor..]
61 }
62
63 pub(crate) fn len(&self) -> usize {
65 self.bytes.len() - self.cursor
66 }
67
68 pub(crate) fn at(&self, pos: usize) -> Self {
70 Self {
71 bytes: self.bytes,
72 cursor: pos,
73 }
74 }
75
76 pub(crate) fn advance(mut self, n: usize) -> Result<Self, Error> {
78 if n == 0 {
79 return Ok(self);
80 }
81
82 if self.cursor + n > self.bytes.len() {
83 return Err(Error::NotEnoughReadBytes {
84 tried_to_read: NonZeroUsize::new(self.cursor.saturating_add(n)).unwrap(),
85 available: self.bytes.len(),
86 });
87 }
88
89 self.cursor += n;
90 Ok(self)
91 }
92
93 fn read_error(&self, n: usize) -> Error {
95 Error::NotEnoughReadBytes {
96 tried_to_read: NonZeroUsize::new(self.cursor.saturating_add(n)).unwrap(),
97 available: self.bytes.len(),
98 }
99 }
100}
101
102impl<'a> Serialize<'a> for () {
103 fn serialized_len(&self) -> usize {
104 0
105 }
106
107 fn serialize(&self, _bytes: &mut [u8]) -> Result<usize, Error> {
108 Ok(0)
109 }
110
111 fn deserialize(&mut self, bytes: Cursor<'a>) -> Result<Cursor<'a>, Error> {
112 Ok(bytes)
113 }
114}
115
116#[derive(Clone, Copy)]
118pub struct Label<'a> {
119 repr: Repr<'a>,
120}
121
122#[derive(Clone, Copy)]
124enum Repr<'a> {
125 Bytes {
129 original: &'a [u8],
131
132 start: usize,
134
135 end: usize,
137 },
138
139 String {
141 string: &'a str,
143 },
144}
145
146impl Default for Label<'_> {
147 fn default() -> Self {
148 Self {
150 repr: Repr::Bytes {
151 original: &[0],
152 start: 0,
153 end: 1,
154 },
155 }
156 }
157}
158
159impl<'a, 'b> PartialEq<Label<'a>> for Label<'b> {
160 fn eq(&self, other: &Label<'a>) -> bool {
161 self.segments().eq(other.segments())
162 }
163}
164
165impl Eq for Label<'_> {}
166
167impl<'a, 'b> PartialOrd<Label<'a>> for Label<'b> {
168 fn partial_cmp(&self, other: &Label<'a>) -> Option<core::cmp::Ordering> {
169 self.segments().partial_cmp(other.segments())
170 }
171}
172
173impl Ord for Label<'_> {
174 fn cmp(&self, other: &Self) -> core::cmp::Ordering {
175 self.segments().cmp(other.segments())
176 }
177}
178
179impl hash::Hash for Label<'_> {
180 fn hash<H: hash::Hasher>(&self, state: &mut H) {
181 for segment in self.segments() {
182 segment.hash(state);
183 }
184 }
185}
186
187impl<'a> Label<'a> {
188 pub fn segments(&self) -> impl Iterator<Item = LabelSegment<'a>> {
190 match self.repr {
191 Repr::Bytes {
192 original, start, ..
193 } => Either::A(parse_bytes(original, start)),
194 Repr::String { string } => Either::B(parse_string(string)),
195 }
196 }
197
198 pub fn names(&self) -> impl Iterator<Item = Result<&'a str, &'a [u8]>> {
200 match self.repr {
201 Repr::String { string } => {
202 Either::A(parse_string(string).filter_map(|seg| seg.as_str().map(Ok)))
204 }
205 Repr::Bytes {
206 original, start, ..
207 } => {
208 let mut cursor = Cursor {
210 bytes: original,
211 cursor: start,
212 };
213
214 Either::B(iter::from_fn(move || {
215 loop {
216 let mut ls: LabelSegment<'_> = LabelSegment::Empty;
217 cursor = ls.deserialize(cursor).ok()?;
218
219 match ls {
221 LabelSegment::Empty => return None,
222 LabelSegment::Pointer(pos) => {
223 cursor = cursor.at(pos.into());
225 }
226 LabelSegment::String(label) => return Some(Ok(label)),
227 }
228 }
229 }))
230 }
231 }
232 }
233}
234
235impl<'a> From<&'a str> for Label<'a> {
236 fn from(string: &'a str) -> Self {
237 Self {
238 repr: Repr::String { string },
239 }
240 }
241}
242
243impl fmt::Debug for Label<'_> {
244 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
245 struct LabelFmt<'a>(&'a Label<'a>);
246
247 impl fmt::Debug for LabelFmt<'_> {
248 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
249 fmt::Display::fmt(self.0, f)
250 }
251 }
252
253 f.debug_tuple("Label").field(&LabelFmt(self)).finish()
254 }
255}
256
257impl fmt::Display for Label<'_> {
258 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
259 self.names().enumerate().try_for_each(|(i, name)| {
260 if i > 0 {
261 f.write_str(".")?;
262 }
263
264 match name {
265 Ok(name) => f.write_str(name),
266 Err(_) => f.write_str("???"),
267 }
268 })
269 }
270}
271
272impl<'a> Serialize<'a> for Label<'a> {
273 fn serialized_len(&self) -> usize {
274 if let Repr::Bytes { start, end, .. } = self.repr {
275 return end - start;
276 }
277
278 self.segments()
279 .map(|item| item.serialized_len())
280 .fold(0, |a, b| a.saturating_add(b))
281 }
282
283 fn serialize(&self, bytes: &mut [u8]) -> Result<usize, Error> {
284 if let Repr::Bytes {
286 original,
287 start,
288 end,
289 } = self.repr
290 {
291 bytes[..end - start].copy_from_slice(&original[start..end]);
292 return Ok(end - start);
293 }
294
295 self.segments().try_fold(0, |mut offset, item| {
296 let len = item.serialize(&mut bytes[offset..])?;
297 offset += len;
298 Ok(offset)
299 })
300 }
301
302 fn deserialize(&mut self, cursor: Cursor<'a>) -> Result<Cursor<'a>, Error> {
303 let original = cursor.original();
304 let start = cursor.cursor;
305
306 let mut end = start;
308 loop {
309 let len_char = match original.get(end) {
310 Some(0) => {
311 end += 1;
312 break;
313 }
314 Some(ptr) if ptr & PTR_MASK != 0 => {
315 end += 2;
317 break;
318 }
319 Some(len_char) => *len_char,
320 None => {
321 return Err(Error::NotEnoughReadBytes {
322 tried_to_read: NonZeroUsize::new(cursor.cursor + end).unwrap(),
323 available: original.len(),
324 })
325 }
326 };
327
328 let len = len_char as usize;
329 end += len + 1;
330 }
331
332 self.repr = Repr::Bytes {
333 original,
334 start,
335 end,
336 };
337 cursor.advance(end - start)
338 }
339}
340
341fn parse_bytes(bytes: &[u8], position: usize) -> impl Iterator<Item = LabelSegment<'_>> + '_ {
343 let mut cursor = Cursor {
344 bytes,
345 cursor: position,
346 };
347 let mut keep_going = true;
348
349 iter::from_fn(move || {
350 if !keep_going {
351 return None;
352 }
353
354 let mut segment = LabelSegment::Empty;
355 cursor = segment.deserialize(cursor).ok()?;
356
357 match segment {
358 LabelSegment::String(_) => {}
359 _ => {
360 keep_going = false;
361 }
362 }
363
364 Some(segment)
365 })
366}
367
368fn parse_string(str: &str) -> impl Iterator<Item = LabelSegment<'_>> + '_ {
370 let dot = Memchr::new(b'.', str.as_bytes());
371 let mut last_index = 0;
372
373 let dot = dot.chain(Some(str.len()));
375
376 dot.filter_map(move |index| {
377 let item = &str[last_index..index];
378 last_index = index.saturating_add(1);
379
380 if item.is_empty() {
381 None
382 } else {
383 Some(LabelSegment::String(item))
384 }
385 })
386 .chain(Some(LabelSegment::Empty))
387}
388
389#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
391pub enum LabelSegment<'a> {
392 Empty,
394
395 String(&'a str),
397
398 Pointer(u16),
400}
401
402const MAX_STR_LEN: usize = !PTR_MASK as usize;
403const PTR_MASK: u8 = 0b1100_0000;
404
405impl<'a> LabelSegment<'a> {
406 fn as_str(&self) -> Option<&'a str> {
407 match self {
408 Self::String(s) => Some(s),
409 _ => None,
410 }
411 }
412}
413
414impl Default for LabelSegment<'_> {
415 fn default() -> Self {
416 Self::Empty
417 }
418}
419
420impl<'a> Serialize<'a> for LabelSegment<'a> {
421 fn serialized_len(&self) -> usize {
422 match self {
423 Self::Empty => 1,
424 Self::Pointer(_) => 2,
425 Self::String(s) => 1 + s.len(),
426 }
427 }
428
429 fn serialize(&self, bytes: &mut [u8]) -> Result<usize, Error> {
430 match self {
431 Self::Empty => {
432 bytes[0] = 0;
434 Ok(1)
435 }
436 Self::Pointer(ptr) => {
437 let [mut b1, b2] = ptr.to_be_bytes();
439 b1 |= PTR_MASK;
440 bytes[0] = b1;
441 bytes[1] = b2;
442 Ok(2)
443 }
444 Self::String(s) => {
445 let len = s.len();
447
448 if len > MAX_STR_LEN {
449 return Err(Error::NameTooLong(len));
450 }
451
452 if len > bytes.len() {
453 panic!("not enough bytes to serialize string");
454 }
455
456 bytes[0] = len as u8;
457 bytes[1..=len].copy_from_slice(s.as_bytes());
458 Ok(len + 1)
459 }
460 }
461 }
462
463 fn deserialize(&mut self, cursor: Cursor<'a>) -> Result<Cursor<'a>, Error> {
464 let b1 = *cursor
466 .remaining()
467 .first()
468 .ok_or_else(|| cursor.read_error(1))?;
469
470 if b1 == 0 {
471 *self = Self::Empty;
473 cursor.advance(1)
474 } else if b1 & PTR_MASK == PTR_MASK {
475 let [b1, b2]: [u8; 2] = cursor.remaining()[..2]
477 .try_into()
478 .map_err(|_| cursor.read_error(2))?;
479 let ptr = u16::from_be_bytes([b1 & !PTR_MASK, b2]);
480 *self = Self::Pointer(ptr);
481 cursor.advance(2)
482 } else {
483 let len = b1 as usize;
485
486 if len > MAX_STR_LEN {
487 return Err(Error::NameTooLong(len));
488 }
489
490 let bytes = cursor.remaining()[1..=len]
492 .try_into()
493 .map_err(|_| cursor.read_error(len + 1))?;
494
495 let s = core::str::from_utf8(bytes)?;
497 *self = Self::String(s);
498 cursor.advance(len + 1)
499 }
500 }
501}
502
503macro_rules! serialize_num {
504 ($($num_ty: ident),*) => {
505 $(
506 impl<'a> Serialize<'a> for $num_ty {
507 fn serialized_len(&self) -> usize {
508 mem::size_of::<$num_ty>()
509 }
510
511 fn serialize(&self, bytes: &mut [u8]) -> Result<usize, Error> {
512 if bytes.len() < mem::size_of::<$num_ty>() {
513 panic!("Not enough space to serialize a {}", stringify!($num_ty));
514 }
515
516 let value = (*self).to_be_bytes();
517 bytes[..mem::size_of::<$num_ty>()].copy_from_slice(&value);
518
519 Ok(mem::size_of::<$num_ty>())
520 }
521
522 fn deserialize(&mut self, bytes: Cursor<'a>) -> Result<Cursor<'a>, Error> {
523 if bytes.len() < mem::size_of::<$num_ty>() {
524 return Err(bytes.read_error(mem::size_of::<$num_ty>()));
525 }
526
527 let mut value = [0; mem::size_of::<$num_ty>()];
528 value.copy_from_slice(&bytes.remaining()[..mem::size_of::<$num_ty>()]);
529 *self = $num_ty::from_be_bytes(value);
530
531 bytes.advance(mem::size_of::<$num_ty>())
532 }
533 }
534 )*
535 }
536}
537
538serialize_num! {
539 u8, u16, u32, u64,
540 i8, i16, i32, i64
541}
542
543enum Either<A, B> {
545 A(A),
546 B(B),
547}
548
549impl<A: Iterator, Other: Iterator<Item = A::Item>> Iterator for Either<A, Other> {
550 type Item = A::Item;
551
552 fn next(&mut self) -> Option<Self::Item> {
553 match self {
554 Either::A(a) => a.next(),
555 Either::B(b) => b.next(),
556 }
557 }
558
559 fn size_hint(&self) -> (usize, Option<usize>) {
560 match self {
561 Either::A(a) => a.size_hint(),
562 Either::B(b) => b.size_hint(),
563 }
564 }
565
566 fn fold<B, F>(self, init: B, f: F) -> B
567 where
568 Self: Sized,
569 F: FnMut(B, Self::Item) -> B,
570 {
571 match self {
572 Either::A(a) => a.fold(init, f),
573 Either::B(b) => b.fold(init, f),
574 }
575 }
576}