1use {
26 crate::require,
27 byteorder::{ByteOrder, ReadBytesExt},
28 serde::{
29 de::{EnumAccess, MapAccess, SeqAccess, VariantAccess},
30 Deserialize,
31 },
32 std::{
33 io::{Cursor, Seek, SeekFrom},
34 mem::size_of,
35 },
36 thiserror::Error,
37};
38
39pub fn from_slice<'de, B, T>(bytes: &'de [u8]) -> Result<T, DeserializerError>
47where
48 T: Deserialize<'de>,
49 B: ByteOrder,
50{
51 let mut deserializer = Deserializer::<B>::new(bytes);
52 T::deserialize(&mut deserializer)
53}
54
55#[derive(Debug, Error)]
56pub enum DeserializerError {
57 #[error("io error: {0}")]
58 Io(#[from] std::io::Error),
59
60 #[error("invalid utf8: {0}")]
61 Utf8(#[from] std::str::Utf8Error),
62
63 #[error("this type is not supported")]
64 Unsupported,
65
66 #[error("sequence too large ({0} elements), max supported is 255")]
67 SequenceTooLarge(usize),
68
69 #[error("message: {0}")]
70 Message(Box<str>),
71
72 #[error("invalid enum variant, higher than expected variant range")]
73 InvalidEnumVariant,
74
75 #[error("eof")]
76 Eof,
77}
78
79pub struct Deserializer<'de, B>
80where
81 B: ByteOrder,
82{
83 cursor: Cursor<&'de [u8]>,
84 endian: std::marker::PhantomData<B>,
85}
86
87impl serde::de::Error for DeserializerError {
88 fn custom<T: std::fmt::Display>(msg: T) -> Self {
89 DeserializerError::Message(msg.to_string().into_boxed_str())
90 }
91}
92
93impl<'de, B> Deserializer<'de, B>
94where
95 B: ByteOrder,
96{
97 pub fn new(buffer: &'de [u8]) -> Self {
98 Self {
99 cursor: Cursor::new(buffer),
100 endian: std::marker::PhantomData,
101 }
102 }
103}
104
105impl<'de, B> serde::de::Deserializer<'de> for &'_ mut Deserializer<'de, B>
106where
107 B: ByteOrder,
108{
109 type Error = DeserializerError;
110
111 fn deserialize_any<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
112 where
113 V: serde::de::Visitor<'de>,
114 {
115 Err(DeserializerError::Unsupported)
116 }
117
118 fn deserialize_ignored_any<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
119 where
120 V: serde::de::Visitor<'de>,
121 {
122 Err(DeserializerError::Unsupported)
123 }
124
125 fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value, Self::Error>
126 where
127 V: serde::de::Visitor<'de>,
128 {
129 let value = self.cursor.read_u8().map_err(DeserializerError::from)?;
130 visitor.visit_bool(value != 0)
131 }
132
133 fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
134 where
135 V: serde::de::Visitor<'de>,
136 {
137 let value = self.cursor.read_i8().map_err(DeserializerError::from)?;
138 visitor.visit_i8(value)
139 }
140
141 fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
142 where
143 V: serde::de::Visitor<'de>,
144 {
145 let value = self
146 .cursor
147 .read_i16::<B>()
148 .map_err(DeserializerError::from)?;
149
150 visitor.visit_i16(value)
151 }
152
153 fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
154 where
155 V: serde::de::Visitor<'de>,
156 {
157 let value = self
158 .cursor
159 .read_i32::<B>()
160 .map_err(DeserializerError::from)?;
161
162 visitor.visit_i32(value)
163 }
164
165 fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
166 where
167 V: serde::de::Visitor<'de>,
168 {
169 let value = self
170 .cursor
171 .read_i64::<B>()
172 .map_err(DeserializerError::from)?;
173
174 visitor.visit_i64(value)
175 }
176
177 fn deserialize_i128<V>(self, visitor: V) -> Result<V::Value, Self::Error>
178 where
179 V: serde::de::Visitor<'de>,
180 {
181 let value = self
182 .cursor
183 .read_i128::<B>()
184 .map_err(DeserializerError::from)?;
185
186 visitor.visit_i128(value)
187 }
188
189 fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
190 where
191 V: serde::de::Visitor<'de>,
192 {
193 let value = self.cursor.read_u8().map_err(DeserializerError::from)?;
194 visitor.visit_u8(value)
195 }
196
197 fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
198 where
199 V: serde::de::Visitor<'de>,
200 {
201 let value = self
202 .cursor
203 .read_u16::<B>()
204 .map_err(DeserializerError::from)?;
205
206 visitor.visit_u16(value)
207 }
208
209 fn deserialize_u32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
210 where
211 V: serde::de::Visitor<'de>,
212 {
213 let value = self
214 .cursor
215 .read_u32::<B>()
216 .map_err(DeserializerError::from)?;
217
218 visitor.visit_u32(value)
219 }
220
221 fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
222 where
223 V: serde::de::Visitor<'de>,
224 {
225 let value = self
226 .cursor
227 .read_u64::<B>()
228 .map_err(DeserializerError::from)?;
229
230 visitor.visit_u64(value)
231 }
232
233 fn deserialize_u128<V>(self, visitor: V) -> Result<V::Value, Self::Error>
234 where
235 V: serde::de::Visitor<'de>,
236 {
237 let value = self
238 .cursor
239 .read_u128::<B>()
240 .map_err(DeserializerError::from)?;
241
242 visitor.visit_u128(value)
243 }
244
245 fn deserialize_f32<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
246 where
247 V: serde::de::Visitor<'de>,
248 {
249 Err(DeserializerError::Unsupported)
250 }
251
252 fn deserialize_f64<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
253 where
254 V: serde::de::Visitor<'de>,
255 {
256 Err(DeserializerError::Unsupported)
257 }
258
259 fn deserialize_char<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
260 where
261 V: serde::de::Visitor<'de>,
262 {
263 Err(DeserializerError::Unsupported)
264 }
265
266 fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error>
267 where
268 V: serde::de::Visitor<'de>,
269 {
270 let len = self.cursor.read_u8().map_err(DeserializerError::from)? as u64;
271
272 self.cursor
280 .seek(SeekFrom::Current(len as i64))
281 .map_err(DeserializerError::from)?;
282
283 let buf = {
284 let buf = self.cursor.get_ref();
285 buf[(self.cursor.position() - len) as usize..]
286 .get(..len as usize)
287 .ok_or(DeserializerError::Eof)?
288 };
289
290 visitor.visit_borrowed_str(std::str::from_utf8(buf).map_err(DeserializerError::from)?)
291 }
292
293 fn deserialize_string<V>(self, visitor: V) -> Result<V::Value, Self::Error>
294 where
295 V: serde::de::Visitor<'de>,
296 {
297 self.deserialize_str(visitor)
298 }
299
300 fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value, Self::Error>
301 where
302 V: serde::de::Visitor<'de>,
303 {
304 let len = self.cursor.read_u8().map_err(DeserializerError::from)? as u64;
305
306 self.cursor
308 .seek(SeekFrom::Current(len as i64))
309 .map_err(DeserializerError::from)?;
310
311 let buf = {
312 let buf = self.cursor.get_ref();
313 buf[(self.cursor.position() - len) as usize..]
314 .get(..len as usize)
315 .ok_or(DeserializerError::Eof)?
316 };
317
318 visitor.visit_borrowed_bytes(buf)
319 }
320
321 fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value, Self::Error>
322 where
323 V: serde::de::Visitor<'de>,
324 {
325 self.deserialize_bytes(visitor)
326 }
327
328 fn deserialize_option<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
329 where
330 V: serde::de::Visitor<'de>,
331 {
332 Err(DeserializerError::Unsupported)
333 }
334
335 fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value, Self::Error>
336 where
337 V: serde::de::Visitor<'de>,
338 {
339 visitor.visit_unit()
340 }
341
342 fn deserialize_unit_struct<V>(
343 self,
344 _name: &'static str,
345 visitor: V,
346 ) -> Result<V::Value, Self::Error>
347 where
348 V: serde::de::Visitor<'de>,
349 {
350 visitor.visit_unit()
351 }
352
353 fn deserialize_newtype_struct<V>(
354 self,
355 _name: &'static str,
356 visitor: V,
357 ) -> Result<V::Value, Self::Error>
358 where
359 V: serde::de::Visitor<'de>,
360 {
361 visitor.visit_newtype_struct(self)
362 }
363
364 fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error>
365 where
366 V: serde::de::Visitor<'de>,
367 {
368 let len = self.cursor.read_u8().map_err(DeserializerError::from)? as usize;
369 visitor.visit_seq(SequenceIterator::new(self, len))
370 }
371
372 fn deserialize_tuple<V>(self, len: usize, visitor: V) -> Result<V::Value, Self::Error>
373 where
374 V: serde::de::Visitor<'de>,
375 {
376 visitor.visit_seq(SequenceIterator::new(self, len))
377 }
378
379 fn deserialize_tuple_struct<V>(
380 self,
381 _name: &'static str,
382 len: usize,
383 visitor: V,
384 ) -> Result<V::Value, Self::Error>
385 where
386 V: serde::de::Visitor<'de>,
387 {
388 visitor.visit_seq(SequenceIterator::new(self, len))
389 }
390
391 fn deserialize_map<V>(self, visitor: V) -> Result<V::Value, Self::Error>
392 where
393 V: serde::de::Visitor<'de>,
394 {
395 let len = self.cursor.read_u8().map_err(DeserializerError::from)? as usize;
396 visitor.visit_map(SequenceIterator::new(self, len))
397 }
398
399 fn deserialize_struct<V>(
400 self,
401 _name: &'static str,
402 fields: &'static [&'static str],
403 visitor: V,
404 ) -> Result<V::Value, Self::Error>
405 where
406 V: serde::de::Visitor<'de>,
407 {
408 visitor.visit_seq(SequenceIterator::new(self, fields.len()))
409 }
410
411 fn deserialize_enum<V>(
412 self,
413 _name: &'static str,
414 variants: &'static [&'static str],
415 visitor: V,
416 ) -> Result<V::Value, Self::Error>
417 where
418 V: serde::de::Visitor<'de>,
419 {
420 let variant = self.cursor.read_u8().map_err(DeserializerError::from)?;
423 if variant >= variants.len() as u8 {
424 return Err(DeserializerError::InvalidEnumVariant);
425 }
426
427 visitor.visit_enum(Enum { de: self, variant })
428 }
429
430 fn deserialize_identifier<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
431 where
432 V: serde::de::Visitor<'de>,
433 {
434 Err(DeserializerError::Unsupported)
435 }
436}
437
438impl<'de, 'a, B: ByteOrder> VariantAccess<'de> for &'a mut Deserializer<'de, B> {
439 type Error = DeserializerError;
440
441 fn unit_variant(self) -> Result<(), Self::Error> {
442 Ok(())
443 }
444
445 fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value, Self::Error>
446 where
447 T: serde::de::DeserializeSeed<'de>,
448 {
449 seed.deserialize(self)
450 }
451
452 fn tuple_variant<V>(self, len: usize, visitor: V) -> Result<V::Value, Self::Error>
453 where
454 V: serde::de::Visitor<'de>,
455 {
456 visitor.visit_seq(SequenceIterator::new(self, len))
457 }
458
459 fn struct_variant<V>(
460 self,
461 fields: &'static [&'static str],
462 visitor: V,
463 ) -> Result<V::Value, Self::Error>
464 where
465 V: serde::de::Visitor<'de>,
466 {
467 visitor.visit_seq(SequenceIterator::new(self, fields.len()))
468 }
469}
470
471struct SequenceIterator<'de, 'a, B: ByteOrder> {
472 de: &'a mut Deserializer<'de, B>,
473 len: usize,
474}
475
476impl<'de, 'a, B: ByteOrder> SequenceIterator<'de, 'a, B> {
477 fn new(de: &'a mut Deserializer<'de, B>, len: usize) -> Self {
478 Self { de, len }
479 }
480}
481
482impl<'de, 'a, B: ByteOrder> SeqAccess<'de> for SequenceIterator<'de, 'a, B> {
483 type Error = DeserializerError;
484
485 fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error>
486 where
487 T: serde::de::DeserializeSeed<'de>,
488 {
489 if self.len == 0 {
490 return Ok(None);
491 }
492
493 self.len -= 1;
494 seed.deserialize(&mut *self.de).map(Some)
495 }
496
497 fn size_hint(&self) -> Option<usize> {
498 Some(self.len)
499 }
500}
501
502impl<'de, 'a, B: ByteOrder> MapAccess<'de> for SequenceIterator<'de, 'a, B> {
503 type Error = DeserializerError;
504
505 fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, Self::Error>
506 where
507 K: serde::de::DeserializeSeed<'de>,
508 {
509 if self.len == 0 {
510 return Ok(None);
511 }
512
513 self.len -= 1;
514 seed.deserialize(&mut *self.de).map(Some)
515 }
516
517 fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Self::Error>
518 where
519 V: serde::de::DeserializeSeed<'de>,
520 {
521 seed.deserialize(&mut *self.de)
522 }
523
524 fn size_hint(&self) -> Option<usize> {
525 Some(self.len)
526 }
527}
528
529struct Enum<'de, 'a, B: ByteOrder> {
530 de: &'a mut Deserializer<'de, B>,
531 variant: u8,
532}
533
534impl<'de, 'a, B: ByteOrder> EnumAccess<'de> for Enum<'de, 'a, B> {
535 type Error = DeserializerError;
536 type Variant = &'a mut Deserializer<'de, B>;
537
538 fn variant_seed<V>(self, _: V) -> Result<(V::Value, Self::Variant), Self::Error>
539 where
540 V: serde::de::DeserializeSeed<'de>,
541 {
542 require!(
564 size_of::<u8>() >= size_of::<V::Value>(),
565 DeserializerError::InvalidEnumVariant
566 );
567
568 Ok((
569 unsafe { std::mem::transmute_copy::<u8, V::Value>(&self.variant) },
570 self.de,
571 ))
572 }
573}