use bytes::{Buf, BufMut};
use crate::error::DecodeError;
use crate::message_field::DefaultInstance;
pub const RECURSION_LIMIT: u32 = 100;
pub const DEFAULT_UNKNOWN_FIELD_LIMIT: usize = 1_000_000;
#[derive(Clone, Copy, Debug)]
pub struct DecodeContext<'a> {
depth: u32,
unknown_fields_remaining: &'a core::cell::Cell<usize>,
}
impl<'a> DecodeContext<'a> {
#[must_use]
pub fn new(depth: u32, unknown_field_limit: &'a core::cell::Cell<usize>) -> Self {
Self {
depth,
unknown_fields_remaining: unknown_field_limit,
}
}
#[must_use]
pub fn depth(&self) -> u32 {
self.depth
}
#[must_use]
pub fn remaining_unknown_fields(&self) -> usize {
self.unknown_fields_remaining.get()
}
pub fn descend(self) -> Result<Self, DecodeError> {
let depth = self
.depth
.checked_sub(1)
.ok_or(DecodeError::RecursionLimitExceeded)?;
Ok(Self { depth, ..self })
}
pub fn register_unknown_field(&self) -> Result<(), DecodeError> {
let remaining = self.unknown_fields_remaining.get();
if remaining == 0 {
return Err(DecodeError::UnknownFieldLimitExceeded);
}
self.unknown_fields_remaining.set(remaining - 1);
Ok(())
}
}
pub trait Message: DefaultInstance + Clone + PartialEq + Send + Sync {
fn compute_size(&self, cache: &mut crate::SizeCache) -> u32;
fn write_to(&self, cache: &mut crate::SizeCache, buf: &mut impl BufMut);
fn encode(&self, buf: &mut impl BufMut) {
let mut cache = crate::SizeCache::new();
self.compute_size(&mut cache);
self.write_to(&mut cache, buf);
}
fn encode_with_cache(&self, cache: &mut crate::SizeCache, buf: &mut impl BufMut) {
cache.clear();
self.compute_size(cache);
self.write_to(cache, buf);
}
#[must_use]
fn encoded_len(&self) -> u32 {
self.compute_size(&mut crate::SizeCache::new())
}
fn encode_length_delimited(&self, buf: &mut impl BufMut) {
let mut cache = crate::SizeCache::new();
let len = self.compute_size(&mut cache);
crate::encoding::encode_varint(len as u64, buf);
self.write_to(&mut cache, buf);
}
#[must_use]
fn encode_to_vec(&self) -> alloc::vec::Vec<u8> {
let mut cache = crate::SizeCache::new();
let size = self.compute_size(&mut cache) as usize;
let mut buf = alloc::vec::Vec::with_capacity(size);
self.write_to(&mut cache, &mut buf);
buf
}
#[must_use]
fn encode_to_bytes(&self) -> bytes::Bytes {
let mut cache = crate::SizeCache::new();
let size = self.compute_size(&mut cache) as usize;
let mut buf = bytes::BytesMut::with_capacity(size);
self.write_to(&mut cache, &mut buf);
buf.freeze()
}
fn decode(buf: &mut impl Buf) -> Result<Self, DecodeError>
where
Self: Sized,
{
let limit = core::cell::Cell::new(DEFAULT_UNKNOWN_FIELD_LIMIT);
let mut msg = Self::default();
msg.merge(buf, DecodeContext::new(RECURSION_LIMIT, &limit))?;
Ok(msg)
}
fn decode_from_slice(mut data: &[u8]) -> Result<Self, DecodeError>
where
Self: Sized,
{
Self::decode(&mut data)
}
fn decode_length_delimited(buf: &mut impl Buf) -> Result<Self, DecodeError>
where
Self: Sized,
{
const MAX_MESSAGE_BYTES: u64 = 0x7FFF_FFFF;
let len_u64 = crate::encoding::decode_varint(buf)?;
if len_u64 > MAX_MESSAGE_BYTES {
return Err(DecodeError::MessageTooLarge);
}
let len = usize::try_from(len_u64).map_err(|_| DecodeError::MessageTooLarge)?;
if buf.remaining() < len {
return Err(DecodeError::UnexpectedEof);
}
let limit = buf.remaining() - len;
let field_limit = core::cell::Cell::new(DEFAULT_UNKNOWN_FIELD_LIMIT);
let mut msg = Self::default();
msg.merge_to_limit(
buf,
DecodeContext::new(RECURSION_LIMIT, &field_limit),
limit,
)?;
if buf.remaining() != limit {
let remaining = buf.remaining();
if remaining > limit {
buf.advance(remaining - limit);
} else {
return Err(DecodeError::UnexpectedEof);
}
}
Ok(msg)
}
fn merge_field(
&mut self,
tag: crate::encoding::Tag,
buf: &mut impl Buf,
ctx: DecodeContext<'_>,
) -> Result<(), DecodeError>;
fn merge_to_limit(
&mut self,
buf: &mut impl Buf,
ctx: DecodeContext<'_>,
limit: usize,
) -> Result<(), DecodeError> {
while buf.remaining() > limit {
let tag = crate::encoding::Tag::decode(buf)?;
self.merge_field(tag, buf, ctx)?;
}
Ok(())
}
fn merge_group(
&mut self,
buf: &mut impl Buf,
ctx: DecodeContext<'_>,
field_number: u32,
) -> Result<(), DecodeError> {
let ctx = ctx.descend()?;
loop {
if !buf.has_remaining() {
return Err(DecodeError::UnexpectedEof);
}
let tag = crate::encoding::Tag::decode(buf)?;
if tag.wire_type() == crate::encoding::WireType::EndGroup {
return if tag.field_number() == field_number {
Ok(())
} else {
Err(DecodeError::InvalidEndGroup(tag.field_number()))
};
}
self.merge_field(tag, buf, ctx)?;
}
}
fn merge(&mut self, buf: &mut impl Buf, ctx: DecodeContext<'_>) -> Result<(), DecodeError> {
self.merge_to_limit(buf, ctx, 0)
}
fn merge_from_slice(&mut self, mut data: &[u8]) -> Result<(), DecodeError> {
let limit = core::cell::Cell::new(DEFAULT_UNKNOWN_FIELD_LIMIT);
self.merge(&mut data, DecodeContext::new(RECURSION_LIMIT, &limit))
}
fn merge_length_delimited(
&mut self,
buf: &mut impl Buf,
ctx: DecodeContext<'_>,
) -> Result<(), DecodeError> {
let ctx = ctx.descend()?;
const MAX_SUB_MESSAGE_BYTES: u64 = 0x7FFF_FFFF;
let len_u64 = crate::encoding::decode_varint(buf)?;
if len_u64 > MAX_SUB_MESSAGE_BYTES {
return Err(DecodeError::MessageTooLarge);
}
let len = usize::try_from(len_u64).map_err(|_| DecodeError::MessageTooLarge)?;
if buf.remaining() < len {
return Err(DecodeError::UnexpectedEof);
}
let limit = buf.remaining() - len;
self.merge_to_limit(buf, ctx, limit)?;
if buf.remaining() != limit {
let remaining = buf.remaining();
if remaining > limit {
buf.advance(remaining - limit);
} else {
return Err(DecodeError::UnexpectedEof);
}
}
Ok(())
}
fn clear(&mut self);
}
pub trait MessageName {
const PACKAGE: &'static str;
const NAME: &'static str;
const FULL_NAME: &'static str;
const TYPE_URL: &'static str;
}
#[derive(Debug, Clone)]
pub struct DecodeOptions {
recursion_limit: u32,
max_message_size: usize,
unbounded_reader_size: bool,
unknown_field_limit: usize,
}
const DEFAULT_MAX_MESSAGE_SIZE: usize = 0x7FFF_FFFF;
impl Default for DecodeOptions {
fn default() -> Self {
Self::new()
}
}
impl DecodeOptions {
pub fn new() -> Self {
Self {
recursion_limit: RECURSION_LIMIT,
max_message_size: DEFAULT_MAX_MESSAGE_SIZE,
unbounded_reader_size: false,
unknown_field_limit: DEFAULT_UNKNOWN_FIELD_LIMIT,
}
}
#[must_use]
pub fn with_recursion_limit(mut self, limit: u32) -> Self {
self.recursion_limit = limit;
self
}
#[must_use]
pub fn with_max_message_size(mut self, max_bytes: usize) -> Self {
debug_assert!(
max_bytes <= DEFAULT_MAX_MESSAGE_SIZE,
"DecodeOptions::with_max_message_size clamps values above the protobuf 2 GiB limit; \
on std builds, use DecodeOptions::without_reader_size_limit for intentionally \
unbounded reader input — there is no unbounded slice/Buf path"
);
self.max_message_size = max_bytes.min(DEFAULT_MAX_MESSAGE_SIZE);
self.unbounded_reader_size = false;
self
}
#[cfg(feature = "std")]
#[must_use]
pub fn without_reader_size_limit(mut self) -> Self {
self.unbounded_reader_size = true;
self
}
#[must_use]
pub fn with_unknown_field_limit(mut self, count: usize) -> Self {
self.unknown_field_limit = count;
self
}
pub fn recursion_limit(&self) -> u32 {
self.recursion_limit
}
pub fn unknown_field_limit(&self) -> usize {
self.unknown_field_limit
}
pub fn max_message_size(&self) -> usize {
self.max_message_size
}
#[cfg(feature = "std")]
pub fn is_reader_size_unbounded(&self) -> bool {
self.unbounded_reader_size
}
pub fn decode<M: Message>(&self, buf: &mut impl Buf) -> Result<M, DecodeError> {
if buf.remaining() > self.max_message_size {
return Err(DecodeError::MessageTooLarge);
}
let limit = core::cell::Cell::new(self.unknown_field_limit);
let mut msg = M::default();
msg.merge(buf, DecodeContext::new(self.recursion_limit, &limit))?;
Ok(msg)
}
pub fn decode_from_slice<M: Message>(&self, data: &[u8]) -> Result<M, DecodeError> {
if data.len() > self.max_message_size {
return Err(DecodeError::MessageTooLarge);
}
self.decode_from_slice_unchecked_size(data)
}
fn decode_from_slice_unchecked_size<M: Message>(&self, data: &[u8]) -> Result<M, DecodeError> {
let limit = core::cell::Cell::new(self.unknown_field_limit);
let mut msg = M::default();
msg.merge(
&mut &*data,
DecodeContext::new(self.recursion_limit, &limit),
)?;
Ok(msg)
}
pub fn decode_length_delimited<M: Message>(
&self,
buf: &mut impl Buf,
) -> Result<M, DecodeError> {
let max = core::cmp::min(
self.max_message_size as u64,
DEFAULT_MAX_MESSAGE_SIZE as u64,
);
let len_u64 = crate::encoding::decode_varint(buf)?;
if len_u64 > max {
return Err(DecodeError::MessageTooLarge);
}
let len = usize::try_from(len_u64).map_err(|_| DecodeError::MessageTooLarge)?;
if buf.remaining() < len {
return Err(DecodeError::UnexpectedEof);
}
let limit = buf.remaining() - len;
let field_limit = core::cell::Cell::new(self.unknown_field_limit);
let mut msg = M::default();
msg.merge_to_limit(
buf,
DecodeContext::new(self.recursion_limit, &field_limit),
limit,
)?;
if buf.remaining() != limit {
let remaining = buf.remaining();
if remaining > limit {
buf.advance(remaining - limit);
} else {
return Err(DecodeError::UnexpectedEof);
}
}
Ok(msg)
}
pub fn merge<M: Message>(&self, msg: &mut M, buf: &mut impl Buf) -> Result<(), DecodeError> {
if buf.remaining() > self.max_message_size {
return Err(DecodeError::MessageTooLarge);
}
let limit = core::cell::Cell::new(self.unknown_field_limit);
msg.merge(buf, DecodeContext::new(self.recursion_limit, &limit))
}
pub fn merge_from_slice<M: Message>(
&self,
msg: &mut M,
data: &[u8],
) -> Result<(), DecodeError> {
if data.len() > self.max_message_size {
return Err(DecodeError::MessageTooLarge);
}
let limit = core::cell::Cell::new(self.unknown_field_limit);
msg.merge(
&mut &*data,
DecodeContext::new(self.recursion_limit, &limit),
)
}
pub fn decode_view<'a, V: crate::view::MessageView<'a>>(
&self,
buf: &'a [u8],
) -> Result<V, DecodeError> {
if buf.len() > self.max_message_size {
return Err(DecodeError::MessageTooLarge);
}
let limit = core::cell::Cell::new(self.unknown_field_limit);
V::decode_view_with_ctx(buf, DecodeContext::new(self.recursion_limit, &limit))
}
pub fn decode_lazy_view<'a, L: crate::view::LazyMessageView<'a>>(
&self,
buf: &'a [u8],
) -> Result<L, DecodeError> {
if buf.len() > self.max_message_size {
return Err(DecodeError::MessageTooLarge);
}
let limit = core::cell::Cell::new(self.unknown_field_limit);
L::decode_lazy_with_ctx(buf, DecodeContext::new(self.recursion_limit, &limit))
}
#[cfg(feature = "std")]
pub fn decode_reader<M: Message>(
&self,
reader: &mut impl std::io::Read,
) -> Result<M, std::io::Error> {
let bytes = self.read_limited(reader)?;
self.decode_from_slice_unchecked_size::<M>(&bytes)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))
}
#[cfg(feature = "std")]
pub fn decode_length_delimited_reader<M: Message>(
&self,
reader: &mut impl std::io::Read,
) -> Result<M, std::io::Error> {
use std::io::Read as _;
let len = read_varint(reader)?;
let max = core::cmp::min(
self.max_message_size as u64,
DEFAULT_MAX_MESSAGE_SIZE as u64,
);
if len > max {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
DecodeError::MessageTooLarge,
));
}
let len = usize::try_from(len).map_err(|_| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
DecodeError::MessageTooLarge,
)
})?;
const INITIAL_CAPACITY_CAP: usize = 64 * 1024;
let mut buf = alloc::vec::Vec::with_capacity(len.min(INITIAL_CAPACITY_CAP));
reader.take(len as u64).read_to_end(&mut buf)?;
if buf.len() < len {
return Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
DecodeError::UnexpectedEof,
));
}
self.decode_from_slice::<M>(&buf)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))
}
#[cfg(feature = "std")]
fn read_limited(
&self,
reader: &mut impl std::io::Read,
) -> Result<alloc::vec::Vec<u8>, std::io::Error> {
use std::io::Read as _;
let mut buf = alloc::vec::Vec::new();
if self.unbounded_reader_size {
reader.read_to_end(&mut buf)?;
return Ok(buf);
}
reader
.take((self.max_message_size as u64).saturating_add(1))
.read_to_end(&mut buf)?;
if buf.len() > self.max_message_size {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
DecodeError::MessageTooLarge,
));
}
Ok(buf)
}
}
#[cfg(feature = "std")]
fn read_varint(reader: &mut impl std::io::Read) -> Result<u64, std::io::Error> {
let mut value: u64 = 0;
let mut shift: u32 = 0;
loop {
let mut byte = [0u8; 1];
reader.read_exact(&mut byte)?;
let b = byte[0];
if shift < 63 {
value |= ((b & 0x7F) as u64) << shift;
if b < 0x80 {
return Ok(value);
}
shift += 7;
} else {
if b > 0x01 {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
DecodeError::VarintTooLong,
));
}
value |= (b as u64) << 63;
return Ok(value);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::encoding::encode_varint;
use crate::error::DecodeError;
use crate::message_field::DefaultInstance;
use crate::SizeCache;
#[derive(Clone, Debug, Default, PartialEq)]
struct FlatMsg {
value: i32,
}
impl DefaultInstance for FlatMsg {
fn default_instance() -> &'static Self {
static INST: crate::__private::OnceBox<FlatMsg> = crate::__private::OnceBox::new();
INST.get_or_init(|| alloc::boxed::Box::new(FlatMsg::default()))
}
}
impl Message for FlatMsg {
fn compute_size(&self, _cache: &mut SizeCache) -> u32 {
if self.value != 0 {
1 + crate::types::int32_encoded_len(self.value) as u32
} else {
0
}
}
fn write_to(&self, _cache: &mut SizeCache, buf: &mut impl BufMut) {
if self.value != 0 {
crate::encoding::Tag::new(1, crate::encoding::WireType::Varint).encode(buf);
crate::types::encode_int32(self.value, buf);
}
}
fn merge_field(
&mut self,
tag: crate::encoding::Tag,
buf: &mut impl Buf,
_ctx: DecodeContext<'_>,
) -> Result<(), DecodeError> {
match tag.field_number() {
1 => {
self.value = crate::types::decode_int32(buf)?;
}
_ => {
crate::encoding::skip_field(tag, buf)?;
}
}
Ok(())
}
fn clear(&mut self) {
*self = Self::default();
}
}
fn wire_bytes(msg: &FlatMsg) -> alloc::vec::Vec<u8> {
let mut buf = alloc::vec::Vec::new();
msg.encode_length_delimited(&mut buf);
buf
}
#[test]
fn test_merge_length_delimited_basic() {
let src = FlatMsg { value: 42 };
let mut dst = FlatMsg::default();
dst.merge_length_delimited(
&mut wire_bytes(&src).as_slice(),
crate::test_ctx(RECURSION_LIMIT),
)
.unwrap();
assert_eq!(dst.value, 42);
}
#[test]
fn test_merge_length_delimited_merges_into_existing() {
let mut dst = FlatMsg::default();
dst.merge_length_delimited(
&mut wire_bytes(&FlatMsg { value: 1 }).as_slice(),
crate::test_ctx(RECURSION_LIMIT),
)
.unwrap();
assert_eq!(dst.value, 1);
dst.merge_length_delimited(
&mut wire_bytes(&FlatMsg { value: 2 }).as_slice(),
crate::test_ctx(RECURSION_LIMIT),
)
.unwrap();
assert_eq!(dst.value, 2);
}
#[test]
fn test_merge_length_delimited_truncated() {
let mut buf = alloc::vec::Vec::new();
encode_varint(10, &mut buf);
buf.extend_from_slice(&[0x01, 0x01]);
let mut dst = FlatMsg::default();
assert_eq!(
dst.merge_length_delimited(&mut buf.as_slice(), crate::test_ctx(RECURSION_LIMIT)),
Err(DecodeError::UnexpectedEof)
);
}
#[test]
fn test_merge_length_delimited_oversized() {
let mut buf = alloc::vec::Vec::new();
encode_varint(0x8000_0000u64, &mut buf); let mut dst = FlatMsg::default();
assert_eq!(
dst.merge_length_delimited(&mut buf.as_slice(), crate::test_ctx(RECURSION_LIMIT)),
Err(DecodeError::MessageTooLarge)
);
}
#[test]
fn test_merge_length_delimited_recursion_limit() {
let src = FlatMsg { value: 7 };
let mut dst = FlatMsg::default();
assert_eq!(
dst.merge_length_delimited(&mut wire_bytes(&src).as_slice(), crate::test_ctx(0)),
Err(DecodeError::RecursionLimitExceeded)
);
dst.merge_length_delimited(&mut wire_bytes(&src).as_slice(), crate::test_ctx(1))
.unwrap();
assert_eq!(dst.value, 7);
}
#[test]
fn test_decode_from_slice_basic() {
let src = FlatMsg { value: 42 };
let bytes = src.encode_to_vec();
let dst = FlatMsg::decode_from_slice(&bytes).unwrap();
assert_eq!(dst.value, 42);
}
#[test]
fn test_encode_to_bytes_matches_encode_to_vec() {
let src = FlatMsg { value: 42 };
let vec = src.encode_to_vec();
let bytes = src.encode_to_bytes();
assert_eq!(vec.as_slice(), bytes.as_ref());
let dst = FlatMsg::decode_from_slice(&bytes).unwrap();
assert_eq!(dst.value, 42);
assert!(FlatMsg::default().encode_to_bytes().is_empty());
}
#[test]
fn test_decode_from_slice_empty() {
let dst = FlatMsg::decode_from_slice(&[]).unwrap();
assert_eq!(dst.value, 0);
}
#[test]
fn test_decode_from_slice_invalid_returns_error() {
let result = FlatMsg::decode_from_slice(&[0xFF]);
assert!(result.is_err());
}
#[test]
fn test_merge_from_slice_basic() {
let src = FlatMsg { value: 7 };
let bytes = src.encode_to_vec();
let mut dst = FlatMsg::default();
dst.merge_from_slice(&bytes).unwrap();
assert_eq!(dst.value, 7);
}
#[test]
fn test_merge_from_slice_last_wins() {
let src1 = FlatMsg { value: 1 };
let src2 = FlatMsg { value: 2 };
let mut dst = FlatMsg::default();
dst.merge_from_slice(&src1.encode_to_vec()).unwrap();
dst.merge_from_slice(&src2.encode_to_vec()).unwrap();
assert_eq!(dst.value, 2);
}
#[test]
fn test_decode_options_default_works() {
let src = FlatMsg { value: 99 };
let bytes = src.encode_to_vec();
let msg: FlatMsg = DecodeOptions::new().decode_from_slice(&bytes).unwrap();
assert_eq!(msg.value, 99);
}
#[test]
fn test_decode_options_max_message_size_rejects() {
let src = FlatMsg { value: 42 };
let bytes = src.encode_to_vec();
let result: Result<FlatMsg, _> = DecodeOptions::new()
.with_max_message_size(1)
.decode_from_slice(&bytes);
assert_eq!(result, Err(DecodeError::MessageTooLarge));
}
#[test]
fn test_decode_options_max_message_size_exact_boundary() {
let src = FlatMsg { value: 42 };
let bytes = src.encode_to_vec();
let msg: FlatMsg = DecodeOptions::new()
.with_max_message_size(bytes.len())
.decode_from_slice(&bytes)
.unwrap();
assert_eq!(msg.value, 42);
let result: Result<FlatMsg, _> = DecodeOptions::new()
.with_max_message_size(bytes.len() - 1)
.decode_from_slice(&bytes);
assert_eq!(result, Err(DecodeError::MessageTooLarge));
}
#[test]
fn test_decode_options_custom_recursion_limit() {
let src = FlatMsg { value: 7 };
let bytes = src.encode_to_vec();
let msg: FlatMsg = DecodeOptions::new()
.with_recursion_limit(1)
.decode_from_slice(&bytes)
.unwrap();
assert_eq!(msg.value, 7);
}
#[test]
fn test_decode_options_merge() {
let src = FlatMsg { value: 55 };
let bytes = src.encode_to_vec();
let mut msg = FlatMsg::default();
DecodeOptions::new()
.merge_from_slice(&mut msg, &bytes)
.unwrap();
assert_eq!(msg.value, 55);
}
#[test]
fn test_decode_options_merge_rejects_oversize() {
let src = FlatMsg { value: 55 };
let bytes = src.encode_to_vec();
let mut msg = FlatMsg::default();
let result = DecodeOptions::new()
.with_max_message_size(1)
.merge_from_slice(&mut msg, &bytes);
assert_eq!(result, Err(DecodeError::MessageTooLarge));
}
#[test]
fn test_decode_options_length_delimited() {
let src = FlatMsg { value: 42 };
let mut ld_bytes = alloc::vec::Vec::new();
src.encode_length_delimited(&mut ld_bytes);
let msg: FlatMsg = DecodeOptions::new()
.decode_length_delimited(&mut ld_bytes.as_slice())
.unwrap();
assert_eq!(msg.value, 42);
}
#[test]
fn test_decode_options_length_delimited_rejects_oversize() {
let src = FlatMsg { value: 42 };
let mut ld_bytes = alloc::vec::Vec::new();
src.encode_length_delimited(&mut ld_bytes);
let result: Result<FlatMsg, _> = DecodeOptions::new()
.with_max_message_size(1)
.decode_length_delimited(&mut ld_bytes.as_slice());
assert_eq!(result, Err(DecodeError::MessageTooLarge));
}
#[test]
fn decode_options_getters_return_defaults() {
let opts = DecodeOptions::new();
assert_eq!(opts.recursion_limit(), RECURSION_LIMIT);
assert_eq!(opts.max_message_size(), 0x7FFF_FFFF);
assert_eq!(opts.unknown_field_limit(), DEFAULT_UNKNOWN_FIELD_LIMIT);
}
#[test]
fn decode_options_getters_return_custom_values() {
let opts = DecodeOptions::new()
.with_recursion_limit(42)
.with_max_message_size(1024)
.with_unknown_field_limit(2048);
assert_eq!(opts.recursion_limit(), 42);
assert_eq!(opts.max_message_size(), 1024);
assert_eq!(opts.unknown_field_limit(), 2048);
}
#[cfg(debug_assertions)]
#[test]
#[should_panic(expected = "protobuf 2 GiB limit")]
fn decode_options_max_message_size_above_protobuf_limit_debug_asserts() {
let _ = DecodeOptions::new().with_max_message_size(DEFAULT_MAX_MESSAGE_SIZE + 1);
}
#[cfg(not(debug_assertions))]
#[test]
fn decode_options_max_message_size_above_protobuf_limit_saturates() {
let opts = DecodeOptions::new().with_max_message_size(DEFAULT_MAX_MESSAGE_SIZE + 1);
assert_eq!(opts.max_message_size(), DEFAULT_MAX_MESSAGE_SIZE);
}
#[test]
fn test_decode_options_default_impl() {
let opts = DecodeOptions::default();
assert_eq!(opts.recursion_limit(), RECURSION_LIMIT);
assert_eq!(opts.max_message_size(), 0x7FFF_FFFF);
assert_eq!(opts.unknown_field_limit(), DEFAULT_UNKNOWN_FIELD_LIMIT);
}
#[test]
fn test_decode_options_decode_buf() {
let src = FlatMsg { value: 123 };
let bytes = src.encode_to_vec();
let msg: FlatMsg = DecodeOptions::new().decode(&mut bytes.as_slice()).unwrap();
assert_eq!(msg.value, 123);
let result: Result<FlatMsg, _> = DecodeOptions::new()
.with_max_message_size(1)
.decode(&mut bytes.as_slice());
assert_eq!(result, Err(DecodeError::MessageTooLarge));
}
#[test]
fn test_decode_options_merge_buf() {
let src = FlatMsg { value: 77 };
let bytes = src.encode_to_vec();
let mut msg = FlatMsg::default();
DecodeOptions::new()
.merge(&mut msg, &mut bytes.as_slice())
.unwrap();
assert_eq!(msg.value, 77);
let mut msg = FlatMsg::default();
let result = DecodeOptions::new()
.with_max_message_size(1)
.merge(&mut msg, &mut bytes.as_slice());
assert_eq!(result, Err(DecodeError::MessageTooLarge));
}
#[test]
fn test_message_encode_trait_default() {
let src = FlatMsg { value: 42 };
let mut buf = alloc::vec::Vec::new();
src.encode(&mut buf);
assert_eq!(buf, src.encode_to_vec());
}
#[test]
fn test_message_decode_length_delimited_trait_default() {
let src = FlatMsg { value: 42 };
let mut ld = alloc::vec::Vec::new();
src.encode_length_delimited(&mut ld);
let got = FlatMsg::decode_length_delimited(&mut ld.as_slice()).unwrap();
assert_eq!(got.value, 42);
}
#[test]
fn test_message_decode_length_delimited_oversize() {
let mut buf = alloc::vec::Vec::new();
encode_varint(0x8000_0000u64, &mut buf);
let result = FlatMsg::decode_length_delimited(&mut buf.as_slice());
assert_eq!(result, Err(DecodeError::MessageTooLarge));
}
#[test]
fn test_message_decode_length_delimited_truncated() {
let mut buf = alloc::vec::Vec::new();
encode_varint(10, &mut buf);
buf.push(0x08);
buf.push(0x01);
let result = FlatMsg::decode_length_delimited(&mut buf.as_slice());
assert_eq!(result, Err(DecodeError::UnexpectedEof));
}
#[test]
fn test_message_decode_length_delimited_with_trailing() {
let a = FlatMsg { value: 1 };
let b = FlatMsg { value: 2 };
let mut buf = alloc::vec::Vec::new();
a.encode_length_delimited(&mut buf);
b.encode_length_delimited(&mut buf);
let mut cur = buf.as_slice();
let first = FlatMsg::decode_length_delimited(&mut cur).unwrap();
assert_eq!(first.value, 1);
let second = FlatMsg::decode_length_delimited(&mut cur).unwrap();
assert_eq!(second.value, 2);
assert!(cur.is_empty());
}
fn group_bytes(value: i32, group_field_number: u32) -> alloc::vec::Vec<u8> {
use crate::encoding::{Tag, WireType};
let mut buf = alloc::vec::Vec::new();
if value != 0 {
Tag::new(1, WireType::Varint).encode(&mut buf);
crate::types::encode_int32(value, &mut buf);
}
Tag::new(group_field_number, WireType::EndGroup).encode(&mut buf);
buf
}
#[test]
fn test_merge_group_basic() {
let data = group_bytes(42, 5);
let mut dst = FlatMsg::default();
dst.merge_group(&mut data.as_slice(), crate::test_ctx(RECURSION_LIMIT), 5)
.unwrap();
assert_eq!(dst.value, 42);
}
#[test]
fn test_merge_group_empty() {
let data = group_bytes(0, 3);
let mut dst = FlatMsg::default();
dst.merge_group(&mut data.as_slice(), crate::test_ctx(RECURSION_LIMIT), 3)
.unwrap();
assert_eq!(dst.value, 0);
}
#[test]
fn test_merge_group_merges_into_existing() {
let data1 = group_bytes(1, 5);
let data2 = group_bytes(2, 5);
let mut dst = FlatMsg::default();
dst.merge_group(&mut data1.as_slice(), crate::test_ctx(RECURSION_LIMIT), 5)
.unwrap();
assert_eq!(dst.value, 1);
dst.merge_group(&mut data2.as_slice(), crate::test_ctx(RECURSION_LIMIT), 5)
.unwrap();
assert_eq!(dst.value, 2);
}
#[test]
fn test_merge_group_recursion_limit_zero() {
let data = group_bytes(42, 5);
let mut dst = FlatMsg::default();
assert_eq!(
dst.merge_group(&mut data.as_slice(), crate::test_ctx(0), 5),
Err(DecodeError::RecursionLimitExceeded)
);
}
#[test]
fn test_merge_group_recursion_limit_one_succeeds() {
let data = group_bytes(7, 5);
let mut dst = FlatMsg::default();
dst.merge_group(&mut data.as_slice(), crate::test_ctx(1), 5)
.unwrap();
assert_eq!(dst.value, 7);
}
#[test]
fn test_merge_group_mismatched_end() {
use crate::encoding::{Tag, WireType};
let mut data = alloc::vec::Vec::new();
Tag::new(99, WireType::EndGroup).encode(&mut data);
let mut dst = FlatMsg::default();
assert_eq!(
dst.merge_group(&mut data.as_slice(), crate::test_ctx(RECURSION_LIMIT), 5),
Err(DecodeError::InvalidEndGroup(99))
);
}
#[test]
fn test_merge_group_truncated() {
use crate::encoding::{Tag, WireType};
let mut data = alloc::vec::Vec::new();
Tag::new(1, WireType::Varint).encode(&mut data);
crate::types::encode_int32(42, &mut data);
let mut dst = FlatMsg::default();
assert_eq!(
dst.merge_group(&mut data.as_slice(), crate::test_ctx(RECURSION_LIMIT), 5),
Err(DecodeError::UnexpectedEof)
);
}
#[test]
fn test_merge_group_empty_buffer() {
let mut dst = FlatMsg::default();
assert_eq!(
dst.merge_group(&mut [].as_slice(), crate::test_ctx(RECURSION_LIMIT), 5),
Err(DecodeError::UnexpectedEof)
);
}
#[test]
fn test_merge_group_unknown_fields_skipped() {
use crate::encoding::{Tag, WireType};
let mut data = alloc::vec::Vec::new();
Tag::new(99, WireType::Varint).encode(&mut data);
crate::encoding::encode_varint(0, &mut data);
Tag::new(1, WireType::Varint).encode(&mut data);
crate::types::encode_int32(99, &mut data);
Tag::new(5, WireType::EndGroup).encode(&mut data);
let mut dst = FlatMsg::default();
dst.merge_group(&mut data.as_slice(), crate::test_ctx(RECURSION_LIMIT), 5)
.unwrap();
assert_eq!(dst.value, 99);
}
#[test]
fn test_merge_group_trailing_data_preserved() {
let mut data = group_bytes(42, 5);
data.extend_from_slice(&[0xDE, 0xAD]);
let mut cur = data.as_slice();
let mut dst = FlatMsg::default();
dst.merge_group(&mut cur, crate::test_ctx(RECURSION_LIMIT), 5)
.unwrap();
assert_eq!(dst.value, 42);
assert_eq!(cur, &[0xDE, 0xAD]);
}
#[cfg(feature = "std")]
mod read_varint_tests {
use super::super::read_varint;
use crate::encoding::encode_varint;
#[test]
fn roundtrip_values() {
let cases: &[u64] = &[0, 1, 127, 128, 300, 1 << 14, 1 << 35, 1 << 63, u64::MAX];
for &v in cases {
let mut buf = Vec::new();
encode_varint(v, &mut buf);
let got = read_varint(&mut buf.as_slice()).unwrap();
assert_eq!(got, v, "roundtrip failed for {v}");
}
}
#[test]
fn rejects_10th_byte_overflow() {
let mut bad: Vec<u8> = vec![0xFF; 9];
bad.push(0x02);
let err = read_varint(&mut bad.as_slice()).unwrap_err();
assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
}
#[test]
fn rejects_11th_byte() {
let bad: &[u8] = &[0xFF; 10];
let err = read_varint(&mut &bad[..]).unwrap_err();
assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
}
#[test]
fn u64_max_roundtrips() {
let mut buf = Vec::new();
encode_varint(u64::MAX, &mut buf);
assert_eq!(buf.len(), 10);
assert_eq!(buf[9], 0x01);
let got = read_varint(&mut buf.as_slice()).unwrap();
assert_eq!(got, u64::MAX);
}
#[test]
fn eof_before_terminator_is_error() {
let bad: &[u8] = &[0x80];
let err = read_varint(&mut &bad[..]).unwrap_err();
assert_eq!(err.kind(), std::io::ErrorKind::UnexpectedEof);
}
#[test]
fn empty_input_is_error() {
let err = read_varint(&mut &[][..]).unwrap_err();
assert_eq!(err.kind(), std::io::ErrorKind::UnexpectedEof);
}
}
#[cfg(feature = "std")]
mod reader_tests {
use super::*;
#[test]
fn decode_reader_basic() {
let src = FlatMsg { value: 42 };
let bytes = src.encode_to_vec();
let msg: FlatMsg = DecodeOptions::new()
.decode_reader(&mut bytes.as_slice())
.unwrap();
assert_eq!(msg.value, 42);
}
#[test]
fn decode_reader_rejects_oversize() {
let src = FlatMsg { value: 42 };
let bytes = src.encode_to_vec();
let err = DecodeOptions::new()
.with_max_message_size(1)
.decode_reader::<FlatMsg>(&mut bytes.as_slice())
.unwrap_err();
assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
}
#[test]
fn decode_reader_exact_boundary() {
let src = FlatMsg { value: 42 };
let bytes = src.encode_to_vec();
let msg: FlatMsg = DecodeOptions::new()
.with_max_message_size(bytes.len())
.decode_reader(&mut bytes.as_slice())
.unwrap();
assert_eq!(msg.value, 42);
}
#[test]
fn decode_options_reader_size_unbounded_getter_tracks_builder_order() {
let opts = DecodeOptions::new();
assert!(!opts.is_reader_size_unbounded());
let opts = opts.without_reader_size_limit();
assert!(opts.is_reader_size_unbounded());
let opts = opts.with_max_message_size(1024);
assert!(!opts.is_reader_size_unbounded());
}
#[test]
fn decode_reader_without_size_limit_does_not_overflow() {
let src = FlatMsg { value: 42 };
let bytes = src.encode_to_vec();
let msg: FlatMsg = DecodeOptions::new()
.without_reader_size_limit()
.decode_reader(&mut bytes.as_slice())
.unwrap();
assert_eq!(msg.value, 42);
}
#[test]
fn decode_reader_with_max_message_size_after_without_limit_reenables_limit() {
let src = FlatMsg { value: 42 };
let bytes = src.encode_to_vec();
let opts = DecodeOptions::new()
.without_reader_size_limit()
.with_max_message_size(1);
assert!(!opts.is_reader_size_unbounded());
let err = opts
.decode_reader::<FlatMsg>(&mut bytes.as_slice())
.unwrap_err();
assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
}
#[test]
fn decode_reader_without_size_limit_keeps_slice_limit() {
let src = FlatMsg { value: 42 };
let bytes = src.encode_to_vec();
let opts = DecodeOptions::new()
.with_max_message_size(1)
.without_reader_size_limit();
assert!(opts.is_reader_size_unbounded());
let slice_result: Result<FlatMsg, _> = opts.decode_from_slice(&bytes);
assert_eq!(slice_result, Err(DecodeError::MessageTooLarge));
let msg: FlatMsg = opts.decode_reader(&mut bytes.as_slice()).unwrap();
assert_eq!(msg.value, 42);
}
#[test]
fn decode_reader_propagates_read_error() {
struct ErrReader;
impl std::io::Read for ErrReader {
fn read(&mut self, _: &mut [u8]) -> std::io::Result<usize> {
Err(std::io::Error::new(std::io::ErrorKind::BrokenPipe, "gone"))
}
}
let err = DecodeOptions::new()
.decode_reader::<FlatMsg>(&mut ErrReader)
.unwrap_err();
assert_eq!(err.kind(), std::io::ErrorKind::BrokenPipe);
}
#[test]
fn decode_length_delimited_reader_basic() {
let src = FlatMsg { value: 99 };
let mut ld = Vec::new();
src.encode_length_delimited(&mut ld);
let msg: FlatMsg = DecodeOptions::new()
.decode_length_delimited_reader(&mut ld.as_slice())
.unwrap();
assert_eq!(msg.value, 99);
}
#[test]
fn decode_length_delimited_reader_rejects_oversize_prefix() {
let src = FlatMsg { value: 99 };
let mut ld = Vec::new();
src.encode_length_delimited(&mut ld);
let err = DecodeOptions::new()
.with_max_message_size(1)
.decode_length_delimited_reader::<FlatMsg>(&mut ld.as_slice())
.unwrap_err();
assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
}
#[test]
fn decode_length_delimited_reader_zero_length() {
let stream = [0x00u8, 0xFF]; let mut cursor = std::io::Cursor::new(stream);
let msg: FlatMsg = DecodeOptions::new()
.decode_length_delimited_reader(&mut cursor)
.unwrap();
assert_eq!(msg, FlatMsg::default());
assert_eq!(cursor.position(), 1);
}
#[test]
fn decode_length_delimited_reader_truncated_reports_eof() {
let src = FlatMsg { value: 99 };
let mut ld = Vec::new();
src.encode_length_delimited(&mut ld);
ld.truncate(ld.len() - 1);
let err = DecodeOptions::new()
.decode_length_delimited_reader::<FlatMsg>(&mut ld.as_slice())
.unwrap_err();
assert_eq!(err.kind(), std::io::ErrorKind::UnexpectedEof);
}
#[test]
fn decode_length_delimited_reader_huge_claim_without_delivery() {
let mut stream = Vec::new();
encode_varint(0x7FFF_FFF0, &mut stream); let err = DecodeOptions::new()
.decode_length_delimited_reader::<FlatMsg>(&mut stream.as_slice())
.unwrap_err();
assert_eq!(err.kind(), std::io::ErrorKind::UnexpectedEof);
}
#[test]
fn decode_length_delimited_without_reader_size_limit_keeps_protobuf_cap() {
let mut stream = Vec::new();
encode_varint(0x8000_0000, &mut stream);
let result: Result<FlatMsg, _> = DecodeOptions::new()
.without_reader_size_limit()
.decode_length_delimited(&mut stream.as_slice());
assert_eq!(result, Err(DecodeError::MessageTooLarge));
}
#[test]
fn decode_length_delimited_reader_without_size_limit_keeps_protobuf_cap() {
let mut stream = Vec::new();
encode_varint(0x8000_0000, &mut stream);
let err = DecodeOptions::new()
.without_reader_size_limit()
.decode_length_delimited_reader::<FlatMsg>(&mut stream.as_slice())
.unwrap_err();
assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
}
#[test]
fn decode_length_delimited_reader_sequential() {
let a = FlatMsg { value: 10 };
let b = FlatMsg { value: 20 };
let mut stream = Vec::new();
a.encode_length_delimited(&mut stream);
b.encode_length_delimited(&mut stream);
let mut cursor = std::io::Cursor::new(stream);
let first: FlatMsg = DecodeOptions::new()
.decode_length_delimited_reader(&mut cursor)
.unwrap();
assert_eq!(first.value, 10);
let second: FlatMsg = DecodeOptions::new()
.decode_length_delimited_reader(&mut cursor)
.unwrap();
assert_eq!(second.value, 20);
}
#[test]
fn decode_length_delimited_reader_truncated_body() {
let mut buf = Vec::new();
crate::encoding::encode_varint(100, &mut buf);
buf.push(0x08);
let err = DecodeOptions::new()
.decode_length_delimited_reader::<FlatMsg>(&mut buf.as_slice())
.unwrap_err();
assert_eq!(err.kind(), std::io::ErrorKind::UnexpectedEof);
}
}
}