use axum::{
extract::Request,
response::{Response, IntoResponse},
body::Body,
http::{StatusCode, HeaderValue},
};
use tracing::{warn, error};
use crate::{
middleware::{Middleware, BoxFuture},
HttpError,
};
#[derive(Debug, Clone)]
pub struct BodyLimitConfig {
pub max_size: usize,
pub log_oversized: bool,
pub error_message: String,
pub include_headers: bool,
}
impl Default for BodyLimitConfig {
fn default() -> Self {
Self {
max_size: 2 * 1024 * 1024, log_oversized: true,
error_message: "Request body too large".to_string(),
include_headers: true,
}
}
}
impl BodyLimitConfig {
pub fn new(max_size: usize) -> Self {
Self {
max_size,
..Default::default()
}
}
pub fn with_max_size(mut self, max_size: usize) -> Self {
self.max_size = max_size;
self
}
pub fn with_logging(mut self, log_oversized: bool) -> Self {
self.log_oversized = log_oversized;
self
}
pub fn with_message<S: Into<String>>(mut self, message: S) -> Self {
self.error_message = message.into();
self
}
pub fn with_headers(mut self, include_headers: bool) -> Self {
self.include_headers = include_headers;
self
}
}
pub struct BodyLimitMiddleware {
config: BodyLimitConfig,
}
impl BodyLimitMiddleware {
pub fn new() -> Self {
Self {
config: BodyLimitConfig::default(),
}
}
pub fn with_limit(max_size: usize) -> Self {
Self {
config: BodyLimitConfig::new(max_size),
}
}
pub fn with_config(config: BodyLimitConfig) -> Self {
Self { config }
}
pub fn max_size(mut self, size: usize) -> Self {
self.config = self.config.with_max_size(size);
self
}
pub fn logging(mut self, enabled: bool) -> Self {
self.config = self.config.with_logging(enabled);
self
}
pub fn message<S: Into<String>>(mut self, message: S) -> Self {
self.config = self.config.with_message(message);
self
}
pub fn limit(&self) -> usize {
self.config.max_size
}
fn create_error_response(&self, content_length: Option<usize>) -> Response {
let mut error = HttpError::payload_too_large(&self.config.error_message);
if self.config.include_headers {
if let Some(size) = content_length {
error = error.with_detail(&format!(
"Request body size {} bytes exceeds limit of {} bytes",
size,
self.config.max_size
));
} else {
error = error.with_detail(&format!(
"Request body exceeds limit of {} bytes",
self.config.max_size
));
}
}
let mut response = error.into_response();
if self.config.include_headers {
if let Ok(max_size_header) = HeaderValue::from_str(&self.config.max_size.to_string()) {
response.headers_mut().insert("X-Max-Body-Size", max_size_header);
}
}
response
}
fn check_content_length(&self, request: &Request) -> Result<Option<usize>, Response> {
if let Some(content_length) = request.headers().get("content-length") {
if let Ok(content_length_str) = content_length.to_str() {
if let Ok(content_length) = content_length_str.parse::<usize>() {
if content_length > self.config.max_size {
if self.config.log_oversized {
warn!(
"Request body size {} bytes exceeds limit of {} bytes (Content-Length check)",
content_length,
self.config.max_size
);
}
return Err(self.create_error_response(Some(content_length)));
}
return Ok(Some(content_length));
}
}
}
Ok(None)
}
}
impl Default for BodyLimitMiddleware {
fn default() -> Self {
Self::new()
}
}
impl Middleware for BodyLimitMiddleware {
fn process_request<'a>(
&'a self,
request: Request
) -> BoxFuture<'a, Result<Request, Response>> {
Box::pin(async move {
let content_length = match self.check_content_length(&request) {
Ok(length) => length,
Err(response) => return Err(response),
};
let mut request = request;
request.extensions_mut().insert(BodyLimitInfo {
max_size: self.config.max_size,
content_length,
error_message: self.config.error_message.clone(),
});
Ok(request)
})
}
fn process_response<'a>(
&'a self,
response: Response
) -> BoxFuture<'a, Response> {
Box::pin(async move {
if response.status() == StatusCode::PAYLOAD_TOO_LARGE && self.config.log_oversized {
warn!("Returned 413 Payload Too Large response due to body size limit");
}
response
})
}
fn name(&self) -> &'static str {
"BodyLimitMiddleware"
}
}
#[derive(Debug, Clone)]
pub struct BodyLimitInfo {
pub max_size: usize,
pub content_length: Option<usize>,
pub error_message: String,
}
pub fn limit_body_size(body: Body, max_size: usize) -> LimitedBody {
LimitedBody {
body,
max_size,
consumed: 0,
}
}
pub struct LimitedBody {
body: Body,
max_size: usize,
consumed: usize,
}
impl LimitedBody {
pub fn new(body: Body, max_size: usize) -> Self {
Self {
body,
max_size,
consumed: 0,
}
}
pub fn remaining(&self) -> usize {
self.max_size.saturating_sub(self.consumed)
}
pub fn consumed(&self) -> usize {
self.consumed
}
pub fn is_exceeded(&self) -> bool {
self.consumed > self.max_size
}
}
pub mod limits {
pub const KB: usize = 1024;
pub const MB: usize = 1024 * 1024;
pub const MB_10: usize = 10 * MB;
pub const MB_100: usize = 100 * MB;
pub const GB: usize = 1024 * MB;
pub mod presets {
use super::super::BodyLimitMiddleware;
use super::*;
pub fn small_api() -> BodyLimitMiddleware {
BodyLimitMiddleware::with_limit(MB)
.message("API request body too large (1MB limit)")
}
pub fn file_upload() -> BodyLimitMiddleware {
BodyLimitMiddleware::with_limit(MB_10)
.message("File upload too large (10MB limit)")
}
pub fn large_upload() -> BodyLimitMiddleware {
BodyLimitMiddleware::with_limit(MB_100)
.message("Large file upload too large (100MB limit)")
}
pub fn tiny() -> BodyLimitMiddleware {
BodyLimitMiddleware::with_limit(64 * KB)
.message("Request body too large (64KB limit)")
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use axum::http::{Method, HeaderValue};
#[tokio::test]
async fn test_body_limit_middleware_basic() {
let middleware = BodyLimitMiddleware::new();
let request = Request::builder()
.method(Method::POST)
.uri("/test")
.body(Body::empty())
.unwrap();
let result = middleware.process_request(request).await;
assert!(result.is_ok());
let processed_request = result.unwrap();
let body_limit_info = processed_request.extensions().get::<BodyLimitInfo>();
assert!(body_limit_info.is_some());
let body_limit_info = body_limit_info.unwrap();
assert_eq!(body_limit_info.max_size, 2 * 1024 * 1024); assert!(body_limit_info.content_length.is_none());
}
#[tokio::test]
async fn test_body_limit_middleware_custom_limit() {
let middleware = BodyLimitMiddleware::with_limit(1024);
assert_eq!(middleware.limit(), 1024);
}
#[tokio::test]
async fn test_body_limit_middleware_builder() {
let middleware = BodyLimitMiddleware::new()
.max_size(512)
.logging(false)
.message("Too big!");
assert_eq!(middleware.config.max_size, 512);
assert!(!middleware.config.log_oversized);
assert_eq!(middleware.config.error_message, "Too big!");
}
#[tokio::test]
async fn test_content_length_check_within_limit() {
let middleware = BodyLimitMiddleware::with_limit(1000);
let request = Request::builder()
.method(Method::POST)
.header("content-length", "500")
.uri("/test")
.body(Body::empty())
.unwrap();
let result = middleware.process_request(request).await;
assert!(result.is_ok());
let processed_request = result.unwrap();
let body_limit_info = processed_request.extensions().get::<BodyLimitInfo>().unwrap();
assert_eq!(body_limit_info.content_length, Some(500));
}
#[tokio::test]
async fn test_content_length_check_exceeds_limit() {
let middleware = BodyLimitMiddleware::with_limit(100);
let request = Request::builder()
.method(Method::POST)
.header("content-length", "200")
.uri("/test")
.body(Body::empty())
.unwrap();
let result = middleware.process_request(request).await;
assert!(result.is_err());
let error_response = result.unwrap_err();
assert_eq!(error_response.status(), StatusCode::PAYLOAD_TOO_LARGE);
assert!(error_response.headers().contains_key("X-Max-Body-Size"));
assert_eq!(
error_response.headers().get("X-Max-Body-Size").unwrap(),
"100"
);
}
#[tokio::test]
async fn test_body_limit_config() {
let config = BodyLimitConfig::new(512)
.with_logging(false)
.with_message("Custom message")
.with_headers(false);
let middleware = BodyLimitMiddleware::with_config(config);
assert_eq!(middleware.config.max_size, 512);
assert!(!middleware.config.log_oversized);
assert_eq!(middleware.config.error_message, "Custom message");
assert!(!middleware.config.include_headers);
}
#[tokio::test]
async fn test_body_limit_middleware_name() {
let middleware = BodyLimitMiddleware::new();
assert_eq!(middleware.name(), "BodyLimitMiddleware");
}
#[tokio::test]
async fn test_limited_body_creation() {
let body = Body::empty();
let limited = limit_body_size(body, 1024);
assert_eq!(limited.remaining(), 1024);
assert_eq!(limited.consumed(), 0);
assert!(!limited.is_exceeded());
}
#[tokio::test]
async fn test_body_limit_presets() {
let small = limits::presets::small_api();
assert_eq!(small.limit(), limits::MB);
let upload = limits::presets::file_upload();
assert_eq!(upload.limit(), limits::MB_10);
let large = limits::presets::large_upload();
assert_eq!(large.limit(), limits::MB_100);
let tiny = limits::presets::tiny();
assert_eq!(tiny.limit(), 64 * limits::KB);
}
#[tokio::test]
async fn test_body_limit_constants() {
assert_eq!(limits::KB, 1024);
assert_eq!(limits::MB, 1024 * 1024);
assert_eq!(limits::MB_10, 10 * 1024 * 1024);
assert_eq!(limits::MB_100, 100 * 1024 * 1024);
assert_eq!(limits::GB, 1024 * 1024 * 1024);
}
#[tokio::test]
async fn test_invalid_content_length_header() {
let middleware = BodyLimitMiddleware::with_limit(1000);
let request = Request::builder()
.method(Method::POST)
.header("content-length", "not-a-number")
.uri("/test")
.body(Body::empty())
.unwrap();
let result = middleware.process_request(request).await;
assert!(result.is_ok());
let processed_request = result.unwrap();
let body_limit_info = processed_request.extensions().get::<BodyLimitInfo>().unwrap();
assert!(body_limit_info.content_length.is_none());
}
}