1use crate::varint::Varint;
2use crate::WireType;
3
4use std::collections::HashMap;
5use std::fmt;
6use std::hash::Hash;
7
8use bytes::{Bytes, BytesRead, BytesReadRef};
9
10#[derive(Debug, Clone, PartialEq, Eq)]
11#[non_exhaustive]
12pub enum DecodeError {
13 UnexpectedEof,
14 ExpectedEof,
15 InvalidVarint,
16 InvalidWireType(u8),
17 WireTypeMismatch,
18 ExpectedVarintWireType,
19 ExpectedI32WireType,
20 ExpectedI64WireType,
21 ExpectedLenWireType,
22 ExpectedUtf8,
23 ExpectedArrayLen(usize),
24 Other(String),
25}
26
27impl fmt::Display for DecodeError {
28 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
29 match self {
30 Self::UnexpectedEof => write!(f, "unexpected end of file"),
31 Self::ExpectedEof => write!(f, "expected end of file"),
32 Self::InvalidVarint => write!(f, "varint is invalid"),
33 Self::InvalidWireType(t) => {
34 write!(f, "the wiretype {t} is invalid")
35 }
36 Self::WireTypeMismatch => write!(f, "wire types don't match"),
37 Self::ExpectedVarintWireType => {
38 write!(f, "expected a varint wire type")
39 }
40 Self::ExpectedI32WireType => write!(f, "expected a i32 wire type"),
41 Self::ExpectedI64WireType => write!(f, "expected a i64 wire type"),
42 Self::ExpectedLenWireType => {
43 write!(f, "expected the len wire type")
44 }
45 Self::ExpectedUtf8 => write!(f, "expected a valid utf8 string"),
46 Self::ExpectedArrayLen(n) => {
47 write!(f, "expected an array length of {n}")
48 }
49 Self::Other(s) => write!(f, "decode error: {s}"),
50 }
51 }
52}
53
54impl std::error::Error for DecodeError {}
55
56#[derive(Debug)]
57pub struct MessageDecoder<'a> {
58 inner: Bytes<'a>,
59}
60
61impl<'a> MessageDecoder<'a> {
62 pub fn new(bytes: &'a [u8]) -> Self {
63 Self {
64 inner: Bytes::from(bytes),
65 }
66 }
67
68 pub fn try_from_kind(kind: FieldKind<'a>) -> Result<Self, DecodeError> {
69 kind.try_unwrap_len().map(Self::new)
70 }
71
72 pub(crate) fn next_varint(&mut self) -> Result<u64, DecodeError> {
73 Varint::read(&mut self.inner)
74 .map(|v| v.0)
75 .map_err(|_| DecodeError::InvalidVarint)
76 }
77
78 fn next_kind(
79 &mut self,
80 ty: WireType,
81 ) -> Result<FieldKind<'a>, DecodeError> {
82 let kind = match ty {
83 WireType::Varint => FieldKind::Varint(self.next_varint()?),
84 WireType::I64 => FieldKind::I64(
85 self.inner
86 .try_read_le_u64()
87 .map_err(|_| DecodeError::UnexpectedEof)?,
88 ),
89 WireType::I32 => FieldKind::I32(
90 self.inner
91 .try_read_le_u32()
92 .map_err(|_| DecodeError::UnexpectedEof)?,
93 ),
94 WireType::Len => {
95 let len = self.next_varint()?;
96 let bytes = self
97 .inner
98 .try_read_ref(len as usize)
99 .map_err(|_| DecodeError::UnexpectedEof)?;
100
101 FieldKind::Len(bytes)
102 }
103 };
104
105 Ok(kind)
106 }
107
108 pub(crate) fn maybe_next_kind(
110 &mut self,
111 ty: WireType,
112 ) -> Result<Option<FieldKind<'a>>, DecodeError> {
113 if self.inner.remaining().is_empty() {
114 return Ok(None);
115 }
116
117 self.next_kind(ty).map(Some)
118 }
119
120 pub fn next(&mut self) -> Result<Option<Field<'a>>, DecodeError> {
123 if self.inner.remaining().is_empty() {
124 return Ok(None);
125 }
126
127 let tag = self.next_varint()?;
128 let wtype = WireType::from_tag(tag)?;
129 let number = tag >> 3;
130
131 let kind = self.next_kind(wtype)?;
132
133 Ok(Some(Field { number, kind }))
134 }
135
136 pub fn finish(self) -> Result<(), DecodeError> {
137 if self.inner.remaining().is_empty() {
138 Ok(())
139 } else {
140 Err(DecodeError::ExpectedEof)
141 }
142 }
143}
144
145#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
146pub struct Field<'a> {
147 pub number: u64,
148 pub kind: FieldKind<'a>,
149}
150
151#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
152pub enum FieldKind<'a> {
153 Varint(u64),
157 I32(u32),
160 I64(u64),
163
164 Len(&'a [u8]),
165}
166
167impl<'a> FieldKind<'a> {
168 pub fn is_len(&self) -> bool {
169 matches!(self, Self::Len(_))
170 }
171
172 pub fn wire_type(&self) -> WireType {
173 match self {
174 Self::Varint(_) => WireType::Varint,
175 Self::I32(_) => WireType::I32,
176 Self::I64(_) => WireType::I64,
177 Self::Len(_) => WireType::Len,
178 }
179 }
180
181 pub fn try_unwrap_varint(&self) -> Result<u64, DecodeError> {
182 match self {
183 Self::Varint(n) => Ok(*n),
184 _ => Err(DecodeError::ExpectedVarintWireType),
185 }
186 }
187
188 pub fn try_unwrap_i32(&self) -> Result<u32, DecodeError> {
189 match self {
190 Self::I32(n) => Ok(*n),
191 _ => Err(DecodeError::ExpectedI32WireType),
192 }
193 }
194
195 pub fn try_unwrap_i64(&self) -> Result<u64, DecodeError> {
196 match self {
197 Self::I64(n) => Ok(*n),
198 _ => Err(DecodeError::ExpectedI64WireType),
199 }
200 }
201
202 pub fn try_unwrap_len(&self) -> Result<&'a [u8], DecodeError> {
204 match self {
205 Self::Len(b) => Ok(b),
206 _ => Err(DecodeError::ExpectedLenWireType),
207 }
208 }
209}
210
211pub trait DecodeMessage<'m> {
212 const WIRE_TYPE: WireType;
216
217 fn parse_from_bytes(b: &'m [u8]) -> Result<Self, DecodeError>
218 where
219 Self: Sized,
220 {
221 let mut this = Self::decode_default();
222
223 this.merge(FieldKind::Len(b), false)?;
224
225 Ok(this)
226 }
227
228 fn decode_default() -> Self;
229
230 fn merge(
234 &mut self,
235 kind: FieldKind<'m>,
236 is_field: bool,
237 ) -> Result<(), DecodeError>;
238}
239
240pub trait DecodeMessageOwned: for<'m> DecodeMessage<'m> {}
241
242impl<T> DecodeMessageOwned for T where T: for<'m> DecodeMessage<'m> {}
243
244impl<'m, V> DecodeMessage<'m> for Vec<V>
246where
247 V: DecodeMessage<'m>,
248{
249 const WIRE_TYPE: WireType = WireType::Len;
250
251 fn decode_default() -> Self {
252 Self::new()
253 }
254
255 fn merge(
256 &mut self,
257 kind: FieldKind<'m>,
258 is_field: bool,
259 ) -> Result<(), DecodeError> {
260 if !is_field {
264 let mut parser = MessageDecoder::try_from_kind(kind)?;
265
266 while let Some(field) = parser.next()? {
267 if field.number != 1 {
268 continue;
269 }
270
271 self.merge(field.kind, true)?;
273 }
274
275 return parser.finish();
276 }
277
278 if kind.is_len() && V::WIRE_TYPE.can_be_packed() {
280 let mut parser = MessageDecoder::try_from_kind(kind)?;
281 while let Some(k) = parser.maybe_next_kind(V::WIRE_TYPE)? {
282 let mut v = V::decode_default();
283 v.merge(k, false)?;
284
285 self.push(v);
286 }
287
288 return parser.finish();
289 }
290
291 let mut v = V::decode_default();
292 v.merge(kind, false)?;
293
294 self.push(v);
295
296 Ok(())
297 }
298}
299
300impl<'m, K, V> DecodeMessage<'m> for HashMap<K, V>
301where
302 K: DecodeMessage<'m> + Eq + Hash,
303 V: DecodeMessage<'m>,
304{
305 const WIRE_TYPE: WireType = WireType::Len;
306
307 fn decode_default() -> Self {
308 Self::new()
309 }
310
311 fn merge(
312 &mut self,
313 kind: FieldKind<'m>,
314 is_field: bool,
315 ) -> Result<(), DecodeError> {
316 if !is_field {
320 let mut parser = MessageDecoder::try_from_kind(kind)?;
321
322 while let Some(field) = parser.next()? {
323 if field.number != 1 {
324 continue;
325 }
326
327 self.merge(field.kind, true)?;
329 }
330
331 return parser.finish();
332 }
333
334 let mut field = <(K, V)>::decode_default();
335 field.merge(kind, false)?;
336
337 self.insert(field.0, field.1);
338
339 Ok(())
340 }
341}
342
343impl<'m> DecodeMessage<'m> for Vec<u8> {
344 const WIRE_TYPE: WireType = WireType::Len;
345
346 fn decode_default() -> Self {
347 Self::new()
348 }
349
350 fn merge(
351 &mut self,
352 kind: FieldKind<'m>,
353 _is_field: bool,
354 ) -> Result<(), DecodeError> {
355 let bytes = kind.try_unwrap_len()?;
356 self.clear();
357 self.extend_from_slice(bytes);
358
359 Ok(())
360 }
361}
362
363impl<'m, const S: usize> DecodeMessage<'m> for [u8; S] {
364 const WIRE_TYPE: WireType = WireType::Len;
365
366 fn decode_default() -> Self {
367 [0; S]
368 }
369
370 fn merge(
371 &mut self,
372 kind: FieldKind<'m>,
373 _is_field: bool,
374 ) -> Result<(), DecodeError> {
375 let bytes = kind.try_unwrap_len()?;
376
377 if bytes.len() != S {
378 return Err(DecodeError::ExpectedArrayLen(S));
379 }
380
381 self.copy_from_slice(bytes);
382
383 Ok(())
384 }
385}
386
387macro_rules! impl_tuple {
389 ($($gen:ident, $idx:tt),*) => (
390 impl<'m, $($gen),*> DecodeMessage<'m> for ($($gen),*)
391 where
392 $($gen: DecodeMessage<'m>),*
393 {
394 const WIRE_TYPE: WireType = WireType::Len;
395
396 fn decode_default() -> Self {
397 ($(
398 $gen::decode_default()
399 ),*)
400 }
401
402 fn merge(
403 &mut self,
404 kind: FieldKind<'m>,
405 _is_field: bool
406 ) -> Result<(), DecodeError> {
407 let mut parser = MessageDecoder::try_from_kind(kind)?;
408
409 while let Some(field) = parser.next()? {
410 match field.number {
411 $(
412 $idx => self.$idx.merge(field.kind, true)?
413 ),*,
414 _ => {}
416 }
417 }
418
419 parser.finish()
420 }
421 }
422 )
423}
424
425impl_tuple![A, 0, B, 1];
429impl_tuple![A, 0, B, 1, C, 2];
430impl_tuple![A, 0, B, 1, C, 2, D, 3];
431impl_tuple![A, 0, B, 1, C, 2, D, 3, E, 4];
432impl_tuple![A, 0, B, 1, C, 2, D, 3, E, 4, F, 5];
433
434impl<'m> DecodeMessage<'m> for String {
435 const WIRE_TYPE: WireType = WireType::Len;
436
437 fn decode_default() -> Self {
438 Self::new()
439 }
440
441 fn merge(
442 &mut self,
443 kind: FieldKind<'m>,
444 _is_field: bool,
445 ) -> Result<(), DecodeError> {
446 let bytes = kind.try_unwrap_len()?;
447 self.clear();
448 let s = std::str::from_utf8(bytes)
449 .map_err(|_| DecodeError::ExpectedUtf8)?;
450 self.push_str(s);
451
452 Ok(())
453 }
454}
455
456impl<'m, V> DecodeMessage<'m> for Option<V>
457where
458 V: DecodeMessage<'m>,
459{
460 const WIRE_TYPE: WireType = WireType::Len;
461
462 fn decode_default() -> Self {
463 None
464 }
465
466 fn merge(
467 &mut self,
468 kind: FieldKind<'m>,
469 is_field: bool,
470 ) -> Result<(), DecodeError> {
471 if !is_field {
475 let mut parser = MessageDecoder::try_from_kind(kind)?;
476
477 while let Some(field) = parser.next()? {
478 if field.number != 1 {
479 continue;
480 }
481
482 self.merge(field.kind, true)?;
484 }
485
486 return parser.finish();
487 }
488
489 match self {
490 Some(v) => {
491 v.merge(kind, false)?;
492 }
493 None => {
494 let mut v = V::decode_default();
495 v.merge(kind, false)?;
496 *self = Some(v);
497 }
498 }
499
500 Ok(())
501 }
502}
503
504impl<'m> DecodeMessage<'m> for bool {
505 const WIRE_TYPE: WireType = WireType::Varint;
506
507 fn decode_default() -> Self {
508 false
509 }
510
511 fn merge(
512 &mut self,
513 kind: FieldKind<'m>,
514 _is_field: bool,
515 ) -> Result<(), DecodeError> {
516 let num = kind.try_unwrap_varint()?;
517 *self = num != 0;
518
519 Ok(())
520 }
521}
522
523macro_rules! impl_varint {
525 ($($ty:ty),*) => ($(
526 impl<'m> DecodeMessage<'m> for $ty {
527 const WIRE_TYPE: WireType = WireType::Varint;
528
529 fn decode_default() -> Self {
530 Default::default()
531 }
532
533 fn merge(
534 &mut self,
535 kind: FieldKind<'m>,
536 _is_field: bool
537 ) -> Result<(), DecodeError> {
538 let num = kind.try_unwrap_varint()?;
539 *self = num as $ty;
540
541 Ok(())
542 }
543 }
544 )*)
545}
546
547impl_varint![i32, i64, u32, u64];
548
549macro_rules! impl_floats {
550 ($($src:ident, $wtype:ident as $ty:ty),*) => ($(
551 impl<'m> DecodeMessage<'m> for $ty {
552 const WIRE_TYPE: WireType = WireType::$wtype;
553
554 fn decode_default() -> Self {
555 Default::default()
556 }
557
558 fn merge(
559 &mut self,
560 kind: FieldKind<'m>,
561 _is_field: bool
562 ) -> Result<(), DecodeError> {
563 let num = kind.$src()?;
564 *self = <$ty>::from_bits(num);
565
566 Ok(())
567 }
568 }
569 )*)
570}
571
572impl_floats![try_unwrap_i32, I32 as f32, try_unwrap_i64, I64 as f64];
573
574#[repr(transparent)]
575#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
576pub struct ZigZag<T>(pub T);
577
578macro_rules! impl_zigzag {
579 ($($ty:ty),*) => ($(
580 impl<'m> DecodeMessage<'m> for ZigZag<$ty> {
581 const WIRE_TYPE: WireType = WireType::Varint;
582
583 fn decode_default() -> Self {
584 Default::default()
585 }
586
587 fn merge(
588 &mut self,
589 kind: FieldKind<'m>,
590 _is_field: bool
591 ) -> Result<(), DecodeError> {
592 let num = kind.try_unwrap_varint()? as $ty;
593 let num = (num >> 1) ^ -(num & 1);
594 *self = ZigZag(num);
595
596 Ok(())
597 }
598 }
599 )*)
600}
601
602impl_zigzag![i32, i64];
603
604#[repr(transparent)]
605#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
606pub struct Fixed<T>(pub T);
607
608macro_rules! impl_fixed {
609 ($($src:ident, $wtype:ident as $ty:ty),*) => ($(
610 impl<'m> DecodeMessage<'m> for Fixed<$ty> {
611 const WIRE_TYPE: WireType = WireType::$wtype;
612
613 fn decode_default() -> Self {
614 Default::default()
615 }
616
617 fn merge(
618 &mut self,
619 kind: FieldKind<'m>,
620 _is_field: bool
621 ) -> Result<(), DecodeError> {
622 let num = kind.$src()?;
623 *self = Fixed(num as $ty);
624
625 Ok(())
626 }
627 }
628 )*)
629}
630
631impl_fixed![
632 try_unwrap_i32,
633 I32 as u32,
634 try_unwrap_i32,
635 I32 as i32,
636 try_unwrap_i64,
637 I64 as u64,
638 try_unwrap_i64,
639 I64 as i64
640];
641
642#[cfg(test)]
643mod tests {
644 use super::*;
645
646 use hex_literal::hex;
647
648 #[test]
649 fn string_and_repeated_test_4() {
650 const MSG: &[u8] = &hex!("220568656c6c6f280128022803");
651
652 let mut parser = MessageDecoder::new(MSG);
653
654 let hello_str = Field {
655 number: 4,
656 kind: FieldKind::Len(b"hello"),
657 };
658 assert_eq!(parser.next().unwrap().unwrap(), hello_str);
659
660 let mut repeated = Field {
661 number: 5,
662 kind: FieldKind::Varint(1),
663 };
664
665 assert_eq!(parser.next().unwrap().unwrap(), repeated);
666 repeated.kind = FieldKind::Varint(2);
667 assert_eq!(parser.next().unwrap().unwrap(), repeated);
668 repeated.kind = FieldKind::Varint(3);
669 assert_eq!(parser.next().unwrap().unwrap(), repeated);
670
671 assert!(parser.next().unwrap().is_none());
672 }
673
674 #[test]
675 fn repeated_packet() {
676 const MSG: &[u8] = &hex!("3206038e029ea705");
677
678 let mut parser = MessageDecoder::new(MSG);
679
680 let packed = parser.next().unwrap().unwrap();
681 assert_eq!(packed.number, 6);
682 let packed = match packed.kind {
683 FieldKind::Len(p) => p,
684 _ => panic!(),
685 };
686
687 let mut parser = MessageDecoder::new(packed);
688 assert_eq!(parser.next_varint().unwrap(), 3);
689 assert_eq!(parser.next_varint().unwrap(), 270);
690 assert_eq!(parser.next_varint().unwrap(), 86942);
691 }
692
693 #[test]
694 fn empty_bytes() {
695 const MSG: &[u8] = &[10, 0];
696
697 let mut parser = MessageDecoder::new(MSG);
698
699 let field = parser.next().unwrap().unwrap();
700 assert_eq!(field.number, 1);
701 assert_eq!(field.kind, FieldKind::Len(&[]));
702 assert!(parser.next().unwrap().is_none());
703 }
704
705 }