1use byteorder::{ReadBytesExt, BE};
2use serde;
3use serde::de::{Deserialize, Visitor};
4use std;
5use std::error::Error as StdError;
6use std::fmt;
7use std::io::{self, Read};
8use std::mem::transmute;
9use std::{i16, i32, i64, i8};
10use utf8;
11
12#[derive(Debug)]
16pub struct Deserializer<R> {
17 reader: R,
18}
19
20#[derive(Debug)]
22pub enum Error {
23 DeserializeAnyUnsupported,
24 UnexpectedEof,
25 InvalidUtf8,
26 Io(io::Error),
27 Message(String),
28}
29
30pub type Result<T> = std::result::Result<T, Error>;
32
33pub fn deserialize<T>(bytes: &[u8]) -> Result<T>
43 where
44 T: for<'de> Deserialize<'de>,
45{
46 deserialize_from(bytes)
47}
48
49pub fn deserialize_from<R, T>(reader: R) -> Result<T>
60 where
61 R: io::BufRead,
62 T: for<'de> Deserialize<'de>,
63{
64 let mut deserializer = Deserializer::new(reader);
65 T::deserialize(&mut deserializer)
66}
67
68impl<R: io::Read> Deserializer<R> {
69 pub fn new(reader: R) -> Deserializer<R> {
71 Deserializer { reader: reader }
72 }
73
74 pub fn deserialize_var_u64(&mut self) -> Result<u64> {
76 let header = self.reader.read_u8()?;
77 let n = header >> 4;
78 let (mut val, _) = ((header & 0x0F) as u64).overflowing_shl(n as u32 * 8);
79 for i in 1..n + 1 {
80 let byte = self.reader.read_u8()?;
81 val += (byte as u64) << ((n - i) * 8);
82 }
83 Ok(val)
84 }
85
86 pub fn deserialize_var_i64(&mut self) -> Result<i64> {
88 let header = self.reader.read_u8()?;
89 let mask = ((header ^ 0x80) as i8 >> 7) as u8;
90 let n = ((header >> 3) ^ mask) & 0x0F;
91 let (mut val, _) = (((header ^ mask) & 0x07) as u64).overflowing_shl(n as u32 * 8);
92 for i in 1..n + 1 {
93 let byte = self.reader.read_u8()?;
94 val += ((byte ^ mask) as u64) << ((n - i) * 8);
95 }
96 let final_mask = (((mask as i64) << 63) >> 63) as u64;
97 val ^= final_mask;
98 Ok(val as i64)
99 }
100}
101
102impl<'de, 'a, R> serde::de::Deserializer<'de> for &'a mut Deserializer<R>
103 where
104 R: io::BufRead,
105{
106 type Error = Error;
107
108 fn deserialize_any<V>(self, _visitor: V) -> Result<V::Value>
109 where
110 V: Visitor<'de>,
111 {
112 Err(Error::DeserializeAnyUnsupported)
113 }
114
115 fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value>
116 where
117 V: Visitor<'de>,
118 {
119 let b = match self.reader.read_u8()? {
120 0 => false,
121 _ => true,
122 };
123 visitor.visit_bool(b)
124 }
125
126 fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value>
127 where
128 V: Visitor<'de>,
129 {
130 let i = self.reader.read_i8()?;
131 visitor.visit_i8(i ^ i8::MIN)
132 }
133
134 fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value>
135 where
136 V: Visitor<'de>,
137 {
138 let i = self.reader.read_i16::<BE>()?;
139 visitor.visit_i16(i ^ i16::MIN)
140 }
141
142 fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value>
143 where
144 V: Visitor<'de>,
145 {
146 let i = self.reader.read_i32::<BE>()?;
147 visitor.visit_i32(i ^ i32::MIN)
148 }
149
150 fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value>
151 where
152 V: Visitor<'de>,
153 {
154 let i = self.reader.read_i64::<BE>()?;
155 visitor.visit_i64(i ^ i64::MIN)
156 }
157
158 fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value>
159 where
160 V: Visitor<'de>,
161 {
162 let u = self.reader.read_u8()?;
163 visitor.visit_u8(u)
164 }
165
166 fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value>
167 where
168 V: Visitor<'de>,
169 {
170 let u = self.reader.read_u16::<BE>()?;
171 visitor.visit_u16(u)
172 }
173
174 fn deserialize_u32<V>(self, visitor: V) -> Result<V::Value>
175 where
176 V: Visitor<'de>,
177 {
178 let u = self.reader.read_u32::<BE>()?;
179 visitor.visit_u32(u)
180 }
181
182 fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value>
183 where
184 V: Visitor<'de>,
185 {
186 let u = self.reader.read_u64::<BE>()?;
187 visitor.visit_u64(u)
188 }
189
190 fn deserialize_f32<V>(self, visitor: V) -> Result<V::Value>
191 where
192 V: Visitor<'de>,
193 {
194 let val = self.reader.read_i32::<BE>()?;
195 let t = ((val ^ i32::MIN) >> 31) | i32::MIN;
196 let f: f32 = unsafe { transmute(val ^ t) };
197 visitor.visit_f32(f)
198 }
199
200 fn deserialize_f64<V>(self, visitor: V) -> Result<V::Value>
201 where
202 V: Visitor<'de>,
203 {
204 let val = self.reader.read_i64::<BE>()?;
205 let t = ((val ^ i64::MIN) >> 63) | i64::MIN;
206 let f: f64 = unsafe { transmute(val ^ t) };
207 visitor.visit_f64(f)
208 }
209
210 fn deserialize_char<V>(self, visitor: V) -> Result<V::Value>
211 where
212 V: Visitor<'de>,
213 {
214 let mut utf8_decoder = utf8::BufReadDecoder::new(&mut self.reader);
215 match utf8_decoder.next_strict() {
216 Some(Ok(s)) => {
217 let ch = s.chars().next().expect("expected at least one `char`");
218 visitor.visit_char(ch)
219 }
220 Some(Err(err)) => return Err(err.into()),
221 None => return Err(Error::UnexpectedEof.into()),
222 }
223 }
224
225 fn deserialize_str<V>(self, visitor: V) -> Result<V::Value>
226 where
227 V: Visitor<'de>,
228 {
229 let mut string = String::new();
230 let mut utf8_decoder = utf8::BufReadDecoder::new(&mut self.reader);
231 while let Some(res) = utf8_decoder.next_strict() {
232 match res {
233 Ok(mut s) => {
234 const EOF: char = '\u{0}';
237 const EOF_STR: &'static str = "\u{0}";
238 if s.len() >= EOF.len_utf8() {
239 let eof_start = s.len() - EOF.len_utf8();
240 if &s[eof_start..] == EOF_STR {
241 s = &s[..eof_start];
242 }
243 }
244 string.push_str(s);
245 }
246 Err(utf8::BufReadDecoderError::Io(err)) => return Err(err.into()),
247 Err(utf8::BufReadDecoderError::InvalidByteSequence(_)) => break,
248 }
249 }
250 let mut tmp = [0u8; 1];
251 self.reader.read(&mut tmp).unwrap();
252 assert_eq!(tmp[0], 0xFF);
253
254 visitor.visit_string(string)
255 }
256
257 fn deserialize_string<V>(self, visitor: V) -> Result<V::Value>
258 where
259 V: Visitor<'de>,
260 {
261 self.deserialize_str(visitor)
262 }
263
264 fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value>
265 where
266 V: Visitor<'de>,
267 {
268 let mut bytes = vec![];
269 for byte in (&mut self.reader).bytes() {
270 bytes.push(byte?);
271 }
272 visitor.visit_byte_buf(bytes)
273 }
274
275 fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value>
276 where
277 V: Visitor<'de>,
278 {
279 self.deserialize_bytes(visitor)
280 }
281
282 fn deserialize_option<V>(self, visitor: V) -> Result<V::Value>
283 where
284 V: Visitor<'de>,
285 {
286 match self.reader.read_u8()? {
287 0 => visitor.visit_none(),
288 1 => visitor.visit_some(&mut *self),
289 b => {
290 let msg = format!("expected `0` or `1` for option tag - found {}", b);
291 Err(Error::Message(msg))
292 }
293 }
294 }
295
296 fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value>
297 where
298 V: Visitor<'de>,
299 {
300 visitor.visit_unit()
301 }
302
303 fn deserialize_unit_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value>
304 where
305 V: Visitor<'de>,
306 {
307 visitor.visit_unit()
308 }
309
310 fn deserialize_newtype_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value>
311 where
312 V: Visitor<'de>,
313 {
314 visitor.visit_newtype_struct(self)
315 }
316
317 fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value>
318 where
319 V: Visitor<'de>,
320 {
321 struct Access<'a, R>
322 where
323 R: 'a + io::BufRead,
324 {
325 deserializer: &'a mut Deserializer<R>,
326 }
327
328 impl<'de, 'a, R> serde::de::SeqAccess<'de> for Access<'a, R>
329 where
330 R: io::BufRead,
331 {
332 type Error = Error;
333
334 fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
335 where
336 T: serde::de::DeserializeSeed<'de>,
337 {
338 match serde::de::DeserializeSeed::deserialize(seed, &mut *self.deserializer) {
339 Ok(v) => Ok(Some(v)),
340 Err(Error::Io(ref err)) if err.kind() == io::ErrorKind::UnexpectedEof => {
341 Ok(None)
342 }
343 Err(err) => Err(err),
344 }
345 }
346 }
347
348 visitor.visit_seq(Access { deserializer: self })
349 }
350
351 fn deserialize_tuple<V>(self, len: usize, visitor: V) -> Result<V::Value>
352 where
353 V: Visitor<'de>,
354 {
355 struct Access<'a, R>
356 where
357 R: 'a + io::BufRead,
358 {
359 deserializer: &'a mut Deserializer<R>,
360 len: usize,
361 }
362
363 impl<'de, 'a, R> serde::de::SeqAccess<'de> for Access<'a, R>
364 where
365 R: io::BufRead,
366 {
367 type Error = Error;
368
369 fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
370 where
371 T: serde::de::DeserializeSeed<'de>,
372 {
373 if self.len == 0 {
374 return Ok(None);
375 }
376 self.len -= 1;
377 let value = serde::de::DeserializeSeed::deserialize(seed, &mut *self.deserializer)?;
378 Ok(Some(value))
379 }
380
381 fn size_hint(&self) -> Option<usize> {
382 Some(self.len)
383 }
384 }
385
386 visitor.visit_seq(Access {
387 deserializer: self,
388 len: len,
389 })
390 }
391
392 fn deserialize_tuple_struct<V>(
393 self,
394 _name: &'static str,
395 len: usize,
396 visitor: V,
397 ) -> Result<V::Value>
398 where
399 V: Visitor<'de>,
400 {
401 self.deserialize_tuple(len, visitor)
402 }
403
404 fn deserialize_map<V>(self, visitor: V) -> Result<V::Value>
405 where
406 V: Visitor<'de>,
407 {
408 struct Access<'a, R>
409 where
410 R: 'a + io::BufRead,
411 {
412 deserializer: &'a mut Deserializer<R>,
413 }
414
415 impl<'de, 'a, R> serde::de::MapAccess<'de> for Access<'a, R>
416 where
417 R: io::BufRead,
418 {
419 type Error = Error;
420
421 fn next_key_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
422 where
423 T: serde::de::DeserializeSeed<'de>,
424 {
425 match serde::de::DeserializeSeed::deserialize(seed, &mut *self.deserializer) {
426 Ok(v) => Ok(Some(v)),
427 Err(Error::Io(ref err)) if err.kind() == io::ErrorKind::UnexpectedEof => {
428 Ok(None)
429 }
430 Err(err) => Err(err),
431 }
432 }
433
434 fn next_value_seed<T>(&mut self, seed: T) -> Result<T::Value>
435 where
436 T: serde::de::DeserializeSeed<'de>,
437 {
438 serde::de::DeserializeSeed::deserialize(seed, &mut *self.deserializer)
439 }
440 }
441
442 visitor.visit_map(Access { deserializer: self })
443 }
444
445 fn deserialize_struct<V>(
446 self,
447 _name: &'static str,
448 fields: &'static [&'static str],
449 visitor: V,
450 ) -> Result<V::Value>
451 where
452 V: Visitor<'de>,
453 {
454 self.deserialize_tuple(fields.len(), visitor)
455 }
456
457 fn deserialize_enum<V>(
458 self,
459 _name: &'static str,
460 _fields: &'static [&'static str],
461 visitor: V,
462 ) -> Result<V::Value>
463 where
464 V: Visitor<'de>,
465 {
466 impl<'de, 'a, R> serde::de::EnumAccess<'de> for &'a mut Deserializer<R>
467 where
468 R: io::BufRead,
469 {
470 type Error = Error;
471 type Variant = Self;
472
473 fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant)>
474 where
475 V: serde::de::DeserializeSeed<'de>,
476 {
477 let idx: u32 = serde::de::Deserialize::deserialize(&mut *self)?;
478 let val: Result<_> =
479 seed.deserialize(serde::de::IntoDeserializer::into_deserializer(idx));
480 Ok((val?, self))
481 }
482 }
483
484 impl<'de, 'a, R> serde::de::VariantAccess<'de> for &'a mut Deserializer<R>
485 where
486 R: io::BufRead,
487 {
488 type Error = Error;
489
490 fn unit_variant(self) -> Result<()> {
491 Ok(())
492 }
493
494 fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value>
495 where
496 T: serde::de::DeserializeSeed<'de>,
497 {
498 serde::de::DeserializeSeed::deserialize(seed, self)
499 }
500
501 fn tuple_variant<V>(self, len: usize, visitor: V) -> Result<V::Value>
502 where
503 V: serde::de::Visitor<'de>,
504 {
505 serde::de::Deserializer::deserialize_tuple(self, len, visitor)
506 }
507
508 fn struct_variant<V>(
509 self,
510 fields: &'static [&'static str],
511 visitor: V,
512 ) -> Result<V::Value>
513 where
514 V: serde::de::Visitor<'de>,
515 {
516 serde::de::Deserializer::deserialize_tuple(self, fields.len(), visitor)
517 }
518 }
519
520 visitor.visit_enum(self)
521 }
522
523 fn deserialize_ignored_any<V>(self, _visitor: V) -> Result<V::Value>
524 where
525 V: serde::de::Visitor<'de>,
526 {
527 Err(Error::DeserializeAnyUnsupported)
528 }
529
530 fn deserialize_identifier<V>(self, _visitor: V) -> Result<V::Value>
531 where
532 V: serde::de::Visitor<'de>,
533 {
534 Err(Error::DeserializeAnyUnsupported)
535 }
536}
537
538impl<'a> From<utf8::BufReadDecoderError<'a>> for Error {
539 fn from(_err: utf8::BufReadDecoderError) -> Self {
540 Error::InvalidUtf8
541 }
542}
543
544impl From<io::Error> for Error {
545 fn from(err: io::Error) -> Self {
546 Error::Io(err)
547 }
548}
549
550#[allow(deprecated)]
551impl fmt::Display for Error {
552 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
553 write!(f, "{}", match *self {
554 Error::DeserializeAnyUnsupported => "`bytekey` is not a self-describing format",
555 Error::UnexpectedEof => "encountered unexpected EOF when deserializing utf8",
556 Error::InvalidUtf8 => "attempted to deserialize invalid utf8",
557 Error::Io(ref err) => err.description(),
558 Error::Message(ref msg) => msg,
559 })
560 }
561}
562
563impl StdError for Error {
564 fn source(&self) -> Option<&(dyn StdError + 'static)> {
565 match *self {
566 Error::DeserializeAnyUnsupported => None,
567 Error::UnexpectedEof => None,
568 Error::InvalidUtf8 => None,
569 Error::Io(ref err) => Some(err),
570 Error::Message(ref _msg) => None,
571 }
572 }
573}
574
575impl serde::de::Error for Error {
576 fn custom<T: fmt::Display>(msg: T) -> Self {
577 Error::Message(msg.to_string())
578 }
579}