actix_web_lab/
body_limit.rs

1//! Body limit extractor.
2//!
3//! See [`BodyLimit`] docs.
4
5use std::{
6    fmt,
7    pin::Pin,
8    task::{Context, Poll, ready},
9};
10
11use actix_web::{
12    FromRequest, HttpMessage as _, HttpRequest, ResponseError,
13    dev::{self, Payload},
14};
15use derive_more::Display;
16use futures_core::Stream as _;
17
18use crate::header::ContentLength;
19
20/// Default body size limit of 2MiB.
21pub const DEFAULT_BODY_LIMIT: usize = 2_097_152;
22
23/// Extractor wrapper that limits size of payload used.
24///
25/// # Examples
26/// ```no_run
27/// use actix_web::{Responder, get, web::Bytes};
28/// use actix_web_lab::extract::BodyLimit;
29///
30/// const BODY_LIMIT: usize = 1_048_576; // 1MB
31///
32/// #[get("/")]
33/// async fn handler(body: BodyLimit<Bytes, BODY_LIMIT>) -> impl Responder {
34///     let body = body.into_inner();
35///     assert!(body.len() < BODY_LIMIT);
36///     body
37/// }
38/// ```
39#[derive(Debug, PartialEq, Eq)]
40pub struct BodyLimit<T, const LIMIT: usize = DEFAULT_BODY_LIMIT> {
41    inner: T,
42}
43
44mod waiting_on_derive_more_to_start_using_syn_2_due_to_proc_macro_panic {
45    use super::*;
46
47    impl<T: std::fmt::Display, const LIMIT: usize> std::fmt::Display for BodyLimit<T, LIMIT> {
48        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49            std::fmt::Display::fmt(&self.inner, f)
50        }
51    }
52
53    impl<T, const LIMIT: usize> AsRef<T> for BodyLimit<T, LIMIT> {
54        fn as_ref(&self) -> &T {
55            &self.inner
56        }
57    }
58
59    impl<T, const LIMIT: usize> From<T> for BodyLimit<T, LIMIT> {
60        fn from(inner: T) -> Self {
61            Self { inner }
62        }
63    }
64}
65
66impl<T, const LIMIT: usize> BodyLimit<T, LIMIT> {
67    /// Returns inner extracted type.
68    pub fn into_inner(self) -> T {
69        self.inner
70    }
71}
72
73impl<T, const LIMIT: usize> FromRequest for BodyLimit<T, LIMIT>
74where
75    T: FromRequest + 'static,
76    T::Error: fmt::Debug + fmt::Display,
77{
78    type Error = BodyLimitError<T>;
79    type Future = BodyLimitFut<T, LIMIT>;
80
81    fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future {
82        // fast check of Content-Length header
83        match req.get_header::<ContentLength>() {
84            // CL header indicated that payload would be too large
85            Some(len) if len > LIMIT => return BodyLimitFut::new_error(BodyLimitError::Overflow),
86            _ => {}
87        }
88
89        let counter = crate::util::fork_request_payload(payload);
90
91        BodyLimitFut {
92            inner: Inner::Body {
93                fut: Box::pin(T::from_request(req, payload)),
94                counter_pl: counter,
95                size: 0,
96            },
97        }
98    }
99}
100
101#[allow(missing_debug_implementations)]
102pub struct BodyLimitFut<T, const LIMIT: usize>
103where
104    T: FromRequest + 'static,
105    T::Error: fmt::Debug + fmt::Display,
106{
107    inner: Inner<T, LIMIT>,
108}
109
110impl<T, const LIMIT: usize> BodyLimitFut<T, LIMIT>
111where
112    T: FromRequest + 'static,
113    T::Error: fmt::Debug + fmt::Display,
114{
115    fn new_error(err: BodyLimitError<T>) -> Self {
116        Self {
117            inner: Inner::Error { err: Some(err) },
118        }
119    }
120}
121
122enum Inner<T, const LIMIT: usize>
123where
124    T: FromRequest + 'static,
125    T::Error: fmt::Debug + fmt::Display,
126{
127    Error {
128        err: Option<BodyLimitError<T>>,
129    },
130
131    Body {
132        /// Wrapped extractor future.
133        fut: Pin<Box<T::Future>>,
134
135        /// Forked request payload.
136        counter_pl: dev::Payload,
137
138        /// Running payload size count.
139        size: usize,
140    },
141}
142
143impl<T, const LIMIT: usize> Unpin for Inner<T, LIMIT>
144where
145    T: FromRequest + 'static,
146    T::Error: fmt::Debug + fmt::Display,
147{
148}
149
150impl<T, const LIMIT: usize> Future for BodyLimitFut<T, LIMIT>
151where
152    T: FromRequest + 'static,
153    T::Error: fmt::Debug + fmt::Display,
154{
155    type Output = Result<BodyLimit<T, LIMIT>, BodyLimitError<T>>;
156
157    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
158        let this = &mut self.get_mut().inner;
159
160        match this {
161            Inner::Error { err } => Poll::Ready(Err(err.take().unwrap())),
162
163            Inner::Body {
164                fut,
165                counter_pl,
166                size,
167            } => {
168                // poll inner extractor first which also polls original payload stream
169                let res = ready!(fut.as_mut().poll(cx).map_err(BodyLimitError::Extractor)?);
170
171                // catch up with payload length counter checks
172                while let Poll::Ready(Some(Ok(chunk))) = Pin::new(&mut *counter_pl).poll_next(cx) {
173                    // update running size
174                    *size += chunk.len();
175
176                    if *size > LIMIT {
177                        return Poll::Ready(Err(BodyLimitError::Overflow));
178                    }
179                }
180
181                let ret = BodyLimit { inner: res };
182
183                Poll::Ready(Ok(ret))
184            }
185        }
186    }
187}
188
189#[derive(Display)]
190pub enum BodyLimitError<T>
191where
192    T: FromRequest + 'static,
193    T::Error: fmt::Debug + fmt::Display,
194{
195    #[display("Wrapped extractor error: {_0}")]
196    Extractor(T::Error),
197
198    #[display("Body was too large")]
199    Overflow,
200}
201
202impl<T> fmt::Debug for BodyLimitError<T>
203where
204    T: FromRequest + 'static,
205    T::Error: fmt::Debug + fmt::Display,
206{
207    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
208        match self {
209            Self::Extractor(err) => f
210                .debug_tuple("BodyLimitError::Extractor")
211                .field(err)
212                .finish(),
213
214            Self::Overflow => write!(f, "BodyLimitError::Overflow"),
215        }
216    }
217}
218
219impl<T> ResponseError for BodyLimitError<T>
220where
221    T: FromRequest + 'static,
222    T::Error: fmt::Debug + fmt::Display,
223{
224}
225
226#[cfg(test)]
227mod tests {
228    use actix_web::{http::header, test::TestRequest};
229    use bytes::Bytes;
230
231    use super::*;
232
233    static_assertions::assert_impl_all!(BodyLimitFut<(), 100>: Unpin);
234    static_assertions::assert_impl_all!(BodyLimitFut<Bytes, 100>: Unpin);
235
236    #[actix_web::test]
237    async fn within_limit() {
238        let (req, mut pl) = TestRequest::default()
239            .insert_header(header::ContentType::plaintext())
240            .insert_header((
241                header::CONTENT_LENGTH,
242                header::HeaderValue::from_static("9"),
243            ))
244            .set_payload(Bytes::from_static(b"123456789"))
245            .to_http_parts();
246
247        let body = BodyLimit::<Bytes, 10>::from_request(&req, &mut pl).await;
248        assert_eq!(
249            body.ok().unwrap().into_inner(),
250            Bytes::from_static(b"123456789")
251        );
252    }
253
254    #[actix_web::test]
255    async fn exceeds_limit() {
256        let (req, mut pl) = TestRequest::default()
257            .insert_header(header::ContentType::plaintext())
258            .insert_header((
259                header::CONTENT_LENGTH,
260                header::HeaderValue::from_static("10"),
261            ))
262            .set_payload(Bytes::from_static(b"0123456789"))
263            .to_http_parts();
264
265        let body = BodyLimit::<Bytes, 4>::from_request(&req, &mut pl).await;
266        assert!(matches!(body.unwrap_err(), BodyLimitError::Overflow));
267
268        let (req, mut pl) = TestRequest::default()
269            .insert_header(header::ContentType::plaintext())
270            .insert_header((
271                header::TRANSFER_ENCODING,
272                header::HeaderValue::from_static("chunked"),
273            ))
274            .set_payload(Bytes::from_static(b"10\r\n0123456789\r\n0"))
275            .to_http_parts();
276
277        let body = BodyLimit::<Bytes, 4>::from_request(&req, &mut pl).await;
278        assert!(matches!(body.unwrap_err(), BodyLimitError::Overflow));
279    }
280}