1use {
26 crate::require,
27 byteorder::{
28 ByteOrder,
29 ReadBytesExt,
30 },
31 serde::{
32 de::{
33 EnumAccess,
34 MapAccess,
35 SeqAccess,
36 VariantAccess,
37 },
38 Deserialize,
39 },
40 std::{
41 io::{
42 Cursor,
43 Seek,
44 SeekFrom,
45 },
46 mem::size_of,
47 },
48 thiserror::Error,
49};
50
51pub fn from_slice<'de, B, T>(bytes: &'de [u8]) -> Result<T, DeserializerError>
59where
60 T: Deserialize<'de>,
61 B: ByteOrder,
62{
63 let mut deserializer = Deserializer::<B>::new(bytes);
64 T::deserialize(&mut deserializer)
65}
66
67#[derive(Debug, Error)]
68pub enum DeserializerError {
69 #[error("io error: {0}")]
70 Io(#[from] std::io::Error),
71
72 #[error("invalid utf8: {0}")]
73 Utf8(#[from] std::str::Utf8Error),
74
75 #[error("this type is not supported")]
76 Unsupported,
77
78 #[error("sequence too large ({0} elements), max supported is 255")]
79 SequenceTooLarge(usize),
80
81 #[error("message: {0}")]
82 Message(Box<str>),
83
84 #[error("invalid enum variant, higher than expected variant range")]
85 InvalidEnumVariant,
86
87 #[error("eof")]
88 Eof,
89}
90
91pub struct Deserializer<'de, B>
92where
93 B: ByteOrder,
94{
95 cursor: Cursor<&'de [u8]>,
96 endian: std::marker::PhantomData<B>,
97}
98
99impl serde::de::Error for DeserializerError {
100 fn custom<T: std::fmt::Display>(msg: T) -> Self {
101 DeserializerError::Message(msg.to_string().into_boxed_str())
102 }
103}
104
105impl<'de, B> Deserializer<'de, B>
106where
107 B: ByteOrder,
108{
109 pub fn new(buffer: &'de [u8]) -> Self {
110 Self {
111 cursor: Cursor::new(buffer),
112 endian: std::marker::PhantomData,
113 }
114 }
115}
116
117impl<'de, B> serde::de::Deserializer<'de> for &'_ mut Deserializer<'de, B>
118where
119 B: ByteOrder,
120{
121 type Error = DeserializerError;
122
123 fn deserialize_any<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
124 where
125 V: serde::de::Visitor<'de>,
126 {
127 Err(DeserializerError::Unsupported)
128 }
129
130 fn deserialize_ignored_any<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
131 where
132 V: serde::de::Visitor<'de>,
133 {
134 Err(DeserializerError::Unsupported)
135 }
136
137 fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value, Self::Error>
138 where
139 V: serde::de::Visitor<'de>,
140 {
141 let value = self.cursor.read_u8().map_err(DeserializerError::from)?;
142 visitor.visit_bool(value != 0)
143 }
144
145 fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
146 where
147 V: serde::de::Visitor<'de>,
148 {
149 let value = self.cursor.read_i8().map_err(DeserializerError::from)?;
150 visitor.visit_i8(value)
151 }
152
153 fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
154 where
155 V: serde::de::Visitor<'de>,
156 {
157 let value = self
158 .cursor
159 .read_i16::<B>()
160 .map_err(DeserializerError::from)?;
161
162 visitor.visit_i16(value)
163 }
164
165 fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
166 where
167 V: serde::de::Visitor<'de>,
168 {
169 let value = self
170 .cursor
171 .read_i32::<B>()
172 .map_err(DeserializerError::from)?;
173
174 visitor.visit_i32(value)
175 }
176
177 fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
178 where
179 V: serde::de::Visitor<'de>,
180 {
181 let value = self
182 .cursor
183 .read_i64::<B>()
184 .map_err(DeserializerError::from)?;
185
186 visitor.visit_i64(value)
187 }
188
189 fn deserialize_i128<V>(self, visitor: V) -> Result<V::Value, Self::Error>
190 where
191 V: serde::de::Visitor<'de>,
192 {
193 let value = self
194 .cursor
195 .read_i128::<B>()
196 .map_err(DeserializerError::from)?;
197
198 visitor.visit_i128(value)
199 }
200
201 fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
202 where
203 V: serde::de::Visitor<'de>,
204 {
205 let value = self.cursor.read_u8().map_err(DeserializerError::from)?;
206 visitor.visit_u8(value)
207 }
208
209 fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
210 where
211 V: serde::de::Visitor<'de>,
212 {
213 let value = self
214 .cursor
215 .read_u16::<B>()
216 .map_err(DeserializerError::from)?;
217
218 visitor.visit_u16(value)
219 }
220
221 fn deserialize_u32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
222 where
223 V: serde::de::Visitor<'de>,
224 {
225 let value = self
226 .cursor
227 .read_u32::<B>()
228 .map_err(DeserializerError::from)?;
229
230 visitor.visit_u32(value)
231 }
232
233 fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
234 where
235 V: serde::de::Visitor<'de>,
236 {
237 let value = self
238 .cursor
239 .read_u64::<B>()
240 .map_err(DeserializerError::from)?;
241
242 visitor.visit_u64(value)
243 }
244
245 fn deserialize_u128<V>(self, visitor: V) -> Result<V::Value, Self::Error>
246 where
247 V: serde::de::Visitor<'de>,
248 {
249 let value = self
250 .cursor
251 .read_u128::<B>()
252 .map_err(DeserializerError::from)?;
253
254 visitor.visit_u128(value)
255 }
256
257 fn deserialize_f32<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
258 where
259 V: serde::de::Visitor<'de>,
260 {
261 Err(DeserializerError::Unsupported)
262 }
263
264 fn deserialize_f64<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
265 where
266 V: serde::de::Visitor<'de>,
267 {
268 Err(DeserializerError::Unsupported)
269 }
270
271 fn deserialize_char<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
272 where
273 V: serde::de::Visitor<'de>,
274 {
275 Err(DeserializerError::Unsupported)
276 }
277
278 fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error>
279 where
280 V: serde::de::Visitor<'de>,
281 {
282 let len = self.cursor.read_u8().map_err(DeserializerError::from)? as u64;
283
284 self.cursor
292 .seek(SeekFrom::Current(len as i64))
293 .map_err(DeserializerError::from)?;
294
295 let buf = {
296 let buf = self.cursor.get_ref();
297 buf[(self.cursor.position() - len) as usize..]
298 .get(..len as usize)
299 .ok_or(DeserializerError::Eof)?
300 };
301
302 visitor.visit_borrowed_str(std::str::from_utf8(buf).map_err(DeserializerError::from)?)
303 }
304
305 fn deserialize_string<V>(self, visitor: V) -> Result<V::Value, Self::Error>
306 where
307 V: serde::de::Visitor<'de>,
308 {
309 self.deserialize_str(visitor)
310 }
311
312 fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value, Self::Error>
313 where
314 V: serde::de::Visitor<'de>,
315 {
316 let len = self.cursor.read_u8().map_err(DeserializerError::from)? as u64;
317
318 self.cursor
320 .seek(SeekFrom::Current(len as i64))
321 .map_err(DeserializerError::from)?;
322
323 let buf = {
324 let buf = self.cursor.get_ref();
325 buf[(self.cursor.position() - len) as usize..]
326 .get(..len as usize)
327 .ok_or(DeserializerError::Eof)?
328 };
329
330 visitor.visit_borrowed_bytes(buf)
331 }
332
333 fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value, Self::Error>
334 where
335 V: serde::de::Visitor<'de>,
336 {
337 self.deserialize_bytes(visitor)
338 }
339
340 fn deserialize_option<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
341 where
342 V: serde::de::Visitor<'de>,
343 {
344 Err(DeserializerError::Unsupported)
345 }
346
347 fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value, Self::Error>
348 where
349 V: serde::de::Visitor<'de>,
350 {
351 visitor.visit_unit()
352 }
353
354 fn deserialize_unit_struct<V>(
355 self,
356 _name: &'static str,
357 visitor: V,
358 ) -> Result<V::Value, Self::Error>
359 where
360 V: serde::de::Visitor<'de>,
361 {
362 visitor.visit_unit()
363 }
364
365 fn deserialize_newtype_struct<V>(
366 self,
367 _name: &'static str,
368 visitor: V,
369 ) -> Result<V::Value, Self::Error>
370 where
371 V: serde::de::Visitor<'de>,
372 {
373 visitor.visit_newtype_struct(self)
374 }
375
376 fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error>
377 where
378 V: serde::de::Visitor<'de>,
379 {
380 let len = self.cursor.read_u8().map_err(DeserializerError::from)? as usize;
381 visitor.visit_seq(SequenceIterator::new(self, len))
382 }
383
384 fn deserialize_tuple<V>(self, len: usize, visitor: V) -> Result<V::Value, Self::Error>
385 where
386 V: serde::de::Visitor<'de>,
387 {
388 visitor.visit_seq(SequenceIterator::new(self, len))
389 }
390
391 fn deserialize_tuple_struct<V>(
392 self,
393 _name: &'static str,
394 len: usize,
395 visitor: V,
396 ) -> Result<V::Value, Self::Error>
397 where
398 V: serde::de::Visitor<'de>,
399 {
400 visitor.visit_seq(SequenceIterator::new(self, len))
401 }
402
403 fn deserialize_map<V>(self, visitor: V) -> Result<V::Value, Self::Error>
404 where
405 V: serde::de::Visitor<'de>,
406 {
407 let len = self.cursor.read_u8().map_err(DeserializerError::from)? as usize;
408 visitor.visit_map(SequenceIterator::new(self, len))
409 }
410
411 fn deserialize_struct<V>(
412 self,
413 _name: &'static str,
414 fields: &'static [&'static str],
415 visitor: V,
416 ) -> Result<V::Value, Self::Error>
417 where
418 V: serde::de::Visitor<'de>,
419 {
420 visitor.visit_seq(SequenceIterator::new(self, fields.len()))
421 }
422
423 fn deserialize_enum<V>(
424 self,
425 _name: &'static str,
426 variants: &'static [&'static str],
427 visitor: V,
428 ) -> Result<V::Value, Self::Error>
429 where
430 V: serde::de::Visitor<'de>,
431 {
432 let variant = self.cursor.read_u8().map_err(DeserializerError::from)?;
435 if variant >= variants.len() as u8 {
436 return Err(DeserializerError::InvalidEnumVariant);
437 }
438
439 visitor.visit_enum(Enum { de: self, variant })
440 }
441
442 fn deserialize_identifier<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
443 where
444 V: serde::de::Visitor<'de>,
445 {
446 Err(DeserializerError::Unsupported)
447 }
448}
449
450impl<'de, 'a, B: ByteOrder> VariantAccess<'de> for &'a mut Deserializer<'de, B> {
451 type Error = DeserializerError;
452
453 fn unit_variant(self) -> Result<(), Self::Error> {
454 Ok(())
455 }
456
457 fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value, Self::Error>
458 where
459 T: serde::de::DeserializeSeed<'de>,
460 {
461 seed.deserialize(self)
462 }
463
464 fn tuple_variant<V>(self, len: usize, visitor: V) -> Result<V::Value, Self::Error>
465 where
466 V: serde::de::Visitor<'de>,
467 {
468 visitor.visit_seq(SequenceIterator::new(self, len))
469 }
470
471 fn struct_variant<V>(
472 self,
473 fields: &'static [&'static str],
474 visitor: V,
475 ) -> Result<V::Value, Self::Error>
476 where
477 V: serde::de::Visitor<'de>,
478 {
479 visitor.visit_seq(SequenceIterator::new(self, fields.len()))
480 }
481}
482
483struct SequenceIterator<'de, 'a, B: ByteOrder> {
484 de: &'a mut Deserializer<'de, B>,
485 len: usize,
486}
487
488impl<'de, 'a, B: ByteOrder> SequenceIterator<'de, 'a, B> {
489 fn new(de: &'a mut Deserializer<'de, B>, len: usize) -> Self {
490 Self { de, len }
491 }
492}
493
494impl<'de, 'a, B: ByteOrder> SeqAccess<'de> for SequenceIterator<'de, 'a, B> {
495 type Error = DeserializerError;
496
497 fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error>
498 where
499 T: serde::de::DeserializeSeed<'de>,
500 {
501 if self.len == 0 {
502 return Ok(None);
503 }
504
505 self.len -= 1;
506 seed.deserialize(&mut *self.de).map(Some)
507 }
508
509 fn size_hint(&self) -> Option<usize> {
510 Some(self.len)
511 }
512}
513
514impl<'de, 'a, B: ByteOrder> MapAccess<'de> for SequenceIterator<'de, 'a, B> {
515 type Error = DeserializerError;
516
517 fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, Self::Error>
518 where
519 K: serde::de::DeserializeSeed<'de>,
520 {
521 if self.len == 0 {
522 return Ok(None);
523 }
524
525 self.len -= 1;
526 seed.deserialize(&mut *self.de).map(Some)
527 }
528
529 fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Self::Error>
530 where
531 V: serde::de::DeserializeSeed<'de>,
532 {
533 seed.deserialize(&mut *self.de)
534 }
535
536 fn size_hint(&self) -> Option<usize> {
537 Some(self.len)
538 }
539}
540
541struct Enum<'de, 'a, B: ByteOrder> {
542 de: &'a mut Deserializer<'de, B>,
543 variant: u8,
544}
545
546impl<'de, 'a, B: ByteOrder> EnumAccess<'de> for Enum<'de, 'a, B> {
547 type Error = DeserializerError;
548 type Variant = &'a mut Deserializer<'de, B>;
549
550 fn variant_seed<V>(self, _: V) -> Result<(V::Value, Self::Variant), Self::Error>
551 where
552 V: serde::de::DeserializeSeed<'de>,
553 {
554 require!(
576 size_of::<u8>() >= size_of::<V::Value>(),
577 DeserializerError::InvalidEnumVariant
578 );
579
580 Ok((
581 unsafe { std::mem::transmute_copy::<u8, V::Value>(&self.variant) },
582 self.de,
583 ))
584 }
585}