Skip to main content

ax_codec_core/
limits.rs

1use crate::{BufferReader, DecodeError};
2use core::marker::PhantomData;
3#[derive(Debug, Clone, Copy, PartialEq, Eq)]
4pub struct DecodeLimits {
5    pub max_alloc: usize,
6    pub max_depth: u16,
7    pub max_string_len: usize,
8    pub max_vec_len: usize,
9}
10
11impl Default for DecodeLimits {
12    fn default() -> Self {
13        Self {
14            max_alloc: 16 * 1024 * 1024, // 16 MiB
15            max_depth: 256,
16            max_string_len: 4 * 1024 * 1024, // 4 MiB
17            max_vec_len: 1024 * 1024,        // 1M items
18        }
19    }
20}
21
22impl DecodeLimits {
23    pub const fn unlimited() -> Self {
24        Self {
25            max_alloc: usize::MAX,
26            max_depth: u16::MAX,
27            max_string_len: usize::MAX,
28            max_vec_len: usize::MAX,
29        }
30    }
31
32    pub const fn conservative() -> Self {
33        Self {
34            max_alloc: 64 * 1024, // 64 KiB
35            max_depth: 32,
36            max_string_len: 4096,
37            max_vec_len: 4096,
38        }
39    }
40}
41
42#[derive(Debug)]
43pub struct DecodeContext {
44    pub limits: DecodeLimits,
45    pub alloc_used: usize,
46    pub current_depth: u16,
47}
48
49impl DecodeContext {
50    pub fn new(limits: DecodeLimits) -> Self {
51        Self {
52            limits,
53            alloc_used: 0,
54            current_depth: 0,
55        }
56    }
57
58    #[inline]
59    pub fn check_alloc(&mut self, n: usize) -> Result<(), DecodeError> {
60        if n > self.limits.max_alloc - self.alloc_used {
61            return Err(DecodeError::AllocationLimitExceeded);
62        }
63        self.alloc_used += n;
64        Ok(())
65    }
66
67    #[inline]
68    pub fn depth_enter(&mut self) -> Result<(), DecodeError> {
69        if self.current_depth >= self.limits.max_depth {
70            return Err(DecodeError::DepthLimitExceeded);
71        }
72        self.current_depth += 1;
73        Ok(())
74    }
75
76    #[inline]
77    pub fn depth_exit(&mut self) {
78        self.current_depth = self.current_depth.saturating_sub(1);
79    }
80}
81
82#[derive(Debug)]
83pub struct LimitedReader<'a, R: BufferReader<'a>> {
84    inner: R,
85    ctx: DecodeContext,
86    _marker: PhantomData<&'a ()>,
87}
88
89impl<'a, R: BufferReader<'a>> LimitedReader<'a, R> {
90    #[inline]
91    pub fn new(inner: R, limits: DecodeLimits) -> Self {
92        Self {
93            inner,
94            ctx: DecodeContext::new(limits),
95            _marker: PhantomData,
96        }
97    }
98
99    #[inline]
100    pub fn finish(self) -> (R, DecodeContext) {
101        (self.inner, self.ctx)
102    }
103
104    #[inline]
105    pub fn context(&self) -> &DecodeContext {
106        &self.ctx
107    }
108
109    #[inline]
110    pub fn context_mut(&mut self) -> &mut DecodeContext {
111        &mut self.ctx
112    }
113}
114
115impl<'a, R: BufferReader<'a>> BufferReader<'a> for LimitedReader<'a, R> {
116    #[inline]
117    fn peek(&self) -> Option<u8> {
118        self.inner.peek()
119    }
120
121    #[inline]
122    fn next(&mut self) -> Option<u8> {
123        self.inner.next()
124    }
125
126    #[inline]
127    fn read_exact(&mut self, buf: &mut [u8]) -> Result<(), DecodeError> {
128        self.inner.read_exact(buf)
129    }
130
131    #[inline]
132    fn remaining(&self) -> &'a [u8] {
133        self.inner.remaining()
134    }
135
136    #[inline]
137    fn advance(&mut self, n: usize) -> Result<(), DecodeError> {
138        self.inner.advance(n)
139    }
140
141    #[inline]
142    fn check_alloc(&mut self, n: usize) -> Result<(), DecodeError> {
143        self.ctx.check_alloc(n)
144    }
145
146    #[inline]
147    fn depth_enter(&mut self) -> Result<(), DecodeError> {
148        self.ctx.depth_enter()
149    }
150
151    #[inline]
152    fn depth_exit(&mut self) {
153        self.ctx.depth_exit()
154    }
155}
156
157#[cfg(test)]
158mod tests {
159    use super::*;
160    use crate::buffer::SliceReader;
161
162    #[test]
163    fn limited_reader_enforces_alloc() {
164        let limits = DecodeLimits {
165            max_alloc: 10,
166            ..DecodeLimits::default()
167        };
168        let mut r = LimitedReader::new(SliceReader::new(b"0123456789"), limits);
169        r.check_alloc(5).unwrap();
170        r.check_alloc(5).unwrap();
171        assert!(r.check_alloc(1).is_err());
172    }
173
174    #[test]
175    fn limited_reader_enforces_depth() {
176        let limits = DecodeLimits {
177            max_depth: 2,
178            ..DecodeLimits::default()
179        };
180        let mut r = LimitedReader::new(SliceReader::new(b""), limits);
181        r.depth_enter().unwrap();
182        r.depth_enter().unwrap();
183        assert!(r.depth_enter().is_err());
184        r.depth_exit();
185        r.depth_enter().unwrap(); // back to 2
186    }
187}