1use bytes::{Buf, BufMut};
9
10use crate::error::DecodeError;
11use crate::message_field::DefaultInstance;
12
13pub const RECURSION_LIMIT: u32 = 100;
25
26pub trait Message: DefaultInstance + Clone + PartialEq + Send + Sync {
75 fn compute_size(&self) -> u32;
87
88 fn write_to(&self, buf: &mut impl BufMut);
93
94 fn encode(&self, buf: &mut impl BufMut) {
96 self.compute_size();
97 self.write_to(buf);
98 }
99
100 fn encode_length_delimited(&self, buf: &mut impl BufMut) {
102 let len = self.compute_size();
103 crate::encoding::encode_varint(len as u64, buf);
104 self.write_to(buf);
105 }
106
107 fn encode_to_vec(&self) -> alloc::vec::Vec<u8> {
109 let size = self.compute_size() as usize;
110 let mut buf = alloc::vec::Vec::with_capacity(size);
111 self.write_to(&mut buf);
112 buf
113 }
114
115 fn encode_to_bytes(&self) -> bytes::Bytes {
124 let size = self.compute_size() as usize;
125 let mut buf = bytes::BytesMut::with_capacity(size);
126 self.write_to(&mut buf);
127 buf.freeze()
128 }
129
130 fn decode(buf: &mut impl Buf) -> Result<Self, DecodeError>
132 where
133 Self: Sized,
134 {
135 let mut msg = Self::default();
136 msg.merge(buf, RECURSION_LIMIT)?;
137 Ok(msg)
138 }
139
140 fn decode_from_slice(mut data: &[u8]) -> Result<Self, DecodeError>
145 where
146 Self: Sized,
147 {
148 Self::decode(&mut data)
151 }
152
153 fn decode_length_delimited(buf: &mut impl Buf) -> Result<Self, DecodeError>
168 where
169 Self: Sized,
170 {
171 const MAX_MESSAGE_BYTES: u64 = 0x7FFF_FFFF;
174 let len_u64 = crate::encoding::decode_varint(buf)?;
175 if len_u64 > MAX_MESSAGE_BYTES {
176 return Err(DecodeError::MessageTooLarge);
177 }
178 let len = usize::try_from(len_u64).map_err(|_| DecodeError::MessageTooLarge)?;
180 if buf.remaining() < len {
181 return Err(DecodeError::UnexpectedEof);
182 }
183 let limit = buf.remaining() - len;
188 let mut msg = Self::default();
189 msg.merge_to_limit(buf, RECURSION_LIMIT, limit)?;
190 if buf.remaining() != limit {
191 let remaining = buf.remaining();
192 if remaining > limit {
193 buf.advance(remaining - limit);
194 } else {
195 return Err(DecodeError::UnexpectedEof);
196 }
197 }
198 Ok(msg)
199 }
200
201 fn merge_field(
217 &mut self,
218 tag: crate::encoding::Tag,
219 buf: &mut impl Buf,
220 depth: u32,
221 ) -> Result<(), DecodeError>;
222
223 fn merge_to_limit(
240 &mut self,
241 buf: &mut impl Buf,
242 depth: u32,
243 limit: usize,
244 ) -> Result<(), DecodeError> {
245 while buf.remaining() > limit {
246 let tag = crate::encoding::Tag::decode(buf)?;
247 self.merge_field(tag, buf, depth)?;
248 }
249 Ok(())
250 }
251
252 fn merge_group(
268 &mut self,
269 buf: &mut impl Buf,
270 depth: u32,
271 field_number: u32,
272 ) -> Result<(), DecodeError> {
273 let depth = depth
274 .checked_sub(1)
275 .ok_or(DecodeError::RecursionLimitExceeded)?;
276 loop {
277 if !buf.has_remaining() {
278 return Err(DecodeError::UnexpectedEof);
279 }
280 let tag = crate::encoding::Tag::decode(buf)?;
281 if tag.wire_type() == crate::encoding::WireType::EndGroup {
282 return if tag.field_number() == field_number {
283 Ok(())
284 } else {
285 Err(DecodeError::InvalidEndGroup(tag.field_number()))
286 };
287 }
288 self.merge_field(tag, buf, depth)?;
289 }
290 }
291
292 fn merge(&mut self, buf: &mut impl Buf, depth: u32) -> Result<(), DecodeError> {
306 self.merge_to_limit(buf, depth, 0)
307 }
308
309 fn merge_from_slice(&mut self, mut data: &[u8]) -> Result<(), DecodeError> {
314 self.merge(&mut data, RECURSION_LIMIT)
315 }
316
317 fn merge_length_delimited(
342 &mut self,
343 buf: &mut impl Buf,
344 depth: u32,
345 ) -> Result<(), DecodeError> {
346 let depth = depth
347 .checked_sub(1)
348 .ok_or(DecodeError::RecursionLimitExceeded)?;
349 const MAX_SUB_MESSAGE_BYTES: u64 = 0x7FFF_FFFF;
350 let len_u64 = crate::encoding::decode_varint(buf)?;
351 if len_u64 > MAX_SUB_MESSAGE_BYTES {
352 return Err(DecodeError::MessageTooLarge);
353 }
354 let len = usize::try_from(len_u64).map_err(|_| DecodeError::MessageTooLarge)?;
355 if buf.remaining() < len {
356 return Err(DecodeError::UnexpectedEof);
357 }
358 let limit = buf.remaining() - len;
364 self.merge_to_limit(buf, depth, limit)?;
365 if buf.remaining() != limit {
366 let remaining = buf.remaining();
367 if remaining > limit {
368 buf.advance(remaining - limit);
370 } else {
371 return Err(DecodeError::UnexpectedEof);
372 }
373 }
374 Ok(())
375 }
376
377 fn cached_size(&self) -> u32;
381
382 fn clear(&mut self);
384}
385
386#[derive(Debug, Clone)]
407pub struct DecodeOptions {
408 recursion_limit: u32,
409 max_message_size: usize,
410}
411
412const DEFAULT_MAX_MESSAGE_SIZE: usize = 0x7FFF_FFFF;
415
416impl Default for DecodeOptions {
417 fn default() -> Self {
418 Self::new()
419 }
420}
421
422impl DecodeOptions {
423 pub fn new() -> Self {
429 Self {
430 recursion_limit: RECURSION_LIMIT,
431 max_message_size: DEFAULT_MAX_MESSAGE_SIZE,
432 }
433 }
434
435 #[must_use]
443 pub fn with_recursion_limit(mut self, limit: u32) -> Self {
444 self.recursion_limit = limit;
445 self
446 }
447
448 #[must_use]
459 pub fn with_max_message_size(mut self, max_bytes: usize) -> Self {
460 self.max_message_size = max_bytes;
461 self
462 }
463
464 pub fn recursion_limit(&self) -> u32 {
466 self.recursion_limit
467 }
468
469 pub fn max_message_size(&self) -> usize {
471 self.max_message_size
472 }
473
474 pub fn decode<M: Message>(&self, buf: &mut impl Buf) -> Result<M, DecodeError> {
476 if buf.remaining() > self.max_message_size {
477 return Err(DecodeError::MessageTooLarge);
478 }
479 let mut msg = M::default();
480 msg.merge(buf, self.recursion_limit)?;
481 Ok(msg)
482 }
483
484 pub fn decode_from_slice<M: Message>(&self, data: &[u8]) -> Result<M, DecodeError> {
486 if data.len() > self.max_message_size {
487 return Err(DecodeError::MessageTooLarge);
488 }
489 let mut msg = M::default();
490 msg.merge(&mut &*data, self.recursion_limit)?;
491 Ok(msg)
492 }
493
494 pub fn decode_length_delimited<M: Message>(
496 &self,
497 buf: &mut impl Buf,
498 ) -> Result<M, DecodeError> {
499 let max = core::cmp::min(
503 self.max_message_size as u64,
504 DEFAULT_MAX_MESSAGE_SIZE as u64,
505 );
506 let len_u64 = crate::encoding::decode_varint(buf)?;
507 if len_u64 > max {
508 return Err(DecodeError::MessageTooLarge);
509 }
510 let len = usize::try_from(len_u64).map_err(|_| DecodeError::MessageTooLarge)?;
511 if buf.remaining() < len {
512 return Err(DecodeError::UnexpectedEof);
513 }
514 let limit = buf.remaining() - len;
515 let mut msg = M::default();
516 msg.merge_to_limit(buf, self.recursion_limit, limit)?;
517 if buf.remaining() != limit {
518 let remaining = buf.remaining();
519 if remaining > limit {
520 buf.advance(remaining - limit);
521 } else {
522 return Err(DecodeError::UnexpectedEof);
523 }
524 }
525 Ok(msg)
526 }
527
528 pub fn merge<M: Message>(&self, msg: &mut M, buf: &mut impl Buf) -> Result<(), DecodeError> {
530 if buf.remaining() > self.max_message_size {
531 return Err(DecodeError::MessageTooLarge);
532 }
533 msg.merge(buf, self.recursion_limit)
534 }
535
536 pub fn merge_from_slice<M: Message>(
538 &self,
539 msg: &mut M,
540 data: &[u8],
541 ) -> Result<(), DecodeError> {
542 if data.len() > self.max_message_size {
543 return Err(DecodeError::MessageTooLarge);
544 }
545 msg.merge(&mut &*data, self.recursion_limit)
546 }
547
548 pub fn decode_view<'a, V: crate::view::MessageView<'a>>(
550 &self,
551 buf: &'a [u8],
552 ) -> Result<V, DecodeError> {
553 if buf.len() > self.max_message_size {
554 return Err(DecodeError::MessageTooLarge);
555 }
556 V::decode_view_with_limit(buf, self.recursion_limit)
557 }
558
559 #[cfg(feature = "std")]
565 pub fn decode_reader<M: Message>(
566 &self,
567 reader: &mut impl std::io::Read,
568 ) -> Result<M, std::io::Error> {
569 let bytes = self.read_limited(reader)?;
570 self.decode_from_slice::<M>(&bytes)
571 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))
572 }
573
574 #[cfg(feature = "std")]
580 pub fn decode_length_delimited_reader<M: Message>(
581 &self,
582 reader: &mut impl std::io::Read,
583 ) -> Result<M, std::io::Error> {
584 let len = read_varint(reader)?;
585 let max = core::cmp::min(
586 self.max_message_size as u64,
587 DEFAULT_MAX_MESSAGE_SIZE as u64,
588 );
589 if len > max {
590 return Err(std::io::Error::new(
591 std::io::ErrorKind::InvalidData,
592 DecodeError::MessageTooLarge,
593 ));
594 }
595 let len = len as usize;
596 let mut buf = alloc::vec![0u8; len];
597 reader.read_exact(&mut buf)?;
598 self.decode_from_slice::<M>(&buf)
599 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))
600 }
601
602 #[cfg(feature = "std")]
604 fn read_limited(
605 &self,
606 reader: &mut impl std::io::Read,
607 ) -> Result<alloc::vec::Vec<u8>, std::io::Error> {
608 use std::io::Read as _;
609 let mut buf = alloc::vec::Vec::new();
610 reader
611 .take(self.max_message_size as u64 + 1)
612 .read_to_end(&mut buf)?;
613 if buf.len() > self.max_message_size {
614 return Err(std::io::Error::new(
615 std::io::ErrorKind::InvalidData,
616 DecodeError::MessageTooLarge,
617 ));
618 }
619 Ok(buf)
620 }
621}
622
623#[cfg(feature = "std")]
629fn read_varint(reader: &mut impl std::io::Read) -> Result<u64, std::io::Error> {
630 let mut value: u64 = 0;
631 let mut shift: u32 = 0;
632 loop {
633 let mut byte = [0u8; 1];
634 reader.read_exact(&mut byte)?;
635 let b = byte[0];
636 if shift < 63 {
637 value |= ((b & 0x7F) as u64) << shift;
638 if b < 0x80 {
639 return Ok(value);
640 }
641 shift += 7;
642 } else {
643 if b > 0x01 {
647 return Err(std::io::Error::new(
648 std::io::ErrorKind::InvalidData,
649 DecodeError::VarintTooLong,
650 ));
651 }
652 value |= (b as u64) << 63;
653 return Ok(value);
654 }
655 }
656}
657
658#[cfg(test)]
659mod tests {
660 use super::*;
661 use crate::cached_size::CachedSize;
662 use crate::encoding::encode_varint;
663 use crate::error::DecodeError;
664 use crate::message_field::DefaultInstance;
665
666 #[derive(Clone, Debug, Default, PartialEq)]
669 struct FlatMsg {
670 value: i32,
671 __buffa_cached_size: CachedSize,
672 }
673
674 unsafe impl DefaultInstance for FlatMsg {
675 fn default_instance() -> &'static Self {
676 static INST: crate::__private::OnceBox<FlatMsg> = crate::__private::OnceBox::new();
677 INST.get_or_init(|| alloc::boxed::Box::new(FlatMsg::default()))
678 }
679 }
680
681 impl Message for FlatMsg {
682 fn compute_size(&self) -> u32 {
683 let size = if self.value != 0 {
684 1 + crate::types::int32_encoded_len(self.value) as u32
685 } else {
686 0
687 };
688 self.__buffa_cached_size.set(size);
689 size
690 }
691
692 fn write_to(&self, buf: &mut impl BufMut) {
693 if self.value != 0 {
694 crate::encoding::Tag::new(1, crate::encoding::WireType::Varint).encode(buf);
695 crate::types::encode_int32(self.value, buf);
696 }
697 }
698
699 fn merge_field(
700 &mut self,
701 tag: crate::encoding::Tag,
702 buf: &mut impl Buf,
703 _depth: u32,
704 ) -> Result<(), DecodeError> {
705 match tag.field_number() {
706 1 => {
707 self.value = crate::types::decode_int32(buf)?;
708 }
709 _ => {
710 crate::encoding::skip_field(tag, buf)?;
711 }
712 }
713 Ok(())
714 }
715
716 fn cached_size(&self) -> u32 {
717 self.__buffa_cached_size.get()
718 }
719
720 fn clear(&mut self) {
721 *self = Self::default();
722 }
723 }
724
725 fn wire_bytes(msg: &FlatMsg) -> alloc::vec::Vec<u8> {
726 let mut buf = alloc::vec::Vec::new();
727 msg.encode_length_delimited(&mut buf);
728 buf
729 }
730
731 #[test]
732 fn test_merge_length_delimited_basic() {
733 let src = FlatMsg {
734 value: 42,
735 __buffa_cached_size: CachedSize::default(),
736 };
737 let mut dst = FlatMsg::default();
738 dst.merge_length_delimited(&mut wire_bytes(&src).as_slice(), RECURSION_LIMIT)
739 .unwrap();
740 assert_eq!(dst.value, 42);
741 }
742
743 #[test]
744 fn test_merge_length_delimited_merges_into_existing() {
745 let mut dst = FlatMsg::default();
747 dst.merge_length_delimited(
748 &mut wire_bytes(&FlatMsg {
749 value: 1,
750 __buffa_cached_size: CachedSize::default(),
751 })
752 .as_slice(),
753 RECURSION_LIMIT,
754 )
755 .unwrap();
756 assert_eq!(dst.value, 1);
757 dst.merge_length_delimited(
758 &mut wire_bytes(&FlatMsg {
759 value: 2,
760 __buffa_cached_size: CachedSize::default(),
761 })
762 .as_slice(),
763 RECURSION_LIMIT,
764 )
765 .unwrap();
766 assert_eq!(dst.value, 2);
767 }
768
769 #[test]
770 fn test_merge_length_delimited_truncated() {
771 let mut buf = alloc::vec::Vec::new();
773 encode_varint(10, &mut buf);
774 buf.extend_from_slice(&[0x01, 0x01]);
775 let mut dst = FlatMsg::default();
776 assert_eq!(
777 dst.merge_length_delimited(&mut buf.as_slice(), RECURSION_LIMIT),
778 Err(DecodeError::UnexpectedEof)
779 );
780 }
781
782 #[test]
783 fn test_merge_length_delimited_oversized() {
784 let mut buf = alloc::vec::Vec::new();
786 encode_varint(0x8000_0000u64, &mut buf); let mut dst = FlatMsg::default();
788 assert_eq!(
789 dst.merge_length_delimited(&mut buf.as_slice(), RECURSION_LIMIT),
790 Err(DecodeError::MessageTooLarge)
791 );
792 }
793
794 #[test]
795 fn test_merge_length_delimited_recursion_limit() {
796 let src = FlatMsg {
802 value: 7,
803 __buffa_cached_size: CachedSize::default(),
804 };
805 let mut dst = FlatMsg::default();
806 assert_eq!(
807 dst.merge_length_delimited(&mut wire_bytes(&src).as_slice(), 0),
808 Err(DecodeError::RecursionLimitExceeded)
809 );
810 dst.merge_length_delimited(&mut wire_bytes(&src).as_slice(), 1)
812 .unwrap();
813 assert_eq!(dst.value, 7);
814 }
815
816 #[test]
817 fn test_decode_from_slice_basic() {
818 let src = FlatMsg {
819 value: 42,
820 __buffa_cached_size: CachedSize::default(),
821 };
822 let bytes = src.encode_to_vec();
823 let dst = FlatMsg::decode_from_slice(&bytes).unwrap();
824 assert_eq!(dst.value, 42);
825 }
826
827 #[test]
828 fn test_encode_to_bytes_matches_encode_to_vec() {
829 let src = FlatMsg {
830 value: 42,
831 __buffa_cached_size: CachedSize::default(),
832 };
833 let vec = src.encode_to_vec();
834 let bytes = src.encode_to_bytes();
835 assert_eq!(vec.as_slice(), bytes.as_ref());
836 let dst = FlatMsg::decode_from_slice(&bytes).unwrap();
838 assert_eq!(dst.value, 42);
839 assert!(FlatMsg::default().encode_to_bytes().is_empty());
841 }
842
843 #[test]
844 fn test_decode_from_slice_empty() {
845 let dst = FlatMsg::decode_from_slice(&[]).unwrap();
846 assert_eq!(dst.value, 0);
847 }
848
849 #[test]
850 fn test_decode_from_slice_invalid_returns_error() {
851 let result = FlatMsg::decode_from_slice(&[0xFF]);
853 assert!(result.is_err());
854 }
855
856 #[test]
857 fn test_merge_from_slice_basic() {
858 let src = FlatMsg {
859 value: 7,
860 __buffa_cached_size: CachedSize::default(),
861 };
862 let bytes = src.encode_to_vec();
863 let mut dst = FlatMsg::default();
864 dst.merge_from_slice(&bytes).unwrap();
865 assert_eq!(dst.value, 7);
866 }
867
868 #[test]
869 fn test_merge_from_slice_last_wins() {
870 let src1 = FlatMsg {
871 value: 1,
872 __buffa_cached_size: CachedSize::default(),
873 };
874 let src2 = FlatMsg {
875 value: 2,
876 __buffa_cached_size: CachedSize::default(),
877 };
878 let mut dst = FlatMsg::default();
879 dst.merge_from_slice(&src1.encode_to_vec()).unwrap();
880 dst.merge_from_slice(&src2.encode_to_vec()).unwrap();
881 assert_eq!(dst.value, 2);
883 }
884
885 #[test]
888 fn test_decode_options_default_works() {
889 let src = FlatMsg {
890 value: 99,
891 __buffa_cached_size: CachedSize::default(),
892 };
893 let bytes = src.encode_to_vec();
894 let msg: FlatMsg = DecodeOptions::new().decode_from_slice(&bytes).unwrap();
895 assert_eq!(msg.value, 99);
896 }
897
898 #[test]
899 fn test_decode_options_max_message_size_rejects() {
900 let src = FlatMsg {
901 value: 42,
902 __buffa_cached_size: CachedSize::default(),
903 };
904 let bytes = src.encode_to_vec();
905 let result: Result<FlatMsg, _> = DecodeOptions::new()
907 .with_max_message_size(1)
908 .decode_from_slice(&bytes);
909 assert_eq!(result, Err(DecodeError::MessageTooLarge));
910 }
911
912 #[test]
913 fn test_decode_options_max_message_size_exact_boundary() {
914 let src = FlatMsg {
915 value: 42,
916 __buffa_cached_size: CachedSize::default(),
917 };
918 let bytes = src.encode_to_vec();
919 let msg: FlatMsg = DecodeOptions::new()
921 .with_max_message_size(bytes.len())
922 .decode_from_slice(&bytes)
923 .unwrap();
924 assert_eq!(msg.value, 42);
925 let result: Result<FlatMsg, _> = DecodeOptions::new()
927 .with_max_message_size(bytes.len() - 1)
928 .decode_from_slice(&bytes);
929 assert_eq!(result, Err(DecodeError::MessageTooLarge));
930 }
931
932 #[test]
933 fn test_decode_options_custom_recursion_limit() {
934 let src = FlatMsg {
937 value: 7,
938 __buffa_cached_size: CachedSize::default(),
939 };
940 let bytes = src.encode_to_vec();
941 let msg: FlatMsg = DecodeOptions::new()
942 .with_recursion_limit(1)
943 .decode_from_slice(&bytes)
944 .unwrap();
945 assert_eq!(msg.value, 7);
946 }
947
948 #[test]
949 fn test_decode_options_merge() {
950 let src = FlatMsg {
951 value: 55,
952 __buffa_cached_size: CachedSize::default(),
953 };
954 let bytes = src.encode_to_vec();
955 let mut msg = FlatMsg::default();
956 DecodeOptions::new()
957 .merge_from_slice(&mut msg, &bytes)
958 .unwrap();
959 assert_eq!(msg.value, 55);
960 }
961
962 #[test]
963 fn test_decode_options_merge_rejects_oversize() {
964 let src = FlatMsg {
965 value: 55,
966 __buffa_cached_size: CachedSize::default(),
967 };
968 let bytes = src.encode_to_vec();
969 let mut msg = FlatMsg::default();
970 let result = DecodeOptions::new()
971 .with_max_message_size(1)
972 .merge_from_slice(&mut msg, &bytes);
973 assert_eq!(result, Err(DecodeError::MessageTooLarge));
974 }
975
976 #[test]
977 fn test_decode_options_length_delimited() {
978 let src = FlatMsg {
979 value: 42,
980 __buffa_cached_size: CachedSize::default(),
981 };
982 let mut ld_bytes = alloc::vec::Vec::new();
983 src.encode_length_delimited(&mut ld_bytes);
984 let msg: FlatMsg = DecodeOptions::new()
985 .decode_length_delimited(&mut ld_bytes.as_slice())
986 .unwrap();
987 assert_eq!(msg.value, 42);
988 }
989
990 #[test]
991 fn test_decode_options_length_delimited_rejects_oversize() {
992 let src = FlatMsg {
993 value: 42,
994 __buffa_cached_size: CachedSize::default(),
995 };
996 let mut ld_bytes = alloc::vec::Vec::new();
997 src.encode_length_delimited(&mut ld_bytes);
998 let result: Result<FlatMsg, _> = DecodeOptions::new()
999 .with_max_message_size(1)
1000 .decode_length_delimited(&mut ld_bytes.as_slice());
1001 assert_eq!(result, Err(DecodeError::MessageTooLarge));
1002 }
1003
1004 #[test]
1005 fn decode_options_getters_return_defaults() {
1006 let opts = DecodeOptions::new();
1007 assert_eq!(opts.recursion_limit(), RECURSION_LIMIT);
1008 assert_eq!(opts.max_message_size(), 0x7FFF_FFFF);
1009 }
1010
1011 #[test]
1012 fn decode_options_getters_return_custom_values() {
1013 let opts = DecodeOptions::new()
1014 .with_recursion_limit(42)
1015 .with_max_message_size(1024);
1016 assert_eq!(opts.recursion_limit(), 42);
1017 assert_eq!(opts.max_message_size(), 1024);
1018 }
1019
1020 #[test]
1021 fn test_decode_options_default_impl() {
1022 let opts = DecodeOptions::default();
1024 assert_eq!(opts.recursion_limit(), RECURSION_LIMIT);
1025 assert_eq!(opts.max_message_size(), 0x7FFF_FFFF);
1026 }
1027
1028 #[test]
1029 fn test_decode_options_decode_buf() {
1030 let src = FlatMsg {
1032 value: 123,
1033 ..Default::default()
1034 };
1035 let bytes = src.encode_to_vec();
1036 let msg: FlatMsg = DecodeOptions::new().decode(&mut bytes.as_slice()).unwrap();
1037 assert_eq!(msg.value, 123);
1038 let result: Result<FlatMsg, _> = DecodeOptions::new()
1040 .with_max_message_size(1)
1041 .decode(&mut bytes.as_slice());
1042 assert_eq!(result, Err(DecodeError::MessageTooLarge));
1043 }
1044
1045 #[test]
1046 fn test_decode_options_merge_buf() {
1047 let src = FlatMsg {
1049 value: 77,
1050 ..Default::default()
1051 };
1052 let bytes = src.encode_to_vec();
1053 let mut msg = FlatMsg::default();
1054 DecodeOptions::new()
1055 .merge(&mut msg, &mut bytes.as_slice())
1056 .unwrap();
1057 assert_eq!(msg.value, 77);
1058 let mut msg = FlatMsg::default();
1060 let result = DecodeOptions::new()
1061 .with_max_message_size(1)
1062 .merge(&mut msg, &mut bytes.as_slice());
1063 assert_eq!(result, Err(DecodeError::MessageTooLarge));
1064 }
1065
1066 #[test]
1069 fn test_message_encode_trait_default() {
1070 let src = FlatMsg {
1072 value: 42,
1073 ..Default::default()
1074 };
1075 let mut buf = alloc::vec::Vec::new();
1076 src.encode(&mut buf);
1077 assert_eq!(buf, src.encode_to_vec());
1078 }
1079
1080 #[test]
1081 fn test_message_decode_length_delimited_trait_default() {
1082 let src = FlatMsg {
1085 value: 42,
1086 ..Default::default()
1087 };
1088 let mut ld = alloc::vec::Vec::new();
1089 src.encode_length_delimited(&mut ld);
1090 let got = FlatMsg::decode_length_delimited(&mut ld.as_slice()).unwrap();
1091 assert_eq!(got.value, 42);
1092 }
1093
1094 #[test]
1095 fn test_message_decode_length_delimited_oversize() {
1096 let mut buf = alloc::vec::Vec::new();
1098 encode_varint(0x8000_0000u64, &mut buf);
1099 let result = FlatMsg::decode_length_delimited(&mut buf.as_slice());
1100 assert_eq!(result, Err(DecodeError::MessageTooLarge));
1101 }
1102
1103 #[test]
1104 fn test_message_decode_length_delimited_truncated() {
1105 let mut buf = alloc::vec::Vec::new();
1107 encode_varint(10, &mut buf);
1108 buf.push(0x08);
1109 buf.push(0x01);
1110 let result = FlatMsg::decode_length_delimited(&mut buf.as_slice());
1111 assert_eq!(result, Err(DecodeError::UnexpectedEof));
1112 }
1113
1114 #[test]
1115 fn test_message_decode_length_delimited_with_trailing() {
1116 let a = FlatMsg {
1120 value: 1,
1121 ..Default::default()
1122 };
1123 let b = FlatMsg {
1124 value: 2,
1125 ..Default::default()
1126 };
1127 let mut buf = alloc::vec::Vec::new();
1128 a.encode_length_delimited(&mut buf);
1129 b.encode_length_delimited(&mut buf);
1130
1131 let mut cur = buf.as_slice();
1132 let first = FlatMsg::decode_length_delimited(&mut cur).unwrap();
1133 assert_eq!(first.value, 1);
1134 let second = FlatMsg::decode_length_delimited(&mut cur).unwrap();
1135 assert_eq!(second.value, 2);
1136 assert!(cur.is_empty());
1137 }
1138
1139 fn group_bytes(value: i32, group_field_number: u32) -> alloc::vec::Vec<u8> {
1144 use crate::encoding::{Tag, WireType};
1145 let mut buf = alloc::vec::Vec::new();
1146 if value != 0 {
1147 Tag::new(1, WireType::Varint).encode(&mut buf);
1148 crate::types::encode_int32(value, &mut buf);
1149 }
1150 Tag::new(group_field_number, WireType::EndGroup).encode(&mut buf);
1151 buf
1152 }
1153
1154 #[test]
1155 fn test_merge_group_basic() {
1156 let data = group_bytes(42, 5);
1157 let mut dst = FlatMsg::default();
1158 dst.merge_group(&mut data.as_slice(), RECURSION_LIMIT, 5)
1159 .unwrap();
1160 assert_eq!(dst.value, 42);
1161 }
1162
1163 #[test]
1164 fn test_merge_group_empty() {
1165 let data = group_bytes(0, 3);
1167 let mut dst = FlatMsg::default();
1168 dst.merge_group(&mut data.as_slice(), RECURSION_LIMIT, 3)
1169 .unwrap();
1170 assert_eq!(dst.value, 0);
1171 }
1172
1173 #[test]
1174 fn test_merge_group_merges_into_existing() {
1175 let data1 = group_bytes(1, 5);
1176 let data2 = group_bytes(2, 5);
1177 let mut dst = FlatMsg::default();
1178 dst.merge_group(&mut data1.as_slice(), RECURSION_LIMIT, 5)
1179 .unwrap();
1180 assert_eq!(dst.value, 1);
1181 dst.merge_group(&mut data2.as_slice(), RECURSION_LIMIT, 5)
1182 .unwrap();
1183 assert_eq!(dst.value, 2);
1184 }
1185
1186 #[test]
1187 fn test_merge_group_recursion_limit_zero() {
1188 let data = group_bytes(42, 5);
1191 let mut dst = FlatMsg::default();
1192 assert_eq!(
1193 dst.merge_group(&mut data.as_slice(), 0, 5),
1194 Err(DecodeError::RecursionLimitExceeded)
1195 );
1196 }
1197
1198 #[test]
1199 fn test_merge_group_recursion_limit_one_succeeds() {
1200 let data = group_bytes(7, 5);
1203 let mut dst = FlatMsg::default();
1204 dst.merge_group(&mut data.as_slice(), 1, 5).unwrap();
1205 assert_eq!(dst.value, 7);
1206 }
1207
1208 #[test]
1209 fn test_merge_group_mismatched_end() {
1210 use crate::encoding::{Tag, WireType};
1212 let mut data = alloc::vec::Vec::new();
1213 Tag::new(99, WireType::EndGroup).encode(&mut data);
1214
1215 let mut dst = FlatMsg::default();
1216 assert_eq!(
1217 dst.merge_group(&mut data.as_slice(), RECURSION_LIMIT, 5),
1218 Err(DecodeError::InvalidEndGroup(99))
1219 );
1220 }
1221
1222 #[test]
1223 fn test_merge_group_truncated() {
1224 use crate::encoding::{Tag, WireType};
1226 let mut data = alloc::vec::Vec::new();
1227 Tag::new(1, WireType::Varint).encode(&mut data);
1228 crate::types::encode_int32(42, &mut data);
1229 let mut dst = FlatMsg::default();
1232 assert_eq!(
1233 dst.merge_group(&mut data.as_slice(), RECURSION_LIMIT, 5),
1234 Err(DecodeError::UnexpectedEof)
1235 );
1236 }
1237
1238 #[test]
1239 fn test_merge_group_empty_buffer() {
1240 let mut dst = FlatMsg::default();
1241 assert_eq!(
1242 dst.merge_group(&mut [].as_slice(), RECURSION_LIMIT, 5),
1243 Err(DecodeError::UnexpectedEof)
1244 );
1245 }
1246
1247 #[test]
1248 fn test_merge_group_unknown_fields_skipped() {
1249 use crate::encoding::{Tag, WireType};
1253 let mut data = alloc::vec::Vec::new();
1254 Tag::new(99, WireType::Varint).encode(&mut data);
1256 crate::encoding::encode_varint(0, &mut data);
1257 Tag::new(1, WireType::Varint).encode(&mut data);
1259 crate::types::encode_int32(99, &mut data);
1260 Tag::new(5, WireType::EndGroup).encode(&mut data);
1262
1263 let mut dst = FlatMsg::default();
1264 dst.merge_group(&mut data.as_slice(), RECURSION_LIMIT, 5)
1265 .unwrap();
1266 assert_eq!(dst.value, 99);
1267 }
1268
1269 #[test]
1270 fn test_merge_group_trailing_data_preserved() {
1271 let mut data = group_bytes(42, 5);
1273 data.extend_from_slice(&[0xDE, 0xAD]);
1274
1275 let mut cur = data.as_slice();
1276 let mut dst = FlatMsg::default();
1277 dst.merge_group(&mut cur, RECURSION_LIMIT, 5).unwrap();
1278 assert_eq!(dst.value, 42);
1279 assert_eq!(cur, &[0xDE, 0xAD]);
1280 }
1281
1282 #[cfg(feature = "std")]
1285 mod read_varint_tests {
1286 use super::super::read_varint;
1287 use crate::encoding::encode_varint;
1288
1289 #[test]
1290 fn roundtrip_values() {
1291 let cases: &[u64] = &[0, 1, 127, 128, 300, 1 << 14, 1 << 35, 1 << 63, u64::MAX];
1292 for &v in cases {
1293 let mut buf = Vec::new();
1294 encode_varint(v, &mut buf);
1295 let got = read_varint(&mut buf.as_slice()).unwrap();
1296 assert_eq!(got, v, "roundtrip failed for {v}");
1297 }
1298 }
1299
1300 #[test]
1301 fn rejects_10th_byte_overflow() {
1302 let mut bad: Vec<u8> = vec![0xFF; 9];
1305 bad.push(0x02);
1306 let err = read_varint(&mut bad.as_slice()).unwrap_err();
1307 assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
1308 }
1309
1310 #[test]
1311 fn rejects_11th_byte() {
1312 let bad: &[u8] = &[0xFF; 10];
1314 let err = read_varint(&mut &bad[..]).unwrap_err();
1315 assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
1316 }
1317
1318 #[test]
1319 fn u64_max_roundtrips() {
1320 let mut buf = Vec::new();
1322 encode_varint(u64::MAX, &mut buf);
1323 assert_eq!(buf.len(), 10);
1324 assert_eq!(buf[9], 0x01);
1325 let got = read_varint(&mut buf.as_slice()).unwrap();
1326 assert_eq!(got, u64::MAX);
1327 }
1328
1329 #[test]
1330 fn eof_before_terminator_is_error() {
1331 let bad: &[u8] = &[0x80];
1333 let err = read_varint(&mut &bad[..]).unwrap_err();
1334 assert_eq!(err.kind(), std::io::ErrorKind::UnexpectedEof);
1335 }
1336
1337 #[test]
1338 fn empty_input_is_error() {
1339 let err = read_varint(&mut &[][..]).unwrap_err();
1340 assert_eq!(err.kind(), std::io::ErrorKind::UnexpectedEof);
1341 }
1342 }
1343
1344 #[cfg(feature = "std")]
1347 mod reader_tests {
1348 use super::*;
1349
1350 #[test]
1351 fn decode_reader_basic() {
1352 let src = FlatMsg {
1353 value: 42,
1354 ..Default::default()
1355 };
1356 let bytes = src.encode_to_vec();
1357 let msg: FlatMsg = DecodeOptions::new()
1358 .decode_reader(&mut bytes.as_slice())
1359 .unwrap();
1360 assert_eq!(msg.value, 42);
1361 }
1362
1363 #[test]
1364 fn decode_reader_rejects_oversize() {
1365 let src = FlatMsg {
1366 value: 42,
1367 ..Default::default()
1368 };
1369 let bytes = src.encode_to_vec();
1370 let err = DecodeOptions::new()
1371 .with_max_message_size(1)
1372 .decode_reader::<FlatMsg>(&mut bytes.as_slice())
1373 .unwrap_err();
1374 assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
1375 }
1376
1377 #[test]
1378 fn decode_reader_exact_boundary() {
1379 let src = FlatMsg {
1381 value: 42,
1382 ..Default::default()
1383 };
1384 let bytes = src.encode_to_vec();
1385 let msg: FlatMsg = DecodeOptions::new()
1386 .with_max_message_size(bytes.len())
1387 .decode_reader(&mut bytes.as_slice())
1388 .unwrap();
1389 assert_eq!(msg.value, 42);
1390 }
1391
1392 #[test]
1393 fn decode_reader_propagates_read_error() {
1394 struct ErrReader;
1396 impl std::io::Read for ErrReader {
1397 fn read(&mut self, _: &mut [u8]) -> std::io::Result<usize> {
1398 Err(std::io::Error::new(std::io::ErrorKind::BrokenPipe, "gone"))
1399 }
1400 }
1401 let err = DecodeOptions::new()
1402 .decode_reader::<FlatMsg>(&mut ErrReader)
1403 .unwrap_err();
1404 assert_eq!(err.kind(), std::io::ErrorKind::BrokenPipe);
1405 }
1406
1407 #[test]
1408 fn decode_length_delimited_reader_basic() {
1409 let src = FlatMsg {
1410 value: 99,
1411 ..Default::default()
1412 };
1413 let mut ld = Vec::new();
1414 src.encode_length_delimited(&mut ld);
1415 let msg: FlatMsg = DecodeOptions::new()
1416 .decode_length_delimited_reader(&mut ld.as_slice())
1417 .unwrap();
1418 assert_eq!(msg.value, 99);
1419 }
1420
1421 #[test]
1422 fn decode_length_delimited_reader_rejects_oversize_prefix() {
1423 let src = FlatMsg {
1425 value: 99,
1426 ..Default::default()
1427 };
1428 let mut ld = Vec::new();
1429 src.encode_length_delimited(&mut ld);
1430 let err = DecodeOptions::new()
1431 .with_max_message_size(1)
1432 .decode_length_delimited_reader::<FlatMsg>(&mut ld.as_slice())
1433 .unwrap_err();
1434 assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
1435 }
1436
1437 #[test]
1438 fn decode_length_delimited_reader_sequential() {
1439 let a = FlatMsg {
1441 value: 10,
1442 ..Default::default()
1443 };
1444 let b = FlatMsg {
1445 value: 20,
1446 ..Default::default()
1447 };
1448 let mut stream = Vec::new();
1449 a.encode_length_delimited(&mut stream);
1450 b.encode_length_delimited(&mut stream);
1451
1452 let mut cursor = std::io::Cursor::new(stream);
1453 let first: FlatMsg = DecodeOptions::new()
1454 .decode_length_delimited_reader(&mut cursor)
1455 .unwrap();
1456 assert_eq!(first.value, 10);
1457 let second: FlatMsg = DecodeOptions::new()
1458 .decode_length_delimited_reader(&mut cursor)
1459 .unwrap();
1460 assert_eq!(second.value, 20);
1461 }
1462
1463 #[test]
1464 fn decode_length_delimited_reader_truncated_body() {
1465 let mut buf = Vec::new();
1467 crate::encoding::encode_varint(100, &mut buf);
1468 buf.push(0x08);
1469 let err = DecodeOptions::new()
1470 .decode_length_delimited_reader::<FlatMsg>(&mut buf.as_slice())
1471 .unwrap_err();
1472 assert_eq!(err.kind(), std::io::ErrorKind::UnexpectedEof);
1473 }
1474 }
1475}