datasynth_server/rest/
request_validation.rs1use axum::{
6 body::Body,
7 http::{header, Method, Request, StatusCode},
8 middleware::Next,
9 response::{IntoResponse, Response},
10};
11
12pub async fn request_validation_middleware(request: Request<Body>, next: Next) -> Response {
17 let method = request.method().clone();
18
19 if matches!(method, Method::POST | Method::PUT | Method::PATCH) {
21 let has_body = request
23 .headers()
24 .get(header::CONTENT_LENGTH)
25 .and_then(|v| v.to_str().ok())
26 .and_then(|v| v.parse::<u64>().ok())
27 .unwrap_or(0)
28 > 0;
29
30 if has_body {
31 let content_type = request
32 .headers()
33 .get(header::CONTENT_TYPE)
34 .and_then(|v| v.to_str().ok())
35 .unwrap_or("");
36
37 if !content_type.starts_with("application/json") {
38 return (
39 StatusCode::UNSUPPORTED_MEDIA_TYPE,
40 "Content-Type must be application/json",
41 )
42 .into_response();
43 }
44 }
45 }
46
47 next.run(request).await
48}
49
50#[cfg(test)]
51mod tests {
52 use super::*;
53 use axum::{routing::post, Router};
54 use tower::ServiceExt;
55
56 async fn ok_handler() -> &'static str {
57 "ok"
58 }
59
60 fn test_router() -> Router {
61 Router::new()
62 .route("/test", post(ok_handler))
63 .layer(axum::middleware::from_fn(request_validation_middleware))
64 }
65
66 #[tokio::test]
67 async fn test_post_with_json_content_type() {
68 let router = test_router();
69 let request = Request::builder()
70 .method(Method::POST)
71 .uri("/test")
72 .header(header::CONTENT_TYPE, "application/json")
73 .header(header::CONTENT_LENGTH, "2")
74 .body(Body::from("{}"))
75 .unwrap();
76
77 let response = router.oneshot(request).await.unwrap();
78 assert_eq!(response.status(), StatusCode::OK);
79 }
80
81 #[tokio::test]
82 async fn test_post_with_wrong_content_type() {
83 let router = test_router();
84 let request = Request::builder()
85 .method(Method::POST)
86 .uri("/test")
87 .header(header::CONTENT_TYPE, "text/plain")
88 .header(header::CONTENT_LENGTH, "5")
89 .body(Body::from("hello"))
90 .unwrap();
91
92 let response = router.oneshot(request).await.unwrap();
93 assert_eq!(response.status(), StatusCode::UNSUPPORTED_MEDIA_TYPE);
94 }
95
96 #[tokio::test]
97 async fn test_post_without_body_passes() {
98 let router = test_router();
99 let request = Request::builder()
100 .method(Method::POST)
101 .uri("/test")
102 .body(Body::empty())
103 .unwrap();
104
105 let response = router.oneshot(request).await.unwrap();
106 assert_eq!(response.status(), StatusCode::OK);
107 }
108}