actix_web_lab/
bytes.rs

1//! Bytes extractor with const-generic payload size limit.
2//!
3//! See docs for [`Bytes`].
4
5use std::{
6    pin::Pin,
7    task::{Context, Poll, ready},
8};
9
10use actix_web::{FromRequest, HttpMessage, HttpRequest, ResponseError, dev, http::StatusCode, web};
11use derive_more::{Display, Error};
12use futures_core::Stream as _;
13use tracing::debug;
14
15/// Default bytes payload size limit of 4MiB.
16pub const DEFAULT_BYTES_LIMIT: usize = 4_194_304;
17
18/// Bytes extractor with const-generic payload size limit.
19///
20/// # Extractor
21/// Extracts raw bytes from a request body, even if it.
22///
23/// Use the `LIMIT` const generic parameter to control the payload size limit. The default limit
24/// that is exported (`DEFAULT_LIMIT`) is 4MiB.
25///
26/// # Differences from `actix_web::web::Bytes`
27/// - Does not read `PayloadConfig` from app data.
28/// - Supports const-generic size limits.
29/// - Will not automatically decompress request bodies.
30///
31/// # Examples
32/// ```
33/// use actix_web::{App, post};
34/// use actix_web_lab::extract::{Bytes, DEFAULT_BYTES_LIMIT};
35///
36/// /// Deserialize `Info` from request's body.
37/// #[post("/")]
38/// async fn index(info: Bytes) -> String {
39///     format!("Payload up to 4MiB: {info:?}!")
40/// }
41///
42/// const LIMIT_32_MB: usize = 33_554_432;
43///
44/// /// Deserialize payload with a higher 32MiB limit.
45/// #[post("/big-payload")]
46/// async fn big_payload(info: Bytes<LIMIT_32_MB>) -> String {
47///     format!("Payload up to 32MiB: {info:?}!")
48/// }
49/// ```
50#[derive(Debug)]
51// #[derive(Debug, Deref, DerefMut, AsRef, AsMut)]
52pub struct Bytes<const LIMIT: usize = DEFAULT_BYTES_LIMIT>(pub web::Bytes);
53
54mod waiting_on_derive_more_to_start_using_syn_2_due_to_proc_macro_panic {
55    use super::*;
56
57    impl<const LIMIT: usize> std::ops::Deref for Bytes<LIMIT> {
58        type Target = web::Bytes;
59
60        fn deref(&self) -> &Self::Target {
61            &self.0
62        }
63    }
64
65    impl<const LIMIT: usize> std::ops::DerefMut for Bytes<LIMIT> {
66        fn deref_mut(&mut self) -> &mut Self::Target {
67            &mut self.0
68        }
69    }
70
71    impl<const LIMIT: usize> AsRef<web::Bytes> for Bytes<LIMIT> {
72        fn as_ref(&self) -> &web::Bytes {
73            &self.0
74        }
75    }
76
77    impl<const LIMIT: usize> AsMut<web::Bytes> for Bytes<LIMIT> {
78        fn as_mut(&mut self) -> &mut web::Bytes {
79            &mut self.0
80        }
81    }
82}
83
84impl<const LIMIT: usize> Bytes<LIMIT> {
85    /// Unwraps into inner `Bytes`.
86    pub fn into_inner(self) -> web::Bytes {
87        self.0
88    }
89}
90
91/// See [here](#extractor) for example of usage as an extractor.
92impl<const LIMIT: usize> FromRequest for Bytes<LIMIT> {
93    type Error = actix_web::Error;
94    type Future = BytesExtractFut<LIMIT>;
95
96    #[inline]
97    fn from_request(req: &HttpRequest, payload: &mut dev::Payload) -> Self::Future {
98        BytesExtractFut {
99            req: Some(req.clone()),
100            fut: BytesBody::new(req, payload),
101        }
102    }
103}
104
105#[allow(missing_debug_implementations)]
106pub struct BytesExtractFut<const LIMIT: usize> {
107    req: Option<HttpRequest>,
108    fut: BytesBody<LIMIT>,
109}
110
111impl<const LIMIT: usize> Future for BytesExtractFut<LIMIT> {
112    type Output = actix_web::Result<Bytes<LIMIT>>;
113
114    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
115        let this = self.get_mut();
116
117        let res = ready!(Pin::new(&mut this.fut).poll(cx));
118
119        let res = match res {
120            Err(err) => {
121                let req = this.req.take().unwrap();
122
123                debug!(
124                    "Failed to extract Bytes from payload in handler: {}",
125                    req.match_name().unwrap_or_else(|| req.path())
126                );
127
128                Err(err.into())
129            }
130            Ok(data) => Ok(Bytes(data)),
131        };
132
133        Poll::Ready(res)
134    }
135}
136
137/// Future that resolves to `Bytes` when the payload is been completely read.
138///
139/// Returns error if:
140/// - `Content-Length` is greater than `LIMIT`.
141pub enum BytesBody<const LIMIT: usize> {
142    Error(Option<BytesPayloadError>),
143    Body {
144        /// Length as reported by `Content-Length` header, if present.
145        #[allow(dead_code)]
146        length: Option<usize>,
147        payload: dev::Payload,
148        buf: web::BytesMut,
149    },
150}
151
152impl<const LIMIT: usize> Unpin for BytesBody<LIMIT> {}
153
154impl<const LIMIT: usize> BytesBody<LIMIT> {
155    /// Create a new future to decode a JSON request payload.
156    pub fn new(req: &HttpRequest, payload: &mut dev::Payload) -> Self {
157        let payload = payload.take();
158
159        let length = req
160            .get_header::<crate::header::ContentLength>()
161            .map(|cl| cl.into_inner());
162
163        // Notice the content-length is not checked against limit here as the internal usage always
164        // call BytesBody::limit after BytesBody::new and limit check to return an error variant of
165        // BytesBody happens there.
166
167        if let Some(len) = length {
168            if len > LIMIT {
169                return BytesBody::Error(Some(BytesPayloadError::OverflowKnownLength {
170                    length: len,
171                    limit: LIMIT,
172                }));
173            }
174        }
175
176        BytesBody::Body {
177            length,
178            payload,
179            buf: web::BytesMut::with_capacity(8192),
180        }
181    }
182}
183
184impl<const LIMIT: usize> Future for BytesBody<LIMIT> {
185    type Output = Result<web::Bytes, BytesPayloadError>;
186
187    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
188        let this = self.get_mut();
189
190        match this {
191            BytesBody::Body { buf, payload, .. } => loop {
192                let res = ready!(Pin::new(&mut *payload).poll_next(cx));
193
194                match res {
195                    Some(chunk) => {
196                        let chunk = chunk?;
197                        let buf_len = buf.len() + chunk.len();
198                        if buf_len > LIMIT {
199                            return Poll::Ready(Err(BytesPayloadError::Overflow { limit: LIMIT }));
200                        } else {
201                            buf.extend_from_slice(&chunk);
202                        }
203                    }
204
205                    None => return Poll::Ready(Ok(buf.split().freeze())),
206                }
207            },
208
209            BytesBody::Error(err) => Poll::Ready(Err(err.take().unwrap())),
210        }
211    }
212}
213
214/// A set of errors that can occur during parsing json payloads
215#[derive(Debug, Display, Error)]
216#[non_exhaustive]
217pub enum BytesPayloadError {
218    /// Payload size is bigger than allowed & content length header set. (default: 4MiB)
219    #[display("Payload ({length} bytes) is larger than allowed (limit: {limit} bytes).")]
220    OverflowKnownLength { length: usize, limit: usize },
221
222    /// Payload size is bigger than allowed but no content length header set. (default: 4MiB)
223    #[display("Payload has exceeded limit ({limit} bytes).")]
224    Overflow { limit: usize },
225
226    /// Payload error.
227    #[display("Error that occur during reading payload: {_0}")]
228    Payload(actix_web::error::PayloadError),
229}
230
231impl From<actix_web::error::PayloadError> for BytesPayloadError {
232    fn from(err: actix_web::error::PayloadError) -> Self {
233        Self::Payload(err)
234    }
235}
236
237impl ResponseError for BytesPayloadError {
238    fn status_code(&self) -> StatusCode {
239        match self {
240            Self::OverflowKnownLength { .. } => StatusCode::PAYLOAD_TOO_LARGE,
241            Self::Overflow { .. } => StatusCode::PAYLOAD_TOO_LARGE,
242            Self::Payload(err) => err.status_code(),
243        }
244    }
245}
246
247#[cfg(test)]
248mod tests {
249    use actix_web::{http::header, test::TestRequest, web};
250
251    use super::*;
252
253    #[cfg(test)]
254    impl PartialEq for BytesPayloadError {
255        fn eq(&self, other: &Self) -> bool {
256            match (self, other) {
257                (
258                    Self::OverflowKnownLength {
259                        length: l_length,
260                        limit: l_limit,
261                    },
262                    Self::OverflowKnownLength {
263                        length: r_length,
264                        limit: r_limit,
265                    },
266                ) => l_length == r_length && l_limit == r_limit,
267
268                (Self::Overflow { limit: l_limit }, Self::Overflow { limit: r_limit }) => {
269                    l_limit == r_limit
270                }
271
272                _ => false,
273            }
274        }
275    }
276
277    #[actix_web::test]
278    async fn extract() {
279        let (req, mut pl) = TestRequest::default()
280            .insert_header(header::ContentType::json())
281            .insert_header(crate::header::ContentLength::from(3))
282            .set_payload(web::Bytes::from_static(b"foo"))
283            .to_http_parts();
284
285        let s = Bytes::<DEFAULT_BYTES_LIMIT>::from_request(&req, &mut pl)
286            .await
287            .unwrap();
288        assert_eq!(s.as_ref(), "foo");
289
290        let (req, mut pl) = TestRequest::default()
291            .insert_header(header::ContentType::json())
292            .insert_header(crate::header::ContentLength::from(16))
293            .set_payload(web::Bytes::from_static(b"foo foo foo foo"))
294            .to_http_parts();
295
296        let s = Bytes::<10>::from_request(&req, &mut pl).await;
297        let err_str = s.unwrap_err().to_string();
298        assert_eq!(
299            err_str,
300            "Payload (16 bytes) is larger than allowed (limit: 10 bytes).",
301        );
302
303        let (req, mut pl) = TestRequest::default()
304            .insert_header(header::ContentType::json())
305            .insert_header(crate::header::ContentLength::from(16))
306            .set_payload(web::Bytes::from_static(b"foo foo foo foo"))
307            .to_http_parts();
308        let s = Bytes::<10>::from_request(&req, &mut pl).await;
309        let err = format!("{}", s.unwrap_err());
310        assert!(
311            err.contains("larger than allowed"),
312            "unexpected error string: {err:?}",
313        );
314    }
315
316    #[actix_web::test]
317    async fn body() {
318        let (req, mut pl) = TestRequest::default().to_http_parts();
319        let _bytes = BytesBody::<DEFAULT_BYTES_LIMIT>::new(&req, &mut pl)
320            .await
321            .unwrap();
322
323        let (req, mut pl) = TestRequest::default()
324            .insert_header(header::ContentType("application/text".parse().unwrap()))
325            .to_http_parts();
326        // content-type doesn't matter
327        BytesBody::<DEFAULT_BYTES_LIMIT>::new(&req, &mut pl)
328            .await
329            .unwrap();
330
331        let (req, mut pl) = TestRequest::default()
332            .insert_header(header::ContentType::json())
333            .insert_header(crate::header::ContentLength::from(10000))
334            .to_http_parts();
335
336        let bytes = BytesBody::<100>::new(&req, &mut pl).await;
337        assert_eq!(
338            bytes.unwrap_err(),
339            BytesPayloadError::OverflowKnownLength {
340                length: 10000,
341                limit: 100
342            }
343        );
344
345        let (req, mut pl) = TestRequest::default()
346            .insert_header(header::ContentType::json())
347            .set_payload(web::Bytes::from_static(&[0u8; 1000]))
348            .to_http_parts();
349
350        let bytes = BytesBody::<100>::new(&req, &mut pl).await;
351
352        assert_eq!(
353            bytes.unwrap_err(),
354            BytesPayloadError::Overflow { limit: 100 }
355        );
356
357        let (req, mut pl) = TestRequest::default()
358            .insert_header(header::ContentType::json())
359            .insert_header(crate::header::ContentLength::from(16))
360            .set_payload(web::Bytes::from_static(b"foo foo foo foo"))
361            .to_http_parts();
362
363        let bytes = BytesBody::<DEFAULT_BYTES_LIMIT>::new(&req, &mut pl).await;
364        assert_eq!(bytes.ok().unwrap(), "foo foo foo foo");
365    }
366
367    #[actix_web::test]
368    async fn test_with_config_in_data_wrapper() {
369        let (req, mut pl) = TestRequest::default()
370            .app_data(web::Data::new(web::PayloadConfig::default().limit(8)))
371            .insert_header(header::ContentType::json())
372            .insert_header((header::CONTENT_LENGTH, 16))
373            .set_payload(web::Bytes::from_static(b"{\"name\": \"test\"}"))
374            .to_http_parts();
375
376        let s = Bytes::<10>::from_request(&req, &mut pl).await;
377        assert!(s.is_err());
378
379        let err_str = s.unwrap_err().to_string();
380        assert_eq!(
381            err_str,
382            "Payload (16 bytes) is larger than allowed (limit: 10 bytes).",
383        );
384    }
385}