use std::io;
use byteorder::{ReadBytesExt, LittleEndian};
use num_traits::ToPrimitive;
use serde::de::{self, Deserialize, DeserializeOwned, DeserializeSeed, EnumAccess, SeqAccess, VariantAccess, Visitor};
use common::{FALSE_ID, TRUE_ID};
use error;
use identifiable::{Identifiable, Wrapper};
pub struct Deserializer<R: io::Read> {
reader: R,
enum_variant_id: Option<&'static str>,
}
impl<R: io::Read> Deserializer<R> {
pub fn new(reader: R, enum_variant_id: Option<&'static str>) -> Deserializer<R> {
Deserializer {
reader: reader,
enum_variant_id: enum_variant_id,
}
}
fn get_str_info(&mut self) -> error::Result<(usize, usize)> {
let first_byte = self.reader.read_u8()?;
let len;
let rem;
if first_byte <= 253 {
len = first_byte as usize;
rem = (len + 1) % 4;
} else if first_byte == 254 {
len = self.reader.read_uint::<LittleEndian>(3)? as usize;
rem = len % 4;
} else { unreachable!();
}
Ok((len, if rem > 0 { 4 - rem } else { 0 }))
}
fn read_string(&mut self) -> error::Result<String> {
let (len, rem) = self.get_str_info()?;
let mut s_bytes = vec![0; len];
self.reader.read_exact(&mut s_bytes)?;
let s = String::from_utf8(s_bytes)?;
let mut padding = vec![0; rem];
self.reader.read_exact(&mut padding)?;
Ok(s)
}
fn read_byte_buf(&mut self) -> error::Result<Vec<u8>> {
let (len, rem) = self.get_str_info()?;
let mut b = vec![0; len];
self.reader.read_exact(&mut b)?;
let mut padding = vec![0; rem];
self.reader.read_exact(&mut padding)?;
Ok(b)
}
}
macro_rules! impl_deserialize_small_int {
($small_deserialize:ident, $small_visit:ident, $cast_to_small:ident,
$big_read:ident::<$big_endianness:ident>
) => {
fn $small_deserialize<V>(self, visitor: V) -> error::Result<V::Value>
where V: Visitor<'de>
{
let value = self.reader.$big_read::<$big_endianness>()?;
let casted = value.$cast_to_small()
.ok_or(error::Error::from(error::ErrorKind::IntegerOverflowingCast))?;
visitor.$small_visit(casted)
}
};
}
macro_rules! impl_deserialize_big_int {
($deserialize:ident, $read:ident::<$endianness:ident>, $visit:ident) => {
fn $deserialize<V>(self, visitor: V) -> error::Result<V::Value>
where V: Visitor<'de>
{
let value = self.reader.$read::<$endianness>()?;
visitor.$visit(value)
}
};
}
impl<'de, 'a, R> de::Deserializer<'de> for &'a mut Deserializer<R>
where R: io::Read
{
type Error = error::Error;
fn deserialize_any<V>(self, _visitor: V) -> error::Result<V::Value>
where V: Visitor<'de>
{
Err("Non self-described format".into())
}
fn deserialize_bool<V>(self, visitor: V) -> error::Result<V::Value>
where V: Visitor<'de>
{
let id_value = self.reader.read_i32::<LittleEndian>()?;
let value = match id_value {
TRUE_ID => true,
FALSE_ID => false,
_ => return Err("Expected a bool".into())
};
visitor.visit_bool(value)
}
impl_deserialize_small_int!(deserialize_i8, visit_i8, to_i8, read_i32::<LittleEndian>);
impl_deserialize_small_int!(deserialize_i16, visit_i16, to_i16, read_i32::<LittleEndian>);
impl_deserialize_small_int!(deserialize_u8, visit_u8, to_u8, read_u32::<LittleEndian>);
impl_deserialize_small_int!(deserialize_u16, visit_u16, to_u16, read_u32::<LittleEndian>);
impl_deserialize_big_int!(deserialize_i32, read_i32::<LittleEndian>, visit_i32);
impl_deserialize_big_int!(deserialize_i64, read_i64::<LittleEndian>, visit_i64);
impl_deserialize_big_int!(deserialize_u32, read_u32::<LittleEndian>, visit_u32);
impl_deserialize_big_int!(deserialize_u64, read_u64::<LittleEndian>, visit_u64);
impl_deserialize_big_int!(deserialize_f32, read_f32::<LittleEndian>, visit_f32);
impl_deserialize_big_int!(deserialize_f64, read_f64::<LittleEndian>, visit_f64);
fn deserialize_char<V>(self, _visitor: V) -> error::Result<V::Value>
where V: Visitor<'de>
{
unreachable!("this method shouldn't be called")
}
fn deserialize_str<V>(self, visitor: V) -> error::Result<V::Value>
where V: Visitor<'de>
{
let s = self.read_string()?;
visitor.visit_str(&s)
}
fn deserialize_string<V>(self, visitor: V) -> error::Result<V::Value>
where V: Visitor<'de>
{
let s = self.read_string()?;
visitor.visit_string(s)
}
fn deserialize_bytes<V>(self, visitor: V) -> error::Result<V::Value>
where V: Visitor<'de>
{
let b = self.read_byte_buf()?;
visitor.visit_bytes(&b)
}
fn deserialize_byte_buf<V>(self, visitor: V) -> error::Result<V::Value>
where V: Visitor<'de>
{
let b = self.read_byte_buf()?;
visitor.visit_byte_buf(b)
}
fn deserialize_option<V>(self, _visitor: V) -> error::Result<V::Value>
where V: Visitor<'de>
{
unimplemented!()
}
fn deserialize_unit<V>(self, _visitor: V) -> error::Result<V::Value>
where V: Visitor<'de>
{
unreachable!("this method shouldn't be called")
}
fn deserialize_unit_struct<V>(self, _name: &'static str, visitor: V) -> error::Result<V::Value>
where V: Visitor<'de>
{
visitor.visit_unit()
}
fn deserialize_newtype_struct<V>(self, _name: &'static str, visitor: V) -> error::Result<V::Value>
where V: Visitor<'de>
{
visitor.visit_newtype_struct(self)
}
fn deserialize_seq<V>(mut self, visitor: V) -> error::Result<V::Value>
where V: Visitor<'de>
{
visitor.visit_seq(Combinator::without_count(&mut self))
}
fn deserialize_tuple<V>(mut self, len: usize, visitor: V) -> error::Result<V::Value>
where V: Visitor<'de>
{
visitor.visit_seq(Combinator::with_count(&mut self, len))
}
fn deserialize_tuple_struct<V>(self, _name: &'static str, _len: usize, _visitor: V) -> error::Result<V::Value>
where V: Visitor<'de>
{
unimplemented!()
}
fn deserialize_map<V>(self, _visitor: V) -> error::Result<V::Value>
where V: Visitor<'de>
{
unimplemented!()
}
fn deserialize_struct<V>(mut self, _name: &'static str, fields: &'static [&'static str], visitor: V) -> error::Result<V::Value>
where V: Visitor<'de>
{
visitor.visit_seq(Combinator::with_count(&mut self, fields.len()))
}
fn deserialize_enum<V>(mut self, _name: &'static str, _variants: &'static [&'static str], visitor: V) -> error::Result<V::Value>
where V: Visitor<'de>
{
visitor.visit_enum(Combinator::without_count(&mut self))
}
fn deserialize_identifier<V>(self, visitor: V) -> error::Result<V::Value>
where V: Visitor<'de>
{
let variant_id = self.enum_variant_id.unwrap();
visitor.visit_str(variant_id)
}
fn deserialize_ignored_any<V>(self, _visitor: V) -> error::Result<V::Value>
where V: Visitor<'de>
{
unimplemented!()
}
}
struct Combinator<'a, R: 'a + io::Read> {
de: &'a mut Deserializer<R>,
next_index: usize,
count: Option<usize>,
}
impl<'a, R: io::Read> Combinator<'a, R> {
fn without_count(de: &'a mut Deserializer<R>) -> Combinator<'a, R> {
Combinator {
de: de,
next_index: 0,
count: None,
}
}
fn with_count(de: &'a mut Deserializer<R>, count: usize) -> Combinator<'a, R> {
Combinator {
de: de,
next_index: 0,
count: Some(count),
}
}
}
impl<'de, 'a, R> SeqAccess<'de> for Combinator<'a, R>
where R: 'a + io::Read
{
type Error = error::Error;
fn next_element_seed<T>(&mut self, seed: T) -> error::Result<Option<T::Value>>
where T: DeserializeSeed<'de>
{
if let Some(count) = self.count {
if self.next_index < count {
self.next_index += 1;
} else {
return Ok(None);
}
}
seed.deserialize(&mut *self.de).map(Some)
}
}
impl<'de, 'a, R> EnumAccess<'de> for Combinator<'a, R>
where R: 'a + io::Read
{
type Error = error::Error;
type Variant = Self;
fn variant_seed<V>(self, seed: V) -> error::Result<(V::Value, Self::Variant)>
where V: DeserializeSeed<'de>
{
let value = seed.deserialize(&mut *self.de)?;
Ok((value, self))
}
}
impl<'de, 'a, R> VariantAccess<'de> for Combinator<'a, R>
where R: 'a + io::Read
{
type Error = error::Error;
fn unit_variant(self) -> error::Result<()> {
Ok(())
}
fn newtype_variant_seed<T>(self, seed: T) -> error::Result<T::Value>
where T: DeserializeSeed<'de>
{
seed.deserialize(self.de)
}
fn tuple_variant<V>(self, len: usize, visitor: V) -> error::Result<V::Value>
where V: Visitor<'de>
{
de::Deserializer::deserialize_tuple_struct(self.de, "", len, visitor)
}
fn struct_variant<V>(self, fields: &'static [&'static str], visitor: V) -> error::Result<V::Value>
where V: Visitor<'de>
{
de::Deserializer::deserialize_struct(self.de, "", fields, visitor)
}
}
pub fn from_bytes<'a, T>(bytes: &'a [u8], enum_variant_id: Option<&'static str>) -> error::Result<T>
where T: Deserialize<'a>
{
let mut de = Deserializer::new(bytes, enum_variant_id);
let value: T = Deserialize::deserialize(&mut de)?;
Ok(value)
}
pub fn from_bytes_identifiable<'a, T>(bytes: &'a [u8], enum_variant_id: Option<&'static str>) -> error::Result<T>
where T: Deserialize<'a> + Identifiable
{
let wrapper: Wrapper<T> = from_bytes(bytes, enum_variant_id)?;
Ok(wrapper.take_data())
}
pub fn from_reader<R, T>(reader: R, enum_variant_id: Option<&'static str>) -> error::Result<T>
where R: io::Read,
T: DeserializeOwned,
{
let mut de = Deserializer::new(reader, enum_variant_id);
let value: T = Deserialize::deserialize(&mut de)?;
Ok(value)
}
pub fn from_reader_identifiable<R, T>(reader: R, enum_variant_id: Option<&'static str>) -> error::Result<T>
where R: io::Read,
T: DeserializeOwned + Identifiable,
{
let wrapper: Wrapper<T> = from_reader(reader, enum_variant_id)?;
Ok(wrapper.take_data())
}