1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126
use crate::http::{self, Mime}; use crate::internal_prelude::*; use std::mem; use async_trait::async_trait; use bytes::{BufMut, Bytes, BytesMut}; use futures::stream::StreamExt; use serde::Deserialize; #[derive(Debug, thiserror::Error)] pub enum BodyError { #[error("LengthLimitExceeded")] LengthLimitExceeded, #[error("InvalidFormat: {}", .source)] InvalidFormat { source: Error }, #[error("ContentTypeMismatch")] ContentTypeMismatch, } async fn to_bytes(mut body: Body, length_limit: usize) -> Result<Bytes> { let mut bufs: Vec<Bytes> = Vec::new(); let mut total: usize = 0; while let Some(bytes) = body.next().await.transpose()? { total = match total.checked_add(bytes.len()) { Some(t) if t <= length_limit => t, _ => return Err(BodyError::LengthLimitExceeded.into()), }; bufs.push(bytes); } let mut buf: BytesMut = BytesMut::with_capacity(total); for bytes in bufs { buf.put(bytes); } Ok(buf.freeze()) } fn parse_mime(req: &Request) -> Option<Mime> { req.as_ref_hyper() .headers() .get(http::header::CONTENT_TYPE)? .to_str() .ok()? .parse() .ok() } fn take_body(hreq: &mut HyperRequest) -> Body { mem::take(hreq.body_mut()) } pub struct FullBody(Bytes); #[derive(Debug, Clone)] pub struct JsonParser { length_limit: usize, } impl Default for JsonParser { fn default() -> Self { Self { length_limit: Self::DEFAULT_LENGTH_LIMIT, } } } impl JsonParser { const DEFAULT_LENGTH_LIMIT: usize = 32 * 1024; pub fn length_limit(&mut self, limit: usize) { self.length_limit = limit; } pub async fn parse<'r, T>(&self, req: &'r mut Request) -> Result<T> where T: Deserialize<'r>, { let ct_check = parse_mime(&req) .map(|mime| mime.type_() == mime::APPLICATION && mime.subtype() == mime::JSON) .unwrap_or(false); if !ct_check { return Err(BodyError::ContentTypeMismatch.into()); } { let hreq = req.as_mut_hyper(); if hreq.extensions().get::<FullBody>().is_none() { let full_body = FullBody(to_bytes(take_body(hreq), self.length_limit).await?); hreq.extensions_mut().insert(full_body); } } let full_body = req.as_ref_hyper().extensions().get::<FullBody>().unwrap(); match serde_json::from_slice(&*full_body.0) { Ok(value) => Ok(value), Err(e) => Err(BodyError::InvalidFormat { source: e.into() }.into()), } } } #[async_trait] pub trait JsonExt { async fn parse_json<'r, T: Deserialize<'r>>(&'r mut self, parser: &JsonParser) -> Result<T>; async fn json<'r, T: Deserialize<'r>>(&'r mut self) -> Result<T>; } #[async_trait] impl JsonExt for Request { async fn parse_json<'r, T: Deserialize<'r>>(&'r mut self, parser: &JsonParser) -> Result<T> { parser.parse(self).await } async fn json<'r, T: Deserialize<'r>>(&'r mut self) -> Result<T> { let parser = match self.as_ref_hyper().extensions().get::<JsonParser>() { Some(p) => p.clone(), None => JsonParser::default(), }; self.parse_json(&parser).await } }