1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
use crate::http::{self, Mime}; use crate::internal_prelude::*; use async_trait::async_trait; use bytes::Bytes; use serde::Deserialize; #[derive(Debug, thiserror::Error)] pub enum BodyError { #[error("LengthLimitExceeded")] LengthLimitExceeded, #[error("InvalidFormat: {}", .source)] InvalidFormat { source: Error }, #[error("ContentTypeMismatch")] ContentTypeMismatch, } fn parse_mime(req: &Request) -> Option<Mime> { req.headers() .get(http::header::CONTENT_TYPE)? .to_str() .ok()? .parse() .ok() } struct FullBody(Bytes); #[derive(Debug, Clone)] pub struct JsonParser { length_limit: usize, } impl Default for JsonParser { fn default() -> Self { Self { length_limit: Self::DEFAULT_LENGTH_LIMIT, } } } impl JsonParser { const DEFAULT_LENGTH_LIMIT: usize = 32 * 1024; pub fn length_limit(&mut self, limit: usize) { self.length_limit = limit; } pub async fn parse<'r, T>(&self, req: &'r mut Request) -> Result<T> where T: Deserialize<'r>, { let ct_check = parse_mime(&req) .map(|mime| mime.type_() == mime::APPLICATION && mime.subtype() == mime::JSON) .unwrap_or(false); if !ct_check { return Err(BodyError::ContentTypeMismatch.into()); } { if req.extensions().get::<FullBody>().is_none() { let full_body = FullBody(req.body_bytes(self.length_limit).await?); req.extensions_mut().insert(full_body); } } let full_body = req.extensions().get::<FullBody>().unwrap(); match serde_json::from_slice(&*full_body.0) { Ok(value) => Ok(value), Err(e) => Err(BodyError::InvalidFormat { source: e.into() }.into()), } } } #[async_trait] pub trait JsonExt { async fn parse_json<'r, T: Deserialize<'r>>(&'r mut self, parser: &JsonParser) -> Result<T>; async fn json<'r, T: Deserialize<'r>>(&'r mut self) -> Result<T>; } #[async_trait] impl JsonExt for Request { async fn parse_json<'r, T: Deserialize<'r>>(&'r mut self, parser: &JsonParser) -> Result<T> { parser.parse(self).await } async fn json<'r, T: Deserialize<'r>>(&'r mut self) -> Result<T> { let parser = match self.extensions().get::<JsonParser>() { Some(p) => p.clone(), None => JsonParser::default(), }; self.parse_json(&parser).await } }