1use byteorder::{BigEndian, ReadBytesExt};
13use std::{
14 convert::TryInto,
15 error::Error,
16 io::{Cursor, Read},
17 mem::size_of,
18 num::TryFromIntError,
19};
20
21#[derive(Debug, thiserror::Error)]
23#[non_exhaustive]
24pub enum CodecError {
25 #[error("I/O error")]
27 Io(#[from] std::io::Error),
28
29 #[error("{0} bytes left in buffer after decoding value")]
31 BytesLeftOver(usize),
32
33 #[error("length prefix of encoded vector overflows buffer: {0}")]
35 LengthPrefixTooBig(usize),
36
37 #[error("vector length exceeded range of length prefix")]
39 LengthPrefixOverflow,
40
41 #[error("other error: {0}")]
43 Other(#[source] Box<dyn Error + 'static + Send + Sync>),
44
45 #[error("unexpected value")]
47 UnexpectedValue,
48}
49
50pub trait Decode: Sized {
52 fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError>;
56
57 fn get_decoded(bytes: &[u8]) -> Result<Self, CodecError> {
60 Self::get_decoded_with_param(&(), bytes)
61 }
62}
63
64pub trait ParameterizedDecode<P>: Sized {
67 fn decode_with_param(
72 decoding_parameter: &P,
73 bytes: &mut Cursor<&[u8]>,
74 ) -> Result<Self, CodecError>;
75
76 fn get_decoded_with_param(decoding_parameter: &P, bytes: &[u8]) -> Result<Self, CodecError> {
79 let mut cursor = Cursor::new(bytes);
80 let decoded = Self::decode_with_param(decoding_parameter, &mut cursor)?;
81 if cursor.position() as usize != bytes.len() {
82 return Err(CodecError::BytesLeftOver(
83 bytes.len() - cursor.position() as usize,
84 ));
85 }
86
87 Ok(decoded)
88 }
89}
90
91impl<D: Decode, T> ParameterizedDecode<T> for D {
94 fn decode_with_param(
95 _decoding_parameter: &T,
96 bytes: &mut Cursor<&[u8]>,
97 ) -> Result<Self, CodecError> {
98 Self::decode(bytes)
99 }
100}
101
102pub trait Encode {
104 fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError>;
106
107 fn get_encoded(&self) -> Result<Vec<u8>, CodecError> {
109 self.get_encoded_with_param(&())
110 }
111
112 fn encoded_len(&self) -> Option<usize> {
115 None
116 }
117}
118
119pub trait ParameterizedEncode<P> {
121 fn encode_with_param(
125 &self,
126 encoding_parameter: &P,
127 bytes: &mut Vec<u8>,
128 ) -> Result<(), CodecError>;
129
130 fn get_encoded_with_param(&self, encoding_parameter: &P) -> Result<Vec<u8>, CodecError> {
132 let mut ret = if let Some(length) = self.encoded_len_with_param(encoding_parameter) {
133 Vec::with_capacity(length)
134 } else {
135 Vec::new()
136 };
137 self.encode_with_param(encoding_parameter, &mut ret)?;
138 Ok(ret)
139 }
140
141 fn encoded_len_with_param(&self, _encoding_parameter: &P) -> Option<usize> {
144 None
145 }
146}
147
148impl<E: Encode + ?Sized, T> ParameterizedEncode<T> for E {
151 fn encode_with_param(
152 &self,
153 _encoding_parameter: &T,
154 bytes: &mut Vec<u8>,
155 ) -> Result<(), CodecError> {
156 self.encode(bytes)
157 }
158
159 fn encoded_len_with_param(&self, _encoding_parameter: &T) -> Option<usize> {
160 <Self as Encode>::encoded_len(self)
161 }
162}
163
164impl Decode for () {
165 fn decode(_bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
166 Ok(())
167 }
168}
169
170impl Encode for () {
171 fn encode(&self, _bytes: &mut Vec<u8>) -> Result<(), CodecError> {
172 Ok(())
173 }
174
175 fn encoded_len(&self) -> Option<usize> {
176 Some(0)
177 }
178}
179
180impl Decode for u8 {
181 fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
182 let mut value = [0u8; size_of::<u8>()];
183 bytes.read_exact(&mut value)?;
184 Ok(value[0])
185 }
186}
187
188impl Encode for u8 {
189 fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
190 bytes.push(*self);
191 Ok(())
192 }
193
194 fn encoded_len(&self) -> Option<usize> {
195 Some(1)
196 }
197}
198
199impl Decode for u16 {
200 fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
201 Ok(bytes.read_u16::<BigEndian>()?)
202 }
203}
204
205impl Encode for u16 {
206 fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
207 bytes.extend_from_slice(&u16::to_be_bytes(*self));
208 Ok(())
209 }
210
211 fn encoded_len(&self) -> Option<usize> {
212 Some(2)
213 }
214}
215
216#[derive(Debug, Clone, Copy, PartialEq, Eq)]
219struct U24(pub u32);
220
221impl Decode for U24 {
222 fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
223 Ok(U24(bytes.read_u24::<BigEndian>()?))
224 }
225}
226
227impl Encode for U24 {
228 fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
229 bytes.extend_from_slice(&u32::to_be_bytes(self.0)[1..]);
231 Ok(())
232 }
233
234 fn encoded_len(&self) -> Option<usize> {
235 Some(3)
236 }
237}
238
239impl Decode for u32 {
240 fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
241 Ok(bytes.read_u32::<BigEndian>()?)
242 }
243}
244
245impl Encode for u32 {
246 fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
247 bytes.extend_from_slice(&u32::to_be_bytes(*self));
248 Ok(())
249 }
250
251 fn encoded_len(&self) -> Option<usize> {
252 Some(4)
253 }
254}
255
256impl Decode for u64 {
257 fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
258 Ok(bytes.read_u64::<BigEndian>()?)
259 }
260}
261
262impl Encode for u64 {
263 fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
264 bytes.extend_from_slice(&u64::to_be_bytes(*self));
265 Ok(())
266 }
267
268 fn encoded_len(&self) -> Option<usize> {
269 Some(8)
270 }
271}
272
273pub fn encode_u8_items<P, E: ParameterizedEncode<P>>(
277 bytes: &mut Vec<u8>,
278 encoding_parameter: &P,
279 items: &[E],
280) -> Result<(), CodecError> {
281 let len_offset = bytes.len();
283 bytes.push(0);
284
285 for item in items {
286 item.encode_with_param(encoding_parameter, bytes)?;
287 }
288
289 let len =
290 u8::try_from(bytes.len() - len_offset - 1).map_err(|_| CodecError::LengthPrefixOverflow)?;
291 bytes[len_offset] = len;
292 Ok(())
293}
294
295pub fn decode_u8_items<P, D: ParameterizedDecode<P>>(
300 decoding_parameter: &P,
301 bytes: &mut Cursor<&[u8]>,
302) -> Result<Vec<D>, CodecError> {
303 let length = usize::from(u8::decode(bytes)?);
305
306 decode_items(length, decoding_parameter, bytes)
307}
308
309pub fn encode_u16_items<P, E: ParameterizedEncode<P>>(
313 bytes: &mut Vec<u8>,
314 encoding_parameter: &P,
315 items: &[E],
316) -> Result<(), CodecError> {
317 let len_offset = bytes.len();
319 0u16.encode(bytes)?;
320
321 for item in items {
322 item.encode_with_param(encoding_parameter, bytes)?;
323 }
324
325 let len = u16::try_from(bytes.len() - len_offset - 2)
326 .map_err(|_| CodecError::LengthPrefixOverflow)?;
327 bytes[len_offset..len_offset + 2].copy_from_slice(&len.to_be_bytes());
328 Ok(())
329}
330
331pub fn decode_u16_items<P, D: ParameterizedDecode<P>>(
336 decoding_parameter: &P,
337 bytes: &mut Cursor<&[u8]>,
338) -> Result<Vec<D>, CodecError> {
339 let length = usize::from(u16::decode(bytes)?);
341
342 decode_items(length, decoding_parameter, bytes)
343}
344
345pub fn encode_u24_items<P, E: ParameterizedEncode<P>>(
350 bytes: &mut Vec<u8>,
351 encoding_parameter: &P,
352 items: &[E],
353) -> Result<(), CodecError> {
354 let len_offset = bytes.len();
356 U24(0).encode(bytes)?;
357
358 for item in items {
359 item.encode_with_param(encoding_parameter, bytes)?;
360 }
361
362 let len = u32::try_from(bytes.len() - len_offset - 3)
363 .map_err(|_| CodecError::LengthPrefixOverflow)?;
364 if len > 0xffffff {
365 return Err(CodecError::LengthPrefixOverflow);
366 }
367 bytes[len_offset..len_offset + 3].copy_from_slice(&len.to_be_bytes()[1..]);
368 Ok(())
369}
370
371pub fn decode_u24_items<P, D: ParameterizedDecode<P>>(
376 decoding_parameter: &P,
377 bytes: &mut Cursor<&[u8]>,
378) -> Result<Vec<D>, CodecError> {
379 let length = U24::decode(bytes)?.0 as usize;
381
382 decode_items(length, decoding_parameter, bytes)
383}
384
385pub fn encode_u32_items<P, E: ParameterizedEncode<P>>(
390 bytes: &mut Vec<u8>,
391 encoding_parameter: &P,
392 items: &[E],
393) -> Result<(), CodecError> {
394 let len_offset = bytes.len();
396 0u32.encode(bytes)?;
397
398 for item in items {
399 item.encode_with_param(encoding_parameter, bytes)?;
400 }
401
402 let len = u32::try_from(bytes.len() - len_offset - 4)
403 .map_err(|_| CodecError::LengthPrefixOverflow)?;
404 bytes[len_offset..len_offset + 4].copy_from_slice(&len.to_be_bytes());
405 Ok(())
406}
407
408pub fn decode_u32_items<P, D: ParameterizedDecode<P>>(
413 decoding_parameter: &P,
414 bytes: &mut Cursor<&[u8]>,
415) -> Result<Vec<D>, CodecError> {
416 let len: usize = u32::decode(bytes)?
418 .try_into()
419 .map_err(|err: TryFromIntError| CodecError::Other(err.into()))?;
420
421 decode_items(len, decoding_parameter, bytes)
422}
423
424fn decode_items<P, D: ParameterizedDecode<P>>(
426 length: usize,
427 decoding_parameter: &P,
428 bytes: &mut Cursor<&[u8]>,
429) -> Result<Vec<D>, CodecError> {
430 let mut decoded = Vec::new();
431 let initial_position = bytes.position() as usize;
432
433 let inner = bytes.get_ref();
435
436 let (items_end, overflowed) = initial_position.overflowing_add(length);
438 if overflowed || items_end > inner.len() {
439 return Err(CodecError::LengthPrefixTooBig(length));
440 }
441
442 let mut sub = Cursor::new(&bytes.get_ref()[initial_position..items_end]);
443
444 while sub.position() < length as u64 {
445 decoded.push(D::decode_with_param(decoding_parameter, &mut sub)?);
446 }
447
448 bytes.set_position(initial_position as u64 + sub.position());
450
451 Ok(decoded)
452}
453
454#[cfg(test)]
455mod tests {
456 use std::io::ErrorKind;
457
458 use super::*;
459 use assert_matches::assert_matches;
460
461 #[test]
462 fn encode_nothing() {
463 let mut bytes = vec![];
464 ().encode(&mut bytes).unwrap();
465 assert_eq!(bytes.len(), 0);
466 }
467
468 #[test]
469 fn roundtrip_u8() {
470 let value = 100u8;
471
472 let mut bytes = vec![];
473 value.encode(&mut bytes).unwrap();
474 assert_eq!(bytes.len(), 1);
475
476 let decoded = u8::decode(&mut Cursor::new(&bytes)).unwrap();
477 assert_eq!(value, decoded);
478 }
479
480 #[test]
481 fn roundtrip_u16() {
482 let value = 1000u16;
483
484 let mut bytes = vec![];
485 value.encode(&mut bytes).unwrap();
486 assert_eq!(bytes.len(), 2);
487 assert_eq!(bytes, vec![3, 232]);
489
490 let decoded = u16::decode(&mut Cursor::new(&bytes)).unwrap();
491 assert_eq!(value, decoded);
492 }
493
494 #[test]
495 fn roundtrip_u24() {
496 let value = U24(1_000_000u32);
497
498 let mut bytes = vec![];
499 value.encode(&mut bytes).unwrap();
500 assert_eq!(bytes.len(), 3);
501 assert_eq!(bytes, vec![15, 66, 64]);
503
504 let decoded = U24::decode(&mut Cursor::new(&bytes)).unwrap();
505 assert_eq!(value, decoded);
506 }
507
508 #[test]
509 fn roundtrip_u32() {
510 let value = 134_217_728u32;
511
512 let mut bytes = vec![];
513 value.encode(&mut bytes).unwrap();
514 assert_eq!(bytes.len(), 4);
515 assert_eq!(bytes, vec![8, 0, 0, 0]);
517
518 let decoded = u32::decode(&mut Cursor::new(&bytes)).unwrap();
519 assert_eq!(value, decoded);
520 }
521
522 #[test]
523 fn roundtrip_u64() {
524 let value = 137_438_953_472u64;
525
526 let mut bytes = vec![];
527 value.encode(&mut bytes).unwrap();
528 assert_eq!(bytes.len(), 8);
529 assert_eq!(bytes, vec![0, 0, 0, 32, 0, 0, 0, 0]);
531
532 let decoded = u64::decode(&mut Cursor::new(&bytes)).unwrap();
533 assert_eq!(value, decoded);
534 }
535
536 #[derive(Debug, Eq, PartialEq)]
537 struct TestMessage {
538 field_u8: u8,
539 field_u16: u16,
540 field_u24: U24,
541 field_u32: u32,
542 field_u64: u64,
543 }
544
545 impl Encode for TestMessage {
546 fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
547 self.field_u8.encode(bytes)?;
548 self.field_u16.encode(bytes)?;
549 self.field_u24.encode(bytes)?;
550 self.field_u32.encode(bytes)?;
551 self.field_u64.encode(bytes)
552 }
553
554 fn encoded_len(&self) -> Option<usize> {
555 Some(
556 self.field_u8.encoded_len()?
557 + self.field_u16.encoded_len()?
558 + self.field_u24.encoded_len()?
559 + self.field_u32.encoded_len()?
560 + self.field_u64.encoded_len()?,
561 )
562 }
563 }
564
565 impl Decode for TestMessage {
566 fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
567 let field_u8 = u8::decode(bytes)?;
568 let field_u16 = u16::decode(bytes)?;
569 let field_u24 = U24::decode(bytes)?;
570 let field_u32 = u32::decode(bytes)?;
571 let field_u64 = u64::decode(bytes)?;
572
573 Ok(TestMessage {
574 field_u8,
575 field_u16,
576 field_u24,
577 field_u32,
578 field_u64,
579 })
580 }
581 }
582
583 impl TestMessage {
584 fn encoded_length() -> usize {
585 1 +
587 2 +
589 3 +
591 4 +
593 8
595 }
596 }
597
598 #[test]
599 fn roundtrip_message() {
600 let value = TestMessage {
601 field_u8: 0,
602 field_u16: 300,
603 field_u24: U24(1_000_000),
604 field_u32: 134_217_728,
605 field_u64: 137_438_953_472,
606 };
607
608 let mut bytes = vec![];
609 value.encode(&mut bytes).unwrap();
610 assert_eq!(bytes.len(), TestMessage::encoded_length());
611 assert_eq!(value.encoded_len().unwrap(), TestMessage::encoded_length());
612
613 let decoded = TestMessage::decode(&mut Cursor::new(&bytes)).unwrap();
614 assert_eq!(value, decoded);
615 }
616
617 fn messages_vec() -> Vec<TestMessage> {
618 vec![
619 TestMessage {
620 field_u8: 0,
621 field_u16: 300,
622 field_u24: U24(1_000_000),
623 field_u32: 134_217_728,
624 field_u64: 137_438_953_472,
625 },
626 TestMessage {
627 field_u8: 0,
628 field_u16: 300,
629 field_u24: U24(1_000_000),
630 field_u32: 134_217_728,
631 field_u64: 137_438_953_472,
632 },
633 TestMessage {
634 field_u8: 0,
635 field_u16: 300,
636 field_u24: U24(1_000_000),
637 field_u32: 134_217_728,
638 field_u64: 137_438_953_472,
639 },
640 ]
641 }
642
643 #[test]
644 fn roundtrip_variable_length_u8() {
645 let values = messages_vec();
646 let mut bytes = vec![];
647 encode_u8_items(&mut bytes, &(), &values).unwrap();
648
649 assert_eq!(
650 bytes.len(),
651 1 +
653 3 * TestMessage::encoded_length()
655 );
656
657 let decoded = decode_u8_items(&(), &mut Cursor::new(&bytes)).unwrap();
658 assert_eq!(values, decoded);
659 }
660
661 #[test]
662 fn roundtrip_variable_length_u16() {
663 let values = messages_vec();
664 let mut bytes = vec![];
665 encode_u16_items(&mut bytes, &(), &values).unwrap();
666
667 assert_eq!(
668 bytes.len(),
669 2 +
671 3 * TestMessage::encoded_length()
673 );
674
675 assert_eq!(bytes[0..2], [0, 3 * TestMessage::encoded_length() as u8]);
677
678 let decoded = decode_u16_items(&(), &mut Cursor::new(&bytes)).unwrap();
679 assert_eq!(values, decoded);
680 }
681
682 #[test]
683 fn roundtrip_variable_length_u24() {
684 let values = messages_vec();
685 let mut bytes = vec![];
686 encode_u24_items(&mut bytes, &(), &values).unwrap();
687
688 assert_eq!(
689 bytes.len(),
690 3 +
692 3 * TestMessage::encoded_length()
694 );
695
696 assert_eq!(bytes[0..3], [0, 0, 3 * TestMessage::encoded_length() as u8]);
698
699 let decoded = decode_u24_items(&(), &mut Cursor::new(&bytes)).unwrap();
700 assert_eq!(values, decoded);
701 }
702
703 #[test]
704 fn roundtrip_variable_length_u32() {
705 let values = messages_vec();
706 let mut bytes = Vec::new();
707 encode_u32_items(&mut bytes, &(), &values).unwrap();
708
709 assert_eq!(bytes.len(), 4 + 3 * TestMessage::encoded_length());
710
711 assert_eq!(
713 bytes[0..4],
714 [0, 0, 0, 3 * TestMessage::encoded_length() as u8]
715 );
716
717 let decoded = decode_u32_items(&(), &mut Cursor::new(&bytes)).unwrap();
718 assert_eq!(values, decoded);
719 }
720
721 #[test]
722 fn decode_too_short() {
723 let values = messages_vec();
724 let mut bytes = Vec::new();
725 encode_u32_items(&mut bytes, &(), &values).unwrap();
726
727 let error =
728 decode_u32_items::<_, TestMessage>(&(), &mut Cursor::new(&bytes[..3])).unwrap_err();
729 assert_matches!(error, CodecError::Io(e) => assert_eq!(e.kind(), ErrorKind::UnexpectedEof));
730
731 let error =
732 decode_u32_items::<_, TestMessage>(&(), &mut Cursor::new(&bytes[..4])).unwrap_err();
733 assert_matches!(error, CodecError::LengthPrefixTooBig(_));
734 }
735
736 #[test]
737 fn decode_items_overflow() {
738 let encoded = vec![1u8];
739
740 let mut cursor = Cursor::new(encoded.as_slice());
741 cursor.set_position(1);
742
743 assert_matches!(
744 decode_items::<(), u8>(usize::MAX, &(), &mut cursor).unwrap_err(),
745 CodecError::LengthPrefixTooBig(usize::MAX)
746 );
747 }
748
749 #[test]
750 fn decode_items_too_big() {
751 let encoded = vec![1u8];
752
753 let mut cursor = Cursor::new(encoded.as_slice());
754 cursor.set_position(1);
755
756 assert_matches!(
757 decode_items::<(), u8>(2, &(), &mut cursor).unwrap_err(),
758 CodecError::LengthPrefixTooBig(2)
759 );
760 }
761
762 #[test]
763 fn length_hint_correctness() {
764 assert_eq!(().encoded_len().unwrap(), ().get_encoded().unwrap().len());
765 assert_eq!(0u8.encoded_len().unwrap(), 0u8.get_encoded().unwrap().len());
766 assert_eq!(
767 0u16.encoded_len().unwrap(),
768 0u16.get_encoded().unwrap().len()
769 );
770 assert_eq!(
771 U24(0).encoded_len().unwrap(),
772 U24(0).get_encoded().unwrap().len()
773 );
774 assert_eq!(
775 0u32.encoded_len().unwrap(),
776 0u32.get_encoded().unwrap().len()
777 );
778 assert_eq!(
779 0u64.encoded_len().unwrap(),
780 0u64.get_encoded().unwrap().len()
781 );
782 }
783
784 #[test]
785 fn get_decoded_leftover() {
786 let encoded_good = [1, 2, 3, 4];
787 assert_matches!(u32::get_decoded(&encoded_good).unwrap(), 0x01020304u32);
788
789 let encoded_bad = [1, 2, 3, 4, 5];
790 let error = u32::get_decoded(&encoded_bad).unwrap_err();
791 assert_matches!(error, CodecError::BytesLeftOver(1));
792 }
793
794 #[test]
795 fn encoded_len_backwards_compatibility() {
796 struct MyMessage;
797
798 impl Encode for MyMessage {
799 fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
800 bytes.extend_from_slice(b"Hello, world");
801 Ok(())
802 }
803 }
804
805 assert_eq!(MyMessage.encoded_len(), None);
806
807 assert_eq!(MyMessage.get_encoded().unwrap(), b"Hello, world");
808 }
809
810 #[test]
811 fn encode_length_prefix_overflow() {
812 let mut bytes = Vec::new();
813 let error = encode_u8_items(&mut bytes, &(), &[1u8; u8::MAX as usize + 1]).unwrap_err();
814 assert_matches!(error, CodecError::LengthPrefixOverflow);
815 }
816}