use std::cmp;
use std::mem;
use std::str;
use std::slice;
use std::convert::{ TryFrom, TryInto };
use std::io::{ Read, BufReader, Cursor, Chain };
use std::borrow::Cow;
use serde::de::{
Deserialize, DeserializeSeed, Deserializer,
IntoDeserializer, value::BorrowedStrDeserializer,
SeqAccess, MapAccess, EnumAccess, VariantAccess,
Visitor, IgnoredAny,
};
use byteorder::{ ByteOrder, LittleEndian };
use crate::error::{ Error, ResultExt };
use crate::format::*;
use super::*;
pub fn from_reader<R, T>(reader: R) -> Result<T, Error>
where
R: Read,
T: for<'de> Deserialize<'de>,
{
let mut de = BinaryStreamDeserializer::new(reader)?;
T::deserialize(&mut de)
}
pub fn from_reader_buffered<R, T>(reader: R) -> Result<T, Error>
where
R: Read,
T: for<'de> Deserialize<'de>,
{
from_reader(BufReader::new(reader))
}
fn visit_cow_str<'de, V: Visitor<'de>>(cow: Cow<str>, visitor: V) -> Result<V::Value, Error> {
match cow {
Cow::Borrowed(s) => visitor.visit_str(s),
Cow::Owned(s) => visitor.visit_string(s),
}
}
fn visit_cow_bytes<'de, V: Visitor<'de>>(cow: Cow<[u8]>, visitor: V) -> Result<V::Value, Error> {
match cow {
Cow::Borrowed(bytes) => visitor.visit_bytes(bytes),
Cow::Owned(bytes) => visitor.visit_byte_buf(bytes),
}
}
fn cow_str_from_utf8(cow: Cow<[u8]>) -> Result<Cow<str>, Error> {
match cow {
Cow::Borrowed(bytes) => Ok(Cow::Borrowed(str::from_utf8(bytes)?)),
Cow::Owned(bytes) => Ok(Cow::Owned(String::from_utf8(bytes)?)),
}
}
fn cautious_size_hint(len: usize) -> usize {
cmp::min(len, usize::from(u16::MAX))
}
trait ReadExt: Read + Sized {
fn read_byte(&mut self) -> Result<u8, Error> {
let mut byte = 0x00;
self.read_exact(slice::from_mut(&mut byte))?;
Ok(byte)
}
fn read_buf_exact(&mut self, len: usize) -> Result<Vec<u8>, Error> {
let capacity = cautious_size_hint(len);
let mut buf = Vec::with_capacity(capacity);
let actually_read = self
.by_ref()
.take(u64::try_from(len)?)
.read_to_end(&mut buf)?;
if len == actually_read {
Ok(buf)
} else {
Err(Error::custom("unexpected end of input"))
}
}
fn read_uint(&mut self) -> Result<usize, Error> {
let tag = self.read_byte()?;
if tag.is_major(MAJOR_TYPE_SMALL_UINT) {
Ok(usize::from(decode_small_uint(tag)))
} else if tag.is_major_minor(MAJOR_TYPE_BIG_VALUE, MINOR_TYPE_UINT) {
self.read_big_uint(tag.decode_log_length())
} else {
corrupted(format_args!(
"invalid type for use count: major {:08b} minor {:08b}",
tag & MAJOR_TYPE_MASK,
tag & MINOR_TYPE_MASK,
))
}
}
fn read_big_uint(&mut self, length: usize) -> Result<usize, Error> {
let mut buf: [u8; 8] = [0; 8];
let slice = &mut buf[..length];
self.read_exact(slice)?;
let num = LittleEndian::read_uint(slice, length);
num.try_into().conv_err()
}
fn prepend_byte(self, byte: Option<u8>) -> Chain<Cursor<[u8; 1]>, Self> {
let buf = [byte.unwrap_or_default()];
let pos = byte.is_none() as u64;
let mut cursor = Cursor::new(buf);
cursor.set_position(pos);
cursor.chain(self)
}
}
impl<R: Read> ReadExt for R {}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
enum SymbolBuf {
Blob(Box<[u8]>),
Str(Box<str>),
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
struct Symbol {
buf: SymbolBuf,
remaining_uses: usize,
}
impl Symbol {
fn parse<R: Read>(reader: &mut R) -> Result<Self, Error> {
let tag = reader.read_byte()?;
let SymbolFlags { is_big, is_string, is_multi } = tag.try_into()?;
let buf_len = if is_big {
reader.read_big_uint(tag.decode_log_length())?
} else {
usize::from(decode_small_uint(tag))
};
let remaining_uses = if is_multi {
reader.read_uint()?
} else {
1
};
let payload = reader.read_buf_exact(buf_len)?;
let buf = if is_string {
SymbolBuf::Str(String::from_utf8(payload)?.into())
} else {
SymbolBuf::Blob(payload.into())
};
Ok(Symbol { buf, remaining_uses })
}
fn use_blob(&mut self) -> Result<Cow<[u8]>, Error> {
if self.remaining_uses == 0 {
return corrupted("used blob symbol after last declared use");
}
self.remaining_uses -= 1;
if self.remaining_uses == 0 {
match self.buf {
SymbolBuf::Blob(ref mut b) => {
let buf = mem::take(b);
Ok(Cow::Owned(Vec::from(buf)))
}
SymbolBuf::Str(ref mut s) => {
let buf = mem::take(s);
Ok(Cow::Owned(String::from(buf).into_bytes()))
}
}
} else {
match self.buf {
SymbolBuf::Blob(ref b) => Ok(Cow::Borrowed(b)),
SymbolBuf::Str(ref s) => Ok(Cow::Borrowed(s.as_bytes())),
}
}
}
fn use_str(&mut self) -> Result<Cow<str>, Error> {
let s = match self.buf {
SymbolBuf::Blob(_) => corrupted("attempted to use blob as string")?,
SymbolBuf::Str(ref mut s) => s,
};
if self.remaining_uses == 0 {
return corrupted("used string symbol after last declared use");
}
self.remaining_uses -= 1;
if self.remaining_uses == 0 {
let buf = mem::take(s);
Ok(Cow::Owned(String::from(buf)))
} else {
Ok(Cow::Borrowed(s))
}
}
}
#[derive(Debug, Clone, Default, PartialEq, Eq, Hash)]
struct SymbolTable {
symbols: Vec<Symbol>,
}
impl SymbolTable {
fn parse<R: Read>(reader: &mut R) -> Result<(Self, Option<u8>), Error> {
let byte = reader.read_byte()?;
if byte.is_major_minor(MAJOR_TYPE_SIMPLE, MINOR_TYPE_SYMTAB) {
let symtab_len = reader.read_big_uint(byte.decode_log_length())?;
let capacity = cautious_size_hint(symtab_len);
let mut symtab = SymbolTable {
symbols: Vec::with_capacity(capacity)
};
for _ in 0..symtab_len {
symtab.symbols.push(Symbol::parse(reader)?);
}
Ok((symtab, None))
} else {
Ok((SymbolTable::default(), Some(byte)))
}
}
fn len(&self) -> usize {
self.symbols.len()
}
fn use_blob(&mut self, index: usize) -> Result<Cow<[u8]>, Error> {
let len = self.len();
if let Some(symbol) = self.symbols.get_mut(index) {
symbol.use_blob()
} else {
corrupted(format_args!(
"owned blob #{} out of bounds for symtab of size {}",
index, len
))
}
}
fn use_str(&mut self, index: usize) -> Result<Cow<str>, Error> {
let len = self.len();
if let Some(symbol) = self.symbols.get_mut(index) {
symbol.use_str()
} else {
corrupted(format_args!(
"owned string #{} out of bounds for symtab of size {}",
index, len,
))
}
}
}
#[derive(Debug)]
pub struct BinaryStreamDeserializer<R> {
reader: Chain<Cursor<[u8; 1]>, R>,
symtab: SymbolTable,
}
impl<R: Read> BinaryStreamDeserializer<R> {
pub fn new(mut reader: R) -> Result<Self, Error> {
let (symtab, head) = SymbolTable::parse(&mut reader)?;
let reader = reader.prepend_byte(head);
Ok(BinaryStreamDeserializer { reader, symtab })
}
fn read_value_header(&mut self, exp: &dyn Expected) -> Result<ValueHeader, Error> {
let b = self.reader.read_byte().chain(
|| format!("missing value; expected {}", exp)
)?;
let mut buf: [u8; 8] = [0; 8];
read_value_header(b, |len| {
let slice = &mut buf[..len];
self.reader.read_exact(slice)?;
Ok(slice)
})
}
fn deserialize_number<'de, V: Visitor<'de>>(
&mut self,
visitor: V,
) -> Result<V::Value, Error> {
visit_number(self.read_value_header(&visitor)?, visitor)
}
fn visit_and_exhaust_seq<'de, V: Visitor<'de>>(
&mut self,
count: usize,
visitor: V,
) -> Result<V::Value, Error> {
let mut seq = SeqDeserializer::new(self, count);
let value = visitor.visit_seq(&mut seq)?;
seq.exhaust()?;
Ok(value)
}
fn visit_and_exhaust_map<'de, V: Visitor<'de>>(
&mut self,
count: usize,
visitor: V,
) -> Result<V::Value, Error> {
let mut map = MapDeserializer::new(self, count);
let value = visitor.visit_map(&mut map)?;
map.exhaust()?;
Ok(value)
}
}
impl<'de, R: Read> Deserializer<'de> for &mut BinaryStreamDeserializer<R> {
type Error = Error;
fn is_human_readable(&self) -> bool {
false
}
fn deserialize_any<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
use ValueHeader::*;
match self.read_value_header(&visitor)? {
Null => visitor.visit_unit(),
Opt => visitor.visit_some(self),
Bool(b) => visitor.visit_bool(b),
I8(x) => visitor.visit_i8(x),
I16(x) => visitor.visit_i16(x),
I32(x) => visitor.visit_i32(x),
I64(x) => visitor.visit_i64(x),
U8(x) => visitor.visit_u8(x),
U16(x) => visitor.visit_u16(x),
U32(x) => visitor.visit_u32(x),
U64(x) => visitor.visit_u64(x),
F32(x) => visitor.visit_f32(x.into()),
F64(x) => visitor.visit_f64(x.into()),
EmptyString => visitor.visit_borrowed_str(""),
EmptyBlob => visitor.visit_borrowed_bytes(&[]),
String(index) => {
let string = self.symtab.use_str(index)?;
visit_cow_str(string, visitor)
},
Blob(index) => {
let bytes = self.symtab.use_blob(index)?;
visit_cow_bytes(bytes, visitor)
},
Array(count) => self.visit_and_exhaust_seq(count, visitor),
Map(count) => self.visit_and_exhaust_map(count, visitor),
}
}
fn deserialize_bool<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
match self.read_value_header(&visitor)? {
ValueHeader::Bool(b) => visitor.visit_bool(b),
value @ _ => type_error(value, &visitor),
}
}
fn deserialize_i8<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
self.deserialize_number(visitor)
}
fn deserialize_i16<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
self.deserialize_number(visitor)
}
fn deserialize_i32<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
self.deserialize_number(visitor)
}
fn deserialize_i64<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
self.deserialize_number(visitor)
}
fn deserialize_i128<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
self.deserialize_number(visitor)
}
fn deserialize_u8<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
self.deserialize_number(visitor)
}
fn deserialize_u16<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
self.deserialize_number(visitor)
}
fn deserialize_u32<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
self.deserialize_number(visitor)
}
fn deserialize_u64<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
self.deserialize_number(visitor)
}
fn deserialize_u128<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
self.deserialize_number(visitor)
}
fn deserialize_f32<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
self.deserialize_number(visitor)
}
fn deserialize_f64<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
self.deserialize_number(visitor)
}
fn deserialize_char<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
self.deserialize_str(visitor)
}
fn deserialize_str<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
use ValueHeader::*;
match self.read_value_header(&visitor)? {
EmptyString | EmptyBlob => visitor.visit_borrowed_str(""),
String(index) => {
let string = self.symtab.use_str(index)?;
visit_cow_str(string, visitor)
},
Blob(index) => {
let blob = self.symtab.use_blob(index)?;
let string = cow_str_from_utf8(blob)?;
visit_cow_str(string, visitor)
},
value @ _ => type_error(value, &visitor),
}
}
fn deserialize_string<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
self.deserialize_str(visitor)
}
fn deserialize_bytes<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
use ValueHeader::*;
match self.read_value_header(&visitor)? {
EmptyString => visitor.visit_borrowed_str(""),
EmptyBlob => visitor.visit_borrowed_bytes(&[]),
String(index) => {
let string = self.symtab.use_str(index)?;
visit_cow_str(string, visitor)
},
Blob(index) => {
let bytes = self.symtab.use_blob(index)?;
visit_cow_bytes(bytes, visitor)
},
Array(count) => self.visit_and_exhaust_seq(count, visitor),
value @ _ => type_error(value, &visitor),
}
}
fn deserialize_byte_buf<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
self.deserialize_bytes(visitor)
}
fn deserialize_option<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
let value = self.read_value_header(&visitor)?;
visit_option(value, self, visitor)
}
fn deserialize_unit<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
visit_unit(self.read_value_header(&visitor)?, visitor)
}
fn deserialize_unit_struct<V: Visitor<'de>>(
self,
_name: &'static str,
visitor: V,
) -> Result<V::Value, Self::Error> {
self.deserialize_unit(visitor)
}
fn deserialize_newtype_struct<V: Visitor<'de>>(
self,
_name: &'static str,
visitor: V,
) -> Result<V::Value, Self::Error> {
visitor.visit_newtype_struct(self)
}
fn deserialize_seq<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
use ValueHeader::Array;
match self.read_value_header(&visitor)? {
Array(count) => self.visit_and_exhaust_seq(count, visitor),
value @ _ => type_error(value, &visitor),
}
}
fn deserialize_tuple<V: Visitor<'de>>(
self,
_len: usize,
visitor: V,
) -> Result<V::Value, Self::Error> {
self.deserialize_seq(visitor)
}
fn deserialize_tuple_struct<V: Visitor<'de>>(
self,
_name: &'static str,
len: usize,
visitor: V,
) -> Result<V::Value, Self::Error> {
self.deserialize_tuple(len, visitor)
}
fn deserialize_map<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
use ValueHeader::Map;
match self.read_value_header(&visitor)? {
Map(count) => self.visit_and_exhaust_map(count, visitor),
value @ _ => type_error(value, &visitor),
}
}
fn deserialize_struct<V: Visitor<'de>>(
self,
_name: &'static str,
_fields: &'static [&'static str],
visitor: V,
) -> Result<V::Value, Self::Error> {
self.deserialize_map(visitor)
}
fn deserialize_identifier<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
self.deserialize_str(visitor)
}
fn deserialize_enum<V: Visitor<'de>>(
self,
_type_name: &'static str,
_variants: &'static [&'static str],
visitor: V,
) -> Result<V::Value, Self::Error> {
match self.read_value_header(&visitor)? {
ValueHeader::String(index) => {
let string = self.symtab.use_str(index)?;
visitor.visit_enum(string.into_deserializer())
},
ValueHeader::EmptyString => {
let deserializer = BorrowedStrDeserializer::new("");
visitor.visit_enum(deserializer)
},
ValueHeader::Map(count) => {
if count == 1 {
visitor.visit_enum(self)
} else {
Err(Error::invalid_length(count, &"enum as single-key map"))
}
},
value @ _ => type_error(value, &"enum as string or single-key map"),
}
}
fn deserialize_ignored_any<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
self.deserialize_any(IgnoredAny).and_then(|_| visitor.visit_unit())
}
}
impl<'de, R: Read> EnumAccess<'de> for &mut BinaryStreamDeserializer<R> {
type Error = Error;
type Variant = Self;
fn variant_seed<V: DeserializeSeed<'de>>(
self,
seed: V
) -> Result<(V::Value, Self::Variant), Self::Error> {
seed.deserialize(&mut *self).map(|v| (v, self))
}
}
impl<'de, R: Read> VariantAccess<'de> for &mut BinaryStreamDeserializer<R> {
type Error = <Self as EnumAccess<'de>>::Error;
fn unit_variant(self) -> Result<(), Self::Error> {
Deserialize::deserialize(self)
}
fn newtype_variant_seed<T: DeserializeSeed<'de>>(
self,
seed: T,
) -> Result<T::Value, Self::Error> {
seed.deserialize(self)
}
fn tuple_variant<V: Visitor<'de>>(
self,
len: usize,
visitor: V,
) -> Result<V::Value, Self::Error> {
self.deserialize_tuple(len, visitor)
}
fn struct_variant<V: Visitor<'de>>(
self,
_fields: &'static [&'static str],
visitor: V,
) -> Result<V::Value, Self::Error> {
self.deserialize_map(visitor)
}
}
#[derive(Debug)]
struct SeqDeserializer<'a, R> {
deserializer: &'a mut BinaryStreamDeserializer<R>,
remaining: usize,
}
impl<'a, R: Read> SeqDeserializer<'a, R> {
fn new(de: &'a mut BinaryStreamDeserializer<R>, count: usize) -> Self {
SeqDeserializer {
deserializer: de,
remaining: count,
}
}
fn exhaust(&mut self) -> Result<(), Error> {
while let Some(IgnoredAny) = self.next_element()? {}
Ok(())
}
}
impl<'a, 'de, R: Read> SeqAccess<'de> for SeqDeserializer<'a, R> {
type Error = Error;
fn next_element_seed<T: DeserializeSeed<'de>>(
&mut self,
seed: T,
) -> Result<Option<T::Value>, Self::Error> {
if self.remaining == 0 {
Ok(None)
} else {
self.remaining -= 1;
seed.deserialize(&mut *self.deserializer).map(Some)
}
}
fn size_hint(&self) -> Option<usize> {
cautious_size_hint(self.remaining).into()
}
}
#[derive(Debug)]
struct MapDeserializer<'a, R> {
deserializer: &'a mut BinaryStreamDeserializer<R>,
remaining: usize,
}
impl<'a, R: Read> MapDeserializer<'a, R> {
fn new(de: &'a mut BinaryStreamDeserializer<R>, count: usize) -> Self {
MapDeserializer {
deserializer: de,
remaining: count,
}
}
fn exhaust(&mut self) -> Result<(), Error> {
while let Some((IgnoredAny, IgnoredAny)) = self.next_entry()? {}
Ok(())
}
}
impl<'a, 'de, R: Read> MapAccess<'de> for MapDeserializer<'a, R> {
type Error = Error;
fn next_key_seed<K: DeserializeSeed<'de>>(
&mut self,
seed: K,
) -> Result<Option<K::Value>, Self::Error> {
if self.remaining == 0 {
Ok(None)
} else {
self.remaining -= 1;
seed.deserialize(&mut *self.deserializer).map(Some)
}
}
fn next_value_seed<V: DeserializeSeed<'de>>(
&mut self,
seed: V,
) -> Result<V::Value, Self::Error> {
seed.deserialize(&mut *self.deserializer)
}
fn size_hint(&self) -> Option<usize> {
cautious_size_hint(self.remaining).into()
}
}