use crate::middleware::Middleware;
use crate::{Request, Response};
use http::StatusCode;
use serde::Serialize;
use tracing::warn;
use super::Next;
#[derive(Debug, Clone)]
pub struct BodySizeLimitMiddleware {
max_size: usize,
custom_message: Option<String>,
log_rejections: bool,
include_size_info: bool,
}
impl BodySizeLimitMiddleware {
pub fn new(max_size: usize) -> Self {
Self {
max_size,
custom_message: None,
log_rejections: true,
include_size_info: true,
}
}
pub fn with_message(max_size: usize, message: impl Into<String>) -> Self {
Self {
max_size,
custom_message: Some(message.into()),
log_rejections: true,
include_size_info: true,
}
}
pub fn message(mut self, message: impl Into<String>) -> Self {
self.custom_message = Some(message.into());
self
}
pub fn megabytes(mb: usize) -> Self {
Self::new(mb * 1024 * 1024)
}
pub fn kilobytes(kb: usize) -> Self {
Self::new(kb * 1024)
}
pub fn gigabytes(gb: usize) -> Self {
Self::new(gb * 1024 * 1024 * 1024)
}
pub fn with_logging(mut self, enabled: bool) -> Self {
self.log_rejections = enabled;
self
}
pub fn with_size_info(mut self, enabled: bool) -> Self {
self.include_size_info = enabled;
self
}
pub fn max_size(&self) -> usize {
self.max_size
}
pub fn format_size(size: usize) -> String {
const UNITS: &[&str] = &["bytes", "KB", "MB", "GB", "TB"];
let mut size = size as f64;
let mut unit_index = 0;
while size >= 1024.0 && unit_index < UNITS.len() - 1 {
size /= 1024.0;
unit_index += 1;
}
if unit_index == 0 {
format!("{} {}", size as usize, UNITS[unit_index])
} else {
format!("{:.1} {}", size, UNITS[unit_index])
}
}
}
impl BodySizeLimitMiddleware {
pub fn json_api() -> Self {
Self::megabytes(1).message("Request body too large for JSON API (max 1MB)")
}
pub fn file_upload() -> Self {
Self::megabytes(10).message("File too large for upload (max 10MB)")
}
pub fn avatar_upload() -> Self {
Self::kilobytes(500).message("Avatar image too large (max 500KB)")
}
pub fn form_data() -> Self {
Self::kilobytes(64).message("Form data too large (max 64KB)")
}
}
#[async_trait::async_trait]
impl Middleware for BodySizeLimitMiddleware {
async fn handle(&self, req: Request, next: Next) -> Response {
let body_size = req.body.len();
if body_size <= self.max_size {
return next.run(req).await;
}
if self.log_rejections {
warn!(
"Request body size limit exceeded: {} > {} ({})",
Self::format_size(body_size),
Self::format_size(self.max_size),
req.uri.path()
);
}
let message = if let Some(custom_msg) = &self.custom_message {
custom_msg.clone()
} else if self.include_size_info {
format!(
"Request body size ({}) exceeds maximum allowed size ({})",
Self::format_size(body_size),
Self::format_size(self.max_size)
)
} else {
"Request body too large".to_string()
};
Response::json(PayloadTooLargeError {
message,
current_size: body_size,
max_size: self.max_size,
include_metadata: self.include_size_info,
})
.with_status(StatusCode::BAD_REQUEST)
}
}
#[derive(Debug, Serialize)]
struct PayloadTooLargeError {
message: String,
current_size: usize,
max_size: usize,
include_metadata: bool,
}
impl std::fmt::Display for PayloadTooLargeError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.message)
}
}
impl BodySizeLimitMiddleware {
pub fn builder() -> BodySizeLimitBuilder {
BodySizeLimitBuilder::new()
}
}
#[derive(Debug)]
pub struct BodySizeLimitBuilder {
max_size: Option<usize>,
custom_message: Option<String>,
log_rejections: bool,
include_size_info: bool,
}
impl BodySizeLimitBuilder {
fn new() -> Self {
Self {
max_size: None,
custom_message: None,
log_rejections: true,
include_size_info: true,
}
}
pub fn max_size(mut self, bytes: usize) -> Self {
self.max_size = Some(bytes);
self
}
pub fn max_size_kb(mut self, kb: usize) -> Self {
self.max_size = Some(kb * 1024);
self
}
pub fn max_size_mb(mut self, mb: usize) -> Self {
self.max_size = Some(mb * 1024 * 1024);
self
}
pub fn max_size_gb(mut self, gb: usize) -> Self {
self.max_size = Some(gb * 1024 * 1024 * 1024);
self
}
pub fn message(mut self, message: impl Into<String>) -> Self {
self.custom_message = Some(message.into());
self
}
pub fn disable_logging(mut self) -> Self {
self.log_rejections = false;
self
}
pub fn hide_size_info(mut self) -> Self {
self.include_size_info = false;
self
}
pub fn build(self) -> BodySizeLimitMiddleware {
let max_size = self.max_size.unwrap_or(10 * 1024 * 1024);
BodySizeLimitMiddleware {
max_size,
custom_message: self.custom_message,
log_rejections: self.log_rejections,
include_size_info: self.include_size_info,
}
}
}