1use crate::{BufferReader, DecodeError};
2use core::marker::PhantomData;
3#[derive(Debug, Clone, Copy, PartialEq, Eq)]
4pub struct DecodeLimits {
5 pub max_alloc: usize,
6 pub max_depth: u16,
7 pub max_string_len: usize,
8 pub max_vec_len: usize,
9 pub max_slice_len: usize,
10}
11
12impl Default for DecodeLimits {
13 fn default() -> Self {
14 Self {
15 max_alloc: 16 * 1024 * 1024, max_depth: 256,
17 max_string_len: 16 * 1024 * 1024, max_vec_len: 1024 * 1024, max_slice_len: 1024 * 1024, }
21 }
22}
23
24impl DecodeLimits {
25 pub const fn unlimited() -> Self {
26 Self {
27 max_alloc: usize::MAX,
28 max_depth: u16::MAX,
29 max_string_len: usize::MAX,
30 max_vec_len: usize::MAX,
31 max_slice_len: usize::MAX,
32 }
33 }
34
35 pub const fn conservative() -> Self {
36 Self {
37 max_alloc: 64 * 1024, max_depth: 32,
39 max_string_len: 4096,
40 max_vec_len: 4096,
41 max_slice_len: 4096,
42 }
43 }
44}
45
46#[derive(Debug)]
47pub struct DecodeContext {
48 pub limits: DecodeLimits,
49 pub alloc_used: usize,
50 pub current_depth: u16,
51}
52
53impl DecodeContext {
54 pub fn new(limits: DecodeLimits) -> Self {
55 Self {
56 limits,
57 alloc_used: 0,
58 current_depth: 0,
59 }
60 }
61
62 #[inline]
63 pub fn check_alloc(&mut self, n: usize) -> Result<(), DecodeError> {
64 if n > self.limits.max_alloc {
65 return Err(DecodeError::AllocationLimitExceeded);
66 }
67 self.alloc_used += n;
68 Ok(())
69 }
70
71 #[inline]
72 pub fn depth_enter(&mut self) -> Result<(), DecodeError> {
73 if self.current_depth >= self.limits.max_depth {
74 return Err(DecodeError::DepthLimitExceeded);
75 }
76 self.current_depth += 1;
77 Ok(())
78 }
79
80 #[inline]
81 pub fn depth_exit(&mut self) {
82 self.current_depth = self.current_depth.saturating_sub(1);
83 }
84}
85
86#[derive(Debug)]
87pub struct LimitedReader<'a, R: BufferReader<'a>> {
88 inner: R,
89 ctx: DecodeContext,
90 _marker: PhantomData<&'a ()>,
91}
92
93impl<'a, R: BufferReader<'a>> LimitedReader<'a, R> {
94 #[inline]
95 pub fn new(inner: R, limits: DecodeLimits) -> Self {
96 Self {
97 inner,
98 ctx: DecodeContext::new(limits),
99 _marker: PhantomData,
100 }
101 }
102
103 #[inline]
104 pub fn finish(self) -> (R, DecodeContext) {
105 (self.inner, self.ctx)
106 }
107
108 #[inline]
109 pub fn context(&self) -> &DecodeContext {
110 &self.ctx
111 }
112
113 #[inline]
114 pub fn context_mut(&mut self) -> &mut DecodeContext {
115 &mut self.ctx
116 }
117}
118
119impl<'a, R: BufferReader<'a>> BufferReader<'a> for LimitedReader<'a, R> {
120 #[inline]
121 fn peek(&self) -> Option<u8> {
122 self.inner.peek()
123 }
124
125 #[inline]
126 fn next(&mut self) -> Option<u8> {
127 self.inner.next()
128 }
129
130 #[inline]
131 fn read_exact(&mut self, buf: &mut [u8]) -> Result<(), DecodeError> {
132 self.inner.read_exact(buf)
133 }
134
135 #[inline]
136 fn remaining(&self) -> &'a [u8] {
137 self.inner.remaining()
138 }
139
140 #[inline]
141 fn advance(&mut self, n: usize) -> Result<(), DecodeError> {
142 self.inner.advance(n)
143 }
144
145 #[inline]
146 fn check_alloc(&mut self, n: usize) -> Result<(), DecodeError> {
147 self.ctx.check_alloc(n)
148 }
149
150 #[inline]
151 fn depth_enter(&mut self) -> Result<(), DecodeError> {
152 self.ctx.depth_enter()
153 }
154
155 #[inline]
156 fn depth_exit(&mut self) {
157 self.ctx.depth_exit()
158 }
159
160 #[inline]
161 fn max_slice_len(&self) -> usize {
162 self.ctx.limits.max_slice_len
163 }
164
165 #[inline]
166 fn max_string_len(&self) -> usize {
167 self.ctx.limits.max_string_len
168 }
169
170 #[inline]
171 fn max_vec_len(&self) -> usize {
172 self.ctx.limits.max_vec_len
173 }
174}
175
176#[cfg(test)]
177mod tests {
178 use super::*;
179 use crate::buffer::SliceReader;
180
181 #[test]
182 fn limited_reader_enforces_alloc() {
183 let limits = DecodeLimits {
184 max_alloc: 10,
185 ..DecodeLimits::default()
186 };
187 let mut r = LimitedReader::new(SliceReader::new(b"0123456789"), limits);
188 r.check_alloc(5).unwrap();
189 r.check_alloc(5).unwrap();
190 assert!(r.check_alloc(1).is_err());
191 }
192
193 #[test]
194 fn limited_reader_enforces_depth() {
195 let limits = DecodeLimits {
196 max_depth: 2,
197 ..DecodeLimits::default()
198 };
199 let mut r = LimitedReader::new(SliceReader::new(b""), limits);
200 r.depth_enter().unwrap();
201 r.depth_enter().unwrap();
202 assert!(r.depth_enter().is_err());
203 r.depth_exit();
204 r.depth_enter().unwrap(); }
206}