bramble-data 0.1.1

Bramble's Binary Data Format
Documentation
use crate::{constants::*, Error, Result};
use serde::{
    de::{
        self, DeserializeSeed, EnumAccess, IntoDeserializer, MapAccess, SeqAccess, VariantAccess,
        Visitor,
    },
    Deserialize,
};
use std::{convert::TryFrom, io, io::Read, marker::PhantomData, slice};

/// Deserializes a value from BDF data in a slice.
pub fn from_slice<'de, T>(slice: &'de [u8]) -> Result<T>
where
    T: Deserialize<'de>,
{
    from_reader(slice)
}

/// Deserializes a value from BDF data in a [`Read`](std::io::Read).
pub fn from_reader<'de, R, T>(reader: R) -> Result<T>
where
    T: Deserialize<'de>,
    R: Read,
{
    let mut de = Deserializer::new(reader);
    let value = Deserialize::deserialize(&mut de)?;
    de.end()?;
    Ok(value)
}

/// A deserializer that parses data as BDF.
pub struct Deserializer<R>
where
    R: Read,
{
    reader: R,
    last_discriminator: Option<(u8, u8)>,
}

impl<R> Deserializer<R>
where
    R: Read,
{
    /// Creates a new deserializer from a reader.
    pub fn new(reader: R) -> Self {
        Self {
            reader,
            last_discriminator: None,
        }
    }

    /// Finishes deserializing BDF.
    ///
    /// This should be called after deserializing to ensure that there is no trailing data.
    pub fn end(&mut self) -> Result<()> {
        match self.read_discriminator() {
            Err(Error::Io(e)) if e.kind() == io::ErrorKind::UnexpectedEof => Ok(()),
            _ => Err(Error::TrailingBytes),
        }
    }

    /// Creates a deserializer that deserializes multiple values from a single BDF stream.
    #[allow(clippy::should_implement_trait)]
    pub fn into_iter<'de, T>(self) -> StreamDeserializer<'de, R, T>
    where
        T: Deserialize<'de>,
    {
        StreamDeserializer {
            de: self,
            failed: false,
            output: PhantomData,
            lifetime: PhantomData,
        }
    }

    fn read_discriminator(&mut self) -> Result<(u8, u8)> {
        let mut d = 0;
        self.reader.read_exact(slice::from_mut(&mut d))?;
        Ok((d & TYPE_MASK, d & !TYPE_MASK))
    }

    fn peek_discriminator(&mut self) -> Result<(u8, u8)> {
        if self.last_discriminator.is_none() {
            self.last_discriminator = Some(self.read_discriminator()?);
        }
        Ok(self.last_discriminator.unwrap())
    }

    fn consume_discriminator(&mut self) -> Result<(u8, u8)> {
        self.last_discriminator
            .take()
            .map(Result::Ok)
            .unwrap_or_else(|| self.read_discriminator())
    }

    fn read_i64(&mut self, len: usize) -> Result<i64> {
        let mut buf = [0u8; 8];
        let start = 8 - len;
        self.reader.read_exact(&mut buf[start..])?;
        if buf[start] & 0x80 != 0 {
            // sign-extend
            buf[0..start].fill(0xFF);
        }
        Ok(i64::from_be_bytes(buf))
    }

    fn read_null(&mut self) -> Result<()> {
        let (typ, bits) = self.consume_discriminator()?;
        if typ != TYPE_NULL {
            return Err(Error::WrongType);
        }
        if bits != 0 {
            return Err(Error::InvalidValue);
        }
        Ok(())
    }

    fn read_boolean(&mut self) -> Result<bool> {
        let (typ, bits) = self.consume_discriminator()?;
        if typ != TYPE_BOOLEAN {
            return Err(Error::WrongType);
        }
        if bits > 1 {
            return Err(Error::InvalidValue);
        }
        Ok(bits == 1)
    }

    fn read_integer(&mut self) -> Result<i64> {
        let (typ, len) = self.consume_discriminator()?;
        if typ != TYPE_INTEGER {
            return Err(Error::WrongType);
        }
        if !len.is_power_of_two() {
            return Err(Error::InvalidLength);
        }
        self.read_i64(len as usize)
    }

    fn read_float(&mut self) -> Result<f64> {
        let (typ, len) = self.consume_discriminator()?;
        if typ != TYPE_FLOAT {
            return Err(Error::WrongType);
        }
        if len != 8 {
            return Err(Error::InvalidLength);
        }
        let mut buf = [0u8; 8];
        self.reader.read_exact(&mut buf)?;
        Ok(f64::from_be_bytes(buf))
    }

    fn read_string(&mut self) -> Result<String> {
        let (typ, llen) = self.consume_discriminator()?;
        if typ != TYPE_STRING {
            return Err(Error::WrongType);
        }
        if !llen.is_power_of_two() {
            return Err(Error::InvalidLengthOfLength);
        }
        let len = self.read_i64(llen as usize)?;
        if len < 0 {
            return Err(Error::InvalidLength);
        }
        let mut s = String::with_capacity(len as usize);
        let read = (&mut self.reader).take(len as u64).read_to_string(&mut s)?;
        if read != len as usize {
            return Err(Error::eof());
        }
        Ok(s)
    }

    fn read_raw(&mut self) -> Result<Vec<u8>> {
        let (typ, llen) = self.consume_discriminator()?;
        if typ != TYPE_RAW {
            return Err(Error::WrongType);
        }
        if !llen.is_power_of_two() {
            return Err(Error::InvalidLengthOfLength);
        }
        let len = self.read_i64(llen as usize)?;
        if len < 0 {
            return Err(Error::InvalidLength);
        }
        let mut v = Vec::with_capacity(len as usize);
        let read = (&mut self.reader).take(len as u64).read_to_end(&mut v)?;
        if read != len as usize {
            return Err(Error::eof());
        }
        Ok(v)
    }

    fn read_list_start(&mut self) -> Result<()> {
        let (typ, bits) = self.consume_discriminator()?;
        if typ != TYPE_LIST {
            return Err(Error::WrongType);
        }
        if bits != 0 {
            return Err(Error::InvalidValue);
        }
        Ok(())
    }

    fn read_dictionary_start(&mut self) -> Result<()> {
        let (typ, bits) = self.consume_discriminator()?;
        if typ != TYPE_DICTIONARY {
            return Err(Error::WrongType);
        }
        if bits != 0 {
            return Err(Error::InvalidValue);
        }
        Ok(())
    }

    fn peek_end(&mut self) -> Result<bool> {
        let (typ, bits) = self.peek_discriminator()?;
        if typ != TYPE_END {
            return Ok(false);
        }
        if bits != 0 {
            return Err(Error::InvalidValue);
        }
        Ok(true)
    }

    fn read_end(&mut self) -> Result<()> {
        let (typ, bits) = self.consume_discriminator()?;
        if typ != TYPE_END {
            return Err(Error::WrongType);
        }
        if bits != 0 {
            return Err(Error::InvalidValue);
        }
        Ok(())
    }
}

impl<'de, 'a, R> de::Deserializer<'de> for &'a mut Deserializer<R>
where
    R: Read,
{
    type Error = Error;

    fn deserialize_any<V>(self, visitor: V) -> Result<V::Value>
    where
        V: Visitor<'de>,
    {
        let (typ, _) = self.peek_discriminator()?;

        match typ {
            TYPE_NULL => self.deserialize_unit(visitor),
            TYPE_BOOLEAN => self.deserialize_bool(visitor),
            TYPE_INTEGER => self.deserialize_i64(visitor),
            TYPE_FLOAT => self.deserialize_f64(visitor),
            TYPE_STRING => self.deserialize_str(visitor),
            TYPE_RAW => self.deserialize_bytes(visitor),
            TYPE_LIST => self.deserialize_seq(visitor),
            TYPE_DICTIONARY => self.deserialize_map(visitor),
            _ => Err(Error::InvalidType),
        }
    }

    fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value>
    where
        V: Visitor<'de>,
    {
        let value = self.read_boolean()?;
        visitor.visit_bool(value)
    }

    fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value>
    where
        V: Visitor<'de>,
    {
        let value = self.read_integer()?;
        visitor.visit_i8(i8::try_from(value)?)
    }

    fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value>
    where
        V: Visitor<'de>,
    {
        let value = self.read_integer()?;
        visitor.visit_i16(i16::try_from(value)?)
    }

    fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value>
    where
        V: Visitor<'de>,
    {
        let value = self.read_integer()?;
        visitor.visit_i32(i32::try_from(value)?)
    }

    fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value>
    where
        V: Visitor<'de>,
    {
        let value = self.read_integer()?;
        visitor.visit_i64(value)
    }

    fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value>
    where
        V: Visitor<'de>,
    {
        let value = self.read_integer()?;
        visitor.visit_u8(u8::try_from(value)?)
    }

    fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value>
    where
        V: Visitor<'de>,
    {
        let value = self.read_integer()?;
        visitor.visit_u16(u16::try_from(value)?)
    }

    fn deserialize_u32<V>(self, visitor: V) -> Result<V::Value>
    where
        V: Visitor<'de>,
    {
        let value = self.read_integer()?;
        visitor.visit_u32(u32::try_from(value)?)
    }

    fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value>
    where
        V: Visitor<'de>,
    {
        let value = self.read_integer()?;
        visitor.visit_u64(u64::try_from(value)?)
    }

    fn deserialize_f32<V>(self, visitor: V) -> Result<V::Value>
    where
        V: Visitor<'de>,
    {
        let value = self.read_float()?;
        visitor.visit_f32(value as f32)
    }

    fn deserialize_f64<V>(self, visitor: V) -> Result<V::Value>
    where
        V: Visitor<'de>,
    {
        let value = self.read_float()?;
        visitor.visit_f64(value)
    }

    fn deserialize_char<V>(self, visitor: V) -> Result<V::Value>
    where
        V: Visitor<'de>,
    {
        let value = self.read_integer()?;
        visitor.visit_char(char::try_from(u32::try_from(value)?)?)
    }

    fn deserialize_str<V>(self, visitor: V) -> Result<V::Value>
    where
        V: Visitor<'de>,
    {
        self.deserialize_string(visitor)
    }

    fn deserialize_string<V>(self, visitor: V) -> Result<V::Value>
    where
        V: Visitor<'de>,
    {
        let value = self.read_string()?;
        visitor.visit_string(value)
    }

    fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value>
    where
        V: Visitor<'de>,
    {
        self.deserialize_byte_buf(visitor)
    }

    fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value>
    where
        V: Visitor<'de>,
    {
        let value = self.read_raw()?;
        visitor.visit_byte_buf(value)
    }

    fn deserialize_option<V>(self, visitor: V) -> Result<V::Value>
    where
        V: Visitor<'de>,
    {
        let (typ, _) = self.peek_discriminator()?;
        match typ {
            TYPE_NULL => visitor.visit_none(),
            TYPE_BOOLEAN | TYPE_INTEGER | TYPE_FLOAT | TYPE_STRING | TYPE_RAW | TYPE_LIST
            | TYPE_DICTIONARY => visitor.visit_some(self),
            _ => Err(Error::WrongType),
        }
    }

    fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value>
    where
        V: Visitor<'de>,
    {
        self.read_null()?;
        visitor.visit_unit()
    }

    fn deserialize_unit_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value>
    where
        V: Visitor<'de>,
    {
        self.deserialize_unit(visitor)
    }

    fn deserialize_newtype_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value>
    where
        V: Visitor<'de>,
    {
        visitor.visit_newtype_struct(self)
    }

    fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value>
    where
        V: Visitor<'de>,
    {
        self.read_list_start()?;
        let value = visitor.visit_seq(&mut *self)?;
        self.read_end()?;
        Ok(value)
    }

    fn deserialize_tuple<V>(self, _len: usize, visitor: V) -> Result<V::Value>
    where
        V: Visitor<'de>,
    {
        self.deserialize_seq(visitor)
    }

    fn deserialize_tuple_struct<V>(
        self,
        _name: &'static str,
        _len: usize,
        visitor: V,
    ) -> Result<V::Value>
    where
        V: Visitor<'de>,
    {
        self.deserialize_seq(visitor)
    }

    fn deserialize_map<V>(self, visitor: V) -> Result<V::Value>
    where
        V: Visitor<'de>,
    {
        self.read_dictionary_start()?;
        let value = visitor.visit_map(&mut *self)?;
        self.read_end()?;
        Ok(value)
    }

    fn deserialize_struct<V>(
        self,
        _name: &'static str,
        fields: &'static [&'static str],
        visitor: V,
    ) -> Result<V::Value>
    where
        V: Visitor<'de>,
    {
        self.deserialize_tuple(fields.len(), visitor)
    }

    fn deserialize_enum<V>(
        self,
        _name: &'static str,
        _variants: &'static [&'static str],
        visitor: V,
    ) -> Result<V::Value>
    where
        V: Visitor<'de>,
    {
        visitor.visit_enum(self)
    }

    fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value>
    where
        V: Visitor<'de>,
    {
        self.deserialize_str(visitor)
    }

    fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value>
    where
        V: Visitor<'de>,
    {
        self.deserialize_any(visitor)
    }
}

impl<'de, 'a, R> SeqAccess<'de> for Deserializer<R>
where
    R: Read,
{
    type Error = Error;

    fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
    where
        T: DeserializeSeed<'de>,
    {
        if self.peek_end()? {
            return Ok(None);
        }

        seed.deserialize(self).map(Some)
    }
}

impl<'de, 'a, R> MapAccess<'de> for Deserializer<R>
where
    R: Read,
{
    type Error = Error;

    fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>>
    where
        K: DeserializeSeed<'de>,
    {
        if self.peek_end()? {
            return Ok(None);
        }

        seed.deserialize(self).map(Some)
    }

    fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value>
    where
        V: DeserializeSeed<'de>,
    {
        seed.deserialize(self)
    }
}

impl<'de, 'a, R> EnumAccess<'de> for &'a mut Deserializer<R>
where
    R: Read,
{
    type Error = Error;
    type Variant = Self;

    fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant)>
    where
        V: DeserializeSeed<'de>,
    {
        let (typ, _) = self.peek_discriminator()?;
        match typ {
            TYPE_INTEGER => {}
            TYPE_LIST => self.read_list_start()?,
            _ => return Err(Error::WrongType),
        }
        let variant_index = u32::try_from(self.read_integer()?)?;
        let value: Result<_> = seed.deserialize(variant_index.into_deserializer());
        Ok((value?, self))
    }
}

impl<'de, 'a, R> VariantAccess<'de> for &'a mut Deserializer<R>
where
    R: Read,
{
    type Error = Error;

    fn unit_variant(self) -> Result<()> {
        Ok(())
    }

    fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value>
    where
        T: DeserializeSeed<'de>,
    {
        let value = seed.deserialize(&mut *self)?;
        self.read_end()?;
        Ok(value)
    }

    fn tuple_variant<V>(self, _len: usize, visitor: V) -> Result<V::Value>
    where
        V: Visitor<'de>,
    {
        self.read_list_start()?;
        let value = visitor.visit_seq(&mut *self)?;
        self.read_end()?;
        self.read_end()?;
        Ok(value)
    }

    fn struct_variant<V>(self, _fields: &'static [&'static str], visitor: V) -> Result<V::Value>
    where
        V: Visitor<'de>,
    {
        self.read_list_start()?;
        let value = visitor.visit_seq(&mut *self)?;
        self.read_end()?;
        self.read_end()?;
        Ok(value)
    }
}

/// Iterator that deserializes a stream into multiple BDF values
pub struct StreamDeserializer<'de, R, T>
where
    R: Read,
    T: Deserialize<'de>,
{
    de: Deserializer<R>,
    failed: bool,
    output: PhantomData<T>,
    lifetime: PhantomData<&'de ()>,
}

impl<'de, R, T> Iterator for StreamDeserializer<'de, R, T>
where
    R: Read,
    T: Deserialize<'de>,
{
    type Item = Result<T>;

    fn next(&mut self) -> Option<Result<T>> {
        if self.failed {
            return None;
        }

        match Deserialize::deserialize(&mut self.de) {
            Err(e) => {
                self.failed = true;
                if e.is_eof() {
                    None
                } else {
                    Some(Err(e))
                }
            }
            ok => Some(ok),
        }
    }
}

#[cfg(test)]
mod test {
    use super::*;
    use hex_literal::hex;
    use serde::Deserialize;
    use std::collections::HashMap;

    #[test]
    fn from_slice_maps() {
        let buf = hex!("70 4103626172 2201C8 4103666F6F 217B 80");
        let map = from_slice(&buf).unwrap();

        let mut expected = HashMap::new();
        expected.insert("foo".to_string(), 123u32);
        expected.insert("bar".to_string(), 456u32);

        assert_eq!(expected, map);
    }

    #[test]
    fn from_slice_structs() {
        #[derive(Deserialize, Debug, PartialEq, Eq)]
        struct Test {
            x: bool,
            y: u32,
            z: Vec<String>,
        }
        let buf = hex!("60 11 2111 60 4103666F6F 4103626172 80 80");
        let s = from_slice(&buf).unwrap();
        let expected = Test {
            x: true,
            y: 17,
            z: vec!["foo".into(), "bar".into()],
        };
        assert_eq!(expected, s);
    }

    #[test]
    fn from_slice_enums() {
        #[derive(Deserialize, Debug, PartialEq, Eq)]
        enum Test {
            UnitVariant,
            NewTypeVariant(u32),
            TupleVariant(bool, u32),
            StructVariant { x: bool, y: u32 },
        }
        let buf = hex!("2100");
        let e = from_slice(&buf).unwrap();
        let expected = Test::UnitVariant;
        assert_eq!(expected, e);

        let buf = hex!("60 2101 2111 80");
        let e = from_slice(&buf).unwrap();
        let expected = Test::NewTypeVariant(17);
        assert_eq!(expected, e);

        let buf = hex!("60 2102 60 11 2111 80 80");
        let e = from_slice(&buf).unwrap();
        let expected = Test::TupleVariant(true, 17);
        assert_eq!(expected, e);

        let buf = hex!("60 2103 60 11 2111 80 80");
        let e = from_slice(&buf).unwrap();
        let expected = Test::StructVariant { x: true, y: 17 };
        assert_eq!(expected, e);
    }

    #[test]
    fn from_slice_options() {
        let buf = hex!("00");
        let o = from_slice(&buf).unwrap();
        let expected: Option<u32> = None;
        assert_eq!(expected, o);

        let buf = hex!("2111");
        let o = from_slice(&buf).unwrap();
        let expected = Some(17);
        assert_eq!(expected, o);
    }

    #[test]
    fn stream_deserializer() {
        let buf = hex!("2100 2101 80 2103");
        let vec = Deserializer::new(buf.as_ref())
            .into_iter()
            .collect::<Vec<Result<u64>>>();
        assert_eq!(vec.len(), 3);
        assert_eq!(vec[0].as_ref().unwrap(), &0);
        assert_eq!(vec[1].as_ref().unwrap(), &1);
        assert!(matches!(vec[2].as_ref().unwrap_err(), Error::WrongType));

        let buf = hex!("2100 2101 2102 2103");
        let vec = Deserializer::new(buf.as_ref())
            .into_iter()
            .collect::<Result<Vec<u64>>>()
            .unwrap();
        assert_eq!(vec![0, 1, 2, 3], vec);
    }
}

// TODO read buffer size limits for ddos protection
// TODO zero-copy reads
// TODO nesting limit