gob/de/
mod.rs

1//! Deserialization
2
3use std::io::{Cursor, Read};
4
5use bytes::Buf;
6use serde::de::{IgnoredAny, Visitor};
7use serde::{self, Deserialize};
8
9use error::Error;
10use internal::gob::{Message, Stream};
11use internal::types::{TypeId, Types, WireType};
12use internal::utils::{Bow, Buffer};
13
14use internal::de::FieldValueDeserializer;
15use internal::de::ValueDeserializer;
16
17pub struct StreamDeserializer<R> {
18    defs: Types,
19    stream: Stream<R>,
20    buffer: Buffer,
21    prev_len: usize,
22}
23
24impl<R> StreamDeserializer<R> {
25    pub fn new(read: R) -> Self {
26        StreamDeserializer {
27            defs: Types::new(),
28            stream: Stream::new(read),
29            buffer: Buffer::new(),
30            prev_len: 0,
31        }
32    }
33
34    pub fn deserialize<'de, T>(&'de mut self) -> Result<Option<T>, Error>
35    where
36        R: Read,
37        T: Deserialize<'de>,
38    {
39        if let Some(deserializer) = self.deserializer()? {
40            Ok(Some(T::deserialize(deserializer)?))
41        } else {
42            Ok(None)
43        }
44    }
45
46    pub fn deserializer<'de>(&'de mut self) -> Result<Option<Deserializer<'de>>, Error>
47    where
48        R: Read,
49    {
50        if self.prev_len > 0 {
51            self.buffer.advance(self.prev_len);
52            self.prev_len = 0;
53        }
54        loop {
55            let header = match self.stream.read_section(&mut self.buffer)? {
56                Some(header) => header,
57                None => return Ok(None),
58            };
59
60            if header.type_id >= 0 {
61                let slice = &self.buffer.bytes()[header.payload_range.clone()];
62                let msg = Message::new(Cursor::new(slice));
63                self.prev_len = header.payload_range.end;
64                return Ok(Some(Deserializer {
65                    defs: Bow::Borrowed(&mut self.defs),
66                    msg: msg,
67                    type_id: Some(TypeId(header.type_id)),
68                }));
69            }
70
71            let wire_type = {
72                let slice = &self.buffer.bytes()[header.payload_range.clone()];
73                let mut msg = Message::new(Cursor::new(slice));
74                let de = FieldValueDeserializer::new(TypeId::WIRE_TYPE, &self.defs, &mut msg);
75                WireType::deserialize(de)
76            }?;
77
78            if -header.type_id != wire_type.common().id.0 {
79                return Err(Error::deserialize("type id mismatch"));
80            }
81
82            self.defs.insert(wire_type);
83            self.buffer.advance(header.payload_range.end);
84        }
85    }
86
87    pub fn get_ref(&self) -> &R {
88        self.stream.get_ref()
89    }
90
91    pub fn get_mut(&mut self) -> &mut R {
92        self.stream.get_mut()
93    }
94
95    pub fn into_inner(self) -> R {
96        self.stream.into_inner()
97    }
98}
99
100pub struct Deserializer<'de> {
101    defs: Bow<'de, Types>,
102    msg: Message<Cursor<&'de [u8]>>,
103    type_id: Option<TypeId>,
104}
105
106impl<'de> Deserializer<'de> {
107    pub fn from_slice(input: &'de [u8]) -> Deserializer<'de> {
108        Deserializer {
109            defs: Bow::Owned(Types::new()),
110            msg: Message::new(Cursor::new(input)),
111            type_id: None,
112        }
113    }
114
115    fn value_deserializer<'t>(&'t mut self) -> Result<ValueDeserializer<'t, 'de>, Error> {
116        if let Some(type_id) = self.type_id {
117            return Ok(ValueDeserializer::new(type_id, &self.defs, &mut self.msg));
118        }
119
120        loop {
121            let _len = self.msg.read_bytes_len()?;
122            let type_id = self.msg.read_int()?;
123
124            if type_id >= 0 {
125                return Ok(ValueDeserializer::new(
126                    TypeId(type_id),
127                    &self.defs,
128                    &mut self.msg,
129                ));
130            }
131
132            let wire_type = {
133                let de = FieldValueDeserializer::new(TypeId::WIRE_TYPE, &self.defs, &mut self.msg);
134                WireType::deserialize(de)
135            }?;
136
137            if -type_id != wire_type.common().id.0 {
138                return Err(serde::de::Error::custom(format!("type id mismatch")));
139            }
140
141            self.defs.insert(wire_type);
142        }
143    }
144}
145
146impl<'de> serde::Deserializer<'de> for Deserializer<'de> {
147    type Error = Error;
148
149    fn deserialize_any<V>(mut self, visitor: V) -> Result<V::Value, Self::Error>
150    where
151        V: Visitor<'de>,
152    {
153        self.value_deserializer()?.deserialize_any(visitor)
154    }
155
156    fn deserialize_enum<V>(
157        mut self,
158        name: &'static str,
159        variants: &'static [&'static str],
160        visitor: V,
161    ) -> Result<V::Value, Self::Error>
162    where
163        V: Visitor<'de>,
164    {
165        self.value_deserializer()?
166            .deserialize_enum(name, variants, visitor)
167    }
168
169    fn deserialize_struct<V>(
170        mut self,
171        name: &'static str,
172        fields: &'static [&'static str],
173        visitor: V,
174    ) -> Result<V::Value, Self::Error>
175    where
176        V: Visitor<'de>,
177    {
178        self.value_deserializer()?
179            .deserialize_struct(name, fields, visitor)
180    }
181
182    fn deserialize_char<V>(self, visitor: V) -> Result<V::Value, Self::Error>
183    where
184        V: Visitor<'de>,
185    {
186        let int = i64::deserialize(self)?;
187        if let Some(c) = ::std::char::from_u32(int as u32) {
188            visitor.visit_char(c)
189        } else {
190            Err(serde::de::Error::custom(format!(
191                "invalid char code {}",
192                int
193            )))
194        }
195    }
196
197    #[inline]
198    fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value, Self::Error>
199    where
200        V: Visitor<'de>,
201    {
202        self.deserialize_ignored_any(IgnoredAny)?;
203        visitor.visit_unit()
204    }
205
206    forward_to_deserialize_any! {
207        bool i8 i16 i32 i64 u8 u16 u32 u64 f32 f64 str string bytes
208        byte_buf option unit_struct newtype_struct seq tuple
209        tuple_struct map identifier ignored_any
210    }
211}