async_read_length_limit/
lib.rs

1#![forbid(unsafe_code)]
2#![deny(
3    clippy::dbg_macro,
4    missing_copy_implementations,
5    rustdoc::missing_crate_level_docs,
6    missing_debug_implementations,
7    nonstandard_style,
8    unused_qualifications
9)]
10#![warn(missing_docs)]
11
12//! # async-read-length-limit
13//!
14//! Protects against a certain class of denial-of-service attacks wherein long chunked bodies are
15//! uploaded to web services. Can be applied to any [`AsyncRead`] type.
16//!
17//! # Examples
18//!
19//! ```rust
20//! use futures_lite::{io::Cursor, AsyncReadExt};
21//! use async_read_length_limit::LengthLimitExt;
22//!
23//! # futures_lite::future::block_on(async move {
24//! // input longer than limit returns an error and only reads bytes up to the limit
25//!
26//! let input_data = Cursor::new(b"these are the input data");
27//! let mut output_buf = Vec::new();
28//! let result = input_data.limit_bytes(5).read_to_end(&mut output_buf).await;
29//! assert!(result.is_err());
30//! assert_eq!(output_buf, b"these");
31//!
32//! // input shorter than limit reads transparently
33//!
34//! let input_data = Cursor::new(b"these are the input data");
35//! let mut output_buf = Vec::new();
36//! let result = input_data.limit_kb(1).read_to_end(&mut output_buf).await;
37//! assert!(result.is_ok());
38//! assert_eq!(output_buf, b"these are the input data");
39//! # });
40//! ```
41
42use futures_lite::AsyncRead;
43use std::{
44    error::Error,
45    fmt::Display,
46    io::{ErrorKind, Result},
47    pin::Pin,
48    task::{ready, Context, Poll},
49};
50
51pin_project_lite::pin_project! {
52    /// # [`AsyncRead`] length limiter
53    ///
54    /// The number of bytes will never be more than the provided byte limit. If the byte limit is
55    /// exactly the length of the contained AsyncRead, it is considered an error.
56    ///
57    /// # Errors
58    ///
59    /// This will return an error if the underlying AsyncRead does so or if the read length meets (or
60    /// would exceed) the provided length limit. The returned [`std::io::Error`] will have an error kind
61    /// of [`ErrorKind::InvalidData`] and a contained error of [`LengthLimitExceeded`].
62    #[derive(Debug, Clone, Copy)]
63    pub struct LengthLimit<T> {
64        #[pin]
65        reader:  T,
66        bytes_remaining: usize,
67    }
68}
69
70impl<T> LengthLimit<T>
71where
72    T: AsyncRead,
73{
74    /// Constructs a new [`LengthLimit`] with provided [`AsyncRead`] reader and `max_bytes` byte
75    /// length
76    pub fn new(reader: T, max_bytes: usize) -> Self {
77        Self {
78            reader,
79            bytes_remaining: max_bytes,
80        }
81    }
82
83    /// Returns the number of additional bytes before the limit is reached
84    pub fn bytes_remaining(&self) -> usize {
85        self.bytes_remaining
86    }
87
88    /// Unwraps the contained AsyncRead, allowing it to be read to completion. bytes remaining data
89    /// are discarded
90    pub fn into_inner(self) -> T {
91        self.reader
92    }
93}
94
95impl<T> AsRef<T> for LengthLimit<T> {
96    fn as_ref(&self) -> &T {
97        &self.reader
98    }
99}
100
101/// A unit error that represents a length overflow.
102///
103/// Contains no further information
104#[derive(Clone, Copy, Debug, PartialEq, Eq)]
105pub struct LengthLimitExceeded;
106impl Display for LengthLimitExceeded {
107    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
108        f.write_str("Length limit exceeded")
109    }
110}
111impl Error for LengthLimitExceeded {}
112impl From<LengthLimitExceeded> for std::io::Error {
113    fn from(value: LengthLimitExceeded) -> Self {
114        Self::new(ErrorKind::InvalidData, value)
115    }
116}
117
118impl<T: AsyncRead> AsyncRead for LengthLimit<T> {
119    fn poll_read(
120        self: Pin<&mut Self>,
121        cx: &mut Context<'_>,
122        mut buf: &mut [u8],
123    ) -> Poll<Result<usize>> {
124        let projection = self.project();
125        let reader = projection.reader;
126        let bytes_remaining = *projection.bytes_remaining;
127
128        if bytes_remaining == 0 {
129            return Poll::Ready(Err(LengthLimitExceeded.into()));
130        }
131
132        if bytes_remaining < buf.len() {
133            buf = &mut buf[..bytes_remaining];
134        }
135
136        let new_bytes = ready!(reader.poll_read(cx, buf))?;
137        *projection.bytes_remaining = bytes_remaining.saturating_sub(new_bytes);
138        Poll::Ready(Ok(new_bytes))
139    }
140}
141
142/// Extension trait to add length limiting behavior to any AsyncRead
143///
144/// Full explanation of the behavior at [`LengthLimit`]
145pub trait LengthLimitExt: Sized + AsyncRead {
146    /// Applies a LengthLimit to self with an exclusive maxiumum of `max_bytes` bytes
147    fn limit_bytes(self, max_bytes: usize) -> LengthLimit<Self> {
148        LengthLimit::new(self, max_bytes)
149    }
150
151    /// Applies a LengthLimit to self with an exclusive maxiumum of `max_kb` kilobytes (defined as
152    /// 1024 bytes)
153    fn limit_kb(self, max_kb: usize) -> LengthLimit<Self> {
154        self.limit_bytes(max_kb * 1024)
155    }
156
157    /// Applies a LengthLimit to self with an exclusive maxiumum of `max_mb` megabytes (defined as
158    /// 1024 kilobytes, or 1,048,576 bytes)
159    fn limit_mb(self, max_mb: usize) -> LengthLimit<Self> {
160        self.limit_kb(max_mb * 1024)
161    }
162
163    /// Applies a LengthLimit to self with an exclusive maxiumum of `max_gb` kilobytes (defined as
164    /// 1024 megabytes, or 1,073,741,824 bytes)
165    fn limit_gb(self, max_gb: usize) -> LengthLimit<Self> {
166        self.limit_mb(max_gb * 1024)
167    }
168}
169
170impl<T> LengthLimitExt for T where T: AsyncRead + Unpin {}