use axum::{
body::Body,
extract::Request,
http::StatusCode,
middleware::Next,
response::{IntoResponse, Response},
};
use tracing::warn;
pub async fn header_limits_middleware(
request: Request<Body>,
next: Next,
max_header_count: usize,
max_header_bytes: usize,
) -> Response {
let headers = request.headers();
let header_count = headers.len();
if header_count > max_header_count {
warn!(header_count, max_header_count, "Request rejected: too many headers");
return (StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE, "Too many request headers")
.into_response();
}
let total_bytes: usize =
headers.iter().map(|(name, value)| name.as_str().len() + value.len()).sum();
if total_bytes > max_header_bytes {
warn!(total_bytes, max_header_bytes, "Request rejected: headers too large");
return (StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE, "Request headers too large")
.into_response();
}
next.run(request).await
}
#[cfg(test)]
mod tests {
use axum::{Router, body::Body, middleware, routing::get};
use http::Request;
use tower::ServiceExt;
use super::*;
async fn ok_handler() -> &'static str {
"ok"
}
fn test_app(max_count: usize, max_bytes: usize) -> Router {
Router::new()
.route("/", get(ok_handler))
.layer(middleware::from_fn(move |req, next| {
header_limits_middleware(req, next, max_count, max_bytes)
}))
}
#[tokio::test]
async fn accepts_request_within_limits() {
let app = test_app(10, 4096);
let req = Request::builder()
.uri("/")
.header("x-test", "value")
.body(Body::empty())
.expect("Reason: test request builder should not fail");
let resp = app.oneshot(req).await.expect("Reason: oneshot should not fail in test");
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn rejects_too_many_headers() {
let app = test_app(3, 65_536);
let mut builder = Request::builder().uri("/");
for i in 0..10 {
builder = builder.header(format!("x-test-{i}"), "value");
}
let req = builder
.body(Body::empty())
.expect("Reason: test request builder should not fail");
let resp = app.oneshot(req).await.expect("Reason: oneshot should not fail in test");
assert_eq!(resp.status(), StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE);
}
#[tokio::test]
async fn rejects_headers_too_large() {
let app = test_app(100, 64); let req = Request::builder()
.uri("/")
.header("x-large", "a]".repeat(100))
.body(Body::empty())
.expect("Reason: test request builder should not fail");
let resp = app.oneshot(req).await.expect("Reason: oneshot should not fail in test");
assert_eq!(resp.status(), StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE);
}
#[tokio::test]
async fn accepts_at_exact_count_limit() {
let app = test_app(5, 65_536);
let mut builder = Request::builder().uri("/");
for i in 0..5 {
builder = builder.header(format!("x-h-{i}"), "v");
}
let req = builder
.body(Body::empty())
.expect("Reason: test request builder should not fail");
let resp = app.oneshot(req).await.expect("Reason: oneshot should not fail in test");
assert_eq!(resp.status(), StatusCode::OK);
}
}