use crate::buffer::InputSource;
use crate::decode_from::*;
use crate::decoder::Decoder;
use crate::{Error, InvalidDataErrorKind, Result};
#[cfg(feature = "alloc")]
use alloc::collections::BTreeMap;
#[cfg(feature = "alloc")]
use alloc::string::String;
#[cfg(feature = "alloc")]
use alloc::vec::Vec;
#[cfg(feature = "alloc")]
use core::fmt::Debug;
#[cfg(feature = "std")]
use std::collections::HashMap;
#[cfg(feature = "std")]
use std::hash::Hash;
const TAG_END_MARKER: i32 = -1;
fn illegal_bool_error(value: u8) -> Error {
let error = InvalidDataErrorKind::IllegalValue {
desc: "bools can only have a numeric value of either '0' or '1'",
value: Some(value as i128),
};
error.into()
}
impl DecodeFrom for bool {
fn decode_from(decoder: &mut Decoder<impl InputSource>) -> crate::Result<Self> {
let byte = decoder.read_byte()?;
match byte {
0 | 1 => Ok(byte != 0),
_ => Err(illegal_bool_error(byte)),
}
}
}
impl DecodeFrom for u8 {
fn decode_from(decoder: &mut Decoder<impl InputSource>) -> crate::Result<Self> {
decoder.read_byte()
}
}
impl DecodeFrom for i8 {
fn decode_from(decoder: &mut Decoder<impl InputSource>) -> crate::Result<Self> {
let byte = decoder.read_byte()?;
Ok(byte as i8)
}
}
implement_decode_from_on_numeric_primitive_type! {u16, "Decodes a [`u16`] from 2 bytes (little endian)."}
implement_decode_from_on_numeric_primitive_type! {i16, "Decodes a [`i16`] from 2 bytes (little endian) in two's complement form."}
implement_decode_from_on_numeric_primitive_type! {u32, "Decodes a [`u32`] from 4 bytes (little endian)."}
implement_decode_from_on_numeric_primitive_type! {i32, "Decodes a [`i32`] from 4 bytes (little endian) in two's complement form."}
implement_decode_from_on_numeric_primitive_type! {u64, "Decodes a [`u64`] from 8 bytes (little endian)."}
implement_decode_from_on_numeric_primitive_type! {i64, "Decodes a [`i64`] from 8 bytes (little endian) in two's complement form."}
implement_decode_from_on_numeric_primitive_type! {f32, "Decodes a [`f32`] from 4 bytes (little endian) using the \"binary32\" representation defined in IEEE 754-2008."}
implement_decode_from_on_numeric_primitive_type! {f64, "Decodes a [`f64`] from 8 bytes (little endian) using the \"binary64\" representation defined in IEEE 754-2008."}
fn varint_range_error<T>(value: i64) -> Error {
let size = core::mem::size_of::<T>() as u32;
let shift_count = i128::BITS - (size * 8);
let error = InvalidDataErrorKind::OutOfRange {
value: value as i128,
min: i128::MIN >> shift_count,
max: i128::MAX >> shift_count,
typename: core::any::type_name::<T>(),
};
error.into()
}
fn varuint_range_error<T>(value: u64) -> Error {
let size = core::mem::size_of::<T>() as u32;
let shift_count = u128::BITS - (size * 8);
let error = InvalidDataErrorKind::OutOfRange {
value: value as i128,
min: 0,
max: (u128::MAX >> shift_count) as i128,
typename: core::any::type_name::<T>(),
};
error.into()
}
impl<I: InputSource> Decoder<I> {
pub fn decode_varint<T: TryFrom<i64>>(&mut self) -> Result<T> {
#[rustfmt::skip] let mut value = match self.peek_byte()? & 0b11 {
0b00 => i8::decode_from(self)? as i64,
0b01 => i16::decode_from(self)? as i64,
0b10 => i32::decode_from(self)? as i64,
0b11 => i64::decode_from(self)?,
_ => unsafe { core::hint::unreachable_unchecked() },
};
value >>= 2;
T::try_from(value).map_err(|_| varint_range_error::<T>(value))
}
pub fn decode_varuint<T: TryFrom<u64>>(&mut self) -> Result<T> {
#[rustfmt::skip] let mut value = match self.peek_byte()? & 0b11 {
0b00 => u8::decode_from(self)? as u64,
0b01 => u16::decode_from(self)? as u64,
0b10 => u32::decode_from(self)? as u64,
0b11 => u64::decode_from(self)?,
_ => unsafe { core::hint::unreachable_unchecked() },
};
value >>= 2;
T::try_from(value).map_err(|_| varuint_range_error::<T>(value))
}
pub fn decode_size(&mut self) -> Result<usize> {
self.decode_varuint()
}
pub fn skip_tagged_fields(&mut self) -> Result<()> {
while self.decode_varint::<i32>()? != TAG_END_MARKER {
let field_size = self.decode_size()?;
self.read_byte_slice_exact(field_size)?;
}
Ok(())
}
}
#[cfg(feature = "alloc")]
impl DecodeFrom for String {
fn decode_from(decoder: &mut Decoder<impl InputSource>) -> Result<Self> {
let length = decoder.decode_varuint()?;
let mut vector = Vec::new();
vector.try_reserve_exact(length)?;
unsafe {
debug_assert_eq!(vector.len(), 0);
let bytes =
core::mem::transmute::<&mut [core::mem::MaybeUninit<u8>], &mut [u8]>(vector.spare_capacity_mut());
decoder.read_bytes_into_exact(bytes)?;
vector.set_len(length);
}
let string = String::from_utf8(vector)?;
Ok(string)
}
}
#[cfg(feature = "alloc")]
impl<T> DecodeFrom for Vec<T>
where
T: DecodeFrom,
{
fn decode_from(decoder: &mut Decoder<impl InputSource>) -> Result<Self> {
let length = decoder.decode_varuint()?;
let mut vector = Vec::new();
vector.try_reserve_exact(length)?;
for _ in 0..length {
let element = decoder.decode()?;
vector.push(element);
}
Ok(vector)
}
}
#[cfg(feature = "std")]
impl<K, V> DecodeFrom for HashMap<K, V>
where
K: DecodeFrom + Debug + Eq + Hash,
V: DecodeFrom,
{
fn decode_from(decoder: &mut Decoder<impl InputSource>) -> Result<Self> {
let length = decoder.decode_varuint()?;
let mut map = HashMap::new();
map.try_reserve(length)?;
decode_dictionary_entries!(map, decoder, length);
Ok(map)
}
}
#[cfg(feature = "alloc")]
impl<K, V> DecodeFrom for BTreeMap<K, V>
where
K: DecodeFrom + Debug + Ord,
V: DecodeFrom,
{
fn decode_from(decoder: &mut Decoder<impl InputSource>) -> Result<Self> {
let length = decoder.decode_varuint()?;
let mut map = BTreeMap::new();
decode_dictionary_entries!(map, decoder, length);
Ok(map)
}
}