better_limit_reader/
lib.rs

1#![warn(missing_docs)]
2
3//! # better-limit-reader
4//!
5//! Exposes [`LimitReader`] which is a limit reader, that protects against zip-bombs and other nefarious activities.
6//!
7//! This crate is heavily inspired by Jon Gjengset's "Crust of Rust" episode on the inner workings of git on `YouTube` (<https://youtu.be/u0VotuGzD_w?si=oIuV9CITSWHJXKBu&t=3503>) and mitigrating Zip-bombs.
8
9use derive_builder::Builder;
10use error::LimitReaderError;
11use flate2::read::ZlibDecoder;
12use readable::MyBufReader;
13use readable::Readable;
14use readable::{falible::LimitReaderFallible, infalible::LimitReaderInfallible};
15use std::fmt::Display;
16use std::fmt::Formatter;
17use std::io;
18use std::io::prelude::*;
19use std::io::BufReader;
20use std::path::PathBuf;
21
22use LimitReaderResult as Result;
23
24pub(crate) mod error;
25pub(crate) mod readable;
26
27/// Default result type for [`LimitReader`]
28pub type LimitReaderResult<T> = std::result::Result<T, LimitReaderError>;
29
30/// Re-exports Traits and macros used by most projects. Add `use better_limit_reader::prelude::*;` to your code to quickly get started with [`LimitReader`].
31pub mod prelude {
32
33    pub use crate::{error::LimitReaderError, LimitReader, LimitReaderOutput, LimitReaderResult};
34}
35
36#[allow(dead_code)]
37/// The [`LimitReader`] reads into `buf` which is held within the record struct.
38pub struct LimitReader {
39    buf: [u8; Self::DEFAULT_BUF_SIZE],
40    expected_size: u64,
41    decode_zlib: bool,
42    decode_gzip: bool,
43}
44
45impl Default for LimitReader {
46    fn default() -> Self {
47        Self::new()
48    }
49}
50
51// Holds a `LimitReader` with a default buffer of size `LimitReader::DEFAULT_BUF_SIZE`
52impl LimitReader {
53    /// Default buffer size for the internal `LimitReader`
54    pub const DEFAULT_BUF_SIZE: usize = 1024;
55
56    /// Create a new [`LimitReader`] with a [`LimitReader::DEFAULT_BUF_SIZE`] for the limit-readers max threshold.
57    #[must_use]
58    pub fn new() -> Self {
59        Self {
60            buf: [0; Self::DEFAULT_BUF_SIZE],
61            expected_size: (Self::DEFAULT_BUF_SIZE - 1) as u64,
62            decode_zlib: false,
63            decode_gzip: false,
64        }
65    }
66
67    /// Return a reference to the internal buffer.
68    #[must_use]
69    pub fn buffer(&self) -> &[u8; Self::DEFAULT_BUF_SIZE] {
70        &self.buf
71    }
72
73    /// Increase the allowed limit on the [`LimitReader`]
74    pub fn limit(&mut self, limit: u64) -> &mut Self {
75        self.expected_size = limit;
76
77        self
78    }
79
80    /// Enable decoding from compressed Zlib
81    pub fn enable_decode_zlib(&mut self) -> &mut Self {
82        self.decode_zlib = true;
83
84        self
85    }
86
87    #[allow(dead_code)]
88    // NOTE: This is private until this is implemented in the future.
89    /// Enable decoding from compressed Gzip
90    fn enable_decode_gzip(&mut self) -> &mut Self {
91        self.decode_gzip = true;
92
93        self
94    }
95
96    /// Read from provided source file.  If the source data is already Zlib compressed, optionally decode the data stream before reading it through a limit-reader.
97    ///
98    /// # Panics
99    ///
100    /// If the provided source file does not exist or is inaccessible, it will panic.  Refer to [`std::fs::File::open`] for details.  This will return [`LimitReaderError`].
101    ///
102    /// # Errors
103    ///
104    /// If this function encounters an error of the kind [`LimitReaderError`], this error will be returned.
105    ///
106    pub fn read(&mut self, source: PathBuf) -> Result<usize> {
107        let f = std::fs::File::open(source).expect("Unable to open file");
108        if self.decode_zlib {
109            let z = ZlibDecoder::new(f);
110            let buf_reader = MyBufReader(z);
111            let reader = LimitReaderFallible::new(buf_reader, self.expected_size);
112
113            self.try_read(reader)
114        } else {
115            let buf_reader = MyBufReader(BufReader::new(f));
116            let reader = LimitReaderFallible::new(buf_reader, self.expected_size);
117
118            self.try_read(reader)
119        }
120    }
121
122    /// Given an accessible source file, this will automatically limit the contents read to the size of the buffer itself.  This will silently truncate read bytes into the buffer, without raising an error.
123    ///
124    /// # Errors
125    ///
126    /// If this function encounters an error of the kind [`LimitReaderError`], this error will be returned.
127    ///
128    pub fn read_limited(&mut self, source: PathBuf) -> Result<LimitReaderOutput> {
129        let source_bytes = std::fs::metadata(&source)?.len();
130        let f = std::fs::File::open(source)?;
131
132        let bytes_read = if self.decode_zlib {
133            let z = ZlibDecoder::new(f);
134            let buf_reader = MyBufReader(z);
135            let reader = LimitReaderInfallible::new(buf_reader, self.expected_size);
136
137            self.try_read(reader)?
138        } else {
139            let buf_reader = MyBufReader(BufReader::new(f));
140            let reader = LimitReaderInfallible::new(buf_reader, self.expected_size);
141
142            self.try_read(reader)?
143        };
144
145        Ok(LimitReaderOutputBuilder::default()
146            .source_size(source_bytes)
147            .bytes_read(bytes_read as u64)
148            .build()?)
149    }
150
151    fn try_read(&mut self, mut reader: impl Readable) -> Result<usize> {
152        let try_read = reader.perform_read(&mut self.buf);
153        match try_read {
154            Ok(value) => Ok(value),
155            Err(err) => Err(LimitReaderError::new(error::ErrorKind::ReadError, err)),
156        }
157    }
158}
159
160/// [`LimitReader`]'s output
161#[allow(missing_docs)]
162#[derive(Default, Builder)]
163#[builder(setter(into))]
164pub struct LimitReaderOutput {
165    source_size: u64,
166    bytes_read: u64,
167}
168
169impl LimitReaderOutput {
170    /// Return bytes read by the underlying reader.
171    #[must_use]
172    pub fn bytes_read(&self) -> u64 {
173        self.bytes_read
174    }
175
176    /// Size in bytes of the underlying file accessible to the reader.
177    #[must_use]
178    pub fn source_size(&self) -> u64 {
179        self.source_size
180    }
181
182    /// Unread bytes (from the underlying file accessible to the reader).
183    #[must_use]
184    pub fn bytes_remaining(&self) -> u64 {
185        self.source_size - self.bytes_read
186    }
187}
188
189impl Display for LimitReaderOutput {
190    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
191        write!(
192            f,
193            "{{ source_size: {}, bytes_read:{} }}",
194            self.source_size, self.bytes_read
195        )
196    }
197}
198
199#[cfg(test)]
200mod tests {
201    use crate::LimitReader;
202    use flate2::write::ZlibEncoder;
203    use flate2::Compression;
204    use std::fs::File;
205    use std::io::Write;
206    use tempfile::tempdir;
207
208    mod falible {
209        use super::*;
210
211        #[test]
212        fn it_works() {
213            let dir = tempdir().unwrap();
214
215            let text = "Mike was here. Briefly.";
216            let file_path = dir.path().join("test_output.txt");
217            let mut file = File::create(&file_path).unwrap();
218            writeln!(file, "{}", &text).unwrap();
219
220            let mut limit_reader = LimitReader::new();
221
222            match limit_reader.read(file_path.clone()) {
223                Ok(read_size) => {
224                    let persisted_text =
225                        String::from_utf8(limit_reader.buf[..read_size].to_vec()).unwrap();
226                    assert_eq!(persisted_text, format!("{}\n", &text).to_string());
227                }
228                Err(_) => unreachable!(),
229            }
230
231            // ZlibDecode
232            let mut file = File::create(&file_path).unwrap();
233            let mut e = ZlibEncoder::new(Vec::new(), Compression::default());
234            e.write_all(text.as_bytes()).unwrap();
235            let compressed = e.finish().unwrap();
236            file.write_all(&compressed).unwrap();
237
238            let mut limit_reader = LimitReader::new();
239            limit_reader.enable_decode_zlib();
240
241            match limit_reader.read(file_path) {
242                Ok(read_size) => {
243                    let persisted_text =
244                        String::from_utf8(limit_reader.buf[..read_size].to_vec()).unwrap();
245                    assert_eq!(persisted_text, format!("{}", &text).to_string());
246                }
247                Err(_) => unreachable!(),
248            };
249
250            drop(file);
251            dir.close().unwrap();
252        }
253
254        #[test]
255        fn panic_due_to_limit_constraint() {
256            let dir = tempdir().unwrap();
257
258            let text = "Mike was here. Briefly.";
259            let file_path = dir.path().join("test_output.txt");
260            let mut file = File::create(&file_path).unwrap();
261            writeln!(file, "{}", &text).unwrap();
262
263            let mut limit_reader = LimitReader::new();
264            let limit = 8_u64;
265            limit_reader.limit(limit);
266
267            match limit_reader.read(file_path) {
268                Ok(read_size) => {
269                    assert!(read_size == limit.try_into().unwrap());
270                }
271                Err(err) => {
272                    assert_eq!("Error: too many bytes", err.to_string());
273                }
274            }
275
276            drop(file);
277            dir.close().unwrap();
278        }
279
280        #[test]
281        fn panic_with_decode_zlib_due_to_limit_constraint() {
282            let dir = tempdir().unwrap();
283
284            let text = "Mike was here. Briefly.";
285            let mut e = ZlibEncoder::new(Vec::new(), Compression::default());
286            e.write_all(text.as_bytes()).unwrap();
287            let compressed = e.finish().unwrap();
288
289            let file_path = dir.path().join("test_output.txt");
290            let mut file = File::create(&file_path).unwrap();
291            file.write_all(&compressed).unwrap();
292
293            let mut limit_reader = LimitReader::new();
294
295            // NOTE: This should error due to exceeding our limit.
296            limit_reader.limit(8);
297
298            match limit_reader.read(file_path) {
299                Ok(read_size) => {
300                    let persisted_text =
301                        String::from_utf8(limit_reader.buf[..read_size].to_vec()).unwrap();
302                    assert_eq!(persisted_text, format!("{}", &text).to_string());
303                }
304                Err(err) => assert_eq!("Error: too many bytes", err.to_string()),
305            };
306
307            drop(file);
308            dir.close().unwrap();
309        }
310
311        #[test]
312        fn panic_decode_zlib_error_on_corrupt_deflate_stream() {
313            let dir = tempdir().unwrap();
314
315            let text = "Mike was here. Briefly.";
316            let file_path = dir.path().join("test_output.txt");
317            let mut file = File::create(&file_path).unwrap();
318            writeln!(file, "{}", &text).unwrap();
319
320            let mut limit_reader = LimitReader::new();
321            limit_reader.enable_decode_zlib();
322
323            match limit_reader.read(file_path) {
324                Ok(_) => unreachable!(),
325                Err(err) => assert_eq!("Error: corrupt deflate stream", err.to_string()),
326            };
327
328            drop(file);
329            dir.close().unwrap();
330        }
331    }
332
333    mod infalible {
334        use super::*;
335
336        #[test]
337        fn it_works() {
338            let dir = tempdir().unwrap();
339
340            let text = "Mike was here. Briefly.";
341            let file_path = dir.path().join("test_output.txt");
342            let mut file = File::create(&file_path).unwrap();
343            writeln!(file, "{}", &text).unwrap();
344
345            let mut limit_reader = LimitReader::new();
346            let limit = 8_u64;
347            limit_reader.limit(limit);
348
349            match limit_reader.read_limited(file_path.clone()) {
350                Ok(reader_output) => {
351                    let bytes_read = reader_output.bytes_read();
352                    assert!(bytes_read == limit)
353                }
354                Err(_) => unreachable!(),
355            }
356
357            // ZlibDecode
358            let mut file = File::create(&file_path).unwrap();
359            let mut e = ZlibEncoder::new(Vec::new(), Compression::default());
360            e.write_all(text.as_bytes()).unwrap();
361            let compressed = e.finish().unwrap();
362            file.write_all(&compressed).unwrap();
363
364            let mut limit_reader = LimitReader::new();
365            limit_reader.limit(limit).enable_decode_zlib();
366
367            match limit_reader.read_limited(file_path.clone()) {
368                Ok(reader_output) => {
369                    let bytes_read = reader_output.bytes_read();
370                    let persisted_text =
371                        String::from_utf8(limit_reader.buf[..(bytes_read as usize)].to_vec())
372                            .unwrap();
373                    assert_eq!(
374                        persisted_text,
375                        format!("{}", &text[..(bytes_read as usize)]).to_string()
376                    );
377                }
378                Err(_) => unreachable!(),
379            };
380
381            drop(file);
382            dir.close().unwrap();
383        }
384
385        #[test]
386        fn panic_decode_zlib_error_on_corrupt_deflate_stream() {
387            let dir = tempdir().unwrap();
388
389            let text = "Mike was here. Briefly.";
390            let file_path = dir.path().join("test_output.txt");
391            let mut file = File::create(&file_path).unwrap();
392            writeln!(file, "{}", &text).unwrap();
393
394            let mut limit_reader = LimitReader::new();
395            let limit = 8_u64;
396            limit_reader
397                // RA block
398                .limit(limit)
399                .enable_decode_zlib();
400
401            match limit_reader.read(file_path) {
402                Ok(_) => unreachable!(),
403                Err(err) => assert_eq!("Error: corrupt deflate stream", err.to_string()),
404            };
405
406            drop(file);
407            dir.close().unwrap();
408        }
409    }
410}