Skip to main content

datasynth_server/rest/
request_validation.rs

1//! Request validation middleware.
2//!
3//! Enforces Content-Type for mutation requests (POST/PUT/PATCH).
4
5use axum::{
6    body::Body,
7    http::{header, Method, Request, StatusCode},
8    middleware::Next,
9    response::{IntoResponse, Response},
10};
11
12/// Request validation middleware.
13///
14/// - POST, PUT, PATCH requests must include `Content-Type: application/json`
15/// - GET, DELETE, OPTIONS, HEAD bypass this check
16pub async fn request_validation_middleware(request: Request<Body>, next: Next) -> Response {
17    let method = request.method().clone();
18
19    // Only validate mutation methods
20    if matches!(method, Method::POST | Method::PUT | Method::PATCH) {
21        // Allow empty bodies (some POST endpoints don't need a body)
22        let has_body = request
23            .headers()
24            .get(header::CONTENT_LENGTH)
25            .and_then(|v| v.to_str().ok())
26            .and_then(|v| v.parse::<u64>().ok())
27            .unwrap_or(0)
28            > 0;
29
30        if has_body {
31            let content_type = request
32                .headers()
33                .get(header::CONTENT_TYPE)
34                .and_then(|v| v.to_str().ok())
35                .unwrap_or("");
36
37            if !content_type.starts_with("application/json") {
38                return (
39                    StatusCode::UNSUPPORTED_MEDIA_TYPE,
40                    "Content-Type must be application/json",
41                )
42                    .into_response();
43            }
44        }
45    }
46
47    next.run(request).await
48}
49
50#[cfg(test)]
51#[allow(clippy::unwrap_used)]
52mod tests {
53    use super::*;
54    use axum::{routing::post, Router};
55    use tower::ServiceExt;
56
57    async fn ok_handler() -> &'static str {
58        "ok"
59    }
60
61    fn test_router() -> Router {
62        Router::new()
63            .route("/test", post(ok_handler))
64            .layer(axum::middleware::from_fn(request_validation_middleware))
65    }
66
67    #[tokio::test]
68    async fn test_post_with_json_content_type() {
69        let router = test_router();
70        let request = Request::builder()
71            .method(Method::POST)
72            .uri("/test")
73            .header(header::CONTENT_TYPE, "application/json")
74            .header(header::CONTENT_LENGTH, "2")
75            .body(Body::from("{}"))
76            .unwrap();
77
78        let response = router.oneshot(request).await.unwrap();
79        assert_eq!(response.status(), StatusCode::OK);
80    }
81
82    #[tokio::test]
83    async fn test_post_with_wrong_content_type() {
84        let router = test_router();
85        let request = Request::builder()
86            .method(Method::POST)
87            .uri("/test")
88            .header(header::CONTENT_TYPE, "text/plain")
89            .header(header::CONTENT_LENGTH, "5")
90            .body(Body::from("hello"))
91            .unwrap();
92
93        let response = router.oneshot(request).await.unwrap();
94        assert_eq!(response.status(), StatusCode::UNSUPPORTED_MEDIA_TYPE);
95    }
96
97    #[tokio::test]
98    async fn test_post_without_body_passes() {
99        let router = test_router();
100        let request = Request::builder()
101            .method(Method::POST)
102            .uri("/test")
103            .body(Body::empty())
104            .unwrap();
105
106        let response = router.oneshot(request).await.unwrap();
107        assert_eq!(response.status(), StatusCode::OK);
108    }
109}