1use super::{
4 error::{Error, Result},
5 types::internal::{text_size, type_of, TypeId},
6 types::{Field, Label, SharedLabel, Type, TypeEnv, TypeInner},
7 CandidType,
8};
9#[cfg(feature = "bignum")]
10use super::{Int, Nat};
11use crate::{
12 binary_parser::{Header, Len, PrincipalBytes},
13 types::subtype::{subtype_with_config, Gamma, OptReport},
14};
15use anyhow::{anyhow, Context};
16use binread::BinRead;
17use byteorder::{LittleEndian, ReadBytesExt};
18use serde::de::{self, Visitor};
19use std::fmt::Write;
20use std::{collections::VecDeque, io::Cursor, mem::replace, rc::Rc};
21
22const MAX_TYPE_LEN: i32 = 500;
23
24pub struct IDLDeserialize<'de> {
26 de: Deserializer<'de>,
27}
28impl<'de> IDLDeserialize<'de> {
29 pub fn new(bytes: &'de [u8]) -> Result<Self> {
31 let config = DecoderConfig::new();
32 Self::new_with_config(bytes, &config)
33 }
34 pub fn new_with_config(bytes: &'de [u8], config: &DecoderConfig) -> Result<Self> {
36 let mut de = Deserializer::from_bytes(bytes, config).with_context(|| {
37 if config.full_error_message || bytes.len() <= 500 {
38 format!("Cannot parse header {}", &hex::encode(bytes))
39 } else {
40 "Cannot parse header".to_string()
41 }
42 })?;
43 de.add_cost((de.input.position() as usize).saturating_mul(4))?;
44 Ok(IDLDeserialize { de })
45 }
46 pub fn get_value<T>(&mut self) -> Result<T>
48 where
49 T: de::Deserialize<'de> + CandidType,
50 {
51 self.de.is_untyped = false;
52 self.deserialize_with_type(T::ty())
53 }
54 #[cfg_attr(docsrs, doc(cfg(feature = "value")))]
55 #[cfg(feature = "value")]
56 pub fn get_value_with_type(
58 &mut self,
59 env: &TypeEnv,
60 expected_type: &Type,
61 ) -> Result<crate::types::value::IDLValue> {
62 Rc::make_mut(&mut self.de.table).merge(env)?;
63 self.de.is_untyped = true;
64 self.deserialize_with_type(expected_type.clone())
65 }
66 fn deserialize_with_type<T>(&mut self, expected_type: Type) -> Result<T>
67 where
68 T: de::Deserialize<'de> + CandidType,
69 {
70 let expected_type = self
71 .de
72 .table
73 .trace_type_with_depth(&expected_type, &self.de.recursion_depth)?;
74 if self.de.types.is_empty() {
75 if matches!(
76 expected_type.as_ref(),
77 TypeInner::Opt(_) | TypeInner::Reserved | TypeInner::Null
78 ) {
79 self.de.expect_type = expected_type;
80 self.de.wire_type = TypeInner::Null.into();
81 return T::deserialize(&mut self.de);
82 } else if self.de.config.full_error_message
83 || text_size(&expected_type, MAX_TYPE_LEN).is_ok()
84 {
85 return Err(Error::msg(format!(
86 "No more values on the wire, the expected type {expected_type} is not opt, null, or reserved"
87 )));
88 } else {
89 return Err(Error::msg("No more values on the wire"));
90 }
91 }
92
93 let (ind, ty) = self.de.types.pop_front().unwrap();
94 self.de.expect_type = if matches!(expected_type.as_ref(), TypeInner::Unknown) {
95 self.de.is_untyped = true;
96 ty.clone()
97 } else {
98 expected_type.clone()
99 };
100 self.de.wire_type = ty.clone();
101
102 let mut v = T::deserialize(&mut self.de).with_context(|| {
103 if self.de.config.full_error_message
104 || (text_size(&ty, MAX_TYPE_LEN).is_ok()
105 && text_size(&expected_type, MAX_TYPE_LEN).is_ok())
106 {
107 format!("Fail to decode argument {ind} from {ty} to {expected_type}")
108 } else {
109 format!("Fail to decode argument {ind}")
110 }
111 });
112 if self.de.config.full_error_message {
113 v = v.with_context(|| self.de.dump_state());
114 }
115 Ok(v?)
116 }
117 pub fn is_done(&self) -> bool {
119 self.de.types.is_empty()
120 }
121 pub fn done(&mut self) -> Result<()> {
123 while !self.is_done() {
124 self.get_value::<crate::Reserved>()?;
125 }
126 let ind = self.de.input.position() as usize;
127 let rest = &self.de.input.get_ref()[ind..];
128 if !rest.is_empty() {
129 if !self.de.config.full_error_message {
130 return Err(Error::msg("Trailing value after finishing deserialization"));
131 } else {
132 return Err(anyhow!(self.de.dump_state()))
133 .context("Trailing value after finishing deserialization")?;
134 }
135 }
136 Ok(())
137 }
138 pub fn get_config(&self) -> DecoderConfig {
140 self.de.config.clone()
141 }
142}
143
144#[derive(Clone)]
145pub struct DecoderConfig {
147 pub decoding_quota: Option<usize>,
148 pub skipping_quota: Option<usize>,
149 pub max_type_len: Option<usize>,
150 full_error_message: bool,
151}
152impl DecoderConfig {
153 pub fn new() -> Self {
157 Self {
158 decoding_quota: None,
159 skipping_quota: None,
160 max_type_len: None,
161 #[cfg(not(target_arch = "wasm32"))]
162 full_error_message: true,
163 #[cfg(target_arch = "wasm32")]
164 full_error_message: false,
165 }
166 }
167 pub fn set_decoding_quota(&mut self, n: usize) -> &mut Self {
208 self.decoding_quota = Some(n);
209 self
210 }
211 pub fn set_skipping_quota(&mut self, n: usize) -> &mut Self {
217 self.skipping_quota = Some(n);
218 self
219 }
220 pub fn set_max_type_len(&mut self, n: usize) -> &mut Self {
222 self.max_type_len = Some(n);
223 self
224 }
225 pub fn set_full_error_message(&mut self, n: bool) -> &mut Self {
229 self.full_error_message = n;
230 self
231 }
232 pub fn compute_cost(&self, original: &Self) -> Self {
234 let decoding_quota = original
235 .decoding_quota
236 .and_then(|n| Some(n - self.decoding_quota?));
237 let skipping_quota = original
238 .skipping_quota
239 .and_then(|n| Some(n - self.skipping_quota?));
240 Self {
241 decoding_quota,
242 skipping_quota,
243 max_type_len: original.max_type_len,
244 full_error_message: original.full_error_message,
245 }
246 }
247}
248impl Default for DecoderConfig {
249 fn default() -> Self {
250 Self::new()
251 }
252}
253
254macro_rules! assert {
255 ( false ) => {{
256 return Err(Error::msg(format!(
257 "Internal error at {}:{}. Please file a bug.",
258 file!(),
259 line!()
260 )));
261 }};
262 ( $pred:expr ) => {{
263 if !$pred {
264 return Err(Error::msg(format!(
265 "Internal error at {}:{}. Please file a bug.",
266 file!(),
267 line!()
268 )));
269 }
270 }};
271}
272
273macro_rules! check {
274 ( false ) => {{
275 return Err(Error::Subtype(format!(
276 "Type mismatch at {}:{}",
277 file!(),
278 line!()
279 )));
280 }};
281 ($exp:expr, $msg:expr) => {{
282 if !$exp {
283 return Err(Error::Subtype($msg.to_string()));
284 }
285 }};
286}
287
288#[derive(Clone)]
289struct Deserializer<'de> {
290 input: Cursor<&'de [u8]>,
291 table: Rc<TypeEnv>,
292 types: VecDeque<(usize, Type)>,
293 wire_type: Type,
294 expect_type: Type,
295 gamma: Gamma,
297 field_name: Option<SharedLabel>,
300 is_untyped: bool,
303 config: DecoderConfig,
304 recursion_depth: crate::utils::RecursionDepth,
305 primitive_vec_fast_path: Option<PrimitiveType>,
306 #[cfg(feature = "bignum")]
307 bignum_vec_fast_path: Option<BigNumFastPath>,
308 text_fast_path: bool,
309}
310
311impl<'de> Deserializer<'de> {
312 fn from_bytes(bytes: &'de [u8], config: &DecoderConfig) -> Result<Self> {
313 let mut reader = Cursor::new(bytes);
314 let header = Header::read_args(&mut reader, (config.max_type_len,))?;
315 let (env, types) = header.to_types()?;
316 Ok(Deserializer {
317 input: reader,
318 table: env.into(),
319 types: types.into_iter().enumerate().collect(),
320 wire_type: TypeInner::Unknown.into(),
321 expect_type: TypeInner::Unknown.into(),
322 gamma: Gamma::default(),
323 field_name: None,
324 is_untyped: false,
325 config: config.clone(),
326 recursion_depth: crate::utils::RecursionDepth::new(),
327 primitive_vec_fast_path: None,
328 #[cfg(feature = "bignum")]
329 bignum_vec_fast_path: None,
330 text_fast_path: false,
331 })
332 }
333 fn dump_state(&self) -> String {
334 let hex = hex::encode(self.input.get_ref());
335 let pos = self.input.position() as usize * 2;
336 let (before, after) = hex.split_at(pos);
337 let mut res = format!("input: {before}_{after}\n");
338 if !self.table.0.is_empty() {
339 write!(&mut res, "table: {}", self.table).unwrap();
340 }
341 write!(
342 &mut res,
343 "wire_type: {}, expect_type: {}",
344 self.wire_type, self.expect_type
345 )
346 .unwrap();
347 if let Some(field) = &self.field_name {
348 write!(&mut res, ", field_name: {field:?}").unwrap();
349 }
350 res
351 }
352 #[inline]
353 fn read_leb_u64(&mut self) -> Result<u64> {
354 self.try_read_leb_u64()?
355 .ok_or_else(|| Error::msg("LEB128 overflow"))
356 }
357 #[inline]
360 fn try_read_leb_u64(&mut self) -> Result<Option<u64>> {
361 let slice = self.input.get_ref();
362 let mut pos = self.input.position() as usize;
363 let end = slice.len();
364 let mut result: u64 = 0;
365 let mut shift: u32 = 0;
366 loop {
367 if pos >= end {
368 return Err(Error::msg("unexpected end of LEB128"));
369 }
370 let byte = slice[pos];
371 pos += 1;
372 let low = (byte & 0x7f) as u64;
373 if shift < 64 {
374 result |= low << shift;
375 }
376 if byte & 0x80 == 0 {
377 self.input.set_position(pos as u64);
378 return Ok(Some(result));
379 }
380 shift += 7;
381 if shift >= 70 {
382 return Ok(None);
383 }
384 }
385 }
386 #[inline]
389 fn try_read_leb_i64(&mut self) -> Result<Option<i64>> {
390 let slice = self.input.get_ref();
391 let mut pos = self.input.position() as usize;
392 let end = slice.len();
393 let mut result: i64 = 0;
394 let mut shift: u32 = 0;
395 let mut byte;
396 loop {
397 if pos >= end {
398 return Err(Error::msg("unexpected end of LEB128"));
399 }
400 byte = slice[pos];
401 pos += 1;
402 let low = (byte & 0x7f) as i64;
403 if shift < 64 {
404 result |= low << shift;
405 }
406 shift += 7;
407 if byte & 0x80 == 0 {
408 break;
409 }
410 if shift >= 70 {
411 return Ok(None);
412 }
413 }
414 if shift < 64 && byte & 0x40 != 0 {
415 result |= !0i64 << shift;
416 }
417 self.input.set_position(pos as u64);
418 Ok(Some(result))
419 }
420 #[inline]
421 fn read_len(&mut self) -> Result<usize> {
422 let val = self.read_leb_u64()?;
423 usize::try_from(val).map_err(|_| Error::msg("length out of usize range"))
424 }
425 #[inline]
426 fn read_bool_val(&mut self) -> Result<bool> {
427 let byte = self.input.read_u8()?;
428 match byte {
429 0 => Ok(false),
430 1 => Ok(true),
431 _ => Err(Error::msg("Expect 00 or 01")),
432 }
433 }
434 #[inline]
435 fn borrow_bytes(&mut self, len: usize) -> Result<&'de [u8]> {
436 let pos = self.input.position() as usize;
437 let slice = self.input.get_ref();
438 if len > slice.len() || pos + len > slice.len() {
439 return Err(Error::msg(format!("Cannot read {len} bytes")));
440 }
441 let end = pos + len;
442 let res = &slice[pos..end];
443 self.input.set_position(end as u64);
444 Ok(res)
445 }
446 fn check_subtype(&mut self) -> Result<()> {
447 self.add_cost(self.table.0.len())?;
448 subtype_with_config(
449 OptReport::Silence,
450 &mut self.gamma,
451 &self.table,
452 &self.wire_type,
453 &self.expect_type,
454 )
455 .with_context(|| {
456 if self.config.full_error_message
457 || (text_size(&self.wire_type, MAX_TYPE_LEN).is_ok()
458 && text_size(&self.expect_type, MAX_TYPE_LEN).is_ok())
459 {
460 format!(
461 "{} is not a subtype of {}",
462 self.wire_type, self.expect_type,
463 )
464 } else {
465 "subtype mismatch".to_string()
466 }
467 })
468 .map_err(Error::subtype)?;
469 Ok(())
470 }
471 #[inline]
472 fn unroll_type(&mut self) -> Result<()> {
473 if matches!(
474 self.expect_type.as_ref(),
475 TypeInner::Var(_) | TypeInner::Knot(_)
476 ) {
477 self.add_cost(1)?;
478 self.expect_type = self
479 .table
480 .trace_type_with_depth(&self.expect_type, &self.recursion_depth)?;
481 }
482 if matches!(
483 self.wire_type.as_ref(),
484 TypeInner::Var(_) | TypeInner::Knot(_)
485 ) {
486 self.add_cost(1)?;
487 self.wire_type = self
488 .table
489 .trace_type_with_depth(&self.wire_type, &self.recursion_depth)?;
490 }
491 Ok(())
492 }
493 #[inline]
494 fn add_cost(&mut self, cost: usize) -> Result<()> {
495 if let Some(n) = self.config.decoding_quota {
496 let cost = if self.is_untyped {
497 cost.saturating_mul(50)
498 } else {
499 cost
500 };
501 if n < cost {
502 return Err(Error::msg("Decoding cost exceeds the limit"));
503 }
504 self.config.decoding_quota = Some(n - cost);
505 }
506 if self.is_untyped {
507 if let Some(n) = self.config.skipping_quota {
508 if n < cost {
509 return Err(Error::msg("Skipping cost exceeds the limit"));
510 }
511 self.config.skipping_quota = Some(n - cost);
512 }
513 }
514 Ok(())
515 }
516 fn set_field_name(&mut self, field: SharedLabel) {
519 if self.field_name.is_some() {
520 unreachable!();
521 }
522 self.field_name = Some(field);
523 }
524 #[cfg_attr(docsrs, doc(cfg(feature = "bignum")))]
531 #[cfg(feature = "bignum")]
532 fn deserialize_int<'a, V>(&'a mut self, visitor: V) -> Result<V::Value>
533 where
534 V: Visitor<'de>,
535 {
536 if self.bignum_vec_fast_path.is_none() {
537 self.unroll_type()?;
538 assert!(*self.expect_type == TypeInner::Int);
539 }
540 if !self.is_untyped {
541 let is_nat = matches!(self.wire_type.as_ref(), TypeInner::Nat);
542 let is_int = matches!(self.wire_type.as_ref(), TypeInner::Int);
543 if is_int {
544 let pos = self.input.position();
545 match self.try_read_leb_i64()? {
546 Some(value) => {
547 self.add_cost((self.input.position() - pos) as usize)?;
548 return visitor.visit_i64(value);
549 }
550 None => {
551 self.input.set_position(pos);
552 }
553 }
554 } else if is_nat {
555 let pos = self.input.position();
556 match self.try_read_leb_u64()? {
557 Some(value) => {
558 self.add_cost((self.input.position() - pos) as usize)?;
559 return visitor.visit_u64(value);
560 }
561 None => {
562 self.input.set_position(pos);
563 }
564 }
565 } else {
566 return Err(Error::subtype(format!(
567 "{} cannot be deserialized to int",
568 self.wire_type
569 )));
570 }
571 }
572 let bignum_pos = self.input.position();
573 let mut bytes = vec![0u8];
574 let int = match self.wire_type.as_ref() {
575 TypeInner::Int => Int::decode(&mut self.input).map_err(Error::msg)?,
576 TypeInner::Nat => Int(Nat::decode(&mut self.input).map_err(Error::msg)?.0.into()),
577 t => return Err(Error::subtype(format!("{t} cannot be deserialized to int"))),
578 };
579 self.add_cost((self.input.position() - bignum_pos) as usize)?;
580 bytes.extend_from_slice(&int.0.to_signed_bytes_le());
581 visitor.visit_byte_buf(bytes)
582 }
583 #[cfg_attr(docsrs, doc(cfg(feature = "bignum")))]
584 #[cfg(feature = "bignum")]
585 fn deserialize_nat<'a, V>(&'a mut self, visitor: V) -> Result<V::Value>
586 where
587 V: Visitor<'de>,
588 {
589 if self.bignum_vec_fast_path.is_none() {
590 self.unroll_type()?;
591 check!(
592 *self.expect_type == TypeInner::Nat && *self.wire_type == TypeInner::Nat,
593 "nat"
594 );
595 }
596 if !self.is_untyped {
597 let pos = self.input.position();
598 match self.try_read_leb_u64()? {
599 Some(value) => {
600 self.add_cost((self.input.position() - pos) as usize)?;
601 return visitor.visit_u64(value);
602 }
603 None => {
604 self.input.set_position(pos);
605 }
606 }
607 }
608 let pos = self.input.position();
609 let mut bytes = vec![1u8];
610 let nat = Nat::decode(&mut self.input).map_err(Error::msg)?;
611 self.add_cost((self.input.position() - pos) as usize)?;
612 bytes.extend_from_slice(&nat.0.to_bytes_le());
613 visitor.visit_byte_buf(bytes)
614 }
615 fn deserialize_principal<'a, V>(&'a mut self, visitor: V) -> Result<V::Value>
616 where
617 V: Visitor<'de>,
618 {
619 self.unroll_type()?;
620 check!(
621 *self.expect_type == TypeInner::Principal && *self.wire_type == TypeInner::Principal,
622 "principal"
623 );
624 let mut bytes = vec![2u8];
625 let id = PrincipalBytes::read(&mut self.input)?;
626 self.add_cost(std::cmp::max(30, id.len as usize))?;
627 bytes.extend_from_slice(&id.inner);
628 visitor.visit_byte_buf(bytes)
629 }
630 fn deserialize_reserved<'a, V>(&'a mut self, visitor: V) -> Result<V::Value>
631 where
632 V: Visitor<'de>,
633 {
634 self.add_cost(1)?;
635 let bytes = vec![3u8];
636 visitor.visit_byte_buf(bytes)
637 }
638 fn deserialize_service<'a, V>(&'a mut self, visitor: V) -> Result<V::Value>
639 where
640 V: Visitor<'de>,
641 {
642 self.unroll_type()?;
643 self.check_subtype()?;
644 let mut bytes = vec![4u8];
645 let id = PrincipalBytes::read(&mut self.input)?;
646 self.add_cost(std::cmp::max(30, id.len as usize))?;
647 bytes.extend_from_slice(&id.inner);
648 visitor.visit_byte_buf(bytes)
649 }
650 fn deserialize_function<'a, V>(&'a mut self, visitor: V) -> Result<V::Value>
651 where
652 V: Visitor<'de>,
653 {
654 self.unroll_type()?;
655 self.check_subtype()?;
656 if !self.read_bool_val()? {
657 return Err(Error::msg("Opaque reference not supported"));
658 }
659 let mut bytes = vec![5u8];
660 let id = PrincipalBytes::read(&mut self.input)?;
661 let len = self.read_len()?;
662 let meth = self.borrow_bytes(len)?;
663 self.add_cost(
664 std::cmp::max(30, id.len as usize)
665 .saturating_add(len)
666 .saturating_add(2),
667 )?;
668 leb128::write::unsigned(&mut bytes, len as u64)?;
670 bytes.extend_from_slice(meth);
671 bytes.extend_from_slice(&id.inner);
672 visitor.visit_byte_buf(bytes)
673 }
674 fn deserialize_blob<'a, V>(&'a mut self, visitor: V) -> Result<V::Value>
675 where
676 V: Visitor<'de>,
677 {
678 self.unroll_type()?;
679 check!(
680 self.expect_type.is_blob(&self.table) && self.wire_type.is_blob(&self.table),
681 "blob"
682 );
683 let len = self.read_len()?;
684 self.add_cost(len.saturating_add(1))?;
685 let blob = self.borrow_bytes(len)?;
686 let mut bytes = Vec::with_capacity(len + 1);
687 bytes.push(6u8);
688 bytes.extend_from_slice(blob);
689 visitor.visit_byte_buf(bytes)
690 }
691 fn deserialize_empty<'a, V>(&'a mut self, _visitor: V) -> Result<V::Value>
692 where
693 V: Visitor<'de>,
694 {
695 Err(if *self.wire_type == TypeInner::Empty {
696 Error::msg("Cannot decode empty type")
697 } else {
698 Error::subtype("Cannot decode empty type")
699 })
700 }
701 fn deserialize_future<'a, V>(&'a mut self, visitor: V) -> Result<V::Value>
702 where
703 V: Visitor<'de>,
704 {
705 let len = self.read_len()? as u64;
706 self.add_cost((len as usize).saturating_add(1))?;
707 self.read_len()?;
708 let slice_len = self.input.get_ref().len() as u64;
709 let pos = self.input.position();
710 if len > slice_len || pos + len > slice_len {
711 return Err(Error::msg(format!("Cannot read {len} bytes")));
712 }
713 self.input.set_position(pos + len);
714 visitor.visit_unit()
715 }
716 fn recoverable_visit_some<'a, V>(&'a mut self, visitor: V) -> Result<V::Value>
717 where
718 V: Visitor<'de>,
719 {
720 use de::Deserializer;
721 let tid = type_of(&visitor);
722 if tid != TypeId::of::<de::IgnoredAny>() && !tid.name.starts_with("serde::de::impls::OptionVisitor<")
728 && !tid.name.starts_with("serde_core::de::impls::OptionVisitor<")
731 {
732 #[cfg(feature = "value")]
733 if tid != TypeId::of::<crate::types::value::IDLValueVisitor>() {
734 panic!("Not a valid visitor: {tid:?}");
736 }
737 #[cfg(not(feature = "value"))]
738 panic!("Not a valid visitor: {tid:?}");
739 }
740 let v = unsafe { std::ptr::read(&visitor) };
742 let self_clone = self.clone();
743 match v.visit_some(&mut *self) {
744 Ok(v) => Ok(v),
745 Err(Error::Subtype(_)) => {
746 *self = Self {
747 config: self.config.clone(),
749 ..self_clone
750 };
751 self.add_cost(10)?;
752 self.deserialize_ignored_any(serde::de::IgnoredAny)?;
753 visitor.visit_none()
754 }
755 Err(e) => Err(e),
756 }
757 }
758}
759
760#[derive(Copy, Clone, Debug, Eq, PartialEq)]
761enum PrimitiveType {
762 Bool,
763 Int8,
764 Int16,
765 Int32,
766 Int64,
767 Nat8,
768 Nat16,
769 Nat32,
770 Nat64,
771 Float32,
772 Float64,
773}
774
775fn primitive_byte_cost(p: PrimitiveType) -> usize {
776 match p {
777 PrimitiveType::Bool | PrimitiveType::Int8 | PrimitiveType::Nat8 => 1,
778 PrimitiveType::Int16 | PrimitiveType::Nat16 => 2,
779 PrimitiveType::Int32 | PrimitiveType::Nat32 | PrimitiveType::Float32 => 4,
780 PrimitiveType::Int64 | PrimitiveType::Nat64 | PrimitiveType::Float64 => 8,
781 }
782}
783
784#[cfg(feature = "bignum")]
785#[derive(Copy, Clone, Debug, Eq, PartialEq)]
786enum BigNumFastPath {
787 Nat,
788 Int,
789 NatAsInt,
790}
791
792fn exact_primitive_type(expect: &Type, wire: &Type) -> Option<PrimitiveType> {
793 match (expect.as_ref(), wire.as_ref()) {
794 (TypeInner::Bool, TypeInner::Bool) => Some(PrimitiveType::Bool),
795 (TypeInner::Int8, TypeInner::Int8) => Some(PrimitiveType::Int8),
796 (TypeInner::Int16, TypeInner::Int16) => Some(PrimitiveType::Int16),
797 (TypeInner::Int32, TypeInner::Int32) => Some(PrimitiveType::Int32),
798 (TypeInner::Int64, TypeInner::Int64) => Some(PrimitiveType::Int64),
799 (TypeInner::Nat8, TypeInner::Nat8) => Some(PrimitiveType::Nat8),
800 (TypeInner::Nat16, TypeInner::Nat16) => Some(PrimitiveType::Nat16),
801 (TypeInner::Nat32, TypeInner::Nat32) => Some(PrimitiveType::Nat32),
802 (TypeInner::Nat64, TypeInner::Nat64) => Some(PrimitiveType::Nat64),
803 (TypeInner::Float32, TypeInner::Float32) => Some(PrimitiveType::Float32),
804 (TypeInner::Float64, TypeInner::Float64) => Some(PrimitiveType::Float64),
805 _ => None,
806 }
807}
808
809macro_rules! primitive_impl {
810 ($ty:ident, $type:expr, $fast:expr, $cost:literal, $($value:tt)*) => {
811 paste::item! {
812 fn [<deserialize_ $ty>]<V>(self, visitor: V) -> Result<V::Value>
813 where V: Visitor<'de> {
814 if self.primitive_vec_fast_path == Some($fast) {
815 let val = self.input.$($value)*().map_err(|_| Error::msg(format!("Cannot read {} value", stringify!($type))))?;
816 return visitor.[<visit_ $ty>](val);
817 }
818 self.unroll_type()?;
819 check!(*self.expect_type == $type && *self.wire_type == $type, stringify!($type));
820 self.add_cost($cost)?;
821 let val = self.input.$($value)*().map_err(|_| Error::msg(format!("Cannot read {} value", stringify!($type))))?;
822 visitor.[<visit_ $ty>](val)
823 }
824 }
825 };
826}
827
828impl<'de> de::Deserializer<'de> for &mut Deserializer<'de> {
829 type Error = Error;
830 fn deserialize_any<V>(self, visitor: V) -> Result<V::Value>
831 where
832 V: Visitor<'de>,
833 {
834 if self.field_name.is_some() {
835 return self.deserialize_identifier(visitor);
836 }
837 #[cfg(feature = "bignum")]
838 if let Some(fast) = self.bignum_vec_fast_path {
839 return match fast {
840 BigNumFastPath::Nat => self.deserialize_nat(visitor),
841 BigNumFastPath::Int | BigNumFastPath::NatAsInt => self.deserialize_int(visitor),
842 };
843 }
844 self.unroll_type()?;
845 match self.expect_type.as_ref() {
846 #[cfg(feature = "bignum")]
847 TypeInner::Int => self.deserialize_int(visitor),
848 #[cfg(not(feature = "bignum"))]
849 TypeInner::Int => self.deserialize_i128(visitor),
850 #[cfg(feature = "bignum")]
851 TypeInner::Nat => self.deserialize_nat(visitor),
852 #[cfg(not(feature = "bignum"))]
853 TypeInner::Nat => self.deserialize_u128(visitor),
854 TypeInner::Nat8 => self.deserialize_u8(visitor),
855 TypeInner::Nat16 => self.deserialize_u16(visitor),
856 TypeInner::Nat32 => self.deserialize_u32(visitor),
857 TypeInner::Nat64 => self.deserialize_u64(visitor),
858 TypeInner::Int8 => self.deserialize_i8(visitor),
859 TypeInner::Int16 => self.deserialize_i16(visitor),
860 TypeInner::Int32 => self.deserialize_i32(visitor),
861 TypeInner::Int64 => self.deserialize_i64(visitor),
862 TypeInner::Float32 => self.deserialize_f32(visitor),
863 TypeInner::Float64 => self.deserialize_f64(visitor),
864 TypeInner::Bool => self.deserialize_bool(visitor),
865 TypeInner::Text => self.deserialize_string(visitor),
866 TypeInner::Null => self.deserialize_unit(visitor),
867 TypeInner::Reserved => {
868 if self.wire_type.as_ref() != &TypeInner::Reserved {
869 self.deserialize_ignored_any(serde::de::IgnoredAny)?;
870 }
871 self.deserialize_reserved(visitor)
872 }
873 TypeInner::Empty => self.deserialize_empty(visitor),
874 TypeInner::Principal => self.deserialize_principal(visitor),
875 TypeInner::Opt(_) => self.deserialize_option(visitor),
877 TypeInner::Vec(_) if self.expect_type.is_blob(&self.table) => {
879 self.deserialize_blob(visitor)
880 }
881 TypeInner::Vec(_) => self.deserialize_seq(visitor),
882 TypeInner::Record(_) => self.deserialize_struct("_", &[], visitor),
883 TypeInner::Variant(_) => self.deserialize_enum("_", &[], visitor),
884 TypeInner::Service(_) => self.deserialize_service(visitor),
885 TypeInner::Func(_) => self.deserialize_function(visitor),
886 TypeInner::Future => self.deserialize_future(visitor),
887 _ => assert!(false),
888 }
889 }
890 fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value>
891 where
892 V: Visitor<'de>,
893 {
894 let is_untyped = replace(&mut self.is_untyped, true);
895 self.expect_type = self.wire_type.clone();
896 let v = self.deserialize_any(visitor);
897 self.is_untyped = is_untyped;
898 v
899 }
900
901 primitive_impl!(i8, TypeInner::Int8, PrimitiveType::Int8, 1, read_i8);
902 primitive_impl!(
903 i16,
904 TypeInner::Int16,
905 PrimitiveType::Int16,
906 2,
907 read_i16::<LittleEndian>
908 );
909 primitive_impl!(
910 i32,
911 TypeInner::Int32,
912 PrimitiveType::Int32,
913 4,
914 read_i32::<LittleEndian>
915 );
916 primitive_impl!(
917 i64,
918 TypeInner::Int64,
919 PrimitiveType::Int64,
920 8,
921 read_i64::<LittleEndian>
922 );
923 primitive_impl!(u8, TypeInner::Nat8, PrimitiveType::Nat8, 1, read_u8);
924 primitive_impl!(
925 u16,
926 TypeInner::Nat16,
927 PrimitiveType::Nat16,
928 2,
929 read_u16::<LittleEndian>
930 );
931 primitive_impl!(
932 u32,
933 TypeInner::Nat32,
934 PrimitiveType::Nat32,
935 4,
936 read_u32::<LittleEndian>
937 );
938 primitive_impl!(
939 u64,
940 TypeInner::Nat64,
941 PrimitiveType::Nat64,
942 8,
943 read_u64::<LittleEndian>
944 );
945 primitive_impl!(
946 f32,
947 TypeInner::Float32,
948 PrimitiveType::Float32,
949 4,
950 read_f32::<LittleEndian>
951 );
952 primitive_impl!(
953 f64,
954 TypeInner::Float64,
955 PrimitiveType::Float64,
956 8,
957 read_f64::<LittleEndian>
958 );
959
960 fn is_human_readable(&self) -> bool {
961 false
962 }
963 fn deserialize_i128<V>(self, visitor: V) -> Result<V::Value>
964 where
965 V: Visitor<'de>,
966 {
967 use crate::types::leb128::{decode_int, decode_nat};
968 self.unroll_type()?;
969 assert!(*self.expect_type == TypeInner::Int);
970 self.add_cost(16)?;
971 let value: i128 = match self.wire_type.as_ref() {
972 TypeInner::Int => decode_int(&mut self.input)?,
973 TypeInner::Nat => i128::try_from(decode_nat(&mut self.input)?)
974 .map_err(|_| Error::msg("Cannot convert nat to i128"))?,
975 t => return Err(Error::subtype(format!("{t} cannot be deserialized to int"))),
976 };
977 visitor.visit_i128(value)
978 }
979 fn deserialize_u128<V>(self, visitor: V) -> Result<V::Value>
980 where
981 V: Visitor<'de>,
982 {
983 self.unroll_type()?;
984 check!(
985 *self.expect_type == TypeInner::Nat && *self.wire_type == TypeInner::Nat,
986 "nat"
987 );
988 self.add_cost(16)?;
989 let value = crate::types::leb128::decode_nat(&mut self.input)?;
990 visitor.visit_u128(value)
991 }
992 fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value>
993 where
994 V: Visitor<'de>,
995 {
996 self.unroll_type()?;
997 check!(
998 *self.expect_type == TypeInner::Null && matches!(*self.wire_type, TypeInner::Null),
999 "unit"
1000 );
1001 self.add_cost(1)?;
1002 visitor.visit_unit()
1003 }
1004 fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value>
1007 where
1008 V: Visitor<'de>,
1009 {
1010 if self.primitive_vec_fast_path == Some(PrimitiveType::Bool) {
1011 let val = self.read_bool_val()?;
1012 return visitor.visit_bool(val);
1013 }
1014 self.unroll_type()?;
1015 check!(
1016 *self.expect_type == TypeInner::Bool && *self.wire_type == TypeInner::Bool,
1017 "bool"
1018 );
1019 self.add_cost(1)?;
1020 let val = self.read_bool_val()?;
1021 visitor.visit_bool(val)
1022 }
1023 fn deserialize_string<V>(self, visitor: V) -> Result<V::Value>
1024 where
1025 V: Visitor<'de>,
1026 {
1027 self.deserialize_str(visitor)
1028 }
1029 fn deserialize_str<V>(self, visitor: V) -> Result<V::Value>
1030 where
1031 V: Visitor<'de>,
1032 {
1033 if !self.text_fast_path {
1034 self.unroll_type()?;
1035 check!(
1036 *self.expect_type == TypeInner::Text && *self.wire_type == TypeInner::Text,
1037 "text"
1038 );
1039 }
1040 let len = self.read_len()?;
1041 self.add_cost(len.saturating_add(1))?;
1042 let slice = self.borrow_bytes(len)?;
1043 let value: &str = std::str::from_utf8(slice).map_err(Error::msg)?;
1044 visitor.visit_borrowed_str(value)
1045 }
1046 fn deserialize_unit_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value>
1047 where
1048 V: Visitor<'de>,
1049 {
1050 self.add_cost(1)?;
1051 self.deserialize_unit(visitor)
1052 }
1053 fn deserialize_newtype_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value>
1054 where
1055 V: Visitor<'de>,
1056 {
1057 self.add_cost(1)?;
1058 visitor.visit_newtype_struct(self)
1059 }
1060 fn deserialize_option<V>(self, visitor: V) -> Result<V::Value>
1061 where
1062 V: Visitor<'de>,
1063 {
1064 self.unroll_type()?;
1065 self.add_cost(1)?;
1066 match (self.wire_type.as_ref(), self.expect_type.as_ref()) {
1067 (TypeInner::Null | TypeInner::Reserved, TypeInner::Opt(_)) => visitor.visit_none(),
1068 (TypeInner::Opt(t1), TypeInner::Opt(t2)) => {
1069 self.wire_type = t1.clone();
1070 self.expect_type = t2.clone();
1071 if self.read_bool_val()? {
1072 let _guard = self.recursion_depth.guard()?;
1073 self.recoverable_visit_some(visitor)
1074 } else {
1075 visitor.visit_none()
1076 }
1077 }
1078 (_, TypeInner::Opt(t2)) => {
1079 self.expect_type = self
1080 .table
1081 .trace_type_with_depth(t2, &self.recursion_depth)?;
1082 let _guard = self.recursion_depth.guard()?;
1083 self.recoverable_visit_some(visitor)
1084 }
1085 (_, _) => check!(false),
1086 }
1087 }
1088 fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value>
1089 where
1090 V: Visitor<'de>,
1091 {
1092 let _guard = self.recursion_depth.guard()?;
1093 self.unroll_type()?;
1094 self.add_cost(1)?;
1095 match (self.expect_type.as_ref(), self.wire_type.as_ref()) {
1096 (TypeInner::Vec(e), TypeInner::Vec(w)) => {
1097 let expect = e.clone();
1098 let wire = self.table.trace_type_with_depth(w, &self.recursion_depth)?;
1099 let len = self.read_len()?;
1100 let exact_primitive = exact_primitive_type(&expect, &wire);
1101 if let Some(prim) = exact_primitive {
1102 let per_element_cost = 3 + primitive_byte_cost(prim);
1103 self.add_cost(
1104 len.checked_mul(per_element_cost)
1105 .ok_or_else(|| Error::msg("Vec length overflow"))?,
1106 )?;
1107
1108 #[cfg(target_endian = "little")]
1109 {
1110 let byte_size = primitive_byte_cost(prim);
1111 let total_bytes = len
1112 .checked_mul(byte_size)
1113 .ok_or_else(|| Error::msg("Vec byte length overflow"))?;
1114 let pos = self.input.position() as usize;
1115 let slice = self.input.get_ref();
1116 if pos + total_bytes > slice.len() {
1117 return Err(Error::msg(format!(
1118 "Not enough bytes for primitive vec: need {total_bytes}, have {}",
1119 slice.len() - pos
1120 )));
1121 }
1122 let data = &slice[pos..pos + total_bytes];
1123 let mut access = PrimitiveVecAccess {
1124 data,
1125 offset: 0,
1126 remaining: len,
1127 element_size: byte_size,
1128 prim,
1129 };
1130 let result = visitor.visit_seq(&mut access);
1131 self.input.set_position((pos + access.offset) as u64);
1134 return result;
1135 }
1136
1137 #[cfg(not(target_endian = "little"))]
1138 {
1139 self.primitive_vec_fast_path = exact_primitive;
1140 }
1141 }
1142 #[cfg(feature = "bignum")]
1143 let bignum_fast = if exact_primitive.is_none() {
1144 match (expect.as_ref(), wire.as_ref()) {
1145 (TypeInner::Nat, TypeInner::Nat) => Some(BigNumFastPath::Nat),
1146 (TypeInner::Int, TypeInner::Int) => Some(BigNumFastPath::Int),
1147 (TypeInner::Int, TypeInner::Nat) => Some(BigNumFastPath::NatAsInt),
1148 _ => None,
1149 }
1150 } else {
1151 None
1152 };
1153 #[cfg(feature = "bignum")]
1154 if let Some(fast) = bignum_fast {
1155 self.add_cost(
1156 len.checked_mul(3)
1157 .ok_or_else(|| Error::msg("Vec length overflow"))?,
1158 )?;
1159 self.bignum_vec_fast_path = Some(fast);
1160 self.expect_type = expect.clone();
1161 self.wire_type = wire.clone();
1162 }
1163 let result = visitor.visit_seq(Compound::new(
1164 self,
1165 Style::Vector {
1166 len,
1167 expect,
1168 wire,
1169 exact_primitive,
1170 },
1171 ));
1172 result
1173 }
1174 (TypeInner::Record(_), TypeInner::Record(_)) => {
1175 let expect = self.expect_type.clone();
1176 let wire = self.wire_type.clone();
1177 check!(self.expect_type.is_tuple(), "seq_tuple");
1178 if !self.wire_type.is_tuple() {
1179 return Err(Error::subtype(format!(
1180 "{} is not a tuple type",
1181 self.wire_type
1182 )));
1183 }
1184 let value = visitor.visit_seq(Compound::new(
1185 self,
1186 Style::Struct {
1187 expect,
1188 wire,
1189 expect_idx: 0,
1190 wire_idx: 0,
1191 },
1192 ))?;
1193 Ok(value)
1194 }
1195 _ => check!(false),
1196 }
1197 }
1198 fn deserialize_byte_buf<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
1199 self.unroll_type()?;
1200 check!(
1201 *self.expect_type == TypeInner::Vec(TypeInner::Nat8.into())
1202 && *self.wire_type == TypeInner::Vec(TypeInner::Nat8.into()),
1203 "vec nat8"
1204 );
1205 let len = self.read_len()?;
1206 self.add_cost(len.saturating_add(1))?;
1207 let bytes = self.borrow_bytes(len)?.to_owned();
1208 visitor.visit_byte_buf(bytes)
1209 }
1210 fn deserialize_bytes<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
1211 self.unroll_type()?;
1212 match self.expect_type.as_ref() {
1213 TypeInner::Principal => self.deserialize_principal(visitor),
1214 TypeInner::Vec(t) if **t == TypeInner::Nat8 => {
1215 let len = self.read_len()?;
1216 self.add_cost(len.saturating_add(1))?;
1217 let slice = self.borrow_bytes(len)?;
1218 visitor.visit_borrowed_bytes(slice)
1219 }
1220 _ => Err(Error::subtype("bytes only takes principal or vec nat8")),
1221 }
1222 }
1223 fn deserialize_map<V>(self, visitor: V) -> Result<V::Value>
1224 where
1225 V: Visitor<'de>,
1226 {
1227 let _guard = self.recursion_depth.guard()?;
1228 self.unroll_type()?;
1229 self.add_cost(1)?;
1230 match (self.expect_type.as_ref(), self.wire_type.as_ref()) {
1231 (TypeInner::Vec(e), TypeInner::Vec(w)) => {
1232 let e = self.table.trace_type_with_depth(e, &self.recursion_depth)?;
1233 let w = self.table.trace_type_with_depth(w, &self.recursion_depth)?;
1234 match (e.as_ref(), w.as_ref()) {
1235 (TypeInner::Record(ref e), TypeInner::Record(ref w)) => {
1236 match (&e[..], &w[..]) {
1237 (
1238 [Field { id: e_id0, ty: ek }, Field { id: e_id1, ty: ev }],
1239 [Field { id: w_id0, ty: wk }, Field { id: w_id1, ty: wv }],
1240 ) if **e_id0 == Label::Id(0)
1241 && **e_id1 == Label::Id(1)
1242 && **w_id0 == Label::Id(0)
1243 && **w_id1 == Label::Id(1) =>
1244 {
1245 let expect = (ek.clone(), ev.clone());
1246 let wire = (wk.clone(), wv.clone());
1247 let len = self.read_len()?;
1248
1249 let key_text_fast = matches!(ek.as_ref(), TypeInner::Text)
1250 && matches!(wk.as_ref(), TypeInner::Text);
1251 #[cfg(feature = "bignum")]
1252 let value_bignum_fast = match (ev.as_ref(), wv.as_ref()) {
1253 (TypeInner::Nat, TypeInner::Nat) => Some(BigNumFastPath::Nat),
1254 (TypeInner::Int, TypeInner::Int) => Some(BigNumFastPath::Int),
1255 (TypeInner::Int, TypeInner::Nat) => {
1256 Some(BigNumFastPath::NatAsInt)
1257 }
1258 _ => None,
1259 };
1260 #[cfg(feature = "bignum")]
1261 let any_fast = key_text_fast || value_bignum_fast.is_some();
1262 #[cfg(not(feature = "bignum"))]
1263 let any_fast = key_text_fast;
1264
1265 if any_fast {
1266 self.add_cost(
1267 len.checked_mul(7)
1268 .ok_or_else(|| Error::msg("Map length overflow"))?,
1269 )?;
1270 if key_text_fast {
1271 self.text_fast_path = true;
1272 }
1273 #[cfg(feature = "bignum")]
1274 if let Some(fast) = value_bignum_fast {
1275 self.bignum_vec_fast_path = Some(fast);
1276 self.wire_type = wv.clone();
1277 }
1278 }
1279
1280 let result = visitor.visit_map(Compound::new(
1281 self,
1282 Style::Map { len, expect, wire },
1283 ));
1284 self.text_fast_path = false;
1285 #[cfg(feature = "bignum")]
1286 {
1287 self.bignum_vec_fast_path = None;
1288 }
1289 result
1290 }
1291 _ => Err(Error::subtype("expect a key-value pair")),
1292 }
1293 }
1294 _ => Err(Error::subtype("expect a key-value pair")),
1295 }
1296 }
1297 _ => check!(false),
1298 }
1299 }
1300 fn deserialize_tuple<V>(self, _len: usize, visitor: V) -> Result<V::Value>
1301 where
1302 V: Visitor<'de>,
1303 {
1304 let _guard = self.recursion_depth.guard()?;
1305 self.add_cost(1)?;
1306 self.deserialize_seq(visitor)
1307 }
1308 fn deserialize_tuple_struct<V>(
1309 self,
1310 _name: &'static str,
1311 _len: usize,
1312 visitor: V,
1313 ) -> Result<V::Value>
1314 where
1315 V: Visitor<'de>,
1316 {
1317 let _guard = self.recursion_depth.guard()?;
1318 self.add_cost(1)?;
1319 self.deserialize_seq(visitor)
1320 }
1321 fn deserialize_struct<V>(
1322 self,
1323 _name: &'static str,
1324 _fields: &'static [&'static str],
1325 visitor: V,
1326 ) -> Result<V::Value>
1327 where
1328 V: Visitor<'de>,
1329 {
1330 let _guard = self.recursion_depth.guard()?;
1331 self.unroll_type()?;
1332 self.add_cost(1)?;
1333 match (self.expect_type.as_ref(), self.wire_type.as_ref()) {
1334 (TypeInner::Record(_), TypeInner::Record(_)) => {
1335 let value = visitor.visit_map(Compound::new(
1336 self,
1337 Style::Struct {
1338 expect: self.expect_type.clone(),
1339 wire: self.wire_type.clone(),
1340 expect_idx: 0,
1341 wire_idx: 0,
1342 },
1343 ))?;
1344 Ok(value)
1345 }
1346 _ => check!(false),
1347 }
1348 }
1349 fn deserialize_enum<V>(
1350 self,
1351 _name: &'static str,
1352 _variants: &'static [&'static str],
1353 visitor: V,
1354 ) -> Result<V::Value>
1355 where
1356 V: Visitor<'de>,
1357 {
1358 let _guard = self.recursion_depth.guard()?;
1359 self.unroll_type()?;
1360 self.add_cost(1)?;
1361 match (self.expect_type.as_ref(), self.wire_type.as_ref()) {
1362 (TypeInner::Variant(e), TypeInner::Variant(w)) => {
1363 let index = Len::read(&mut self.input)?.0;
1364 let len = w.len();
1365 if index >= len {
1366 return Err(Error::msg(format!(
1367 "Variant index {index} larger than length {len}"
1368 )));
1369 }
1370 let wire = w[index].clone();
1371 let expect = match e.iter().find(|f| f.id == wire.id) {
1372 Some(v) => v.clone(),
1373 None => {
1374 return Err(Error::subtype(format!("Unknown variant field {}", wire.id)));
1375 }
1376 };
1377 visitor.visit_enum(Compound::new(self, Style::Enum { expect, wire }))
1378 }
1379 _ => check!(false),
1380 }
1381 }
1382 fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value>
1383 where
1384 V: Visitor<'de>,
1385 {
1386 match self.field_name.take() {
1387 Some(l) => match l.as_ref() {
1388 Label::Named(name) => {
1389 self.add_cost(name.len())?;
1390 visitor.visit_string(name.to_string())
1391 }
1392 Label::Id(hash) | Label::Unnamed(hash) => {
1393 self.add_cost(4)?;
1394 visitor.visit_u32(*hash)
1395 }
1396 },
1397 None => assert!(false),
1398 }
1399 }
1400
1401 serde::forward_to_deserialize_any! {
1402 char
1403 }
1404}
1405
1406#[derive(Debug)]
1407enum Style {
1408 Vector {
1409 len: usize,
1410 expect: Type,
1411 wire: Type,
1412 exact_primitive: Option<PrimitiveType>,
1413 },
1414 Struct {
1415 expect: Type,
1416 wire: Type,
1417 expect_idx: usize,
1418 wire_idx: usize,
1419 },
1420 Enum {
1421 expect: Field,
1422 wire: Field,
1423 },
1424 Map {
1425 len: usize,
1426 expect: (Type, Type),
1427 wire: (Type, Type),
1428 },
1429}
1430
1431#[cfg(target_endian = "little")]
1432struct PrimitiveVecAccess<'de> {
1433 data: &'de [u8],
1434 offset: usize,
1435 remaining: usize,
1436 element_size: usize,
1437 prim: PrimitiveType,
1438}
1439
1440#[cfg(target_endian = "little")]
1441impl<'de> de::SeqAccess<'de> for PrimitiveVecAccess<'de> {
1442 type Error = Error;
1443
1444 fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
1445 where
1446 T: de::DeserializeSeed<'de>,
1447 {
1448 use serde::de::IntoDeserializer;
1449 if self.remaining == 0 {
1450 return Ok(None);
1451 }
1452 self.remaining -= 1;
1453 let bytes = &self.data[self.offset..self.offset + self.element_size];
1454 self.offset += self.element_size;
1455
1456 match self.prim {
1457 PrimitiveType::Bool => match bytes[0] {
1458 0 => seed.deserialize(false.into_deserializer()).map(Some),
1459 1 => seed.deserialize(true.into_deserializer()).map(Some),
1460 _ => Err(Error::msg("Expect 00 or 01")),
1461 },
1462 PrimitiveType::Nat8 => seed.deserialize(bytes[0].into_deserializer()).map(Some),
1463 PrimitiveType::Int8 => seed
1464 .deserialize((bytes[0] as i8).into_deserializer())
1465 .map(Some),
1466 PrimitiveType::Nat16 => {
1467 let v = u16::from_le_bytes(bytes.try_into().unwrap());
1468 seed.deserialize(v.into_deserializer()).map(Some)
1469 }
1470 PrimitiveType::Int16 => {
1471 let v = i16::from_le_bytes(bytes.try_into().unwrap());
1472 seed.deserialize(v.into_deserializer()).map(Some)
1473 }
1474 PrimitiveType::Nat32 => {
1475 let v = u32::from_le_bytes(bytes.try_into().unwrap());
1476 seed.deserialize(v.into_deserializer()).map(Some)
1477 }
1478 PrimitiveType::Int32 => {
1479 let v = i32::from_le_bytes(bytes.try_into().unwrap());
1480 seed.deserialize(v.into_deserializer()).map(Some)
1481 }
1482 PrimitiveType::Float32 => {
1483 let v = f32::from_le_bytes(bytes.try_into().unwrap());
1484 seed.deserialize(v.into_deserializer()).map(Some)
1485 }
1486 PrimitiveType::Nat64 => {
1487 let v = u64::from_le_bytes(bytes.try_into().unwrap());
1488 seed.deserialize(v.into_deserializer()).map(Some)
1489 }
1490 PrimitiveType::Int64 => {
1491 let v = i64::from_le_bytes(bytes.try_into().unwrap());
1492 seed.deserialize(v.into_deserializer()).map(Some)
1493 }
1494 PrimitiveType::Float64 => {
1495 let v = f64::from_le_bytes(bytes.try_into().unwrap());
1496 seed.deserialize(v.into_deserializer()).map(Some)
1497 }
1498 }
1499 }
1500
1501 fn size_hint(&self) -> Option<usize> {
1502 Some(self.remaining)
1503 }
1504}
1505
1506struct Compound<'a, 'de> {
1507 de: &'a mut Deserializer<'de>,
1508 style: Style,
1509}
1510
1511impl Style {
1512 fn struct_remaining(&self) -> Option<usize> {
1513 match self {
1514 Style::Struct {
1515 expect,
1516 wire,
1517 expect_idx,
1518 wire_idx,
1519 } => {
1520 let remaining = |ty: &Type, idx: usize| match ty.as_ref() {
1521 TypeInner::Record(fields) => fields.len().saturating_sub(idx),
1522 _ => 0,
1523 };
1524 Some(remaining(expect, *expect_idx).min(remaining(wire, *wire_idx)))
1525 }
1526 _ => None,
1527 }
1528 }
1529}
1530
1531impl<'a, 'de> Compound<'a, 'de> {
1532 fn new(de: &'a mut Deserializer<'de>, style: Style) -> Self {
1533 Compound { de, style }
1534 }
1535}
1536
1537impl<'de> de::SeqAccess<'de> for Compound<'_, 'de> {
1538 type Error = Error;
1539
1540 fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
1541 where
1542 T: de::DeserializeSeed<'de>,
1543 {
1544 match self.style {
1545 Style::Vector {
1546 ref mut len,
1547 ref expect,
1548 ref wire,
1549 exact_primitive,
1550 } => {
1551 if *len == 0 {
1552 return Ok(None);
1553 }
1554 *len -= 1;
1555 self.de.expect_type = expect.clone();
1556 self.de.wire_type = wire.clone();
1557 #[cfg(feature = "bignum")]
1558 let is_fast = exact_primitive.is_some() || self.de.bignum_vec_fast_path.is_some();
1559 #[cfg(not(feature = "bignum"))]
1560 let is_fast = exact_primitive.is_some();
1561 if !is_fast {
1562 self.de.add_cost(3)?;
1563 }
1564 seed.deserialize(&mut *self.de).map(Some)
1565 }
1566 Style::Struct {
1567 ref expect,
1568 ref wire,
1569 ref mut expect_idx,
1570 ref mut wire_idx,
1571 } => {
1572 self.de.add_cost(3)?;
1573 let expect_fields = match expect.as_ref() {
1574 TypeInner::Record(fields) => fields,
1575 _ => unreachable!(),
1576 };
1577 let wire_fields = match wire.as_ref() {
1578 TypeInner::Record(fields) => fields,
1579 _ => unreachable!(),
1580 };
1581 if *expect_idx >= expect_fields.len() && *wire_idx >= wire_fields.len() {
1582 return Ok(None);
1583 }
1584 self.de.expect_type = expect_fields
1585 .get(*expect_idx)
1586 .map(|f| {
1587 *expect_idx += 1;
1588 f.ty.clone()
1589 })
1590 .unwrap_or_else(|| TypeInner::Reserved.into());
1591 self.de.wire_type = wire_fields
1592 .get(*wire_idx)
1593 .map(|f| {
1594 *wire_idx += 1;
1595 f.ty.clone()
1596 })
1597 .unwrap_or_else(|| TypeInner::Null.into());
1598 seed.deserialize(&mut *self.de).map(Some)
1599 }
1600 _ => Err(Error::subtype("expect vector or tuple")),
1601 }
1602 }
1603
1604 fn size_hint(&self) -> Option<usize> {
1605 match &self.style {
1606 Style::Vector { len, .. } => Some(*len),
1607 _ => self.style.struct_remaining(),
1608 }
1609 }
1610}
1611
1612impl Drop for Compound<'_, '_> {
1613 fn drop(&mut self) {
1614 self.de.primitive_vec_fast_path = None;
1615 self.de.text_fast_path = false;
1616 #[cfg(feature = "bignum")]
1617 {
1618 self.de.bignum_vec_fast_path = None;
1619 }
1620 }
1621}
1622
1623impl<'de> de::MapAccess<'de> for Compound<'_, 'de> {
1624 type Error = Error;
1625 fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>>
1626 where
1627 K: de::DeserializeSeed<'de>,
1628 {
1629 self.de.add_cost(4)?;
1630 match self.style {
1631 Style::Struct {
1632 ref expect,
1633 ref wire,
1634 ref mut expect_idx,
1635 ref mut wire_idx,
1636 } => {
1637 let expect_fields = match expect.as_ref() {
1638 TypeInner::Record(fields) => fields,
1639 _ => unreachable!(),
1640 };
1641 let wire_fields = match wire.as_ref() {
1642 TypeInner::Record(fields) => fields,
1643 _ => unreachable!(),
1644 };
1645 match (expect_fields.get(*expect_idx), wire_fields.get(*wire_idx)) {
1646 (Some(e), Some(w)) => {
1647 use std::cmp::Ordering;
1648 match e.id.get_id().cmp(&w.id.get_id()) {
1649 Ordering::Equal => {
1650 self.de.set_field_name(e.id.clone());
1651 self.de.expect_type = e.ty.clone();
1652 self.de.wire_type = w.ty.clone();
1653 *expect_idx += 1;
1654 *wire_idx += 1;
1655 }
1656 Ordering::Less => {
1657 let field = e.id.clone();
1659 self.de.set_field_name(field.clone());
1660 let expect = e.ty.clone();
1661 *expect_idx += 1;
1662 self.de.expect_type = self
1663 .de
1664 .table
1665 .trace_type_with_depth(&expect, &self.de.recursion_depth)?;
1666 check!(
1667 matches!(
1668 self.de.expect_type.as_ref(),
1669 TypeInner::Opt(_) | TypeInner::Reserved | TypeInner::Null
1670 ),
1671 format!("field {field} is not optional field")
1672 );
1673 self.de.wire_type = TypeInner::Null.into();
1674 }
1675 Ordering::Greater => {
1676 self.de.set_field_name(Label::Named("_".to_owned()).into());
1677 self.de.wire_type = w.ty.clone();
1678 self.de.expect_type = TypeInner::Reserved.into();
1679 *wire_idx += 1;
1680 }
1681 }
1682 }
1683 (None, Some(_)) => {
1684 self.de.set_field_name(Label::Named("_".to_owned()).into());
1685 self.de.wire_type = wire_fields[*wire_idx].ty.clone();
1686 self.de.expect_type = TypeInner::Reserved.into();
1687 *wire_idx += 1;
1688 }
1689 (Some(e), None) => {
1690 self.de.set_field_name(e.id.clone());
1691 self.de.expect_type = e.ty.clone();
1692 self.de.wire_type = TypeInner::Null.into();
1693 *expect_idx += 1;
1694 }
1695 (None, None) => return Ok(None),
1696 }
1697 seed.deserialize(&mut *self.de).map(Some)
1698 }
1699 Style::Map {
1700 ref mut len,
1701 ref expect,
1702 ref wire,
1703 } => {
1704 if *len == 0 {
1705 return Ok(None);
1706 }
1707 *len -= 1;
1708 #[cfg(feature = "bignum")]
1709 let any_fast = self.de.text_fast_path || self.de.bignum_vec_fast_path.is_some();
1710 #[cfg(not(feature = "bignum"))]
1711 let any_fast = self.de.text_fast_path;
1712 if !any_fast {
1713 self.de.add_cost(4)?;
1714 }
1715 if !self.de.text_fast_path {
1716 self.de.expect_type = expect.0.clone();
1717 self.de.wire_type = wire.0.clone();
1718 }
1719 seed.deserialize(&mut *self.de).map(Some)
1720 }
1721 _ => Err(Error::msg("expect struct or map")),
1722 }
1723 }
1724 fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value>
1725 where
1726 V: de::DeserializeSeed<'de>,
1727 {
1728 match &self.style {
1729 Style::Map { expect, wire, .. } => {
1730 #[cfg(feature = "bignum")]
1731 let any_fast = self.de.text_fast_path || self.de.bignum_vec_fast_path.is_some();
1732 #[cfg(not(feature = "bignum"))]
1733 let any_fast = self.de.text_fast_path;
1734 if !any_fast {
1735 self.de.add_cost(3)?;
1736 }
1737 #[cfg(feature = "bignum")]
1738 let value_fast = self.de.bignum_vec_fast_path.is_some();
1739 #[cfg(not(feature = "bignum"))]
1740 let value_fast = false;
1741 if !value_fast {
1742 self.de.expect_type = expect.1.clone();
1743 self.de.wire_type = wire.1.clone();
1744 }
1745 seed.deserialize(&mut *self.de)
1746 }
1747 _ => {
1748 self.de.add_cost(1)?;
1749 seed.deserialize(&mut *self.de)
1750 }
1751 }
1752 }
1753
1754 fn size_hint(&self) -> Option<usize> {
1755 match &self.style {
1756 Style::Map { len, .. } => Some(*len),
1757 _ => self.style.struct_remaining(),
1758 }
1759 }
1760}
1761
1762impl<'de> de::EnumAccess<'de> for Compound<'_, 'de> {
1763 type Error = Error;
1764 type Variant = Self;
1765
1766 fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant)>
1767 where
1768 V: de::DeserializeSeed<'de>,
1769 {
1770 self.de.add_cost(4)?;
1771 match &self.style {
1772 Style::Enum { expect, wire } => {
1773 self.de.expect_type = expect.ty.clone();
1774 self.de.wire_type = wire.ty.clone();
1775 let (mut label, label_type) = match expect.id.as_ref() {
1776 Label::Named(name) => (name.clone(), "name"),
1777 Label::Id(hash) | Label::Unnamed(hash) => (hash.to_string(), "id"),
1778 };
1779 if self.de.is_untyped {
1780 let accessor = match expect.ty.as_ref() {
1781 TypeInner::Null => "unit",
1782 TypeInner::Record(_) => "struct",
1783 _ => "newtype",
1784 };
1785 write!(&mut label, ",{label_type},{accessor}").map_err(Error::msg)?;
1786 }
1787 self.de.set_field_name(Label::Named(label).into());
1788 let field = seed.deserialize(&mut *self.de)?;
1789 Ok((field, self))
1790 }
1791 _ => Err(Error::subtype("expect enum")),
1792 }
1793 }
1794}
1795
1796impl<'de> de::VariantAccess<'de> for Compound<'_, 'de> {
1797 type Error = Error;
1798
1799 fn unit_variant(self) -> Result<()> {
1800 check!(
1801 *self.de.expect_type == TypeInner::Null && *self.de.wire_type == TypeInner::Null,
1802 "unit_variant"
1803 );
1804 self.de.add_cost(1)?;
1805 Ok(())
1806 }
1807
1808 fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value>
1809 where
1810 T: de::DeserializeSeed<'de>,
1811 {
1812 self.de.add_cost(1)?;
1813 seed.deserialize(&mut *self.de)
1814 }
1815
1816 fn tuple_variant<V>(self, len: usize, visitor: V) -> Result<V::Value>
1817 where
1818 V: Visitor<'de>,
1819 {
1820 self.de.add_cost(1)?;
1821 de::Deserializer::deserialize_tuple(&mut *self.de, len, visitor)
1822 }
1823
1824 fn struct_variant<V>(self, fields: &'static [&'static str], visitor: V) -> Result<V::Value>
1825 where
1826 V: Visitor<'de>,
1827 {
1828 self.de.add_cost(1)?;
1829 de::Deserializer::deserialize_struct(&mut *self.de, "_", fields, visitor)
1830 }
1831}