fire_protobuf/
decode.rs

1use crate::WireType;
2use crate::varint::Varint;
3
4use std::fmt;
5use std::hash::Hash;
6use std::collections::HashMap;
7
8use bytes::{Bytes, BytesRead, BytesReadRef};
9
10
11#[derive(Debug, Clone, PartialEq, Eq)]
12#[non_exhaustive]
13pub enum DecodeError {
14	UnexpectedEof,
15	ExpectedEof,
16	InvalidVarint,
17	InvalidWireType(u8),
18	WireTypeMismatch,
19	ExpectedVarintWireType,
20	ExpectedI32WireType,
21	ExpectedI64WireType,
22	ExpectedLenWireType,
23	ExpectedUtf8,
24	ExpectedArrayLen(usize),
25	Other(String)
26}
27
28impl fmt::Display for DecodeError {
29	fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
30		match self {
31			Self::UnexpectedEof => write!(f, "unexpected end of file"),
32			Self::ExpectedEof => write!(f, "expected end of file"),
33			Self::InvalidVarint => write!(f, "varint is invalid"),
34			Self::InvalidWireType(t) => {
35				write!(f, "the wiretype {t} is invalid")
36			},
37			Self::WireTypeMismatch => write!(f, "wire types don't match"),
38			Self::ExpectedVarintWireType => {
39				write!(f, "expected a varint wire type")
40			},
41			Self::ExpectedI32WireType => write!(f, "expected a i32 wire type"),
42			Self::ExpectedI64WireType => write!(f, "expected a i64 wire type"),
43			Self::ExpectedLenWireType => {
44				write!(f, "expected the len wire type")
45			},
46			Self::ExpectedUtf8 => write!(f, "expected a valid utf8 string"),
47			Self::ExpectedArrayLen(n) => {
48				write!(f, "expected an array length of {n}")
49			},
50			Self::Other(s) => write!(f, "decode error: {s}")
51		}
52	}
53}
54
55impl std::error::Error for DecodeError {}
56
57#[derive(Debug)]
58pub struct MessageDecoder<'a> {
59	inner: Bytes<'a>
60}
61
62impl<'a> MessageDecoder<'a> {
63	pub fn new(bytes: &'a [u8]) -> Self {
64		Self {
65			inner: Bytes::from(bytes)
66		}
67	}
68
69	pub fn try_from_kind(kind: FieldKind<'a>) -> Result<Self, DecodeError> {
70		kind.try_unwrap_len().map(Self::new)
71	}
72
73	pub(crate) fn next_varint(&mut self) -> Result<u64, DecodeError> {
74		Varint::read(&mut self.inner)
75			.map(|v| v.0)
76			.map_err(|_| DecodeError::InvalidVarint)
77	}
78
79	fn next_kind(
80		&mut self,
81		ty: WireType
82	) -> Result<FieldKind<'a>, DecodeError> {
83		let kind = match ty {
84			WireType::Varint => FieldKind::Varint(self.next_varint()?),
85			WireType::I64 => FieldKind::I64(
86				self.inner.try_read_le_u64()
87					.map_err(|_| DecodeError::UnexpectedEof)?
88			),
89			WireType::I32 => FieldKind::I32(
90				self.inner.try_read_le_u32()
91					.map_err(|_| DecodeError::UnexpectedEof)?
92			),
93			WireType::Len => {
94				let len = self.next_varint()?;
95				let bytes = self.inner.try_read_ref(len as usize)
96					.map_err(|_| DecodeError::UnexpectedEof)?;
97
98				FieldKind::Len(bytes)
99			}
100		};
101
102		Ok(kind)
103	}
104
105	/// should only be used for reading packed values
106	pub(crate) fn maybe_next_kind(
107		&mut self,
108		ty: WireType
109	) -> Result<Option<FieldKind<'a>>, DecodeError> {
110		if self.inner.remaining().is_empty() {
111			return Ok(None)
112		}
113
114		self.next_kind(ty).map(Some)
115	}
116
117	/// If this returns Ok(None), this means there will never be any
118	/// more fields
119	pub fn next(&mut self) -> Result<Option<Field<'a>>, DecodeError> {
120		if self.inner.remaining().is_empty() {
121			return Ok(None)
122		}
123
124		let tag = self.next_varint()?;
125		let wtype = WireType::from_tag(tag)?;
126		let number = tag >> 3;
127
128		let kind = self.next_kind(wtype)?;
129
130		Ok(Some(Field { number, kind }))
131	}
132
133	pub fn finish(self) -> Result<(), DecodeError> {
134		if self.inner.remaining().is_empty() {
135			Ok(())
136		} else {
137			Err(DecodeError::ExpectedEof)
138		}
139	}
140}
141
142#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
143pub struct Field<'a> {
144	pub number: u64,
145	pub kind: FieldKind<'a>
146}
147
148#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
149pub enum FieldKind<'a> {
150	// backwards compatible:
151	// - int32, uint32, int64, uint64, bool
152	// - sint32, sint64
153	Varint(u64),
154	// backwards compatible:
155	// - fixed32, sfixed32
156	I32(u32),
157	// backwards compatible:
158	// - fixed64, sfixed64
159	I64(u64),
160	
161	Len(&'a [u8])
162}
163
164impl<'a> FieldKind<'a> {
165	pub fn is_len(&self) -> bool {
166		matches!(self, Self::Len(_))
167	}
168
169	pub fn wire_type(&self) -> WireType {
170		match self {
171			Self::Varint(_) => WireType::Varint,
172			Self::I32(_) => WireType::I32,
173			Self::I64(_) => WireType::I64,
174			Self::Len(_) => WireType::Len
175		}
176	}
177
178	pub fn try_unwrap_varint(&self) -> Result<u64, DecodeError> {
179		match self {
180			Self::Varint(n) => Ok(*n),
181			_ => Err(DecodeError::ExpectedVarintWireType)
182		}
183	}
184
185	pub fn try_unwrap_i32(&self) -> Result<u32, DecodeError> {
186		match self {
187			Self::I32(n) => Ok(*n),
188			_ => Err(DecodeError::ExpectedI32WireType)
189		}
190	}
191
192	pub fn try_unwrap_i64(&self) -> Result<u64, DecodeError> {
193		match self {
194			Self::I64(n) => Ok(*n),
195			_ => Err(DecodeError::ExpectedI64WireType)
196		}
197	}
198
199	/// Returns ExpectedLenWireType if the kind is not Len
200	pub fn try_unwrap_len(&self) -> Result<&'a [u8], DecodeError> {
201		match self {
202			Self::Len(b) => Ok(b),
203			_ => Err(DecodeError::ExpectedLenWireType)
204		}
205	}
206}
207
208
209
210pub trait DecodeMessage<'m> {
211	/// This field is just a hint, merge might accept another type
212	/// 
213	/// mostly this is used for detecting if we can pack a message
214	const WIRE_TYPE: WireType;
215
216	fn parse_from_bytes(b: &'m [u8]) -> Result<Self, DecodeError>
217	where Self: Sized {
218		let mut this = Self::decode_default();
219
220		this.merge(FieldKind::Len(b), false)?;
221
222		Ok(this)
223	}
224
225	fn decode_default() -> Self;
226
227	/// kind does not need to be the same as Self::WIRE_TYPE
228	/// 
229	/// is_field is true if this message is a field of a struct or enum
230	fn merge(
231		&mut self,
232		kind: FieldKind<'m>,
233		is_field: bool
234	) -> Result<(), DecodeError>;
235}
236
237pub trait DecodeMessageOwned: for<'m> DecodeMessage<'m> {}
238
239impl<T> DecodeMessageOwned for T
240where T: for<'m> DecodeMessage<'m> {}
241
242// a vec is represented as repeated ty = 1;
243impl<'m, V> DecodeMessage<'m> for Vec<V>
244where V: DecodeMessage<'m> {
245	const WIRE_TYPE: WireType = WireType::Len;
246
247	fn decode_default() -> Self {
248		Self::new()
249	}
250
251	fn merge(
252		&mut self,
253		kind: FieldKind<'m>,
254		is_field: bool
255	) -> Result<(), DecodeError> {
256		// if this is not a field
257		// we need to create a struct / message
258		// which contains one field which is repeatable
259		if !is_field {
260			let mut parser = MessageDecoder::try_from_kind(kind)?;
261
262			while let Some(field) = parser.next()? {
263				if field.number != 1 {
264					continue
265				}
266
267				// were now in a field of our virtual message/struct
268				self.merge(field.kind, true)?;
269			}
270
271			return parser.finish();
272		}
273
274		// the data could be packet
275		if kind.is_len() && V::WIRE_TYPE.can_be_packed() {
276			let mut parser = MessageDecoder::try_from_kind(kind)?;
277			while let Some(k) = parser.maybe_next_kind(V::WIRE_TYPE)? {
278				let mut v = V::decode_default();
279				v.merge(k, false)?;
280
281				self.push(v);
282			}
283
284			return parser.finish()
285		}
286
287
288		let mut v = V::decode_default();
289		v.merge(kind, false)?;
290
291		self.push(v);
292
293		Ok(())
294	}
295}
296
297impl<'m, K, V> DecodeMessage<'m> for HashMap<K, V>
298where
299	K: DecodeMessage<'m> + Eq + Hash,
300	V: DecodeMessage<'m>
301{
302	const WIRE_TYPE: WireType = WireType::Len;
303
304	fn decode_default() -> Self {
305		Self::new()
306	}
307
308	fn merge(
309		&mut self,
310		kind: FieldKind<'m>,
311		is_field: bool
312	) -> Result<(), DecodeError> {
313		// if this is not a field
314		// we need to create a struct / message
315		// which contains one field which is repeatable
316		if !is_field {
317			let mut parser = MessageDecoder::try_from_kind(kind)?;
318
319			while let Some(field) = parser.next()? {
320				if field.number != 1 {
321					continue
322				}
323
324				// were now in a field of our virtual message/struct
325				self.merge(field.kind, true)?;
326			}
327
328			return parser.finish();
329		}
330
331		let mut field = <(K, V)>::decode_default();
332		field.merge(kind, false)?;
333
334		self.insert(field.0, field.1);
335
336		Ok(())
337	}
338}
339
340impl<'m> DecodeMessage<'m> for Vec<u8> {
341	const WIRE_TYPE: WireType = WireType::Len;
342
343	fn decode_default() -> Self {
344		Self::new()
345	}
346
347	fn merge(
348		&mut self,
349		kind: FieldKind<'m>,
350		_is_field: bool
351	) -> Result<(), DecodeError> {
352		let bytes = kind.try_unwrap_len()?;
353		self.clear();
354		self.extend_from_slice(bytes);
355
356		Ok(())
357	}
358}
359
360impl<'m, const S: usize> DecodeMessage<'m> for [u8; S] {
361	const WIRE_TYPE: WireType = WireType::Len;
362
363	fn decode_default() -> Self {
364		[0; S]
365	}
366
367	fn merge(
368		&mut self,
369		kind: FieldKind<'m>,
370		_is_field: bool
371	) -> Result<(), DecodeError> {
372		let bytes = kind.try_unwrap_len()?;
373
374		if bytes.len() != S {
375			return Err(DecodeError::ExpectedArrayLen(S))
376		}
377
378		self.copy_from_slice(bytes);
379
380		Ok(())
381	}
382}
383
384/// a tuple behaves the same way as a struct
385macro_rules! impl_tuple {
386	($($gen:ident, $idx:tt),*) => (
387		impl<'m, $($gen),*> DecodeMessage<'m> for ($($gen),*)
388		where
389			$($gen: DecodeMessage<'m>),*
390		{
391			const WIRE_TYPE: WireType = WireType::Len;
392
393			fn decode_default() -> Self {
394				($(
395					$gen::decode_default()
396				),*)
397			}
398
399			fn merge(
400				&mut self,
401				kind: FieldKind<'m>,
402				_is_field: bool
403			) -> Result<(), DecodeError> {
404				let mut parser = MessageDecoder::try_from_kind(kind)?;
405
406				while let Some(field) = parser.next()? {
407					match field.number {
408						$(
409							$idx => self.$idx.merge(field.kind, true)?
410						),*,
411						// ignore unknown fields
412						_ => {}
413					}
414				}
415
416				parser.finish()
417			}
418		}
419	)
420}
421
422// impl_tuple![
423// 	A, 0
424// ];
425impl_tuple![
426	A, 0,
427	B, 1
428];
429impl_tuple![
430	A, 0,
431	B, 1,
432	C, 2
433];
434impl_tuple![
435	A, 0,
436	B, 1,
437	C, 2,
438	D, 3
439];
440impl_tuple![
441	A, 0,
442	B, 1,
443	C, 2,
444	D, 3,
445	E, 4
446];
447impl_tuple![
448	A, 0,
449	B, 1,
450	C, 2,
451	D, 3,
452	E, 4,
453	F, 5
454];
455
456impl<'m> DecodeMessage<'m> for String {
457	const WIRE_TYPE: WireType = WireType::Len;
458
459	fn decode_default() -> Self {
460		Self::new()
461	}
462
463	fn merge(
464		&mut self,
465		kind: FieldKind<'m>,
466		_is_field: bool
467	) -> Result<(), DecodeError> {
468		let bytes = kind.try_unwrap_len()?;
469		self.clear();
470		let s = std::str::from_utf8(bytes)
471			.map_err(|_| DecodeError::ExpectedUtf8)?;
472		self.push_str(s);
473
474		Ok(())
475	}
476}
477
478impl<'m, V> DecodeMessage<'m> for Option<V>
479where V: DecodeMessage<'m> {
480	const WIRE_TYPE: WireType = WireType::Len;
481
482	fn decode_default() -> Self {
483		None
484	}
485
486	fn merge(
487		&mut self,
488		kind: FieldKind<'m>,
489		is_field: bool
490	) -> Result<(), DecodeError> {
491		// if this is not a field
492		// we need to create a struct / message
493		// which contains one field which represent V
494		if !is_field {
495			let mut parser = MessageDecoder::try_from_kind(kind)?;
496
497			while let Some(field) = parser.next()? {
498				if field.number != 1 {
499					continue
500				}
501
502				// were now in a field of our virtual message/struct
503				self.merge(field.kind, true)?;
504			}
505
506			return parser.finish();
507		}
508
509		match self {
510			Some(v) => {
511				v.merge(kind, false)?;
512			}
513			None => {
514				let mut v = V::decode_default();
515				v.merge(kind, false)?;
516				*self = Some(v);
517			}
518		}
519
520		Ok(())
521	}
522}
523
524impl<'m> DecodeMessage<'m> for bool {
525	const WIRE_TYPE: WireType = WireType::Varint;
526
527	fn decode_default() -> Self {
528		false
529	}
530
531	fn merge(
532		&mut self,
533		kind: FieldKind<'m>,
534		_is_field: bool
535	) -> Result<(), DecodeError> {
536		let num = kind.try_unwrap_varint()?;
537		*self = num != 0;
538
539		Ok(())
540	}
541}
542
543// impl basic varint
544macro_rules! impl_varint {
545	($($ty:ty),*) => ($(
546		impl<'m> DecodeMessage<'m> for $ty {
547			const WIRE_TYPE: WireType = WireType::Varint;
548
549			fn decode_default() -> Self {
550				Default::default()
551			}
552
553			fn merge(
554				&mut self,
555				kind: FieldKind<'m>,
556				_is_field: bool
557			) -> Result<(), DecodeError> {
558				let num = kind.try_unwrap_varint()?;
559				*self = num as $ty;
560
561				Ok(())
562			}
563		}
564	)*)
565}
566
567impl_varint![i32, i64, u32, u64];
568
569macro_rules! impl_floats {
570	($($src:ident, $wtype:ident as $ty:ty),*) => ($(
571		impl<'m> DecodeMessage<'m> for $ty {
572			const WIRE_TYPE: WireType = WireType::$wtype;
573
574			fn decode_default() -> Self {
575				Default::default()
576			}
577
578			fn merge(
579				&mut self,
580				kind: FieldKind<'m>,
581				_is_field: bool
582			) -> Result<(), DecodeError> {
583				let num = kind.$src()?;
584				*self = num as $ty;
585
586				Ok(())
587			}
588		}
589	)*)
590}
591
592impl_floats![
593	try_unwrap_i32, I32 as f32,
594	try_unwrap_i64, I64 as f64
595];
596
597#[repr(transparent)]
598#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
599pub struct ZigZag<T>(pub T);
600
601macro_rules! impl_zigzag {
602	($($ty:ty),*) => ($(
603		impl<'m> DecodeMessage<'m> for ZigZag<$ty> {
604			const WIRE_TYPE: WireType = WireType::Varint;
605
606			fn decode_default() -> Self {
607				Default::default()
608			}
609
610			fn merge(
611				&mut self,
612				kind: FieldKind<'m>,
613				_is_field: bool
614			) -> Result<(), DecodeError> {
615				let num = kind.try_unwrap_varint()? as $ty;
616				let num = (num >> 1) ^ -(num & 1);
617				*self = ZigZag(num);
618
619				Ok(())
620			}
621		}
622	)*)
623}
624
625impl_zigzag![i32, i64];
626
627#[repr(transparent)]
628#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
629pub struct Fixed<T>(pub T);
630
631macro_rules! impl_fixed {
632	($($src:ident, $wtype:ident as $ty:ty),*) => ($(
633		impl<'m> DecodeMessage<'m> for Fixed<$ty> {
634			const WIRE_TYPE: WireType = WireType::$wtype;
635
636			fn decode_default() -> Self {
637				Default::default()
638			}
639
640			fn merge(
641				&mut self,
642				kind: FieldKind<'m>,
643				_is_field: bool
644			) -> Result<(), DecodeError> {
645				let num = kind.$src()?;
646				*self = Fixed(num as $ty);
647
648				Ok(())
649			}
650		}
651	)*)
652}
653
654impl_fixed![
655	try_unwrap_i32, I32 as u32, try_unwrap_i32, I32 as i32,
656	try_unwrap_i64, I64 as u64, try_unwrap_i64, I64 as i64
657];
658
659
660#[cfg(test)]
661mod tests {
662	use super::*;
663
664	use hex_literal::hex;
665
666	#[test]
667	fn string_and_repeated_test_4() {
668		const MSG: &[u8] = &hex!("220568656c6c6f280128022803");
669
670		let mut parser = MessageDecoder::new(MSG);
671
672		let hello_str = Field { number: 4, kind: FieldKind::Len(b"hello") };
673		assert_eq!(parser.next().unwrap().unwrap(), hello_str);
674
675		let mut repeated = Field { number: 5, kind: FieldKind::Varint(1) };
676
677		assert_eq!(parser.next().unwrap().unwrap(), repeated);
678		repeated.kind = FieldKind::Varint(2);
679		assert_eq!(parser.next().unwrap().unwrap(), repeated);
680		repeated.kind = FieldKind::Varint(3);
681		assert_eq!(parser.next().unwrap().unwrap(), repeated);
682
683		assert!(parser.next().unwrap().is_none());
684	}
685
686	#[test]
687	fn repeated_packet() {
688		const MSG: &[u8] = &hex!("3206038e029ea705");
689
690		let mut parser = MessageDecoder::new(MSG);
691
692		let packed = parser.next().unwrap().unwrap();
693		assert_eq!(packed.number, 6);
694		let packed = match packed.kind {
695			FieldKind::Len(p) => p,
696			_ => panic!()
697		};
698
699		let mut parser = MessageDecoder::new(packed);
700		assert_eq!(parser.next_varint().unwrap(), 3);
701		assert_eq!(parser.next_varint().unwrap(), 270);
702		assert_eq!(parser.next_varint().unwrap(), 86942);
703	}
704
705	#[test]
706	fn empty_bytes() {
707		const MSG: &[u8] = &[10, 0];
708
709		let mut parser = MessageDecoder::new(MSG);
710
711		let field = parser.next().unwrap().unwrap();
712		assert_eq!(field.number, 1);
713		assert_eq!(field.kind, FieldKind::Len(&[]));
714		assert!(parser.next().unwrap().is_none());
715	}
716
717	/*
718	message Target {
719		oneof target {
720			Unknown unknown = 1;
721			Unit unit = 2;
722			Weapon weapon = 3;
723			Static static = 4;
724			Scenery scenery = 5;
725			Airbase airbase = 6;
726			Cargo cargo = 7;
727		}
728	}
729	*/
730
731
732	// struct Test {
733		
734	// }
735
736	// impl Message for Test {
737	// 	fn parse(r) -> Result<Self, Error> {
738
739	// 	}
740
741	// 	fn merge_field(&self, )
742	// }
743
744
745
746
747}