use alloc::borrow::Cow;
use alloc::collections::{BTreeMap, BTreeSet};
use alloc::string::String;
use alloc::vec::Vec;
use core::mem;
use core::str;
use bytes::{Buf, BufMut, Bytes};
use crate::buf::ReverseBuf;
use crate::encoding::{
delegate_encoding, delegate_value_encoding, encode_varint, encoded_len_varint,
encoder_where_value_encoder, prepend_varint, Canonicity, Capped, DecodeContext, DecodeError,
DistinguishedValueEncoder, Encoder, Fixed, Map, PlainBytes, Unpacked, ValueEncoder, Varint,
WireType, Wiretyped,
};
use crate::message::{merge, merge_distinguished, RawDistinguishedMessage, RawMessage};
use crate::Blob;
use crate::DecodeErrorKind::InvalidValue;
pub struct General;
encoder_where_value_encoder!(General);
delegate_encoding!(delegate from (General) to (Unpacked<General>)
for type (Vec<T>) including distinguished with generics (T));
delegate_encoding!(delegate from (General) to (Unpacked<General>)
for type (Cow<'a, [T]>) including distinguished
with where clause (T: Clone)
with generics ('a, T));
#[cfg(feature = "arrayvec")]
delegate_encoding!(delegate from (General) to (Unpacked<General>)
for type (arrayvec::ArrayVec<T, N>) including distinguished
with generics (T, const N: usize));
#[cfg(feature = "smallvec")]
delegate_encoding!(delegate from (General) to (Unpacked<General>)
for type (smallvec::SmallVec<A>) including distinguished
with where clause (A: smallvec::Array<Item = T>)
with generics (T, A));
#[cfg(feature = "thin-vec")]
delegate_encoding!(delegate from (General) to (Unpacked<General>)
for type (thin_vec::ThinVec<T>) including distinguished with generics (T));
#[cfg(feature = "tinyvec")]
delegate_encoding!(delegate from (General) to (Unpacked<General>)
for type (tinyvec::ArrayVec<A>) including distinguished
with where clause (A: tinyvec::Array<Item = T>)
with generics (T, A));
#[cfg(feature = "tinyvec")]
delegate_encoding!(delegate from (General) to (Unpacked<General>)
for type (tinyvec::TinyVec<A>) including distinguished
with where clause (A: tinyvec::Array<Item = T>)
with generics (T, A));
delegate_encoding!(delegate from (General) to (Unpacked<General>)
for type (BTreeSet<T>) including distinguished with generics (T));
delegate_value_encoding!(delegate from (General) to (Map<General, General>)
for type (BTreeMap<K, V>) including distinguished
with where clause for expedient (K: Ord)
with where clause for distinguished (V: Eq)
with generics (K, V));
#[cfg(feature = "std")]
delegate_encoding!(delegate from (General) to (Unpacked<General>)
for type (std::collections::HashSet<T, S>)
with where clause (S: Default + core::hash::BuildHasher)
with generics (T, S));
#[cfg(feature = "std")]
delegate_value_encoding!(delegate from (General) to (Map<General, General>)
for type (std::collections::HashMap<K, V, S>)
with where clause (K: Eq + core::hash::Hash, S: Default + core::hash::BuildHasher)
with generics (K, V, S));
#[cfg(feature = "hashbrown")]
delegate_encoding!(delegate from (General) to (Unpacked<General>)
for type (hashbrown::HashSet<T, S>)
with where clause (S: Default + core::hash::BuildHasher)
with generics (T, S));
#[cfg(feature = "hashbrown")]
delegate_value_encoding!(delegate from (General) to (Map<General, General>)
for type (hashbrown::HashMap<K, V, S>)
with where clause (K: Eq + core::hash::Hash, S: Default + core::hash::BuildHasher)
with generics (K, V, S));
delegate_value_encoding!(delegate from (General) to (Varint)
for type (bool) including distinguished);
delegate_value_encoding!(delegate from (General) to (Varint)
for type (u16) including distinguished);
delegate_value_encoding!(delegate from (General) to (Varint)
for type (i16) including distinguished);
delegate_value_encoding!(delegate from (General) to (Varint)
for type (u32) including distinguished);
delegate_value_encoding!(delegate from (General) to (Varint)
for type (i32) including distinguished);
delegate_value_encoding!(delegate from (General) to (Varint)
for type (u64) including distinguished);
delegate_value_encoding!(delegate from (General) to (Varint)
for type (i64) including distinguished);
delegate_value_encoding!(delegate from (General) to (Fixed) for type (f32));
delegate_value_encoding!(delegate from (General) to (Fixed) for type (f64));
impl Wiretyped<General> for String {
const WIRE_TYPE: WireType = WireType::LengthDelimited;
}
impl ValueEncoder<General> for String {
#[inline]
fn encode_value<B: BufMut + ?Sized>(value: &String, buf: &mut B) {
encode_varint(value.len() as u64, buf);
buf.put_slice(value.as_bytes());
}
#[inline]
fn prepend_value<B: ReverseBuf + ?Sized>(value: &String, buf: &mut B) {
buf.prepend_slice(value.as_bytes());
prepend_varint(value.len() as u64, buf);
}
#[inline]
fn value_encoded_len(value: &String) -> usize {
encoded_len_varint(value.len() as u64) + value.len()
}
#[inline]
fn decode_value<B: Buf + ?Sized>(
value: &mut String,
mut buf: Capped<B>,
_ctx: DecodeContext,
) -> Result<(), DecodeError> {
struct DropGuard<'a>(&'a mut Vec<u8>);
impl Drop for DropGuard<'_> {
#[inline]
fn drop(&mut self) {
self.0.clear();
}
}
let source = buf.take_length_delimited()?.take_all();
value.clear();
value.reserve(source.remaining());
unsafe {
let drop_guard = DropGuard(value.as_mut_vec());
drop_guard.0.put(source);
match str::from_utf8(drop_guard.0) {
Ok(_) => {
mem::forget(drop_guard);
Ok(())
}
Err(_) => Err(DecodeError::new(InvalidValue)),
}
}
}
}
impl DistinguishedValueEncoder<General> for String {
#[inline]
fn decode_value_distinguished<const ALLOW_EMPTY: bool>(
value: &mut String,
buf: Capped<impl Buf + ?Sized>,
ctx: DecodeContext,
) -> Result<Canonicity, DecodeError> {
Self::decode_value(value, buf, ctx)?;
Ok(if !ALLOW_EMPTY && value.is_empty() {
Canonicity::NotCanonical
} else {
Canonicity::Canonical
})
}
}
#[cfg(test)]
mod string {
use super::{General, String};
use crate::encoding::test::check_type_test;
check_type_test!(General, expedient, String, WireType::LengthDelimited);
check_type_test!(General, distinguished, String, WireType::LengthDelimited);
}
impl Wiretyped<General> for Cow<'_, str> {
const WIRE_TYPE: WireType = WireType::LengthDelimited;
}
impl ValueEncoder<General> for Cow<'_, str> {
#[inline]
fn encode_value<B: BufMut + ?Sized>(value: &Cow<str>, buf: &mut B) {
encode_varint(value.len() as u64, buf);
buf.put_slice(value.as_bytes());
}
#[inline]
fn prepend_value<B: ReverseBuf + ?Sized>(value: &Cow<str>, buf: &mut B) {
buf.prepend_slice(value.as_bytes());
prepend_varint(value.len() as u64, buf);
}
#[inline]
fn value_encoded_len(value: &Cow<str>) -> usize {
encoded_len_varint(value.len() as u64) + value.len()
}
#[inline]
fn decode_value<B: Buf + ?Sized>(
value: &mut Cow<str>,
buf: Capped<B>,
ctx: DecodeContext,
) -> Result<(), DecodeError> {
ValueEncoder::<General>::decode_value(value.to_mut(), buf, ctx)
}
}
impl DistinguishedValueEncoder<General> for Cow<'_, str> {
#[inline]
fn decode_value_distinguished<const ALLOW_EMPTY: bool>(
value: &mut Cow<str>,
buf: Capped<impl Buf + ?Sized>,
ctx: DecodeContext,
) -> Result<Canonicity, DecodeError> {
DistinguishedValueEncoder::<General>::decode_value_distinguished::<ALLOW_EMPTY>(
value.to_mut(),
buf,
ctx,
)
}
}
#[cfg(test)]
mod cow_string {
use super::{Cow, General};
use crate::encoding::test::check_type_test;
check_type_test!(General, expedient, Cow<str>, WireType::LengthDelimited);
check_type_test!(General, distinguished, Cow<str>, WireType::LengthDelimited);
}
#[cfg(feature = "bytestring")]
impl Wiretyped<General> for bytestring::ByteString {
const WIRE_TYPE: WireType = WireType::LengthDelimited;
}
#[cfg(feature = "bytestring")]
impl ValueEncoder<General> for bytestring::ByteString {
#[inline]
fn encode_value<B: BufMut + ?Sized>(value: &bytestring::ByteString, buf: &mut B) {
encode_varint(value.len() as u64, buf);
buf.put_slice(value.as_bytes());
}
#[inline]
fn prepend_value<B: ReverseBuf + ?Sized>(value: &bytestring::ByteString, buf: &mut B) {
buf.prepend_slice(value.as_bytes());
prepend_varint(value.len() as u64, buf);
}
#[inline]
fn value_encoded_len(value: &bytestring::ByteString) -> usize {
encoded_len_varint(value.len() as u64) + value.len()
}
#[inline]
fn decode_value<B: Buf + ?Sized>(
value: &mut bytestring::ByteString,
mut buf: Capped<B>,
_ctx: DecodeContext,
) -> Result<(), DecodeError> {
let mut string_data = buf.take_length_delimited()?;
let string_len = string_data.remaining_before_cap();
*value = bytestring::ByteString::try_from(string_data.copy_to_bytes(string_len))
.map_err(|_| DecodeError::new(InvalidValue))?;
Ok(())
}
}
#[cfg(feature = "bytestring")]
impl DistinguishedValueEncoder<General> for bytestring::ByteString {
#[inline]
fn decode_value_distinguished<const ALLOW_EMPTY: bool>(
value: &mut bytestring::ByteString,
buf: Capped<impl Buf + ?Sized>,
ctx: DecodeContext,
) -> Result<Canonicity, DecodeError> {
Self::decode_value(value, buf, ctx)?;
Ok(if !ALLOW_EMPTY && value.is_empty() {
Canonicity::NotCanonical
} else {
Canonicity::Canonical
})
}
}
#[cfg(feature = "bytestring")]
#[cfg(test)]
mod bytestring_string {
use super::{General, String};
use crate::encoding::test::check_type_test;
check_type_test!(General, expedient, from String, into bytestring::ByteString, WireType::LengthDelimited);
check_type_test!(General, distinguished, from String, into bytestring::ByteString,
WireType::LengthDelimited);
}
impl Wiretyped<General> for Bytes {
const WIRE_TYPE: WireType = WireType::LengthDelimited;
}
impl ValueEncoder<General> for Bytes {
#[inline]
fn encode_value<B: BufMut + ?Sized>(value: &Bytes, mut buf: &mut B) {
encode_varint(value.len() as u64, buf);
(&mut buf).put(value.clone()); }
#[inline]
fn prepend_value<B: ReverseBuf + ?Sized>(value: &Bytes, buf: &mut B) {
buf.prepend_slice(value);
prepend_varint(value.len() as u64, buf);
}
#[inline]
fn value_encoded_len(value: &Bytes) -> usize {
encoded_len_varint(value.len() as u64) + value.len()
}
#[inline]
fn decode_value<B: Buf + ?Sized>(
value: &mut Bytes,
mut buf: Capped<B>,
_ctx: DecodeContext,
) -> Result<(), DecodeError> {
let mut buf = buf.take_length_delimited()?;
let len = buf.remaining_before_cap();
*value = buf.copy_to_bytes(len);
Ok(())
}
}
impl DistinguishedValueEncoder<General> for Bytes {
#[inline]
fn decode_value_distinguished<const ALLOW_EMPTY: bool>(
value: &mut Bytes,
buf: Capped<impl Buf + ?Sized>,
ctx: DecodeContext,
) -> Result<Canonicity, DecodeError> {
Self::decode_value(value, buf, ctx)?;
Ok(if !ALLOW_EMPTY && value.is_empty() {
Canonicity::NotCanonical
} else {
Canonicity::Canonical
})
}
}
#[cfg(test)]
mod bytes_blob {
use super::{Bytes, General, Vec};
use crate::encoding::test::check_type_test;
check_type_test!(General, expedient, from Vec<u8>, into Bytes, WireType::LengthDelimited);
check_type_test!(General, distinguished, from Vec<u8>, into Bytes, WireType::LengthDelimited);
}
impl Wiretyped<General> for Blob {
const WIRE_TYPE: WireType = WireType::LengthDelimited;
}
impl ValueEncoder<General> for Blob {
#[inline]
fn encode_value<B: BufMut + ?Sized>(value: &Blob, buf: &mut B) {
ValueEncoder::<PlainBytes>::encode_value(&**value, buf)
}
#[inline]
fn prepend_value<B: ReverseBuf + ?Sized>(value: &Blob, buf: &mut B) {
buf.prepend_slice(value);
prepend_varint(value.len() as u64, buf);
}
#[inline]
fn value_encoded_len(value: &Blob) -> usize {
ValueEncoder::<PlainBytes>::value_encoded_len(&**value)
}
#[inline]
fn decode_value<B: Buf + ?Sized>(
value: &mut Blob,
buf: Capped<B>,
ctx: DecodeContext,
) -> Result<(), DecodeError> {
ValueEncoder::<PlainBytes>::decode_value(&mut **value, buf, ctx)
}
}
impl DistinguishedValueEncoder<General> for Blob {
#[inline]
fn decode_value_distinguished<const ALLOW_EMPTY: bool>(
value: &mut Blob,
buf: Capped<impl Buf + ?Sized>,
ctx: DecodeContext,
) -> Result<Canonicity, DecodeError> {
DistinguishedValueEncoder::<PlainBytes>::decode_value_distinguished::<ALLOW_EMPTY>(
&mut **value,
buf,
ctx,
)
}
}
#[cfg(test)]
mod blob {
use super::{Blob, General};
use crate::encoding::test::check_type_test;
check_type_test!(General, expedient, Blob, WireType::LengthDelimited);
check_type_test!(General, distinguished, Blob, WireType::LengthDelimited);
}
impl<T> Wiretyped<General> for T
where
T: RawMessage,
{
const WIRE_TYPE: WireType = WireType::LengthDelimited;
}
impl<T> ValueEncoder<General> for T
where
T: RawMessage,
{
#[inline]
fn encode_value<B: BufMut + ?Sized>(value: &T, buf: &mut B) {
encode_varint(value.raw_encoded_len() as u64, buf);
value.raw_encode(buf);
}
#[inline]
fn prepend_value<B: ReverseBuf + ?Sized>(value: &T, buf: &mut B) {
let end = buf.remaining();
value.raw_prepend(buf);
prepend_varint((buf.remaining() - end) as u64, buf);
}
#[inline]
fn value_encoded_len(value: &T) -> usize {
let inner_len = value.raw_encoded_len();
encoded_len_varint(inner_len as u64) + inner_len
}
#[inline]
fn decode_value<B: Buf + ?Sized>(
value: &mut T,
mut buf: Capped<B>,
ctx: DecodeContext,
) -> Result<(), DecodeError> {
ctx.limit_reached()?;
merge(value, buf.take_length_delimited()?, ctx.enter_recursion())
}
}
impl<T> DistinguishedValueEncoder<General> for T
where
T: RawDistinguishedMessage + Eq,
{
#[inline]
fn decode_value_distinguished<const ALLOW_EMPTY: bool>(
value: &mut T,
mut buf: Capped<impl Buf + ?Sized>,
ctx: DecodeContext,
) -> Result<Canonicity, DecodeError> {
ctx.limit_reached()?;
let buf = buf.take_length_delimited()?;
if !ALLOW_EMPTY && buf.remaining_before_cap() == 0 {
return Ok(Canonicity::NotCanonical);
}
merge_distinguished(value, buf, ctx.enter_recursion())
}
}