1use crate::{
2 err::{io::LimitedReadError, Layer, LenError},
3 *,
4};
5
6#[cfg(feature = "std")]
15#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
16pub struct LimitedReader<T> {
17 reader: T,
19 max_len: usize,
21 len_source: LenSource,
23 layer: Layer,
25 layer_offset: usize,
27 read_len: usize,
29}
30
31#[cfg(feature = "std")]
32#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
33impl<T: std::io::Read + Sized> LimitedReader<T> {
34 pub fn new(
36 reader: T,
37 max_len: usize,
38 len_source: LenSource,
39 layer_offset: usize,
40 layer: Layer,
41 ) -> LimitedReader<T> {
42 LimitedReader {
43 reader,
44 max_len,
45 len_source,
46 layer,
47 layer_offset,
48 read_len: 0,
49 }
50 }
51
52 pub fn max_len(&self) -> usize {
54 self.max_len
55 }
56
57 pub fn len_source(&self) -> LenSource {
59 self.len_source
60 }
61
62 pub fn layer(&self) -> Layer {
64 self.layer
65 }
66
67 pub fn layer_offset(&self) -> usize {
69 self.layer_offset
70 }
71
72 pub fn read_len(&self) -> usize {
74 self.read_len
75 }
76
77 pub fn start_layer(&mut self, layer: Layer) {
79 self.layer_offset += self.read_len;
80 self.max_len -= self.read_len;
81 self.read_len = 0;
82 self.layer = layer;
83 }
84
85 pub fn read_exact(&mut self, buf: &mut [u8]) -> Result<(), LimitedReadError> {
89 use LimitedReadError::*;
90 if self.max_len - self.read_len < buf.len() {
91 Err(Len(LenError {
92 required_len: self.read_len + buf.len(),
93 len: self.max_len,
94 len_source: self.len_source,
95 layer: self.layer,
96 layer_start_offset: self.layer_offset,
97 }))
98 } else {
99 self.reader.read_exact(buf).map_err(Io)?;
100 self.read_len += buf.len();
101 Ok(())
102 }
103 }
104
105 pub fn take_reader(self) -> T {
107 self.reader
108 }
109}
110
111#[cfg(feature = "std")]
112#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
113impl<T: core::fmt::Debug> core::fmt::Debug for LimitedReader<T> {
114 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
115 f.debug_struct("LimitedReader")
116 .field("reader", &self.reader)
117 .field("max_len", &self.max_len)
118 .field("len_source", &self.len_source)
119 .field("layer", &self.layer)
120 .field("layer_offset", &self.layer_offset)
121 .field("read_len", &self.read_len)
122 .finish()
123 }
124}
125
126#[cfg(all(test, feature = "std"))]
127mod tests {
128 use std::format;
129 use std::io::Cursor;
130
131 use super::*;
132
133 #[test]
134 fn new() {
135 let data = [1, 2, 3, 4];
136 let actual = LimitedReader::new(
137 Cursor::new(&data),
138 data.len(),
139 LenSource::Slice,
140 5,
141 Layer::Ipv4Header,
142 );
143 assert_eq!(actual.max_len, data.len());
144 assert_eq!(actual.max_len(), data.len());
145 assert_eq!(actual.len_source, LenSource::Slice);
146 assert_eq!(actual.len_source(), LenSource::Slice);
147 assert_eq!(actual.layer, Layer::Ipv4Header);
148 assert_eq!(actual.layer(), Layer::Ipv4Header);
149 assert_eq!(actual.layer_offset, 5);
150 assert_eq!(actual.layer_offset(), 5);
151 assert_eq!(actual.read_len, 0);
152 assert_eq!(actual.read_len(), 0);
153 }
154
155 #[test]
156 fn start_layer() {
157 let data = [1, 2, 3, 4, 5];
158 let mut r = LimitedReader::new(
159 Cursor::new(&data),
160 data.len(),
161 LenSource::Slice,
162 6,
163 Layer::Ipv4Header,
164 );
165 {
166 let mut read_result = [0u8; 2];
167 r.read_exact(&mut read_result).unwrap();
168 assert_eq!(read_result, [1, 2]);
169 }
170 r.start_layer(Layer::IpAuthHeader);
171
172 assert_eq!(r.max_len, 3);
173 assert_eq!(r.len_source, LenSource::Slice);
174 assert_eq!(r.layer, Layer::IpAuthHeader);
175 assert_eq!(r.layer_offset, 2 + 6);
176 assert_eq!(r.read_len, 0);
177
178 {
179 let mut read_result = [0u8; 4];
180 assert_eq!(
181 r.read_exact(&mut read_result).unwrap_err().len().unwrap(),
182 LenError {
183 required_len: 4,
184 len: 3,
185 len_source: LenSource::Slice,
186 layer: Layer::IpAuthHeader,
187 layer_start_offset: 2 + 6
188 }
189 );
190 }
191 }
192
193 #[test]
194 fn read_exact() {
195 let data = [1, 2, 3, 4, 5];
196 let mut r = LimitedReader::new(
197 Cursor::new(&data),
198 data.len() + 1,
199 LenSource::Ipv4HeaderTotalLen,
200 10,
201 Layer::Ipv4Header,
202 );
203
204 {
206 let mut read_result = [0u8; 2];
207 r.read_exact(&mut read_result).unwrap();
208 assert_eq!(read_result, [1, 2]);
209 }
210
211 {
213 let mut read_result = [0u8; 5];
214 assert_eq!(
215 r.read_exact(&mut read_result).unwrap_err().len().unwrap(),
216 LenError {
217 required_len: 7,
218 len: 6,
219 len_source: LenSource::Ipv4HeaderTotalLen,
220 layer: Layer::Ipv4Header,
221 layer_start_offset: 10
222 }
223 );
224 }
225
226 {
228 let mut read_result = [0u8; 4];
229 assert!(r.read_exact(&mut read_result).unwrap_err().io().is_some());
230 }
231 }
232
233 #[test]
234 fn take_reader() {
235 let data = [1, 2, 3, 4, 5];
236 let mut r = LimitedReader::new(
237 Cursor::new(&data),
238 data.len(),
239 LenSource::Slice,
240 6,
241 Layer::Ipv4Header,
242 );
243 {
244 let mut read_result = [0u8; 2];
245 r.read_exact(&mut read_result).unwrap();
246 assert_eq!(read_result, [1, 2]);
247 }
248 let result = r.take_reader();
249 assert_eq!(2, result.position());
250 }
251
252 #[test]
253 fn debug() {
254 let data = [1, 2, 3, 4];
255 let actual = LimitedReader::new(
256 Cursor::new(&data),
257 data.len(),
258 LenSource::Slice,
259 5,
260 Layer::Ipv4Header,
261 );
262 assert_eq!(
263 format!("{:?}", actual),
264 format!(
265 "LimitedReader {{ reader: {:?}, max_len: {:?}, len_source: {:?}, layer: {:?}, layer_offset: {:?}, read_len: {:?} }}",
266 &actual.reader,
267 &actual.max_len,
268 &actual.len_source,
269 &actual.layer,
270 &actual.layer_offset,
271 &actual.read_len
272 )
273 );
274 }
275}