bramble_data/
de.rs

1use crate::{constants::*, Error, Result};
2use serde::{
3    de::{
4        self, DeserializeSeed, EnumAccess, IntoDeserializer, MapAccess, SeqAccess, VariantAccess,
5        Visitor,
6    },
7    Deserialize,
8};
9use std::{convert::TryFrom, io, io::Read, marker::PhantomData, slice};
10
11/// Deserializes a value from BDF data in a slice.
12pub fn from_slice<'de, T>(slice: &'de [u8]) -> Result<T>
13where
14    T: Deserialize<'de>,
15{
16    from_reader(slice)
17}
18
19/// Deserializes a value from BDF data in a [`Read`](std::io::Read).
20pub fn from_reader<'de, R, T>(reader: R) -> Result<T>
21where
22    T: Deserialize<'de>,
23    R: Read,
24{
25    let mut de = Deserializer::new(reader);
26    let value = Deserialize::deserialize(&mut de)?;
27    de.end()?;
28    Ok(value)
29}
30
31/// A deserializer that parses data as BDF.
32pub struct Deserializer<R>
33where
34    R: Read,
35{
36    reader: R,
37    last_discriminator: Option<(u8, u8)>,
38}
39
40impl<R> Deserializer<R>
41where
42    R: Read,
43{
44    /// Creates a new deserializer from a reader.
45    pub fn new(reader: R) -> Self {
46        Self {
47            reader,
48            last_discriminator: None,
49        }
50    }
51
52    /// Finishes deserializing BDF.
53    ///
54    /// This should be called after deserializing to ensure that there is no trailing data.
55    pub fn end(&mut self) -> Result<()> {
56        match self.read_discriminator() {
57            Err(Error::Io(e)) if e.kind() == io::ErrorKind::UnexpectedEof => Ok(()),
58            _ => Err(Error::TrailingBytes),
59        }
60    }
61
62    /// Creates a deserializer that deserializes multiple values from a single BDF stream.
63    #[allow(clippy::should_implement_trait)]
64    pub fn into_iter<'de, T>(self) -> StreamDeserializer<'de, R, T>
65    where
66        T: Deserialize<'de>,
67    {
68        StreamDeserializer {
69            de: self,
70            failed: false,
71            output: PhantomData,
72            lifetime: PhantomData,
73        }
74    }
75
76    fn read_discriminator(&mut self) -> Result<(u8, u8)> {
77        let mut d = 0;
78        self.reader.read_exact(slice::from_mut(&mut d))?;
79        Ok((d & TYPE_MASK, d & !TYPE_MASK))
80    }
81
82    fn peek_discriminator(&mut self) -> Result<(u8, u8)> {
83        if self.last_discriminator.is_none() {
84            self.last_discriminator = Some(self.read_discriminator()?);
85        }
86        Ok(self.last_discriminator.unwrap())
87    }
88
89    fn consume_discriminator(&mut self) -> Result<(u8, u8)> {
90        self.last_discriminator
91            .take()
92            .map(Result::Ok)
93            .unwrap_or_else(|| self.read_discriminator())
94    }
95
96    fn read_i64(&mut self, len: usize) -> Result<i64> {
97        let mut buf = [0u8; 8];
98        let start = 8 - len;
99        self.reader.read_exact(&mut buf[start..])?;
100        if buf[start] & 0x80 != 0 {
101            // sign-extend
102            buf[0..start].fill(0xFF);
103        }
104        Ok(i64::from_be_bytes(buf))
105    }
106
107    fn read_null(&mut self) -> Result<()> {
108        let (typ, bits) = self.consume_discriminator()?;
109        if typ != TYPE_NULL {
110            return Err(Error::WrongType);
111        }
112        if bits != 0 {
113            return Err(Error::InvalidValue);
114        }
115        Ok(())
116    }
117
118    fn read_boolean(&mut self) -> Result<bool> {
119        let (typ, bits) = self.consume_discriminator()?;
120        if typ != TYPE_BOOLEAN {
121            return Err(Error::WrongType);
122        }
123        if bits > 1 {
124            return Err(Error::InvalidValue);
125        }
126        Ok(bits == 1)
127    }
128
129    fn read_integer(&mut self) -> Result<i64> {
130        let (typ, len) = self.consume_discriminator()?;
131        if typ != TYPE_INTEGER {
132            return Err(Error::WrongType);
133        }
134        if !len.is_power_of_two() {
135            return Err(Error::InvalidLength);
136        }
137        self.read_i64(len as usize)
138    }
139
140    fn read_float(&mut self) -> Result<f64> {
141        let (typ, len) = self.consume_discriminator()?;
142        if typ != TYPE_FLOAT {
143            return Err(Error::WrongType);
144        }
145        if len != 8 {
146            return Err(Error::InvalidLength);
147        }
148        let mut buf = [0u8; 8];
149        self.reader.read_exact(&mut buf)?;
150        Ok(f64::from_be_bytes(buf))
151    }
152
153    fn read_string(&mut self) -> Result<String> {
154        let (typ, llen) = self.consume_discriminator()?;
155        if typ != TYPE_STRING {
156            return Err(Error::WrongType);
157        }
158        if !llen.is_power_of_two() {
159            return Err(Error::InvalidLengthOfLength);
160        }
161        let len = self.read_i64(llen as usize)?;
162        if len < 0 {
163            return Err(Error::InvalidLength);
164        }
165        let mut s = String::with_capacity(len as usize);
166        let read = (&mut self.reader).take(len as u64).read_to_string(&mut s)?;
167        if read != len as usize {
168            return Err(Error::eof());
169        }
170        Ok(s)
171    }
172
173    fn read_raw(&mut self) -> Result<Vec<u8>> {
174        let (typ, llen) = self.consume_discriminator()?;
175        if typ != TYPE_RAW {
176            return Err(Error::WrongType);
177        }
178        if !llen.is_power_of_two() {
179            return Err(Error::InvalidLengthOfLength);
180        }
181        let len = self.read_i64(llen as usize)?;
182        if len < 0 {
183            return Err(Error::InvalidLength);
184        }
185        let mut v = Vec::with_capacity(len as usize);
186        let read = (&mut self.reader).take(len as u64).read_to_end(&mut v)?;
187        if read != len as usize {
188            return Err(Error::eof());
189        }
190        Ok(v)
191    }
192
193    fn read_list_start(&mut self) -> Result<()> {
194        let (typ, bits) = self.consume_discriminator()?;
195        if typ != TYPE_LIST {
196            return Err(Error::WrongType);
197        }
198        if bits != 0 {
199            return Err(Error::InvalidValue);
200        }
201        Ok(())
202    }
203
204    fn read_dictionary_start(&mut self) -> Result<()> {
205        let (typ, bits) = self.consume_discriminator()?;
206        if typ != TYPE_DICTIONARY {
207            return Err(Error::WrongType);
208        }
209        if bits != 0 {
210            return Err(Error::InvalidValue);
211        }
212        Ok(())
213    }
214
215    fn peek_end(&mut self) -> Result<bool> {
216        let (typ, bits) = self.peek_discriminator()?;
217        if typ != TYPE_END {
218            return Ok(false);
219        }
220        if bits != 0 {
221            return Err(Error::InvalidValue);
222        }
223        Ok(true)
224    }
225
226    fn read_end(&mut self) -> Result<()> {
227        let (typ, bits) = self.consume_discriminator()?;
228        if typ != TYPE_END {
229            return Err(Error::WrongType);
230        }
231        if bits != 0 {
232            return Err(Error::InvalidValue);
233        }
234        Ok(())
235    }
236}
237
238impl<'de, 'a, R> de::Deserializer<'de> for &'a mut Deserializer<R>
239where
240    R: Read,
241{
242    type Error = Error;
243
244    fn deserialize_any<V>(self, visitor: V) -> Result<V::Value>
245    where
246        V: Visitor<'de>,
247    {
248        let (typ, _) = self.peek_discriminator()?;
249
250        match typ {
251            TYPE_NULL => self.deserialize_unit(visitor),
252            TYPE_BOOLEAN => self.deserialize_bool(visitor),
253            TYPE_INTEGER => self.deserialize_i64(visitor),
254            TYPE_FLOAT => self.deserialize_f64(visitor),
255            TYPE_STRING => self.deserialize_str(visitor),
256            TYPE_RAW => self.deserialize_bytes(visitor),
257            TYPE_LIST => self.deserialize_seq(visitor),
258            TYPE_DICTIONARY => self.deserialize_map(visitor),
259            _ => Err(Error::InvalidType),
260        }
261    }
262
263    fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value>
264    where
265        V: Visitor<'de>,
266    {
267        let value = self.read_boolean()?;
268        visitor.visit_bool(value)
269    }
270
271    fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value>
272    where
273        V: Visitor<'de>,
274    {
275        let value = self.read_integer()?;
276        visitor.visit_i8(i8::try_from(value)?)
277    }
278
279    fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value>
280    where
281        V: Visitor<'de>,
282    {
283        let value = self.read_integer()?;
284        visitor.visit_i16(i16::try_from(value)?)
285    }
286
287    fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value>
288    where
289        V: Visitor<'de>,
290    {
291        let value = self.read_integer()?;
292        visitor.visit_i32(i32::try_from(value)?)
293    }
294
295    fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value>
296    where
297        V: Visitor<'de>,
298    {
299        let value = self.read_integer()?;
300        visitor.visit_i64(value)
301    }
302
303    fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value>
304    where
305        V: Visitor<'de>,
306    {
307        let value = self.read_integer()?;
308        visitor.visit_u8(u8::try_from(value)?)
309    }
310
311    fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value>
312    where
313        V: Visitor<'de>,
314    {
315        let value = self.read_integer()?;
316        visitor.visit_u16(u16::try_from(value)?)
317    }
318
319    fn deserialize_u32<V>(self, visitor: V) -> Result<V::Value>
320    where
321        V: Visitor<'de>,
322    {
323        let value = self.read_integer()?;
324        visitor.visit_u32(u32::try_from(value)?)
325    }
326
327    fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value>
328    where
329        V: Visitor<'de>,
330    {
331        let value = self.read_integer()?;
332        visitor.visit_u64(u64::try_from(value)?)
333    }
334
335    fn deserialize_f32<V>(self, visitor: V) -> Result<V::Value>
336    where
337        V: Visitor<'de>,
338    {
339        let value = self.read_float()?;
340        visitor.visit_f32(value as f32)
341    }
342
343    fn deserialize_f64<V>(self, visitor: V) -> Result<V::Value>
344    where
345        V: Visitor<'de>,
346    {
347        let value = self.read_float()?;
348        visitor.visit_f64(value)
349    }
350
351    fn deserialize_char<V>(self, visitor: V) -> Result<V::Value>
352    where
353        V: Visitor<'de>,
354    {
355        let value = self.read_integer()?;
356        visitor.visit_char(char::try_from(u32::try_from(value)?)?)
357    }
358
359    fn deserialize_str<V>(self, visitor: V) -> Result<V::Value>
360    where
361        V: Visitor<'de>,
362    {
363        self.deserialize_string(visitor)
364    }
365
366    fn deserialize_string<V>(self, visitor: V) -> Result<V::Value>
367    where
368        V: Visitor<'de>,
369    {
370        let value = self.read_string()?;
371        visitor.visit_string(value)
372    }
373
374    fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value>
375    where
376        V: Visitor<'de>,
377    {
378        self.deserialize_byte_buf(visitor)
379    }
380
381    fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value>
382    where
383        V: Visitor<'de>,
384    {
385        let value = self.read_raw()?;
386        visitor.visit_byte_buf(value)
387    }
388
389    fn deserialize_option<V>(self, visitor: V) -> Result<V::Value>
390    where
391        V: Visitor<'de>,
392    {
393        let (typ, _) = self.peek_discriminator()?;
394        match typ {
395            TYPE_NULL => visitor.visit_none(),
396            TYPE_BOOLEAN | TYPE_INTEGER | TYPE_FLOAT | TYPE_STRING | TYPE_RAW | TYPE_LIST
397            | TYPE_DICTIONARY => visitor.visit_some(self),
398            _ => Err(Error::WrongType),
399        }
400    }
401
402    fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value>
403    where
404        V: Visitor<'de>,
405    {
406        self.read_null()?;
407        visitor.visit_unit()
408    }
409
410    fn deserialize_unit_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value>
411    where
412        V: Visitor<'de>,
413    {
414        self.deserialize_unit(visitor)
415    }
416
417    fn deserialize_newtype_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value>
418    where
419        V: Visitor<'de>,
420    {
421        visitor.visit_newtype_struct(self)
422    }
423
424    fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value>
425    where
426        V: Visitor<'de>,
427    {
428        self.read_list_start()?;
429        let value = visitor.visit_seq(&mut *self)?;
430        self.read_end()?;
431        Ok(value)
432    }
433
434    fn deserialize_tuple<V>(self, _len: usize, visitor: V) -> Result<V::Value>
435    where
436        V: Visitor<'de>,
437    {
438        self.deserialize_seq(visitor)
439    }
440
441    fn deserialize_tuple_struct<V>(
442        self,
443        _name: &'static str,
444        _len: usize,
445        visitor: V,
446    ) -> Result<V::Value>
447    where
448        V: Visitor<'de>,
449    {
450        self.deserialize_seq(visitor)
451    }
452
453    fn deserialize_map<V>(self, visitor: V) -> Result<V::Value>
454    where
455        V: Visitor<'de>,
456    {
457        self.read_dictionary_start()?;
458        let value = visitor.visit_map(&mut *self)?;
459        self.read_end()?;
460        Ok(value)
461    }
462
463    fn deserialize_struct<V>(
464        self,
465        _name: &'static str,
466        fields: &'static [&'static str],
467        visitor: V,
468    ) -> Result<V::Value>
469    where
470        V: Visitor<'de>,
471    {
472        self.deserialize_tuple(fields.len(), visitor)
473    }
474
475    fn deserialize_enum<V>(
476        self,
477        _name: &'static str,
478        _variants: &'static [&'static str],
479        visitor: V,
480    ) -> Result<V::Value>
481    where
482        V: Visitor<'de>,
483    {
484        visitor.visit_enum(self)
485    }
486
487    fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value>
488    where
489        V: Visitor<'de>,
490    {
491        self.deserialize_str(visitor)
492    }
493
494    fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value>
495    where
496        V: Visitor<'de>,
497    {
498        self.deserialize_any(visitor)
499    }
500}
501
502impl<'de, 'a, R> SeqAccess<'de> for Deserializer<R>
503where
504    R: Read,
505{
506    type Error = Error;
507
508    fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
509    where
510        T: DeserializeSeed<'de>,
511    {
512        if self.peek_end()? {
513            return Ok(None);
514        }
515
516        seed.deserialize(self).map(Some)
517    }
518}
519
520impl<'de, 'a, R> MapAccess<'de> for Deserializer<R>
521where
522    R: Read,
523{
524    type Error = Error;
525
526    fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>>
527    where
528        K: DeserializeSeed<'de>,
529    {
530        if self.peek_end()? {
531            return Ok(None);
532        }
533
534        seed.deserialize(self).map(Some)
535    }
536
537    fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value>
538    where
539        V: DeserializeSeed<'de>,
540    {
541        seed.deserialize(self)
542    }
543}
544
545impl<'de, 'a, R> EnumAccess<'de> for &'a mut Deserializer<R>
546where
547    R: Read,
548{
549    type Error = Error;
550    type Variant = Self;
551
552    fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant)>
553    where
554        V: DeserializeSeed<'de>,
555    {
556        let (typ, _) = self.peek_discriminator()?;
557        match typ {
558            TYPE_INTEGER => {}
559            TYPE_LIST => self.read_list_start()?,
560            _ => return Err(Error::WrongType),
561        }
562        let variant_index = u32::try_from(self.read_integer()?)?;
563        let value: Result<_> = seed.deserialize(variant_index.into_deserializer());
564        Ok((value?, self))
565    }
566}
567
568impl<'de, 'a, R> VariantAccess<'de> for &'a mut Deserializer<R>
569where
570    R: Read,
571{
572    type Error = Error;
573
574    fn unit_variant(self) -> Result<()> {
575        Ok(())
576    }
577
578    fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value>
579    where
580        T: DeserializeSeed<'de>,
581    {
582        let value = seed.deserialize(&mut *self)?;
583        self.read_end()?;
584        Ok(value)
585    }
586
587    fn tuple_variant<V>(self, _len: usize, visitor: V) -> Result<V::Value>
588    where
589        V: Visitor<'de>,
590    {
591        self.read_list_start()?;
592        let value = visitor.visit_seq(&mut *self)?;
593        self.read_end()?;
594        self.read_end()?;
595        Ok(value)
596    }
597
598    fn struct_variant<V>(self, _fields: &'static [&'static str], visitor: V) -> Result<V::Value>
599    where
600        V: Visitor<'de>,
601    {
602        self.read_list_start()?;
603        let value = visitor.visit_seq(&mut *self)?;
604        self.read_end()?;
605        self.read_end()?;
606        Ok(value)
607    }
608}
609
610/// Iterator that deserializes a stream into multiple BDF values
611pub struct StreamDeserializer<'de, R, T>
612where
613    R: Read,
614    T: Deserialize<'de>,
615{
616    de: Deserializer<R>,
617    failed: bool,
618    output: PhantomData<T>,
619    lifetime: PhantomData<&'de ()>,
620}
621
622impl<'de, R, T> Iterator for StreamDeserializer<'de, R, T>
623where
624    R: Read,
625    T: Deserialize<'de>,
626{
627    type Item = Result<T>;
628
629    fn next(&mut self) -> Option<Result<T>> {
630        if self.failed {
631            return None;
632        }
633
634        match Deserialize::deserialize(&mut self.de) {
635            Err(e) => {
636                self.failed = true;
637                if e.is_eof() {
638                    None
639                } else {
640                    Some(Err(e))
641                }
642            }
643            ok => Some(ok),
644        }
645    }
646}
647
648#[cfg(test)]
649mod test {
650    use super::*;
651    use hex_literal::hex;
652    use serde::Deserialize;
653    use std::collections::HashMap;
654
655    #[test]
656    fn from_slice_maps() {
657        let buf = hex!("70 4103626172 2201C8 4103666F6F 217B 80");
658        let map = from_slice(&buf).unwrap();
659
660        let mut expected = HashMap::new();
661        expected.insert("foo".to_string(), 123u32);
662        expected.insert("bar".to_string(), 456u32);
663
664        assert_eq!(expected, map);
665    }
666
667    #[test]
668    fn from_slice_structs() {
669        #[derive(Deserialize, Debug, PartialEq, Eq)]
670        struct Test {
671            x: bool,
672            y: u32,
673            z: Vec<String>,
674        }
675        let buf = hex!("60 11 2111 60 4103666F6F 4103626172 80 80");
676        let s = from_slice(&buf).unwrap();
677        let expected = Test {
678            x: true,
679            y: 17,
680            z: vec!["foo".into(), "bar".into()],
681        };
682        assert_eq!(expected, s);
683    }
684
685    #[test]
686    fn from_slice_enums() {
687        #[derive(Deserialize, Debug, PartialEq, Eq)]
688        enum Test {
689            UnitVariant,
690            NewTypeVariant(u32),
691            TupleVariant(bool, u32),
692            StructVariant { x: bool, y: u32 },
693        }
694        let buf = hex!("2100");
695        let e = from_slice(&buf).unwrap();
696        let expected = Test::UnitVariant;
697        assert_eq!(expected, e);
698
699        let buf = hex!("60 2101 2111 80");
700        let e = from_slice(&buf).unwrap();
701        let expected = Test::NewTypeVariant(17);
702        assert_eq!(expected, e);
703
704        let buf = hex!("60 2102 60 11 2111 80 80");
705        let e = from_slice(&buf).unwrap();
706        let expected = Test::TupleVariant(true, 17);
707        assert_eq!(expected, e);
708
709        let buf = hex!("60 2103 60 11 2111 80 80");
710        let e = from_slice(&buf).unwrap();
711        let expected = Test::StructVariant { x: true, y: 17 };
712        assert_eq!(expected, e);
713    }
714
715    #[test]
716    fn from_slice_options() {
717        let buf = hex!("00");
718        let o = from_slice(&buf).unwrap();
719        let expected: Option<u32> = None;
720        assert_eq!(expected, o);
721
722        let buf = hex!("2111");
723        let o = from_slice(&buf).unwrap();
724        let expected = Some(17);
725        assert_eq!(expected, o);
726    }
727
728    #[test]
729    fn stream_deserializer() {
730        let buf = hex!("2100 2101 80 2103");
731        let vec = Deserializer::new(buf.as_ref())
732            .into_iter()
733            .collect::<Vec<Result<u64>>>();
734        assert_eq!(vec.len(), 3);
735        assert_eq!(vec[0].as_ref().unwrap(), &0);
736        assert_eq!(vec[1].as_ref().unwrap(), &1);
737        assert!(matches!(vec[2].as_ref().unwrap_err(), Error::WrongType));
738
739        let buf = hex!("2100 2101 2102 2103");
740        let vec = Deserializer::new(buf.as_ref())
741            .into_iter()
742            .collect::<Result<Vec<u64>>>()
743            .unwrap();
744        assert_eq!(vec![0, 1, 2, 3], vec);
745    }
746}
747
748// TODO read buffer size limits for ddos protection
749// TODO zero-copy reads
750// TODO nesting limit