bytekey_fix/
de.rs

1use byteorder::{ReadBytesExt, BE};
2use serde;
3use serde::de::{Deserialize, Visitor};
4use std;
5use std::error::Error as StdError;
6use std::fmt;
7use std::io::{self, Read};
8use std::mem::transmute;
9use std::{i16, i32, i64, i8};
10use utf8;
11
12/// A decoder for deserializing bytes from an order preserving format to a value.
13///
14/// Please see the **Serializer** documentation for a precise overview of the `bytekey` format.
15#[derive(Debug)]
16pub struct Deserializer<R> {
17    reader: R,
18}
19
20/// Errors that may be occur when deserializing.
21#[derive(Debug)]
22pub enum Error {
23    DeserializeAnyUnsupported,
24    UnexpectedEof,
25    InvalidUtf8,
26    Io(io::Error),
27    Message(String),
28}
29
30/// Shorthand for `Result<T, bytekey::de::Error>`.
31pub type Result<T> = std::result::Result<T, Error>;
32
33/// Deserialize data from the given slice of bytes.
34///
35/// #### Usage
36///
37/// ```
38/// # use bytekey::{serialize, deserialize};
39/// let bytes = serialize(&42usize).unwrap();
40/// assert_eq!(42usize, deserialize::<usize>(&bytes).unwrap());
41/// ```
42pub fn deserialize<T>(bytes: &[u8]) -> Result<T>
43    where
44        T: for<'de> Deserialize<'de>,
45{
46    deserialize_from(bytes)
47}
48
49/// Deserialize data from the given byte reader.
50///
51/// #### Usage
52///
53/// ```
54/// # use bytekey::{serialize, deserialize_from};
55/// let bytes = serialize(&42u64).unwrap();
56/// let result: u64 = deserialize_from(&bytes[..]).unwrap();
57/// assert_eq!(42u64, result);
58/// ```
59pub fn deserialize_from<R, T>(reader: R) -> Result<T>
60    where
61        R: io::BufRead,
62        T: for<'de> Deserialize<'de>,
63{
64    let mut deserializer = Deserializer::new(reader);
65    T::deserialize(&mut deserializer)
66}
67
68impl<R: io::Read> Deserializer<R> {
69    /// Creates a new ordered bytes encoder whose output will be written to the provided writer.
70    pub fn new(reader: R) -> Deserializer<R> {
71        Deserializer { reader: reader }
72    }
73
74    /// Deserialize a `u64` that has been serialized using the `serialize_var_u64` method.
75    pub fn deserialize_var_u64(&mut self) -> Result<u64> {
76        let header = self.reader.read_u8()?;
77        let n = header >> 4;
78        let (mut val, _) = ((header & 0x0F) as u64).overflowing_shl(n as u32 * 8);
79        for i in 1..n + 1 {
80            let byte = self.reader.read_u8()?;
81            val += (byte as u64) << ((n - i) * 8);
82        }
83        Ok(val)
84    }
85
86    /// Deserialize an `i64` that has been serialized using the `serialize_var_i64` method.
87    pub fn deserialize_var_i64(&mut self) -> Result<i64> {
88        let header = self.reader.read_u8()?;
89        let mask = ((header ^ 0x80) as i8 >> 7) as u8;
90        let n = ((header >> 3) ^ mask) & 0x0F;
91        let (mut val, _) = (((header ^ mask) & 0x07) as u64).overflowing_shl(n as u32 * 8);
92        for i in 1..n + 1 {
93            let byte = self.reader.read_u8()?;
94            val += ((byte ^ mask) as u64) << ((n - i) * 8);
95        }
96        let final_mask = (((mask as i64) << 63) >> 63) as u64;
97        val ^= final_mask;
98        Ok(val as i64)
99    }
100}
101
102impl<'de, 'a, R> serde::de::Deserializer<'de> for &'a mut Deserializer<R>
103    where
104        R: io::BufRead,
105{
106    type Error = Error;
107
108    fn deserialize_any<V>(self, _visitor: V) -> Result<V::Value>
109        where
110            V: Visitor<'de>,
111    {
112        Err(Error::DeserializeAnyUnsupported)
113    }
114
115    fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value>
116        where
117            V: Visitor<'de>,
118    {
119        let b = match self.reader.read_u8()? {
120            0 => false,
121            _ => true,
122        };
123        visitor.visit_bool(b)
124    }
125
126    fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value>
127        where
128            V: Visitor<'de>,
129    {
130        let i = self.reader.read_i8()?;
131        visitor.visit_i8(i ^ i8::MIN)
132    }
133
134    fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value>
135        where
136            V: Visitor<'de>,
137    {
138        let i = self.reader.read_i16::<BE>()?;
139        visitor.visit_i16(i ^ i16::MIN)
140    }
141
142    fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value>
143        where
144            V: Visitor<'de>,
145    {
146        let i = self.reader.read_i32::<BE>()?;
147        visitor.visit_i32(i ^ i32::MIN)
148    }
149
150    fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value>
151        where
152            V: Visitor<'de>,
153    {
154        let i = self.reader.read_i64::<BE>()?;
155        visitor.visit_i64(i ^ i64::MIN)
156    }
157
158    fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value>
159        where
160            V: Visitor<'de>,
161    {
162        let u = self.reader.read_u8()?;
163        visitor.visit_u8(u)
164    }
165
166    fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value>
167        where
168            V: Visitor<'de>,
169    {
170        let u = self.reader.read_u16::<BE>()?;
171        visitor.visit_u16(u)
172    }
173
174    fn deserialize_u32<V>(self, visitor: V) -> Result<V::Value>
175        where
176            V: Visitor<'de>,
177    {
178        let u = self.reader.read_u32::<BE>()?;
179        visitor.visit_u32(u)
180    }
181
182    fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value>
183        where
184            V: Visitor<'de>,
185    {
186        let u = self.reader.read_u64::<BE>()?;
187        visitor.visit_u64(u)
188    }
189
190    fn deserialize_f32<V>(self, visitor: V) -> Result<V::Value>
191        where
192            V: Visitor<'de>,
193    {
194        let val = self.reader.read_i32::<BE>()?;
195        let t = ((val ^ i32::MIN) >> 31) | i32::MIN;
196        let f: f32 = unsafe { transmute(val ^ t) };
197        visitor.visit_f32(f)
198    }
199
200    fn deserialize_f64<V>(self, visitor: V) -> Result<V::Value>
201        where
202            V: Visitor<'de>,
203    {
204        let val = self.reader.read_i64::<BE>()?;
205        let t = ((val ^ i64::MIN) >> 63) | i64::MIN;
206        let f: f64 = unsafe { transmute(val ^ t) };
207        visitor.visit_f64(f)
208    }
209
210    fn deserialize_char<V>(self, visitor: V) -> Result<V::Value>
211        where
212            V: Visitor<'de>,
213    {
214        let mut utf8_decoder = utf8::BufReadDecoder::new(&mut self.reader);
215        match utf8_decoder.next_strict() {
216            Some(Ok(s)) => {
217                let ch = s.chars().next().expect("expected at least one `char`");
218                visitor.visit_char(ch)
219            }
220            Some(Err(err)) => return Err(err.into()),
221            None => return Err(Error::UnexpectedEof.into()),
222        }
223    }
224
225    fn deserialize_str<V>(self, visitor: V) -> Result<V::Value>
226        where
227            V: Visitor<'de>,
228    {
229        let mut string = String::new();
230        let mut utf8_decoder = utf8::BufReadDecoder::new(&mut self.reader);
231        while let Some(res) = utf8_decoder.next_strict() {
232            match res {
233                Ok(mut s) => {
234                    // The only way for us to know whether `String` or `Vec` deserialization is
235                    // complete is to check for a EOF character yielded by the reader.
236                    const EOF: char = '\u{0}';
237                    const EOF_STR: &'static str = "\u{0}";
238                    if s.len() >= EOF.len_utf8() {
239                        let eof_start = s.len() - EOF.len_utf8();
240                        if &s[eof_start..] == EOF_STR {
241                            s = &s[..eof_start];
242                        }
243                    }
244                    string.push_str(s);
245                }
246                Err(utf8::BufReadDecoderError::Io(err)) => return Err(err.into()),
247                Err(utf8::BufReadDecoderError::InvalidByteSequence(_)) => break,
248            }
249        }
250        let mut tmp = [0u8; 1];
251        self.reader.read(&mut tmp).unwrap();
252        assert_eq!(tmp[0], 0xFF);
253
254        visitor.visit_string(string)
255    }
256
257    fn deserialize_string<V>(self, visitor: V) -> Result<V::Value>
258        where
259            V: Visitor<'de>,
260    {
261        self.deserialize_str(visitor)
262    }
263
264    fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value>
265        where
266            V: Visitor<'de>,
267    {
268        let mut bytes = vec![];
269        for byte in (&mut self.reader).bytes() {
270            bytes.push(byte?);
271        }
272        visitor.visit_byte_buf(bytes)
273    }
274
275    fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value>
276        where
277            V: Visitor<'de>,
278    {
279        self.deserialize_bytes(visitor)
280    }
281
282    fn deserialize_option<V>(self, visitor: V) -> Result<V::Value>
283        where
284            V: Visitor<'de>,
285    {
286        match self.reader.read_u8()? {
287            0 => visitor.visit_none(),
288            1 => visitor.visit_some(&mut *self),
289            b => {
290                let msg = format!("expected `0` or `1` for option tag - found {}", b);
291                Err(Error::Message(msg))
292            }
293        }
294    }
295
296    fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value>
297        where
298            V: Visitor<'de>,
299    {
300        visitor.visit_unit()
301    }
302
303    fn deserialize_unit_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value>
304        where
305            V: Visitor<'de>,
306    {
307        visitor.visit_unit()
308    }
309
310    fn deserialize_newtype_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value>
311        where
312            V: Visitor<'de>,
313    {
314        visitor.visit_newtype_struct(self)
315    }
316
317    fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value>
318        where
319            V: Visitor<'de>,
320    {
321        struct Access<'a, R>
322            where
323                R: 'a + io::BufRead,
324        {
325            deserializer: &'a mut Deserializer<R>,
326        }
327
328        impl<'de, 'a, R> serde::de::SeqAccess<'de> for Access<'a, R>
329            where
330                R: io::BufRead,
331        {
332            type Error = Error;
333
334            fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
335                where
336                    T: serde::de::DeserializeSeed<'de>,
337            {
338                match serde::de::DeserializeSeed::deserialize(seed, &mut *self.deserializer) {
339                    Ok(v) => Ok(Some(v)),
340                    Err(Error::Io(ref err)) if err.kind() == io::ErrorKind::UnexpectedEof => {
341                        Ok(None)
342                    }
343                    Err(err) => Err(err),
344                }
345            }
346        }
347
348        visitor.visit_seq(Access { deserializer: self })
349    }
350
351    fn deserialize_tuple<V>(self, len: usize, visitor: V) -> Result<V::Value>
352        where
353            V: Visitor<'de>,
354    {
355        struct Access<'a, R>
356            where
357                R: 'a + io::BufRead,
358        {
359            deserializer: &'a mut Deserializer<R>,
360            len: usize,
361        }
362
363        impl<'de, 'a, R> serde::de::SeqAccess<'de> for Access<'a, R>
364            where
365                R: io::BufRead,
366        {
367            type Error = Error;
368
369            fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
370                where
371                    T: serde::de::DeserializeSeed<'de>,
372            {
373                if self.len == 0 {
374                    return Ok(None);
375                }
376                self.len -= 1;
377                let value = serde::de::DeserializeSeed::deserialize(seed, &mut *self.deserializer)?;
378                Ok(Some(value))
379            }
380
381            fn size_hint(&self) -> Option<usize> {
382                Some(self.len)
383            }
384        }
385
386        visitor.visit_seq(Access {
387            deserializer: self,
388            len: len,
389        })
390    }
391
392    fn deserialize_tuple_struct<V>(
393        self,
394        _name: &'static str,
395        len: usize,
396        visitor: V,
397    ) -> Result<V::Value>
398        where
399            V: Visitor<'de>,
400    {
401        self.deserialize_tuple(len, visitor)
402    }
403
404    fn deserialize_map<V>(self, visitor: V) -> Result<V::Value>
405        where
406            V: Visitor<'de>,
407    {
408        struct Access<'a, R>
409            where
410                R: 'a + io::BufRead,
411        {
412            deserializer: &'a mut Deserializer<R>,
413        }
414
415        impl<'de, 'a, R> serde::de::MapAccess<'de> for Access<'a, R>
416            where
417                R: io::BufRead,
418        {
419            type Error = Error;
420
421            fn next_key_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
422                where
423                    T: serde::de::DeserializeSeed<'de>,
424            {
425                match serde::de::DeserializeSeed::deserialize(seed, &mut *self.deserializer) {
426                    Ok(v) => Ok(Some(v)),
427                    Err(Error::Io(ref err)) if err.kind() == io::ErrorKind::UnexpectedEof => {
428                        Ok(None)
429                    }
430                    Err(err) => Err(err),
431                }
432            }
433
434            fn next_value_seed<T>(&mut self, seed: T) -> Result<T::Value>
435                where
436                    T: serde::de::DeserializeSeed<'de>,
437            {
438                serde::de::DeserializeSeed::deserialize(seed, &mut *self.deserializer)
439            }
440        }
441
442        visitor.visit_map(Access { deserializer: self })
443    }
444
445    fn deserialize_struct<V>(
446        self,
447        _name: &'static str,
448        fields: &'static [&'static str],
449        visitor: V,
450    ) -> Result<V::Value>
451        where
452            V: Visitor<'de>,
453    {
454        self.deserialize_tuple(fields.len(), visitor)
455    }
456
457    fn deserialize_enum<V>(
458        self,
459        _name: &'static str,
460        _fields: &'static [&'static str],
461        visitor: V,
462    ) -> Result<V::Value>
463        where
464            V: Visitor<'de>,
465    {
466        impl<'de, 'a, R> serde::de::EnumAccess<'de> for &'a mut Deserializer<R>
467            where
468                R: io::BufRead,
469        {
470            type Error = Error;
471            type Variant = Self;
472
473            fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant)>
474                where
475                    V: serde::de::DeserializeSeed<'de>,
476            {
477                let idx: u32 = serde::de::Deserialize::deserialize(&mut *self)?;
478                let val: Result<_> =
479                    seed.deserialize(serde::de::IntoDeserializer::into_deserializer(idx));
480                Ok((val?, self))
481            }
482        }
483
484        impl<'de, 'a, R> serde::de::VariantAccess<'de> for &'a mut Deserializer<R>
485            where
486                R: io::BufRead,
487        {
488            type Error = Error;
489
490            fn unit_variant(self) -> Result<()> {
491                Ok(())
492            }
493
494            fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value>
495                where
496                    T: serde::de::DeserializeSeed<'de>,
497            {
498                serde::de::DeserializeSeed::deserialize(seed, self)
499            }
500
501            fn tuple_variant<V>(self, len: usize, visitor: V) -> Result<V::Value>
502                where
503                    V: serde::de::Visitor<'de>,
504            {
505                serde::de::Deserializer::deserialize_tuple(self, len, visitor)
506            }
507
508            fn struct_variant<V>(
509                self,
510                fields: &'static [&'static str],
511                visitor: V,
512            ) -> Result<V::Value>
513                where
514                    V: serde::de::Visitor<'de>,
515            {
516                serde::de::Deserializer::deserialize_tuple(self, fields.len(), visitor)
517            }
518        }
519
520        visitor.visit_enum(self)
521    }
522
523    fn deserialize_ignored_any<V>(self, _visitor: V) -> Result<V::Value>
524        where
525            V: serde::de::Visitor<'de>,
526    {
527        Err(Error::DeserializeAnyUnsupported)
528    }
529
530    fn deserialize_identifier<V>(self, _visitor: V) -> Result<V::Value>
531        where
532            V: serde::de::Visitor<'de>,
533    {
534        Err(Error::DeserializeAnyUnsupported)
535    }
536}
537
538impl<'a> From<utf8::BufReadDecoderError<'a>> for Error {
539    fn from(_err: utf8::BufReadDecoderError) -> Self {
540        Error::InvalidUtf8
541    }
542}
543
544impl From<io::Error> for Error {
545    fn from(err: io::Error) -> Self {
546        Error::Io(err)
547    }
548}
549
550#[allow(deprecated)]
551impl fmt::Display for Error {
552    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
553        write!(f, "{}", match *self {
554            Error::DeserializeAnyUnsupported => "`bytekey` is not a self-describing format",
555            Error::UnexpectedEof => "encountered unexpected EOF when deserializing utf8",
556            Error::InvalidUtf8 => "attempted to deserialize invalid utf8",
557            Error::Io(ref err) => err.description(),
558            Error::Message(ref msg) => msg,
559        })
560    }
561}
562
563impl StdError for Error {
564    fn source(&self) -> Option<&(dyn StdError + 'static)> {
565        match *self {
566            Error::DeserializeAnyUnsupported => None,
567            Error::UnexpectedEof => None,
568            Error::InvalidUtf8 => None,
569            Error::Io(ref err) => Some(err),
570            Error::Message(ref _msg) => None,
571        }
572    }
573}
574
575impl serde::de::Error for Error {
576    fn custom<T: fmt::Display>(msg: T) -> Self {
577        Error::Message(msg.to_string())
578    }
579}