Skip to main content

ax_codec_core/
limits.rs

1use crate::{BufferReader, DecodeError};
2use core::marker::PhantomData;
3#[derive(Debug, Clone, Copy, PartialEq, Eq)]
4pub struct DecodeLimits {
5    pub max_alloc: usize,
6    pub max_depth: u16,
7    pub max_string_len: usize,
8    pub max_vec_len: usize,
9    pub max_slice_len: usize,
10}
11
12impl Default for DecodeLimits {
13    fn default() -> Self {
14        Self {
15            max_alloc: 16 * 1024 * 1024, // 16 MiB
16            max_depth: 256,
17            max_string_len: 16 * 1024 * 1024, // 16 MiB
18            max_vec_len: 1024 * 1024,         // 1M items
19            max_slice_len: 1024 * 1024,       // 1M items
20        }
21    }
22}
23
24impl DecodeLimits {
25    pub const fn unlimited() -> Self {
26        Self {
27            max_alloc: usize::MAX,
28            max_depth: u16::MAX,
29            max_string_len: usize::MAX,
30            max_vec_len: usize::MAX,
31            max_slice_len: usize::MAX,
32        }
33    }
34
35    pub const fn conservative() -> Self {
36        Self {
37            max_alloc: 64 * 1024, // 64 KiB
38            max_depth: 32,
39            max_string_len: 4096,
40            max_vec_len: 4096,
41            max_slice_len: 4096,
42        }
43    }
44}
45
46#[derive(Debug)]
47pub struct DecodeContext {
48    pub limits: DecodeLimits,
49    pub alloc_used: usize,
50    pub current_depth: u16,
51}
52
53impl DecodeContext {
54    pub fn new(limits: DecodeLimits) -> Self {
55        Self {
56            limits,
57            alloc_used: 0,
58            current_depth: 0,
59        }
60    }
61
62    #[inline]
63    pub fn check_alloc(&mut self, n: usize) -> Result<(), DecodeError> {
64        if n > self.limits.max_alloc {
65            return Err(DecodeError::AllocationLimitExceeded);
66        }
67        self.alloc_used += n;
68        Ok(())
69    }
70
71    #[inline]
72    pub fn depth_enter(&mut self) -> Result<(), DecodeError> {
73        if self.current_depth >= self.limits.max_depth {
74            return Err(DecodeError::DepthLimitExceeded);
75        }
76        self.current_depth += 1;
77        Ok(())
78    }
79
80    #[inline]
81    pub fn depth_exit(&mut self) {
82        self.current_depth = self.current_depth.saturating_sub(1);
83    }
84}
85
86#[derive(Debug)]
87pub struct LimitedReader<'a, R: BufferReader<'a>> {
88    inner: R,
89    ctx: DecodeContext,
90    _marker: PhantomData<&'a ()>,
91}
92
93impl<'a, R: BufferReader<'a>> LimitedReader<'a, R> {
94    #[inline]
95    pub fn new(inner: R, limits: DecodeLimits) -> Self {
96        Self {
97            inner,
98            ctx: DecodeContext::new(limits),
99            _marker: PhantomData,
100        }
101    }
102
103    #[inline]
104    pub fn finish(self) -> (R, DecodeContext) {
105        (self.inner, self.ctx)
106    }
107
108    #[inline]
109    pub fn context(&self) -> &DecodeContext {
110        &self.ctx
111    }
112
113    #[inline]
114    pub fn context_mut(&mut self) -> &mut DecodeContext {
115        &mut self.ctx
116    }
117}
118
119impl<'a, R: BufferReader<'a>> BufferReader<'a> for LimitedReader<'a, R> {
120    #[inline]
121    fn peek(&self) -> Option<u8> {
122        self.inner.peek()
123    }
124
125    #[inline]
126    fn next(&mut self) -> Option<u8> {
127        self.inner.next()
128    }
129
130    #[inline]
131    fn read_exact(&mut self, buf: &mut [u8]) -> Result<(), DecodeError> {
132        self.inner.read_exact(buf)
133    }
134
135    #[inline]
136    fn remaining(&self) -> &'a [u8] {
137        self.inner.remaining()
138    }
139
140    #[inline]
141    fn advance(&mut self, n: usize) -> Result<(), DecodeError> {
142        self.inner.advance(n)
143    }
144
145    #[inline]
146    fn check_alloc(&mut self, n: usize) -> Result<(), DecodeError> {
147        self.ctx.check_alloc(n)
148    }
149
150    #[inline]
151    fn depth_enter(&mut self) -> Result<(), DecodeError> {
152        self.ctx.depth_enter()
153    }
154
155    #[inline]
156    fn depth_exit(&mut self) {
157        self.ctx.depth_exit()
158    }
159
160    #[inline]
161    fn max_slice_len(&self) -> usize {
162        self.ctx.limits.max_slice_len
163    }
164
165    #[inline]
166    fn max_string_len(&self) -> usize {
167        self.ctx.limits.max_string_len
168    }
169
170    #[inline]
171    fn max_vec_len(&self) -> usize {
172        self.ctx.limits.max_vec_len
173    }
174}
175
176#[cfg(test)]
177mod tests {
178    use super::*;
179    use crate::buffer::SliceReader;
180
181    #[test]
182    fn limited_reader_enforces_alloc() {
183        let limits = DecodeLimits {
184            max_alloc: 10,
185            ..DecodeLimits::default()
186        };
187        let mut r = LimitedReader::new(SliceReader::new(b"0123456789"), limits);
188        r.check_alloc(5).unwrap();
189        r.check_alloc(5).unwrap();
190        assert!(r.check_alloc(1).is_err());
191    }
192
193    #[test]
194    fn limited_reader_enforces_depth() {
195        let limits = DecodeLimits {
196            max_depth: 2,
197            ..DecodeLimits::default()
198        };
199        let mut r = LimitedReader::new(SliceReader::new(b""), limits);
200        r.depth_enter().unwrap();
201        r.depth_enter().unwrap();
202        assert!(r.depth_enter().is_err());
203        r.depth_exit();
204        r.depth_enter().unwrap(); // back to 2
205    }
206}