Skip to main content

datasynth_server/rest/
request_validation.rs

1//! Request validation middleware.
2//!
3//! Enforces Content-Type for mutation requests (POST/PUT/PATCH).
4
5use axum::{
6    body::Body,
7    http::{header, Method, Request, StatusCode},
8    middleware::Next,
9    response::{IntoResponse, Response},
10};
11
12/// Request validation middleware.
13///
14/// - POST, PUT, PATCH requests must include `Content-Type: application/json`
15/// - GET, DELETE, OPTIONS, HEAD bypass this check
16pub async fn request_validation_middleware(request: Request<Body>, next: Next) -> Response {
17    let method = request.method().clone();
18
19    // Only validate mutation methods
20    if matches!(method, Method::POST | Method::PUT | Method::PATCH) {
21        // Allow empty bodies (some POST endpoints don't need a body)
22        let has_body = request
23            .headers()
24            .get(header::CONTENT_LENGTH)
25            .and_then(|v| v.to_str().ok())
26            .and_then(|v| v.parse::<u64>().ok())
27            .unwrap_or(0)
28            > 0;
29
30        if has_body {
31            let content_type = request
32                .headers()
33                .get(header::CONTENT_TYPE)
34                .and_then(|v| v.to_str().ok())
35                .unwrap_or("");
36
37            if !content_type.starts_with("application/json") {
38                return (
39                    StatusCode::UNSUPPORTED_MEDIA_TYPE,
40                    "Content-Type must be application/json",
41                )
42                    .into_response();
43            }
44        }
45    }
46
47    next.run(request).await
48}
49
50#[cfg(test)]
51mod tests {
52    use super::*;
53    use axum::{routing::post, Router};
54    use tower::ServiceExt;
55
56    async fn ok_handler() -> &'static str {
57        "ok"
58    }
59
60    fn test_router() -> Router {
61        Router::new()
62            .route("/test", post(ok_handler))
63            .layer(axum::middleware::from_fn(request_validation_middleware))
64    }
65
66    #[tokio::test]
67    async fn test_post_with_json_content_type() {
68        let router = test_router();
69        let request = Request::builder()
70            .method(Method::POST)
71            .uri("/test")
72            .header(header::CONTENT_TYPE, "application/json")
73            .header(header::CONTENT_LENGTH, "2")
74            .body(Body::from("{}"))
75            .unwrap();
76
77        let response = router.oneshot(request).await.unwrap();
78        assert_eq!(response.status(), StatusCode::OK);
79    }
80
81    #[tokio::test]
82    async fn test_post_with_wrong_content_type() {
83        let router = test_router();
84        let request = Request::builder()
85            .method(Method::POST)
86            .uri("/test")
87            .header(header::CONTENT_TYPE, "text/plain")
88            .header(header::CONTENT_LENGTH, "5")
89            .body(Body::from("hello"))
90            .unwrap();
91
92        let response = router.oneshot(request).await.unwrap();
93        assert_eq!(response.status(), StatusCode::UNSUPPORTED_MEDIA_TYPE);
94    }
95
96    #[tokio::test]
97    async fn test_post_without_body_passes() {
98        let router = test_router();
99        let request = Request::builder()
100            .method(Method::POST)
101            .uri("/test")
102            .body(Body::empty())
103            .unwrap();
104
105        let response = router.oneshot(request).await.unwrap();
106        assert_eq!(response.status(), StatusCode::OK);
107    }
108}