use super::{BoxedNext, MiddlewareLayer};
use crate::error::ApiError;
use crate::request::Request;
use crate::response::{IntoResponse, Response};
use http::StatusCode;
use std::future::Future;
use std::pin::Pin;
pub const DEFAULT_BODY_LIMIT: usize = 1024 * 1024;
#[derive(Clone)]
pub struct BodyLimitLayer {
limit: usize,
}
impl BodyLimitLayer {
pub fn new(limit: usize) -> Self {
Self { limit }
}
pub fn default_limit() -> Self {
Self::new(DEFAULT_BODY_LIMIT)
}
pub fn limit(&self) -> usize {
self.limit
}
}
impl Default for BodyLimitLayer {
fn default() -> Self {
Self::default_limit()
}
}
impl MiddlewareLayer for BodyLimitLayer {
fn call(
&self,
req: Request,
next: BoxedNext,
) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> {
let limit = self.limit;
Box::pin(async move {
if let Some(content_length) = req.headers().get(http::header::CONTENT_LENGTH) {
if let Ok(length_str) = content_length.to_str() {
if let Ok(length) = length_str.parse::<usize>() {
if length > limit {
return ApiError::new(
StatusCode::PAYLOAD_TOO_LARGE,
"payload_too_large",
format!("Request body exceeds limit of {} bytes", limit),
)
.into_response();
}
}
}
}
if let crate::request::BodyVariant::Buffered(bytes) = &req.body {
if bytes.len() > limit {
return ApiError::new(
StatusCode::PAYLOAD_TOO_LARGE,
"payload_too_large",
format!("Request body exceeds limit of {} bytes", limit),
)
.into_response();
}
}
next(req).await
})
}
fn clone_box(&self) -> Box<dyn MiddlewareLayer> {
Box::new(self.clone())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::path_params::PathParams;
use crate::request::Request;
use bytes::Bytes;
use http::{Extensions, Method};
use proptest::prelude::*;
use std::sync::Arc;
fn create_test_request_with_body(body: Bytes) -> Request {
let uri: http::Uri = "/test".parse().unwrap();
let mut builder = http::Request::builder().method(Method::POST).uri(uri);
builder = builder.header(http::header::CONTENT_LENGTH, body.len().to_string());
let req = builder.body(()).unwrap();
let (parts, _) = req.into_parts();
Request::new(
parts,
crate::request::BodyVariant::Buffered(body),
Arc::new(Extensions::new()),
PathParams::new(),
)
}
fn create_test_request_without_content_length(body: Bytes) -> Request {
let uri: http::Uri = "/test".parse().unwrap();
let builder = http::Request::builder().method(Method::POST).uri(uri);
let req = builder.body(()).unwrap();
let (parts, _) = req.into_parts();
Request::new(
parts,
crate::request::BodyVariant::Buffered(body),
Arc::new(Extensions::new()),
PathParams::new(),
)
}
fn ok_handler() -> BoxedNext {
Arc::new(|_req: Request| {
Box::pin(async {
http::Response::builder()
.status(StatusCode::OK)
.body(crate::response::Body::from("ok"))
.unwrap()
}) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
})
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn prop_body_size_limit_enforcement(
limit in 1usize..10240usize,
body_size_factor in 0.5f64..2.0f64,
) {
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
let body_size = ((limit as f64) * body_size_factor) as usize;
let body = Bytes::from(vec![b'x'; body_size]);
let request = create_test_request_with_body(body.clone());
let layer = BodyLimitLayer::new(limit);
let handler = ok_handler();
let response = layer.call(request, handler).await;
if body_size > limit {
prop_assert_eq!(
response.status(),
StatusCode::PAYLOAD_TOO_LARGE,
"Expected 413 for body size {} > limit {}",
body_size,
limit
);
} else {
prop_assert_eq!(
response.status(),
StatusCode::OK,
"Expected 200 for body size {} <= limit {}",
body_size,
limit
);
}
Ok(())
})?;
}
#[test]
fn prop_body_limit_without_content_length_header(
limit in 1usize..10240usize,
body_size_factor in 0.5f64..2.0f64,
) {
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
let body_size = ((limit as f64) * body_size_factor) as usize;
let body = Bytes::from(vec![b'x'; body_size]);
let request = create_test_request_without_content_length(body.clone());
let layer = BodyLimitLayer::new(limit);
let handler = ok_handler();
let response = layer.call(request, handler).await;
if body_size > limit {
prop_assert_eq!(
response.status(),
StatusCode::PAYLOAD_TOO_LARGE,
"Expected 413 for body size {} > limit {} (no Content-Length)",
body_size,
limit
);
} else {
prop_assert_eq!(
response.status(),
StatusCode::OK,
"Expected 200 for body size {} <= limit {} (no Content-Length)",
body_size,
limit
);
}
Ok(())
})?;
}
}
#[tokio::test]
async fn test_body_at_exact_limit() {
let limit = 100;
let body = Bytes::from(vec![b'x'; limit]);
let request = create_test_request_with_body(body);
let layer = BodyLimitLayer::new(limit);
let handler = ok_handler();
let response = layer.call(request, handler).await;
assert_eq!(response.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_body_one_byte_over_limit() {
let limit = 100;
let body = Bytes::from(vec![b'x'; limit + 1]);
let request = create_test_request_with_body(body);
let layer = BodyLimitLayer::new(limit);
let handler = ok_handler();
let response = layer.call(request, handler).await;
assert_eq!(response.status(), StatusCode::PAYLOAD_TOO_LARGE);
}
#[tokio::test]
async fn test_body_one_byte_under_limit() {
let limit = 100;
let body = Bytes::from(vec![b'x'; limit - 1]);
let request = create_test_request_with_body(body);
let layer = BodyLimitLayer::new(limit);
let handler = ok_handler();
let response = layer.call(request, handler).await;
assert_eq!(response.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_empty_body() {
let limit = 100;
let body = Bytes::new();
let request = create_test_request_with_body(body);
let layer = BodyLimitLayer::new(limit);
let handler = ok_handler();
let response = layer.call(request, handler).await;
assert_eq!(response.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_default_limit() {
let layer = BodyLimitLayer::default();
assert_eq!(layer.limit(), DEFAULT_BODY_LIMIT);
}
#[test]
fn test_clone() {
let layer = BodyLimitLayer::new(1024);
let cloned = layer.clone();
assert_eq!(layer.limit(), cloned.limit());
}
}