use alloc::collections::BTreeMap;
use alloc::format;
use alloc::vec::Vec;
use ::bytes::Buf;
use ::bytes::BufMut;
mod bytes_adapter;
mod map;
mod primitives;
pub use bytes_adapter::*;
pub use map::*;
pub use primitives::*;
pub mod varint;
pub use varint::decode_varint;
pub use varint::encode_varint;
pub use varint::encoded_len_varint;
pub mod length_delimiter;
pub use length_delimiter::decode_length_delimiter;
pub use length_delimiter::encode_length_delimiter;
pub use length_delimiter::length_delimiter_len;
pub mod wire_type;
pub use wire_type::WireType;
pub use wire_type::check_wire_type;
use crate::error::DecodeError;
pub const MIN_TAG: u32 = 1;
pub const MAX_TAG: u32 = (1 << 29) - 1;
#[derive(Clone, Debug, Copy)]
#[cfg_attr(feature = "no-recursion-limit", derive(Default))]
pub struct DecodeContext {
#[cfg(not(feature = "no-recursion-limit"))]
recurse_count: u32,
}
#[cfg(not(feature = "no-recursion-limit"))]
impl Default for DecodeContext {
#[inline]
fn default() -> DecodeContext {
DecodeContext {
recurse_count: crate::RECURSION_LIMIT,
}
}
}
impl DecodeContext {
#[cfg(not(feature = "no-recursion-limit"))]
#[inline]
#[must_use]
pub const fn enter_recursion(&self) -> DecodeContext {
DecodeContext {
recurse_count: self.recurse_count - 1,
}
}
#[allow(clippy::trivially_copy_pass_by_ref)]
#[cfg(feature = "no-recursion-limit")]
#[inline]
#[must_use]
pub const fn enter_recursion(&self) -> DecodeContext {
DecodeContext {}
}
#[cfg(not(feature = "no-recursion-limit"))]
#[inline]
pub fn limit_reached(&self) -> Result<(), DecodeError> {
if self.recurse_count == 0 {
Err(DecodeError::new("recursion limit reached"))
} else {
Ok(())
}
}
#[allow(clippy::trivially_copy_pass_by_ref)]
#[allow(clippy::unnecessary_wraps)]
#[cfg(feature = "no-recursion-limit")]
#[inline]
pub const fn limit_reached(&self) -> Result<(), DecodeError> {
Ok(())
}
}
#[inline]
pub fn encode_key(tag: u32, wire_type: WireType, buf: &mut impl BufMut) {
debug_assert!((MIN_TAG..=MAX_TAG).contains(&tag));
let key = (tag << 3) | wire_type as u32;
encode_varint(u64::from(key), buf);
}
#[inline]
pub fn decode_key(buf: &mut impl Buf) -> Result<(u32, WireType), DecodeError> {
let key = decode_varint(buf)?;
if key > u64::from(u32::MAX) {
return Err(DecodeError::new(format!("invalid key value: {key}")));
}
let wire_type = WireType::try_from(key & 0x07)?;
let tag = key as u32 >> 3;
if tag < MIN_TAG {
return Err(DecodeError::new("invalid tag value: 0"));
}
Ok((tag, wire_type))
}
#[inline]
pub const fn key_len(tag: u32) -> usize {
encoded_len_varint((tag << 3) as u64)
}
pub fn merge_loop<T, M, B>(value: &mut T, buf: &mut B, ctx: DecodeContext, mut merge: M) -> Result<(), DecodeError>
where
M: FnMut(&mut T, &mut B, DecodeContext) -> Result<(), DecodeError>,
B: Buf,
{
let len = decode_varint(buf)?;
let remaining = buf.remaining();
if len > remaining as u64 {
return Err(DecodeError::new("buffer underflow"));
}
let limit = remaining - len as usize;
while buf.remaining() > limit {
merge(value, buf, ctx)?;
}
if buf.remaining() != limit {
return Err(DecodeError::new("delimited length exceeded"));
}
Ok(())
}
pub fn skip_field(wire_type: WireType, tag: u32, buf: &mut impl Buf, ctx: DecodeContext) -> Result<(), DecodeError> {
ctx.limit_reached()?;
let len = match wire_type {
WireType::Varint => decode_varint(buf).map(|_| 0)?,
WireType::ThirtyTwoBit => 4,
WireType::SixtyFourBit => 8,
WireType::LengthDelimited => decode_varint(buf)?,
WireType::StartGroup => loop {
let (inner_tag, inner_wire_type) = decode_key(buf)?;
match inner_wire_type {
WireType::EndGroup => {
if inner_tag != tag {
return Err(DecodeError::new("unexpected end group tag"));
}
break 0;
}
_ => skip_field(inner_wire_type, inner_tag, buf, ctx.enter_recursion())?,
}
},
WireType::EndGroup => return Err(DecodeError::new("unexpected end group tag")),
};
if len > buf.remaining() as u64 {
return Err(DecodeError::new("buffer underflow"));
}
buf.advance(len as usize);
Ok(())
}
#[cfg(test)]
mod test {
use alloc::string::ToString;
use core::borrow::Borrow;
use core::fmt::Debug;
use proptest::prelude::*;
use proptest::test_runner::TestCaseResult;
use super::*;
use crate::bytes::Bytes;
use crate::bytes::BytesMut;
pub fn check_type<T, B>(
value: T,
tag: u32,
wire_type: WireType,
encode: fn(u32, &B, &mut BytesMut),
merge: fn(WireType, &mut T, &mut Bytes, DecodeContext) -> Result<(), DecodeError>,
encoded_len: fn(u32, &B) -> usize,
) -> TestCaseResult
where
T: Debug + Default + PartialEq + Borrow<B>,
B: ?Sized,
{
prop_assume!((MIN_TAG..=MAX_TAG).contains(&tag));
let expected_len = encoded_len(tag, value.borrow());
let mut buf = BytesMut::with_capacity(expected_len);
encode(tag, value.borrow(), &mut buf);
let mut buf = buf.freeze();
prop_assert_eq!(
buf.remaining(),
expected_len,
"encoded_len wrong; expected: {}, actual: {}",
expected_len,
buf.remaining()
);
if !buf.has_remaining() {
return Ok(());
}
let (decoded_tag, decoded_wire_type) = decode_key(&mut buf).map_err(|error| TestCaseError::fail(error.to_string()))?;
prop_assert_eq!(
tag,
decoded_tag,
"decoded tag does not match; expected: {}, actual: {}",
tag,
decoded_tag
);
prop_assert_eq!(
wire_type,
decoded_wire_type,
"decoded wire type does not match; expected: {:?}, actual: {:?}",
wire_type,
decoded_wire_type,
);
match wire_type {
WireType::SixtyFourBit if buf.remaining() != 8 => Err(TestCaseError::fail(format!(
"64bit wire type illegal remaining: {}, tag: {}",
buf.remaining(),
tag
))),
WireType::ThirtyTwoBit if buf.remaining() != 4 => Err(TestCaseError::fail(format!(
"32bit wire type illegal remaining: {}, tag: {}",
buf.remaining(),
tag
))),
_ => Ok(()),
}?;
let mut roundtrip_value = T::default();
merge(wire_type, &mut roundtrip_value, &mut buf, DecodeContext::default())
.map_err(|error| TestCaseError::fail(error.to_string()))?;
prop_assert!(!buf.has_remaining(), "expected buffer to be empty, remaining: {}", buf.remaining());
prop_assert_eq!(value, roundtrip_value);
Ok(())
}
pub fn check_collection_type<T, B, E, M, L>(
value: T,
tag: u32,
wire_type: WireType,
encode: E,
mut merge: M,
encoded_len: L,
) -> TestCaseResult
where
T: Debug + Default + PartialEq + Borrow<B>,
B: ?Sized,
E: FnOnce(u32, &B, &mut BytesMut),
M: FnMut(WireType, &mut T, &mut Bytes, DecodeContext) -> Result<(), DecodeError>,
L: FnOnce(u32, &B) -> usize,
{
prop_assume!((MIN_TAG..=MAX_TAG).contains(&tag));
let expected_len = encoded_len(tag, value.borrow());
let mut buf = BytesMut::with_capacity(expected_len);
encode(tag, value.borrow(), &mut buf);
let mut buf = buf.freeze();
prop_assert_eq!(
buf.remaining(),
expected_len,
"encoded_len wrong; expected: {}, actual: {}",
expected_len,
buf.remaining()
);
let mut roundtrip_value = Default::default();
while buf.has_remaining() {
let (decoded_tag, decoded_wire_type) = decode_key(&mut buf).map_err(|error| TestCaseError::fail(error.to_string()))?;
prop_assert_eq!(
tag,
decoded_tag,
"decoded tag does not match; expected: {}, actual: {}",
tag,
decoded_tag
);
prop_assert_eq!(
wire_type,
decoded_wire_type,
"decoded wire type does not match; expected: {:?}, actual: {:?}",
wire_type,
decoded_wire_type
);
merge(wire_type, &mut roundtrip_value, &mut buf, DecodeContext::default())
.map_err(|error| TestCaseError::fail(error.to_string()))?;
}
prop_assert_eq!(value, roundtrip_value);
Ok(())
}
#[test]
fn string_merge_invalid_utf8() {
let mut s = String::new();
let buf = b"\x02\x80\x80";
let r = string::merge(WireType::LengthDelimited, &mut s, &mut &buf[..], DecodeContext::default());
r.expect_err("must be an error");
assert!(s.is_empty());
}
#[cfg(feature = "std_legacy")]
macro_rules! map_tests {
(keys: $keys:tt,
vals: $vals:tt) => {
mod hash_map {
map_tests!(@private HashMap, hash_map, $keys, $vals);
}
mod btree_map {
map_tests!(@private BTreeMap, btree_map, $keys, $vals);
}
};
(@private $map_type:ident,
$mod_name:ident,
[$(($key_ty:ty, $key_proto:ident)),*],
$vals:tt) => {
$(
mod $key_proto {
use std::collections::$map_type;
use proptest::prelude::*;
use crate::encoding::*;
use crate::encoding::test::check_collection_type;
map_tests!(@private $map_type, $mod_name, ($key_ty, $key_proto), $vals);
}
)*
};
(@private $map_type:ident,
$mod_name:ident,
($key_ty:ty, $key_proto:ident),
[$(($val_ty:ty, $val_proto:ident)),*]) => {
$(
proptest! {
#[test]
fn $val_proto(values: $map_type<$key_ty, $val_ty>, tag in MIN_TAG..=MAX_TAG) {
check_collection_type(values, tag, WireType::LengthDelimited,
|tag, values, buf| {
$mod_name::encode($key_proto::_encode_by_ref_tagged,
$key_proto::_encoded_len_by_ref_tagged,
$val_proto::_encode_by_ref_tagged,
$val_proto::_encoded_len_by_ref_tagged,
tag,
values,
buf)
},
|wire_type, values, buf, ctx| {
check_wire_type(WireType::LengthDelimited, wire_type)?;
$mod_name::merge($key_proto::merge,
$val_proto::merge,
values,
buf,
ctx)
},
|tag, values| {
$mod_name::encoded_len($key_proto::_encoded_len_by_ref_tagged,
$val_proto::_encoded_len_by_ref_tagged,
tag,
values)
})?;
}
}
)*
};
}
#[cfg(feature = "std_legacy")]
map_tests!(keys: [
(i32, int32),
(i64, int64),
(u32, uint32),
(u64, uint64),
(i32, sint32),
(i64, sint64),
(u32, fixed32),
(u64, fixed64),
(i32, sfixed32),
(i64, sfixed64),
(bool, bool),
(String, string)
],
vals: [
(f32, float),
(f64, double),
(i32, int32),
(i64, int64),
(u32, uint32),
(u64, uint64),
(i32, sint32),
(i64, sint64),
(u32, fixed32),
(u64, fixed64),
(i32, sfixed32),
(i64, sfixed64),
(bool, bool),
(String, string),
(Vec<u8>, bytes)
]);
#[test]
fn split_varint_decoding() {
let mut test_values = Vec::<u64>::with_capacity(10 * 2);
test_values.push(128);
for i in 2..9 {
test_values.push((1 << (7 * i)) - 1);
test_values.push(1 << (7 * i));
}
for v in test_values {
let mut buf = BytesMut::with_capacity(10);
encode_varint(v, &mut buf);
let half_len = buf.len() / 2;
let len = buf.len();
let b2 = buf.split_off(half_len);
let mut c = buf.chain(b2);
assert_eq!(c.remaining(), len);
assert_eq!(c.chunk().len(), half_len);
assert_eq!(v, decode_varint(&mut c).unwrap());
}
}
}