1use byteorder::{BigEndian, ReadBytesExt};
13use std::{
14 convert::TryInto,
15 error::Error,
16 io::{Cursor, Read},
17 mem::size_of,
18 num::TryFromIntError,
19};
20
21#[derive(Debug, thiserror::Error)]
23#[non_exhaustive]
24pub enum CodecError {
25 #[error("I/O error")]
27 Io(#[from] std::io::Error),
28
29 #[error("{0} bytes left in buffer after decoding value")]
31 BytesLeftOver(usize),
32
33 #[error("length prefix of encoded vector overflows buffer: {0}")]
35 LengthPrefixTooBig(usize),
36
37 #[error("vector length exceeded range of length prefix")]
39 LengthPrefixOverflow,
40
41 #[error("other error: {0}")]
43 Other(#[source] Box<dyn Error + 'static + Send + Sync>),
44
45 #[error("unexpected value")]
47 UnexpectedValue,
48}
49
50pub trait Decode: Sized {
52 fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError>;
56
57 fn get_decoded(bytes: &[u8]) -> Result<Self, CodecError> {
60 Self::get_decoded_with_param(&(), bytes)
61 }
62}
63
64pub trait ParameterizedDecode<P>: Sized {
67 fn decode_with_param(
72 decoding_parameter: &P,
73 bytes: &mut Cursor<&[u8]>,
74 ) -> Result<Self, CodecError>;
75
76 fn get_decoded_with_param(decoding_parameter: &P, bytes: &[u8]) -> Result<Self, CodecError> {
79 let mut cursor = Cursor::new(bytes);
80 let decoded = Self::decode_with_param(decoding_parameter, &mut cursor)?;
81 if cursor.position() as usize != bytes.len() {
82 return Err(CodecError::BytesLeftOver(
83 bytes.len() - cursor.position() as usize,
84 ));
85 }
86
87 Ok(decoded)
88 }
89}
90
91impl<D: Decode, T> ParameterizedDecode<T> for D {
94 fn decode_with_param(
95 _decoding_parameter: &T,
96 bytes: &mut Cursor<&[u8]>,
97 ) -> Result<Self, CodecError> {
98 Self::decode(bytes)
99 }
100}
101
102pub trait Encode {
104 fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError>;
106
107 fn get_encoded(&self) -> Result<Vec<u8>, CodecError> {
109 self.get_encoded_with_param(&())
110 }
111
112 fn encoded_len(&self) -> Option<usize> {
115 None
116 }
117}
118
119pub trait ParameterizedEncode<P> {
121 fn encode_with_param(
125 &self,
126 encoding_parameter: &P,
127 bytes: &mut Vec<u8>,
128 ) -> Result<(), CodecError>;
129
130 fn get_encoded_with_param(&self, encoding_parameter: &P) -> Result<Vec<u8>, CodecError> {
132 let mut ret = if let Some(length) = self.encoded_len_with_param(encoding_parameter) {
133 Vec::with_capacity(length)
134 } else {
135 Vec::new()
136 };
137 self.encode_with_param(encoding_parameter, &mut ret)?;
138 Ok(ret)
139 }
140
141 fn encoded_len_with_param(&self, _encoding_parameter: &P) -> Option<usize> {
144 None
145 }
146}
147
148impl<E: Encode + ?Sized, T> ParameterizedEncode<T> for E {
151 fn encode_with_param(
152 &self,
153 _encoding_parameter: &T,
154 bytes: &mut Vec<u8>,
155 ) -> Result<(), CodecError> {
156 self.encode(bytes)
157 }
158
159 fn encoded_len_with_param(&self, _encoding_parameter: &T) -> Option<usize> {
160 <Self as Encode>::encoded_len(self)
161 }
162}
163
164impl Decode for () {
165 fn decode(_bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
166 Ok(())
167 }
168}
169
170impl Encode for () {
171 fn encode(&self, _bytes: &mut Vec<u8>) -> Result<(), CodecError> {
172 Ok(())
173 }
174
175 fn encoded_len(&self) -> Option<usize> {
176 Some(0)
177 }
178}
179
180impl Decode for u8 {
181 fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
182 let mut value = [0u8; size_of::<u8>()];
183 bytes.read_exact(&mut value)?;
184 Ok(value[0])
185 }
186}
187
188impl Encode for u8 {
189 fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
190 bytes.push(*self);
191 Ok(())
192 }
193
194 fn encoded_len(&self) -> Option<usize> {
195 Some(1)
196 }
197}
198
199impl Decode for u16 {
200 fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
201 Ok(bytes.read_u16::<BigEndian>()?)
202 }
203}
204
205impl Encode for u16 {
206 fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
207 bytes.extend_from_slice(&u16::to_be_bytes(*self));
208 Ok(())
209 }
210
211 fn encoded_len(&self) -> Option<usize> {
212 Some(2)
213 }
214}
215
216#[derive(Debug, Clone, Copy, PartialEq, Eq)]
219struct U24(pub u32);
220
221impl Decode for U24 {
222 fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
223 Ok(U24(bytes.read_u24::<BigEndian>()?))
224 }
225}
226
227impl Encode for U24 {
228 fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
229 bytes.extend_from_slice(&u32::to_be_bytes(self.0)[1..]);
231 Ok(())
232 }
233
234 fn encoded_len(&self) -> Option<usize> {
235 Some(3)
236 }
237}
238
239impl Decode for u32 {
240 fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
241 Ok(bytes.read_u32::<BigEndian>()?)
242 }
243}
244
245impl Encode for u32 {
246 fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
247 bytes.extend_from_slice(&u32::to_be_bytes(*self));
248 Ok(())
249 }
250
251 fn encoded_len(&self) -> Option<usize> {
252 Some(4)
253 }
254}
255
256impl Decode for u64 {
257 fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
258 Ok(bytes.read_u64::<BigEndian>()?)
259 }
260}
261
262impl Encode for u64 {
263 fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
264 bytes.extend_from_slice(&u64::to_be_bytes(*self));
265 Ok(())
266 }
267
268 fn encoded_len(&self) -> Option<usize> {
269 Some(8)
270 }
271}
272
273pub fn encode_fixlen_items<E: Encode>(bytes: &mut Vec<u8>, items: &[E]) -> Result<(), CodecError> {
277 for item in items {
278 item.encode(bytes)?;
279 }
280
281 Ok(())
282}
283
284pub fn encode_u8_items<P, E: ParameterizedEncode<P>>(
288 bytes: &mut Vec<u8>,
289 encoding_parameter: &P,
290 items: &[E],
291) -> Result<(), CodecError> {
292 let len_offset = bytes.len();
294 bytes.push(0);
295
296 for item in items {
297 item.encode_with_param(encoding_parameter, bytes)?;
298 }
299
300 let len =
301 u8::try_from(bytes.len() - len_offset - 1).map_err(|_| CodecError::LengthPrefixOverflow)?;
302 bytes[len_offset] = len;
303 Ok(())
304}
305
306pub fn decode_u8_items<P, D: ParameterizedDecode<P>>(
311 decoding_parameter: &P,
312 bytes: &mut Cursor<&[u8]>,
313) -> Result<Vec<D>, CodecError> {
314 let length = usize::from(u8::decode(bytes)?);
316
317 decode_fixlen_items(length, decoding_parameter, bytes)
318}
319
320pub fn encode_u16_items<P, E: ParameterizedEncode<P>>(
324 bytes: &mut Vec<u8>,
325 encoding_parameter: &P,
326 items: &[E],
327) -> Result<(), CodecError> {
328 let len_offset = bytes.len();
330 0u16.encode(bytes)?;
331
332 for item in items {
333 item.encode_with_param(encoding_parameter, bytes)?;
334 }
335
336 let len = u16::try_from(bytes.len() - len_offset - 2)
337 .map_err(|_| CodecError::LengthPrefixOverflow)?;
338 bytes[len_offset..len_offset + 2].copy_from_slice(&len.to_be_bytes());
339 Ok(())
340}
341
342pub fn decode_u16_items<P, D: ParameterizedDecode<P>>(
347 decoding_parameter: &P,
348 bytes: &mut Cursor<&[u8]>,
349) -> Result<Vec<D>, CodecError> {
350 let length = usize::from(u16::decode(bytes)?);
352
353 decode_fixlen_items(length, decoding_parameter, bytes)
354}
355
356pub fn encode_u24_items<P, E: ParameterizedEncode<P>>(
361 bytes: &mut Vec<u8>,
362 encoding_parameter: &P,
363 items: &[E],
364) -> Result<(), CodecError> {
365 let len_offset = bytes.len();
367 U24(0).encode(bytes)?;
368
369 for item in items {
370 item.encode_with_param(encoding_parameter, bytes)?;
371 }
372
373 let len = u32::try_from(bytes.len() - len_offset - 3)
374 .map_err(|_| CodecError::LengthPrefixOverflow)?;
375 if len > 0xffffff {
376 return Err(CodecError::LengthPrefixOverflow);
377 }
378 bytes[len_offset..len_offset + 3].copy_from_slice(&len.to_be_bytes()[1..]);
379 Ok(())
380}
381
382pub fn decode_u24_items<P, D: ParameterizedDecode<P>>(
387 decoding_parameter: &P,
388 bytes: &mut Cursor<&[u8]>,
389) -> Result<Vec<D>, CodecError> {
390 let length = U24::decode(bytes)?.0 as usize;
392
393 decode_fixlen_items(length, decoding_parameter, bytes)
394}
395
396pub fn encode_u32_items<P, E: ParameterizedEncode<P>>(
401 bytes: &mut Vec<u8>,
402 encoding_parameter: &P,
403 items: &[E],
404) -> Result<(), CodecError> {
405 let len_offset = bytes.len();
407 0u32.encode(bytes)?;
408
409 for item in items {
410 item.encode_with_param(encoding_parameter, bytes)?;
411 }
412
413 let len = u32::try_from(bytes.len() - len_offset - 4)
414 .map_err(|_| CodecError::LengthPrefixOverflow)?;
415 bytes[len_offset..len_offset + 4].copy_from_slice(&len.to_be_bytes());
416 Ok(())
417}
418
419pub fn decode_u32_items<P, D: ParameterizedDecode<P>>(
424 decoding_parameter: &P,
425 bytes: &mut Cursor<&[u8]>,
426) -> Result<Vec<D>, CodecError> {
427 let len: usize = u32::decode(bytes)?
429 .try_into()
430 .map_err(|err: TryFromIntError| CodecError::Other(err.into()))?;
431
432 decode_fixlen_items(len, decoding_parameter, bytes)
433}
434
435pub fn decode_fixlen_items<P, D: ParameterizedDecode<P>>(
439 length: usize,
440 decoding_parameter: &P,
441 bytes: &mut Cursor<&[u8]>,
442) -> Result<Vec<D>, CodecError> {
443 let mut decoded = Vec::new();
444 let initial_position = bytes.position() as usize;
445
446 let inner = bytes.get_ref();
448
449 let (items_end, overflowed) = initial_position.overflowing_add(length);
451 if overflowed || items_end > inner.len() {
452 return Err(CodecError::LengthPrefixTooBig(length));
453 }
454
455 let mut sub = Cursor::new(&bytes.get_ref()[initial_position..items_end]);
456
457 while sub.position() < length as u64 {
458 decoded.push(D::decode_with_param(decoding_parameter, &mut sub)?);
459 }
460
461 bytes.set_position(initial_position as u64 + sub.position());
463
464 Ok(decoded)
465}
466
467#[cfg(test)]
468mod tests {
469 use std::io::ErrorKind;
470
471 use super::*;
472 use assert_matches::assert_matches;
473
474 #[test]
475 fn encode_nothing() {
476 let mut bytes = vec![];
477 ().encode(&mut bytes).unwrap();
478 assert_eq!(bytes.len(), 0);
479 }
480
481 #[test]
482 fn roundtrip_u8() {
483 let value = 100u8;
484
485 let mut bytes = vec![];
486 value.encode(&mut bytes).unwrap();
487 assert_eq!(bytes.len(), 1);
488
489 let decoded = u8::decode(&mut Cursor::new(&bytes)).unwrap();
490 assert_eq!(value, decoded);
491 }
492
493 #[test]
494 fn roundtrip_u16() {
495 let value = 1000u16;
496
497 let mut bytes = vec![];
498 value.encode(&mut bytes).unwrap();
499 assert_eq!(bytes.len(), 2);
500 assert_eq!(bytes, vec![3, 232]);
502
503 let decoded = u16::decode(&mut Cursor::new(&bytes)).unwrap();
504 assert_eq!(value, decoded);
505 }
506
507 #[test]
508 fn roundtrip_u24() {
509 let value = U24(1_000_000u32);
510
511 let mut bytes = vec![];
512 value.encode(&mut bytes).unwrap();
513 assert_eq!(bytes.len(), 3);
514 assert_eq!(bytes, vec![15, 66, 64]);
516
517 let decoded = U24::decode(&mut Cursor::new(&bytes)).unwrap();
518 assert_eq!(value, decoded);
519 }
520
521 #[test]
522 fn roundtrip_u32() {
523 let value = 134_217_728u32;
524
525 let mut bytes = vec![];
526 value.encode(&mut bytes).unwrap();
527 assert_eq!(bytes.len(), 4);
528 assert_eq!(bytes, vec![8, 0, 0, 0]);
530
531 let decoded = u32::decode(&mut Cursor::new(&bytes)).unwrap();
532 assert_eq!(value, decoded);
533 }
534
535 #[test]
536 fn roundtrip_u64() {
537 let value = 137_438_953_472u64;
538
539 let mut bytes = vec![];
540 value.encode(&mut bytes).unwrap();
541 assert_eq!(bytes.len(), 8);
542 assert_eq!(bytes, vec![0, 0, 0, 32, 0, 0, 0, 0]);
544
545 let decoded = u64::decode(&mut Cursor::new(&bytes)).unwrap();
546 assert_eq!(value, decoded);
547 }
548
549 #[derive(Debug, Eq, PartialEq)]
550 struct TestMessage {
551 field_u8: u8,
552 field_u16: u16,
553 field_u24: U24,
554 field_u32: u32,
555 field_u64: u64,
556 }
557
558 impl Encode for TestMessage {
559 fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
560 self.field_u8.encode(bytes)?;
561 self.field_u16.encode(bytes)?;
562 self.field_u24.encode(bytes)?;
563 self.field_u32.encode(bytes)?;
564 self.field_u64.encode(bytes)
565 }
566
567 fn encoded_len(&self) -> Option<usize> {
568 Some(
569 self.field_u8.encoded_len()?
570 + self.field_u16.encoded_len()?
571 + self.field_u24.encoded_len()?
572 + self.field_u32.encoded_len()?
573 + self.field_u64.encoded_len()?,
574 )
575 }
576 }
577
578 impl Decode for TestMessage {
579 fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
580 let field_u8 = u8::decode(bytes)?;
581 let field_u16 = u16::decode(bytes)?;
582 let field_u24 = U24::decode(bytes)?;
583 let field_u32 = u32::decode(bytes)?;
584 let field_u64 = u64::decode(bytes)?;
585
586 Ok(TestMessage {
587 field_u8,
588 field_u16,
589 field_u24,
590 field_u32,
591 field_u64,
592 })
593 }
594 }
595
596 impl TestMessage {
597 fn encoded_length() -> usize {
598 1 +
600 2 +
602 3 +
604 4 +
606 8
608 }
609 }
610
611 #[test]
612 fn roundtrip_message() {
613 let value = TestMessage {
614 field_u8: 0,
615 field_u16: 300,
616 field_u24: U24(1_000_000),
617 field_u32: 134_217_728,
618 field_u64: 137_438_953_472,
619 };
620
621 let mut bytes = vec![];
622 value.encode(&mut bytes).unwrap();
623 assert_eq!(bytes.len(), TestMessage::encoded_length());
624 assert_eq!(value.encoded_len().unwrap(), TestMessage::encoded_length());
625
626 let decoded = TestMessage::decode(&mut Cursor::new(&bytes)).unwrap();
627 assert_eq!(value, decoded);
628 }
629
630 fn messages_vec() -> Vec<TestMessage> {
631 vec![
632 TestMessage {
633 field_u8: 0,
634 field_u16: 300,
635 field_u24: U24(1_000_000),
636 field_u32: 134_217_728,
637 field_u64: 137_438_953_472,
638 },
639 TestMessage {
640 field_u8: 0,
641 field_u16: 300,
642 field_u24: U24(1_000_000),
643 field_u32: 134_217_728,
644 field_u64: 137_438_953_472,
645 },
646 TestMessage {
647 field_u8: 0,
648 field_u16: 300,
649 field_u24: U24(1_000_000),
650 field_u32: 134_217_728,
651 field_u64: 137_438_953_472,
652 },
653 ]
654 }
655
656 #[test]
657 fn roundtrip_variable_length_u8() {
658 let values = messages_vec();
659 let mut bytes = vec![];
660 encode_u8_items(&mut bytes, &(), &values).unwrap();
661
662 assert_eq!(
663 bytes.len(),
664 1 +
666 3 * TestMessage::encoded_length()
668 );
669
670 let decoded = decode_u8_items(&(), &mut Cursor::new(&bytes)).unwrap();
671 assert_eq!(values, decoded);
672 }
673
674 #[test]
675 fn roundtrip_variable_length_u16() {
676 let values = messages_vec();
677 let mut bytes = vec![];
678 encode_u16_items(&mut bytes, &(), &values).unwrap();
679
680 assert_eq!(
681 bytes.len(),
682 2 +
684 3 * TestMessage::encoded_length()
686 );
687
688 assert_eq!(bytes[0..2], [0, 3 * TestMessage::encoded_length() as u8]);
690
691 let decoded = decode_u16_items(&(), &mut Cursor::new(&bytes)).unwrap();
692 assert_eq!(values, decoded);
693 }
694
695 #[test]
696 fn roundtrip_variable_length_u24() {
697 let values = messages_vec();
698 let mut bytes = vec![];
699 encode_u24_items(&mut bytes, &(), &values).unwrap();
700
701 assert_eq!(
702 bytes.len(),
703 3 +
705 3 * TestMessage::encoded_length()
707 );
708
709 assert_eq!(bytes[0..3], [0, 0, 3 * TestMessage::encoded_length() as u8]);
711
712 let decoded = decode_u24_items(&(), &mut Cursor::new(&bytes)).unwrap();
713 assert_eq!(values, decoded);
714 }
715
716 #[test]
717 fn roundtrip_variable_length_u32() {
718 let values = messages_vec();
719 let mut bytes = Vec::new();
720 encode_u32_items(&mut bytes, &(), &values).unwrap();
721
722 assert_eq!(bytes.len(), 4 + 3 * TestMessage::encoded_length());
723
724 assert_eq!(
726 bytes[0..4],
727 [0, 0, 0, 3 * TestMessage::encoded_length() as u8]
728 );
729
730 let decoded = decode_u32_items(&(), &mut Cursor::new(&bytes)).unwrap();
731 assert_eq!(values, decoded);
732 }
733
734 #[test]
735 fn roundtrip_fixlen_vector() {
736 let values = messages_vec();
737 let mut bytes = Vec::new();
738 encode_fixlen_items(&mut bytes, &values).unwrap();
739
740 let decoded = decode_fixlen_items(bytes.len(), &(), &mut Cursor::new(&bytes)).unwrap();
741 assert_eq!(values, decoded);
742
743 assert_matches!(
745 decode_fixlen_items::<_, TestMessage>(bytes.len() - 1, &(), &mut Cursor::new(&bytes))
746 .unwrap_err(),
747 CodecError::Io(e) => assert_eq!(e.kind(), ErrorKind::UnexpectedEof)
748 );
749
750 assert_matches!(
752 decode_fixlen_items::<_, TestMessage>(bytes.len() + 1, &(), &mut Cursor::new(&bytes))
753 .unwrap_err(),
754 CodecError::LengthPrefixTooBig(_)
755 );
756 }
757
758 #[test]
759 fn decode_too_short() {
760 let values = messages_vec();
761 let mut bytes = Vec::new();
762 encode_u32_items(&mut bytes, &(), &values).unwrap();
763
764 let error =
765 decode_u32_items::<_, TestMessage>(&(), &mut Cursor::new(&bytes[..3])).unwrap_err();
766 assert_matches!(error, CodecError::Io(e) => assert_eq!(e.kind(), ErrorKind::UnexpectedEof));
767
768 let error =
769 decode_u32_items::<_, TestMessage>(&(), &mut Cursor::new(&bytes[..4])).unwrap_err();
770 assert_matches!(error, CodecError::LengthPrefixTooBig(_));
771 }
772
773 #[test]
774 fn decode_items_overflow() {
775 let encoded = vec![1u8];
776
777 let mut cursor = Cursor::new(encoded.as_slice());
778 cursor.set_position(1);
779
780 assert_matches!(
781 decode_fixlen_items::<(), u8>(usize::MAX, &(), &mut cursor).unwrap_err(),
782 CodecError::LengthPrefixTooBig(usize::MAX)
783 );
784 }
785
786 #[test]
787 fn decode_items_too_big() {
788 let encoded = vec![1u8];
789
790 let mut cursor = Cursor::new(encoded.as_slice());
791 cursor.set_position(1);
792
793 assert_matches!(
794 decode_fixlen_items::<(), u8>(2, &(), &mut cursor).unwrap_err(),
795 CodecError::LengthPrefixTooBig(2)
796 );
797 }
798
799 #[test]
800 fn length_hint_correctness() {
801 assert_eq!(().encoded_len().unwrap(), ().get_encoded().unwrap().len());
802 assert_eq!(0u8.encoded_len().unwrap(), 0u8.get_encoded().unwrap().len());
803 assert_eq!(
804 0u16.encoded_len().unwrap(),
805 0u16.get_encoded().unwrap().len()
806 );
807 assert_eq!(
808 U24(0).encoded_len().unwrap(),
809 U24(0).get_encoded().unwrap().len()
810 );
811 assert_eq!(
812 0u32.encoded_len().unwrap(),
813 0u32.get_encoded().unwrap().len()
814 );
815 assert_eq!(
816 0u64.encoded_len().unwrap(),
817 0u64.get_encoded().unwrap().len()
818 );
819 }
820
821 #[test]
822 fn get_decoded_leftover() {
823 let encoded_good = [1, 2, 3, 4];
824 assert_matches!(u32::get_decoded(&encoded_good).unwrap(), 0x01020304u32);
825
826 let encoded_bad = [1, 2, 3, 4, 5];
827 let error = u32::get_decoded(&encoded_bad).unwrap_err();
828 assert_matches!(error, CodecError::BytesLeftOver(1));
829 }
830
831 #[test]
832 fn encoded_len_backwards_compatibility() {
833 struct MyMessage;
834
835 impl Encode for MyMessage {
836 fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
837 bytes.extend_from_slice(b"Hello, world");
838 Ok(())
839 }
840 }
841
842 assert_eq!(MyMessage.encoded_len(), None);
843
844 assert_eq!(MyMessage.get_encoded().unwrap(), b"Hello, world");
845 }
846
847 #[test]
848 fn encode_length_prefix_overflow() {
849 let mut bytes = Vec::new();
850 let error = encode_u8_items(&mut bytes, &(), &[1u8; u8::MAX as usize + 1]).unwrap_err();
851 assert_matches!(error, CodecError::LengthPrefixOverflow);
852 }
853}