protopuffer/
decode.rs

1use crate::varint::Varint;
2use crate::WireType;
3
4use std::collections::HashMap;
5use std::fmt;
6use std::hash::Hash;
7
8use bytes::{Bytes, BytesRead, BytesReadRef};
9
10#[derive(Debug, Clone, PartialEq, Eq)]
11#[non_exhaustive]
12pub enum DecodeError {
13	UnexpectedEof,
14	ExpectedEof,
15	InvalidVarint,
16	InvalidWireType(u8),
17	WireTypeMismatch,
18	ExpectedVarintWireType,
19	ExpectedI32WireType,
20	ExpectedI64WireType,
21	ExpectedLenWireType,
22	ExpectedUtf8,
23	ExpectedArrayLen(usize),
24	Other(String),
25}
26
27impl fmt::Display for DecodeError {
28	fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
29		match self {
30			Self::UnexpectedEof => write!(f, "unexpected end of file"),
31			Self::ExpectedEof => write!(f, "expected end of file"),
32			Self::InvalidVarint => write!(f, "varint is invalid"),
33			Self::InvalidWireType(t) => {
34				write!(f, "the wiretype {t} is invalid")
35			}
36			Self::WireTypeMismatch => write!(f, "wire types don't match"),
37			Self::ExpectedVarintWireType => {
38				write!(f, "expected a varint wire type")
39			}
40			Self::ExpectedI32WireType => write!(f, "expected a i32 wire type"),
41			Self::ExpectedI64WireType => write!(f, "expected a i64 wire type"),
42			Self::ExpectedLenWireType => {
43				write!(f, "expected the len wire type")
44			}
45			Self::ExpectedUtf8 => write!(f, "expected a valid utf8 string"),
46			Self::ExpectedArrayLen(n) => {
47				write!(f, "expected an array length of {n}")
48			}
49			Self::Other(s) => write!(f, "decode error: {s}"),
50		}
51	}
52}
53
54impl std::error::Error for DecodeError {}
55
56#[derive(Debug)]
57pub struct MessageDecoder<'a> {
58	inner: Bytes<'a>,
59}
60
61impl<'a> MessageDecoder<'a> {
62	pub fn new(bytes: &'a [u8]) -> Self {
63		Self {
64			inner: Bytes::from(bytes),
65		}
66	}
67
68	pub fn try_from_kind(kind: FieldKind<'a>) -> Result<Self, DecodeError> {
69		kind.try_unwrap_len().map(Self::new)
70	}
71
72	pub(crate) fn next_varint(&mut self) -> Result<u64, DecodeError> {
73		Varint::read(&mut self.inner)
74			.map(|v| v.0)
75			.map_err(|_| DecodeError::InvalidVarint)
76	}
77
78	fn next_kind(
79		&mut self,
80		ty: WireType,
81	) -> Result<FieldKind<'a>, DecodeError> {
82		let kind = match ty {
83			WireType::Varint => FieldKind::Varint(self.next_varint()?),
84			WireType::I64 => FieldKind::I64(
85				self.inner
86					.try_read_le_u64()
87					.map_err(|_| DecodeError::UnexpectedEof)?,
88			),
89			WireType::I32 => FieldKind::I32(
90				self.inner
91					.try_read_le_u32()
92					.map_err(|_| DecodeError::UnexpectedEof)?,
93			),
94			WireType::Len => {
95				let len = self.next_varint()?;
96				let bytes = self
97					.inner
98					.try_read_ref(len as usize)
99					.map_err(|_| DecodeError::UnexpectedEof)?;
100
101				FieldKind::Len(bytes)
102			}
103		};
104
105		Ok(kind)
106	}
107
108	/// should only be used for reading packed values
109	pub(crate) fn maybe_next_kind(
110		&mut self,
111		ty: WireType,
112	) -> Result<Option<FieldKind<'a>>, DecodeError> {
113		if self.inner.remaining().is_empty() {
114			return Ok(None);
115		}
116
117		self.next_kind(ty).map(Some)
118	}
119
120	/// If this returns Ok(None), this means there will never be any
121	/// more fields
122	pub fn next(&mut self) -> Result<Option<Field<'a>>, DecodeError> {
123		if self.inner.remaining().is_empty() {
124			return Ok(None);
125		}
126
127		let tag = self.next_varint()?;
128		let wtype = WireType::from_tag(tag)?;
129		let number = tag >> 3;
130
131		let kind = self.next_kind(wtype)?;
132
133		Ok(Some(Field { number, kind }))
134	}
135
136	pub fn finish(self) -> Result<(), DecodeError> {
137		if self.inner.remaining().is_empty() {
138			Ok(())
139		} else {
140			Err(DecodeError::ExpectedEof)
141		}
142	}
143}
144
145#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
146pub struct Field<'a> {
147	pub number: u64,
148	pub kind: FieldKind<'a>,
149}
150
151#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
152pub enum FieldKind<'a> {
153	// backwards compatible:
154	// - int32, uint32, int64, uint64, bool
155	// - sint32, sint64
156	Varint(u64),
157	// backwards compatible:
158	// - fixed32, sfixed32
159	I32(u32),
160	// backwards compatible:
161	// - fixed64, sfixed64
162	I64(u64),
163
164	Len(&'a [u8]),
165}
166
167impl<'a> FieldKind<'a> {
168	pub fn is_len(&self) -> bool {
169		matches!(self, Self::Len(_))
170	}
171
172	pub fn wire_type(&self) -> WireType {
173		match self {
174			Self::Varint(_) => WireType::Varint,
175			Self::I32(_) => WireType::I32,
176			Self::I64(_) => WireType::I64,
177			Self::Len(_) => WireType::Len,
178		}
179	}
180
181	pub fn try_unwrap_varint(&self) -> Result<u64, DecodeError> {
182		match self {
183			Self::Varint(n) => Ok(*n),
184			_ => Err(DecodeError::ExpectedVarintWireType),
185		}
186	}
187
188	pub fn try_unwrap_i32(&self) -> Result<u32, DecodeError> {
189		match self {
190			Self::I32(n) => Ok(*n),
191			_ => Err(DecodeError::ExpectedI32WireType),
192		}
193	}
194
195	pub fn try_unwrap_i64(&self) -> Result<u64, DecodeError> {
196		match self {
197			Self::I64(n) => Ok(*n),
198			_ => Err(DecodeError::ExpectedI64WireType),
199		}
200	}
201
202	/// Returns ExpectedLenWireType if the kind is not Len
203	pub fn try_unwrap_len(&self) -> Result<&'a [u8], DecodeError> {
204		match self {
205			Self::Len(b) => Ok(b),
206			_ => Err(DecodeError::ExpectedLenWireType),
207		}
208	}
209}
210
211pub trait DecodeMessage<'m> {
212	/// This field is just a hint, merge might accept another type
213	///
214	/// mostly this is used for detecting if we can pack a message
215	const WIRE_TYPE: WireType;
216
217	fn parse_from_bytes(b: &'m [u8]) -> Result<Self, DecodeError>
218	where
219		Self: Sized,
220	{
221		let mut this = Self::decode_default();
222
223		this.merge(FieldKind::Len(b), false)?;
224
225		Ok(this)
226	}
227
228	fn decode_default() -> Self;
229
230	/// kind does not need to be the same as Self::WIRE_TYPE
231	///
232	/// is_field is true if this message is a field of a struct or enum
233	fn merge(
234		&mut self,
235		kind: FieldKind<'m>,
236		is_field: bool,
237	) -> Result<(), DecodeError>;
238}
239
240pub trait DecodeMessageOwned: for<'m> DecodeMessage<'m> {}
241
242impl<T> DecodeMessageOwned for T where T: for<'m> DecodeMessage<'m> {}
243
244// a vec is represented as repeated ty = 1;
245impl<'m, V> DecodeMessage<'m> for Vec<V>
246where
247	V: DecodeMessage<'m>,
248{
249	const WIRE_TYPE: WireType = WireType::Len;
250
251	fn decode_default() -> Self {
252		Self::new()
253	}
254
255	fn merge(
256		&mut self,
257		kind: FieldKind<'m>,
258		is_field: bool,
259	) -> Result<(), DecodeError> {
260		// if this is not a field
261		// we need to create a struct / message
262		// which contains one field which is repeatable
263		if !is_field {
264			let mut parser = MessageDecoder::try_from_kind(kind)?;
265
266			while let Some(field) = parser.next()? {
267				if field.number != 1 {
268					continue;
269				}
270
271				// were now in a field of our virtual message/struct
272				self.merge(field.kind, true)?;
273			}
274
275			return parser.finish();
276		}
277
278		// the data could be packet
279		if kind.is_len() && V::WIRE_TYPE.can_be_packed() {
280			let mut parser = MessageDecoder::try_from_kind(kind)?;
281			while let Some(k) = parser.maybe_next_kind(V::WIRE_TYPE)? {
282				let mut v = V::decode_default();
283				v.merge(k, false)?;
284
285				self.push(v);
286			}
287
288			return parser.finish();
289		}
290
291		let mut v = V::decode_default();
292		v.merge(kind, false)?;
293
294		self.push(v);
295
296		Ok(())
297	}
298}
299
300impl<'m, K, V> DecodeMessage<'m> for HashMap<K, V>
301where
302	K: DecodeMessage<'m> + Eq + Hash,
303	V: DecodeMessage<'m>,
304{
305	const WIRE_TYPE: WireType = WireType::Len;
306
307	fn decode_default() -> Self {
308		Self::new()
309	}
310
311	fn merge(
312		&mut self,
313		kind: FieldKind<'m>,
314		is_field: bool,
315	) -> Result<(), DecodeError> {
316		// if this is not a field
317		// we need to create a struct / message
318		// which contains one field which is repeatable
319		if !is_field {
320			let mut parser = MessageDecoder::try_from_kind(kind)?;
321
322			while let Some(field) = parser.next()? {
323				if field.number != 1 {
324					continue;
325				}
326
327				// were now in a field of our virtual message/struct
328				self.merge(field.kind, true)?;
329			}
330
331			return parser.finish();
332		}
333
334		let mut field = <(K, V)>::decode_default();
335		field.merge(kind, false)?;
336
337		self.insert(field.0, field.1);
338
339		Ok(())
340	}
341}
342
343impl<'m> DecodeMessage<'m> for Vec<u8> {
344	const WIRE_TYPE: WireType = WireType::Len;
345
346	fn decode_default() -> Self {
347		Self::new()
348	}
349
350	fn merge(
351		&mut self,
352		kind: FieldKind<'m>,
353		_is_field: bool,
354	) -> Result<(), DecodeError> {
355		let bytes = kind.try_unwrap_len()?;
356		self.clear();
357		self.extend_from_slice(bytes);
358
359		Ok(())
360	}
361}
362
363impl<'m, const S: usize> DecodeMessage<'m> for [u8; S] {
364	const WIRE_TYPE: WireType = WireType::Len;
365
366	fn decode_default() -> Self {
367		[0; S]
368	}
369
370	fn merge(
371		&mut self,
372		kind: FieldKind<'m>,
373		_is_field: bool,
374	) -> Result<(), DecodeError> {
375		let bytes = kind.try_unwrap_len()?;
376
377		if bytes.len() != S {
378			return Err(DecodeError::ExpectedArrayLen(S));
379		}
380
381		self.copy_from_slice(bytes);
382
383		Ok(())
384	}
385}
386
387/// a tuple behaves the same way as a struct
388macro_rules! impl_tuple {
389	($($gen:ident, $idx:tt),*) => (
390		impl<'m, $($gen),*> DecodeMessage<'m> for ($($gen),*)
391		where
392			$($gen: DecodeMessage<'m>),*
393		{
394			const WIRE_TYPE: WireType = WireType::Len;
395
396			fn decode_default() -> Self {
397				($(
398					$gen::decode_default()
399				),*)
400			}
401
402			fn merge(
403				&mut self,
404				kind: FieldKind<'m>,
405				_is_field: bool
406			) -> Result<(), DecodeError> {
407				let mut parser = MessageDecoder::try_from_kind(kind)?;
408
409				while let Some(field) = parser.next()? {
410					match field.number {
411						$(
412							$idx => self.$idx.merge(field.kind, true)?
413						),*,
414						// ignore unknown fields
415						_ => {}
416					}
417				}
418
419				parser.finish()
420			}
421		}
422	)
423}
424
425// impl_tuple![
426// 	A, 0
427// ];
428impl_tuple![A, 0, B, 1];
429impl_tuple![A, 0, B, 1, C, 2];
430impl_tuple![A, 0, B, 1, C, 2, D, 3];
431impl_tuple![A, 0, B, 1, C, 2, D, 3, E, 4];
432impl_tuple![A, 0, B, 1, C, 2, D, 3, E, 4, F, 5];
433
434impl<'m> DecodeMessage<'m> for String {
435	const WIRE_TYPE: WireType = WireType::Len;
436
437	fn decode_default() -> Self {
438		Self::new()
439	}
440
441	fn merge(
442		&mut self,
443		kind: FieldKind<'m>,
444		_is_field: bool,
445	) -> Result<(), DecodeError> {
446		let bytes = kind.try_unwrap_len()?;
447		self.clear();
448		let s = std::str::from_utf8(bytes)
449			.map_err(|_| DecodeError::ExpectedUtf8)?;
450		self.push_str(s);
451
452		Ok(())
453	}
454}
455
456impl<'m, V> DecodeMessage<'m> for Option<V>
457where
458	V: DecodeMessage<'m>,
459{
460	const WIRE_TYPE: WireType = WireType::Len;
461
462	fn decode_default() -> Self {
463		None
464	}
465
466	fn merge(
467		&mut self,
468		kind: FieldKind<'m>,
469		is_field: bool,
470	) -> Result<(), DecodeError> {
471		// if this is not a field
472		// we need to create a struct / message
473		// which contains one field which represent V
474		if !is_field {
475			let mut parser = MessageDecoder::try_from_kind(kind)?;
476
477			while let Some(field) = parser.next()? {
478				if field.number != 1 {
479					continue;
480				}
481
482				// were now in a field of our virtual message/struct
483				self.merge(field.kind, true)?;
484			}
485
486			return parser.finish();
487		}
488
489		match self {
490			Some(v) => {
491				v.merge(kind, false)?;
492			}
493			None => {
494				let mut v = V::decode_default();
495				v.merge(kind, false)?;
496				*self = Some(v);
497			}
498		}
499
500		Ok(())
501	}
502}
503
504impl<'m> DecodeMessage<'m> for bool {
505	const WIRE_TYPE: WireType = WireType::Varint;
506
507	fn decode_default() -> Self {
508		false
509	}
510
511	fn merge(
512		&mut self,
513		kind: FieldKind<'m>,
514		_is_field: bool,
515	) -> Result<(), DecodeError> {
516		let num = kind.try_unwrap_varint()?;
517		*self = num != 0;
518
519		Ok(())
520	}
521}
522
523// impl basic varint
524macro_rules! impl_varint {
525	($($ty:ty),*) => ($(
526		impl<'m> DecodeMessage<'m> for $ty {
527			const WIRE_TYPE: WireType = WireType::Varint;
528
529			fn decode_default() -> Self {
530				Default::default()
531			}
532
533			fn merge(
534				&mut self,
535				kind: FieldKind<'m>,
536				_is_field: bool
537			) -> Result<(), DecodeError> {
538				let num = kind.try_unwrap_varint()?;
539				*self = num as $ty;
540
541				Ok(())
542			}
543		}
544	)*)
545}
546
547impl_varint![i32, i64, u32, u64];
548
549macro_rules! impl_floats {
550	($($src:ident, $wtype:ident as $ty:ty),*) => ($(
551		impl<'m> DecodeMessage<'m> for $ty {
552			const WIRE_TYPE: WireType = WireType::$wtype;
553
554			fn decode_default() -> Self {
555				Default::default()
556			}
557
558			fn merge(
559				&mut self,
560				kind: FieldKind<'m>,
561				_is_field: bool
562			) -> Result<(), DecodeError> {
563				let num = kind.$src()?;
564				*self = <$ty>::from_bits(num);
565
566				Ok(())
567			}
568		}
569	)*)
570}
571
572impl_floats![try_unwrap_i32, I32 as f32, try_unwrap_i64, I64 as f64];
573
574#[repr(transparent)]
575#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
576pub struct ZigZag<T>(pub T);
577
578macro_rules! impl_zigzag {
579	($($ty:ty),*) => ($(
580		impl<'m> DecodeMessage<'m> for ZigZag<$ty> {
581			const WIRE_TYPE: WireType = WireType::Varint;
582
583			fn decode_default() -> Self {
584				Default::default()
585			}
586
587			fn merge(
588				&mut self,
589				kind: FieldKind<'m>,
590				_is_field: bool
591			) -> Result<(), DecodeError> {
592				let num = kind.try_unwrap_varint()? as $ty;
593				let num = (num >> 1) ^ -(num & 1);
594				*self = ZigZag(num);
595
596				Ok(())
597			}
598		}
599	)*)
600}
601
602impl_zigzag![i32, i64];
603
604#[repr(transparent)]
605#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
606pub struct Fixed<T>(pub T);
607
608macro_rules! impl_fixed {
609	($($src:ident, $wtype:ident as $ty:ty),*) => ($(
610		impl<'m> DecodeMessage<'m> for Fixed<$ty> {
611			const WIRE_TYPE: WireType = WireType::$wtype;
612
613			fn decode_default() -> Self {
614				Default::default()
615			}
616
617			fn merge(
618				&mut self,
619				kind: FieldKind<'m>,
620				_is_field: bool
621			) -> Result<(), DecodeError> {
622				let num = kind.$src()?;
623				*self = Fixed(num as $ty);
624
625				Ok(())
626			}
627		}
628	)*)
629}
630
631impl_fixed![
632	try_unwrap_i32,
633	I32 as u32,
634	try_unwrap_i32,
635	I32 as i32,
636	try_unwrap_i64,
637	I64 as u64,
638	try_unwrap_i64,
639	I64 as i64
640];
641
642#[cfg(test)]
643mod tests {
644	use super::*;
645
646	use hex_literal::hex;
647
648	#[test]
649	fn string_and_repeated_test_4() {
650		const MSG: &[u8] = &hex!("220568656c6c6f280128022803");
651
652		let mut parser = MessageDecoder::new(MSG);
653
654		let hello_str = Field {
655			number: 4,
656			kind: FieldKind::Len(b"hello"),
657		};
658		assert_eq!(parser.next().unwrap().unwrap(), hello_str);
659
660		let mut repeated = Field {
661			number: 5,
662			kind: FieldKind::Varint(1),
663		};
664
665		assert_eq!(parser.next().unwrap().unwrap(), repeated);
666		repeated.kind = FieldKind::Varint(2);
667		assert_eq!(parser.next().unwrap().unwrap(), repeated);
668		repeated.kind = FieldKind::Varint(3);
669		assert_eq!(parser.next().unwrap().unwrap(), repeated);
670
671		assert!(parser.next().unwrap().is_none());
672	}
673
674	#[test]
675	fn repeated_packet() {
676		const MSG: &[u8] = &hex!("3206038e029ea705");
677
678		let mut parser = MessageDecoder::new(MSG);
679
680		let packed = parser.next().unwrap().unwrap();
681		assert_eq!(packed.number, 6);
682		let packed = match packed.kind {
683			FieldKind::Len(p) => p,
684			_ => panic!(),
685		};
686
687		let mut parser = MessageDecoder::new(packed);
688		assert_eq!(parser.next_varint().unwrap(), 3);
689		assert_eq!(parser.next_varint().unwrap(), 270);
690		assert_eq!(parser.next_varint().unwrap(), 86942);
691	}
692
693	#[test]
694	fn empty_bytes() {
695		const MSG: &[u8] = &[10, 0];
696
697		let mut parser = MessageDecoder::new(MSG);
698
699		let field = parser.next().unwrap().unwrap();
700		assert_eq!(field.number, 1);
701		assert_eq!(field.kind, FieldKind::Len(&[]));
702		assert!(parser.next().unwrap().is_none());
703	}
704
705	/*
706	message Target {
707		oneof target {
708			Unknown unknown = 1;
709			Unit unit = 2;
710			Weapon weapon = 3;
711			Static static = 4;
712			Scenery scenery = 5;
713			Airbase airbase = 6;
714			Cargo cargo = 7;
715		}
716	}
717	*/
718
719	// struct Test {
720
721	// }
722
723	// impl Message for Test {
724	// 	fn parse(r) -> Result<Self, Error> {
725
726	// 	}
727
728	// 	fn merge_field(&self, )
729	// }
730}