use crate::bytes::{Buf, BufMut};
use crate::encoding::{
check_wire_type, decode_varint, encode_varint, skip_field_depth, varint_len, Tag, WireType,
};
use crate::error::DecodeError;
use crate::types;
use crate::{DecodeContext, EnumValue, Enumeration, Message, SizeCache};
use core::hash::Hash;
pub type Map<K, V> = crate::__private::HashMap<K, V>;
#[rustversion::attr(
since(1.78),
diagnostic::on_unimplemented(
message = "`{Self}` cannot be used as a buffa custom map type",
note = "buffa owns `MapStorage`, so a foreign type can't implement it directly (orphan rule). \
Wrap it in a crate-local newtype and implement `MapStorage` on the newtype. \
See the `custom-types` example in the buffa repository for a template."
)
)]
pub trait MapStorage {
type Key;
type Value;
fn storage_len(&self) -> usize;
fn storage_insert(&mut self, key: Self::Key, value: Self::Value);
fn storage_clear(&mut self);
fn storage_iter<'a>(&'a self) -> impl Iterator<Item = (&'a Self::Key, &'a Self::Value)>
where
Self::Key: 'a,
Self::Value: 'a;
}
macro_rules! map_storage_hashmap {
($($ty:tt)*) => {
impl<K: Eq + Hash, V, S: core::hash::BuildHasher + Default> MapStorage for $($ty)*<K, V, S> {
type Key = K;
type Value = V;
#[inline]
fn storage_len(&self) -> usize {
self.len()
}
#[inline]
fn storage_insert(&mut self, key: K, value: V) {
self.insert(key, value);
}
#[inline]
fn storage_clear(&mut self) {
self.clear();
}
#[inline]
fn storage_iter<'a>(&'a self) -> impl Iterator<Item = (&'a K, &'a V)>
where
K: 'a,
V: 'a,
{
self.iter()
}
}
};
}
#[cfg(feature = "std")]
map_storage_hashmap!(std::collections::HashMap);
#[cfg(not(feature = "std"))]
map_storage_hashmap!(hashbrown::HashMap);
impl<K: Ord, V> MapStorage for crate::alloc::collections::BTreeMap<K, V> {
type Key = K;
type Value = V;
#[inline]
fn storage_len(&self) -> usize {
self.len()
}
#[inline]
fn storage_insert(&mut self, key: K, value: V) {
self.insert(key, value);
}
#[inline]
fn storage_clear(&mut self) {
self.clear();
}
#[inline]
fn storage_iter<'a>(&'a self) -> impl Iterator<Item = (&'a K, &'a V)>
where
K: 'a,
V: 'a,
{
self.iter()
}
}
mod sealed {
pub trait Sealed {}
}
pub trait MapValueDecode: sealed::Sealed {
type Value: Default;
const WIRE_TYPE: WireType;
const MAY_RETURN_UNKNOWN: bool = false;
fn merge(
value: &mut Self::Value,
buf: &mut impl Buf,
ctx: DecodeContext<'_>,
) -> Result<MapValueDecodeStatus, DecodeError>;
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[non_exhaustive]
pub enum MapValueDecodeStatus {
Known,
Unknown,
}
pub trait MapCodec: MapValueDecode {
const FIXED_LEN: Option<u32> = None;
fn encoded_len(value: &Self::Value) -> u32;
fn encode(value: &Self::Value, buf: &mut impl BufMut);
}
macro_rules! scalar_codec {
($(#[$doc:meta])* $name:ident, $value:ty, $wire:expr, $fixed:expr,
len: $len:expr, encode: $encode:expr, decode: $decode:expr) => {
$(#[$doc])*
pub struct $name;
impl sealed::Sealed for $name {}
impl MapValueDecode for $name {
type Value = $value;
const WIRE_TYPE: WireType = $wire;
#[inline]
fn merge(
value: &mut Self::Value,
buf: &mut impl Buf,
_ctx: DecodeContext<'_>,
) -> Result<MapValueDecodeStatus, DecodeError> {
*value = $decode(buf)?;
Ok(MapValueDecodeStatus::Known)
}
}
impl MapCodec for $name {
const FIXED_LEN: Option<u32> = $fixed;
#[inline]
#[allow(clippy::redundant_closure_call)]
fn encoded_len(value: &Self::Value) -> u32 {
($len)(value) as u32
}
#[inline]
#[allow(clippy::redundant_closure_call)]
fn encode(value: &Self::Value, buf: &mut impl BufMut) {
($encode)(value, buf)
}
}
};
}
scalar_codec!(
Int32, i32, WireType::Varint, None,
len: |v: &i32| types::int32_encoded_len(*v),
encode: |v: &i32, buf: &mut _| types::encode_int32(*v, buf),
decode: types::decode_int32
);
scalar_codec!(
Int64, i64, WireType::Varint, None,
len: |v: &i64| types::int64_encoded_len(*v),
encode: |v: &i64, buf: &mut _| types::encode_int64(*v, buf),
decode: types::decode_int64
);
scalar_codec!(
Uint32, u32, WireType::Varint, None,
len: |v: &u32| types::uint32_encoded_len(*v),
encode: |v: &u32, buf: &mut _| types::encode_uint32(*v, buf),
decode: types::decode_uint32
);
scalar_codec!(
Uint64, u64, WireType::Varint, None,
len: |v: &u64| types::uint64_encoded_len(*v),
encode: |v: &u64, buf: &mut _| types::encode_uint64(*v, buf),
decode: types::decode_uint64
);
scalar_codec!(
Sint32, i32, WireType::Varint, None,
len: |v: &i32| types::sint32_encoded_len(*v),
encode: |v: &i32, buf: &mut _| types::encode_sint32(*v, buf),
decode: types::decode_sint32
);
scalar_codec!(
Sint64, i64, WireType::Varint, None,
len: |v: &i64| types::sint64_encoded_len(*v),
encode: |v: &i64, buf: &mut _| types::encode_sint64(*v, buf),
decode: types::decode_sint64
);
scalar_codec!(
Bool, bool, WireType::Varint, Some(types::BOOL_ENCODED_LEN as u32),
len: |_: &bool| types::BOOL_ENCODED_LEN,
encode: |v: &bool, buf: &mut _| types::encode_bool(*v, buf),
decode: types::decode_bool
);
scalar_codec!(
Fixed32, u32, WireType::Fixed32, Some(types::FIXED32_ENCODED_LEN as u32),
len: |_: &u32| types::FIXED32_ENCODED_LEN,
encode: |v: &u32, buf: &mut _| types::encode_fixed32(*v, buf),
decode: types::decode_fixed32
);
scalar_codec!(
Fixed64, u64, WireType::Fixed64, Some(types::FIXED64_ENCODED_LEN as u32),
len: |_: &u64| types::FIXED64_ENCODED_LEN,
encode: |v: &u64, buf: &mut _| types::encode_fixed64(*v, buf),
decode: types::decode_fixed64
);
scalar_codec!(
Sfixed32, i32, WireType::Fixed32, Some(types::FIXED32_ENCODED_LEN as u32),
len: |_: &i32| types::FIXED32_ENCODED_LEN,
encode: |v: &i32, buf: &mut _| types::encode_sfixed32(*v, buf),
decode: types::decode_sfixed32
);
scalar_codec!(
Sfixed64, i64, WireType::Fixed64, Some(types::FIXED64_ENCODED_LEN as u32),
len: |_: &i64| types::FIXED64_ENCODED_LEN,
encode: |v: &i64, buf: &mut _| types::encode_sfixed64(*v, buf),
decode: types::decode_sfixed64
);
scalar_codec!(
Float, f32, WireType::Fixed32, Some(types::FIXED32_ENCODED_LEN as u32),
len: |_: &f32| types::FIXED32_ENCODED_LEN,
encode: |v: &f32, buf: &mut _| types::encode_float(*v, buf),
decode: types::decode_float
);
scalar_codec!(
Double, f64, WireType::Fixed64, Some(types::FIXED64_ENCODED_LEN as u32),
len: |_: &f64| types::FIXED64_ENCODED_LEN,
encode: |v: &f64, buf: &mut _| types::encode_double(*v, buf),
decode: types::decode_double
);
scalar_codec!(
Str, crate::alloc::string::String, WireType::LengthDelimited, None,
len: |v: &crate::alloc::string::String| types::string_encoded_len(v),
encode: |v: &crate::alloc::string::String, buf: &mut _| types::encode_string(v, buf),
decode: types::decode_string
);
scalar_codec!(
BytesVec, crate::alloc::vec::Vec<u8>, WireType::LengthDelimited, None,
len: |v: &crate::alloc::vec::Vec<u8>| types::bytes_encoded_len(v),
encode: |v: &crate::alloc::vec::Vec<u8>, buf: &mut _| types::encode_bytes(v, buf),
decode: types::decode_bytes
);
scalar_codec!(
BytesBuf, crate::bytes::Bytes, WireType::LengthDelimited, None,
len: |v: &crate::bytes::Bytes| types::bytes_encoded_len(v),
encode: |v: &crate::bytes::Bytes, buf: &mut _| types::encode_bytes(v, buf),
decode: types::decode_bytes_to_bytes
);
pub struct ProtoBytesMap<B>(core::marker::PhantomData<B>);
impl<B: crate::types::ProtoBytes> sealed::Sealed for ProtoBytesMap<B> {}
impl<B: crate::types::ProtoBytes> MapValueDecode for ProtoBytesMap<B> {
type Value = B;
const WIRE_TYPE: WireType = WireType::LengthDelimited;
#[inline]
fn merge(
value: &mut Self::Value,
buf: &mut impl Buf,
_ctx: DecodeContext<'_>,
) -> Result<MapValueDecodeStatus, DecodeError> {
*value = crate::types::decode_bytes_to::<B>(buf)?;
Ok(MapValueDecodeStatus::Known)
}
}
impl<B: crate::types::ProtoBytes> MapCodec for ProtoBytesMap<B> {
#[inline]
fn encoded_len(value: &Self::Value) -> u32 {
types::bytes_encoded_len(value.as_ref()) as u32
}
#[inline]
fn encode(value: &Self::Value, buf: &mut impl BufMut) {
types::encode_bytes(value.as_ref(), buf);
}
}
pub struct ProtoStringMap<S>(core::marker::PhantomData<S>);
impl<S: crate::types::ProtoString> sealed::Sealed for ProtoStringMap<S> {}
impl<S: crate::types::ProtoString> MapValueDecode for ProtoStringMap<S> {
type Value = S;
const WIRE_TYPE: WireType = WireType::LengthDelimited;
#[inline]
fn merge(
value: &mut Self::Value,
buf: &mut impl Buf,
_ctx: DecodeContext<'_>,
) -> Result<MapValueDecodeStatus, DecodeError> {
*value = crate::types::decode_string_to::<S>(buf)?;
Ok(MapValueDecodeStatus::Known)
}
}
impl<S: crate::types::ProtoString> MapCodec for ProtoStringMap<S> {
#[inline]
fn encoded_len(value: &Self::Value) -> u32 {
types::string_encoded_len(value.as_ref()) as u32
}
#[inline]
fn encode(value: &Self::Value, buf: &mut impl BufMut) {
types::encode_string(value.as_ref(), buf);
}
}
pub struct OpenEnum<E>(core::marker::PhantomData<E>);
impl<E: Enumeration> sealed::Sealed for OpenEnum<E> {}
impl<E: Enumeration> MapValueDecode for OpenEnum<E> {
type Value = EnumValue<E>;
const WIRE_TYPE: WireType = WireType::Varint;
#[inline]
fn merge(
value: &mut Self::Value,
buf: &mut impl Buf,
_ctx: DecodeContext<'_>,
) -> Result<MapValueDecodeStatus, DecodeError> {
*value = EnumValue::from(types::decode_int32(buf)?);
Ok(MapValueDecodeStatus::Known)
}
}
impl<E: Enumeration> MapCodec for OpenEnum<E> {
#[inline]
fn encoded_len(value: &Self::Value) -> u32 {
types::int32_encoded_len(value.to_i32()) as u32
}
#[inline]
fn encode(value: &Self::Value, buf: &mut impl BufMut) {
types::encode_int32(value.to_i32(), buf);
}
}
pub struct ClosedEnum<E>(core::marker::PhantomData<E>);
impl<E: Enumeration + Default> sealed::Sealed for ClosedEnum<E> {}
impl<E: Enumeration + Default> MapValueDecode for ClosedEnum<E> {
type Value = E;
const WIRE_TYPE: WireType = WireType::Varint;
const MAY_RETURN_UNKNOWN: bool = true;
#[inline]
fn merge(
value: &mut Self::Value,
buf: &mut impl Buf,
_ctx: DecodeContext<'_>,
) -> Result<MapValueDecodeStatus, DecodeError> {
let raw = types::decode_int32(buf)?;
if let Some(v) = E::from_i32(raw) {
*value = v;
Ok(MapValueDecodeStatus::Known)
} else {
Ok(MapValueDecodeStatus::Unknown)
}
}
}
impl<E: Enumeration + Default> MapCodec for ClosedEnum<E> {
#[inline]
fn encoded_len(value: &Self::Value) -> u32 {
types::int32_encoded_len(value.to_i32()) as u32
}
#[inline]
fn encode(value: &Self::Value, buf: &mut impl BufMut) {
types::encode_int32(value.to_i32(), buf);
}
}
pub struct Msg<M>(core::marker::PhantomData<M>);
impl<M: Message + Default> sealed::Sealed for Msg<M> {}
impl<M: Message + Default> MapValueDecode for Msg<M> {
type Value = M;
const WIRE_TYPE: WireType = WireType::LengthDelimited;
#[inline]
fn merge(
value: &mut Self::Value,
buf: &mut impl Buf,
ctx: DecodeContext<'_>,
) -> Result<MapValueDecodeStatus, DecodeError> {
Message::merge_length_delimited(value, buf, ctx)?;
Ok(MapValueDecodeStatus::Known)
}
}
const ENTRY_TAG_LEN: u32 = 2;
#[inline]
fn entry_len<KC: MapCodec, VC: MapCodec>(k: &KC::Value, v: &VC::Value) -> u32 {
ENTRY_TAG_LEN + KC::encoded_len(k) + VC::encoded_len(v)
}
pub fn field_len<KC: MapCodec, VC: MapCodec, C>(map: &C, outer_tag_len: u32) -> u32
where
C: MapStorage<Key = KC::Value, Value = VC::Value>,
{
if let (Some(kf), Some(vf)) = (KC::FIXED_LEN, VC::FIXED_LEN) {
let entry = ENTRY_TAG_LEN + kf + vf;
return map.storage_len() as u32
* (outer_tag_len + varint_len(entry as u64) as u32 + entry);
}
let mut size = 0u32;
for (k, v) in map.storage_iter() {
let entry = entry_len::<KC, VC>(k, v);
size += outer_tag_len + varint_len(entry as u64) as u32 + entry;
}
size
}
pub fn write_field<KC: MapCodec, VC: MapCodec, C>(map: &C, field_number: u32, buf: &mut impl BufMut)
where
C: MapStorage<Key = KC::Value, Value = VC::Value>,
{
for (k, v) in map.storage_iter() {
let entry = entry_len::<KC, VC>(k, v);
Tag::new(field_number, WireType::LengthDelimited).encode(buf);
encode_varint(entry as u64, buf);
Tag::new(1, KC::WIRE_TYPE).encode(buf);
KC::encode(k, buf);
Tag::new(2, VC::WIRE_TYPE).encode(buf);
VC::encode(v, buf);
}
}
pub fn message_field_len<KC: MapCodec, M: Message, C>(
map: &C,
outer_tag_len: u32,
cache: &mut SizeCache,
) -> u32
where
C: MapStorage<Key = KC::Value, Value = M>,
{
let mut size = 0u32;
for (k, v) in map.storage_iter() {
let slot = cache.reserve();
let inner = v.compute_size(cache);
cache.set(slot, inner);
let entry = ENTRY_TAG_LEN + KC::encoded_len(k) + varint_len(inner as u64) as u32 + inner;
size += outer_tag_len + varint_len(entry as u64) as u32 + entry;
}
size
}
pub fn write_message_field<KC: MapCodec, M: Message, C>(
map: &C,
field_number: u32,
cache: &mut SizeCache,
buf: &mut impl BufMut,
) where
C: MapStorage<Key = KC::Value, Value = M>,
{
for (k, v) in map.storage_iter() {
let inner = cache.consume_next();
let entry = ENTRY_TAG_LEN + KC::encoded_len(k) + varint_len(inner as u64) as u32 + inner;
Tag::new(field_number, WireType::LengthDelimited).encode(buf);
encode_varint(entry as u64, buf);
Tag::new(1, KC::WIRE_TYPE).encode(buf);
KC::encode(k, buf);
Tag::new(2, WireType::LengthDelimited).encode(buf);
encode_varint(inner as u64, buf);
v.write_to(cache, buf);
}
}
fn merge_entry_contents<KC, VC>(
key: &mut KC::Value,
val: &mut VC::Value,
buf: &mut impl Buf,
ctx: DecodeContext<'_>,
) -> Result<MapValueDecodeStatus, DecodeError>
where
KC: MapValueDecode,
VC: MapValueDecode,
{
let mut val_status = MapValueDecodeStatus::Known;
while buf.has_remaining() {
let entry_tag = Tag::decode(buf)?;
match entry_tag.field_number() {
1 => {
check_wire_type(entry_tag, KC::WIRE_TYPE)?;
KC::merge(key, buf, ctx)?;
}
2 => {
check_wire_type(entry_tag, VC::WIRE_TYPE)?;
val_status = VC::merge(val, buf, ctx)?;
}
_ => {
skip_field_depth(entry_tag, buf, ctx.depth())?;
}
}
}
Ok(val_status)
}
pub fn merge_entry<KC, VC, C>(
map: &mut C,
buf: &mut impl Buf,
ctx: DecodeContext<'_>,
) -> Result<(), DecodeError>
where
KC: MapValueDecode,
VC: MapValueDecode,
C: MapStorage<Key = KC::Value, Value = VC::Value>,
{
merge_entry_with_unknowns::<KC, VC, C>(map, buf, ctx, None)
}
pub fn merge_entry_with_unknowns<KC, VC, C>(
map: &mut C,
buf: &mut impl Buf,
ctx: DecodeContext<'_>,
unknown_fields: Option<(u32, &mut crate::UnknownFields)>,
) -> Result<(), DecodeError>
where
KC: MapValueDecode,
VC: MapValueDecode,
C: MapStorage<Key = KC::Value, Value = VC::Value>,
{
let entry_len = decode_varint(buf)?;
let entry_len = usize::try_from(entry_len).map_err(|_| DecodeError::MessageTooLarge)?;
if buf.remaining() < entry_len {
return Err(DecodeError::UnexpectedEof);
}
let mut key: KC::Value = Default::default();
let mut val: VC::Value = Default::default();
if unknown_fields.is_some() && (KC::MAY_RETURN_UNKNOWN || VC::MAY_RETURN_UNKNOWN) {
if buf.chunk().len() >= entry_len {
let preserved = {
let entry_slice = &buf.chunk()[..entry_len];
let mut entry_cur = entry_slice;
let status =
merge_entry_contents::<KC, VC>(&mut key, &mut val, &mut entry_cur, ctx)?;
matches!(status, MapValueDecodeStatus::Unknown).then(|| entry_slice.to_vec())
};
buf.advance(entry_len);
match preserved {
None => map.storage_insert(key, val),
Some(payload) => {
if let Some((field_number, unknown_fields)) = unknown_fields {
ctx.register_unknown_field()?;
unknown_fields.push(crate::UnknownField {
number: field_number,
data: crate::UnknownFieldData::LengthDelimited(payload),
});
}
}
}
return Ok(());
}
let entry_payload = buf.copy_to_bytes(entry_len);
let mut entry_cur = entry_payload.clone();
let status = merge_entry_contents::<KC, VC>(&mut key, &mut val, &mut entry_cur, ctx)?;
if matches!(status, MapValueDecodeStatus::Known) {
map.storage_insert(key, val);
} else if let Some((field_number, unknown_fields)) = unknown_fields {
ctx.register_unknown_field()?;
unknown_fields.push(crate::UnknownField {
number: field_number,
data: crate::UnknownFieldData::LengthDelimited(entry_payload.to_vec()),
});
}
return Ok(());
}
let entry_limit = buf.remaining() - entry_len;
let mut val_status = MapValueDecodeStatus::Known;
while buf.remaining() > entry_limit {
let entry_tag = Tag::decode(buf)?;
match entry_tag.field_number() {
1 => {
check_wire_type(entry_tag, KC::WIRE_TYPE)?;
KC::merge(&mut key, buf, ctx)?;
}
2 => {
check_wire_type(entry_tag, VC::WIRE_TYPE)?;
val_status = VC::merge(&mut val, buf, ctx)?;
}
_ => {
skip_field_depth(entry_tag, buf, ctx.depth())?;
}
}
}
if buf.remaining() != entry_limit {
let remaining = buf.remaining();
if remaining > entry_limit {
buf.advance(remaining - entry_limit);
} else {
return Err(DecodeError::UnexpectedEof);
}
}
if matches!(val_status, MapValueDecodeStatus::Known) {
map.storage_insert(key, val);
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::alloc::string::String;
use crate::alloc::vec::Vec;
fn encode_field<KC: MapCodec, VC: MapCodec>(
map: &Map<KC::Value, VC::Value>,
field_number: u32,
outer_tag_len: u32,
) -> Vec<u8>
where
KC::Value: Eq + Hash,
{
let len = field_len::<KC, VC, _>(map, outer_tag_len);
let mut buf = Vec::new();
write_field::<KC, VC, _>(map, field_number, &mut buf);
assert_eq!(buf.len() as u32, len, "field_len must match written bytes");
buf
}
fn decode_field<KC, VC>(mut wire: &[u8]) -> Map<KC::Value, VC::Value>
where
KC: MapValueDecode,
KC::Value: Eq + Hash,
VC: MapValueDecode,
{
let mut map = Map::default();
let limit = core::cell::Cell::new(crate::DEFAULT_UNKNOWN_FIELD_LIMIT);
while !wire.is_empty() {
let tag = Tag::decode(&mut wire).unwrap();
assert_eq!(tag.wire_type(), WireType::LengthDelimited);
let ctx = DecodeContext::new(crate::RECURSION_LIMIT, &limit);
merge_entry::<KC, VC, _>(&mut map, &mut wire, ctx).unwrap();
}
map
}
#[test]
fn string_int32_round_trip() {
let mut map: Map<String, i32> = Map::default();
map.insert("a".into(), 1);
map.insert("bee".into(), -7);
let wire = encode_field::<Str, Int32>(&map, 5, 1);
let back = decode_field::<Str, Int32>(&wire);
assert_eq!(back, map);
}
#[test]
fn proto_string_map_codec_matches_str() {
let mut map: Map<String, i32> = Map::default();
map.insert("a".into(), 1);
map.insert("bee".into(), -7);
let str_wire = encode_field::<Str, Int32>(&map, 5, 1);
let custom_wire = encode_field::<ProtoStringMap<String>, Int32>(&map, 5, 1);
assert_eq!(
decode_field::<ProtoStringMap<String>, Int32>(&str_wire),
map
);
assert_eq!(decode_field::<Str, Int32>(&custom_wire), map);
let mut vmap: Map<i32, String> = Map::default();
vmap.insert(1, "x".into());
let vwire = encode_field::<Int32, ProtoStringMap<String>>(&vmap, 3, 1);
assert_eq!(decode_field::<Int32, ProtoStringMap<String>>(&vwire), vmap);
}
#[test]
fn fixed_fixed_len_fold_matches_written_bytes() {
let mut map: Map<u32, f64> = Map::default();
map.insert(1, 0.5);
map.insert(9, -2.25);
map.insert(1000, 0.0);
let wire = encode_field::<Fixed32, Double>(&map, 3, 1);
let back = decode_field::<Fixed32, Double>(&wire);
assert_eq!(back, map);
}
#[test]
fn missing_key_and_value_take_defaults() {
let wire = [0x00u8];
let mut map: Map<String, i32> = Map::default();
let limit = core::cell::Cell::new(crate::DEFAULT_UNKNOWN_FIELD_LIMIT);
merge_entry::<Str, Int32, _>(&mut map, &mut &wire[..], DecodeContext::new(10, &limit))
.unwrap();
assert_eq!(map.get(""), Some(&0));
}
#[test]
fn unknown_entry_field_is_skipped() {
let mut entry = Vec::new();
Tag::new(1, WireType::LengthDelimited).encode(&mut entry);
types::encode_string("a", &mut entry);
Tag::new(3, WireType::Varint).encode(&mut entry);
encode_varint(99, &mut entry);
Tag::new(2, WireType::Varint).encode(&mut entry);
types::encode_int32(7, &mut entry);
let mut wire = Vec::new();
encode_varint(entry.len() as u64, &mut wire);
wire.extend_from_slice(&entry);
let mut map: Map<String, i32> = Map::default();
let limit = core::cell::Cell::new(crate::DEFAULT_UNKNOWN_FIELD_LIMIT);
merge_entry::<Str, Int32, _>(
&mut map,
&mut wire.as_slice(),
DecodeContext::new(10, &limit),
)
.unwrap();
assert_eq!(map.get("a"), Some(&7));
}
#[test]
fn entry_wire_type_mismatch_errors() {
let mut entry = Vec::new();
Tag::new(1, WireType::Fixed64).encode(&mut entry);
entry.extend_from_slice(&[0u8; 8]);
let mut wire = Vec::new();
encode_varint(entry.len() as u64, &mut wire);
wire.extend_from_slice(&entry);
let mut map: Map<String, i32> = Map::default();
let limit = core::cell::Cell::new(crate::DEFAULT_UNKNOWN_FIELD_LIMIT);
let err = merge_entry::<Str, Int32, _>(
&mut map,
&mut wire.as_slice(),
DecodeContext::new(10, &limit),
)
.unwrap_err();
assert!(matches!(err, DecodeError::WireTypeMismatch { .. }));
}
#[test]
fn truncated_entry_errors() {
let wire = [0x05u8, 0x08];
let mut map: Map<String, i32> = Map::default();
let limit = core::cell::Cell::new(crate::DEFAULT_UNKNOWN_FIELD_LIMIT);
let err =
merge_entry::<Str, Int32, _>(&mut map, &mut &wire[..], DecodeContext::new(10, &limit))
.unwrap_err();
assert!(matches!(err, DecodeError::UnexpectedEof));
}
#[test]
fn message_map_two_pass_round_trip() {
use crate::{DefaultInstance, SizeCache};
#[derive(Clone, PartialEq, Eq, Debug, Default)]
struct FlatMsg {
value: i32,
}
impl DefaultInstance for FlatMsg {
fn default_instance() -> &'static Self {
static INST: crate::__private::OnceBox<FlatMsg> = crate::__private::OnceBox::new();
INST.get_or_init(|| crate::alloc::boxed::Box::new(FlatMsg::default()))
}
}
impl Message for FlatMsg {
fn compute_size(&self, _cache: &mut SizeCache) -> u32 {
if self.value != 0 {
1 + types::int32_encoded_len(self.value) as u32
} else {
0
}
}
fn write_to(&self, _cache: &mut SizeCache, buf: &mut impl BufMut) {
if self.value != 0 {
Tag::new(1, WireType::Varint).encode(buf);
types::encode_int32(self.value, buf);
}
}
fn merge_field(
&mut self,
tag: Tag,
buf: &mut impl Buf,
ctx: DecodeContext<'_>,
) -> Result<(), DecodeError> {
match tag.field_number() {
1 => self.value = types::decode_int32(buf)?,
_ => skip_field_depth(tag, buf, ctx.depth())?,
}
Ok(())
}
fn clear(&mut self) {
*self = Self::default();
}
}
let mut map: Map<i32, FlatMsg> = Map::default();
map.insert(1, FlatMsg { value: 0 }); map.insert(2, FlatMsg { value: -3 }); map.insert(9, FlatMsg { value: 7 });
let mut cache = SizeCache::default();
let len = message_field_len::<Int32, FlatMsg, _>(&map, 1, &mut cache);
let mut wire = Vec::new();
write_message_field::<Int32, FlatMsg, _>(&map, 4, &mut cache, &mut wire);
assert_eq!(wire.len() as u32, len, "size pass must match write pass");
let back = decode_field::<Int32, Msg<FlatMsg>>(&wire);
assert_eq!(back, map);
}
#[test]
fn open_enum_value_preserves_unknown() {
#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug, Default)]
#[repr(i32)]
enum E {
#[default]
A = 0,
}
impl Enumeration for E {
fn from_i32(value: i32) -> Option<Self> {
(value == 0).then_some(E::A)
}
fn to_i32(&self) -> i32 {
*self as i32
}
fn proto_name(&self) -> &'static str {
"A"
}
fn from_proto_name(name: &str) -> Option<Self> {
(name == "A").then_some(E::A)
}
}
let mut map: Map<i32, EnumValue<E>> = Map::default();
map.insert(1, EnumValue::Unknown(42));
let wire = encode_field::<Int32, OpenEnum<E>>(&map, 2, 1);
let back = decode_field::<Int32, OpenEnum<E>>(&wire);
assert_eq!(back.get(&1), Some(&EnumValue::Unknown(42)));
let back = decode_field::<Int32, ClosedEnum<E>>(&wire);
assert!(!back.contains_key(&1));
}
#[test]
fn closed_enum_unknown_preserves_whole_entry_when_requested() {
#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug, Default)]
#[repr(i32)]
enum E {
#[default]
A = 0,
B = 1,
}
impl Enumeration for E {
fn from_i32(value: i32) -> Option<Self> {
match value {
0 => Some(E::A),
1 => Some(E::B),
_ => None,
}
}
fn to_i32(&self) -> i32 {
*self as i32
}
fn proto_name(&self) -> &'static str {
match self {
E::A => "A",
E::B => "B",
}
}
}
let mut entry = Vec::new();
Tag::new(1, WireType::Varint).encode(&mut entry);
types::encode_int32(7, &mut entry);
Tag::new(2, WireType::Varint).encode(&mut entry);
types::encode_int32(99, &mut entry);
let mut wire = Vec::new();
encode_varint(entry.len() as u64, &mut wire);
wire.extend_from_slice(&entry);
let mut map: Map<i32, E> = Map::default();
let mut unknown_fields = crate::UnknownFields::new();
let limit = core::cell::Cell::new(crate::DEFAULT_UNKNOWN_FIELD_LIMIT);
merge_entry_with_unknowns::<Int32, ClosedEnum<E>, _>(
&mut map,
&mut wire.as_slice(),
DecodeContext::new(10, &limit),
Some((5, &mut unknown_fields)),
)
.unwrap();
assert!(map.is_empty());
let unknowns: Vec<_> = unknown_fields.iter().collect();
assert_eq!(unknowns.len(), 1);
assert_eq!(unknowns[0].number, 5);
assert!(matches!(
&unknowns[0].data,
crate::UnknownFieldData::LengthDelimited(payload) if payload == &entry
));
let mut map: Map<i32, E> = Map::default();
let mut unknown_fields = crate::UnknownFields::new();
let (a, b) = wire.split_at(3);
let mut chained = bytes::Buf::chain(a, b);
merge_entry_with_unknowns::<Int32, ClosedEnum<E>, _>(
&mut map,
&mut chained,
DecodeContext::new(10, &limit),
Some((5, &mut unknown_fields)),
)
.unwrap();
assert!(map.is_empty());
let unknowns: Vec<_> = unknown_fields.iter().collect();
assert_eq!(unknowns.len(), 1);
assert!(matches!(
&unknowns[0].data,
crate::UnknownFieldData::LengthDelimited(payload) if payload == &entry
));
let mut entry = Vec::new();
Tag::new(1, WireType::Varint).encode(&mut entry);
types::encode_int32(7, &mut entry);
Tag::new(2, WireType::Varint).encode(&mut entry);
types::encode_int32(99, &mut entry);
Tag::new(2, WireType::Varint).encode(&mut entry);
types::encode_int32(1, &mut entry);
let mut wire = Vec::new();
encode_varint(entry.len() as u64, &mut wire);
wire.extend_from_slice(&entry);
let mut map: Map<i32, E> = Map::default();
let mut unknown_fields = crate::UnknownFields::new();
merge_entry_with_unknowns::<Int32, ClosedEnum<E>, _>(
&mut map,
&mut wire.as_slice(),
DecodeContext::new(10, &limit),
Some((5, &mut unknown_fields)),
)
.unwrap();
assert_eq!(map.get(&7), Some(&E::B));
assert_eq!(unknown_fields.iter().count(), 0);
}
}