datasynth_server/rest/
request_validation.rs1use axum::{
6 body::Body,
7 http::{header, Method, Request, StatusCode},
8 middleware::Next,
9 response::{IntoResponse, Response},
10};
11
12pub async fn request_validation_middleware(request: Request<Body>, next: Next) -> Response {
17 let method = request.method().clone();
18
19 if matches!(method, Method::POST | Method::PUT | Method::PATCH) {
21 let has_body = request
23 .headers()
24 .get(header::CONTENT_LENGTH)
25 .and_then(|v| v.to_str().ok())
26 .and_then(|v| v.parse::<u64>().ok())
27 .unwrap_or(0)
28 > 0;
29
30 if has_body {
31 let content_type = request
32 .headers()
33 .get(header::CONTENT_TYPE)
34 .and_then(|v| v.to_str().ok())
35 .unwrap_or("");
36
37 if !content_type.starts_with("application/json") {
38 return (
39 StatusCode::UNSUPPORTED_MEDIA_TYPE,
40 "Content-Type must be application/json",
41 )
42 .into_response();
43 }
44 }
45 }
46
47 next.run(request).await
48}
49
50#[cfg(test)]
51#[allow(clippy::unwrap_used)]
52mod tests {
53 use super::*;
54 use axum::{routing::post, Router};
55 use tower::ServiceExt;
56
57 async fn ok_handler() -> &'static str {
58 "ok"
59 }
60
61 fn test_router() -> Router {
62 Router::new()
63 .route("/test", post(ok_handler))
64 .layer(axum::middleware::from_fn(request_validation_middleware))
65 }
66
67 #[tokio::test]
68 async fn test_post_with_json_content_type() {
69 let router = test_router();
70 let request = Request::builder()
71 .method(Method::POST)
72 .uri("/test")
73 .header(header::CONTENT_TYPE, "application/json")
74 .header(header::CONTENT_LENGTH, "2")
75 .body(Body::from("{}"))
76 .unwrap();
77
78 let response = router.oneshot(request).await.unwrap();
79 assert_eq!(response.status(), StatusCode::OK);
80 }
81
82 #[tokio::test]
83 async fn test_post_with_wrong_content_type() {
84 let router = test_router();
85 let request = Request::builder()
86 .method(Method::POST)
87 .uri("/test")
88 .header(header::CONTENT_TYPE, "text/plain")
89 .header(header::CONTENT_LENGTH, "5")
90 .body(Body::from("hello"))
91 .unwrap();
92
93 let response = router.oneshot(request).await.unwrap();
94 assert_eq!(response.status(), StatusCode::UNSUPPORTED_MEDIA_TYPE);
95 }
96
97 #[tokio::test]
98 async fn test_post_without_body_passes() {
99 let router = test_router();
100 let request = Request::builder()
101 .method(Method::POST)
102 .uri("/test")
103 .body(Body::empty())
104 .unwrap();
105
106 let response = router.oneshot(request).await.unwrap();
107 assert_eq!(response.status(), StatusCode::OK);
108 }
109}