parity_scale_codec/
mem_tracking.rs

1// Copyright (C) Parity Technologies (UK) Ltd.
2// SPDX-License-Identifier: Apache-2.0
3
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// 	http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use crate::{Decode, Error, Input};
17use impl_trait_for_tuples::impl_for_tuples;
18
19/// Marker trait used for identifying types that call the [`Input::on_before_alloc_mem`] hook
20/// while decoding.
21pub trait DecodeWithMemTracking: Decode {}
22
23const DECODE_OOM_MSG: &str = "Heap memory limit exceeded while decoding";
24
25#[impl_for_tuples(18)]
26impl DecodeWithMemTracking for Tuple {}
27
28/// `Input` implementation that can be used for limiting the heap memory usage while decoding.
29pub struct MemTrackingInput<'a, I> {
30	input: &'a mut I,
31	used_mem: usize,
32	mem_limit: usize,
33}
34
35impl<'a, I: Input> MemTrackingInput<'a, I> {
36	/// Create a new instance of `MemTrackingInput`.
37	pub fn new(input: &'a mut I, mem_limit: usize) -> Self {
38		Self { input, used_mem: 0, mem_limit }
39	}
40
41	/// Get the `used_mem` field.
42	pub fn used_mem(&self) -> usize {
43		self.used_mem
44	}
45}
46
47impl<I: Input> Input for MemTrackingInput<'_, I> {
48	fn remaining_len(&mut self) -> Result<Option<usize>, Error> {
49		self.input.remaining_len()
50	}
51
52	fn read(&mut self, into: &mut [u8]) -> Result<(), Error> {
53		self.input.read(into)
54	}
55
56	fn read_byte(&mut self) -> Result<u8, Error> {
57		self.input.read_byte()
58	}
59
60	fn descend_ref(&mut self) -> Result<(), Error> {
61		self.input.descend_ref()
62	}
63
64	fn ascend_ref(&mut self) {
65		self.input.ascend_ref()
66	}
67
68	fn on_before_alloc_mem(&mut self, size: usize) -> Result<(), Error> {
69		self.input.on_before_alloc_mem(size)?;
70
71		self.used_mem = self.used_mem.saturating_add(size);
72		if self.used_mem >= self.mem_limit {
73			return Err(DECODE_OOM_MSG.into());
74		}
75
76		Ok(())
77	}
78}
79
80/// Extension trait to [`Decode`] for decoding with a maximum memory limit.
81pub trait DecodeWithMemLimit: DecodeWithMemTracking {
82	/// Decode `Self` with the given maximum memory limit and advance `input` by the number of
83	/// bytes consumed.
84	///
85	/// If `mem_limit` is hit, an error is returned.
86	fn decode_with_mem_limit<I: Input>(input: &mut I, mem_limit: usize) -> Result<Self, Error>;
87}
88
89impl<T> DecodeWithMemLimit for T
90where
91	T: DecodeWithMemTracking,
92{
93	fn decode_with_mem_limit<I: Input>(input: &mut I, mem_limit: usize) -> Result<Self, Error> {
94		let mut input = MemTrackingInput::new(input, mem_limit);
95		T::decode(&mut input)
96	}
97}