use bytes::Bytes;
use http::header::CONTENT_LENGTH;
use http_body_util::{BodyExt, Limited};
use ubyte::ByteUnit;
use crate::blueprint::constructor::{Constructor, RegisteredConstructor};
use crate::blueprint::Blueprint;
use crate::{f, request::body::errors::SizeLimitExceeded, request::RequestHead};
use super::{
errors::{ExtractBufferedBodyError, UnexpectedBufferError},
BodySizeLimit, RawIncomingBody,
};
#[derive(Debug)]
#[non_exhaustive]
pub struct BufferedBody {
pub bytes: Bytes,
}
impl BufferedBody {
pub async fn extract(
request_head: &RequestHead,
body: RawIncomingBody,
body_size_limit: BodySizeLimit,
) -> Result<Self, ExtractBufferedBodyError> {
match body_size_limit {
BodySizeLimit::Enabled { max_size } => {
Self::_extract_with_limit(request_head, body, max_size).await
}
BodySizeLimit::Disabled => match body.collect().await {
Ok(collected) => Ok(Self {
bytes: collected.to_bytes(),
}),
Err(e) => Err(UnexpectedBufferError { source: e.into() }.into()),
},
}
}
pub fn register(bp: &mut Blueprint) -> RegisteredConstructor {
Self::default_constructor().register(bp)
}
pub fn default_constructor() -> Constructor {
Constructor::request_scoped(f!(super::BufferedBody::extract))
.error_handler(f!(super::errors::ExtractBufferedBodyError::into_response))
}
async fn _extract_with_limit<B>(
request_head: &RequestHead,
body: B,
max_size: ByteUnit,
) -> Result<Self, ExtractBufferedBodyError>
where
B: hyper::body::Body,
B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
{
let content_length = request_head
.headers
.get(CONTENT_LENGTH)
.and_then(|value| value.to_str().ok()?.parse::<usize>().ok());
let limit_error = || SizeLimitExceeded {
max_size,
content_length,
};
if let Some(len) = content_length {
if len > max_size {
return Err(limit_error().into());
}
}
let max_n_bytes = max_size.as_u64().try_into().unwrap_or(usize::MAX);
let limited_body = Limited::new(body, max_n_bytes);
match limited_body.collect().await {
Ok(collected) => Ok(Self {
bytes: collected.to_bytes(),
}),
Err(e) => {
if e.downcast_ref::<http_body_util::LengthLimitError>()
.is_some()
{
Err(limit_error().into())
} else {
Err(UnexpectedBufferError { source: e }.into())
}
}
}
}
}
impl From<BufferedBody> for Bytes {
fn from(buffered_body: BufferedBody) -> Self {
buffered_body.bytes
}
}
#[cfg(test)]
mod tests {
use http::HeaderMap;
use ubyte::ToByteUnit;
use crate::request::RequestHead;
use super::{BufferedBody, Bytes};
fn dummy_request_head() -> RequestHead {
RequestHead {
method: http::Method::GET,
target: "/".parse().unwrap(),
version: http::Version::HTTP_11,
headers: HeaderMap::new(),
}
}
#[tokio::test]
async fn error_if_body_above_size_limit_without_content_length() {
let raw_body = vec![0; 1000];
let max_n_bytes = 100.bytes();
assert!(raw_body.len() > max_n_bytes.as_u64() as usize);
let body = crate::response::body::raw::Full::new(Bytes::from(raw_body));
let err = BufferedBody::_extract_with_limit(&dummy_request_head(), body, max_n_bytes)
.await
.unwrap_err();
insta::assert_snapshot!(err, @"The request body is larger than the maximum size limit enforced by this server.");
insta::assert_debug_snapshot!(err, @r###"
SizeLimitExceeded(
SizeLimitExceeded {
max_size: ByteUnit(
100,
),
content_length: None,
},
)
"###);
}
#[tokio::test]
async fn error_if_content_length_header_is_larger_than_limit() {
let mut request_head = dummy_request_head();
let max_size = 100.bytes();
let body = crate::response::body::raw::Full::new(Bytes::from(vec![0; 500]));
request_head
.headers
.insert("Content-Length", "1000".parse().unwrap());
let err = BufferedBody::_extract_with_limit(&request_head, body, max_size)
.await
.unwrap_err();
insta::assert_snapshot!(err, @"The request body is larger than the maximum size limit enforced by this server.");
insta::assert_debug_snapshot!(err, @r###"
SizeLimitExceeded(
SizeLimitExceeded {
max_size: ByteUnit(
100,
),
content_length: Some(
1000,
),
},
)
"###);
}
}