1use crate::{BufferReader, DecodeError};
2use core::marker::PhantomData;
3#[derive(Debug, Clone, Copy, PartialEq, Eq)]
4pub struct DecodeLimits {
5 pub max_alloc: usize,
6 pub max_depth: u16,
7 pub max_string_len: usize,
8 pub max_vec_len: usize,
9}
10
11impl Default for DecodeLimits {
12 fn default() -> Self {
13 Self {
14 max_alloc: 16 * 1024 * 1024, max_depth: 256,
16 max_string_len: 4 * 1024 * 1024, max_vec_len: 1024 * 1024, }
19 }
20}
21
22impl DecodeLimits {
23 pub const fn unlimited() -> Self {
24 Self {
25 max_alloc: usize::MAX,
26 max_depth: u16::MAX,
27 max_string_len: usize::MAX,
28 max_vec_len: usize::MAX,
29 }
30 }
31
32 pub const fn conservative() -> Self {
33 Self {
34 max_alloc: 64 * 1024, max_depth: 32,
36 max_string_len: 4096,
37 max_vec_len: 4096,
38 }
39 }
40}
41
42#[derive(Debug)]
43pub struct DecodeContext {
44 pub limits: DecodeLimits,
45 pub alloc_used: usize,
46 pub current_depth: u16,
47}
48
49impl DecodeContext {
50 pub fn new(limits: DecodeLimits) -> Self {
51 Self {
52 limits,
53 alloc_used: 0,
54 current_depth: 0,
55 }
56 }
57
58 #[inline]
59 pub fn check_alloc(&mut self, n: usize) -> Result<(), DecodeError> {
60 if n > self.limits.max_alloc - self.alloc_used {
61 return Err(DecodeError::AllocationLimitExceeded);
62 }
63 self.alloc_used += n;
64 Ok(())
65 }
66
67 #[inline]
68 pub fn depth_enter(&mut self) -> Result<(), DecodeError> {
69 if self.current_depth >= self.limits.max_depth {
70 return Err(DecodeError::DepthLimitExceeded);
71 }
72 self.current_depth += 1;
73 Ok(())
74 }
75
76 #[inline]
77 pub fn depth_exit(&mut self) {
78 self.current_depth = self.current_depth.saturating_sub(1);
79 }
80}
81
82#[derive(Debug)]
83pub struct LimitedReader<'a, R: BufferReader<'a>> {
84 inner: R,
85 ctx: DecodeContext,
86 _marker: PhantomData<&'a ()>,
87}
88
89impl<'a, R: BufferReader<'a>> LimitedReader<'a, R> {
90 #[inline]
91 pub fn new(inner: R, limits: DecodeLimits) -> Self {
92 Self {
93 inner,
94 ctx: DecodeContext::new(limits),
95 _marker: PhantomData,
96 }
97 }
98
99 #[inline]
100 pub fn finish(self) -> (R, DecodeContext) {
101 (self.inner, self.ctx)
102 }
103
104 #[inline]
105 pub fn context(&self) -> &DecodeContext {
106 &self.ctx
107 }
108
109 #[inline]
110 pub fn context_mut(&mut self) -> &mut DecodeContext {
111 &mut self.ctx
112 }
113}
114
115impl<'a, R: BufferReader<'a>> BufferReader<'a> for LimitedReader<'a, R> {
116 #[inline]
117 fn peek(&self) -> Option<u8> {
118 self.inner.peek()
119 }
120
121 #[inline]
122 fn next(&mut self) -> Option<u8> {
123 self.inner.next()
124 }
125
126 #[inline]
127 fn read_exact(&mut self, buf: &mut [u8]) -> Result<(), DecodeError> {
128 self.inner.read_exact(buf)
129 }
130
131 #[inline]
132 fn remaining(&self) -> &'a [u8] {
133 self.inner.remaining()
134 }
135
136 #[inline]
137 fn advance(&mut self, n: usize) -> Result<(), DecodeError> {
138 self.inner.advance(n)
139 }
140
141 #[inline]
142 fn check_alloc(&mut self, n: usize) -> Result<(), DecodeError> {
143 self.ctx.check_alloc(n)
144 }
145
146 #[inline]
147 fn depth_enter(&mut self) -> Result<(), DecodeError> {
148 self.ctx.depth_enter()
149 }
150
151 #[inline]
152 fn depth_exit(&mut self) {
153 self.ctx.depth_exit()
154 }
155}
156
157#[cfg(test)]
158mod tests {
159 use super::*;
160 use crate::buffer::SliceReader;
161
162 #[test]
163 fn limited_reader_enforces_alloc() {
164 let limits = DecodeLimits {
165 max_alloc: 10,
166 ..DecodeLimits::default()
167 };
168 let mut r = LimitedReader::new(SliceReader::new(b"0123456789"), limits);
169 r.check_alloc(5).unwrap();
170 r.check_alloc(5).unwrap();
171 assert!(r.check_alloc(1).is_err());
172 }
173
174 #[test]
175 fn limited_reader_enforces_depth() {
176 let limits = DecodeLimits {
177 max_depth: 2,
178 ..DecodeLimits::default()
179 };
180 let mut r = LimitedReader::new(SliceReader::new(b""), limits);
181 r.depth_enter().unwrap();
182 r.depth_enter().unwrap();
183 assert!(r.depth_enter().is_err());
184 r.depth_exit();
185 r.depth_enter().unwrap(); }
187}