1use crate::WireType;
2use crate::varint::Varint;
3
4use std::fmt;
5use std::hash::Hash;
6use std::collections::HashMap;
7
8use bytes::{Bytes, BytesRead, BytesReadRef};
9
10
11#[derive(Debug, Clone, PartialEq, Eq)]
12#[non_exhaustive]
13pub enum DecodeError {
14 UnexpectedEof,
15 ExpectedEof,
16 InvalidVarint,
17 InvalidWireType(u8),
18 WireTypeMismatch,
19 ExpectedVarintWireType,
20 ExpectedI32WireType,
21 ExpectedI64WireType,
22 ExpectedLenWireType,
23 ExpectedUtf8,
24 ExpectedArrayLen(usize),
25 Other(String)
26}
27
28impl fmt::Display for DecodeError {
29 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
30 match self {
31 Self::UnexpectedEof => write!(f, "unexpected end of file"),
32 Self::ExpectedEof => write!(f, "expected end of file"),
33 Self::InvalidVarint => write!(f, "varint is invalid"),
34 Self::InvalidWireType(t) => {
35 write!(f, "the wiretype {t} is invalid")
36 },
37 Self::WireTypeMismatch => write!(f, "wire types don't match"),
38 Self::ExpectedVarintWireType => {
39 write!(f, "expected a varint wire type")
40 },
41 Self::ExpectedI32WireType => write!(f, "expected a i32 wire type"),
42 Self::ExpectedI64WireType => write!(f, "expected a i64 wire type"),
43 Self::ExpectedLenWireType => {
44 write!(f, "expected the len wire type")
45 },
46 Self::ExpectedUtf8 => write!(f, "expected a valid utf8 string"),
47 Self::ExpectedArrayLen(n) => {
48 write!(f, "expected an array length of {n}")
49 },
50 Self::Other(s) => write!(f, "decode error: {s}")
51 }
52 }
53}
54
55impl std::error::Error for DecodeError {}
56
57#[derive(Debug)]
58pub struct MessageDecoder<'a> {
59 inner: Bytes<'a>
60}
61
62impl<'a> MessageDecoder<'a> {
63 pub fn new(bytes: &'a [u8]) -> Self {
64 Self {
65 inner: Bytes::from(bytes)
66 }
67 }
68
69 pub fn try_from_kind(kind: FieldKind<'a>) -> Result<Self, DecodeError> {
70 kind.try_unwrap_len().map(Self::new)
71 }
72
73 pub(crate) fn next_varint(&mut self) -> Result<u64, DecodeError> {
74 Varint::read(&mut self.inner)
75 .map(|v| v.0)
76 .map_err(|_| DecodeError::InvalidVarint)
77 }
78
79 fn next_kind(
80 &mut self,
81 ty: WireType
82 ) -> Result<FieldKind<'a>, DecodeError> {
83 let kind = match ty {
84 WireType::Varint => FieldKind::Varint(self.next_varint()?),
85 WireType::I64 => FieldKind::I64(
86 self.inner.try_read_le_u64()
87 .map_err(|_| DecodeError::UnexpectedEof)?
88 ),
89 WireType::I32 => FieldKind::I32(
90 self.inner.try_read_le_u32()
91 .map_err(|_| DecodeError::UnexpectedEof)?
92 ),
93 WireType::Len => {
94 let len = self.next_varint()?;
95 let bytes = self.inner.try_read_ref(len as usize)
96 .map_err(|_| DecodeError::UnexpectedEof)?;
97
98 FieldKind::Len(bytes)
99 }
100 };
101
102 Ok(kind)
103 }
104
105 pub(crate) fn maybe_next_kind(
107 &mut self,
108 ty: WireType
109 ) -> Result<Option<FieldKind<'a>>, DecodeError> {
110 if self.inner.remaining().is_empty() {
111 return Ok(None)
112 }
113
114 self.next_kind(ty).map(Some)
115 }
116
117 pub fn next(&mut self) -> Result<Option<Field<'a>>, DecodeError> {
120 if self.inner.remaining().is_empty() {
121 return Ok(None)
122 }
123
124 let tag = self.next_varint()?;
125 let wtype = WireType::from_tag(tag)?;
126 let number = tag >> 3;
127
128 let kind = self.next_kind(wtype)?;
129
130 Ok(Some(Field { number, kind }))
131 }
132
133 pub fn finish(self) -> Result<(), DecodeError> {
134 if self.inner.remaining().is_empty() {
135 Ok(())
136 } else {
137 Err(DecodeError::ExpectedEof)
138 }
139 }
140}
141
142#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
143pub struct Field<'a> {
144 pub number: u64,
145 pub kind: FieldKind<'a>
146}
147
148#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
149pub enum FieldKind<'a> {
150 Varint(u64),
154 I32(u32),
157 I64(u64),
160
161 Len(&'a [u8])
162}
163
164impl<'a> FieldKind<'a> {
165 pub fn is_len(&self) -> bool {
166 matches!(self, Self::Len(_))
167 }
168
169 pub fn wire_type(&self) -> WireType {
170 match self {
171 Self::Varint(_) => WireType::Varint,
172 Self::I32(_) => WireType::I32,
173 Self::I64(_) => WireType::I64,
174 Self::Len(_) => WireType::Len
175 }
176 }
177
178 pub fn try_unwrap_varint(&self) -> Result<u64, DecodeError> {
179 match self {
180 Self::Varint(n) => Ok(*n),
181 _ => Err(DecodeError::ExpectedVarintWireType)
182 }
183 }
184
185 pub fn try_unwrap_i32(&self) -> Result<u32, DecodeError> {
186 match self {
187 Self::I32(n) => Ok(*n),
188 _ => Err(DecodeError::ExpectedI32WireType)
189 }
190 }
191
192 pub fn try_unwrap_i64(&self) -> Result<u64, DecodeError> {
193 match self {
194 Self::I64(n) => Ok(*n),
195 _ => Err(DecodeError::ExpectedI64WireType)
196 }
197 }
198
199 pub fn try_unwrap_len(&self) -> Result<&'a [u8], DecodeError> {
201 match self {
202 Self::Len(b) => Ok(b),
203 _ => Err(DecodeError::ExpectedLenWireType)
204 }
205 }
206}
207
208
209
210pub trait DecodeMessage<'m> {
211 const WIRE_TYPE: WireType;
215
216 fn parse_from_bytes(b: &'m [u8]) -> Result<Self, DecodeError>
217 where Self: Sized {
218 let mut this = Self::decode_default();
219
220 this.merge(FieldKind::Len(b), false)?;
221
222 Ok(this)
223 }
224
225 fn decode_default() -> Self;
226
227 fn merge(
231 &mut self,
232 kind: FieldKind<'m>,
233 is_field: bool
234 ) -> Result<(), DecodeError>;
235}
236
237pub trait DecodeMessageOwned: for<'m> DecodeMessage<'m> {}
238
239impl<T> DecodeMessageOwned for T
240where T: for<'m> DecodeMessage<'m> {}
241
242impl<'m, V> DecodeMessage<'m> for Vec<V>
244where V: DecodeMessage<'m> {
245 const WIRE_TYPE: WireType = WireType::Len;
246
247 fn decode_default() -> Self {
248 Self::new()
249 }
250
251 fn merge(
252 &mut self,
253 kind: FieldKind<'m>,
254 is_field: bool
255 ) -> Result<(), DecodeError> {
256 if !is_field {
260 let mut parser = MessageDecoder::try_from_kind(kind)?;
261
262 while let Some(field) = parser.next()? {
263 if field.number != 1 {
264 continue
265 }
266
267 self.merge(field.kind, true)?;
269 }
270
271 return parser.finish();
272 }
273
274 if kind.is_len() && V::WIRE_TYPE.can_be_packed() {
276 let mut parser = MessageDecoder::try_from_kind(kind)?;
277 while let Some(k) = parser.maybe_next_kind(V::WIRE_TYPE)? {
278 let mut v = V::decode_default();
279 v.merge(k, false)?;
280
281 self.push(v);
282 }
283
284 return parser.finish()
285 }
286
287
288 let mut v = V::decode_default();
289 v.merge(kind, false)?;
290
291 self.push(v);
292
293 Ok(())
294 }
295}
296
297impl<'m, K, V> DecodeMessage<'m> for HashMap<K, V>
298where
299 K: DecodeMessage<'m> + Eq + Hash,
300 V: DecodeMessage<'m>
301{
302 const WIRE_TYPE: WireType = WireType::Len;
303
304 fn decode_default() -> Self {
305 Self::new()
306 }
307
308 fn merge(
309 &mut self,
310 kind: FieldKind<'m>,
311 is_field: bool
312 ) -> Result<(), DecodeError> {
313 if !is_field {
317 let mut parser = MessageDecoder::try_from_kind(kind)?;
318
319 while let Some(field) = parser.next()? {
320 if field.number != 1 {
321 continue
322 }
323
324 self.merge(field.kind, true)?;
326 }
327
328 return parser.finish();
329 }
330
331 let mut field = <(K, V)>::decode_default();
332 field.merge(kind, false)?;
333
334 self.insert(field.0, field.1);
335
336 Ok(())
337 }
338}
339
340impl<'m> DecodeMessage<'m> for Vec<u8> {
341 const WIRE_TYPE: WireType = WireType::Len;
342
343 fn decode_default() -> Self {
344 Self::new()
345 }
346
347 fn merge(
348 &mut self,
349 kind: FieldKind<'m>,
350 _is_field: bool
351 ) -> Result<(), DecodeError> {
352 let bytes = kind.try_unwrap_len()?;
353 self.clear();
354 self.extend_from_slice(bytes);
355
356 Ok(())
357 }
358}
359
360impl<'m, const S: usize> DecodeMessage<'m> for [u8; S] {
361 const WIRE_TYPE: WireType = WireType::Len;
362
363 fn decode_default() -> Self {
364 [0; S]
365 }
366
367 fn merge(
368 &mut self,
369 kind: FieldKind<'m>,
370 _is_field: bool
371 ) -> Result<(), DecodeError> {
372 let bytes = kind.try_unwrap_len()?;
373
374 if bytes.len() != S {
375 return Err(DecodeError::ExpectedArrayLen(S))
376 }
377
378 self.copy_from_slice(bytes);
379
380 Ok(())
381 }
382}
383
384macro_rules! impl_tuple {
386 ($($gen:ident, $idx:tt),*) => (
387 impl<'m, $($gen),*> DecodeMessage<'m> for ($($gen),*)
388 where
389 $($gen: DecodeMessage<'m>),*
390 {
391 const WIRE_TYPE: WireType = WireType::Len;
392
393 fn decode_default() -> Self {
394 ($(
395 $gen::decode_default()
396 ),*)
397 }
398
399 fn merge(
400 &mut self,
401 kind: FieldKind<'m>,
402 _is_field: bool
403 ) -> Result<(), DecodeError> {
404 let mut parser = MessageDecoder::try_from_kind(kind)?;
405
406 while let Some(field) = parser.next()? {
407 match field.number {
408 $(
409 $idx => self.$idx.merge(field.kind, true)?
410 ),*,
411 _ => {}
413 }
414 }
415
416 parser.finish()
417 }
418 }
419 )
420}
421
422impl_tuple![
426 A, 0,
427 B, 1
428];
429impl_tuple![
430 A, 0,
431 B, 1,
432 C, 2
433];
434impl_tuple![
435 A, 0,
436 B, 1,
437 C, 2,
438 D, 3
439];
440impl_tuple![
441 A, 0,
442 B, 1,
443 C, 2,
444 D, 3,
445 E, 4
446];
447impl_tuple![
448 A, 0,
449 B, 1,
450 C, 2,
451 D, 3,
452 E, 4,
453 F, 5
454];
455
456impl<'m> DecodeMessage<'m> for String {
457 const WIRE_TYPE: WireType = WireType::Len;
458
459 fn decode_default() -> Self {
460 Self::new()
461 }
462
463 fn merge(
464 &mut self,
465 kind: FieldKind<'m>,
466 _is_field: bool
467 ) -> Result<(), DecodeError> {
468 let bytes = kind.try_unwrap_len()?;
469 self.clear();
470 let s = std::str::from_utf8(bytes)
471 .map_err(|_| DecodeError::ExpectedUtf8)?;
472 self.push_str(s);
473
474 Ok(())
475 }
476}
477
478impl<'m, V> DecodeMessage<'m> for Option<V>
479where V: DecodeMessage<'m> {
480 const WIRE_TYPE: WireType = WireType::Len;
481
482 fn decode_default() -> Self {
483 None
484 }
485
486 fn merge(
487 &mut self,
488 kind: FieldKind<'m>,
489 is_field: bool
490 ) -> Result<(), DecodeError> {
491 if !is_field {
495 let mut parser = MessageDecoder::try_from_kind(kind)?;
496
497 while let Some(field) = parser.next()? {
498 if field.number != 1 {
499 continue
500 }
501
502 self.merge(field.kind, true)?;
504 }
505
506 return parser.finish();
507 }
508
509 match self {
510 Some(v) => {
511 v.merge(kind, false)?;
512 }
513 None => {
514 let mut v = V::decode_default();
515 v.merge(kind, false)?;
516 *self = Some(v);
517 }
518 }
519
520 Ok(())
521 }
522}
523
524impl<'m> DecodeMessage<'m> for bool {
525 const WIRE_TYPE: WireType = WireType::Varint;
526
527 fn decode_default() -> Self {
528 false
529 }
530
531 fn merge(
532 &mut self,
533 kind: FieldKind<'m>,
534 _is_field: bool
535 ) -> Result<(), DecodeError> {
536 let num = kind.try_unwrap_varint()?;
537 *self = num != 0;
538
539 Ok(())
540 }
541}
542
543macro_rules! impl_varint {
545 ($($ty:ty),*) => ($(
546 impl<'m> DecodeMessage<'m> for $ty {
547 const WIRE_TYPE: WireType = WireType::Varint;
548
549 fn decode_default() -> Self {
550 Default::default()
551 }
552
553 fn merge(
554 &mut self,
555 kind: FieldKind<'m>,
556 _is_field: bool
557 ) -> Result<(), DecodeError> {
558 let num = kind.try_unwrap_varint()?;
559 *self = num as $ty;
560
561 Ok(())
562 }
563 }
564 )*)
565}
566
567impl_varint![i32, i64, u32, u64];
568
569macro_rules! impl_floats {
570 ($($src:ident, $wtype:ident as $ty:ty),*) => ($(
571 impl<'m> DecodeMessage<'m> for $ty {
572 const WIRE_TYPE: WireType = WireType::$wtype;
573
574 fn decode_default() -> Self {
575 Default::default()
576 }
577
578 fn merge(
579 &mut self,
580 kind: FieldKind<'m>,
581 _is_field: bool
582 ) -> Result<(), DecodeError> {
583 let num = kind.$src()?;
584 *self = num as $ty;
585
586 Ok(())
587 }
588 }
589 )*)
590}
591
592impl_floats![
593 try_unwrap_i32, I32 as f32,
594 try_unwrap_i64, I64 as f64
595];
596
597#[repr(transparent)]
598#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
599pub struct ZigZag<T>(pub T);
600
601macro_rules! impl_zigzag {
602 ($($ty:ty),*) => ($(
603 impl<'m> DecodeMessage<'m> for ZigZag<$ty> {
604 const WIRE_TYPE: WireType = WireType::Varint;
605
606 fn decode_default() -> Self {
607 Default::default()
608 }
609
610 fn merge(
611 &mut self,
612 kind: FieldKind<'m>,
613 _is_field: bool
614 ) -> Result<(), DecodeError> {
615 let num = kind.try_unwrap_varint()? as $ty;
616 let num = (num >> 1) ^ -(num & 1);
617 *self = ZigZag(num);
618
619 Ok(())
620 }
621 }
622 )*)
623}
624
625impl_zigzag![i32, i64];
626
627#[repr(transparent)]
628#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
629pub struct Fixed<T>(pub T);
630
631macro_rules! impl_fixed {
632 ($($src:ident, $wtype:ident as $ty:ty),*) => ($(
633 impl<'m> DecodeMessage<'m> for Fixed<$ty> {
634 const WIRE_TYPE: WireType = WireType::$wtype;
635
636 fn decode_default() -> Self {
637 Default::default()
638 }
639
640 fn merge(
641 &mut self,
642 kind: FieldKind<'m>,
643 _is_field: bool
644 ) -> Result<(), DecodeError> {
645 let num = kind.$src()?;
646 *self = Fixed(num as $ty);
647
648 Ok(())
649 }
650 }
651 )*)
652}
653
654impl_fixed![
655 try_unwrap_i32, I32 as u32, try_unwrap_i32, I32 as i32,
656 try_unwrap_i64, I64 as u64, try_unwrap_i64, I64 as i64
657];
658
659
660#[cfg(test)]
661mod tests {
662 use super::*;
663
664 use hex_literal::hex;
665
666 #[test]
667 fn string_and_repeated_test_4() {
668 const MSG: &[u8] = &hex!("220568656c6c6f280128022803");
669
670 let mut parser = MessageDecoder::new(MSG);
671
672 let hello_str = Field { number: 4, kind: FieldKind::Len(b"hello") };
673 assert_eq!(parser.next().unwrap().unwrap(), hello_str);
674
675 let mut repeated = Field { number: 5, kind: FieldKind::Varint(1) };
676
677 assert_eq!(parser.next().unwrap().unwrap(), repeated);
678 repeated.kind = FieldKind::Varint(2);
679 assert_eq!(parser.next().unwrap().unwrap(), repeated);
680 repeated.kind = FieldKind::Varint(3);
681 assert_eq!(parser.next().unwrap().unwrap(), repeated);
682
683 assert!(parser.next().unwrap().is_none());
684 }
685
686 #[test]
687 fn repeated_packet() {
688 const MSG: &[u8] = &hex!("3206038e029ea705");
689
690 let mut parser = MessageDecoder::new(MSG);
691
692 let packed = parser.next().unwrap().unwrap();
693 assert_eq!(packed.number, 6);
694 let packed = match packed.kind {
695 FieldKind::Len(p) => p,
696 _ => panic!()
697 };
698
699 let mut parser = MessageDecoder::new(packed);
700 assert_eq!(parser.next_varint().unwrap(), 3);
701 assert_eq!(parser.next_varint().unwrap(), 270);
702 assert_eq!(parser.next_varint().unwrap(), 86942);
703 }
704
705 #[test]
706 fn empty_bytes() {
707 const MSG: &[u8] = &[10, 0];
708
709 let mut parser = MessageDecoder::new(MSG);
710
711 let field = parser.next().unwrap().unwrap();
712 assert_eq!(field.number, 1);
713 assert_eq!(field.kind, FieldKind::Len(&[]));
714 assert!(parser.next().unwrap().is_none());
715 }
716
717 }