Skip to main content

fraiseql_server/middleware/
header_limits.rs

1//! HTTP header count and size limit middleware.
2//!
3//! Rejects requests that exceed configured header count or total header byte
4//! limits, preventing header-flooding `DoS` attacks that exhaust memory.
5
6use axum::{
7    body::Body,
8    extract::Request,
9    http::StatusCode,
10    middleware::Next,
11    response::{IntoResponse, Response},
12};
13use tracing::warn;
14
15/// Middleware that enforces header count and total header byte size limits.
16///
17/// Returns 431 Request Header Fields Too Large when either limit is exceeded.
18pub async fn header_limits_middleware(
19    request: Request<Body>,
20    next: Next,
21    max_header_count: usize,
22    max_header_bytes: usize,
23) -> Response {
24    let headers = request.headers();
25    let header_count = headers.len();
26
27    if header_count > max_header_count {
28        warn!(header_count, max_header_count, "Request rejected: too many headers");
29        return (StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE, "Too many request headers")
30            .into_response();
31    }
32
33    let total_bytes: usize =
34        headers.iter().map(|(name, value)| name.as_str().len() + value.len()).sum();
35
36    if total_bytes > max_header_bytes {
37        warn!(total_bytes, max_header_bytes, "Request rejected: headers too large");
38        return (StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE, "Request headers too large")
39            .into_response();
40    }
41
42    next.run(request).await
43}
44
45#[cfg(test)]
46mod tests {
47    use axum::{Router, body::Body, middleware, routing::get};
48    use http::Request;
49    use tower::ServiceExt;
50
51    use super::*;
52
53    async fn ok_handler() -> &'static str {
54        "ok"
55    }
56
57    fn test_app(max_count: usize, max_bytes: usize) -> Router {
58        Router::new()
59            .route("/", get(ok_handler))
60            .layer(middleware::from_fn(move |req, next| {
61                header_limits_middleware(req, next, max_count, max_bytes)
62            }))
63    }
64
65    #[tokio::test]
66    async fn accepts_request_within_limits() {
67        let app = test_app(10, 4096);
68        let req = Request::builder()
69            .uri("/")
70            .header("x-test", "value")
71            .body(Body::empty())
72            .expect("Reason: test request builder should not fail");
73
74        let resp = app.oneshot(req).await.expect("Reason: oneshot should not fail in test");
75        assert_eq!(resp.status(), StatusCode::OK);
76    }
77
78    #[tokio::test]
79    async fn rejects_too_many_headers() {
80        let app = test_app(3, 65_536);
81        let mut builder = Request::builder().uri("/");
82        for i in 0..10 {
83            builder = builder.header(format!("x-test-{i}"), "value");
84        }
85        let req = builder
86            .body(Body::empty())
87            .expect("Reason: test request builder should not fail");
88
89        let resp = app.oneshot(req).await.expect("Reason: oneshot should not fail in test");
90        assert_eq!(resp.status(), StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE);
91    }
92
93    #[tokio::test]
94    async fn rejects_headers_too_large() {
95        let app = test_app(100, 64); // 64-byte total limit
96        let req = Request::builder()
97            .uri("/")
98            .header("x-large", "a]".repeat(100))
99            .body(Body::empty())
100            .expect("Reason: test request builder should not fail");
101
102        let resp = app.oneshot(req).await.expect("Reason: oneshot should not fail in test");
103        assert_eq!(resp.status(), StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE);
104    }
105
106    #[tokio::test]
107    async fn accepts_at_exact_count_limit() {
108        let app = test_app(5, 65_536);
109        let mut builder = Request::builder().uri("/");
110        // Add exactly 5 custom headers (host may be added automatically)
111        for i in 0..5 {
112            builder = builder.header(format!("x-h-{i}"), "v");
113        }
114        let req = builder
115            .body(Body::empty())
116            .expect("Reason: test request builder should not fail");
117
118        let resp = app.oneshot(req).await.expect("Reason: oneshot should not fail in test");
119        // With 5 custom headers, total is 5 which is at limit — should pass
120        assert_eq!(resp.status(), StatusCode::OK);
121    }
122}