use bytes::{Buf, BufMut};
use crate::buf::ReverseBuf;
use crate::encoding::value_traits::{
Collection, DistinguishedCollection, EmptyState, NewForOverwrite,
};
use crate::encoding::{
check_wire_type, peek_repeated_field, Capped, DecodeContext, DistinguishedEncoder,
DistinguishedValueEncoder, Encoder, FieldEncoder, General, Packed, TagMeasurer, TagRevWriter,
TagWriter, ValueEncoder, WireType, Wiretyped,
};
use crate::DecodeErrorKind::{InvalidValue, UnexpectedlyRepeated};
use crate::{Canonicity, DecodeError};
pub struct Unpacked<E = General>(E);
#[inline]
pub(crate) fn decode<T, E>(
wire_type: WireType,
collection: &mut T,
mut buf: Capped<impl Buf + ?Sized>,
ctx: DecodeContext,
) -> Result<(), DecodeError>
where
T: Collection,
T::Item: NewForOverwrite + ValueEncoder<E>,
{
check_wire_type(<T::Item as Wiretyped<E>>::WIRE_TYPE, wire_type)?;
loop {
let mut new_item = T::Item::new_for_overwrite();
ValueEncoder::<E>::decode_value(&mut new_item, buf.lend(), ctx.clone())?;
collection.insert(new_item)?;
if let Some(next_wire_type) = peek_repeated_field(&mut buf) {
check_wire_type(<T::Item as Wiretyped<E>>::WIRE_TYPE, next_wire_type)?;
} else {
break;
}
}
Ok(())
}
#[inline]
pub(crate) fn decode_array<T, const N: usize, E>(
wire_type: WireType,
arr: &mut [T; N],
mut buf: Capped<impl Buf + ?Sized>,
ctx: DecodeContext,
) -> Result<(), DecodeError>
where
T: ValueEncoder<E>,
{
check_wire_type(<T as Wiretyped<E>>::WIRE_TYPE, wire_type)?;
for (i, dest) in arr.iter_mut().enumerate() {
if i > 0 {
if let Some(next_wire_type) = peek_repeated_field(&mut buf) {
check_wire_type(<T as Wiretyped<E>>::WIRE_TYPE, next_wire_type)?;
} else {
return Err(DecodeError::new(InvalidValue));
}
}
ValueEncoder::<E>::decode_value(dest, buf.lend(), ctx.clone())?;
}
if peek_repeated_field(&mut buf).is_some() {
Err(DecodeError::new(InvalidValue))
} else {
Ok(())
}
}
#[inline]
pub(crate) fn decode_distinguished<T, E>(
wire_type: WireType,
collection: &mut T,
mut buf: Capped<impl Buf + ?Sized>,
ctx: DecodeContext,
) -> Result<Canonicity, DecodeError>
where
T: DistinguishedCollection,
T::Item: NewForOverwrite + Eq + DistinguishedValueEncoder<E>,
{
check_wire_type(<T::Item as Wiretyped<E>>::WIRE_TYPE, wire_type)?;
let mut canon = Canonicity::Canonical;
loop {
let mut new_item = T::Item::new_for_overwrite();
canon.update(
DistinguishedValueEncoder::<E>::decode_value_distinguished::<true>(
&mut new_item,
buf.lend(),
ctx.clone(),
)?,
);
canon.update(collection.insert_distinguished(new_item)?);
if let Some(next_wire_type) = peek_repeated_field(&mut buf) {
check_wire_type(<T::Item as Wiretyped<E>>::WIRE_TYPE, next_wire_type)?;
} else {
break;
}
}
Ok(canon)
}
#[inline]
pub(crate) fn decode_distinguished_array<T, const N: usize, E>(
wire_type: WireType,
arr: &mut [T; N],
mut buf: Capped<impl Buf + ?Sized>,
ctx: DecodeContext,
) -> Result<Canonicity, DecodeError>
where
T: Eq + EmptyState + DistinguishedValueEncoder<E>,
{
check_wire_type(<T as Wiretyped<E>>::WIRE_TYPE, wire_type)?;
let mut canon = Canonicity::Canonical;
for (i, dest) in arr.iter_mut().enumerate() {
if i > 0 {
if let Some(next_wire_type) = peek_repeated_field(&mut buf) {
check_wire_type(<T as Wiretyped<E>>::WIRE_TYPE, next_wire_type)?;
} else {
return Err(DecodeError::new(InvalidValue));
}
}
canon.update(
DistinguishedValueEncoder::<E>::decode_value_distinguished::<true>(
dest,
buf.lend(),
ctx.clone(),
)?,
);
}
if peek_repeated_field(&mut buf).is_some() {
Err(DecodeError::new(InvalidValue))
} else {
Ok(if EmptyState::is_empty(arr) {
Canonicity::NotCanonical
} else {
canon
})
}
}
impl<C, T, E> Encoder<Unpacked<E>> for C
where
C: Collection<Item = T>,
T: NewForOverwrite + ValueEncoder<E>,
{
#[inline]
fn encode<B: BufMut + ?Sized>(tag: u32, value: &C, buf: &mut B, tw: &mut TagWriter) {
for val in value.iter() {
FieldEncoder::<E>::encode_field(tag, val, buf, tw);
}
}
#[inline]
fn prepend_encode<B: ReverseBuf + ?Sized>(
tag: u32,
value: &Self,
buf: &mut B,
tw: &mut TagRevWriter,
) {
for val in value.reversed() {
FieldEncoder::<E>::prepend_field(tag, val, buf, tw);
}
}
#[inline]
fn encoded_len(tag: u32, value: &C, tm: &mut impl TagMeasurer) -> usize {
if !value.is_empty() {
tm.key_len(tag) + ValueEncoder::<E>::many_values_encoded_len(value.iter()) + value.len()
- 1
} else {
0
}
}
#[inline]
fn decode<B: Buf + ?Sized>(
wire_type: WireType,
duplicated: bool,
value: &mut C,
buf: Capped<B>,
ctx: DecodeContext,
) -> Result<(), DecodeError> {
if duplicated {
return Err(DecodeError::new(UnexpectedlyRepeated));
}
if wire_type == WireType::LengthDelimited
&& <C::Item as Wiretyped<E>>::WIRE_TYPE != WireType::LengthDelimited
{
ValueEncoder::<Packed<E>>::decode_value(value, buf, ctx)
} else {
decode::<C, E>(wire_type, value, buf, ctx)
}
}
}
impl<C, T, E> DistinguishedEncoder<Unpacked<E>> for C
where
Self: DistinguishedCollection<Item = T> + ValueEncoder<Packed<E>> + Encoder<Unpacked<E>>,
T: NewForOverwrite + Eq + DistinguishedValueEncoder<E>,
{
#[inline]
fn decode_distinguished<B: Buf + ?Sized>(
wire_type: WireType,
duplicated: bool,
value: &mut C,
buf: Capped<B>,
ctx: DecodeContext,
) -> Result<Canonicity, DecodeError> {
if duplicated {
return Err(DecodeError::new(UnexpectedlyRepeated));
}
if wire_type == WireType::LengthDelimited
&& <T as Wiretyped<E>>::WIRE_TYPE != WireType::LengthDelimited
{
<C as ValueEncoder<Packed<E>>>::decode_value(value, buf, ctx)?;
Ok(Canonicity::NotCanonical)
} else {
decode_distinguished::<C, E>(wire_type, value, buf, ctx)
}
}
}
impl<T, const N: usize, E> Encoder<Unpacked<E>> for [T; N]
where
T: EmptyState + ValueEncoder<E>,
{
#[inline]
fn encode<B: BufMut + ?Sized>(tag: u32, value: &[T; N], buf: &mut B, tw: &mut TagWriter) {
if !EmptyState::is_empty(value) {
for val in value.iter() {
FieldEncoder::<E>::encode_field(tag, val, buf, tw);
}
}
}
#[inline]
fn prepend_encode<B: ReverseBuf + ?Sized>(
tag: u32,
value: &Self,
buf: &mut B,
tw: &mut TagRevWriter,
) {
if !EmptyState::is_empty(value) {
for val in value.iter().rev() {
FieldEncoder::<E>::prepend_field(tag, val, buf, tw);
}
}
}
#[inline]
fn encoded_len(tag: u32, value: &[T; N], tm: &mut impl TagMeasurer) -> usize {
if !EmptyState::is_empty(value) {
tm.key_len(tag) + ValueEncoder::<E>::many_values_encoded_len(value.iter()) + N - 1
} else {
0
}
}
#[inline]
fn decode<B: Buf + ?Sized>(
wire_type: WireType,
duplicated: bool,
value: &mut [T; N],
buf: Capped<B>,
ctx: DecodeContext,
) -> Result<(), DecodeError> {
if duplicated {
return Err(DecodeError::new(UnexpectedlyRepeated));
}
if wire_type == WireType::LengthDelimited
&& <T as Wiretyped<E>>::WIRE_TYPE != WireType::LengthDelimited
{
ValueEncoder::<Packed<E>>::decode_value(value, buf, ctx)
} else {
decode_array::<T, N, E>(wire_type, value, buf, ctx)
}
}
}
impl<T, const N: usize, E> DistinguishedEncoder<Unpacked<E>> for [T; N]
where
T: Eq + EmptyState + DistinguishedValueEncoder<E> + ValueEncoder<E>,
{
#[inline]
fn decode_distinguished<B: Buf + ?Sized>(
wire_type: WireType,
duplicated: bool,
value: &mut [T; N],
buf: Capped<B>,
ctx: DecodeContext,
) -> Result<Canonicity, DecodeError> {
if duplicated {
return Err(DecodeError::new(UnexpectedlyRepeated));
}
if wire_type == WireType::LengthDelimited
&& <T as Wiretyped<E>>::WIRE_TYPE != WireType::LengthDelimited
{
<[T; N] as ValueEncoder<Packed<E>>>::decode_value(value, buf, ctx)?;
Ok(Canonicity::NotCanonical)
} else {
decode_distinguished_array::<T, N, E>(wire_type, value, buf, ctx)
}
}
}
#[cfg(test)]
mod test {
use alloc::string::String;
use alloc::vec::Vec;
use proptest::proptest;
use crate::encoding::test::{distinguished, expedient};
use crate::encoding::{Fixed, Unpacked, WireType};
proptest! {
#[test]
fn varint(value: Vec<u64>, tag: u32) {
expedient::check_type_unpacked::<Vec<u64>, Unpacked>(
value.clone(),
tag,
WireType::Varint,
)?;
distinguished::check_type_unpacked::<Vec<u64>, Unpacked>(value, tag, WireType::Varint)?;
}
#[test]
fn length_delimited(value: Vec<String>, tag: u32) {
expedient::check_type_unpacked::<Vec<String>, Unpacked>(
value.clone(),
tag,
WireType::LengthDelimited,
)?;
distinguished::check_type_unpacked::<Vec<String>, Unpacked>(
value,
tag,
WireType::LengthDelimited,
)?;
}
#[test]
fn fixed32(value: Vec<u32>, tag: u32) {
expedient::check_type_unpacked::<Vec<u32>, Unpacked<Fixed>>(
value.clone(),
tag,
WireType::ThirtyTwoBit,
)?;
distinguished::check_type_unpacked::<Vec<u32>, Unpacked<Fixed>>(
value,
tag,
WireType::ThirtyTwoBit,
)?;
}
#[test]
fn fixed64(value: Vec<u64>, tag: u32) {
expedient::check_type_unpacked::<Vec<u64>, Unpacked<Fixed>>(
value.clone(),
tag,
WireType::SixtyFourBit,
)?;
distinguished::check_type_unpacked::<Vec<u64>, Unpacked<Fixed>>(
value,
tag,
WireType::SixtyFourBit,
)?;
}
#[test]
fn varint_array(value: [u64; 2], tag: u32) {
expedient::check_type_unpacked::<[u64; 2], Unpacked>(
value,
tag,
WireType::Varint,
)?;
distinguished::check_type_unpacked::<[u64; 2], Unpacked>(value, tag, WireType::Varint)?;
}
#[test]
fn length_delimited_array(value: [String; 2], tag: u32) {
expedient::check_type_unpacked::<[String; 2], Unpacked>(
value.clone(),
tag,
WireType::LengthDelimited,
)?;
distinguished::check_type_unpacked::<[String; 2], Unpacked>(
value,
tag,
WireType::LengthDelimited,
)?;
}
#[test]
fn fixed32_array(value: [u32; 2], tag: u32) {
expedient::check_type_unpacked::<[u32; 2], Unpacked<Fixed>>(
value,
tag,
WireType::ThirtyTwoBit,
)?;
distinguished::check_type_unpacked::<[u32; 2], Unpacked<Fixed>>(
value,
tag,
WireType::ThirtyTwoBit,
)?;
}
#[test]
fn fixed64_array(value: [u64; 2], tag: u32) {
expedient::check_type_unpacked::<[u64; 2], Unpacked<Fixed>>(
value,
tag,
WireType::SixtyFourBit,
)?;
distinguished::check_type_unpacked::<[u64; 2], Unpacked<Fixed>>(
value,
tag,
WireType::SixtyFourBit,
)?;
}
}
}