actix_web_lab/
json.rs

1//! JSON extractor with const-generic payload size limit.
2
3use std::{
4    fmt,
5    marker::PhantomData,
6    pin::Pin,
7    task::{Context, Poll, ready},
8};
9
10use actix_web::{
11    FromRequest, HttpMessage, HttpRequest, ResponseError, dev::Payload, http::header, web,
12};
13use derive_more::{Display, Error};
14use futures_core::Stream as _;
15use http::StatusCode;
16use serde::de::DeserializeOwned;
17use tracing::debug;
18
19/// Default JSON payload size limit of 2MiB.
20pub const DEFAULT_JSON_LIMIT: usize = 2_097_152;
21
22/// JSON extractor with const-generic payload size limit.
23///
24/// `Json` is used to extract typed data from JSON request payloads.
25///
26/// # Extractor
27/// To extract typed data from a request body, the inner type `T` must implement the
28/// [`serde::Deserialize`] trait.
29///
30/// Use the `LIMIT` const generic parameter to control the payload size limit. The default limit
31/// that is exported (`DEFAULT_LIMIT`) is 2MiB.
32///
33/// ```
34/// use actix_web::{error, post, App, HttpRequest, HttpResponse, Responder};
35/// use actix_web_lab::extract::{Json, JsonPayloadError, DEFAULT_JSON_LIMIT};
36/// use serde::{Deserialize, Serialize};
37/// use serde_json::json;
38///
39/// #[derive(Deserialize, Serialize)]
40/// struct Info {
41///     username: String,
42/// }
43///
44/// /// Deserialize `Info` from request's body.
45/// #[post("/")]
46/// async fn index(info: Json<Info>) -> String {
47///     format!("Welcome {}!", info.username)
48/// }
49///
50/// const LIMIT_32_MB: usize = 33_554_432;
51///
52/// /// Deserialize payload with a higher 32MiB limit.
53/// #[post("/big-payload")]
54/// async fn big_payload(info: Json<Info, LIMIT_32_MB>) -> String {
55///     format!("Welcome {}!", info.username)
56/// }
57///
58/// /// Capture the error that may have occurred when deserializing the body.
59/// #[post("/normal-payload")]
60/// async fn normal_payload(
61///     res: Result<Json<Info>, JsonPayloadError>,
62///     req: HttpRequest,
63/// ) -> actix_web::Result<impl Responder> {
64///     let item = res.map_err(|err| {
65///         eprintln!("failed to deserialize JSON: {err}");
66///         let res = HttpResponse::BadGateway().json(json!({
67///             "error": "invalid_json",
68///             "detail": err.to_string(),
69///         }));
70///         error::InternalError::from_response(err, res)
71///     })?;
72///
73///     Ok(HttpResponse::Ok().json(item.0))
74/// }
75/// ```
76#[derive(Debug)]
77// #[derive(Debug, Deref, DerefMut, Display)]
78pub struct Json<T, const LIMIT: usize = DEFAULT_JSON_LIMIT>(pub T);
79
80mod waiting_on_derive_more_to_start_using_syn_2_due_to_proc_macro_panic {
81    use super::*;
82
83    impl<T, const LIMIT: usize> std::ops::Deref for Json<T, LIMIT> {
84        type Target = T;
85
86        fn deref(&self) -> &Self::Target {
87            &self.0
88        }
89    }
90
91    impl<T, const LIMIT: usize> std::ops::DerefMut for Json<T, LIMIT> {
92        fn deref_mut(&mut self) -> &mut Self::Target {
93            &mut self.0
94        }
95    }
96
97    impl<T: std::fmt::Display, const LIMIT: usize> std::fmt::Display for Json<T, LIMIT> {
98        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
99            std::fmt::Display::fmt(&self.0, f)
100        }
101    }
102}
103
104impl<T, const LIMIT: usize> Json<T, LIMIT> {
105    /// Unwraps into inner `T` value.
106    pub fn into_inner(self) -> T {
107        self.0
108    }
109}
110
111/// See [here](#extractor) for example of usage as an extractor.
112impl<T: DeserializeOwned, const LIMIT: usize> FromRequest for Json<T, LIMIT> {
113    type Error = JsonPayloadError;
114    type Future = JsonExtractFut<T, LIMIT>;
115
116    #[inline]
117    fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future {
118        JsonExtractFut {
119            req: Some(req.clone()),
120            fut: JsonBody::new(req, payload),
121        }
122    }
123}
124
125#[allow(missing_debug_implementations)]
126pub struct JsonExtractFut<T, const LIMIT: usize> {
127    req: Option<HttpRequest>,
128    fut: JsonBody<T, LIMIT>,
129}
130
131impl<T: DeserializeOwned, const LIMIT: usize> Future for JsonExtractFut<T, LIMIT> {
132    type Output = Result<Json<T, LIMIT>, JsonPayloadError>;
133
134    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
135        let this = self.get_mut();
136
137        let res = ready!(Pin::new(&mut this.fut).poll(cx));
138
139        let res = match res {
140            Err(err) => {
141                let req = this.req.take().unwrap();
142                debug!(
143                    "Failed to deserialize Json<{}> from payload in handler: {}",
144                    core::any::type_name::<T>(),
145                    req.match_name().unwrap_or_else(|| req.path())
146                );
147
148                Err(err)
149            }
150            Ok(data) => Ok(Json(data)),
151        };
152
153        Poll::Ready(res)
154    }
155}
156
157/// Future that resolves to some `T` when parsed from a JSON payload.
158///
159/// Can deserialize any type `T` that implements [`Deserialize`][serde::Deserialize].
160///
161/// Returns error if:
162/// - `Content-Type` is not `application/json`.
163/// - `Content-Length` is greater than `LIMIT`.
164/// - The payload, when consumed, is not valid JSON.
165pub enum JsonBody<T, const LIMIT: usize> {
166    Error(Option<JsonPayloadError>),
167    Body {
168        /// Length as reported by `Content-Length` header, if present.
169        #[allow(dead_code)]
170        length: Option<usize>,
171        // #[cfg(feature = "__compress")]
172        // payload: Decompress<Payload>,
173        // #[cfg(not(feature = "__compress"))]
174        payload: Payload,
175        buf: web::BytesMut,
176        _res: PhantomData<T>,
177    },
178}
179
180impl<T, const LIMIT: usize> Unpin for JsonBody<T, LIMIT> {}
181
182impl<T: DeserializeOwned, const LIMIT: usize> JsonBody<T, LIMIT> {
183    /// Create a new future to decode a JSON request payload.
184    pub fn new(req: &HttpRequest, payload: &mut Payload) -> Self {
185        // check content-type
186        let can_parse_json = if let Ok(Some(mime)) = req.mime_type() {
187            mime.subtype() == mime::JSON || mime.suffix() == Some(mime::JSON)
188        } else {
189            false
190        };
191
192        if !can_parse_json {
193            return JsonBody::Error(Some(JsonPayloadError::ContentType));
194        }
195
196        let length = req
197            .headers()
198            .get(&header::CONTENT_LENGTH)
199            .and_then(|l| l.to_str().ok())
200            .and_then(|s| s.parse::<usize>().ok());
201
202        let payload = payload.take();
203
204        if let Some(len) = length {
205            if len > LIMIT {
206                return JsonBody::Error(Some(JsonPayloadError::Overflow {
207                    limit: LIMIT,
208                    length: Some(len),
209                }));
210            }
211        }
212
213        JsonBody::Body {
214            length,
215            payload,
216            buf: web::BytesMut::with_capacity(8192),
217            _res: PhantomData,
218        }
219    }
220}
221
222impl<T: DeserializeOwned, const LIMIT: usize> Future for JsonBody<T, LIMIT> {
223    type Output = Result<T, JsonPayloadError>;
224
225    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
226        let this = self.get_mut();
227
228        match this {
229            JsonBody::Body { buf, payload, .. } => loop {
230                let res = ready!(Pin::new(&mut *payload).poll_next(cx));
231
232                match res {
233                    Some(chunk) => {
234                        let chunk =
235                            chunk.map_err(|err| JsonPayloadError::Payload { source: err })?;
236
237                        let buf_len = buf.len() + chunk.len();
238                        if buf_len > LIMIT {
239                            return Poll::Ready(Err(JsonPayloadError::Overflow {
240                                limit: LIMIT,
241                                length: None,
242                            }));
243                        } else {
244                            buf.extend_from_slice(&chunk);
245                        }
246                    }
247
248                    None => {
249                        let mut de = serde_json::Deserializer::from_slice(buf);
250                        let json = serde_path_to_error::deserialize(&mut de).map_err(|err| {
251                            JsonPayloadError::Deserialize {
252                                source: JsonDeserializeError {
253                                    path: err.path().clone(),
254                                    source: err.into_inner(),
255                                },
256                            }
257                        })?;
258
259                        return Poll::Ready(Ok(json));
260                    }
261                }
262            },
263
264            JsonBody::Error(err) => Poll::Ready(Err(err.take().unwrap())),
265        }
266    }
267}
268
269/// A set of errors that can occur during parsing json payloads
270#[derive(Debug, Display, Error)]
271#[non_exhaustive]
272pub enum JsonPayloadError {
273    /// Payload size is bigger than allowed header set.
274    #[display(
275        "JSON payload {}is larger than allowed (limit: {limit} bytes)",
276        length.map(|length| format!("({length} bytes) ")).unwrap_or("".to_owned()),
277    )]
278    Overflow {
279        /// Configured payload size limit.
280        limit: usize,
281
282        /// The Content-Length, if sent.
283        length: Option<usize>,
284    },
285
286    /// Content type error.
287    #[display("Content type error")]
288    ContentType,
289
290    /// Deserialization error.
291    #[display("Deserialization error")]
292    Deserialize {
293        /// Deserialization error.
294        source: JsonDeserializeError,
295    },
296
297    /// Payload error.
298    #[display("Error that occur during reading payload")]
299    Payload {
300        /// Payload error.
301        source: actix_web::error::PayloadError,
302    },
303}
304
305/// Deserialization errors that can occur during parsing query strings.
306#[derive(Debug, Error)]
307pub struct JsonDeserializeError {
308    /// Path where deserialization error occurred.
309    path: serde_path_to_error::Path,
310
311    /// Deserialization error.
312    source: serde_json::Error,
313}
314
315impl JsonDeserializeError {
316    /// Returns the path at which the deserialization error occurred.
317    pub fn path(&self) -> impl fmt::Display + '_ {
318        &self.path
319    }
320
321    /// Returns the source error.
322    pub fn source(&self) -> &serde_json::Error {
323        &self.source
324    }
325}
326
327impl fmt::Display for JsonDeserializeError {
328    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
329        f.write_str("JSON deserialization failed")?;
330
331        if self.path.iter().len() > 0 {
332            write!(f, " at path: {}", &self.path)?;
333        }
334
335        Ok(())
336    }
337}
338
339impl ResponseError for JsonPayloadError {
340    fn status_code(&self) -> StatusCode {
341        match self {
342            Self::Overflow { .. } => StatusCode::PAYLOAD_TOO_LARGE,
343            Self::Payload { source } => source.status_code(),
344            Self::Deserialize { source: err } if err.source().is_data() => {
345                StatusCode::UNPROCESSABLE_ENTITY
346            }
347            Self::Deserialize { .. } => StatusCode::BAD_REQUEST,
348            Self::ContentType => StatusCode::NOT_ACCEPTABLE,
349        }
350    }
351}
352
353#[cfg(test)]
354mod tests {
355    use actix_web::{http::header, test::TestRequest, web::Bytes};
356    use serde::Deserialize;
357
358    use super::*;
359
360    #[derive(Debug, PartialEq, Deserialize)]
361    struct MyObject {
362        name: String,
363    }
364
365    fn json_eq(err: JsonPayloadError, other: JsonPayloadError) -> bool {
366        match err {
367            JsonPayloadError::Overflow { .. } => {
368                matches!(other, JsonPayloadError::Overflow { .. })
369            }
370            JsonPayloadError::ContentType => matches!(other, JsonPayloadError::ContentType),
371            _ => false,
372        }
373    }
374
375    #[actix_web::test]
376    async fn test_extract() {
377        let (req, mut pl) = TestRequest::default()
378            .insert_header(header::ContentType::json())
379            .insert_header((
380                header::CONTENT_LENGTH,
381                header::HeaderValue::from_static("16"),
382            ))
383            .set_payload(Bytes::from_static(b"{\"name\": \"test\"}"))
384            .to_http_parts();
385
386        let s = Json::<MyObject, DEFAULT_JSON_LIMIT>::from_request(&req, &mut pl)
387            .await
388            .unwrap();
389        assert_eq!(s.name, "test");
390        assert_eq!(
391            s.into_inner(),
392            MyObject {
393                name: "test".to_string()
394            }
395        );
396
397        let (req, mut pl) = TestRequest::default()
398            .insert_header(header::ContentType::json())
399            .insert_header((
400                header::CONTENT_LENGTH,
401                header::HeaderValue::from_static("16"),
402            ))
403            .set_payload(Bytes::from_static(b"{\"name\": \"test\"}"))
404            .to_http_parts();
405
406        let res = Json::<MyObject, 10>::from_request(&req, &mut pl).await;
407        let err = res.unwrap_err();
408        assert_eq!(
409            "JSON payload (16 bytes) is larger than allowed (limit: 10 bytes)",
410            err.to_string(),
411        );
412
413        let (req, mut pl) = TestRequest::default()
414            .insert_header(header::ContentType::json())
415            .insert_header((
416                header::CONTENT_LENGTH,
417                header::HeaderValue::from_static("16"),
418            ))
419            .set_payload(Bytes::from_static(b"{\"name\": \"test\"}"))
420            .to_http_parts();
421        let s = Json::<MyObject, 10>::from_request(&req, &mut pl).await;
422        let err = s.unwrap_err();
423        assert!(
424            err.to_string().contains("larger than allowed"),
425            "unexpected error string: {err:?}"
426        );
427    }
428
429    #[actix_web::test]
430    async fn test_json_body() {
431        let (req, mut pl) = TestRequest::default().to_http_parts();
432        let json = JsonBody::<MyObject, DEFAULT_JSON_LIMIT>::new(&req, &mut pl).await;
433        assert!(json_eq(json.unwrap_err(), JsonPayloadError::ContentType));
434
435        let (req, mut pl) = TestRequest::default()
436            .insert_header((
437                header::CONTENT_TYPE,
438                header::HeaderValue::from_static("application/text"),
439            ))
440            .to_http_parts();
441        let json = JsonBody::<MyObject, DEFAULT_JSON_LIMIT>::new(&req, &mut pl).await;
442        assert!(json_eq(json.unwrap_err(), JsonPayloadError::ContentType));
443
444        let (req, mut pl) = TestRequest::default()
445            .insert_header(header::ContentType::json())
446            .insert_header((
447                header::CONTENT_LENGTH,
448                header::HeaderValue::from_static("10000"),
449            ))
450            .to_http_parts();
451
452        let json = JsonBody::<MyObject, 100>::new(&req, &mut pl).await;
453        assert!(json_eq(
454            json.unwrap_err(),
455            JsonPayloadError::Overflow {
456                limit: 100,
457                length: Some(10000),
458            }
459        ));
460
461        let (req, mut pl) = TestRequest::default()
462            .insert_header(header::ContentType::json())
463            .set_payload(Bytes::from_static(&[0u8; 1000]))
464            .to_http_parts();
465
466        let json = JsonBody::<MyObject, 100>::new(&req, &mut pl).await;
467
468        assert!(json_eq(
469            json.unwrap_err(),
470            JsonPayloadError::Overflow {
471                limit: 100,
472                length: None
473            }
474        ));
475
476        let (req, mut pl) = TestRequest::default()
477            .insert_header(header::ContentType::json())
478            .insert_header((
479                header::CONTENT_LENGTH,
480                header::HeaderValue::from_static("16"),
481            ))
482            .set_payload(Bytes::from_static(b"{\"name\": \"test\"}"))
483            .to_http_parts();
484
485        let json = JsonBody::<MyObject, DEFAULT_JSON_LIMIT>::new(&req, &mut pl).await;
486        assert_eq!(
487            json.ok().unwrap(),
488            MyObject {
489                name: "test".to_owned()
490            }
491        );
492    }
493
494    #[actix_web::test]
495    async fn test_with_json_and_bad_content_type() {
496        let (req, mut pl) = TestRequest::default()
497            .insert_header((
498                header::CONTENT_TYPE,
499                header::HeaderValue::from_static("text/plain"),
500            ))
501            .insert_header((
502                header::CONTENT_LENGTH,
503                header::HeaderValue::from_static("16"),
504            ))
505            .set_payload(Bytes::from_static(b"{\"name\": \"test\"}"))
506            .to_http_parts();
507
508        Json::<MyObject, 4096>::from_request(&req, &mut pl)
509            .await
510            .unwrap_err();
511    }
512
513    #[actix_web::test]
514    async fn test_with_config_in_data_wrapper() {
515        let (req, mut pl) = TestRequest::default()
516            .insert_header(header::ContentType::json())
517            .insert_header((header::CONTENT_LENGTH, 16))
518            .set_payload(Bytes::from_static(b"{\"name\": \"test\"}"))
519            .to_http_parts();
520
521        let res = Json::<MyObject, 10>::from_request(&req, &mut pl).await;
522        let err = res.unwrap_err();
523        assert_eq!(
524            "JSON payload (16 bytes) is larger than allowed (limit: 10 bytes)",
525            err.to_string(),
526        );
527    }
528
529    #[actix_web::test]
530    async fn json_deserialize_errors_contain_path() {
531        #[derive(Debug, PartialEq, Deserialize)]
532        struct Names {
533            names: Vec<String>,
534        }
535
536        let (req, mut pl) = TestRequest::default()
537            .insert_header(header::ContentType::json())
538            .set_payload(Bytes::from_static(b"{\"names\": [\"test\", 1]}"))
539            .to_http_parts();
540
541        let res = Json::<Names>::from_request(&req, &mut pl).await;
542        let err = res.unwrap_err();
543        match err {
544            JsonPayloadError::Deserialize { source: err } => {
545                assert_eq!("names[1]", err.path().to_string());
546            }
547            err => panic!("unexpected error variant: {err}"),
548        }
549    }
550}