parity_scale_codec/
mem_tracking.rs1use crate::{Decode, Error, Input};
17use impl_trait_for_tuples::impl_for_tuples;
18
19pub trait DecodeWithMemTracking: Decode {}
22
23const DECODE_OOM_MSG: &str = "Heap memory limit exceeded while decoding";
24
25#[impl_for_tuples(18)]
26impl DecodeWithMemTracking for Tuple {}
27
28pub struct MemTrackingInput<'a, I> {
30 input: &'a mut I,
31 used_mem: usize,
32 mem_limit: usize,
33}
34
35impl<'a, I: Input> MemTrackingInput<'a, I> {
36 pub fn new(input: &'a mut I, mem_limit: usize) -> Self {
38 Self { input, used_mem: 0, mem_limit }
39 }
40
41 pub fn used_mem(&self) -> usize {
43 self.used_mem
44 }
45}
46
47impl<I: Input> Input for MemTrackingInput<'_, I> {
48 fn remaining_len(&mut self) -> Result<Option<usize>, Error> {
49 self.input.remaining_len()
50 }
51
52 fn read(&mut self, into: &mut [u8]) -> Result<(), Error> {
53 self.input.read(into)
54 }
55
56 fn read_byte(&mut self) -> Result<u8, Error> {
57 self.input.read_byte()
58 }
59
60 fn descend_ref(&mut self) -> Result<(), Error> {
61 self.input.descend_ref()
62 }
63
64 fn ascend_ref(&mut self) {
65 self.input.ascend_ref()
66 }
67
68 fn on_before_alloc_mem(&mut self, size: usize) -> Result<(), Error> {
69 self.input.on_before_alloc_mem(size)?;
70
71 self.used_mem = self.used_mem.saturating_add(size);
72 if self.used_mem >= self.mem_limit {
73 return Err(DECODE_OOM_MSG.into());
74 }
75
76 Ok(())
77 }
78}
79
80pub trait DecodeWithMemLimit: DecodeWithMemTracking {
82 fn decode_with_mem_limit<I: Input>(input: &mut I, mem_limit: usize) -> Result<Self, Error>;
87}
88
89impl<T> DecodeWithMemLimit for T
90where
91 T: DecodeWithMemTracking,
92{
93 fn decode_with_mem_limit<I: Input>(input: &mut I, mem_limit: usize) -> Result<Self, Error> {
94 let mut input = MemTrackingInput::new(input, mem_limit);
95 T::decode(&mut input)
96 }
97}