use bytes::Buf;
use derive_new::new;
use serde::de::{self, DeserializeOwned, DeserializeSeed, SeqAccess, Visitor};
use std::fmt::{self, Display};
pub type Result<T> = std::result::Result<T, Error>;
pub fn decode<T: DeserializeOwned>(buf: &[u8]) -> anyhow::Result<T> {
let mut de = Deserializer::new(buf);
let t = T::deserialize(&mut de)?;
Ok(t)
}
#[derive(Clone, Debug, PartialEq)]
pub struct Error(String);
impl From<std::io::Error> for Error {
fn from(e: std::io::Error) -> Self {
Self(e.to_string())
}
}
impl de::Error for Error {
fn custom<T: Display>(msg: T) -> Self {
Error(msg.to_string())
}
}
impl Display for Error {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.0)
}
}
impl std::error::Error for Error {}
#[derive(new)]
pub struct Deserializer<'de> {
buf: &'de [u8],
}
impl<'de> Deserializer<'de> {
fn parse_str(&mut self) -> String {
let s = String::from_utf8_lossy(self.buf).to_string();
self.buf = &self.buf[self.buf.len()..];
s
}
fn check(&mut self, len: usize) -> Result<()> {
if self.buf.len() < len {
Err(Error(format!(
"Buffer is too short: expect {} but {}",
len,
self.buf.len()
)))
} else {
Ok(())
}
}
}
impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> {
type Error = Error;
fn deserialize_any<V>(self, _visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
unreachable!()
}
fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
self.check(1)?;
visitor.visit_bool(self.buf.get_u8() != 0)
}
fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
self.check(1)?;
visitor.visit_i8(self.buf.get_i8())
}
fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
self.check(2)?;
visitor.visit_i16(self.buf.get_i16_le())
}
fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
self.check(4)?;
visitor.visit_i32(self.buf.get_i32_le())
}
fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
self.check(8)?;
visitor.visit_i64(self.buf.get_i64_le())
}
fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
self.check(1)?;
visitor.visit_u8(self.buf.get_u8())
}
fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
self.check(2)?;
visitor.visit_u16(self.buf.get_u16_le())
}
fn deserialize_u32<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
self.check(4)?;
visitor.visit_u32(self.buf.get_u32_le())
}
fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
self.check(8)?;
visitor.visit_u64(self.buf.get_u64_le())
}
fn deserialize_f32<V>(self, _visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
unreachable!()
}
fn deserialize_f64<V>(self, _visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
unreachable!()
}
fn deserialize_char<V>(self, _visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
unreachable!()
}
fn deserialize_str<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
visitor.visit_string(self.parse_str())
}
fn deserialize_string<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
self.deserialize_str(visitor)
}
fn deserialize_bytes<V>(self, _visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
unreachable!()
}
fn deserialize_byte_buf<V>(self, _visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
unreachable!()
}
fn deserialize_option<V>(self, _visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
unreachable!()
}
fn deserialize_unit<V>(self, _visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
unreachable!()
}
fn deserialize_unit_struct<V>(self, _name: &'static str, _visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
unreachable!()
}
fn deserialize_newtype_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
visitor.visit_newtype_struct(self)
}
fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
visitor.visit_seq(self)
}
fn deserialize_tuple<V>(self, _len: usize, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
self.deserialize_seq(visitor)
}
fn deserialize_tuple_struct<V>(
self,
_name: &'static str,
_len: usize,
visitor: V,
) -> Result<V::Value>
where
V: Visitor<'de>,
{
self.deserialize_seq(visitor)
}
fn deserialize_map<V>(self, _visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
unreachable!()
}
fn deserialize_struct<V>(
self,
_name: &'static str,
_fields: &'static [&'static str],
visitor: V,
) -> Result<V::Value>
where
V: Visitor<'de>,
{
self.deserialize_seq(visitor)
}
fn deserialize_enum<V>(
self,
_name: &'static str,
_variants: &'static [&'static str],
_visitor: V,
) -> Result<V::Value>
where
V: Visitor<'de>,
{
unreachable!()
}
fn deserialize_identifier<V>(self, _visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
unreachable!()
}
fn deserialize_ignored_any<V>(self, _visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
unreachable!()
}
}
impl<'de, 'a> SeqAccess<'de> for &'a mut Deserializer<'de> {
type Error = Error;
fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
where
T: DeserializeSeed<'de>,
{
if self.buf.is_empty() {
Ok(None)
} else {
seed.deserialize(&mut **self).map(Some)
}
}
}