use core::convert::TryInto;
use serde::de::{self, DeserializeSeed, IntoDeserializer, Visitor};
use crate::error::{Error, Result};
use crate::varint::VarintUsize;
pub struct Deserializer<'de> {
pub(crate) input: &'de [u8],
}
impl<'de> Deserializer<'de> {
pub fn from_bytes(input: &'de [u8]) -> Self {
Deserializer { input }
}
}
impl<'de> Deserializer<'de> {
fn try_take_n(&mut self, ct: usize) -> Result<&'de [u8]> {
if self.input.len() >= ct {
let (a, b) = self.input.split_at(ct);
self.input = b;
Ok(a)
} else {
Err(Error::DeserializeUnexpectedEnd)
}
}
fn try_take_varint(&mut self) -> Result<usize> {
for i in 0..VarintUsize::varint_usize_max() {
let val = self.input.get(i).ok_or(Error::DeserializeUnexpectedEnd)?;
if (val & 0x80) == 0 {
let (a, b) = self.input.split_at(i + 1);
self.input = b;
let mut out = 0usize;
for byte in a.iter().rev() {
out <<= 7;
out |= (byte & 0x7F) as usize;
}
return Ok(out);
}
}
Err(Error::DeserializeBadVarint)
}
}
struct MultiAccess<'a, 'b: 'a> {
deserializer: &'a mut Deserializer<'b>,
len: usize,
}
impl<'a, 'b: 'a> serde::de::SeqAccess<'b> for MultiAccess<'a, 'b> {
type Error = Error;
fn next_element_seed<V: DeserializeSeed<'b>>(&mut self, seed: V) -> Result<Option<V::Value>> {
if self.len > 0 {
self.len -= 1;
Ok(Some(DeserializeSeed::deserialize(
seed,
&mut *self.deserializer,
)?))
} else {
Ok(None)
}
}
fn size_hint(&self) -> Option<usize> {
Some(self.len)
}
}
impl<'de, 'a> serde::de::MapAccess<'de> for MultiAccess<'a, 'de> {
type Error = Error;
fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>>
where
K: DeserializeSeed<'de>,
{
if self.len > 0 {
self.len -= 1;
Ok(Some(DeserializeSeed::deserialize(
seed,
&mut *self.deserializer,
)?))
} else {
Ok(None)
}
}
fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value>
where
V: DeserializeSeed<'de>,
{
Ok(DeserializeSeed::deserialize(seed, &mut *self.deserializer)?)
}
}
impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> {
type Error = Error;
fn deserialize_any<V>(self, _visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
Err(Error::WontImplement)
}
fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
let val = match self.try_take_n(1)?[0] {
0 => false,
1 => true,
_ => return Err(Error::DeserializeBadBool),
};
visitor.visit_bool(val)
}
fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
let mut buf = [0u8; 1];
buf[..].copy_from_slice(self.try_take_n(1)?);
visitor.visit_i8(i8::from_le_bytes(buf))
}
fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
let mut buf = [0u8; 2];
buf[..].copy_from_slice(self.try_take_n(2)?);
visitor.visit_i16(i16::from_le_bytes(buf))
}
fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
let mut buf = [0u8; 4];
buf[..].copy_from_slice(self.try_take_n(4)?);
visitor.visit_i32(i32::from_le_bytes(buf))
}
fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
let mut buf = [0u8; 8];
buf[..].copy_from_slice(self.try_take_n(8)?);
visitor.visit_i64(i64::from_le_bytes(buf))
}
fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
visitor.visit_u8(self.try_take_n(1)?[0])
}
fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
let mut buf = [0u8; 2];
buf[..].copy_from_slice(self.try_take_n(2)?);
visitor.visit_u16(u16::from_le_bytes(buf))
}
fn deserialize_u32<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
let mut buf = [0u8; 4];
buf[..].copy_from_slice(self.try_take_n(4)?);
visitor.visit_u32(u32::from_le_bytes(buf))
}
fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
let mut buf = [0u8; 8];
buf[..].copy_from_slice(self.try_take_n(8)?);
visitor.visit_u64(u64::from_le_bytes(buf))
}
fn deserialize_f32<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
let bytes = self.try_take_n(4)?;
visitor.visit_f32(f32::from_le_bytes(bytes.try_into().unwrap()))
}
fn deserialize_f64<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
let bytes = self.try_take_n(8)?;
visitor.visit_f64(f64::from_le_bytes(bytes.try_into().unwrap()))
}
fn deserialize_char<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
let mut buf = [0u8; 4];
let bytes = self.try_take_n(4)?;
buf.copy_from_slice(bytes);
let integer = u32::from_le_bytes(buf);
visitor.visit_char(core::char::from_u32(integer).ok_or(Error::DeserializeBadChar)?)
}
fn deserialize_str<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
let sz = self.try_take_varint()?;
let bytes: &'de [u8] = self.try_take_n(sz)?;
let str_sl = core::str::from_utf8(bytes).map_err(|_| Error::DeserializeBadUtf8)?;
visitor.visit_borrowed_str(str_sl)
}
fn deserialize_string<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
self.deserialize_str(visitor)
}
fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
let sz = self.try_take_varint()?;
let bytes: &'de [u8] = self.try_take_n(sz)?;
visitor.visit_borrowed_bytes(bytes)
}
fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
self.deserialize_bytes(visitor)
}
fn deserialize_option<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
match self.try_take_n(1)?[0] {
0 => visitor.visit_none(),
1 => visitor.visit_some(self),
_ => Err(Error::DeserializeBadOption),
}
}
fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
visitor.visit_unit()
}
fn deserialize_unit_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
self.deserialize_unit(visitor)
}
fn deserialize_newtype_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
visitor.visit_newtype_struct(self)
}
fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
let len = self.try_take_varint()?;
visitor.visit_seq(MultiAccess {
deserializer: self,
len,
})
}
fn deserialize_tuple<V>(self, len: usize, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
visitor.visit_seq(MultiAccess {
deserializer: self,
len,
})
}
fn deserialize_tuple_struct<V>(
self,
_name: &'static str,
_len: usize,
visitor: V,
) -> Result<V::Value>
where
V: Visitor<'de>,
{
self.deserialize_seq(visitor)
}
fn deserialize_map<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
let len = self.try_take_varint()?;
visitor.visit_map(MultiAccess {
deserializer: self,
len,
})
}
fn deserialize_struct<V>(
self,
_name: &'static str,
fields: &'static [&'static str],
visitor: V,
) -> Result<V::Value>
where
V: Visitor<'de>,
{
self.deserialize_tuple(fields.len(), visitor)
}
fn deserialize_enum<V>(
self,
_name: &'static str,
_variants: &'static [&'static str],
visitor: V,
) -> Result<V::Value>
where
V: Visitor<'de>,
{
visitor.visit_enum(self)
}
fn deserialize_identifier<V>(self, _visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
Err(Error::WontImplement)
}
fn deserialize_ignored_any<V>(self, _visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
Err(Error::WontImplement)
}
}
impl<'de, 'a> serde::de::VariantAccess<'de> for &'a mut Deserializer<'de> {
type Error = Error;
fn unit_variant(self) -> Result<()> {
Ok(())
}
fn newtype_variant_seed<V: DeserializeSeed<'de>>(self, seed: V) -> Result<V::Value> {
DeserializeSeed::deserialize(seed, self)
}
fn tuple_variant<V: Visitor<'de>>(self, len: usize, visitor: V) -> Result<V::Value> {
serde::de::Deserializer::deserialize_tuple(self, len, visitor)
}
fn struct_variant<V: Visitor<'de>>(
self,
fields: &'static [&'static str],
visitor: V,
) -> Result<V::Value> {
serde::de::Deserializer::deserialize_tuple(self, fields.len(), visitor)
}
}
impl<'de, 'a> serde::de::EnumAccess<'de> for &'a mut Deserializer<'de> {
type Error = Error;
type Variant = Self;
fn variant_seed<V: DeserializeSeed<'de>>(self, seed: V) -> Result<(V::Value, Self)> {
let varint = self.try_take_varint()?;
if varint > 0xFFFF_FFFF {
return Err(Error::DeserializeBadEnum);
}
let v = DeserializeSeed::deserialize(seed, (varint as u32).into_deserializer())?;
Ok((v, self))
}
}