Skip to main content

fraiseql_server/middleware/
content_type.rs

1//! CSRF protection via Content-Type enforcement.
2//!
3//! Rejects POST requests that do not carry `Content-Type: application/json`.
4//! This prevents cross-site request forgery via `text/plain` or
5//! `application/x-www-form-urlencoded` form submissions.
6
7use axum::{
8    body::Body,
9    http::{Method, Request, StatusCode, header::CONTENT_TYPE},
10    middleware::Next,
11    response::{IntoResponse, Response},
12};
13
14/// Middleware that rejects POST requests without a JSON Content-Type.
15///
16/// Non-POST methods pass through unconditionally.
17/// POST requests must have `Content-Type` starting with `application/json`
18/// (e.g. `application/json` or `application/json; charset=utf-8`).
19///
20/// # Errors
21///
22/// Returns a `415 Unsupported Media Type` response if the POST request does not carry a JSON
23/// `Content-Type`.
24pub async fn require_json_content_type(
25    req: Request<Body>,
26    next: Next,
27) -> Result<Response, Response> {
28    if req.method() != Method::POST {
29        return Ok(next.run(req).await);
30    }
31
32    let content_type = req.headers().get(CONTENT_TYPE).and_then(|v| v.to_str().ok()).unwrap_or("");
33
34    if !content_type.starts_with("application/json") {
35        let body = serde_json::json!({
36            "errors": [{
37                "message": "Content-Type must be application/json",
38                "extensions": { "code": "UNSUPPORTED_MEDIA_TYPE" }
39            }]
40        });
41        return Err((
42            StatusCode::UNSUPPORTED_MEDIA_TYPE,
43            [(CONTENT_TYPE, "application/json")],
44            serde_json::to_string(&body).unwrap_or_else(|_| {
45                r#"{"errors":[{"message":"Unsupported Media Type"}]}"#.to_owned()
46            }),
47        )
48            .into_response());
49    }
50
51    Ok(next.run(req).await)
52}
53
54#[cfg(test)]
55mod tests {
56    #![allow(clippy::unwrap_used)] // Reason: test code, panics acceptable
57    #![allow(clippy::cast_precision_loss)] // Reason: test metrics reporting
58    #![allow(clippy::cast_sign_loss)] // Reason: test data uses small positive integers
59    #![allow(clippy::cast_possible_truncation)] // Reason: test data values are bounded
60    #![allow(clippy::cast_possible_wrap)] // Reason: test data values are bounded
61    #![allow(clippy::missing_panics_doc)] // Reason: test helpers
62    #![allow(clippy::missing_errors_doc)] // Reason: test helpers
63    #![allow(missing_docs)] // Reason: test code
64    #![allow(clippy::items_after_statements)] // Reason: test helpers defined near use site
65
66    use axum::{
67        Router,
68        body::Body,
69        http::{Request, StatusCode, header::CONTENT_TYPE},
70        middleware,
71        routing::post,
72    };
73    use tower::ServiceExt;
74
75    use super::require_json_content_type;
76
77    async fn echo_handler() -> &'static str {
78        "ok"
79    }
80
81    fn app() -> Router {
82        Router::new()
83            .route("/graphql", post(echo_handler))
84            .layer(middleware::from_fn(require_json_content_type))
85    }
86
87    #[tokio::test]
88    async fn text_plain_rejected_with_415() {
89        let res = app()
90            .oneshot(
91                Request::post("/graphql")
92                    .header(CONTENT_TYPE, "text/plain")
93                    .body(Body::from(r#"{"query":"{ __typename }"}"#))
94                    .unwrap(),
95            )
96            .await
97            .unwrap();
98        assert_eq!(res.status(), StatusCode::UNSUPPORTED_MEDIA_TYPE);
99    }
100
101    #[tokio::test]
102    async fn form_urlencoded_rejected_with_415() {
103        let res = app()
104            .oneshot(
105                Request::post("/graphql")
106                    .header(CONTENT_TYPE, "application/x-www-form-urlencoded")
107                    .body(Body::from("query=%7B+__typename+%7D"))
108                    .unwrap(),
109            )
110            .await
111            .unwrap();
112        assert_eq!(res.status(), StatusCode::UNSUPPORTED_MEDIA_TYPE);
113    }
114
115    #[tokio::test]
116    async fn application_json_passes() {
117        let res = app()
118            .oneshot(
119                Request::post("/graphql")
120                    .header(CONTENT_TYPE, "application/json")
121                    .body(Body::from(r#"{"query":"{ __typename }"}"#))
122                    .unwrap(),
123            )
124            .await
125            .unwrap();
126        assert_eq!(res.status(), StatusCode::OK);
127    }
128
129    #[tokio::test]
130    async fn application_json_with_charset_passes() {
131        let res = app()
132            .oneshot(
133                Request::post("/graphql")
134                    .header(CONTENT_TYPE, "application/json; charset=utf-8")
135                    .body(Body::from(r#"{"query":"{ __typename }"}"#))
136                    .unwrap(),
137            )
138            .await
139            .unwrap();
140        assert_eq!(res.status(), StatusCode::OK);
141    }
142
143    #[tokio::test]
144    async fn get_request_passes_without_content_type() {
145        let app = Router::new()
146            .route("/graphql", axum::routing::get(echo_handler))
147            .layer(middleware::from_fn(require_json_content_type));
148
149        let res = app
150            .oneshot(Request::get("/graphql").body(Body::empty()).unwrap())
151            .await
152            .unwrap();
153        assert_eq!(res.status(), StatusCode::OK);
154    }
155
156    #[tokio::test]
157    async fn missing_content_type_rejected() {
158        let res = app()
159            .oneshot(
160                Request::post("/graphql")
161                    .body(Body::from(r#"{"query":"{ __typename }"}"#))
162                    .unwrap(),
163            )
164            .await
165            .unwrap();
166        assert_eq!(res.status(), StatusCode::UNSUPPORTED_MEDIA_TYPE);
167    }
168}