1use std::{
4 fmt,
5 marker::PhantomData,
6 pin::Pin,
7 task::{Context, Poll, ready},
8};
9
10use actix_web::{
11 FromRequest, HttpMessage, HttpRequest, ResponseError, dev::Payload, http::header, web,
12};
13use derive_more::{Display, Error};
14use futures_core::Stream as _;
15use http::StatusCode;
16use serde::de::DeserializeOwned;
17use tracing::debug;
18
19pub const DEFAULT_JSON_LIMIT: usize = 2_097_152;
21
22#[derive(Debug)]
77pub struct Json<T, const LIMIT: usize = DEFAULT_JSON_LIMIT>(pub T);
79
80mod waiting_on_derive_more_to_start_using_syn_2_due_to_proc_macro_panic {
81 use super::*;
82
83 impl<T, const LIMIT: usize> std::ops::Deref for Json<T, LIMIT> {
84 type Target = T;
85
86 fn deref(&self) -> &Self::Target {
87 &self.0
88 }
89 }
90
91 impl<T, const LIMIT: usize> std::ops::DerefMut for Json<T, LIMIT> {
92 fn deref_mut(&mut self) -> &mut Self::Target {
93 &mut self.0
94 }
95 }
96
97 impl<T: std::fmt::Display, const LIMIT: usize> std::fmt::Display for Json<T, LIMIT> {
98 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
99 std::fmt::Display::fmt(&self.0, f)
100 }
101 }
102}
103
104impl<T, const LIMIT: usize> Json<T, LIMIT> {
105 pub fn into_inner(self) -> T {
107 self.0
108 }
109}
110
111impl<T: DeserializeOwned, const LIMIT: usize> FromRequest for Json<T, LIMIT> {
113 type Error = JsonPayloadError;
114 type Future = JsonExtractFut<T, LIMIT>;
115
116 #[inline]
117 fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future {
118 JsonExtractFut {
119 req: Some(req.clone()),
120 fut: JsonBody::new(req, payload),
121 }
122 }
123}
124
125#[allow(missing_debug_implementations)]
126pub struct JsonExtractFut<T, const LIMIT: usize> {
127 req: Option<HttpRequest>,
128 fut: JsonBody<T, LIMIT>,
129}
130
131impl<T: DeserializeOwned, const LIMIT: usize> Future for JsonExtractFut<T, LIMIT> {
132 type Output = Result<Json<T, LIMIT>, JsonPayloadError>;
133
134 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
135 let this = self.get_mut();
136
137 let res = ready!(Pin::new(&mut this.fut).poll(cx));
138
139 let res = match res {
140 Err(err) => {
141 let req = this.req.take().unwrap();
142 debug!(
143 "Failed to deserialize Json<{}> from payload in handler: {}",
144 core::any::type_name::<T>(),
145 req.match_name().unwrap_or_else(|| req.path())
146 );
147
148 Err(err)
149 }
150 Ok(data) => Ok(Json(data)),
151 };
152
153 Poll::Ready(res)
154 }
155}
156
157pub enum JsonBody<T, const LIMIT: usize> {
166 Error(Option<JsonPayloadError>),
167 Body {
168 #[allow(dead_code)]
170 length: Option<usize>,
171 payload: Payload,
175 buf: web::BytesMut,
176 _res: PhantomData<T>,
177 },
178}
179
180impl<T, const LIMIT: usize> Unpin for JsonBody<T, LIMIT> {}
181
182impl<T: DeserializeOwned, const LIMIT: usize> JsonBody<T, LIMIT> {
183 pub fn new(req: &HttpRequest, payload: &mut Payload) -> Self {
185 let can_parse_json = if let Ok(Some(mime)) = req.mime_type() {
187 mime.subtype() == mime::JSON || mime.suffix() == Some(mime::JSON)
188 } else {
189 false
190 };
191
192 if !can_parse_json {
193 return JsonBody::Error(Some(JsonPayloadError::ContentType));
194 }
195
196 let length = req
197 .headers()
198 .get(&header::CONTENT_LENGTH)
199 .and_then(|l| l.to_str().ok())
200 .and_then(|s| s.parse::<usize>().ok());
201
202 let payload = payload.take();
203
204 if let Some(len) = length {
205 if len > LIMIT {
206 return JsonBody::Error(Some(JsonPayloadError::Overflow {
207 limit: LIMIT,
208 length: Some(len),
209 }));
210 }
211 }
212
213 JsonBody::Body {
214 length,
215 payload,
216 buf: web::BytesMut::with_capacity(8192),
217 _res: PhantomData,
218 }
219 }
220}
221
222impl<T: DeserializeOwned, const LIMIT: usize> Future for JsonBody<T, LIMIT> {
223 type Output = Result<T, JsonPayloadError>;
224
225 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
226 let this = self.get_mut();
227
228 match this {
229 JsonBody::Body { buf, payload, .. } => loop {
230 let res = ready!(Pin::new(&mut *payload).poll_next(cx));
231
232 match res {
233 Some(chunk) => {
234 let chunk =
235 chunk.map_err(|err| JsonPayloadError::Payload { source: err })?;
236
237 let buf_len = buf.len() + chunk.len();
238 if buf_len > LIMIT {
239 return Poll::Ready(Err(JsonPayloadError::Overflow {
240 limit: LIMIT,
241 length: None,
242 }));
243 } else {
244 buf.extend_from_slice(&chunk);
245 }
246 }
247
248 None => {
249 let mut de = serde_json::Deserializer::from_slice(buf);
250 let json = serde_path_to_error::deserialize(&mut de).map_err(|err| {
251 JsonPayloadError::Deserialize {
252 source: JsonDeserializeError {
253 path: err.path().clone(),
254 source: err.into_inner(),
255 },
256 }
257 })?;
258
259 return Poll::Ready(Ok(json));
260 }
261 }
262 },
263
264 JsonBody::Error(err) => Poll::Ready(Err(err.take().unwrap())),
265 }
266 }
267}
268
269#[derive(Debug, Display, Error)]
271#[non_exhaustive]
272pub enum JsonPayloadError {
273 #[display(
275 "JSON payload {}is larger than allowed (limit: {limit} bytes)",
276 length.map(|length| format!("({length} bytes) ")).unwrap_or("".to_owned()),
277 )]
278 Overflow {
279 limit: usize,
281
282 length: Option<usize>,
284 },
285
286 #[display("Content type error")]
288 ContentType,
289
290 #[display("Deserialization error")]
292 Deserialize {
293 source: JsonDeserializeError,
295 },
296
297 #[display("Error that occur during reading payload")]
299 Payload {
300 source: actix_web::error::PayloadError,
302 },
303}
304
305#[derive(Debug, Error)]
307pub struct JsonDeserializeError {
308 path: serde_path_to_error::Path,
310
311 source: serde_json::Error,
313}
314
315impl JsonDeserializeError {
316 pub fn path(&self) -> impl fmt::Display + '_ {
318 &self.path
319 }
320
321 pub fn source(&self) -> &serde_json::Error {
323 &self.source
324 }
325}
326
327impl fmt::Display for JsonDeserializeError {
328 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
329 f.write_str("JSON deserialization failed")?;
330
331 if self.path.iter().len() > 0 {
332 write!(f, " at path: {}", &self.path)?;
333 }
334
335 Ok(())
336 }
337}
338
339impl ResponseError for JsonPayloadError {
340 fn status_code(&self) -> StatusCode {
341 match self {
342 Self::Overflow { .. } => StatusCode::PAYLOAD_TOO_LARGE,
343 Self::Payload { source } => source.status_code(),
344 Self::Deserialize { source: err } if err.source().is_data() => {
345 StatusCode::UNPROCESSABLE_ENTITY
346 }
347 Self::Deserialize { .. } => StatusCode::BAD_REQUEST,
348 Self::ContentType => StatusCode::NOT_ACCEPTABLE,
349 }
350 }
351}
352
353#[cfg(test)]
354mod tests {
355 use actix_web::{http::header, test::TestRequest, web::Bytes};
356 use serde::Deserialize;
357
358 use super::*;
359
360 #[derive(Debug, PartialEq, Deserialize)]
361 struct MyObject {
362 name: String,
363 }
364
365 fn json_eq(err: JsonPayloadError, other: JsonPayloadError) -> bool {
366 match err {
367 JsonPayloadError::Overflow { .. } => {
368 matches!(other, JsonPayloadError::Overflow { .. })
369 }
370 JsonPayloadError::ContentType => matches!(other, JsonPayloadError::ContentType),
371 _ => false,
372 }
373 }
374
375 #[actix_web::test]
376 async fn test_extract() {
377 let (req, mut pl) = TestRequest::default()
378 .insert_header(header::ContentType::json())
379 .insert_header((
380 header::CONTENT_LENGTH,
381 header::HeaderValue::from_static("16"),
382 ))
383 .set_payload(Bytes::from_static(b"{\"name\": \"test\"}"))
384 .to_http_parts();
385
386 let s = Json::<MyObject, DEFAULT_JSON_LIMIT>::from_request(&req, &mut pl)
387 .await
388 .unwrap();
389 assert_eq!(s.name, "test");
390 assert_eq!(
391 s.into_inner(),
392 MyObject {
393 name: "test".to_string()
394 }
395 );
396
397 let (req, mut pl) = TestRequest::default()
398 .insert_header(header::ContentType::json())
399 .insert_header((
400 header::CONTENT_LENGTH,
401 header::HeaderValue::from_static("16"),
402 ))
403 .set_payload(Bytes::from_static(b"{\"name\": \"test\"}"))
404 .to_http_parts();
405
406 let res = Json::<MyObject, 10>::from_request(&req, &mut pl).await;
407 let err = res.unwrap_err();
408 assert_eq!(
409 "JSON payload (16 bytes) is larger than allowed (limit: 10 bytes)",
410 err.to_string(),
411 );
412
413 let (req, mut pl) = TestRequest::default()
414 .insert_header(header::ContentType::json())
415 .insert_header((
416 header::CONTENT_LENGTH,
417 header::HeaderValue::from_static("16"),
418 ))
419 .set_payload(Bytes::from_static(b"{\"name\": \"test\"}"))
420 .to_http_parts();
421 let s = Json::<MyObject, 10>::from_request(&req, &mut pl).await;
422 let err = s.unwrap_err();
423 assert!(
424 err.to_string().contains("larger than allowed"),
425 "unexpected error string: {err:?}"
426 );
427 }
428
429 #[actix_web::test]
430 async fn test_json_body() {
431 let (req, mut pl) = TestRequest::default().to_http_parts();
432 let json = JsonBody::<MyObject, DEFAULT_JSON_LIMIT>::new(&req, &mut pl).await;
433 assert!(json_eq(json.unwrap_err(), JsonPayloadError::ContentType));
434
435 let (req, mut pl) = TestRequest::default()
436 .insert_header((
437 header::CONTENT_TYPE,
438 header::HeaderValue::from_static("application/text"),
439 ))
440 .to_http_parts();
441 let json = JsonBody::<MyObject, DEFAULT_JSON_LIMIT>::new(&req, &mut pl).await;
442 assert!(json_eq(json.unwrap_err(), JsonPayloadError::ContentType));
443
444 let (req, mut pl) = TestRequest::default()
445 .insert_header(header::ContentType::json())
446 .insert_header((
447 header::CONTENT_LENGTH,
448 header::HeaderValue::from_static("10000"),
449 ))
450 .to_http_parts();
451
452 let json = JsonBody::<MyObject, 100>::new(&req, &mut pl).await;
453 assert!(json_eq(
454 json.unwrap_err(),
455 JsonPayloadError::Overflow {
456 limit: 100,
457 length: Some(10000),
458 }
459 ));
460
461 let (req, mut pl) = TestRequest::default()
462 .insert_header(header::ContentType::json())
463 .set_payload(Bytes::from_static(&[0u8; 1000]))
464 .to_http_parts();
465
466 let json = JsonBody::<MyObject, 100>::new(&req, &mut pl).await;
467
468 assert!(json_eq(
469 json.unwrap_err(),
470 JsonPayloadError::Overflow {
471 limit: 100,
472 length: None
473 }
474 ));
475
476 let (req, mut pl) = TestRequest::default()
477 .insert_header(header::ContentType::json())
478 .insert_header((
479 header::CONTENT_LENGTH,
480 header::HeaderValue::from_static("16"),
481 ))
482 .set_payload(Bytes::from_static(b"{\"name\": \"test\"}"))
483 .to_http_parts();
484
485 let json = JsonBody::<MyObject, DEFAULT_JSON_LIMIT>::new(&req, &mut pl).await;
486 assert_eq!(
487 json.ok().unwrap(),
488 MyObject {
489 name: "test".to_owned()
490 }
491 );
492 }
493
494 #[actix_web::test]
495 async fn test_with_json_and_bad_content_type() {
496 let (req, mut pl) = TestRequest::default()
497 .insert_header((
498 header::CONTENT_TYPE,
499 header::HeaderValue::from_static("text/plain"),
500 ))
501 .insert_header((
502 header::CONTENT_LENGTH,
503 header::HeaderValue::from_static("16"),
504 ))
505 .set_payload(Bytes::from_static(b"{\"name\": \"test\"}"))
506 .to_http_parts();
507
508 Json::<MyObject, 4096>::from_request(&req, &mut pl)
509 .await
510 .unwrap_err();
511 }
512
513 #[actix_web::test]
514 async fn test_with_config_in_data_wrapper() {
515 let (req, mut pl) = TestRequest::default()
516 .insert_header(header::ContentType::json())
517 .insert_header((header::CONTENT_LENGTH, 16))
518 .set_payload(Bytes::from_static(b"{\"name\": \"test\"}"))
519 .to_http_parts();
520
521 let res = Json::<MyObject, 10>::from_request(&req, &mut pl).await;
522 let err = res.unwrap_err();
523 assert_eq!(
524 "JSON payload (16 bytes) is larger than allowed (limit: 10 bytes)",
525 err.to_string(),
526 );
527 }
528
529 #[actix_web::test]
530 async fn json_deserialize_errors_contain_path() {
531 #[derive(Debug, PartialEq, Deserialize)]
532 struct Names {
533 names: Vec<String>,
534 }
535
536 let (req, mut pl) = TestRequest::default()
537 .insert_header(header::ContentType::json())
538 .set_payload(Bytes::from_static(b"{\"names\": [\"test\", 1]}"))
539 .to_http_parts();
540
541 let res = Json::<Names>::from_request(&req, &mut pl).await;
542 let err = res.unwrap_err();
543 match err {
544 JsonPayloadError::Deserialize { source: err } => {
545 assert_eq!("names[1]", err.path().to_string());
546 }
547 err => panic!("unexpected error variant: {err}"),
548 }
549 }
550}