read_restrict/lib.rs
1//! # read-restrict
2//!
3//! An adaptor around Rust's standard [`io::Take`] which instead of returning
4//! `Ok(0)` when the read limit is exceeded, instead returns an error of of the kind
5//! [`ErrorKind::InvalidData`].
6//!
7//! This is intended for enforcing explicit input limits when simply truncating with
8//! `take` could result in incorrect behaviour.
9//!
10//! `read_restrict` also offers restricted variants of
11//! [`std::fs::read`](std::fs::read) and
12//! [`std::fs::read_to_string`](std::fs::read_to_string), to conveniently
13//! prevent unbounded reads of overly-large files.
14//!
15//! # Examples
16//!
17//! ```no_run
18//! use std::io::{self, Read, ErrorKind};
19//!
20//! use read_restrict::ReadExt;
21//!
22//! fn main() -> io::Result<()> {
23//! let f = std::fs::File::open("foo.txt")?;
24//! let mut handle = f.restrict(5);
25//! let mut buf = [0; 8];
26//! assert_eq!(5, handle.read(&mut buf)?); // reads at most 5 bytes
27//! assert_eq!(0, handle.restriction()); // is now exhausted
28//! assert_eq!(ErrorKind::InvalidData, handle.read(&mut buf).unwrap_err().kind());
29//! Ok(())
30//! }
31//! ```
32//!
33//! ```no_run
34//! fn load_config(path: &std::path::Path) -> std::io::Result<String> {
35//! // No sensible configuration is going to exceed 640 KiB
36//! let conf = read_restrict::read_to_string(&path, 640 * 1024)?;
37//! // probably want to parse it here
38//! Ok(conf)
39//! }
40//! ```
41//!
42//! [`io::Take`]: https://doc.rust-lang.org/std/io/struct.Take.html
43//! [`ErrorKind::InvalidData`]: https://doc.rust-lang.org/std/io/enum.ErrorKind.html#variant.InvalidData
44
45use std::fs::File;
46use std::io::{self, BufRead, Read, Result, Take};
47use std::path::Path;
48
49pub trait ReadExt {
50 fn restrict(self, restriction: u64) -> Restrict<Self>
51 where
52 Self: Sized + Read,
53 {
54 Restrict {
55 inner: self.take(restriction),
56 }
57 }
58}
59
60impl<R: Read> ReadExt for R {}
61
62/// Reader adaptor which restricts the bytes read from an underlying reader,
63/// returning an IO error of the kind [`ErrorKind::InvalidData`] when it is exceeded.
64///
65/// This struct is generally created by calling [`restrict`] on a reader.
66/// Please see the documentation of [`restrict`] for more details.
67///
68/// [`restrict`]: trait.ReadExt.html#method.restrict
69/// [`ErrorKind::InvalidData`]: https://doc.rust-lang.org/std/io/enum.ErrorKind.html#variant.InvalidData
70#[derive(Debug)]
71pub struct Restrict<T> {
72 inner: Take<T>,
73}
74
75impl<T> Restrict<T> {
76 /// Returns the number of bytes that can be read before this instance will
77 /// return an error.
78 ///
79 /// # Examples
80 ///
81 /// ```no_run
82 /// use std::io;
83 /// use std::io::prelude::*;
84 /// use read_restrict::ReadExt;
85 /// use std::fs::File;
86 ///
87 /// fn main() -> io::Result<()> {
88 /// let f = File::open("foo.txt")?;
89 ///
90 /// // read at most five bytes
91 /// let handle = f.restrict(5);
92 ///
93 /// println!("restriction: {}", handle.restriction());
94 /// Ok(())
95 /// }
96 /// ```
97 pub fn restriction(&self) -> u64 {
98 self.inner.limit()
99 }
100
101 /// Sets the number of bytes that can be read before this instance will
102 /// return an error. This is the same as constructing a new `Restrict` instance, so
103 /// the amount of bytes read and the previous restriction value don't matter when
104 /// calling this method.
105 ///
106 /// # Examples
107 ///
108 /// ```no_run
109 /// use std::io;
110 /// use std::io::prelude::*;
111 /// use std::fs::File;
112 /// use read_restrict::ReadExt;
113 ///
114 /// fn main() -> io::Result<()> {
115 /// let f = File::open("foo.txt")?;
116 ///
117 /// // read at most five bytes
118 /// let mut handle = f.restrict(5);
119 /// handle.set_restriction(10);
120 ///
121 /// assert_eq!(handle.restriction(), 10);
122 /// Ok(())
123 /// }
124 /// ```
125 pub fn set_restriction(&mut self, restriction: u64) {
126 self.inner.set_limit(restriction);
127 }
128
129 /// Consumes the `Restrict`, returning the wrapped reader.
130 ///
131 /// # Examples
132 ///
133 /// ```no_run
134 /// use std::io;
135 /// use std::io::prelude::*;
136 /// use std::fs::File;
137 /// use read_restrict::ReadExt;
138 ///
139 /// fn main() -> io::Result<()> {
140 /// let mut file = File::open("foo.txt")?;
141 ///
142 /// let mut buffer = [0; 5];
143 /// let mut handle = file.restrict(5);
144 /// handle.read(&mut buffer)?;
145 ///
146 /// let file = handle.into_inner();
147 /// Ok(())
148 /// }
149 /// ```
150 pub fn into_inner(self) -> T {
151 self.inner.into_inner()
152 }
153
154 /// Gets a reference to the underlying reader.
155 ///
156 /// # Examples
157 ///
158 /// ```no_run
159 /// use std::io;
160 /// use std::io::prelude::*;
161 /// use std::fs::File;
162 /// use read_restrict::ReadExt;
163 ///
164 /// fn main() -> io::Result<()> {
165 /// let mut file = File::open("foo.txt")?;
166 ///
167 /// let mut buffer = [0; 5];
168 /// let mut handle = file.restrict(5);
169 /// handle.read(&mut buffer)?;
170 ///
171 /// let file = handle.get_ref();
172 /// Ok(())
173 /// }
174 /// ```
175 pub fn get_ref(&self) -> &T {
176 self.inner.get_ref()
177 }
178
179 /// Gets a mutable reference to the underlying reader.
180 ///
181 /// Care should be taken to avoid modifying the internal I/O state of the
182 /// underlying reader as doing so may corrupt the internal limit of this
183 /// `Restrict`.
184 ///
185 /// # Examples
186 ///
187 /// ```no_run
188 /// use std::io;
189 /// use std::io::prelude::*;
190 /// use std::fs::File;
191 /// use read_restrict::ReadExt;
192 ///
193 /// fn main() -> io::Result<()> {
194 /// let mut file = File::open("foo.txt")?;
195 ///
196 /// let mut buffer = [0; 5];
197 /// let mut handle = file.restrict(5);
198 /// handle.read(&mut buffer)?;
199 ///
200 /// let file = handle.get_mut();
201 /// Ok(())
202 /// }
203 /// ```
204 pub fn get_mut(&mut self) -> &mut T {
205 self.inner.get_mut()
206 }
207}
208
209impl<T: Read> Read for Restrict<T> {
210 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
211 if !buf.is_empty() && self.restriction() == 0 {
212 return Err(io::Error::new(
213 io::ErrorKind::InvalidData,
214 "Read restriction exceeded",
215 ));
216 }
217
218 self.inner.read(&mut buf[..])
219 }
220}
221
222impl<T: BufRead> BufRead for Restrict<T> {
223 fn fill_buf(&mut self) -> Result<&[u8]> {
224 if self.restriction() == 0 {
225 return Err(io::Error::new(
226 io::ErrorKind::InvalidData,
227 "Read restriction exceeded",
228 ));
229 }
230
231 self.inner.fill_buf()
232 }
233
234 fn consume(&mut self, amt: usize) {
235 self.inner.consume(amt);
236 }
237}
238
239/// Provided the file at `path` fits within the specified limit, pass a
240/// restricted read handle and a suitable initial buffer size to the closure
241/// and return its result.
242fn open_with_restriction<F, T>(path: &Path, restriction: usize, f: F) -> io::Result<T>
243where
244 F: FnOnce(Restrict<File>, usize) -> io::Result<T>,
245{
246 let file = File::open(path)?;
247 let size = match file.metadata().map(|m| m.len()) {
248 Ok(size) if size > restriction as u64 => {
249 return Err(io::Error::new(
250 io::ErrorKind::InvalidData,
251 "File exceeds size restriction",
252 ))
253 }
254 Ok(size) => (size as usize).saturating_add(1),
255 Err(_) => 0,
256 };
257 f(file.restrict(restriction as u64 + 1), size)
258}
259
260/// Read the entire contents of a file into a bytes vector, provided it fits
261/// within a specified size limit.
262///
263/// This is a restricted alternative to [`std::fs::read`](std::fs::read)
264/// with otherwise identical semantics.
265///
266/// # Examples
267///
268/// ```no_run
269/// fn main() -> std::io::Result<()> {
270/// let vec_at_most_64_bytes = read_restrict::read("foo.txt", 64)?;
271/// Ok(())
272/// }
273/// ```
274pub fn read<P: AsRef<Path>>(path: P, restriction: usize) -> io::Result<Vec<u8>> {
275 open_with_restriction(path.as_ref(), restriction, |mut file, size| {
276 let mut bytes = Vec::with_capacity(size);
277 file.read_to_end(&mut bytes)?;
278 Ok(bytes)
279 })
280}
281
282/// Read the entire contents of a file into a string, provided it fits within a
283/// specified size limit.
284///
285/// This is a restricted alternative to [`std::fs::read_to_string`](std::fs::read_to_string)
286/// with otherwise identical semantics.
287///
288/// # Examples
289///
290/// ```no_run
291/// fn main() -> std::io::Result<()> {
292/// let string_at_most_64_bytes = read_restrict::read_to_string("foo.txt", 64)?;
293/// Ok(())
294/// }
295/// ```
296pub fn read_to_string<P: AsRef<Path>>(path: P, restriction: usize) -> io::Result<String> {
297 open_with_restriction(path.as_ref(), restriction, |mut file, size| {
298 let mut string = String::with_capacity(size);
299 file.read_to_string(&mut string)?;
300 Ok(string)
301 })
302}
303
304#[cfg(test)]
305mod tests {
306 use super::{read, read_to_string, ReadExt};
307 use std::io::{self, BufRead, BufReader, Cursor, Read};
308
309 #[test]
310 fn test_read() {
311 let path = "Cargo.toml";
312 let size = std::fs::metadata(&path).unwrap().len() as usize;
313
314 assert_eq!(size, read(&path, size).unwrap().len());
315
316 assert_eq!(
317 io::ErrorKind::InvalidData,
318 read(&path, size - 1).unwrap_err().kind()
319 );
320 }
321
322 #[test]
323 fn test_read_to_string() {
324 let path = "Cargo.toml";
325 let size = std::fs::metadata(&path).unwrap().len() as usize;
326
327 assert_eq!(size, read_to_string(&path, size).unwrap().len());
328
329 assert_eq!(
330 io::ErrorKind::InvalidData,
331 read_to_string(&path, size - 1).unwrap_err().kind()
332 );
333 }
334
335 #[test]
336 fn restrict() {
337 let data = b"Stupidity is the same as evil if you judge by the results";
338 let mut f = Cursor::new(&data[..]).restrict(0);
339
340 // empty reads always succeed
341 let mut buf = [0; 0];
342 assert_eq!(0, f.read(&mut buf).unwrap());
343
344 let mut buf = [0; 1];
345 assert_eq!(
346 io::ErrorKind::InvalidData,
347 f.read(&mut buf).unwrap_err().kind()
348 );
349
350 // restriction can be dynamically adjusted
351 f.set_restriction(6);
352 let mut buf = [0; 8];
353 assert_eq!(6, f.read(&mut buf).unwrap());
354 assert_eq!(b"Stupid", &buf[..6]);
355 assert_eq!(
356 io::ErrorKind::InvalidData,
357 f.read(&mut buf).unwrap_err().kind()
358 );
359
360 // and leaves the reader in a consistent position
361 let mut f = BufReader::new(f.into_inner()).restrict(3);
362 assert_eq!(b"ity", f.fill_buf().unwrap());
363 f.consume(3);
364 assert_eq!(io::ErrorKind::InvalidData, f.fill_buf().unwrap_err().kind());
365 }
366
367 #[test]
368 fn restrict_err() {
369 struct R;
370
371 impl Read for R {
372 fn read(&mut self, _: &mut [u8]) -> io::Result<usize> {
373 Err(io::Error::new(io::ErrorKind::Other, ""))
374 }
375 }
376 impl BufRead for R {
377 fn fill_buf(&mut self) -> io::Result<&[u8]> {
378 Err(io::Error::new(io::ErrorKind::Other, ""))
379 }
380 fn consume(&mut self, _amt: usize) {}
381 }
382
383 let mut buf = [0; 1];
384 assert_eq!(
385 io::ErrorKind::InvalidData,
386 R.restrict(0).read(&mut buf).unwrap_err().kind()
387 );
388 assert_eq!(
389 io::ErrorKind::InvalidData,
390 R.restrict(0).fill_buf().unwrap_err().kind()
391 );
392 }
393}