use bytes::{Buf, BufMut};
use crate::error::DecodeError;
use crate::message_field::DefaultInstance;
pub const RECURSION_LIMIT: u32 = 100;
pub trait Message: DefaultInstance + Clone + PartialEq + Send + Sync {
fn compute_size(&self) -> u32;
fn write_to(&self, buf: &mut impl BufMut);
fn encode(&self, buf: &mut impl BufMut) {
self.compute_size();
self.write_to(buf);
}
fn encode_length_delimited(&self, buf: &mut impl BufMut) {
let len = self.compute_size();
crate::encoding::encode_varint(len as u64, buf);
self.write_to(buf);
}
fn encode_to_vec(&self) -> alloc::vec::Vec<u8> {
let size = self.compute_size() as usize;
let mut buf = alloc::vec::Vec::with_capacity(size);
self.write_to(&mut buf);
buf
}
fn encode_to_bytes(&self) -> bytes::Bytes {
let size = self.compute_size() as usize;
let mut buf = bytes::BytesMut::with_capacity(size);
self.write_to(&mut buf);
buf.freeze()
}
fn decode(buf: &mut impl Buf) -> Result<Self, DecodeError>
where
Self: Sized,
{
let mut msg = Self::default();
msg.merge(buf, RECURSION_LIMIT)?;
Ok(msg)
}
fn decode_from_slice(mut data: &[u8]) -> Result<Self, DecodeError>
where
Self: Sized,
{
Self::decode(&mut data)
}
fn decode_length_delimited(buf: &mut impl Buf) -> Result<Self, DecodeError>
where
Self: Sized,
{
const MAX_MESSAGE_BYTES: u64 = 0x7FFF_FFFF;
let len_u64 = crate::encoding::decode_varint(buf)?;
if len_u64 > MAX_MESSAGE_BYTES {
return Err(DecodeError::MessageTooLarge);
}
let len = usize::try_from(len_u64).map_err(|_| DecodeError::MessageTooLarge)?;
if buf.remaining() < len {
return Err(DecodeError::UnexpectedEof);
}
let limit = buf.remaining() - len;
let mut msg = Self::default();
msg.merge_to_limit(buf, RECURSION_LIMIT, limit)?;
if buf.remaining() != limit {
let remaining = buf.remaining();
if remaining > limit {
buf.advance(remaining - limit);
} else {
return Err(DecodeError::UnexpectedEof);
}
}
Ok(msg)
}
fn merge_field(
&mut self,
tag: crate::encoding::Tag,
buf: &mut impl Buf,
depth: u32,
) -> Result<(), DecodeError>;
fn merge_to_limit(
&mut self,
buf: &mut impl Buf,
depth: u32,
limit: usize,
) -> Result<(), DecodeError> {
while buf.remaining() > limit {
let tag = crate::encoding::Tag::decode(buf)?;
self.merge_field(tag, buf, depth)?;
}
Ok(())
}
fn merge_group(
&mut self,
buf: &mut impl Buf,
depth: u32,
field_number: u32,
) -> Result<(), DecodeError> {
let depth = depth
.checked_sub(1)
.ok_or(DecodeError::RecursionLimitExceeded)?;
loop {
if !buf.has_remaining() {
return Err(DecodeError::UnexpectedEof);
}
let tag = crate::encoding::Tag::decode(buf)?;
if tag.wire_type() == crate::encoding::WireType::EndGroup {
return if tag.field_number() == field_number {
Ok(())
} else {
Err(DecodeError::InvalidEndGroup(tag.field_number()))
};
}
self.merge_field(tag, buf, depth)?;
}
}
fn merge(&mut self, buf: &mut impl Buf, depth: u32) -> Result<(), DecodeError> {
self.merge_to_limit(buf, depth, 0)
}
fn merge_from_slice(&mut self, mut data: &[u8]) -> Result<(), DecodeError> {
self.merge(&mut data, RECURSION_LIMIT)
}
fn merge_length_delimited(
&mut self,
buf: &mut impl Buf,
depth: u32,
) -> Result<(), DecodeError> {
let depth = depth
.checked_sub(1)
.ok_or(DecodeError::RecursionLimitExceeded)?;
const MAX_SUB_MESSAGE_BYTES: u64 = 0x7FFF_FFFF;
let len_u64 = crate::encoding::decode_varint(buf)?;
if len_u64 > MAX_SUB_MESSAGE_BYTES {
return Err(DecodeError::MessageTooLarge);
}
let len = usize::try_from(len_u64).map_err(|_| DecodeError::MessageTooLarge)?;
if buf.remaining() < len {
return Err(DecodeError::UnexpectedEof);
}
let limit = buf.remaining() - len;
self.merge_to_limit(buf, depth, limit)?;
if buf.remaining() != limit {
let remaining = buf.remaining();
if remaining > limit {
buf.advance(remaining - limit);
} else {
return Err(DecodeError::UnexpectedEof);
}
}
Ok(())
}
fn cached_size(&self) -> u32;
fn clear(&mut self);
}
#[derive(Debug, Clone)]
pub struct DecodeOptions {
recursion_limit: u32,
max_message_size: usize,
}
const DEFAULT_MAX_MESSAGE_SIZE: usize = 0x7FFF_FFFF;
impl Default for DecodeOptions {
fn default() -> Self {
Self::new()
}
}
impl DecodeOptions {
pub fn new() -> Self {
Self {
recursion_limit: RECURSION_LIMIT,
max_message_size: DEFAULT_MAX_MESSAGE_SIZE,
}
}
#[must_use]
pub fn with_recursion_limit(mut self, limit: u32) -> Self {
self.recursion_limit = limit;
self
}
#[must_use]
pub fn with_max_message_size(mut self, max_bytes: usize) -> Self {
self.max_message_size = max_bytes;
self
}
pub fn recursion_limit(&self) -> u32 {
self.recursion_limit
}
pub fn max_message_size(&self) -> usize {
self.max_message_size
}
pub fn decode<M: Message>(&self, buf: &mut impl Buf) -> Result<M, DecodeError> {
if buf.remaining() > self.max_message_size {
return Err(DecodeError::MessageTooLarge);
}
let mut msg = M::default();
msg.merge(buf, self.recursion_limit)?;
Ok(msg)
}
pub fn decode_from_slice<M: Message>(&self, data: &[u8]) -> Result<M, DecodeError> {
if data.len() > self.max_message_size {
return Err(DecodeError::MessageTooLarge);
}
let mut msg = M::default();
msg.merge(&mut &*data, self.recursion_limit)?;
Ok(msg)
}
pub fn decode_length_delimited<M: Message>(
&self,
buf: &mut impl Buf,
) -> Result<M, DecodeError> {
let max = core::cmp::min(
self.max_message_size as u64,
DEFAULT_MAX_MESSAGE_SIZE as u64,
);
let len_u64 = crate::encoding::decode_varint(buf)?;
if len_u64 > max {
return Err(DecodeError::MessageTooLarge);
}
let len = usize::try_from(len_u64).map_err(|_| DecodeError::MessageTooLarge)?;
if buf.remaining() < len {
return Err(DecodeError::UnexpectedEof);
}
let limit = buf.remaining() - len;
let mut msg = M::default();
msg.merge_to_limit(buf, self.recursion_limit, limit)?;
if buf.remaining() != limit {
let remaining = buf.remaining();
if remaining > limit {
buf.advance(remaining - limit);
} else {
return Err(DecodeError::UnexpectedEof);
}
}
Ok(msg)
}
pub fn merge<M: Message>(&self, msg: &mut M, buf: &mut impl Buf) -> Result<(), DecodeError> {
if buf.remaining() > self.max_message_size {
return Err(DecodeError::MessageTooLarge);
}
msg.merge(buf, self.recursion_limit)
}
pub fn merge_from_slice<M: Message>(
&self,
msg: &mut M,
data: &[u8],
) -> Result<(), DecodeError> {
if data.len() > self.max_message_size {
return Err(DecodeError::MessageTooLarge);
}
msg.merge(&mut &*data, self.recursion_limit)
}
pub fn decode_view<'a, V: crate::view::MessageView<'a>>(
&self,
buf: &'a [u8],
) -> Result<V, DecodeError> {
if buf.len() > self.max_message_size {
return Err(DecodeError::MessageTooLarge);
}
V::decode_view_with_limit(buf, self.recursion_limit)
}
#[cfg(feature = "std")]
pub fn decode_reader<M: Message>(
&self,
reader: &mut impl std::io::Read,
) -> Result<M, std::io::Error> {
let bytes = self.read_limited(reader)?;
self.decode_from_slice::<M>(&bytes)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))
}
#[cfg(feature = "std")]
pub fn decode_length_delimited_reader<M: Message>(
&self,
reader: &mut impl std::io::Read,
) -> Result<M, std::io::Error> {
let len = read_varint(reader)?;
let max = core::cmp::min(
self.max_message_size as u64,
DEFAULT_MAX_MESSAGE_SIZE as u64,
);
if len > max {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
DecodeError::MessageTooLarge,
));
}
let len = len as usize;
let mut buf = alloc::vec![0u8; len];
reader.read_exact(&mut buf)?;
self.decode_from_slice::<M>(&buf)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))
}
#[cfg(feature = "std")]
fn read_limited(
&self,
reader: &mut impl std::io::Read,
) -> Result<alloc::vec::Vec<u8>, std::io::Error> {
use std::io::Read as _;
let mut buf = alloc::vec::Vec::new();
reader
.take(self.max_message_size as u64 + 1)
.read_to_end(&mut buf)?;
if buf.len() > self.max_message_size {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
DecodeError::MessageTooLarge,
));
}
Ok(buf)
}
}
#[cfg(feature = "std")]
fn read_varint(reader: &mut impl std::io::Read) -> Result<u64, std::io::Error> {
let mut value: u64 = 0;
let mut shift: u32 = 0;
loop {
let mut byte = [0u8; 1];
reader.read_exact(&mut byte)?;
let b = byte[0];
if shift < 63 {
value |= ((b & 0x7F) as u64) << shift;
if b < 0x80 {
return Ok(value);
}
shift += 7;
} else {
if b > 0x01 {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
DecodeError::VarintTooLong,
));
}
value |= (b as u64) << 63;
return Ok(value);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::cached_size::CachedSize;
use crate::encoding::encode_varint;
use crate::error::DecodeError;
use crate::message_field::DefaultInstance;
#[derive(Clone, Debug, Default, PartialEq)]
struct FlatMsg {
value: i32,
__buffa_cached_size: CachedSize,
}
unsafe impl DefaultInstance for FlatMsg {
fn default_instance() -> &'static Self {
static INST: crate::__private::OnceBox<FlatMsg> = crate::__private::OnceBox::new();
INST.get_or_init(|| alloc::boxed::Box::new(FlatMsg::default()))
}
}
impl Message for FlatMsg {
fn compute_size(&self) -> u32 {
let size = if self.value != 0 {
1 + crate::types::int32_encoded_len(self.value) as u32
} else {
0
};
self.__buffa_cached_size.set(size);
size
}
fn write_to(&self, buf: &mut impl BufMut) {
if self.value != 0 {
crate::encoding::Tag::new(1, crate::encoding::WireType::Varint).encode(buf);
crate::types::encode_int32(self.value, buf);
}
}
fn merge_field(
&mut self,
tag: crate::encoding::Tag,
buf: &mut impl Buf,
_depth: u32,
) -> Result<(), DecodeError> {
match tag.field_number() {
1 => {
self.value = crate::types::decode_int32(buf)?;
}
_ => {
crate::encoding::skip_field(tag, buf)?;
}
}
Ok(())
}
fn cached_size(&self) -> u32 {
self.__buffa_cached_size.get()
}
fn clear(&mut self) {
*self = Self::default();
}
}
fn wire_bytes(msg: &FlatMsg) -> alloc::vec::Vec<u8> {
let mut buf = alloc::vec::Vec::new();
msg.encode_length_delimited(&mut buf);
buf
}
#[test]
fn test_merge_length_delimited_basic() {
let src = FlatMsg {
value: 42,
__buffa_cached_size: CachedSize::default(),
};
let mut dst = FlatMsg::default();
dst.merge_length_delimited(&mut wire_bytes(&src).as_slice(), RECURSION_LIMIT)
.unwrap();
assert_eq!(dst.value, 42);
}
#[test]
fn test_merge_length_delimited_merges_into_existing() {
let mut dst = FlatMsg::default();
dst.merge_length_delimited(
&mut wire_bytes(&FlatMsg {
value: 1,
__buffa_cached_size: CachedSize::default(),
})
.as_slice(),
RECURSION_LIMIT,
)
.unwrap();
assert_eq!(dst.value, 1);
dst.merge_length_delimited(
&mut wire_bytes(&FlatMsg {
value: 2,
__buffa_cached_size: CachedSize::default(),
})
.as_slice(),
RECURSION_LIMIT,
)
.unwrap();
assert_eq!(dst.value, 2);
}
#[test]
fn test_merge_length_delimited_truncated() {
let mut buf = alloc::vec::Vec::new();
encode_varint(10, &mut buf);
buf.extend_from_slice(&[0x01, 0x01]);
let mut dst = FlatMsg::default();
assert_eq!(
dst.merge_length_delimited(&mut buf.as_slice(), RECURSION_LIMIT),
Err(DecodeError::UnexpectedEof)
);
}
#[test]
fn test_merge_length_delimited_oversized() {
let mut buf = alloc::vec::Vec::new();
encode_varint(0x8000_0000u64, &mut buf); let mut dst = FlatMsg::default();
assert_eq!(
dst.merge_length_delimited(&mut buf.as_slice(), RECURSION_LIMIT),
Err(DecodeError::MessageTooLarge)
);
}
#[test]
fn test_merge_length_delimited_recursion_limit() {
let src = FlatMsg {
value: 7,
__buffa_cached_size: CachedSize::default(),
};
let mut dst = FlatMsg::default();
assert_eq!(
dst.merge_length_delimited(&mut wire_bytes(&src).as_slice(), 0),
Err(DecodeError::RecursionLimitExceeded)
);
dst.merge_length_delimited(&mut wire_bytes(&src).as_slice(), 1)
.unwrap();
assert_eq!(dst.value, 7);
}
#[test]
fn test_decode_from_slice_basic() {
let src = FlatMsg {
value: 42,
__buffa_cached_size: CachedSize::default(),
};
let bytes = src.encode_to_vec();
let dst = FlatMsg::decode_from_slice(&bytes).unwrap();
assert_eq!(dst.value, 42);
}
#[test]
fn test_encode_to_bytes_matches_encode_to_vec() {
let src = FlatMsg {
value: 42,
__buffa_cached_size: CachedSize::default(),
};
let vec = src.encode_to_vec();
let bytes = src.encode_to_bytes();
assert_eq!(vec.as_slice(), bytes.as_ref());
let dst = FlatMsg::decode_from_slice(&bytes).unwrap();
assert_eq!(dst.value, 42);
assert!(FlatMsg::default().encode_to_bytes().is_empty());
}
#[test]
fn test_decode_from_slice_empty() {
let dst = FlatMsg::decode_from_slice(&[]).unwrap();
assert_eq!(dst.value, 0);
}
#[test]
fn test_decode_from_slice_invalid_returns_error() {
let result = FlatMsg::decode_from_slice(&[0xFF]);
assert!(result.is_err());
}
#[test]
fn test_merge_from_slice_basic() {
let src = FlatMsg {
value: 7,
__buffa_cached_size: CachedSize::default(),
};
let bytes = src.encode_to_vec();
let mut dst = FlatMsg::default();
dst.merge_from_slice(&bytes).unwrap();
assert_eq!(dst.value, 7);
}
#[test]
fn test_merge_from_slice_last_wins() {
let src1 = FlatMsg {
value: 1,
__buffa_cached_size: CachedSize::default(),
};
let src2 = FlatMsg {
value: 2,
__buffa_cached_size: CachedSize::default(),
};
let mut dst = FlatMsg::default();
dst.merge_from_slice(&src1.encode_to_vec()).unwrap();
dst.merge_from_slice(&src2.encode_to_vec()).unwrap();
assert_eq!(dst.value, 2);
}
#[test]
fn test_decode_options_default_works() {
let src = FlatMsg {
value: 99,
__buffa_cached_size: CachedSize::default(),
};
let bytes = src.encode_to_vec();
let msg: FlatMsg = DecodeOptions::new().decode_from_slice(&bytes).unwrap();
assert_eq!(msg.value, 99);
}
#[test]
fn test_decode_options_max_message_size_rejects() {
let src = FlatMsg {
value: 42,
__buffa_cached_size: CachedSize::default(),
};
let bytes = src.encode_to_vec();
let result: Result<FlatMsg, _> = DecodeOptions::new()
.with_max_message_size(1)
.decode_from_slice(&bytes);
assert_eq!(result, Err(DecodeError::MessageTooLarge));
}
#[test]
fn test_decode_options_max_message_size_exact_boundary() {
let src = FlatMsg {
value: 42,
__buffa_cached_size: CachedSize::default(),
};
let bytes = src.encode_to_vec();
let msg: FlatMsg = DecodeOptions::new()
.with_max_message_size(bytes.len())
.decode_from_slice(&bytes)
.unwrap();
assert_eq!(msg.value, 42);
let result: Result<FlatMsg, _> = DecodeOptions::new()
.with_max_message_size(bytes.len() - 1)
.decode_from_slice(&bytes);
assert_eq!(result, Err(DecodeError::MessageTooLarge));
}
#[test]
fn test_decode_options_custom_recursion_limit() {
let src = FlatMsg {
value: 7,
__buffa_cached_size: CachedSize::default(),
};
let bytes = src.encode_to_vec();
let msg: FlatMsg = DecodeOptions::new()
.with_recursion_limit(1)
.decode_from_slice(&bytes)
.unwrap();
assert_eq!(msg.value, 7);
}
#[test]
fn test_decode_options_merge() {
let src = FlatMsg {
value: 55,
__buffa_cached_size: CachedSize::default(),
};
let bytes = src.encode_to_vec();
let mut msg = FlatMsg::default();
DecodeOptions::new()
.merge_from_slice(&mut msg, &bytes)
.unwrap();
assert_eq!(msg.value, 55);
}
#[test]
fn test_decode_options_merge_rejects_oversize() {
let src = FlatMsg {
value: 55,
__buffa_cached_size: CachedSize::default(),
};
let bytes = src.encode_to_vec();
let mut msg = FlatMsg::default();
let result = DecodeOptions::new()
.with_max_message_size(1)
.merge_from_slice(&mut msg, &bytes);
assert_eq!(result, Err(DecodeError::MessageTooLarge));
}
#[test]
fn test_decode_options_length_delimited() {
let src = FlatMsg {
value: 42,
__buffa_cached_size: CachedSize::default(),
};
let mut ld_bytes = alloc::vec::Vec::new();
src.encode_length_delimited(&mut ld_bytes);
let msg: FlatMsg = DecodeOptions::new()
.decode_length_delimited(&mut ld_bytes.as_slice())
.unwrap();
assert_eq!(msg.value, 42);
}
#[test]
fn test_decode_options_length_delimited_rejects_oversize() {
let src = FlatMsg {
value: 42,
__buffa_cached_size: CachedSize::default(),
};
let mut ld_bytes = alloc::vec::Vec::new();
src.encode_length_delimited(&mut ld_bytes);
let result: Result<FlatMsg, _> = DecodeOptions::new()
.with_max_message_size(1)
.decode_length_delimited(&mut ld_bytes.as_slice());
assert_eq!(result, Err(DecodeError::MessageTooLarge));
}
#[test]
fn decode_options_getters_return_defaults() {
let opts = DecodeOptions::new();
assert_eq!(opts.recursion_limit(), RECURSION_LIMIT);
assert_eq!(opts.max_message_size(), 0x7FFF_FFFF);
}
#[test]
fn decode_options_getters_return_custom_values() {
let opts = DecodeOptions::new()
.with_recursion_limit(42)
.with_max_message_size(1024);
assert_eq!(opts.recursion_limit(), 42);
assert_eq!(opts.max_message_size(), 1024);
}
#[test]
fn test_decode_options_default_impl() {
let opts = DecodeOptions::default();
assert_eq!(opts.recursion_limit(), RECURSION_LIMIT);
assert_eq!(opts.max_message_size(), 0x7FFF_FFFF);
}
#[test]
fn test_decode_options_decode_buf() {
let src = FlatMsg {
value: 123,
..Default::default()
};
let bytes = src.encode_to_vec();
let msg: FlatMsg = DecodeOptions::new().decode(&mut bytes.as_slice()).unwrap();
assert_eq!(msg.value, 123);
let result: Result<FlatMsg, _> = DecodeOptions::new()
.with_max_message_size(1)
.decode(&mut bytes.as_slice());
assert_eq!(result, Err(DecodeError::MessageTooLarge));
}
#[test]
fn test_decode_options_merge_buf() {
let src = FlatMsg {
value: 77,
..Default::default()
};
let bytes = src.encode_to_vec();
let mut msg = FlatMsg::default();
DecodeOptions::new()
.merge(&mut msg, &mut bytes.as_slice())
.unwrap();
assert_eq!(msg.value, 77);
let mut msg = FlatMsg::default();
let result = DecodeOptions::new()
.with_max_message_size(1)
.merge(&mut msg, &mut bytes.as_slice());
assert_eq!(result, Err(DecodeError::MessageTooLarge));
}
#[test]
fn test_message_encode_trait_default() {
let src = FlatMsg {
value: 42,
..Default::default()
};
let mut buf = alloc::vec::Vec::new();
src.encode(&mut buf);
assert_eq!(buf, src.encode_to_vec());
}
#[test]
fn test_message_decode_length_delimited_trait_default() {
let src = FlatMsg {
value: 42,
..Default::default()
};
let mut ld = alloc::vec::Vec::new();
src.encode_length_delimited(&mut ld);
let got = FlatMsg::decode_length_delimited(&mut ld.as_slice()).unwrap();
assert_eq!(got.value, 42);
}
#[test]
fn test_message_decode_length_delimited_oversize() {
let mut buf = alloc::vec::Vec::new();
encode_varint(0x8000_0000u64, &mut buf);
let result = FlatMsg::decode_length_delimited(&mut buf.as_slice());
assert_eq!(result, Err(DecodeError::MessageTooLarge));
}
#[test]
fn test_message_decode_length_delimited_truncated() {
let mut buf = alloc::vec::Vec::new();
encode_varint(10, &mut buf);
buf.push(0x08);
buf.push(0x01);
let result = FlatMsg::decode_length_delimited(&mut buf.as_slice());
assert_eq!(result, Err(DecodeError::UnexpectedEof));
}
#[test]
fn test_message_decode_length_delimited_with_trailing() {
let a = FlatMsg {
value: 1,
..Default::default()
};
let b = FlatMsg {
value: 2,
..Default::default()
};
let mut buf = alloc::vec::Vec::new();
a.encode_length_delimited(&mut buf);
b.encode_length_delimited(&mut buf);
let mut cur = buf.as_slice();
let first = FlatMsg::decode_length_delimited(&mut cur).unwrap();
assert_eq!(first.value, 1);
let second = FlatMsg::decode_length_delimited(&mut cur).unwrap();
assert_eq!(second.value, 2);
assert!(cur.is_empty());
}
fn group_bytes(value: i32, group_field_number: u32) -> alloc::vec::Vec<u8> {
use crate::encoding::{Tag, WireType};
let mut buf = alloc::vec::Vec::new();
if value != 0 {
Tag::new(1, WireType::Varint).encode(&mut buf);
crate::types::encode_int32(value, &mut buf);
}
Tag::new(group_field_number, WireType::EndGroup).encode(&mut buf);
buf
}
#[test]
fn test_merge_group_basic() {
let data = group_bytes(42, 5);
let mut dst = FlatMsg::default();
dst.merge_group(&mut data.as_slice(), RECURSION_LIMIT, 5)
.unwrap();
assert_eq!(dst.value, 42);
}
#[test]
fn test_merge_group_empty() {
let data = group_bytes(0, 3);
let mut dst = FlatMsg::default();
dst.merge_group(&mut data.as_slice(), RECURSION_LIMIT, 3)
.unwrap();
assert_eq!(dst.value, 0);
}
#[test]
fn test_merge_group_merges_into_existing() {
let data1 = group_bytes(1, 5);
let data2 = group_bytes(2, 5);
let mut dst = FlatMsg::default();
dst.merge_group(&mut data1.as_slice(), RECURSION_LIMIT, 5)
.unwrap();
assert_eq!(dst.value, 1);
dst.merge_group(&mut data2.as_slice(), RECURSION_LIMIT, 5)
.unwrap();
assert_eq!(dst.value, 2);
}
#[test]
fn test_merge_group_recursion_limit_zero() {
let data = group_bytes(42, 5);
let mut dst = FlatMsg::default();
assert_eq!(
dst.merge_group(&mut data.as_slice(), 0, 5),
Err(DecodeError::RecursionLimitExceeded)
);
}
#[test]
fn test_merge_group_recursion_limit_one_succeeds() {
let data = group_bytes(7, 5);
let mut dst = FlatMsg::default();
dst.merge_group(&mut data.as_slice(), 1, 5).unwrap();
assert_eq!(dst.value, 7);
}
#[test]
fn test_merge_group_mismatched_end() {
use crate::encoding::{Tag, WireType};
let mut data = alloc::vec::Vec::new();
Tag::new(99, WireType::EndGroup).encode(&mut data);
let mut dst = FlatMsg::default();
assert_eq!(
dst.merge_group(&mut data.as_slice(), RECURSION_LIMIT, 5),
Err(DecodeError::InvalidEndGroup(99))
);
}
#[test]
fn test_merge_group_truncated() {
use crate::encoding::{Tag, WireType};
let mut data = alloc::vec::Vec::new();
Tag::new(1, WireType::Varint).encode(&mut data);
crate::types::encode_int32(42, &mut data);
let mut dst = FlatMsg::default();
assert_eq!(
dst.merge_group(&mut data.as_slice(), RECURSION_LIMIT, 5),
Err(DecodeError::UnexpectedEof)
);
}
#[test]
fn test_merge_group_empty_buffer() {
let mut dst = FlatMsg::default();
assert_eq!(
dst.merge_group(&mut [].as_slice(), RECURSION_LIMIT, 5),
Err(DecodeError::UnexpectedEof)
);
}
#[test]
fn test_merge_group_unknown_fields_skipped() {
use crate::encoding::{Tag, WireType};
let mut data = alloc::vec::Vec::new();
Tag::new(99, WireType::Varint).encode(&mut data);
crate::encoding::encode_varint(0, &mut data);
Tag::new(1, WireType::Varint).encode(&mut data);
crate::types::encode_int32(99, &mut data);
Tag::new(5, WireType::EndGroup).encode(&mut data);
let mut dst = FlatMsg::default();
dst.merge_group(&mut data.as_slice(), RECURSION_LIMIT, 5)
.unwrap();
assert_eq!(dst.value, 99);
}
#[test]
fn test_merge_group_trailing_data_preserved() {
let mut data = group_bytes(42, 5);
data.extend_from_slice(&[0xDE, 0xAD]);
let mut cur = data.as_slice();
let mut dst = FlatMsg::default();
dst.merge_group(&mut cur, RECURSION_LIMIT, 5).unwrap();
assert_eq!(dst.value, 42);
assert_eq!(cur, &[0xDE, 0xAD]);
}
#[cfg(feature = "std")]
mod read_varint_tests {
use super::super::read_varint;
use crate::encoding::encode_varint;
#[test]
fn roundtrip_values() {
let cases: &[u64] = &[0, 1, 127, 128, 300, 1 << 14, 1 << 35, 1 << 63, u64::MAX];
for &v in cases {
let mut buf = Vec::new();
encode_varint(v, &mut buf);
let got = read_varint(&mut buf.as_slice()).unwrap();
assert_eq!(got, v, "roundtrip failed for {v}");
}
}
#[test]
fn rejects_10th_byte_overflow() {
let mut bad: Vec<u8> = vec![0xFF; 9];
bad.push(0x02);
let err = read_varint(&mut bad.as_slice()).unwrap_err();
assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
}
#[test]
fn rejects_11th_byte() {
let bad: &[u8] = &[0xFF; 10];
let err = read_varint(&mut &bad[..]).unwrap_err();
assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
}
#[test]
fn u64_max_roundtrips() {
let mut buf = Vec::new();
encode_varint(u64::MAX, &mut buf);
assert_eq!(buf.len(), 10);
assert_eq!(buf[9], 0x01);
let got = read_varint(&mut buf.as_slice()).unwrap();
assert_eq!(got, u64::MAX);
}
#[test]
fn eof_before_terminator_is_error() {
let bad: &[u8] = &[0x80];
let err = read_varint(&mut &bad[..]).unwrap_err();
assert_eq!(err.kind(), std::io::ErrorKind::UnexpectedEof);
}
#[test]
fn empty_input_is_error() {
let err = read_varint(&mut &[][..]).unwrap_err();
assert_eq!(err.kind(), std::io::ErrorKind::UnexpectedEof);
}
}
#[cfg(feature = "std")]
mod reader_tests {
use super::*;
#[test]
fn decode_reader_basic() {
let src = FlatMsg {
value: 42,
..Default::default()
};
let bytes = src.encode_to_vec();
let msg: FlatMsg = DecodeOptions::new()
.decode_reader(&mut bytes.as_slice())
.unwrap();
assert_eq!(msg.value, 42);
}
#[test]
fn decode_reader_rejects_oversize() {
let src = FlatMsg {
value: 42,
..Default::default()
};
let bytes = src.encode_to_vec();
let err = DecodeOptions::new()
.with_max_message_size(1)
.decode_reader::<FlatMsg>(&mut bytes.as_slice())
.unwrap_err();
assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
}
#[test]
fn decode_reader_exact_boundary() {
let src = FlatMsg {
value: 42,
..Default::default()
};
let bytes = src.encode_to_vec();
let msg: FlatMsg = DecodeOptions::new()
.with_max_message_size(bytes.len())
.decode_reader(&mut bytes.as_slice())
.unwrap();
assert_eq!(msg.value, 42);
}
#[test]
fn decode_reader_propagates_read_error() {
struct ErrReader;
impl std::io::Read for ErrReader {
fn read(&mut self, _: &mut [u8]) -> std::io::Result<usize> {
Err(std::io::Error::new(std::io::ErrorKind::BrokenPipe, "gone"))
}
}
let err = DecodeOptions::new()
.decode_reader::<FlatMsg>(&mut ErrReader)
.unwrap_err();
assert_eq!(err.kind(), std::io::ErrorKind::BrokenPipe);
}
#[test]
fn decode_length_delimited_reader_basic() {
let src = FlatMsg {
value: 99,
..Default::default()
};
let mut ld = Vec::new();
src.encode_length_delimited(&mut ld);
let msg: FlatMsg = DecodeOptions::new()
.decode_length_delimited_reader(&mut ld.as_slice())
.unwrap();
assert_eq!(msg.value, 99);
}
#[test]
fn decode_length_delimited_reader_rejects_oversize_prefix() {
let src = FlatMsg {
value: 99,
..Default::default()
};
let mut ld = Vec::new();
src.encode_length_delimited(&mut ld);
let err = DecodeOptions::new()
.with_max_message_size(1)
.decode_length_delimited_reader::<FlatMsg>(&mut ld.as_slice())
.unwrap_err();
assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
}
#[test]
fn decode_length_delimited_reader_sequential() {
let a = FlatMsg {
value: 10,
..Default::default()
};
let b = FlatMsg {
value: 20,
..Default::default()
};
let mut stream = Vec::new();
a.encode_length_delimited(&mut stream);
b.encode_length_delimited(&mut stream);
let mut cursor = std::io::Cursor::new(stream);
let first: FlatMsg = DecodeOptions::new()
.decode_length_delimited_reader(&mut cursor)
.unwrap();
assert_eq!(first.value, 10);
let second: FlatMsg = DecodeOptions::new()
.decode_length_delimited_reader(&mut cursor)
.unwrap();
assert_eq!(second.value, 20);
}
#[test]
fn decode_length_delimited_reader_truncated_body() {
let mut buf = Vec::new();
crate::encoding::encode_varint(100, &mut buf);
buf.push(0x08);
let err = DecodeOptions::new()
.decode_length_delimited_reader::<FlatMsg>(&mut buf.as_slice())
.unwrap_err();
assert_eq!(err.kind(), std::io::ErrorKind::UnexpectedEof);
}
}
}