use serde::{de, forward_to_deserialize_any};
type Result<T, E = Error> = std::result::Result<T, E>;
type Error = de::value::Error;
pub(crate) fn deserialize_struct<T: de::DeserializeOwned>(input: &[u8]) -> Result<T> {
let mut de = ValueDeserializer {
rest: input,
typ: Tag::Struct,
};
let v = T::deserialize(&mut de)?;
if !de.rest.is_empty() {
return Err(de::Error::custom(format_args!(
"unexpected trailing bytes at {}",
input.len() - de.rest.len(),
)));
}
Ok(v)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
pub(crate) enum Tag {
BoolTrue = 1,
BoolFalse = 2,
I16 = 4,
I32 = 5,
Binary = 8,
Map = 11,
Struct = 12,
UnknownBool = 0,
Invalid = 15,
}
impl Tag {
fn without_inline_bool(self) -> Self {
if let Self::BoolTrue | Self::BoolFalse = self {
Self::UnknownBool
} else {
self
}
}
}
impl TryFrom<u8> for Tag {
type Error = Error;
fn try_from(typ: u8) -> Result<Self> {
Ok(match typ {
1 => Tag::BoolTrue,
2 => Tag::BoolFalse,
4 => Tag::I16,
5 => Tag::I32,
8 => Tag::Binary,
11 => Tag::Map,
12 => Tag::Struct,
_ => {
return Err(de::Error::custom(format_args!(
"invalid or unsupported type tag: {typ:#x}"
)));
}
})
}
}
struct ValueDeserializer<'de> {
rest: &'de [u8],
typ: Tag,
}
impl<'de> ValueDeserializer<'de> {
fn eat_byte(&mut self) -> Result<u8> {
let (&fst, rest) = self
.rest
.split_first()
.ok_or_else(|| de::Error::custom("unexpected EOF"))?;
self.rest = rest;
Ok(fst)
}
fn eat_varint(&mut self) -> Result<u32> {
let mut x = 0u32;
for i in 0..5 {
let b = self.eat_byte()?;
x += u32::from(b & 0x7F) << (i * 7);
if b & 0x80 == 0 {
return Ok(x);
}
}
Err(de::Error::custom("encoded varint is too long"))
}
fn eat_zigzag(&mut self) -> Result<i32> {
let x = self.eat_varint()?;
Ok((x >> 1) as i32 ^ -(x as i32 & 1))
}
}
impl<'de> de::Deserializer<'de> for &mut ValueDeserializer<'de> {
type Error = Error;
fn is_human_readable(&self) -> bool {
false
}
fn deserialize_any<V>(self, visitor: V) -> Result<V::Value>
where
V: de::Visitor<'de>,
{
match self.typ {
Tag::UnknownBool => visitor.visit_bool(match self.eat_byte()? {
1 => true,
2 => false,
x => {
return Err(de::Error::custom(format_args!(
"invalid value for bool: {x:#x}"
)));
}
}),
Tag::BoolTrue => visitor.visit_bool(true),
Tag::BoolFalse => visitor.visit_bool(false),
Tag::I16 | Tag::I32 => visitor.visit_i32(self.eat_zigzag()?),
Tag::Binary => {
let len = self.eat_varint()?;
let len = usize::try_from(len).unwrap_or(usize::MAX);
let (data, rest) = self
.rest
.split_at_checked(len)
.ok_or_else(|| de::Error::custom("input data is too short"))?;
self.rest = rest;
visitor.visit_borrowed_bytes(data)
}
Tag::Map => {
let len = self.eat_varint()?;
let (ktype, vtype) = if len == 0 {
(Tag::Invalid, Tag::Invalid)
} else {
let typ = self.eat_byte()?;
let ktype = Tag::try_from(typ >> 4)?.without_inline_bool();
let vtype = Tag::try_from(typ & 0xF)?.without_inline_bool();
(ktype, vtype)
};
visitor.visit_map(MapDeserializer {
de: self,
len,
ktype,
vtype,
})
}
Tag::Struct => visitor.visit_map(StructDeserializer {
de: self,
field_id: 0,
value_type: Tag::Invalid,
}),
Tag::Invalid => unreachable!(),
}
}
forward_to_deserialize_any! {
bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string
bytes byte_buf option unit unit_struct newtype_struct seq tuple
tuple_struct map struct enum identifier ignored_any
}
}
struct StructDeserializer<'a, 'de> {
de: &'a mut ValueDeserializer<'de>,
field_id: i16,
value_type: Tag,
}
impl<'de> de::MapAccess<'de> for StructDeserializer<'_, 'de> {
type Error = Error;
fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>>
where
K: de::DeserializeSeed<'de>,
{
let b = self.de.eat_byte()?;
if b == 0 {
return Ok(None);
}
let id_delta = i16::from(b >> 4);
self.field_id = if id_delta != 0 {
self.field_id.checked_add(id_delta)
} else {
i16::try_from(self.de.eat_zigzag()?).ok()
}
.ok_or_else(|| de::Error::custom("field id overflow"))?;
self.value_type = Tag::try_from(b & 0xF)?;
let field_id = (self.field_id - 1) as u64;
seed.deserialize(de::value::U64Deserializer::new(field_id))
.map(Some)
}
fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value>
where
V: de::DeserializeSeed<'de>,
{
let prev_typ = std::mem::replace(&mut self.de.typ, self.value_type);
let v = seed.deserialize(&mut *self.de);
self.de.typ = prev_typ;
v
}
}
struct MapDeserializer<'a, 'de> {
de: &'a mut ValueDeserializer<'de>,
len: u32,
ktype: Tag,
vtype: Tag,
}
impl<'de> de::MapAccess<'de> for MapDeserializer<'_, 'de> {
type Error = Error;
fn size_hint(&self) -> Option<usize> {
usize::try_from(self.len).ok()
}
fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>>
where
K: de::DeserializeSeed<'de>,
{
if self.len == 0 {
return Ok(None);
}
self.len -= 1;
let prev_typ = std::mem::replace(&mut self.de.typ, self.ktype);
let k = seed.deserialize(&mut *self.de);
self.de.typ = prev_typ;
k.map(Some)
}
fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value>
where
V: de::DeserializeSeed<'de>,
{
let prev_typ = std::mem::replace(&mut self.de.typ, self.vtype);
let v = seed.deserialize(&mut *self.de);
self.de.typ = prev_typ;
v
}
}