use crate::consts::{U16_TAG, U32_TAG, U64_TAG, U8_TAG};
use crate::error::SerializerError;
use serde::de;
use serde::de::{DeserializeSeed, EnumAccess, MapAccess, SeqAccess, VariantAccess, Visitor};
pub struct BinDeserializer<'de> {
pub(crate) bts: &'de [u8],
}
pub struct SubDeserializerSeq<'a, 'de> {
deserializer: &'a mut BinDeserializer<'de>,
rest_len: usize,
}
pub struct SubDeserializerEnum<'a, 'de> {
deserializer: &'a mut BinDeserializer<'de>,
}
impl<'de> BinDeserializer<'de> {
fn get_str(&mut self) -> Result<&'de str, SerializerError> {
std::str::from_utf8(self.get_bytes()?).map_err(|_| SerializerError::InvalidData)
}
fn get_bytes(&mut self) -> Result<&'de [u8], SerializerError> {
let len = self.get_size()?;
if self.bts.len() < len {
return Err(SerializerError::UnexpectedEOF);
}
let val = &self.bts[..len];
self.bts = &self.bts[len..];
Ok(val)
}
fn get_size(&mut self) -> Result<usize, SerializerError> {
if self.bts.is_empty() {
return Err(SerializerError::UnexpectedEOF);
}
let tag = self.bts[0];
self.bts = &self.bts[1..];
match () {
_ if tag == U8_TAG => self.get_u8().map(|num| num as usize),
_ if tag == U16_TAG => self.get_u16().map(|num| num as usize),
_ if tag == U32_TAG => self.get_u32().map(|num| num as usize),
_ if tag == U64_TAG => self.get_u64().map(|num| num as usize),
_ => Err(SerializerError::InvalidData),
}
}
fn get_u8(&mut self) -> Result<u8, SerializerError> {
if self.bts.is_empty() {
return Err(SerializerError::UnexpectedEOF);
}
let val = u8::from_le(self.bts[0]);
self.bts = &self.bts[1..];
Ok(val)
}
fn get_u16(&mut self) -> Result<u16, SerializerError> {
const BUF_LEN: usize = 2;
if self.bts.len() < BUF_LEN {
return Err(SerializerError::UnexpectedEOF);
}
let val = u16::from_le_bytes([self.bts[0], self.bts[1]]);
self.bts = &self.bts[BUF_LEN..];
Ok(val)
}
fn get_u32(&mut self) -> Result<u32, SerializerError> {
const BUF_LEN: usize = 4;
if self.bts.len() < BUF_LEN {
return Err(SerializerError::UnexpectedEOF);
}
let val = u32::from_le_bytes([self.bts[0], self.bts[1], self.bts[2], self.bts[3]]);
self.bts = &self.bts[BUF_LEN..];
Ok(val)
}
fn get_u64(&mut self) -> Result<u64, SerializerError> {
const BUF_LEN: usize = 8;
if self.bts.len() < BUF_LEN {
return Err(SerializerError::UnexpectedEOF);
}
let val = u64::from_le_bytes([
self.bts[0],
self.bts[1],
self.bts[2],
self.bts[3],
self.bts[4],
self.bts[5],
self.bts[6],
self.bts[7],
]);
self.bts = &self.bts[BUF_LEN..];
Ok(val)
}
}
impl<'de, 'a> de::Deserializer<'de> for &'a mut BinDeserializer<'de> {
type Error = SerializerError;
fn deserialize_any<V>(self, _: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
panic!("any is not implemented")
}
fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
if self.bts.is_empty() {
return Err(SerializerError::UnexpectedEOF);
}
self.bts = &self.bts[1..];
visitor.visit_bool(self.bts[0] != 0)
}
fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
if self.bts.is_empty() {
return Err(SerializerError::UnexpectedEOF);
}
let val = i8::from_le_bytes([self.bts[0]]);
self.bts = &self.bts[1..];
visitor.visit_i8(val)
}
fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
const BUF_LEN: usize = 2;
if self.bts.len() < BUF_LEN {
return Err(SerializerError::UnexpectedEOF);
}
let val = i16::from_le_bytes([self.bts[0], self.bts[1]]);
self.bts = &self.bts[BUF_LEN..];
visitor.visit_i16(val)
}
fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
const BUF_LEN: usize = 4;
if self.bts.len() < BUF_LEN {
return Err(SerializerError::UnexpectedEOF);
}
let val = i32::from_le_bytes([self.bts[0], self.bts[1], self.bts[2], self.bts[3]]);
self.bts = &self.bts[BUF_LEN..];
visitor.visit_i32(val)
}
fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
const BUF_LEN: usize = 8;
if self.bts.len() < BUF_LEN {
return Err(SerializerError::UnexpectedEOF);
}
let val = i64::from_le_bytes([
self.bts[0],
self.bts[1],
self.bts[2],
self.bts[3],
self.bts[4],
self.bts[5],
self.bts[6],
self.bts[7],
]);
self.bts = &self.bts[BUF_LEN..];
visitor.visit_i64(val)
}
fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
visitor.visit_u8(self.get_u8()?)
}
fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
visitor.visit_u16(self.get_u16()?)
}
fn deserialize_u32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
visitor.visit_u32(self.get_u32()?)
}
fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
visitor.visit_u64(self.get_u64()?)
}
fn deserialize_f32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
const BUF_LEN: usize = 4;
if self.bts.len() < BUF_LEN {
return Err(SerializerError::UnexpectedEOF);
}
let val = f32::from_le_bytes([self.bts[0], self.bts[1], self.bts[2], self.bts[3]]);
self.bts = &self.bts[BUF_LEN..];
visitor.visit_f32(val)
}
fn deserialize_f64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
const BUF_LEN: usize = 8;
if self.bts.len() < BUF_LEN {
return Err(SerializerError::UnexpectedEOF);
}
let val = f64::from_le_bytes([
self.bts[0],
self.bts[1],
self.bts[2],
self.bts[3],
self.bts[4],
self.bts[5],
self.bts[6],
self.bts[7],
]);
self.bts = &self.bts[BUF_LEN..];
visitor.visit_f64(val)
}
fn deserialize_char<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
match self.get_str()?.chars().next() {
Some(sym) => visitor.visit_char(sym),
None => Err(SerializerError::InvalidData),
}
}
fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
visitor.visit_str(self.get_str()?)
}
fn deserialize_string<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
let s = self.get_str()?.to_string();
visitor.visit_string(s)
}
fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
visitor.visit_bytes(self.get_bytes()?)
}
fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
let buf = self.get_bytes()?.to_vec();
visitor.visit_byte_buf(buf)
}
fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
if self.get_u8()? == 0 {
visitor.visit_none()
} else {
visitor.visit_some(self)
}
}
fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
visitor.visit_unit()
}
fn deserialize_unit_struct<V>(
self,
_: &'static str,
visitor: V,
) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
visitor.visit_unit()
}
fn deserialize_newtype_struct<V>(
self,
_: &'static str,
visitor: V,
) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
visitor.visit_newtype_struct(self)
}
fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
let rest_len = self.get_size()?;
let sub = SubDeserializerSeq {
deserializer: self,
rest_len,
};
visitor.visit_seq(sub)
}
fn deserialize_tuple<V>(self, len: usize, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
let sub = SubDeserializerSeq {
deserializer: self,
rest_len: len,
};
visitor.visit_seq(sub)
}
fn deserialize_tuple_struct<V>(
self,
_: &'static str,
len: usize,
visitor: V,
) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
self.deserialize_tuple(len, visitor)
}
fn deserialize_map<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
let rest_len = self.get_size()?;
let sub = SubDeserializerSeq {
deserializer: self,
rest_len,
};
visitor.visit_map(sub)
}
fn deserialize_struct<V>(
self,
_: &'static str,
fields: &'static [&'static str],
visitor: V,
) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
self.deserialize_tuple(fields.len(), visitor)
}
fn deserialize_enum<V>(
self,
_: &'static str,
_: &'static [&'static str],
visitor: V,
) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
let sub = SubDeserializerEnum { deserializer: self };
visitor.visit_enum(sub)
}
fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
let index = self.get_size()?;
visitor.visit_u32(index as u32)
}
fn deserialize_ignored_any<V>(self, _: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
panic!("ignored any is not implemented")
}
}
impl<'a, 'de> SeqAccess<'de> for SubDeserializerSeq<'a, 'de> {
type Error = SerializerError;
fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error>
where
T: DeserializeSeed<'de>,
{
if self.rest_len == 0 {
return Ok(None);
}
self.rest_len -= 1;
let val = seed.deserialize(&mut *self.deserializer)?;
Ok(Some(val))
}
}
impl<'a, 'de> MapAccess<'de> for SubDeserializerSeq<'a, 'de> {
type Error = SerializerError;
fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, Self::Error>
where
K: DeserializeSeed<'de>,
{
self.next_element_seed(seed)
}
fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Self::Error>
where
V: DeserializeSeed<'de>,
{
seed.deserialize(&mut *self.deserializer)
}
}
impl<'a, 'de> EnumAccess<'de> for SubDeserializerEnum<'a, 'de> {
type Error = SerializerError;
type Variant = Self;
fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant), Self::Error>
where
V: DeserializeSeed<'de>,
{
let val = seed.deserialize(&mut *self.deserializer)?;
Ok((val, self))
}
}
impl<'de, 'a> VariantAccess<'de> for SubDeserializerEnum<'a, 'de> {
type Error = SerializerError;
fn unit_variant(self) -> Result<(), Self::Error> {
Ok(())
}
fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value, Self::Error>
where
T: DeserializeSeed<'de>,
{
seed.deserialize(&mut *self.deserializer)
}
fn tuple_variant<V>(self, len: usize, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
let sub = SubDeserializerSeq {
rest_len: len,
deserializer: self.deserializer,
};
visitor.visit_seq(sub)
}
fn struct_variant<V>(
self,
fields: &'static [&'static str],
visitor: V,
) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
let sub = SubDeserializerSeq {
rest_len: fields.len(),
deserializer: self.deserializer,
};
visitor.visit_seq(sub)
}
}