1use super::Error;
13
14use core::convert::TryInto;
15use core::fmt;
16use core::hash;
17use core::iter;
18use core::mem;
19use core::num::NonZeroUsize;
20
21use memchr::Memchr;
22
23pub trait Serialize<'a> {
25 fn serialized_len(&self) -> usize;
27
28 fn serialize(&self, bytes: &mut [u8]) -> Result<usize, Error>;
30}
31
32pub trait Deserialize<'a> {
34 fn deserialize(&mut self, cursor: Cursor<'a>) -> Result<Cursor<'a>, Error>;
36}
37
38#[derive(Debug, Copy, Clone)]
42pub struct Cursor<'a> {
43 bytes: &'a [u8],
45
46 cursor: usize,
48}
49
50impl<'a> Cursor<'a> {
51 pub fn new(bytes: &'a [u8]) -> Self {
53 Self { bytes, cursor: 0 }
54 }
55
56 pub fn original(&self) -> &'a [u8] {
58 self.bytes
59 }
60
61 pub fn remaining(&self) -> &'a [u8] {
63 &self.bytes[self.cursor..]
64 }
65
66 #[allow(clippy::len_without_is_empty)]
68 pub fn len(&self) -> usize {
69 self.bytes.len() - self.cursor
70 }
71
72 pub fn at(&self, pos: usize) -> Self {
74 Self {
75 bytes: self.bytes,
76 cursor: pos,
77 }
78 }
79
80 pub fn advance(mut self, n: usize) -> Result<Self, Error> {
82 if n == 0 {
83 return Ok(self);
84 }
85
86 if self.cursor + n > self.bytes.len() {
87 return Err(Error::NotEnoughReadBytes {
88 tried_to_read: NonZeroUsize::new(self.cursor.saturating_add(n)).unwrap(),
89 available: self.bytes.len(),
90 });
91 }
92
93 self.cursor += n;
94 Ok(self)
95 }
96
97 fn read_error(&self, n: usize) -> Error {
99 Error::NotEnoughReadBytes {
100 tried_to_read: NonZeroUsize::new(self.cursor.saturating_add(n)).unwrap(),
101 available: self.bytes.len(),
102 }
103 }
104}
105
106impl Serialize<'_> for () {
107 fn serialized_len(&self) -> usize {
108 0
109 }
110
111 fn serialize(&self, _bytes: &mut [u8]) -> Result<usize, Error> {
112 Ok(0)
113 }
114}
115
116impl<'a> Deserialize<'a> for () {
117 fn deserialize(&mut self, bytes: Cursor<'a>) -> Result<Cursor<'a>, Error> {
118 Ok(bytes)
119 }
120}
121
122#[derive(Clone, Copy)]
124pub struct Label<'a> {
125 repr: Repr<'a>,
126}
127
128#[derive(Clone, Copy)]
130enum Repr<'a> {
131 Bytes {
135 original: &'a [u8],
137
138 start: usize,
140
141 end: usize,
143 },
144
145 String {
147 string: &'a str,
149 },
150}
151
152impl Default for Label<'_> {
153 fn default() -> Self {
154 Self {
156 repr: Repr::Bytes {
157 original: &[0],
158 start: 0,
159 end: 1,
160 },
161 }
162 }
163}
164
165impl<'a> PartialEq<Label<'a>> for Label<'_> {
166 fn eq(&self, other: &Label<'a>) -> bool {
167 self.segments().eq(other.segments())
168 }
169}
170
171impl Eq for Label<'_> {}
172
173impl<'a> PartialOrd<Label<'a>> for Label<'_> {
174 fn partial_cmp(&self, other: &Label<'a>) -> Option<core::cmp::Ordering> {
175 self.segments().partial_cmp(other.segments())
176 }
177}
178
179impl Ord for Label<'_> {
180 fn cmp(&self, other: &Self) -> core::cmp::Ordering {
181 self.segments().cmp(other.segments())
182 }
183}
184
185impl hash::Hash for Label<'_> {
186 fn hash<H: hash::Hasher>(&self, state: &mut H) {
187 for segment in self.segments() {
188 segment.hash(state);
189 }
190 }
191}
192
193impl<'a> Label<'a> {
194 pub fn segments(&self) -> impl Iterator<Item = LabelSegment<'a>> {
196 match self.repr {
197 Repr::Bytes {
198 original, start, ..
199 } => Either::A(parse_bytes(original, start)),
200 Repr::String { string } => Either::B(parse_string(string)),
201 }
202 }
203
204 pub fn names(&self) -> impl Iterator<Item = Result<&'a str, &'a [u8]>> {
206 match self.repr {
207 Repr::String { string } => {
208 Either::A(parse_string(string).filter_map(|seg| seg.as_str().map(Ok)))
210 }
211 Repr::Bytes {
212 original, start, ..
213 } => {
214 let mut cursor = Cursor {
216 bytes: original,
217 cursor: start,
218 };
219
220 Either::B(iter::from_fn(move || {
221 loop {
222 let mut ls: LabelSegment<'_> = LabelSegment::Empty;
223 cursor = ls.deserialize(cursor).ok()?;
224
225 match ls {
227 LabelSegment::Empty => return None,
228 LabelSegment::Pointer(pos) => {
229 cursor = cursor.at(pos.into());
231 }
232 LabelSegment::String(label) => return Some(Ok(label)),
233 }
234 }
235 }))
236 }
237 }
238 }
239}
240
241impl<'a> From<&'a str> for Label<'a> {
242 fn from(string: &'a str) -> Self {
243 Self {
244 repr: Repr::String { string },
245 }
246 }
247}
248
249impl fmt::Debug for Label<'_> {
250 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
251 struct LabelFmt<'a>(&'a Label<'a>);
252
253 impl fmt::Debug for LabelFmt<'_> {
254 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
255 fmt::Display::fmt(self.0, f)
256 }
257 }
258
259 f.debug_tuple("Label").field(&LabelFmt(self)).finish()
260 }
261}
262
263impl fmt::Display for Label<'_> {
264 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
265 self.names().enumerate().try_for_each(|(i, name)| {
266 if i > 0 {
267 f.write_str(".")?;
268 }
269
270 match name {
271 Ok(name) => f.write_str(name),
272 Err(_) => f.write_str("???"),
273 }
274 })
275 }
276}
277
278impl<'a> Serialize<'a> for Label<'a> {
279 fn serialized_len(&self) -> usize {
280 if let Repr::Bytes { start, end, .. } = self.repr {
281 return end - start;
282 }
283
284 self.segments()
285 .map(|item| item.serialized_len())
286 .fold(0, |a, b| a.saturating_add(b))
287 }
288
289 fn serialize(&self, bytes: &mut [u8]) -> Result<usize, Error> {
290 if let Repr::Bytes {
292 original,
293 start,
294 end,
295 } = self.repr
296 {
297 bytes[..end - start].copy_from_slice(&original[start..end]);
298 return Ok(end - start);
299 }
300
301 self.segments().try_fold(0, |mut offset, item| {
302 let len = item.serialize(&mut bytes[offset..])?;
303 offset += len;
304 Ok(offset)
305 })
306 }
307}
308
309impl<'a> Deserialize<'a> for Label<'a> {
310 fn deserialize(&mut self, cursor: Cursor<'a>) -> Result<Cursor<'a>, Error> {
311 let original = cursor.original();
312 let start = cursor.cursor;
313
314 let mut end = start;
316 loop {
317 let len_char = match original.get(end) {
318 Some(0) => {
319 end += 1;
320 break;
321 }
322 Some(ptr) if ptr & PTR_MASK != 0 => {
323 end += 2;
325 break;
326 }
327 Some(len_char) => *len_char,
328 None => {
329 return Err(Error::NotEnoughReadBytes {
330 tried_to_read: NonZeroUsize::new(cursor.cursor + end).unwrap(),
331 available: original.len(),
332 })
333 }
334 };
335
336 let len = len_char as usize;
337 end += len + 1;
338 }
339
340 self.repr = Repr::Bytes {
341 original,
342 start,
343 end,
344 };
345 cursor.advance(end - start)
346 }
347}
348
349fn parse_bytes(bytes: &[u8], position: usize) -> impl Iterator<Item = LabelSegment<'_>> + '_ {
351 let mut cursor = Cursor {
352 bytes,
353 cursor: position,
354 };
355 let mut keep_going = true;
356
357 iter::from_fn(move || {
358 if !keep_going {
359 return None;
360 }
361
362 let mut segment = LabelSegment::Empty;
363 cursor = segment.deserialize(cursor).ok()?;
364
365 match segment {
366 LabelSegment::String(_) => {}
367 _ => {
368 keep_going = false;
369 }
370 }
371
372 Some(segment)
373 })
374}
375
376fn parse_string(str: &str) -> impl Iterator<Item = LabelSegment<'_>> + '_ {
378 let dot = Memchr::new(b'.', str.as_bytes());
379 let mut last_index = 0;
380
381 let dot = dot.chain(Some(str.len()));
383
384 dot.filter_map(move |index| {
385 let item = &str[last_index..index];
386 last_index = index.saturating_add(1);
387
388 if item.is_empty() {
389 None
390 } else {
391 Some(LabelSegment::String(item))
392 }
393 })
394 .chain(Some(LabelSegment::Empty))
395}
396
397#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
399pub enum LabelSegment<'a> {
400 Empty,
402
403 String(&'a str),
405
406 Pointer(u16),
408}
409
410const MAX_STR_LEN: usize = !PTR_MASK as usize;
411const PTR_MASK: u8 = 0b1100_0000;
412
413impl<'a> LabelSegment<'a> {
414 fn as_str(&self) -> Option<&'a str> {
415 match self {
416 Self::String(s) => Some(s),
417 _ => None,
418 }
419 }
420}
421
422impl Default for LabelSegment<'_> {
423 fn default() -> Self {
424 Self::Empty
425 }
426}
427
428impl<'a> Serialize<'a> for LabelSegment<'a> {
429 fn serialized_len(&self) -> usize {
430 match self {
431 Self::Empty => 1,
432 Self::Pointer(_) => 2,
433 Self::String(s) => 1 + s.len(),
434 }
435 }
436
437 fn serialize(&self, bytes: &mut [u8]) -> Result<usize, Error> {
438 match self {
439 Self::Empty => {
440 bytes[0] = 0;
442 Ok(1)
443 }
444 Self::Pointer(ptr) => {
445 let [mut b1, b2] = ptr.to_be_bytes();
447 b1 |= PTR_MASK;
448 bytes[0] = b1;
449 bytes[1] = b2;
450 Ok(2)
451 }
452 Self::String(s) => {
453 let len = s.len();
455
456 if len > MAX_STR_LEN {
457 return Err(Error::NameTooLong(len));
458 }
459
460 if len > bytes.len() {
461 panic!("not enough bytes to serialize string");
462 }
463
464 bytes[0] = len as u8;
465 bytes[1..=len].copy_from_slice(s.as_bytes());
466 Ok(len + 1)
467 }
468 }
469 }
470}
471
472impl<'a> Deserialize<'a> for LabelSegment<'a> {
473 fn deserialize(&mut self, cursor: Cursor<'a>) -> Result<Cursor<'a>, Error> {
474 let b1 = *cursor
476 .remaining()
477 .first()
478 .ok_or_else(|| cursor.read_error(1))?;
479
480 if b1 == 0 {
481 *self = Self::Empty;
483 cursor.advance(1)
484 } else if b1 & PTR_MASK == PTR_MASK {
485 let [b1, b2]: [u8; 2] = cursor.remaining()[..2]
487 .try_into()
488 .map_err(|_| cursor.read_error(2))?;
489 let ptr = u16::from_be_bytes([b1 & !PTR_MASK, b2]);
490 *self = Self::Pointer(ptr);
491 cursor.advance(2)
492 } else {
493 let len = b1 as usize;
495
496 if len > MAX_STR_LEN {
497 return Err(Error::NameTooLong(len));
498 }
499
500 let bytes = cursor.remaining()[1..=len]
502 .try_into()
503 .map_err(|_| cursor.read_error(len + 1))?;
504
505 let s = simdutf8::compat::from_utf8(bytes)?;
507 *self = Self::String(s);
508 cursor.advance(len + 1)
509 }
510 }
511}
512
513macro_rules! serialize_num {
514 ($($num_ty: ident),*) => {
515 $(
516 impl<'a> Serialize<'a> for $num_ty {
517 fn serialized_len(&self) -> usize {
518 mem::size_of::<$num_ty>()
519 }
520
521 fn serialize(&self, bytes: &mut [u8]) -> Result<usize, Error> {
522 if bytes.len() < mem::size_of::<$num_ty>() {
523 panic!("Not enough space to serialize a {}", stringify!($num_ty));
524 }
525
526 let value = (*self).to_be_bytes();
527 bytes[..mem::size_of::<$num_ty>()].copy_from_slice(&value);
528
529 Ok(mem::size_of::<$num_ty>())
530 }
531 }
532
533 impl<'a> Deserialize<'a> for $num_ty {
534 fn deserialize(&mut self, bytes: Cursor<'a>) -> Result<Cursor<'a>, Error> {
535 if bytes.len() < mem::size_of::<$num_ty>() {
536 return Err(bytes.read_error(mem::size_of::<$num_ty>()));
537 }
538
539 let mut value = [0; mem::size_of::<$num_ty>()];
540 value.copy_from_slice(&bytes.remaining()[..mem::size_of::<$num_ty>()]);
541 *self = $num_ty::from_be_bytes(value);
542
543 bytes.advance(mem::size_of::<$num_ty>())
544 }
545 }
546 )*
547 }
548}
549
550serialize_num! {
551 u8, u16, u32, u64,
552 i8, i16, i32, i64
553}
554
555enum Either<A, B> {
557 A(A),
558 B(B),
559}
560
561impl<A: Iterator, Other: Iterator<Item = A::Item>> Iterator for Either<A, Other> {
562 type Item = A::Item;
563
564 fn next(&mut self) -> Option<Self::Item> {
565 match self {
566 Either::A(a) => a.next(),
567 Either::B(b) => b.next(),
568 }
569 }
570
571 fn size_hint(&self) -> (usize, Option<usize>) {
572 match self {
573 Either::A(a) => a.size_hint(),
574 Either::B(b) => b.size_hint(),
575 }
576 }
577
578 fn fold<B, F>(self, init: B, f: F) -> B
579 where
580 Self: Sized,
581 F: FnMut(B, Self::Item) -> B,
582 {
583 match self {
584 Either::A(a) => a.fold(init, f),
585 Either::B(b) => b.fold(init, f),
586 }
587 }
588}