use crate::{constants::*, Error, Result};
use serde::{
de::{
self, DeserializeSeed, EnumAccess, IntoDeserializer, MapAccess, SeqAccess, VariantAccess,
Visitor,
},
Deserialize,
};
use std::{convert::TryFrom, io, io::Read, marker::PhantomData, slice};
pub fn from_slice<'de, T>(slice: &'de [u8]) -> Result<T>
where
T: Deserialize<'de>,
{
from_reader(slice)
}
pub fn from_reader<'de, R, T>(reader: R) -> Result<T>
where
T: Deserialize<'de>,
R: Read,
{
let mut de = Deserializer::new(reader);
let value = Deserialize::deserialize(&mut de)?;
de.end()?;
Ok(value)
}
pub struct Deserializer<R>
where
R: Read,
{
reader: R,
last_discriminator: Option<(u8, u8)>,
}
impl<R> Deserializer<R>
where
R: Read,
{
pub fn new(reader: R) -> Self {
Self {
reader,
last_discriminator: None,
}
}
pub fn end(&mut self) -> Result<()> {
match self.read_discriminator() {
Err(Error::Io(e)) if e.kind() == io::ErrorKind::UnexpectedEof => Ok(()),
_ => Err(Error::TrailingBytes),
}
}
#[allow(clippy::should_implement_trait)]
pub fn into_iter<'de, T>(self) -> StreamDeserializer<'de, R, T>
where
T: Deserialize<'de>,
{
StreamDeserializer {
de: self,
failed: false,
output: PhantomData,
lifetime: PhantomData,
}
}
fn read_discriminator(&mut self) -> Result<(u8, u8)> {
let mut d = 0;
self.reader.read_exact(slice::from_mut(&mut d))?;
Ok((d & TYPE_MASK, d & !TYPE_MASK))
}
fn peek_discriminator(&mut self) -> Result<(u8, u8)> {
if self.last_discriminator.is_none() {
self.last_discriminator = Some(self.read_discriminator()?);
}
Ok(self.last_discriminator.unwrap())
}
fn consume_discriminator(&mut self) -> Result<(u8, u8)> {
self.last_discriminator
.take()
.map(Result::Ok)
.unwrap_or_else(|| self.read_discriminator())
}
fn read_i64(&mut self, len: usize) -> Result<i64> {
let mut buf = [0u8; 8];
let start = 8 - len;
self.reader.read_exact(&mut buf[start..])?;
if buf[start] & 0x80 != 0 {
buf[0..start].fill(0xFF);
}
Ok(i64::from_be_bytes(buf))
}
fn read_null(&mut self) -> Result<()> {
let (typ, bits) = self.consume_discriminator()?;
if typ != TYPE_NULL {
return Err(Error::WrongType);
}
if bits != 0 {
return Err(Error::InvalidValue);
}
Ok(())
}
fn read_boolean(&mut self) -> Result<bool> {
let (typ, bits) = self.consume_discriminator()?;
if typ != TYPE_BOOLEAN {
return Err(Error::WrongType);
}
if bits > 1 {
return Err(Error::InvalidValue);
}
Ok(bits == 1)
}
fn read_integer(&mut self) -> Result<i64> {
let (typ, len) = self.consume_discriminator()?;
if typ != TYPE_INTEGER {
return Err(Error::WrongType);
}
if !len.is_power_of_two() {
return Err(Error::InvalidLength);
}
self.read_i64(len as usize)
}
fn read_float(&mut self) -> Result<f64> {
let (typ, len) = self.consume_discriminator()?;
if typ != TYPE_FLOAT {
return Err(Error::WrongType);
}
if len != 8 {
return Err(Error::InvalidLength);
}
let mut buf = [0u8; 8];
self.reader.read_exact(&mut buf)?;
Ok(f64::from_be_bytes(buf))
}
fn read_string(&mut self) -> Result<String> {
let (typ, llen) = self.consume_discriminator()?;
if typ != TYPE_STRING {
return Err(Error::WrongType);
}
if !llen.is_power_of_two() {
return Err(Error::InvalidLengthOfLength);
}
let len = self.read_i64(llen as usize)?;
if len < 0 {
return Err(Error::InvalidLength);
}
let mut s = String::with_capacity(len as usize);
let read = (&mut self.reader).take(len as u64).read_to_string(&mut s)?;
if read != len as usize {
return Err(Error::eof());
}
Ok(s)
}
fn read_raw(&mut self) -> Result<Vec<u8>> {
let (typ, llen) = self.consume_discriminator()?;
if typ != TYPE_RAW {
return Err(Error::WrongType);
}
if !llen.is_power_of_two() {
return Err(Error::InvalidLengthOfLength);
}
let len = self.read_i64(llen as usize)?;
if len < 0 {
return Err(Error::InvalidLength);
}
let mut v = Vec::with_capacity(len as usize);
let read = (&mut self.reader).take(len as u64).read_to_end(&mut v)?;
if read != len as usize {
return Err(Error::eof());
}
Ok(v)
}
fn read_list_start(&mut self) -> Result<()> {
let (typ, bits) = self.consume_discriminator()?;
if typ != TYPE_LIST {
return Err(Error::WrongType);
}
if bits != 0 {
return Err(Error::InvalidValue);
}
Ok(())
}
fn read_dictionary_start(&mut self) -> Result<()> {
let (typ, bits) = self.consume_discriminator()?;
if typ != TYPE_DICTIONARY {
return Err(Error::WrongType);
}
if bits != 0 {
return Err(Error::InvalidValue);
}
Ok(())
}
fn peek_end(&mut self) -> Result<bool> {
let (typ, bits) = self.peek_discriminator()?;
if typ != TYPE_END {
return Ok(false);
}
if bits != 0 {
return Err(Error::InvalidValue);
}
Ok(true)
}
fn read_end(&mut self) -> Result<()> {
let (typ, bits) = self.consume_discriminator()?;
if typ != TYPE_END {
return Err(Error::WrongType);
}
if bits != 0 {
return Err(Error::InvalidValue);
}
Ok(())
}
}
impl<'de, 'a, R> de::Deserializer<'de> for &'a mut Deserializer<R>
where
R: Read,
{
type Error = Error;
fn deserialize_any<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
let (typ, _) = self.peek_discriminator()?;
match typ {
TYPE_NULL => self.deserialize_unit(visitor),
TYPE_BOOLEAN => self.deserialize_bool(visitor),
TYPE_INTEGER => self.deserialize_i64(visitor),
TYPE_FLOAT => self.deserialize_f64(visitor),
TYPE_STRING => self.deserialize_str(visitor),
TYPE_RAW => self.deserialize_bytes(visitor),
TYPE_LIST => self.deserialize_seq(visitor),
TYPE_DICTIONARY => self.deserialize_map(visitor),
_ => Err(Error::InvalidType),
}
}
fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
let value = self.read_boolean()?;
visitor.visit_bool(value)
}
fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
let value = self.read_integer()?;
visitor.visit_i8(i8::try_from(value)?)
}
fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
let value = self.read_integer()?;
visitor.visit_i16(i16::try_from(value)?)
}
fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
let value = self.read_integer()?;
visitor.visit_i32(i32::try_from(value)?)
}
fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
let value = self.read_integer()?;
visitor.visit_i64(value)
}
fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
let value = self.read_integer()?;
visitor.visit_u8(u8::try_from(value)?)
}
fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
let value = self.read_integer()?;
visitor.visit_u16(u16::try_from(value)?)
}
fn deserialize_u32<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
let value = self.read_integer()?;
visitor.visit_u32(u32::try_from(value)?)
}
fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
let value = self.read_integer()?;
visitor.visit_u64(u64::try_from(value)?)
}
fn deserialize_f32<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
let value = self.read_float()?;
visitor.visit_f32(value as f32)
}
fn deserialize_f64<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
let value = self.read_float()?;
visitor.visit_f64(value)
}
fn deserialize_char<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
let value = self.read_integer()?;
visitor.visit_char(char::try_from(u32::try_from(value)?)?)
}
fn deserialize_str<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
self.deserialize_string(visitor)
}
fn deserialize_string<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
let value = self.read_string()?;
visitor.visit_string(value)
}
fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
self.deserialize_byte_buf(visitor)
}
fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
let value = self.read_raw()?;
visitor.visit_byte_buf(value)
}
fn deserialize_option<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
let (typ, _) = self.peek_discriminator()?;
match typ {
TYPE_NULL => visitor.visit_none(),
TYPE_BOOLEAN | TYPE_INTEGER | TYPE_FLOAT | TYPE_STRING | TYPE_RAW | TYPE_LIST
| TYPE_DICTIONARY => visitor.visit_some(self),
_ => Err(Error::WrongType),
}
}
fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
self.read_null()?;
visitor.visit_unit()
}
fn deserialize_unit_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
self.deserialize_unit(visitor)
}
fn deserialize_newtype_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
visitor.visit_newtype_struct(self)
}
fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
self.read_list_start()?;
let value = visitor.visit_seq(&mut *self)?;
self.read_end()?;
Ok(value)
}
fn deserialize_tuple<V>(self, _len: usize, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
self.deserialize_seq(visitor)
}
fn deserialize_tuple_struct<V>(
self,
_name: &'static str,
_len: usize,
visitor: V,
) -> Result<V::Value>
where
V: Visitor<'de>,
{
self.deserialize_seq(visitor)
}
fn deserialize_map<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
self.read_dictionary_start()?;
let value = visitor.visit_map(&mut *self)?;
self.read_end()?;
Ok(value)
}
fn deserialize_struct<V>(
self,
_name: &'static str,
fields: &'static [&'static str],
visitor: V,
) -> Result<V::Value>
where
V: Visitor<'de>,
{
self.deserialize_tuple(fields.len(), visitor)
}
fn deserialize_enum<V>(
self,
_name: &'static str,
_variants: &'static [&'static str],
visitor: V,
) -> Result<V::Value>
where
V: Visitor<'de>,
{
visitor.visit_enum(self)
}
fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
self.deserialize_str(visitor)
}
fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
self.deserialize_any(visitor)
}
}
impl<'de, 'a, R> SeqAccess<'de> for Deserializer<R>
where
R: Read,
{
type Error = Error;
fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
where
T: DeserializeSeed<'de>,
{
if self.peek_end()? {
return Ok(None);
}
seed.deserialize(self).map(Some)
}
}
impl<'de, 'a, R> MapAccess<'de> for Deserializer<R>
where
R: Read,
{
type Error = Error;
fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>>
where
K: DeserializeSeed<'de>,
{
if self.peek_end()? {
return Ok(None);
}
seed.deserialize(self).map(Some)
}
fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value>
where
V: DeserializeSeed<'de>,
{
seed.deserialize(self)
}
}
impl<'de, 'a, R> EnumAccess<'de> for &'a mut Deserializer<R>
where
R: Read,
{
type Error = Error;
type Variant = Self;
fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant)>
where
V: DeserializeSeed<'de>,
{
let (typ, _) = self.peek_discriminator()?;
match typ {
TYPE_INTEGER => {}
TYPE_LIST => self.read_list_start()?,
_ => return Err(Error::WrongType),
}
let variant_index = u32::try_from(self.read_integer()?)?;
let value: Result<_> = seed.deserialize(variant_index.into_deserializer());
Ok((value?, self))
}
}
impl<'de, 'a, R> VariantAccess<'de> for &'a mut Deserializer<R>
where
R: Read,
{
type Error = Error;
fn unit_variant(self) -> Result<()> {
Ok(())
}
fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value>
where
T: DeserializeSeed<'de>,
{
let value = seed.deserialize(&mut *self)?;
self.read_end()?;
Ok(value)
}
fn tuple_variant<V>(self, _len: usize, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
self.read_list_start()?;
let value = visitor.visit_seq(&mut *self)?;
self.read_end()?;
self.read_end()?;
Ok(value)
}
fn struct_variant<V>(self, _fields: &'static [&'static str], visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
self.read_list_start()?;
let value = visitor.visit_seq(&mut *self)?;
self.read_end()?;
self.read_end()?;
Ok(value)
}
}
pub struct StreamDeserializer<'de, R, T>
where
R: Read,
T: Deserialize<'de>,
{
de: Deserializer<R>,
failed: bool,
output: PhantomData<T>,
lifetime: PhantomData<&'de ()>,
}
impl<'de, R, T> Iterator for StreamDeserializer<'de, R, T>
where
R: Read,
T: Deserialize<'de>,
{
type Item = Result<T>;
fn next(&mut self) -> Option<Result<T>> {
if self.failed {
return None;
}
match Deserialize::deserialize(&mut self.de) {
Err(e) => {
self.failed = true;
if e.is_eof() {
None
} else {
Some(Err(e))
}
}
ok => Some(ok),
}
}
}
#[cfg(test)]
mod test {
use super::*;
use hex_literal::hex;
use serde::Deserialize;
use std::collections::HashMap;
#[test]
fn from_slice_maps() {
let buf = hex!("70 4103626172 2201C8 4103666F6F 217B 80");
let map = from_slice(&buf).unwrap();
let mut expected = HashMap::new();
expected.insert("foo".to_string(), 123u32);
expected.insert("bar".to_string(), 456u32);
assert_eq!(expected, map);
}
#[test]
fn from_slice_structs() {
#[derive(Deserialize, Debug, PartialEq, Eq)]
struct Test {
x: bool,
y: u32,
z: Vec<String>,
}
let buf = hex!("60 11 2111 60 4103666F6F 4103626172 80 80");
let s = from_slice(&buf).unwrap();
let expected = Test {
x: true,
y: 17,
z: vec!["foo".into(), "bar".into()],
};
assert_eq!(expected, s);
}
#[test]
fn from_slice_enums() {
#[derive(Deserialize, Debug, PartialEq, Eq)]
enum Test {
UnitVariant,
NewTypeVariant(u32),
TupleVariant(bool, u32),
StructVariant { x: bool, y: u32 },
}
let buf = hex!("2100");
let e = from_slice(&buf).unwrap();
let expected = Test::UnitVariant;
assert_eq!(expected, e);
let buf = hex!("60 2101 2111 80");
let e = from_slice(&buf).unwrap();
let expected = Test::NewTypeVariant(17);
assert_eq!(expected, e);
let buf = hex!("60 2102 60 11 2111 80 80");
let e = from_slice(&buf).unwrap();
let expected = Test::TupleVariant(true, 17);
assert_eq!(expected, e);
let buf = hex!("60 2103 60 11 2111 80 80");
let e = from_slice(&buf).unwrap();
let expected = Test::StructVariant { x: true, y: 17 };
assert_eq!(expected, e);
}
#[test]
fn from_slice_options() {
let buf = hex!("00");
let o = from_slice(&buf).unwrap();
let expected: Option<u32> = None;
assert_eq!(expected, o);
let buf = hex!("2111");
let o = from_slice(&buf).unwrap();
let expected = Some(17);
assert_eq!(expected, o);
}
#[test]
fn stream_deserializer() {
let buf = hex!("2100 2101 80 2103");
let vec = Deserializer::new(buf.as_ref())
.into_iter()
.collect::<Vec<Result<u64>>>();
assert_eq!(vec.len(), 3);
assert_eq!(vec[0].as_ref().unwrap(), &0);
assert_eq!(vec[1].as_ref().unwrap(), &1);
assert!(matches!(vec[2].as_ref().unwrap_err(), Error::WrongType));
let buf = hex!("2100 2101 2102 2103");
let vec = Deserializer::new(buf.as_ref())
.into_iter()
.collect::<Result<Vec<u64>>>()
.unwrap();
assert_eq!(vec![0, 1, 2, 3], vec);
}
}