use std::mem;
use serde::{de::{
self, DeserializeSeed, SeqAccess, Visitor,
}, Deserialize};
use byteorder::{BigEndian, ReadBytesExt};
use crate::{error::{Error, Result}, rlp::RlpTree};
use paste::paste;
pub struct Deserializer<'de> {
input: &'de [u8]
}
impl<'de> Deserializer<'de> {
pub fn new(input: &'de [u8]) -> Self {
Self {
input
}
}
pub fn is_empty(&self) -> bool {
self.input.is_empty()
}
pub fn next_is_bytes(&self) -> bool {
self.input[0] < 192
}
pub fn next_bytes(&self) -> Result<(&'de [u8], &'de [u8], Self)> {
let buf = self.input;
let (start, end) = match buf[0] {
0..=127 => (0, 1),
len @ 128..=183 => (1, 1 + (len as usize - 128)),
be_len @ 184..=191 => {
let be_len = be_len as usize - 183;
let len = (&buf[1..]).read_uint::<BigEndian>(be_len)
.or(Err(Error::MalformedData))? as usize;
(1 + be_len, 1 + be_len + len)
},
_ => Err(Error::MalformedData)?
};
Ok((&buf[..end], &buf[start..end], Self::new(&buf[end..])))
}
pub fn next_seq(&self) -> Result<(&'de [u8], Self, Self)> {
let buf = self.input;
let (start, end) = match buf[0] {
len @ 192..=247 => (1, 1 + len as usize - 192),
be_len @ 248..=255 => {
let be_len = be_len as usize - 247;
let len = (&buf[1..]).read_uint::<BigEndian>(be_len)
.or(Err(Error::MalformedData))? as usize;
(1 + be_len, 1 + be_len + len)
},
_ => Err(Error::MalformedData)?
};
Ok((&buf[..end], Self::new(&buf[start..end]), Self::new(&buf[end..])))
}
}
fn be_bytes_expand<const N: usize>(src: &[u8]) -> [u8; N] {
let mut dest = [0_u8; N];
dest[N - src.len()..].copy_from_slice(src);
dest
}
macro_rules! impl_deseralize_not_supported {
($($ity:ident),+) => {
paste! {$(
fn [<deserialize_ $ity>]<V>(self, _visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
unimplemented!()
}
)+}
}
}
macro_rules! impl_deseralize_integer {
($($ity:ident),+) => {
paste! {$(
fn [<deserialize_ $ity>]<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
let (_, bytes, new) = self.next_bytes()?;
let expanded = be_bytes_expand::<{ mem::size_of::<$ity>() }>(bytes);
*self = new;
visitor.[<visit_ $ity>]($ity::from_be_bytes(expanded))
}
)+}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct RlpProxy(Vec<u8>);
impl RlpProxy {
pub fn raw(&self) -> &[u8] {
&self.0
}
pub fn rlp_tree(&self) -> RlpTree {
RlpTree::new(&self.0).unwrap()
}
}
impl<'de> Deserialize<'de> for RlpProxy {
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
where
D: serde::Deserializer<'de>
{
deserializer.deserialize_any(RlpProxyVisitor)
}
}
struct RlpProxyVisitor;
impl<'de> Visitor<'de> for RlpProxyVisitor {
type Value = RlpProxy;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("AggregateVisitor Error.")
}
fn visit_borrowed_bytes<E>(self, v: &'de [u8]) -> std::result::Result<Self::Value, E>
where
E: de::Error
{
Ok(RlpProxy(v.to_vec()))
}
}
impl<'de: 'a, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> {
type Error = Error;
impl_deseralize_not_supported! {bool, f32, f64, identifier, ignored_any, map, i16, i32, i64, i8}
impl_deseralize_integer! {u8, u16, u32, u64}
fn deserialize_any<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>
{
let (bytes, new) = if self.next_is_bytes() {
let (bytes, _, new) = self.next_bytes()?;
(bytes, new)
} else {
let (bytes, _, new) = self.next_seq()?;
(bytes, new)
};
*self = new;
visitor.visit_borrowed_bytes(bytes)
}
fn deserialize_char<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
let (_, bytes, new) = self.next_bytes()?;
*self = new;
let string = String::from_utf8(bytes.to_vec())
.or(Err(Error::MalformedData))?;
visitor.visit_char(
string
.as_str()
.chars()
.next()
.ok_or(Error::MalformedData)?
)
}
fn deserialize_str<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
let (_, bytes, new) = self.next_bytes()?;
*self = new;
let string = std::str::from_utf8(bytes)
.or(Err(Error::MalformedData))?;
visitor.visit_borrowed_str(string)
}
fn deserialize_string<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
self.deserialize_str(visitor)
}
fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
let (_, bytes, new) = self.next_bytes()?;
*self = new;
visitor.visit_borrowed_bytes(bytes)
}
fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
self.deserialize_bytes(visitor)
}
fn deserialize_option<V>(self, _visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
unimplemented!()
}
fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
let (_, seq, new) = self.next_seq()?;
*self = new;
if seq.input.is_empty() {
visitor.visit_unit()
} else {
Err(Error::MalformedData)
}
}
fn deserialize_unit_struct<V>(
self,
_name: &'static str,
visitor: V,
) -> Result<V::Value>
where
V: Visitor<'de>,
{
let (_, bytes, new) = self.next_bytes()?;
*self = new;
if bytes.is_empty() {
visitor.visit_unit()
} else {
Err(Error::MalformedData)
}
}
fn deserialize_newtype_struct<V>(
self,
_name: &'static str,
visitor: V,
) -> Result<V::Value>
where
V: Visitor<'de>,
{
visitor.visit_newtype_struct(self)
}
fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
let (_, seq, new) = self.next_seq()?;
*self = new;
visitor.visit_seq(seq)
}
fn deserialize_tuple<V>(self, _len: usize, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
self.deserialize_seq(visitor)
}
fn deserialize_tuple_struct<V>(
self,
_name: &'static str,
_len: usize,
visitor: V,
) -> Result<V::Value>
where
V: Visitor<'de>,
{
self.deserialize_seq(visitor)
}
fn deserialize_struct<V>(
self,
_name: &'static str,
_fields: &'static [&'static str],
visitor: V,
) -> Result<V::Value>
where
V: Visitor<'de>,
{
self.deserialize_seq(visitor)
}
fn deserialize_enum<V>(
self,
_name: &'static str,
_variants: &'static [&'static str],
_visitor: V,
) -> Result<V::Value>
where
V: Visitor<'de>,
{
unimplemented!()
}
}
impl<'de, 'a> SeqAccess<'de> for Deserializer<'de> {
type Error = Error;
fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
where
T: DeserializeSeed<'de>,
{
seed.deserialize(&mut *self).map(Some)
}
}