read_restrict/
lib.rs

1//! # read-restrict
2//!
3//! An adaptor around Rust's standard [`io::Take`] which instead of returning
4//! `Ok(0)` when the read limit is exceeded, instead returns an error of of the kind
5//! [`ErrorKind::InvalidData`].
6//!
7//! This is intended for enforcing explicit input limits when simply truncating with
8//! `take` could result in incorrect behaviour.
9//!
10//! `read_restrict` also offers restricted variants of
11//! [`std::fs::read`](std::fs::read) and
12//! [`std::fs::read_to_string`](std::fs::read_to_string), to conveniently
13//! prevent unbounded reads of overly-large files.
14//!
15//! # Examples
16//!
17//! ```no_run
18//! use std::io::{self, Read, ErrorKind};
19//!
20//! use read_restrict::ReadExt;
21//!
22//! fn main() -> io::Result<()> {
23//!     let f = std::fs::File::open("foo.txt")?;
24//!     let mut handle = f.restrict(5);
25//!     let mut buf = [0; 8];
26//!     assert_eq!(5, handle.read(&mut buf)?); // reads at most 5 bytes
27//!     assert_eq!(0, handle.restriction()); // is now exhausted
28//!     assert_eq!(ErrorKind::InvalidData, handle.read(&mut buf).unwrap_err().kind());
29//!     Ok(())
30//! }
31//! ```
32//!
33//! ```no_run
34//! fn load_config(path: &std::path::Path) -> std::io::Result<String> {
35//!     // No sensible configuration is going to exceed 640 KiB
36//!     let conf = read_restrict::read_to_string(&path, 640 * 1024)?;
37//!     // probably want to parse it here
38//!     Ok(conf)
39//! }
40//! ```
41//!
42//! [`io::Take`]: https://doc.rust-lang.org/std/io/struct.Take.html
43//! [`ErrorKind::InvalidData`]: https://doc.rust-lang.org/std/io/enum.ErrorKind.html#variant.InvalidData
44
45use std::fs::File;
46use std::io::{self, BufRead, Read, Result, Take};
47use std::path::Path;
48
49pub trait ReadExt {
50    fn restrict(self, restriction: u64) -> Restrict<Self>
51    where
52        Self: Sized + Read,
53    {
54        Restrict {
55            inner: self.take(restriction),
56        }
57    }
58}
59
60impl<R: Read> ReadExt for R {}
61
62/// Reader adaptor which restricts the bytes read from an underlying reader,
63/// returning an IO error of the kind [`ErrorKind::InvalidData`] when it is exceeded.
64///
65/// This struct is generally created by calling [`restrict`] on a reader.
66/// Please see the documentation of [`restrict`] for more details.
67///
68/// [`restrict`]: trait.ReadExt.html#method.restrict
69/// [`ErrorKind::InvalidData`]: https://doc.rust-lang.org/std/io/enum.ErrorKind.html#variant.InvalidData
70#[derive(Debug)]
71pub struct Restrict<T> {
72    inner: Take<T>,
73}
74
75impl<T> Restrict<T> {
76    /// Returns the number of bytes that can be read before this instance will
77    /// return an error.
78    ///
79    /// # Examples
80    ///
81    /// ```no_run
82    /// use std::io;
83    /// use std::io::prelude::*;
84    /// use read_restrict::ReadExt;
85    /// use std::fs::File;
86    ///
87    /// fn main() -> io::Result<()> {
88    ///     let f = File::open("foo.txt")?;
89    ///
90    ///     // read at most five bytes
91    ///     let handle = f.restrict(5);
92    ///
93    ///     println!("restriction: {}", handle.restriction());
94    ///     Ok(())
95    /// }
96    /// ```
97    pub fn restriction(&self) -> u64 {
98        self.inner.limit()
99    }
100
101    /// Sets the number of bytes that can be read before this instance will
102    /// return an error. This is the same as constructing a new `Restrict` instance, so
103    /// the amount of bytes read and the previous restriction value don't matter when
104    /// calling this method.
105    ///
106    /// # Examples
107    ///
108    /// ```no_run
109    /// use std::io;
110    /// use std::io::prelude::*;
111    /// use std::fs::File;
112    /// use read_restrict::ReadExt;
113    ///
114    /// fn main() -> io::Result<()> {
115    ///     let f = File::open("foo.txt")?;
116    ///
117    ///     // read at most five bytes
118    ///     let mut handle = f.restrict(5);
119    ///     handle.set_restriction(10);
120    ///
121    ///     assert_eq!(handle.restriction(), 10);
122    ///     Ok(())
123    /// }
124    /// ```
125    pub fn set_restriction(&mut self, restriction: u64) {
126        self.inner.set_limit(restriction);
127    }
128
129    /// Consumes the `Restrict`, returning the wrapped reader.
130    ///
131    /// # Examples
132    ///
133    /// ```no_run
134    /// use std::io;
135    /// use std::io::prelude::*;
136    /// use std::fs::File;
137    /// use read_restrict::ReadExt;
138    ///
139    /// fn main() -> io::Result<()> {
140    ///     let mut file = File::open("foo.txt")?;
141    ///
142    ///     let mut buffer = [0; 5];
143    ///     let mut handle = file.restrict(5);
144    ///     handle.read(&mut buffer)?;
145    ///
146    ///     let file = handle.into_inner();
147    ///     Ok(())
148    /// }
149    /// ```
150    pub fn into_inner(self) -> T {
151        self.inner.into_inner()
152    }
153
154    /// Gets a reference to the underlying reader.
155    ///
156    /// # Examples
157    ///
158    /// ```no_run
159    /// use std::io;
160    /// use std::io::prelude::*;
161    /// use std::fs::File;
162    /// use read_restrict::ReadExt;
163    ///
164    /// fn main() -> io::Result<()> {
165    ///     let mut file = File::open("foo.txt")?;
166    ///
167    ///     let mut buffer = [0; 5];
168    ///     let mut handle = file.restrict(5);
169    ///     handle.read(&mut buffer)?;
170    ///
171    ///     let file = handle.get_ref();
172    ///     Ok(())
173    /// }
174    /// ```
175    pub fn get_ref(&self) -> &T {
176        self.inner.get_ref()
177    }
178
179    /// Gets a mutable reference to the underlying reader.
180    ///
181    /// Care should be taken to avoid modifying the internal I/O state of the
182    /// underlying reader as doing so may corrupt the internal limit of this
183    /// `Restrict`.
184    ///
185    /// # Examples
186    ///
187    /// ```no_run
188    /// use std::io;
189    /// use std::io::prelude::*;
190    /// use std::fs::File;
191    /// use read_restrict::ReadExt;
192    ///
193    /// fn main() -> io::Result<()> {
194    ///     let mut file = File::open("foo.txt")?;
195    ///
196    ///     let mut buffer = [0; 5];
197    ///     let mut handle = file.restrict(5);
198    ///     handle.read(&mut buffer)?;
199    ///
200    ///     let file = handle.get_mut();
201    ///     Ok(())
202    /// }
203    /// ```
204    pub fn get_mut(&mut self) -> &mut T {
205        self.inner.get_mut()
206    }
207}
208
209impl<T: Read> Read for Restrict<T> {
210    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
211        if !buf.is_empty() && self.restriction() == 0 {
212            return Err(io::Error::new(
213                io::ErrorKind::InvalidData,
214                "Read restriction exceeded",
215            ));
216        }
217
218        self.inner.read(&mut buf[..])
219    }
220}
221
222impl<T: BufRead> BufRead for Restrict<T> {
223    fn fill_buf(&mut self) -> Result<&[u8]> {
224        if self.restriction() == 0 {
225            return Err(io::Error::new(
226                io::ErrorKind::InvalidData,
227                "Read restriction exceeded",
228            ));
229        }
230
231        self.inner.fill_buf()
232    }
233
234    fn consume(&mut self, amt: usize) {
235        self.inner.consume(amt);
236    }
237}
238
239/// Provided the file at `path` fits within the specified limit, pass a
240/// restricted read handle and a suitable initial buffer size to the closure
241/// and return its result.
242fn open_with_restriction<F, T>(path: &Path, restriction: usize, f: F) -> io::Result<T>
243where
244    F: FnOnce(Restrict<File>, usize) -> io::Result<T>,
245{
246    let file = File::open(path)?;
247    let size = match file.metadata().map(|m| m.len()) {
248        Ok(size) if size > restriction as u64 => {
249            return Err(io::Error::new(
250                io::ErrorKind::InvalidData,
251                "File exceeds size restriction",
252            ))
253        }
254        Ok(size) => (size as usize).saturating_add(1),
255        Err(_) => 0,
256    };
257    f(file.restrict(restriction as u64 + 1), size)
258}
259
260/// Read the entire contents of a file into a bytes vector, provided it fits
261/// within a specified size limit.
262///
263/// This is a restricted alternative to [`std::fs::read`](std::fs::read)
264/// with otherwise identical semantics.
265///
266/// # Examples
267///
268/// ```no_run
269/// fn main() -> std::io::Result<()> {
270///     let vec_at_most_64_bytes = read_restrict::read("foo.txt", 64)?;
271///     Ok(())
272/// }
273/// ```
274pub fn read<P: AsRef<Path>>(path: P, restriction: usize) -> io::Result<Vec<u8>> {
275    open_with_restriction(path.as_ref(), restriction, |mut file, size| {
276        let mut bytes = Vec::with_capacity(size);
277        file.read_to_end(&mut bytes)?;
278        Ok(bytes)
279    })
280}
281
282/// Read the entire contents of a file into a string, provided it fits within a
283/// specified size limit.
284///
285/// This is a restricted alternative to [`std::fs::read_to_string`](std::fs::read_to_string)
286/// with otherwise identical semantics.
287///
288/// # Examples
289///
290/// ```no_run
291/// fn main() -> std::io::Result<()> {
292///     let string_at_most_64_bytes = read_restrict::read_to_string("foo.txt", 64)?;
293///     Ok(())
294/// }
295/// ```
296pub fn read_to_string<P: AsRef<Path>>(path: P, restriction: usize) -> io::Result<String> {
297    open_with_restriction(path.as_ref(), restriction, |mut file, size| {
298        let mut string = String::with_capacity(size);
299        file.read_to_string(&mut string)?;
300        Ok(string)
301    })
302}
303
304#[cfg(test)]
305mod tests {
306    use super::{read, read_to_string, ReadExt};
307    use std::io::{self, BufRead, BufReader, Cursor, Read};
308
309    #[test]
310    fn test_read() {
311        let path = "Cargo.toml";
312        let size = std::fs::metadata(&path).unwrap().len() as usize;
313
314        assert_eq!(size, read(&path, size).unwrap().len());
315
316        assert_eq!(
317            io::ErrorKind::InvalidData,
318            read(&path, size - 1).unwrap_err().kind()
319        );
320    }
321
322    #[test]
323    fn test_read_to_string() {
324        let path = "Cargo.toml";
325        let size = std::fs::metadata(&path).unwrap().len() as usize;
326
327        assert_eq!(size, read_to_string(&path, size).unwrap().len());
328
329        assert_eq!(
330            io::ErrorKind::InvalidData,
331            read_to_string(&path, size - 1).unwrap_err().kind()
332        );
333    }
334
335    #[test]
336    fn restrict() {
337        let data = b"Stupidity is the same as evil if you judge by the results";
338        let mut f = Cursor::new(&data[..]).restrict(0);
339
340        // empty reads always succeed
341        let mut buf = [0; 0];
342        assert_eq!(0, f.read(&mut buf).unwrap());
343
344        let mut buf = [0; 1];
345        assert_eq!(
346            io::ErrorKind::InvalidData,
347            f.read(&mut buf).unwrap_err().kind()
348        );
349
350        // restriction can be dynamically adjusted
351        f.set_restriction(6);
352        let mut buf = [0; 8];
353        assert_eq!(6, f.read(&mut buf).unwrap());
354        assert_eq!(b"Stupid", &buf[..6]);
355        assert_eq!(
356            io::ErrorKind::InvalidData,
357            f.read(&mut buf).unwrap_err().kind()
358        );
359
360        // and leaves the reader in a consistent position
361        let mut f = BufReader::new(f.into_inner()).restrict(3);
362        assert_eq!(b"ity", f.fill_buf().unwrap());
363        f.consume(3);
364        assert_eq!(io::ErrorKind::InvalidData, f.fill_buf().unwrap_err().kind());
365    }
366
367    #[test]
368    fn restrict_err() {
369        struct R;
370
371        impl Read for R {
372            fn read(&mut self, _: &mut [u8]) -> io::Result<usize> {
373                Err(io::Error::new(io::ErrorKind::Other, ""))
374            }
375        }
376        impl BufRead for R {
377            fn fill_buf(&mut self) -> io::Result<&[u8]> {
378                Err(io::Error::new(io::ErrorKind::Other, ""))
379            }
380            fn consume(&mut self, _amt: usize) {}
381        }
382
383        let mut buf = [0; 1];
384        assert_eq!(
385            io::ErrorKind::InvalidData,
386            R.restrict(0).read(&mut buf).unwrap_err().kind()
387        );
388        assert_eq!(
389            io::ErrorKind::InvalidData,
390            R.restrict(0).fill_buf().unwrap_err().kind()
391        );
392    }
393}