1use std::io::{Cursor, Read};
4
5use bytes::Buf;
6use serde::de::{IgnoredAny, Visitor};
7use serde::{self, Deserialize};
8
9use error::Error;
10use internal::gob::{Message, Stream};
11use internal::types::{TypeId, Types, WireType};
12use internal::utils::{Bow, Buffer};
13
14use internal::de::FieldValueDeserializer;
15use internal::de::ValueDeserializer;
16
17pub struct StreamDeserializer<R> {
18 defs: Types,
19 stream: Stream<R>,
20 buffer: Buffer,
21 prev_len: usize,
22}
23
24impl<R> StreamDeserializer<R> {
25 pub fn new(read: R) -> Self {
26 StreamDeserializer {
27 defs: Types::new(),
28 stream: Stream::new(read),
29 buffer: Buffer::new(),
30 prev_len: 0,
31 }
32 }
33
34 pub fn deserialize<'de, T>(&'de mut self) -> Result<Option<T>, Error>
35 where
36 R: Read,
37 T: Deserialize<'de>,
38 {
39 if let Some(deserializer) = self.deserializer()? {
40 Ok(Some(T::deserialize(deserializer)?))
41 } else {
42 Ok(None)
43 }
44 }
45
46 pub fn deserializer<'de>(&'de mut self) -> Result<Option<Deserializer<'de>>, Error>
47 where
48 R: Read,
49 {
50 if self.prev_len > 0 {
51 self.buffer.advance(self.prev_len);
52 self.prev_len = 0;
53 }
54 loop {
55 let header = match self.stream.read_section(&mut self.buffer)? {
56 Some(header) => header,
57 None => return Ok(None),
58 };
59
60 if header.type_id >= 0 {
61 let slice = &self.buffer.bytes()[header.payload_range.clone()];
62 let msg = Message::new(Cursor::new(slice));
63 self.prev_len = header.payload_range.end;
64 return Ok(Some(Deserializer {
65 defs: Bow::Borrowed(&mut self.defs),
66 msg: msg,
67 type_id: Some(TypeId(header.type_id)),
68 }));
69 }
70
71 let wire_type = {
72 let slice = &self.buffer.bytes()[header.payload_range.clone()];
73 let mut msg = Message::new(Cursor::new(slice));
74 let de = FieldValueDeserializer::new(TypeId::WIRE_TYPE, &self.defs, &mut msg);
75 WireType::deserialize(de)
76 }?;
77
78 if -header.type_id != wire_type.common().id.0 {
79 return Err(Error::deserialize("type id mismatch"));
80 }
81
82 self.defs.insert(wire_type);
83 self.buffer.advance(header.payload_range.end);
84 }
85 }
86
87 pub fn get_ref(&self) -> &R {
88 self.stream.get_ref()
89 }
90
91 pub fn get_mut(&mut self) -> &mut R {
92 self.stream.get_mut()
93 }
94
95 pub fn into_inner(self) -> R {
96 self.stream.into_inner()
97 }
98}
99
100pub struct Deserializer<'de> {
101 defs: Bow<'de, Types>,
102 msg: Message<Cursor<&'de [u8]>>,
103 type_id: Option<TypeId>,
104}
105
106impl<'de> Deserializer<'de> {
107 pub fn from_slice(input: &'de [u8]) -> Deserializer<'de> {
108 Deserializer {
109 defs: Bow::Owned(Types::new()),
110 msg: Message::new(Cursor::new(input)),
111 type_id: None,
112 }
113 }
114
115 fn value_deserializer<'t>(&'t mut self) -> Result<ValueDeserializer<'t, 'de>, Error> {
116 if let Some(type_id) = self.type_id {
117 return Ok(ValueDeserializer::new(type_id, &self.defs, &mut self.msg));
118 }
119
120 loop {
121 let _len = self.msg.read_bytes_len()?;
122 let type_id = self.msg.read_int()?;
123
124 if type_id >= 0 {
125 return Ok(ValueDeserializer::new(
126 TypeId(type_id),
127 &self.defs,
128 &mut self.msg,
129 ));
130 }
131
132 let wire_type = {
133 let de = FieldValueDeserializer::new(TypeId::WIRE_TYPE, &self.defs, &mut self.msg);
134 WireType::deserialize(de)
135 }?;
136
137 if -type_id != wire_type.common().id.0 {
138 return Err(serde::de::Error::custom(format!("type id mismatch")));
139 }
140
141 self.defs.insert(wire_type);
142 }
143 }
144}
145
146impl<'de> serde::Deserializer<'de> for Deserializer<'de> {
147 type Error = Error;
148
149 fn deserialize_any<V>(mut self, visitor: V) -> Result<V::Value, Self::Error>
150 where
151 V: Visitor<'de>,
152 {
153 self.value_deserializer()?.deserialize_any(visitor)
154 }
155
156 fn deserialize_enum<V>(
157 mut self,
158 name: &'static str,
159 variants: &'static [&'static str],
160 visitor: V,
161 ) -> Result<V::Value, Self::Error>
162 where
163 V: Visitor<'de>,
164 {
165 self.value_deserializer()?
166 .deserialize_enum(name, variants, visitor)
167 }
168
169 fn deserialize_struct<V>(
170 mut self,
171 name: &'static str,
172 fields: &'static [&'static str],
173 visitor: V,
174 ) -> Result<V::Value, Self::Error>
175 where
176 V: Visitor<'de>,
177 {
178 self.value_deserializer()?
179 .deserialize_struct(name, fields, visitor)
180 }
181
182 fn deserialize_char<V>(self, visitor: V) -> Result<V::Value, Self::Error>
183 where
184 V: Visitor<'de>,
185 {
186 let int = i64::deserialize(self)?;
187 if let Some(c) = ::std::char::from_u32(int as u32) {
188 visitor.visit_char(c)
189 } else {
190 Err(serde::de::Error::custom(format!(
191 "invalid char code {}",
192 int
193 )))
194 }
195 }
196
197 #[inline]
198 fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value, Self::Error>
199 where
200 V: Visitor<'de>,
201 {
202 self.deserialize_ignored_any(IgnoredAny)?;
203 visitor.visit_unit()
204 }
205
206 forward_to_deserialize_any! {
207 bool i8 i16 i32 i64 u8 u16 u32 u64 f32 f64 str string bytes
208 byte_buf option unit_struct newtype_struct seq tuple
209 tuple_struct map identifier ignored_any
210 }
211}