1use crate::code;
2use crate::unpack;
3use crate::unpack_error;
4use crate::BufferedRead;
5
6use serde;
7use serde::de::{self, Deserialize, DeserializeOwned, DeserializeSeed, Visitor};
8use serde::forward_to_deserialize_any;
9use std::io;
10
11use std::error;
12use std::fmt::{self, Display};
13
14#[derive(Debug)]
15pub enum DeError {
16 InvalidSize,
17 UnpackError(unpack_error::UnpackError),
18 Custom(String),
19}
20
21impl From<unpack_error::UnpackError> for DeError {
22 fn from(err: unpack_error::UnpackError) -> DeError {
23 DeError::UnpackError(err)
24 }
25}
26
27impl Display for DeError {
28 fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
29 error::Error::description(self).fmt(f)
30 }
31}
32
33impl error::Error for DeError {
34 fn description(&self) -> &str {
35 use DeError::*;
36
37 match *self {
38 InvalidSize => "invalid size",
39 UnpackError(ref e) => e.description(),
40 Custom(ref s) => s,
41 }
42 }
43
44 fn cause(&self) -> Option<&dyn error::Error> {
45 use DeError::*;
46
47 match *self {
48 UnpackError(ref e) => Some(e),
49 Custom(_) => None,
50 InvalidSize => None,
51 }
52 }
53}
54
55impl serde::de::Error for DeError {
56 fn custom<T: Display>(msg: T) -> DeError {
57 DeError::Custom(msg.to_string())
58 }
59}
60
61struct PeekReader<R> {
62 code: Option<code::Code>,
63 reader: R,
64}
65
66impl<R: io::Read> PeekReader<R> {
67 pub fn peek_code(&mut self) -> Result<&code::Code, unpack_error::UnpackError> {
68 if let Some(ref v) = self.code {
69 Ok(v)
70 } else {
71 let code = unpack::read_code(&mut self.reader)?;
72 self.code = Some(code);
73 Ok(self.code.as_ref().unwrap())
74 }
75 }
76
77 pub fn consume_code(&mut self) -> Option<code::Code> {
78 self.code.take()
79 }
80}
81
82impl<R: io::Read> io::Read for PeekReader<R> {
83 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
84 if let Some(ref v) = self.code {
85 buf[0] = v.to_u8();
86 if buf.len() > 1 {
87 self.reader.read(&mut buf[1..])
88 } else {
89 Ok(1)
90 }
91 } else {
92 self.reader.read(buf)
93 }
94 }
95}
96
97impl<'a, R: BufferedRead<'a>> BufferedRead<'a> for PeekReader<R> {
98 fn fill_buf(&self) -> io::Result<&'a [u8]> {
99 self.reader.fill_buf()
100 }
101
102 fn consume(&mut self, len: usize) {
103 self.reader.consume(len)
104 }
105}
106
107struct SeqAccess<'a, R: io::Read + 'a> {
108 de: &'a mut Deserializer<R>,
109 len: usize,
110}
111
112impl<'de, 'a, R> serde::de::SeqAccess<'de> for SeqAccess<'a, R>
113where
114 R: BufferedRead<'de> + 'a,
115{
116 type Error = DeError;
117
118 fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error>
119 where
120 T: serde::de::DeserializeSeed<'de>,
121 {
122 if self.len > 0 {
123 self.len -= 1;
124 Ok(Some(seed.deserialize(&mut *self.de)?))
125 } else {
126 Ok(None)
127 }
128 }
129
130 fn size_hint(&self) -> Option<usize> {
131 Some(self.len)
132 }
133}
134
135struct MapAccess<'a, R: 'a> {
136 de: &'a mut Deserializer<R>,
137 len: usize,
138}
139
140impl<'de, 'a, R> de::MapAccess<'de> for MapAccess<'a, R>
141where
142 R: BufferedRead<'de> + 'a,
143{
144 type Error = DeError;
145
146 fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, Self::Error>
147 where
148 K: DeserializeSeed<'de>,
149 {
150 if self.len > 0 {
151 self.len -= 1;
152 Ok(Some(seed.deserialize(&mut *self.de)?))
153 } else {
154 Ok(None)
155 }
156 }
157
158 fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Self::Error>
159 where
160 V: DeserializeSeed<'de>,
161 {
162 Ok(seed.deserialize(&mut *self.de)?)
163 }
164
165 fn size_hint(&self) -> Option<usize> {
166 Some(self.len)
167 }
168}
169
170pub struct Deserializer<R> {
200 reader: PeekReader<R>,
201}
202
203impl<R> Deserializer<R> {
204 pub fn new(r: R) -> Self {
205 Deserializer {
206 reader: PeekReader {
207 code: None,
208 reader: r,
209 },
210 }
211 }
212}
213
214macro_rules! impl_nums {
215 ($dser_method:ident, $visitor_method:ident, $unpack_method:ident) => {
216 #[inline]
217 fn $dser_method<V>(self, visitor: V) -> Result<V::Value, Self::Error>
218 where
219 V: serde::de::Visitor<'de>,
220 {
221 let v = unpack::$unpack_method(&mut self.reader)?;
222 visitor.$visitor_method(v)
223 }
224 }
225}
226
227impl<'de, 'a, R> serde::Deserializer<'de> for &'a mut Deserializer<R>
228where
229 R: BufferedRead<'de>,
230{
231 type Error = DeError;
232
233 impl_nums!(deserialize_u8, visit_u8, unpack_u8);
234 impl_nums!(deserialize_u16, visit_u16, unpack_u16);
235 impl_nums!(deserialize_u32, visit_u32, unpack_u32);
236 impl_nums!(deserialize_u64, visit_u64, unpack_u64);
237 impl_nums!(deserialize_i8, visit_i8, unpack_i8);
238 impl_nums!(deserialize_i16, visit_i16, unpack_i16);
239 impl_nums!(deserialize_i32, visit_i32, unpack_i32);
240 impl_nums!(deserialize_i64, visit_i64, unpack_i64);
241 impl_nums!(deserialize_f32, visit_f32, unpack_f32);
242 impl_nums!(deserialize_f64, visit_f64, unpack_f64);
243
244 fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
245 where
246 V: Visitor<'de>,
247 {
248 use code::Code;
249
250 match self.reader.peek_code()? {
251 Code::Nil => self.deserialize_unit(visitor),
252 Code::True | Code::False => self.deserialize_bool(visitor),
253 Code::Uint8 | Code::PosInt(_) => self.deserialize_u8(visitor),
254 Code::Uint16 => self.deserialize_u16(visitor),
255 Code::Uint32 => self.deserialize_u32(visitor),
256 Code::Uint64 => self.deserialize_u64(visitor),
257 Code::Int8 | Code::NegInt(_) => self.deserialize_i8(visitor),
258 Code::Int16 => self.deserialize_i16(visitor),
259 Code::Int32 => self.deserialize_i32(visitor),
260 Code::Int64 => self.deserialize_i64(visitor),
261 Code::Float32 => self.deserialize_f32(visitor),
262 Code::Float64 => self.deserialize_f64(visitor),
263 Code::FixStr(_) | Code::Str8 | Code::Str16 | Code::Str32 => {
264 self.deserialize_string(visitor)
265 }
266 Code::Bin8 | Code::Bin16 | Code::Bin32 => self.deserialize_bytes(visitor),
267 Code::FixArray(_) | Code::Array16 | Code::Array32 => self.deserialize_seq(visitor),
268 Code::FixMap(_) | Code::Map16 | Code::Map32 => self.deserialize_map(visitor),
269 Code::Reserved => unreachable!(), _ => unreachable!(),
279 }
280 }
281
282 fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
283 where
284 V: serde::de::Visitor<'de>,
285 {
286 match self.reader.peek_code()? {
287 code::Code::Nil => {
288 let _ = self.reader.consume_code();
289 visitor.visit_none()
290 }
291 _ => visitor.visit_some(self),
292 }
293 }
294
295 fn deserialize_enum<V>(
296 self,
297 _name: &str,
298 _variants: &[&str],
299 visitor: V,
300 ) -> Result<V::Value, Self::Error>
301 where
302 V: Visitor<'de>,
303 {
304 visitor.visit_none()
305 }
306
307 fn deserialize_newtype_struct<V>(
308 self,
309 _name: &'static str,
310 visitor: V,
311 ) -> Result<V::Value, Self::Error>
312 where
313 V: Visitor<'de>,
314 {
315 visitor.visit_newtype_struct(self)
316 }
317
318 fn deserialize_unit_struct<V>(
319 self,
320 _name: &'static str,
321 visitor: V,
322 ) -> Result<V::Value, Self::Error>
323 where
324 V: Visitor<'de>,
325 {
326 visitor.visit_unit()
327 }
328
329 fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value, Self::Error>
330 where
331 V: serde::de::Visitor<'de>,
332 {
333 match unpack::unpack_bool(&mut self.reader)? {
334 true => visitor.visit_bool(true),
335 false => visitor.visit_bool(false),
336 }
337 }
338
339 fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error>
340 where
341 V: serde::de::Visitor<'de>,
342 {
343 let size = unpack::unpack_ary_header(&mut self.reader)?;
344
345 visitor.visit_seq(SeqAccess {
346 de: self,
347 len: size,
348 })
349 }
350
351 fn deserialize_tuple<V>(self, len: usize, visitor: V) -> Result<V::Value, Self::Error>
352 where
353 V: serde::de::Visitor<'de>,
354 {
355 let size = unpack::unpack_ary_header(&mut self.reader)?;
356 if size != len {
357 return Err(Self::Error::InvalidSize);
358 }
359
360 visitor.visit_seq(SeqAccess {
361 de: self,
362 len: size,
363 })
364 }
365
366 fn deserialize_tuple_struct<V>(
367 self,
368 _name: &'static str,
369 len: usize,
370 visitor: V,
371 ) -> Result<V::Value, Self::Error>
372 where
373 V: serde::de::Visitor<'de>,
374 {
375 self.deserialize_tuple(len, visitor)
376 }
377
378 fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value, Self::Error>
379 where
380 V: serde::de::Visitor<'de>,
381 {
382 visitor.visit_unit()
383 }
384
385 fn deserialize_string<V>(self, visitor: V) -> Result<V::Value, Self::Error>
386 where
387 V: serde::de::Visitor<'de>,
388 {
389 let body = unpack::unpack_str(&mut self.reader)?;
390 visitor.visit_string(body)
392 }
393
394 fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error>
395 where
396 V: serde::de::Visitor<'de>,
397 {
398 let body = unpack::unpack_str_ref(&mut self.reader)?;
399 visitor.visit_str(body)
401 }
402
403 fn deserialize_struct<V>(
404 self,
405 _name: &str,
406 fields: &'static [&'static str],
407 visitor: V,
408 ) -> Result<V::Value, Self::Error>
409 where
410 V: serde::de::Visitor<'de>,
411 {
412 let _size = unpack::unpack_map_header(&mut self.reader)?;
414 visitor.visit_map(MapAccess {
415 de: self,
416 len: fields.len(),
417 })
418 }
419
420 fn deserialize_map<V>(self, visitor: V) -> Result<V::Value, Self::Error>
421 where
422 V: serde::de::Visitor<'de>,
423 {
424 let size = unpack::unpack_map_header(&mut self.reader)?;
425 visitor.visit_map(MapAccess {
426 de: self,
427 len: size,
428 })
429 }
430
431 fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value, Self::Error>
432 where
433 V: de::Visitor<'de>,
434 {
435 self.deserialize_str(visitor)
436 }
437
438 fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value, Self::Error>
439 where
440 V: serde::de::Visitor<'de>,
441 {
442 let body = unpack::unpack_bin_ref(&mut self.reader)?;
443 visitor.visit_bytes(body)
444 }
445
446 fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value, Self::Error>
447 where
448 V: serde::de::Visitor<'de>,
449 {
450 let body = unpack::unpack_bin(&mut self.reader)?;
451 visitor.visit_byte_buf(body)
452 }
453
454 fn deserialize_char<V>(self, visitor: V) -> Result<V::Value, Self::Error>
455 where
456 V: de::Visitor<'de>,
457 {
458 self.deserialize_str(visitor)
459 }
460
461 fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
462 where
463 V: de::Visitor<'de>,
464 {
465 visitor.visit_unit()
466 }
467}