1use super::{
4 error::{Error, Result},
5 types::internal::{text_size, type_of, TypeId},
6 types::{Field, Label, SharedLabel, Type, TypeEnv, TypeInner},
7 CandidType,
8};
9#[cfg(feature = "bignum")]
10use super::{Int, Nat};
11use crate::{
12 binary_parser::{BoolValue, Header, Len, PrincipalBytes},
13 types::subtype::{subtype_with_config, Gamma, OptReport},
14};
15use anyhow::{anyhow, Context};
16use binread::BinRead;
17use byteorder::{LittleEndian, ReadBytesExt};
18use serde::de::{self, Visitor};
19use std::fmt::Write;
20use std::{collections::VecDeque, io::Cursor, mem::replace, rc::Rc};
21
22const MAX_TYPE_LEN: i32 = 500;
23
24pub struct IDLDeserialize<'de> {
26 de: Deserializer<'de>,
27}
28impl<'de> IDLDeserialize<'de> {
29 pub fn new(bytes: &'de [u8]) -> Result<Self> {
31 let config = DecoderConfig::new();
32 Self::new_with_config(bytes, &config)
33 }
34 pub fn new_with_config(bytes: &'de [u8], config: &DecoderConfig) -> Result<Self> {
36 let mut de = Deserializer::from_bytes(bytes, config).with_context(|| {
37 if config.full_error_message || bytes.len() <= 500 {
38 format!("Cannot parse header {}", &hex::encode(bytes))
39 } else {
40 "Cannot parse header".to_string()
41 }
42 })?;
43 de.add_cost((de.input.position() as usize).saturating_mul(4))?;
44 Ok(IDLDeserialize { de })
45 }
46 pub fn get_value<T>(&mut self) -> Result<T>
48 where
49 T: de::Deserialize<'de> + CandidType,
50 {
51 self.de.is_untyped = false;
52 self.deserialize_with_type(T::ty())
53 }
54 #[cfg_attr(docsrs, doc(cfg(feature = "value")))]
55 #[cfg(feature = "value")]
56 pub fn get_value_with_type(
58 &mut self,
59 env: &TypeEnv,
60 expected_type: &Type,
61 ) -> Result<crate::types::value::IDLValue> {
62 Rc::make_mut(&mut self.de.table).merge(env)?;
63 self.de.is_untyped = true;
64 self.deserialize_with_type(expected_type.clone())
65 }
66 fn deserialize_with_type<T>(&mut self, expected_type: Type) -> Result<T>
67 where
68 T: de::Deserialize<'de> + CandidType,
69 {
70 let expected_type = self.de.table.trace_type(&expected_type)?;
71 if self.de.types.is_empty() {
72 if matches!(
73 expected_type.as_ref(),
74 TypeInner::Opt(_) | TypeInner::Reserved | TypeInner::Null
75 ) {
76 self.de.expect_type = expected_type;
77 self.de.wire_type = TypeInner::Null.into();
78 return T::deserialize(&mut self.de);
79 } else if self.de.config.full_error_message
80 || text_size(&expected_type, MAX_TYPE_LEN).is_ok()
81 {
82 return Err(Error::msg(format!(
83 "No more values on the wire, the expected type {expected_type} is not opt, null, or reserved"
84 )));
85 } else {
86 return Err(Error::msg("No more values on the wire"));
87 }
88 }
89
90 let (ind, ty) = self.de.types.pop_front().unwrap();
91 self.de.expect_type = if matches!(expected_type.as_ref(), TypeInner::Unknown) {
92 self.de.is_untyped = true;
93 ty.clone()
94 } else {
95 expected_type.clone()
96 };
97 self.de.wire_type = ty.clone();
98
99 let mut v = T::deserialize(&mut self.de).with_context(|| {
100 if self.de.config.full_error_message
101 || (text_size(&ty, MAX_TYPE_LEN).is_ok()
102 && text_size(&expected_type, MAX_TYPE_LEN).is_ok())
103 {
104 format!("Fail to decode argument {ind} from {ty} to {expected_type}")
105 } else {
106 format!("Fail to decode argument {ind}")
107 }
108 });
109 if self.de.config.full_error_message {
110 v = v.with_context(|| self.de.dump_state());
111 }
112 Ok(v?)
113 }
114 pub fn is_done(&self) -> bool {
116 self.de.types.is_empty()
117 }
118 pub fn done(&mut self) -> Result<()> {
120 while !self.is_done() {
121 self.get_value::<crate::Reserved>()?;
122 }
123 let ind = self.de.input.position() as usize;
124 let rest = &self.de.input.get_ref()[ind..];
125 if !rest.is_empty() {
126 if !self.de.config.full_error_message {
127 return Err(Error::msg("Trailing value after finishing deserialization"));
128 } else {
129 return Err(anyhow!(self.de.dump_state()))
130 .context("Trailing value after finishing deserialization")?;
131 }
132 }
133 Ok(())
134 }
135 pub fn get_config(&self) -> DecoderConfig {
137 self.de.config.clone()
138 }
139}
140
141#[derive(Clone)]
142pub struct DecoderConfig {
144 pub decoding_quota: Option<usize>,
145 pub skipping_quota: Option<usize>,
146 full_error_message: bool,
147}
148impl DecoderConfig {
149 pub fn new() -> Self {
153 Self {
154 decoding_quota: None,
155 skipping_quota: None,
156 #[cfg(not(target_arch = "wasm32"))]
157 full_error_message: true,
158 #[cfg(target_arch = "wasm32")]
159 full_error_message: false,
160 }
161 }
162 pub fn set_decoding_quota(&mut self, n: usize) -> &mut Self {
203 self.decoding_quota = Some(n);
204 self
205 }
206 pub fn set_skipping_quota(&mut self, n: usize) -> &mut Self {
212 self.skipping_quota = Some(n);
213 self
214 }
215 pub fn set_full_error_message(&mut self, n: bool) -> &mut Self {
219 self.full_error_message = n;
220 self
221 }
222 pub fn compute_cost(&self, original: &Self) -> Self {
224 let decoding_quota = original
225 .decoding_quota
226 .and_then(|n| Some(n - self.decoding_quota?));
227 let skipping_quota = original
228 .skipping_quota
229 .and_then(|n| Some(n - self.skipping_quota?));
230 Self {
231 decoding_quota,
232 skipping_quota,
233 full_error_message: original.full_error_message,
234 }
235 }
236}
237impl Default for DecoderConfig {
238 fn default() -> Self {
239 Self::new()
240 }
241}
242
243macro_rules! assert {
244 ( false ) => {{
245 return Err(Error::msg(format!(
246 "Internal error at {}:{}. Please file a bug.",
247 file!(),
248 line!()
249 )));
250 }};
251 ( $pred:expr ) => {{
252 if !$pred {
253 return Err(Error::msg(format!(
254 "Internal error at {}:{}. Please file a bug.",
255 file!(),
256 line!()
257 )));
258 }
259 }};
260}
261
262macro_rules! check {
263 ( false ) => {{
264 return Err(Error::Subtype(format!(
265 "Type mismatch at {}:{}",
266 file!(),
267 line!()
268 )));
269 }};
270 ($exp:expr, $msg:expr) => {{
271 if !$exp {
272 return Err(Error::Subtype($msg.to_string()));
273 }
274 }};
275}
276#[cfg(not(target_arch = "wasm32"))]
277macro_rules! check_recursion {
278 ($this:ident $($body:tt)*) => {
279 $this.recursion_depth += 1;
280 match stacker::remaining_stack() {
281 Some(size) if size < 32768 => return Err(Error::msg(format!("Recursion limit exceeded at depth {}", $this.recursion_depth))),
282 None if $this.recursion_depth > 512 => return Err(Error::msg(format!("Recursion limit exceeded at depth {}. Cannot detect stack size, use a conservative bound", $this.recursion_depth))),
283 _ => (),
284 }
285 let __ret = { $this $($body)* };
286 $this.recursion_depth -= 1;
287 __ret
288 };
289}
290#[cfg(target_arch = "wasm32")]
292macro_rules! check_recursion {
293 ($this:ident $($body:tt)*) => {
294 $this $($body)*
295 };
296}
297
298#[derive(Clone)]
299struct Deserializer<'de> {
300 input: Cursor<&'de [u8]>,
301 table: Rc<TypeEnv>,
302 types: VecDeque<(usize, Type)>,
303 wire_type: Type,
304 expect_type: Type,
305 gamma: Gamma,
307 field_name: Option<SharedLabel>,
310 is_untyped: bool,
313 config: DecoderConfig,
314 #[cfg(not(target_arch = "wasm32"))]
315 recursion_depth: u16,
316}
317
318impl<'de> Deserializer<'de> {
319 fn from_bytes(bytes: &'de [u8], config: &DecoderConfig) -> Result<Self> {
320 let mut reader = Cursor::new(bytes);
321 let header = Header::read(&mut reader)?;
322 let (env, types) = header.to_types()?;
323 Ok(Deserializer {
324 input: reader,
325 table: env.into(),
326 types: types.into_iter().enumerate().collect(),
327 wire_type: TypeInner::Unknown.into(),
328 expect_type: TypeInner::Unknown.into(),
329 gamma: Gamma::default(),
330 field_name: None,
331 is_untyped: false,
332 config: config.clone(),
333 #[cfg(not(target_arch = "wasm32"))]
334 recursion_depth: 0,
335 })
336 }
337 fn dump_state(&self) -> String {
338 let hex = hex::encode(self.input.get_ref());
339 let pos = self.input.position() as usize * 2;
340 let (before, after) = hex.split_at(pos);
341 let mut res = format!("input: {before}_{after}\n");
342 if !self.table.0.is_empty() {
343 write!(&mut res, "table: {}", self.table).unwrap();
344 }
345 write!(
346 &mut res,
347 "wire_type: {}, expect_type: {}",
348 self.wire_type, self.expect_type
349 )
350 .unwrap();
351 if let Some(field) = &self.field_name {
352 write!(&mut res, ", field_name: {field:?}").unwrap();
353 }
354 res
355 }
356 fn borrow_bytes(&mut self, len: usize) -> Result<&'de [u8]> {
357 let pos = self.input.position() as usize;
358 let slice = self.input.get_ref();
359 if len > slice.len() || pos + len > slice.len() {
360 return Err(Error::msg(format!("Cannot read {len} bytes")));
361 }
362 let end = pos + len;
363 let res = &slice[pos..end];
364 self.input.set_position(end as u64);
365 Ok(res)
366 }
367 fn check_subtype(&mut self) -> Result<()> {
368 self.add_cost(self.table.0.len())?;
369 subtype_with_config(
370 OptReport::Silence,
371 &mut self.gamma,
372 &self.table,
373 &self.wire_type,
374 &self.expect_type,
375 )
376 .with_context(|| {
377 if self.config.full_error_message
378 || (text_size(&self.wire_type, MAX_TYPE_LEN).is_ok()
379 && text_size(&self.expect_type, MAX_TYPE_LEN).is_ok())
380 {
381 format!(
382 "{} is not a subtype of {}",
383 self.wire_type, self.expect_type,
384 )
385 } else {
386 "subtype mismatch".to_string()
387 }
388 })
389 .map_err(Error::subtype)?;
390 Ok(())
391 }
392 fn unroll_type(&mut self) -> Result<()> {
393 if matches!(
394 self.expect_type.as_ref(),
395 TypeInner::Var(_) | TypeInner::Knot(_)
396 ) {
397 self.add_cost(1)?;
398 self.expect_type = self.table.trace_type(&self.expect_type)?;
399 }
400 if matches!(
401 self.wire_type.as_ref(),
402 TypeInner::Var(_) | TypeInner::Knot(_)
403 ) {
404 self.add_cost(1)?;
405 self.wire_type = self.table.trace_type(&self.wire_type)?;
406 }
407 Ok(())
408 }
409 fn add_cost(&mut self, cost: usize) -> Result<()> {
410 if let Some(n) = self.config.decoding_quota {
411 let cost = if self.is_untyped {
412 cost.saturating_mul(50)
413 } else {
414 cost
415 };
416 if n < cost {
417 return Err(Error::msg("Decoding cost exceeds the limit"));
418 }
419 self.config.decoding_quota = Some(n - cost);
420 }
421 if self.is_untyped {
422 if let Some(n) = self.config.skipping_quota {
423 if n < cost {
424 return Err(Error::msg("Skipping cost exceeds the limit"));
425 }
426 self.config.skipping_quota = Some(n - cost);
427 }
428 }
429 Ok(())
430 }
431 fn set_field_name(&mut self, field: SharedLabel) {
434 if self.field_name.is_some() {
435 unreachable!();
436 }
437 self.field_name = Some(field);
438 }
439 #[cfg_attr(docsrs, doc(cfg(feature = "bignum")))]
446 #[cfg(feature = "bignum")]
447 fn deserialize_int<'a, V>(&'a mut self, visitor: V) -> Result<V::Value>
448 where
449 V: Visitor<'de>,
450 {
451 self.unroll_type()?;
452 assert!(*self.expect_type == TypeInner::Int);
453 let mut bytes = vec![0u8];
454 let pos = self.input.position();
455 let int = match self.wire_type.as_ref() {
456 TypeInner::Int => Int::decode(&mut self.input).map_err(Error::msg)?,
457 TypeInner::Nat => Int(Nat::decode(&mut self.input).map_err(Error::msg)?.0.into()),
458 t => return Err(Error::subtype(format!("{t} cannot be deserialized to int"))),
459 };
460 self.add_cost((self.input.position() - pos) as usize)?;
461 bytes.extend_from_slice(&int.0.to_signed_bytes_le());
462 visitor.visit_byte_buf(bytes)
463 }
464 #[cfg_attr(docsrs, doc(cfg(feature = "bignum")))]
465 #[cfg(feature = "bignum")]
466 fn deserialize_nat<'a, V>(&'a mut self, visitor: V) -> Result<V::Value>
467 where
468 V: Visitor<'de>,
469 {
470 self.unroll_type()?;
471 check!(
472 *self.expect_type == TypeInner::Nat && *self.wire_type == TypeInner::Nat,
473 "nat"
474 );
475 let mut bytes = vec![1u8];
476 let pos = self.input.position();
477 let nat = Nat::decode(&mut self.input).map_err(Error::msg)?;
478 self.add_cost((self.input.position() - pos) as usize)?;
479 bytes.extend_from_slice(&nat.0.to_bytes_le());
480 visitor.visit_byte_buf(bytes)
481 }
482 fn deserialize_principal<'a, V>(&'a mut self, visitor: V) -> Result<V::Value>
483 where
484 V: Visitor<'de>,
485 {
486 self.unroll_type()?;
487 check!(
488 *self.expect_type == TypeInner::Principal && *self.wire_type == TypeInner::Principal,
489 "principal"
490 );
491 let mut bytes = vec![2u8];
492 let id = PrincipalBytes::read(&mut self.input)?;
493 self.add_cost(std::cmp::max(30, id.len as usize))?;
494 bytes.extend_from_slice(&id.inner);
495 visitor.visit_byte_buf(bytes)
496 }
497 fn deserialize_reserved<'a, V>(&'a mut self, visitor: V) -> Result<V::Value>
498 where
499 V: Visitor<'de>,
500 {
501 self.add_cost(1)?;
502 let bytes = vec![3u8];
503 visitor.visit_byte_buf(bytes)
504 }
505 fn deserialize_service<'a, V>(&'a mut self, visitor: V) -> Result<V::Value>
506 where
507 V: Visitor<'de>,
508 {
509 self.unroll_type()?;
510 self.check_subtype()?;
511 let mut bytes = vec![4u8];
512 let id = PrincipalBytes::read(&mut self.input)?;
513 self.add_cost(std::cmp::max(30, id.len as usize))?;
514 bytes.extend_from_slice(&id.inner);
515 visitor.visit_byte_buf(bytes)
516 }
517 fn deserialize_function<'a, V>(&'a mut self, visitor: V) -> Result<V::Value>
518 where
519 V: Visitor<'de>,
520 {
521 self.unroll_type()?;
522 self.check_subtype()?;
523 if !BoolValue::read(&mut self.input)?.0 {
524 return Err(Error::msg("Opaque reference not supported"));
525 }
526 let mut bytes = vec![5u8];
527 let id = PrincipalBytes::read(&mut self.input)?;
528 let len = Len::read(&mut self.input)?.0;
529 let meth = self.borrow_bytes(len)?;
530 self.add_cost(
531 std::cmp::max(30, id.len as usize)
532 .saturating_add(len)
533 .saturating_add(2),
534 )?;
535 leb128::write::unsigned(&mut bytes, len as u64)?;
537 bytes.extend_from_slice(meth);
538 bytes.extend_from_slice(&id.inner);
539 visitor.visit_byte_buf(bytes)
540 }
541 fn deserialize_blob<'a, V>(&'a mut self, visitor: V) -> Result<V::Value>
542 where
543 V: Visitor<'de>,
544 {
545 self.unroll_type()?;
546 check!(
547 self.expect_type.is_blob(&self.table) && self.wire_type.is_blob(&self.table),
548 "blob"
549 );
550 let len = Len::read(&mut self.input)?.0;
551 self.add_cost(len.saturating_add(1))?;
552 let blob = self.borrow_bytes(len)?;
553 let mut bytes = Vec::with_capacity(len + 1);
554 bytes.push(6u8);
555 bytes.extend_from_slice(blob);
556 visitor.visit_byte_buf(bytes)
557 }
558 fn deserialize_empty<'a, V>(&'a mut self, _visitor: V) -> Result<V::Value>
559 where
560 V: Visitor<'de>,
561 {
562 Err(if *self.wire_type == TypeInner::Empty {
563 Error::msg("Cannot decode empty type")
564 } else {
565 Error::subtype("Cannot decode empty type")
566 })
567 }
568 fn deserialize_future<'a, V>(&'a mut self, visitor: V) -> Result<V::Value>
569 where
570 V: Visitor<'de>,
571 {
572 let len = Len::read(&mut self.input)?.0 as u64;
573 self.add_cost((len as usize).saturating_add(1))?;
574 Len::read(&mut self.input)?;
575 let slice_len = self.input.get_ref().len() as u64;
576 let pos = self.input.position();
577 if len > slice_len || pos + len > slice_len {
578 return Err(Error::msg(format!("Cannot read {len} bytes")));
579 }
580 self.input.set_position(pos + len);
581 visitor.visit_unit()
582 }
583 fn recoverable_visit_some<'a, V>(&'a mut self, visitor: V) -> Result<V::Value>
584 where
585 V: Visitor<'de>,
586 {
587 use de::Deserializer;
588 let tid = type_of(&visitor);
589 if tid != TypeId::of::<de::IgnoredAny>() && !tid.name.starts_with("serde::de::impls::OptionVisitor<")
595 && !tid.name.starts_with("serde_core::de::impls::OptionVisitor<")
598 {
599 #[cfg(feature = "value")]
600 if tid != TypeId::of::<crate::types::value::IDLValueVisitor>() {
601 panic!("Not a valid visitor: {tid:?}");
603 }
604 #[cfg(not(feature = "value"))]
605 panic!("Not a valid visitor: {tid:?}");
606 }
607 let v = unsafe { std::ptr::read(&visitor) };
609 let self_clone = self.clone();
610 match v.visit_some(&mut *self) {
611 Ok(v) => Ok(v),
612 Err(Error::Subtype(_)) => {
613 *self = Self {
614 config: self.config.clone(),
616 ..self_clone
617 };
618 self.add_cost(10)?;
619 self.deserialize_ignored_any(serde::de::IgnoredAny)?;
620 visitor.visit_none()
621 }
622 Err(e) => Err(e),
623 }
624 }
625}
626
627macro_rules! primitive_impl {
628 ($ty:ident, $type:expr, $cost:literal, $($value:tt)*) => {
629 paste::item! {
630 fn [<deserialize_ $ty>]<V>(self, visitor: V) -> Result<V::Value>
631 where V: Visitor<'de> {
632 self.unroll_type()?;
633 check!(*self.expect_type == $type && *self.wire_type == $type, stringify!($type));
634 self.add_cost($cost)?;
635 let val = self.input.$($value)*().map_err(|_| Error::msg(format!("Cannot read {} value", stringify!($type))))?;
636 visitor.[<visit_ $ty>](val)
637 }
638 }
639 };
640}
641
642impl<'de> de::Deserializer<'de> for &mut Deserializer<'de> {
643 type Error = Error;
644 fn deserialize_any<V>(self, visitor: V) -> Result<V::Value>
645 where
646 V: Visitor<'de>,
647 {
648 if self.field_name.is_some() {
649 return self.deserialize_identifier(visitor);
650 }
651 self.unroll_type()?;
652 match self.expect_type.as_ref() {
653 #[cfg(feature = "bignum")]
654 TypeInner::Int => self.deserialize_int(visitor),
655 #[cfg(not(feature = "bignum"))]
656 TypeInner::Int => self.deserialize_i128(visitor),
657 #[cfg(feature = "bignum")]
658 TypeInner::Nat => self.deserialize_nat(visitor),
659 #[cfg(not(feature = "bignum"))]
660 TypeInner::Nat => self.deserialize_u128(visitor),
661 TypeInner::Nat8 => self.deserialize_u8(visitor),
662 TypeInner::Nat16 => self.deserialize_u16(visitor),
663 TypeInner::Nat32 => self.deserialize_u32(visitor),
664 TypeInner::Nat64 => self.deserialize_u64(visitor),
665 TypeInner::Int8 => self.deserialize_i8(visitor),
666 TypeInner::Int16 => self.deserialize_i16(visitor),
667 TypeInner::Int32 => self.deserialize_i32(visitor),
668 TypeInner::Int64 => self.deserialize_i64(visitor),
669 TypeInner::Float32 => self.deserialize_f32(visitor),
670 TypeInner::Float64 => self.deserialize_f64(visitor),
671 TypeInner::Bool => self.deserialize_bool(visitor),
672 TypeInner::Text => self.deserialize_string(visitor),
673 TypeInner::Null => self.deserialize_unit(visitor),
674 TypeInner::Reserved => {
675 if self.wire_type.as_ref() != &TypeInner::Reserved {
676 self.deserialize_ignored_any(serde::de::IgnoredAny)?;
677 }
678 self.deserialize_reserved(visitor)
679 }
680 TypeInner::Empty => self.deserialize_empty(visitor),
681 TypeInner::Principal => self.deserialize_principal(visitor),
682 TypeInner::Opt(_) => self.deserialize_option(visitor),
684 TypeInner::Vec(_) if self.expect_type.is_blob(&self.table) => {
686 self.deserialize_blob(visitor)
687 }
688 TypeInner::Vec(_) => self.deserialize_seq(visitor),
689 TypeInner::Record(_) => self.deserialize_struct("_", &[], visitor),
690 TypeInner::Variant(_) => self.deserialize_enum("_", &[], visitor),
691 TypeInner::Service(_) => self.deserialize_service(visitor),
692 TypeInner::Func(_) => self.deserialize_function(visitor),
693 TypeInner::Future => self.deserialize_future(visitor),
694 _ => assert!(false),
695 }
696 }
697 fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value>
698 where
699 V: Visitor<'de>,
700 {
701 let is_untyped = replace(&mut self.is_untyped, true);
702 self.expect_type = self.wire_type.clone();
703 let v = self.deserialize_any(visitor);
704 self.is_untyped = is_untyped;
705 v
706 }
707
708 primitive_impl!(i8, TypeInner::Int8, 1, read_i8);
709 primitive_impl!(i16, TypeInner::Int16, 2, read_i16::<LittleEndian>);
710 primitive_impl!(i32, TypeInner::Int32, 4, read_i32::<LittleEndian>);
711 primitive_impl!(i64, TypeInner::Int64, 8, read_i64::<LittleEndian>);
712 primitive_impl!(u8, TypeInner::Nat8, 1, read_u8);
713 primitive_impl!(u16, TypeInner::Nat16, 2, read_u16::<LittleEndian>);
714 primitive_impl!(u32, TypeInner::Nat32, 4, read_u32::<LittleEndian>);
715 primitive_impl!(u64, TypeInner::Nat64, 8, read_u64::<LittleEndian>);
716 primitive_impl!(f32, TypeInner::Float32, 4, read_f32::<LittleEndian>);
717 primitive_impl!(f64, TypeInner::Float64, 8, read_f64::<LittleEndian>);
718
719 fn is_human_readable(&self) -> bool {
720 false
721 }
722 fn deserialize_i128<V>(self, visitor: V) -> Result<V::Value>
723 where
724 V: Visitor<'de>,
725 {
726 use crate::types::leb128::{decode_int, decode_nat};
727 self.unroll_type()?;
728 assert!(*self.expect_type == TypeInner::Int);
729 self.add_cost(16)?;
730 let value: i128 = match self.wire_type.as_ref() {
731 TypeInner::Int => decode_int(&mut self.input)?,
732 TypeInner::Nat => i128::try_from(decode_nat(&mut self.input)?)
733 .map_err(|_| Error::msg("Cannot convert nat to i128"))?,
734 t => return Err(Error::subtype(format!("{t} cannot be deserialized to int"))),
735 };
736 visitor.visit_i128(value)
737 }
738 fn deserialize_u128<V>(self, visitor: V) -> Result<V::Value>
739 where
740 V: Visitor<'de>,
741 {
742 self.unroll_type()?;
743 check!(
744 *self.expect_type == TypeInner::Nat && *self.wire_type == TypeInner::Nat,
745 "nat"
746 );
747 self.add_cost(16)?;
748 let value = crate::types::leb128::decode_nat(&mut self.input)?;
749 visitor.visit_u128(value)
750 }
751 fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value>
752 where
753 V: Visitor<'de>,
754 {
755 self.unroll_type()?;
756 check!(
757 *self.expect_type == TypeInner::Null && matches!(*self.wire_type, TypeInner::Null),
758 "unit"
759 );
760 self.add_cost(1)?;
761 visitor.visit_unit()
762 }
763 fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value>
764 where
765 V: Visitor<'de>,
766 {
767 self.unroll_type()?;
768 check!(
769 *self.expect_type == TypeInner::Bool && *self.wire_type == TypeInner::Bool,
770 "bool"
771 );
772 self.add_cost(1)?;
773 let res = BoolValue::read(&mut self.input)?;
774 visitor.visit_bool(res.0)
775 }
776 fn deserialize_string<V>(self, visitor: V) -> Result<V::Value>
777 where
778 V: Visitor<'de>,
779 {
780 self.unroll_type()?;
781 check!(
782 *self.expect_type == TypeInner::Text && *self.wire_type == TypeInner::Text,
783 "text"
784 );
785 let len = Len::read(&mut self.input)?.0;
786 self.add_cost(len.saturating_add(1))?;
787 let bytes = self.borrow_bytes(len)?.to_owned();
788 let value = String::from_utf8(bytes).map_err(Error::msg)?;
789 visitor.visit_string(value)
790 }
791 fn deserialize_str<V>(self, visitor: V) -> Result<V::Value>
792 where
793 V: Visitor<'de>,
794 {
795 self.unroll_type()?;
796 check!(
797 *self.expect_type == TypeInner::Text && *self.wire_type == TypeInner::Text,
798 "text"
799 );
800 let len = Len::read(&mut self.input)?.0;
801 self.add_cost(len.saturating_add(1))?;
802 let slice = self.borrow_bytes(len)?;
803 let value: &str = std::str::from_utf8(slice).map_err(Error::msg)?;
804 visitor.visit_borrowed_str(value)
805 }
806 fn deserialize_unit_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value>
807 where
808 V: Visitor<'de>,
809 {
810 self.add_cost(1)?;
811 self.deserialize_unit(visitor)
812 }
813 fn deserialize_newtype_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value>
814 where
815 V: Visitor<'de>,
816 {
817 self.add_cost(1)?;
818 visitor.visit_newtype_struct(self)
819 }
820 fn deserialize_option<V>(self, visitor: V) -> Result<V::Value>
821 where
822 V: Visitor<'de>,
823 {
824 self.unroll_type()?;
825 self.add_cost(1)?;
826 match (self.wire_type.as_ref(), self.expect_type.as_ref()) {
827 (TypeInner::Null | TypeInner::Reserved, TypeInner::Opt(_)) => visitor.visit_none(),
828 (TypeInner::Opt(t1), TypeInner::Opt(t2)) => {
829 self.wire_type = t1.clone();
830 self.expect_type = t2.clone();
831 if BoolValue::read(&mut self.input)?.0 {
832 check_recursion! {
833 self.recoverable_visit_some(visitor)
834 }
835 } else {
836 visitor.visit_none()
837 }
838 }
839 (_, TypeInner::Opt(t2)) => {
840 self.expect_type = self.table.trace_type(t2)?;
841 check_recursion! {
842 self.recoverable_visit_some(visitor)
843 }
844 }
845 (_, _) => check!(false),
846 }
847 }
848 fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value>
849 where
850 V: Visitor<'de>,
851 {
852 check_recursion! {
853 self.unroll_type()?;
854 self.add_cost(1)?;
855 match (self.expect_type.as_ref(), self.wire_type.as_ref()) {
856 (TypeInner::Vec(e), TypeInner::Vec(w)) => {
857 let expect = e.clone();
858 let wire = self.table.trace_type(w)?;
859 let len = Len::read(&mut self.input)?.0;
860 visitor.visit_seq(Compound::new(self, Style::Vector { len, expect, wire }))
861 }
862 (TypeInner::Record(e), TypeInner::Record(w)) => {
863 let expect = e.clone().into();
864 let wire = w.clone().into();
865 check!(self.expect_type.is_tuple(), "seq_tuple");
866 if !self.wire_type.is_tuple() {
867 return Err(Error::subtype(format!(
868 "{} is not a tuple type",
869 self.wire_type
870 )));
871 }
872 let value =
873 visitor.visit_seq(Compound::new(self, Style::Struct { expect, wire }))?;
874 Ok(value)
875 }
876 _ => check!(false),
877 }
878 }
879 }
880 fn deserialize_byte_buf<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
881 self.unroll_type()?;
882 check!(
883 *self.expect_type == TypeInner::Vec(TypeInner::Nat8.into())
884 && *self.wire_type == TypeInner::Vec(TypeInner::Nat8.into()),
885 "vec nat8"
886 );
887 let len = Len::read(&mut self.input)?.0;
888 self.add_cost(len.saturating_add(1))?;
889 let bytes = self.borrow_bytes(len)?.to_owned();
890 visitor.visit_byte_buf(bytes)
891 }
892 fn deserialize_bytes<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
893 self.unroll_type()?;
894 match self.expect_type.as_ref() {
895 TypeInner::Principal => self.deserialize_principal(visitor),
896 TypeInner::Vec(t) if **t == TypeInner::Nat8 => {
897 let len = Len::read(&mut self.input)?.0;
898 self.add_cost(len.saturating_add(1))?;
899 let slice = self.borrow_bytes(len)?;
900 visitor.visit_borrowed_bytes(slice)
901 }
902 _ => Err(Error::subtype("bytes only takes principal or vec nat8")),
903 }
904 }
905 fn deserialize_map<V>(self, visitor: V) -> Result<V::Value>
906 where
907 V: Visitor<'de>,
908 {
909 check_recursion! {
910 self.unroll_type()?;
911 self.add_cost(1)?;
912 match (self.expect_type.as_ref(), self.wire_type.as_ref()) {
913 (TypeInner::Vec(e), TypeInner::Vec(w)) => {
914 let e = self.table.trace_type(e)?;
915 let w = self.table.trace_type(w)?;
916 match (e.as_ref(), w.as_ref()) {
917 (TypeInner::Record(ref e), TypeInner::Record(ref w)) => {
918 match (&e[..], &w[..]) {
919 (
920 [Field { id: e_id0, ty: ek }, Field { id: e_id1, ty: ev }],
921 [Field { id: w_id0, ty: wk }, Field { id: w_id1, ty: wv }],
922 ) if **e_id0 == Label::Id(0)
923 && **e_id1 == Label::Id(1)
924 && **w_id0 == Label::Id(0)
925 && **w_id1 == Label::Id(1) =>
926 {
927 let expect = (ek.clone(), ev.clone());
928 let wire = (wk.clone(), wv.clone());
929 let len = Len::read(&mut self.input)?.0;
930 visitor.visit_map(Compound::new(
931 self,
932 Style::Map { len, expect, wire },
933 ))
934 }
935 _ => Err(Error::subtype("expect a key-value pair")),
936 }
937 }
938 _ => Err(Error::subtype("expect a key-value pair")),
939 }
940 }
941 _ => check!(false),
942 }
943 }
944 }
945 fn deserialize_tuple<V>(self, _len: usize, visitor: V) -> Result<V::Value>
946 where
947 V: Visitor<'de>,
948 {
949 check_recursion! {
950 self.add_cost(1)?;
951 self.deserialize_seq(visitor)
952 }
953 }
954 fn deserialize_tuple_struct<V>(
955 self,
956 _name: &'static str,
957 _len: usize,
958 visitor: V,
959 ) -> Result<V::Value>
960 where
961 V: Visitor<'de>,
962 {
963 check_recursion! {
964 self.add_cost(1)?;
965 self.deserialize_seq(visitor)
966 }
967 }
968 fn deserialize_struct<V>(
969 self,
970 _name: &'static str,
971 _fields: &'static [&'static str],
972 visitor: V,
973 ) -> Result<V::Value>
974 where
975 V: Visitor<'de>,
976 {
977 check_recursion! {
978 self.unroll_type()?;
979 self.add_cost(1)?;
980 match (self.expect_type.as_ref(), self.wire_type.as_ref()) {
981 (TypeInner::Record(e), TypeInner::Record(w)) => {
982 let expect = e.clone().into();
983 let wire = w.clone().into();
984 let value =
985 visitor.visit_map(Compound::new(self, Style::Struct { expect, wire }))?;
986 Ok(value)
987 }
988 _ => check!(false),
989 }
990 }
991 }
992 fn deserialize_enum<V>(
993 self,
994 _name: &'static str,
995 _variants: &'static [&'static str],
996 visitor: V,
997 ) -> Result<V::Value>
998 where
999 V: Visitor<'de>,
1000 {
1001 check_recursion! {
1002 self.unroll_type()?;
1003 self.add_cost(1)?;
1004 match (self.expect_type.as_ref(), self.wire_type.as_ref()) {
1005 (TypeInner::Variant(e), TypeInner::Variant(w)) => {
1006 let index = Len::read(&mut self.input)?.0;
1007 let len = w.len();
1008 if index >= len {
1009 return Err(Error::msg(format!(
1010 "Variant index {index} larger than length {len}"
1011 )));
1012 }
1013 let wire = w[index].clone();
1014 let expect = match e.iter().find(|f| f.id == wire.id) {
1015 Some(v) => v.clone(),
1016 None => {
1017 return Err(Error::subtype(format!("Unknown variant field {}", wire.id)));
1018 }
1019 };
1020 visitor.visit_enum(Compound::new(self, Style::Enum { expect, wire }))
1021 }
1022 _ => check!(false),
1023 }
1024 }
1025 }
1026 fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value>
1027 where
1028 V: Visitor<'de>,
1029 {
1030 match self.field_name.take() {
1031 Some(l) => match l.as_ref() {
1032 Label::Named(name) => {
1033 self.add_cost(name.len())?;
1034 visitor.visit_string(name.to_string())
1035 }
1036 Label::Id(hash) | Label::Unnamed(hash) => {
1037 self.add_cost(4)?;
1038 visitor.visit_u32(*hash)
1039 }
1040 },
1041 None => assert!(false),
1042 }
1043 }
1044
1045 serde::forward_to_deserialize_any! {
1046 char
1047 }
1048}
1049
1050#[derive(Debug)]
1051enum Style {
1052 Vector {
1053 len: usize,
1054 expect: Type,
1055 wire: Type,
1056 },
1057 Struct {
1058 expect: VecDeque<Field>,
1059 wire: VecDeque<Field>,
1060 },
1061 Enum {
1062 expect: Field,
1063 wire: Field,
1064 },
1065 Map {
1066 len: usize,
1067 expect: (Type, Type),
1068 wire: (Type, Type),
1069 },
1070}
1071
1072struct Compound<'a, 'de> {
1073 de: &'a mut Deserializer<'de>,
1074 style: Style,
1075}
1076
1077impl<'a, 'de> Compound<'a, 'de> {
1078 fn new(de: &'a mut Deserializer<'de>, style: Style) -> Self {
1079 Compound { de, style }
1080 }
1081}
1082
1083impl<'de> de::SeqAccess<'de> for Compound<'_, 'de> {
1084 type Error = Error;
1085
1086 fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
1087 where
1088 T: de::DeserializeSeed<'de>,
1089 {
1090 self.de.add_cost(3)?;
1091 match self.style {
1092 Style::Vector {
1093 ref mut len,
1094 ref expect,
1095 ref wire,
1096 } => {
1097 if *len == 0 {
1098 return Ok(None);
1099 }
1100 *len -= 1;
1101 self.de.expect_type = expect.clone();
1102 self.de.wire_type = wire.clone();
1103 seed.deserialize(&mut *self.de).map(Some)
1104 }
1105 Style::Struct {
1106 ref mut expect,
1107 ref mut wire,
1108 } => {
1109 if expect.is_empty() && wire.is_empty() {
1110 return Ok(None);
1111 }
1112 self.de.expect_type = expect
1113 .pop_front()
1114 .map(|f| f.ty)
1115 .unwrap_or_else(|| TypeInner::Reserved.into());
1116 self.de.wire_type = wire
1117 .pop_front()
1118 .map(|f| f.ty)
1119 .unwrap_or_else(|| TypeInner::Null.into());
1120 seed.deserialize(&mut *self.de).map(Some)
1121 }
1122 _ => Err(Error::subtype("expect vector or tuple")),
1123 }
1124 }
1125
1126 fn size_hint(&self) -> Option<usize> {
1127 match &self.style {
1128 Style::Vector { len, .. } => Some(*len),
1129 Style::Struct { expect, wire, .. } => Some(expect.len().min(wire.len())),
1130 _ => None,
1131 }
1132 }
1133}
1134
1135impl<'de> de::MapAccess<'de> for Compound<'_, 'de> {
1136 type Error = Error;
1137 fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>>
1138 where
1139 K: de::DeserializeSeed<'de>,
1140 {
1141 self.de.add_cost(4)?;
1142 match self.style {
1143 Style::Struct {
1144 ref mut expect,
1145 ref mut wire,
1146 } => {
1147 match (expect.front(), wire.front()) {
1148 (Some(e), Some(w)) => {
1149 use std::cmp::Ordering;
1150 match e.id.get_id().cmp(&w.id.get_id()) {
1151 Ordering::Equal => {
1152 self.de.set_field_name(e.id.clone());
1153 self.de.expect_type = expect.pop_front().unwrap().ty;
1154 self.de.wire_type = wire.pop_front().unwrap().ty;
1155 }
1156 Ordering::Less => {
1157 let field = e.id.clone();
1159 self.de.set_field_name(field.clone());
1160 let expect = expect.pop_front().unwrap().ty;
1161 self.de.expect_type = self.de.table.trace_type(&expect)?;
1162 check!(
1163 matches!(
1164 self.de.expect_type.as_ref(),
1165 TypeInner::Opt(_) | TypeInner::Reserved | TypeInner::Null
1166 ),
1167 format!("field {field} is not optional field")
1168 );
1169 self.de.wire_type = TypeInner::Null.into();
1170 }
1171 Ordering::Greater => {
1172 self.de.set_field_name(Label::Named("_".to_owned()).into());
1173 self.de.wire_type = wire.pop_front().unwrap().ty;
1174 self.de.expect_type = TypeInner::Reserved.into();
1175 }
1176 }
1177 }
1178 (None, Some(_)) => {
1179 self.de.set_field_name(Label::Named("_".to_owned()).into());
1180 self.de.wire_type = wire.pop_front().unwrap().ty;
1181 self.de.expect_type = TypeInner::Reserved.into();
1182 }
1183 (Some(e), None) => {
1184 self.de.set_field_name(e.id.clone());
1185 self.de.expect_type = expect.pop_front().unwrap().ty;
1186 self.de.wire_type = TypeInner::Null.into();
1187 }
1188 (None, None) => return Ok(None),
1189 }
1190 seed.deserialize(&mut *self.de).map(Some)
1191 }
1192 Style::Map {
1193 ref mut len,
1194 ref expect,
1195 ref wire,
1196 } => {
1197 if *len == 0 {
1199 return Ok(None);
1200 }
1201 self.de.expect_type = expect.0.clone();
1202 self.de.wire_type = wire.0.clone();
1203 *len -= 1;
1204 seed.deserialize(&mut *self.de).map(Some)
1205 }
1206 _ => Err(Error::msg("expect struct or map")),
1207 }
1208 }
1209 fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value>
1210 where
1211 V: de::DeserializeSeed<'de>,
1212 {
1213 match &self.style {
1214 Style::Map { expect, wire, .. } => {
1215 self.de.add_cost(3)?;
1216 self.de.expect_type = expect.1.clone();
1217 self.de.wire_type = wire.1.clone();
1218 seed.deserialize(&mut *self.de)
1219 }
1220 _ => {
1221 self.de.add_cost(1)?;
1222 seed.deserialize(&mut *self.de)
1223 }
1224 }
1225 }
1226
1227 fn size_hint(&self) -> Option<usize> {
1228 match &self.style {
1229 Style::Map { len, .. } => Some(*len),
1230 Style::Struct { expect, wire, .. } => Some(expect.len().min(wire.len())),
1231 _ => None,
1232 }
1233 }
1234}
1235
1236impl<'de> de::EnumAccess<'de> for Compound<'_, 'de> {
1237 type Error = Error;
1238 type Variant = Self;
1239
1240 fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant)>
1241 where
1242 V: de::DeserializeSeed<'de>,
1243 {
1244 self.de.add_cost(4)?;
1245 match &self.style {
1246 Style::Enum { expect, wire } => {
1247 self.de.expect_type = expect.ty.clone();
1248 self.de.wire_type = wire.ty.clone();
1249 let (mut label, label_type) = match expect.id.as_ref() {
1250 Label::Named(name) => (name.clone(), "name"),
1251 Label::Id(hash) | Label::Unnamed(hash) => (hash.to_string(), "id"),
1252 };
1253 if self.de.is_untyped {
1254 let accessor = match expect.ty.as_ref() {
1255 TypeInner::Null => "unit",
1256 TypeInner::Record(_) => "struct",
1257 _ => "newtype",
1258 };
1259 write!(&mut label, ",{label_type},{accessor}").map_err(Error::msg)?;
1260 }
1261 self.de.set_field_name(Label::Named(label).into());
1262 let field = seed.deserialize(&mut *self.de)?;
1263 Ok((field, self))
1264 }
1265 _ => Err(Error::subtype("expect enum")),
1266 }
1267 }
1268}
1269
1270impl<'de> de::VariantAccess<'de> for Compound<'_, 'de> {
1271 type Error = Error;
1272
1273 fn unit_variant(self) -> Result<()> {
1274 check!(
1275 *self.de.expect_type == TypeInner::Null && *self.de.wire_type == TypeInner::Null,
1276 "unit_variant"
1277 );
1278 self.de.add_cost(1)?;
1279 Ok(())
1280 }
1281
1282 fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value>
1283 where
1284 T: de::DeserializeSeed<'de>,
1285 {
1286 self.de.add_cost(1)?;
1287 seed.deserialize(self.de)
1288 }
1289
1290 fn tuple_variant<V>(self, len: usize, visitor: V) -> Result<V::Value>
1291 where
1292 V: Visitor<'de>,
1293 {
1294 self.de.add_cost(1)?;
1295 de::Deserializer::deserialize_tuple(self.de, len, visitor)
1296 }
1297
1298 fn struct_variant<V>(self, fields: &'static [&'static str], visitor: V) -> Result<V::Value>
1299 where
1300 V: Visitor<'de>,
1301 {
1302 self.de.add_cost(1)?;
1303 de::Deserializer::deserialize_struct(self.de, "_", fields, visitor)
1304 }
1305}