use async_trait::async_trait;
use std::sync::Arc;
use std::time::Duration;
use tokio::time::timeout;
use super::Middleware;
use crate::core::{Request, Response, router::Handler};
#[derive(Clone)]
pub struct LimitsConfig {
pub request_timeout: Duration,
pub max_body_size: usize,
pub max_path_length: usize,
pub max_headers: usize,
pub max_header_size: usize,
}
impl Default for LimitsConfig {
fn default() -> Self {
Self {
request_timeout: Duration::from_secs(30),
max_body_size: 1024 * 1024, max_path_length: 2048,
max_headers: 100,
max_header_size: 8 * 1024, }
}
}
impl LimitsConfig {
pub fn new() -> Self {
Self::default()
}
pub fn request_timeout(mut self, timeout: Duration) -> Self {
self.request_timeout = timeout;
self
}
pub fn max_body_size(mut self, size: usize) -> Self {
self.max_body_size = size;
self
}
pub fn max_path_length(mut self, length: usize) -> Self {
self.max_path_length = length;
self
}
pub fn max_headers(mut self, count: usize) -> Self {
self.max_headers = count;
self
}
pub fn max_header_size(mut self, size: usize) -> Self {
self.max_header_size = size;
self
}
}
pub struct LimitsMiddleware {
config: LimitsConfig,
}
impl LimitsMiddleware {
pub fn new() -> Self {
Self {
config: LimitsConfig::default(),
}
}
pub fn with_config(config: LimitsConfig) -> Self {
Self { config }
}
fn validate_request(&self, req: &Request) -> Option<Response> {
if req.path().len() > self.config.max_path_length {
tracing::warn!(
"Request path too long: {} > {}",
req.path().len(),
self.config.max_path_length
);
return Some(Response::text(414, "URI Too Long"));
}
if req.headers().len() > self.config.max_headers {
tracing::warn!(
"Too many headers: {} > {}",
req.headers().len(),
self.config.max_headers
);
return Some(Response::text(431, "Request Header Fields Too Large"));
}
for (name, value) in req.headers() {
let name_len = name.as_str().len();
let value_len = value.len();
if name_len + value_len > self.config.max_header_size {
tracing::warn!(
"Header too large: {} + {} > {}",
name_len,
value_len,
self.config.max_header_size
);
return Some(Response::text(431, "Request Header Fields Too Large"));
}
}
if req.body().len() > self.config.max_body_size {
tracing::warn!(
"Request body too large: {} > {}",
req.body().len(),
self.config.max_body_size
);
return Some(Response::text(413, "Payload Too Large"));
}
None
}
}
impl Default for LimitsMiddleware {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Middleware for LimitsMiddleware {
async fn handle(&self, req: Request, next: Arc<dyn Handler>) -> Response {
if let Some(error_response) = self.validate_request(&req) {
return error_response;
}
match timeout(self.config.request_timeout, next.handle(req)).await {
Ok(response) => response,
Err(_) => {
tracing::warn!(
"Request timeout after {}ms",
self.config.request_timeout.as_millis()
);
Response::text(408, "Request Timeout")
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::{Method, Request};
struct MockHandler {
delay: Option<Duration>,
}
impl MockHandler {
fn new() -> Arc<Self> {
Arc::new(Self { delay: None })
}
fn with_delay(delay: Duration) -> Arc<Self> {
Arc::new(Self { delay: Some(delay) })
}
}
#[async_trait]
impl Handler for MockHandler {
async fn handle(&self, _req: Request) -> Response {
if let Some(delay) = self.delay {
tokio::time::sleep(delay).await;
}
Response::text(200, "ok")
}
}
#[tokio::test]
async fn test_request_timeout() {
let config = LimitsConfig::new().request_timeout(Duration::from_millis(100));
let middleware = LimitsMiddleware::with_config(config);
let handler = MockHandler::with_delay(Duration::from_millis(200));
let req = Request::new(Method::GET, "/test");
let response = middleware.handle(req, handler).await;
assert_eq!(response.status.as_u16(), 408);
}
#[tokio::test]
async fn test_path_length_limit() {
let config = LimitsConfig::new().max_path_length(10);
let middleware = LimitsMiddleware::with_config(config);
let handler = MockHandler::new();
let req = Request::new(Method::GET, "/very-long-path-that-exceeds-limit");
let response = middleware.handle(req, handler).await;
assert_eq!(response.status.as_u16(), 414);
}
#[tokio::test]
async fn test_body_size_limit() {
let config = LimitsConfig::new().max_body_size(5);
let middleware = LimitsMiddleware::with_config(config);
let handler = MockHandler::new();
let req = Request::new(Method::POST, "/test").with_body(b"too long body".to_vec());
let response = middleware.handle(req, handler).await;
assert_eq!(response.status.as_u16(), 413);
}
#[tokio::test]
async fn test_header_count_limit() {
let config = LimitsConfig::new().max_headers(2);
let middleware = LimitsMiddleware::with_config(config);
let handler = MockHandler::new();
let mut req = Request::new(Method::GET, "/test");
req.headers_mut()
.insert("header1", "value1".try_into().unwrap());
req.headers_mut()
.insert("header2", "value2".try_into().unwrap());
req.headers_mut()
.insert("header3", "value3".try_into().unwrap());
let response = middleware.handle(req, handler).await;
assert_eq!(response.status.as_u16(), 431);
}
#[tokio::test]
async fn test_header_size_limit() {
let config = LimitsConfig::new().max_header_size(10);
let middleware = LimitsMiddleware::with_config(config);
let handler = MockHandler::new();
let mut req = Request::new(Method::GET, "/test");
req.headers_mut()
.insert("x-long", "very-long-value".try_into().unwrap());
let response = middleware.handle(req, handler).await;
assert_eq!(response.status.as_u16(), 431);
}
#[tokio::test]
async fn test_valid_request_passes() {
let config = LimitsConfig::new();
let middleware = LimitsMiddleware::with_config(config);
let handler = MockHandler::new();
let req = Request::new(Method::GET, "/test").with_body(b"small".to_vec());
let response = middleware.handle(req, handler).await;
assert_eq!(response.status.as_u16(), 200);
}
}