use std::str;
use std::fmt::Display;
use std::convert::TryInto;
use serde::de::{
Deserialize, DeserializeSeed, Deserializer, value::BorrowedStrDeserializer,
SeqAccess, MapAccess, EnumAccess, VariantAccess,
Visitor, IgnoredAny,
Expected,
};
use crate::error::Error;
use crate::format::*;
use super::*;
pub fn from_bytes<'de, T: Deserialize<'de>>(bytes: &'de [u8]) -> Result<T, Error> {
let mut de = BinarySliceDeserializer::new(bytes)?;
let value = T::deserialize(&mut de)?;
de.finalize()?;
Ok(value)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
enum SymbolRef<'de> {
Blob(&'de [u8]),
Str(&'de str),
}
impl<'de> SymbolRef<'de> {
#[allow(clippy::shadow_reuse)]
fn parse(bytes: &'de [u8]) -> Result<(Self, &[u8]), Error> {
let (tag, rest) = match bytes.split_first() {
Some((&tag, rest)) => (tag, rest),
None => return corrupted("missing symbol")
};
let SymbolFlags { is_big, is_string, is_multi } = tag.try_into()?;
let (buf_len, rest) = if is_big {
let buf_len_len = tag.decode_log_length();
uint_from_parts(buf_len_len, rest)?
} else {
let buf_len = usize::from(decode_small_uint(tag));
(buf_len, rest)
};
let rest = if is_multi {
let (tag, rest) = match rest.split_first() {
Some((&tag, rest)) => (tag, rest),
None => return corrupted("multi-use symbol missing use count"),
};
if tag.is_major(MAJOR_TYPE_SMALL_UINT) {
rest
} else if tag.is_major_minor(MAJOR_TYPE_BIG_VALUE, MINOR_TYPE_UINT) {
match rest.get(tag.decode_log_length()..) {
Some(slice) => slice,
None => return corrupted(
"symbol use count buffer too short or missing"
),
}
} else {
return corrupted(format_args!(
"invalid type for use count: major {:08b} minor {:08b}",
tag & MAJOR_TYPE_MASK,
tag & MINOR_TYPE_MASK,
));
}
} else {
rest
};
let (payload, rest) = if buf_len <= rest.len() {
rest.split_at(buf_len)
} else {
return corrupted("symbol payload buffer too short or missing");
};
let symbol = if is_string {
let s = str::from_utf8(payload).map_err(|e| {
Error::custom(format_args!(
"interned string symbol is invalid UTF-8 (cause: {})", e
))
})?;
SymbolRef::Str(s)
} else {
SymbolRef::Blob(payload)
};
Ok((symbol, rest))
}
const fn as_bytes(self) -> &'de [u8] {
match self {
SymbolRef::Blob(b) => b,
SymbolRef::Str(s) => s.as_bytes(),
}
}
fn try_as_str(self) -> Result<&'de str, Error> {
match self {
SymbolRef::Blob(_) => corrupted("attempted to use blob as string"),
SymbolRef::Str(s) => Ok(s),
}
}
}
#[derive(Debug, Clone, Default, PartialEq, Eq, Hash)]
struct SymbolRefTable<'de> {
symbols: Vec<SymbolRef<'de>>,
}
impl<'de> SymbolRefTable<'de> {
#[allow(clippy::shadow_reuse)]
fn parse(bytes: &[u8]) -> Result<(SymbolRefTable, &[u8]), Error> {
let (marker, rest) = match bytes.split_first() {
Some((&x, rest)) => {
if x.is_major_minor(MAJOR_TYPE_SIMPLE, MINOR_TYPE_SYMTAB) {
(x, rest)
} else {
return Ok((SymbolRefTable::default(), bytes));
}
}
None => return Ok((SymbolRefTable::default(), bytes))
};
let symtab_len_len = marker.decode_log_length();
let (symtab_len, mut rest) = uint_from_parts(symtab_len_len, rest)?;
if symtab_len > rest.len() {
return corrupted(format_args!(
"symtab length is {} but only {} bytes of symbol data follow",
symtab_len,
rest.len(),
));
}
let mut symtab = SymbolRefTable {
symbols: Vec::with_capacity(symtab_len)
};
for _ in 0..symtab_len {
let (sym, next) = SymbolRef::parse(rest)?;
symtab.symbols.push(sym);
rest = next;
}
Ok((symtab, rest))
}
fn len(&self) -> usize {
self.symbols.len()
}
fn get_blob(&self, index: usize) -> Result<&'de [u8], Error> {
if let Some(symbol) = self.symbols.get(index) {
Ok(symbol.as_bytes())
} else {
corrupted(format_args!(
"borrowed blob #{} out of bounds for symtab of size {}",
index, self.len()
))
}
}
fn get_str(&self, index: usize) -> Result<&'de str, Error> {
if let Some(symbol) = self.symbols.get(index) {
symbol.try_as_str()
} else {
corrupted(format_args!(
"borrowed string #{} out of bounds for symtab of size {}",
index, self.len()
))
}
}
}
#[derive(Debug, Clone)]
pub struct BinarySliceDeserializer<'de> {
symbol_table: SymbolRefTable<'de>,
body: &'de [u8],
}
impl<'de> BinarySliceDeserializer<'de> {
pub fn new(bytes: &'de [u8]) -> Result<Self, Error> {
let (symbol_table, body) = SymbolRefTable::parse(bytes)?;
Ok(BinarySliceDeserializer { symbol_table, body })
}
pub fn finalize(self) -> Result<(), Error> {
if self.body.is_empty() {
Ok(())
} else {
corrupted(format_args!(
"junk of length {} after data", self.body.len()
))
}
}
fn eat_byte<T: Display>(&mut self, error_msg: T) -> Result<u8, Error> {
match self.body.split_first() {
Some((&b, rest)) => {
self.body = rest;
Ok(b)
}
None => corrupted(error_msg)
}
}
fn eat_slice<T: Display>(
&mut self,
length: usize,
error_msg: T,
) -> Result<&'de [u8], Error> {
if length <= self.body.len() {
let (head, rest) = self.body.split_at(length);
self.body = rest;
Ok(head)
} else {
corrupted(error_msg)
}
}
fn read_value_header(&mut self, exp: &dyn Expected) -> Result<ValueHeader, Error> {
let b = self.eat_byte(format_args!("missing value; expected {}", exp))?;
read_value_header(b, |len| {
self.eat_slice(len, "unexpected end of input in big value")
})
}
fn deserialize_number<V: Visitor<'de>>(
&mut self,
visitor: V,
) -> Result<V::Value, Error> {
visit_number(self.read_value_header(&visitor)?, visitor)
}
fn visit_and_exhaust_seq<V: Visitor<'de>>(
&mut self,
count: usize,
visitor: V,
) -> Result<V::Value, Error> {
let mut seq = SeqDeserializer::new(self, count)?;
let value = visitor.visit_seq(&mut seq)?;
seq.exhaust()?;
Ok(value)
}
fn visit_and_exhaust_map<V: Visitor<'de>>(
&mut self,
count: usize,
visitor: V,
) -> Result<V::Value, Error> {
let mut map = MapDeserializer::new(self, count)?;
let value = visitor.visit_map(&mut map)?;
map.exhaust()?;
Ok(value)
}
}
impl<'de> Deserializer<'de> for &mut BinarySliceDeserializer<'de> {
type Error = Error;
fn is_human_readable(&self) -> bool {
false
}
fn deserialize_any<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
use ValueHeader::*;
match self.read_value_header(&visitor)? {
Null => visitor.visit_unit(),
Opt => visitor.visit_some(self),
Bool(b) => visitor.visit_bool(b),
I8(x) => visitor.visit_i8(x),
I16(x) => visitor.visit_i16(x),
I32(x) => visitor.visit_i32(x),
I64(x) => visitor.visit_i64(x),
U8(x) => visitor.visit_u8(x),
U16(x) => visitor.visit_u16(x),
U32(x) => visitor.visit_u32(x),
U64(x) => visitor.visit_u64(x),
F32(x) => visitor.visit_f32(x.into()),
F64(x) => visitor.visit_f64(x.into()),
EmptyString => visitor.visit_borrowed_str(""),
EmptyBlob => visitor.visit_borrowed_bytes(&[]),
String(index) => {
let string = self.symbol_table.get_str(index)?;
visitor.visit_borrowed_str(string)
},
Blob(index) => {
let bytes = self.symbol_table.get_blob(index)?;
visitor.visit_borrowed_bytes(bytes)
},
Array(count) => self.visit_and_exhaust_seq(count, visitor),
Map(count) => self.visit_and_exhaust_map(count, visitor),
}
}
fn deserialize_bool<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
match self.read_value_header(&visitor)? {
ValueHeader::Bool(b) => visitor.visit_bool(b),
value @ _ => type_error(value, &visitor),
}
}
fn deserialize_i8<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
self.deserialize_number(visitor)
}
fn deserialize_i16<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
self.deserialize_number(visitor)
}
fn deserialize_i32<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
self.deserialize_number(visitor)
}
fn deserialize_i64<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
self.deserialize_number(visitor)
}
fn deserialize_i128<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
self.deserialize_number(visitor)
}
fn deserialize_u8<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
self.deserialize_number(visitor)
}
fn deserialize_u16<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
self.deserialize_number(visitor)
}
fn deserialize_u32<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
self.deserialize_number(visitor)
}
fn deserialize_u64<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
self.deserialize_number(visitor)
}
fn deserialize_u128<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
self.deserialize_number(visitor)
}
fn deserialize_f32<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
self.deserialize_number(visitor)
}
fn deserialize_f64<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
self.deserialize_number(visitor)
}
fn deserialize_char<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
self.deserialize_str(visitor)
}
fn deserialize_str<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
use ValueHeader::*;
match self.read_value_header(&visitor)? {
EmptyString | EmptyBlob => visitor.visit_borrowed_str(""),
String(index) => {
let string = self.symbol_table.get_str(index)?;
visitor.visit_borrowed_str(string)
},
Blob(index) => {
let bytes = self.symbol_table.get_blob(index)?;
let string = str::from_utf8(bytes)?;
visitor.visit_borrowed_str(string)
},
value @ _ => type_error(value, &visitor),
}
}
fn deserialize_string<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
self.deserialize_str(visitor)
}
fn deserialize_bytes<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
use ValueHeader::*;
match self.read_value_header(&visitor)? {
EmptyString => visitor.visit_borrowed_str(""),
EmptyBlob => visitor.visit_borrowed_bytes(&[]),
String(index) => {
let string = self.symbol_table.get_str(index)?;
visitor.visit_borrowed_str(string)
},
Blob(index) => {
let bytes = self.symbol_table.get_blob(index)?;
visitor.visit_borrowed_bytes(bytes)
},
Array(count) => self.visit_and_exhaust_seq(count, visitor),
value @ _ => type_error(value, &visitor),
}
}
fn deserialize_byte_buf<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
self.deserialize_bytes(visitor)
}
fn deserialize_option<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
let value = self.read_value_header(&visitor)?;
visit_option(value, self, visitor)
}
fn deserialize_unit<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
visit_unit(self.read_value_header(&visitor)?, visitor)
}
fn deserialize_unit_struct<V: Visitor<'de>>(
self,
_name: &'static str,
visitor: V,
) -> Result<V::Value, Self::Error> {
self.deserialize_unit(visitor)
}
fn deserialize_newtype_struct<V: Visitor<'de>>(
self,
_name: &'static str,
visitor: V,
) -> Result<V::Value, Self::Error> {
visitor.visit_newtype_struct(self)
}
fn deserialize_seq<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
use ValueHeader::Array;
match self.read_value_header(&visitor)? {
Array(count) => self.visit_and_exhaust_seq(count, visitor),
value @ _ => type_error(value, &visitor),
}
}
fn deserialize_tuple<V: Visitor<'de>>(
self,
_len: usize,
visitor: V,
) -> Result<V::Value, Self::Error> {
self.deserialize_seq(visitor)
}
fn deserialize_tuple_struct<V: Visitor<'de>>(
self,
_name: &'static str,
len: usize,
visitor: V,
) -> Result<V::Value, Self::Error> {
self.deserialize_tuple(len, visitor)
}
fn deserialize_map<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
use ValueHeader::Map;
match self.read_value_header(&visitor)? {
Map(count) => self.visit_and_exhaust_map(count, visitor),
value @ _ => type_error(value, &visitor),
}
}
fn deserialize_struct<V: Visitor<'de>>(
self,
_name: &'static str,
_fields: &'static [&'static str],
visitor: V,
) -> Result<V::Value, Self::Error> {
self.deserialize_map(visitor)
}
fn deserialize_identifier<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
self.deserialize_str(visitor)
}
fn deserialize_enum<V: Visitor<'de>>(
self,
_type_name: &'static str,
_variants: &'static [&'static str],
visitor: V,
) -> Result<V::Value, Self::Error> {
match self.read_value_header(&visitor)? {
ValueHeader::String(index) => {
let string = self.symbol_table.get_str(index)?;
let deserializer = BorrowedStrDeserializer::new(string);
visitor.visit_enum(deserializer)
},
ValueHeader::EmptyString => {
let deserializer = BorrowedStrDeserializer::new("");
visitor.visit_enum(deserializer)
},
ValueHeader::Map(count) => {
if count == 1 {
visitor.visit_enum(self)
} else {
Err(Error::invalid_length(count, &"enum as single-key map"))
}
},
value @ _ => type_error(value, &"enum as string or single-key map"),
}
}
fn deserialize_ignored_any<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
self.deserialize_any(IgnoredAny).and_then(|_| visitor.visit_unit())
}
}
impl<'de> EnumAccess<'de> for &mut BinarySliceDeserializer<'de> {
type Error = Error;
type Variant = Self;
fn variant_seed<V: DeserializeSeed<'de>>(
self,
seed: V
) -> Result<(V::Value, Self::Variant), Self::Error> {
seed.deserialize(&mut *self).map(|v| (v, self))
}
}
impl<'de> VariantAccess<'de> for &mut BinarySliceDeserializer<'de> {
type Error = <Self as EnumAccess<'de>>::Error;
fn unit_variant(self) -> Result<(), Self::Error> {
Deserialize::deserialize(self)
}
fn newtype_variant_seed<T: DeserializeSeed<'de>>(
self,
seed: T,
) -> Result<T::Value, Self::Error> {
seed.deserialize(self)
}
fn tuple_variant<V: Visitor<'de>>(
self,
len: usize,
visitor: V,
) -> Result<V::Value, Self::Error> {
self.deserialize_tuple(len, visitor)
}
fn struct_variant<V: Visitor<'de>>(
self,
_fields: &'static [&'static str],
visitor: V,
) -> Result<V::Value, Self::Error> {
self.deserialize_map(visitor)
}
}
#[derive(Debug)]
struct SeqDeserializer<'a, 'de: 'a> {
deserializer: &'a mut BinarySliceDeserializer<'de>,
remaining: usize,
}
impl<'a, 'de> SeqDeserializer<'a, 'de> {
fn new(de: &'a mut BinarySliceDeserializer<'de>, count: usize) -> Result<Self, Error> {
let bytes_remaining = de.body.len();
if count <= bytes_remaining {
Ok(SeqDeserializer {
deserializer: de,
remaining: count,
})
} else {
corrupted(format_args!(
"sequence count is {} but only {} bytes are remaining",
count,
bytes_remaining,
))
}
}
fn exhaust(&mut self) -> Result<(), Error> {
while let Some(IgnoredAny) = self.next_element()? {}
Ok(())
}
}
impl<'a, 'de> SeqAccess<'de> for SeqDeserializer<'a, 'de> {
type Error = Error;
fn next_element_seed<T: DeserializeSeed<'de>>(
&mut self,
seed: T,
) -> Result<Option<T::Value>, Self::Error> {
if self.remaining == 0 {
Ok(None)
} else {
self.remaining -= 1;
seed.deserialize(&mut *self.deserializer).map(Some)
}
}
fn size_hint(&self) -> Option<usize> {
self.remaining.into()
}
}
#[derive(Debug)]
struct MapDeserializer<'a, 'de: 'a> {
deserializer: &'a mut BinarySliceDeserializer<'de>,
remaining: usize,
}
impl<'a, 'de> MapDeserializer<'a, 'de> {
fn new(de: &'a mut BinarySliceDeserializer<'de>, count: usize) -> Result<Self, Error> {
let bytes_remaining = de.body.len();
if count <= bytes_remaining / 2 {
Ok(MapDeserializer {
deserializer: de,
remaining: count,
})
} else {
corrupted(format_args!(
"key-value count is 2 * {} but only {} bytes are remaining",
count,
bytes_remaining,
))
}
}
fn exhaust(&mut self) -> Result<(), Error> {
while let Some((IgnoredAny, IgnoredAny)) = self.next_entry()? {}
Ok(())
}
}
impl<'a, 'de> MapAccess<'de> for MapDeserializer<'a, 'de> {
type Error = Error;
fn next_key_seed<K: DeserializeSeed<'de>>(
&mut self,
seed: K,
) -> Result<Option<K::Value>, Self::Error> {
if self.remaining == 0 {
Ok(None)
} else {
self.remaining -= 1;
seed.deserialize(&mut *self.deserializer).map(Some)
}
}
fn next_value_seed<V: DeserializeSeed<'de>>(
&mut self,
seed: V,
) -> Result<V::Value, Self::Error> {
seed.deserialize(&mut *self.deserializer)
}
fn size_hint(&self) -> Option<usize> {
self.remaining.into()
}
}