Skip to main content

allowthem_server/
csrf.rs

1use axum::{
2    body::Body,
3    extract::FromRequestParts,
4    http::{Request, StatusCode, header, request::Parts},
5    middleware::Next,
6    response::Response,
7};
8use uuid::Uuid;
9
10const CSRF_COOKIE_NAME: &str = "csrf_token";
11
12/// A CSRF token for the current request.
13///
14/// Available to handlers via extractor after the `csrf_middleware` layer has run.
15/// Embed this in forms as a hidden field named `csrf_token`, or send it as the
16/// `X-CSRF-Token` header for AJAX requests.
17#[derive(Clone)]
18pub struct CsrfToken(pub String);
19
20impl CsrfToken {
21    pub fn as_str(&self) -> &str {
22        &self.0
23    }
24}
25
26impl<S: Send + Sync> FromRequestParts<S> for CsrfToken {
27    type Rejection = StatusCode;
28
29    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
30        parts
31            .extensions
32            .get::<CsrfToken>()
33            .cloned()
34            .ok_or(StatusCode::INTERNAL_SERVER_ERROR)
35    }
36}
37
38/// CSRF protection middleware using the double-submit cookie pattern.
39///
40/// **Safe methods (GET, HEAD, OPTIONS):** Reads or generates a CSRF token, sets it
41/// as a cookie on the response (not `HttpOnly` so JS/HTMX can read it), and inserts
42/// it into request extensions so handlers can embed it in forms via [`CsrfToken`].
43///
44/// **Unsafe methods (POST, PUT, DELETE, PATCH):** Requires the submitted token
45/// (from `X-CSRF-Token` header or `csrf_token` form field) to match the CSRF
46/// cookie. Returns 403 on mismatch or missing token.
47pub async fn csrf_middleware(
48    mut request: Request<Body>,
49    next: Next,
50) -> Result<Response, StatusCode> {
51    let method = request.method().clone();
52    let is_safe = matches!(
53        method,
54        axum::http::Method::GET | axum::http::Method::HEAD | axum::http::Method::OPTIONS
55    );
56
57    let cookie_token = extract_csrf_cookie(request.headers());
58
59    if is_safe {
60        let is_new = cookie_token.is_none();
61        let token = cookie_token.unwrap_or_else(|| Uuid::new_v4().to_string());
62
63        request.extensions_mut().insert(CsrfToken(token.clone()));
64
65        let mut response = next.run(request).await;
66
67        if is_new {
68            let cookie = format!("{}={}; SameSite=Lax; Path=/", CSRF_COOKIE_NAME, token);
69            if let Ok(value) = cookie.parse() {
70                response.headers_mut().append(header::SET_COOKIE, value);
71            }
72        }
73
74        Ok(response)
75    } else {
76        let submitted = extract_submitted_token(&mut request).await?;
77
78        let cookie_val = cookie_token.ok_or(StatusCode::FORBIDDEN)?;
79
80        if submitted != cookie_val {
81            return Err(StatusCode::FORBIDDEN);
82        }
83
84        request.extensions_mut().insert(CsrfToken(cookie_val));
85
86        Ok(next.run(request).await)
87    }
88}
89
90/// Extract the CSRF token from the `csrf_token` cookie in the `Cookie` header.
91fn extract_csrf_cookie(headers: &header::HeaderMap) -> Option<String> {
92    let cookie_header = headers.get(header::COOKIE)?.to_str().ok()?;
93    for pair in cookie_header.split("; ") {
94        if let Some((name, value)) = pair.split_once('=')
95            && name.trim() == CSRF_COOKIE_NAME
96        {
97            return Some(value.trim().to_string());
98        }
99    }
100    None
101}
102
103/// Extract the submitted CSRF token from either the `X-CSRF-Token` header or
104/// the `csrf_token` field in a `application/x-www-form-urlencoded` body.
105///
106/// Consumes and then replaces the request body so the handler still receives it.
107async fn extract_submitted_token(request: &mut Request<Body>) -> Result<String, StatusCode> {
108    // Check header first — preferred for AJAX/HTMX.
109    if let Some(header_val) = request.headers().get("x-csrf-token")
110        && let Ok(token) = header_val.to_str()
111    {
112        return Ok(token.to_string());
113    }
114
115    // Fall back to form body for traditional form submissions.
116    let is_form = request
117        .headers()
118        .get(header::CONTENT_TYPE)
119        .and_then(|v| v.to_str().ok())
120        .map(|ct| ct.starts_with("application/x-www-form-urlencoded"))
121        .unwrap_or(false);
122
123    if !is_form {
124        return Err(StatusCode::FORBIDDEN);
125    }
126
127    // Consume the body to search for the token.
128    let body = std::mem::replace(request.body_mut(), Body::empty());
129    let bytes = axum::body::to_bytes(body, 64 * 1024)
130        .await
131        .map_err(|_| StatusCode::BAD_REQUEST)?;
132
133    // Put the body back so the handler can read it.
134    *request.body_mut() = Body::from(bytes.clone());
135
136    // Parse without serde_urlencoded: find csrf_token=<value> pair.
137    let body_str = std::str::from_utf8(&bytes).map_err(|_| StatusCode::BAD_REQUEST)?;
138    for pair in body_str.split('&') {
139        if let Some((key, value)) = pair.split_once('=')
140            && key == "csrf_token"
141        {
142            return Ok(value.to_string());
143        }
144    }
145
146    Err(StatusCode::FORBIDDEN)
147}
148
149#[cfg(test)]
150mod tests {
151    use super::*;
152    use axum::{Router, middleware, routing::get};
153    use tower::ServiceExt;
154
155    async fn ok_handler() -> StatusCode {
156        StatusCode::OK
157    }
158
159    fn test_app() -> Router {
160        Router::new()
161            .route("/", get(ok_handler).post(ok_handler))
162            .layer(middleware::from_fn(csrf_middleware))
163    }
164
165    fn get_set_cookie(response: &Response) -> Option<String> {
166        response
167            .headers()
168            .get(header::SET_COOKIE)
169            .and_then(|v| v.to_str().ok())
170            .map(|s| s.to_string())
171    }
172
173    fn extract_token_from_set_cookie(set_cookie: &str) -> String {
174        // Format: "csrf_token=<value>; SameSite=Lax; Path=/"
175        set_cookie
176            .split(';')
177            .next()
178            .and_then(|pair| pair.split_once('='))
179            .map(|(_, v)| v.trim().to_string())
180            .expect("csrf token not found in Set-Cookie")
181    }
182
183    #[tokio::test]
184    async fn get_sets_csrf_cookie() {
185        let app = test_app();
186        let response = app
187            .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
188            .await
189            .unwrap();
190
191        assert_eq!(response.status(), StatusCode::OK);
192        let set_cookie = get_set_cookie(&response).expect("Set-Cookie header missing");
193        assert!(set_cookie.starts_with("csrf_token="));
194        assert!(set_cookie.contains("SameSite=Lax"));
195    }
196
197    #[tokio::test]
198    async fn head_does_not_require_csrf() {
199        let app = Router::new()
200            .route("/", axum::routing::any(ok_handler))
201            .layer(middleware::from_fn(csrf_middleware));
202
203        let response = app
204            .oneshot(
205                Request::builder()
206                    .method("HEAD")
207                    .uri("/")
208                    .body(Body::empty())
209                    .unwrap(),
210            )
211            .await
212            .unwrap();
213
214        assert_eq!(response.status(), StatusCode::OK);
215    }
216
217    #[tokio::test]
218    async fn post_with_valid_header_token_passes() {
219        let app = test_app();
220
221        // First GET to obtain a token.
222        let get_resp = app
223            .clone()
224            .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
225            .await
226            .unwrap();
227        let set_cookie = get_set_cookie(&get_resp).expect("Set-Cookie missing");
228        let token = extract_token_from_set_cookie(&set_cookie);
229
230        // POST with the token in the header and the cookie set.
231        let post_resp = app
232            .oneshot(
233                Request::builder()
234                    .method("POST")
235                    .uri("/")
236                    .header(header::COOKIE, format!("csrf_token={token}"))
237                    .header("x-csrf-token", &token)
238                    .body(Body::empty())
239                    .unwrap(),
240            )
241            .await
242            .unwrap();
243
244        assert_eq!(post_resp.status(), StatusCode::OK);
245    }
246
247    #[tokio::test]
248    async fn post_with_valid_form_token_passes() {
249        let app = test_app();
250
251        let get_resp = app
252            .clone()
253            .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
254            .await
255            .unwrap();
256        let set_cookie = get_set_cookie(&get_resp).expect("Set-Cookie missing");
257        let token = extract_token_from_set_cookie(&set_cookie);
258
259        let body = format!("username=alice&csrf_token={token}");
260        let post_resp = app
261            .oneshot(
262                Request::builder()
263                    .method("POST")
264                    .uri("/")
265                    .header(header::COOKIE, format!("csrf_token={token}"))
266                    .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
267                    .body(Body::from(body))
268                    .unwrap(),
269            )
270            .await
271            .unwrap();
272
273        assert_eq!(post_resp.status(), StatusCode::OK);
274    }
275
276    #[tokio::test]
277    async fn post_with_missing_token_returns_403() {
278        let app = test_app();
279
280        // POST with a cookie but no submitted token.
281        let response = app
282            .oneshot(
283                Request::builder()
284                    .method("POST")
285                    .uri("/")
286                    .header(header::COOKIE, "csrf_token=someval")
287                    .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
288                    .body(Body::from("username=alice"))
289                    .unwrap(),
290            )
291            .await
292            .unwrap();
293
294        assert_eq!(response.status(), StatusCode::FORBIDDEN);
295    }
296
297    #[tokio::test]
298    async fn post_with_wrong_token_returns_403() {
299        let app = test_app();
300
301        let response = app
302            .oneshot(
303                Request::builder()
304                    .method("POST")
305                    .uri("/")
306                    .header(header::COOKIE, "csrf_token=correct")
307                    .header("x-csrf-token", "wrong")
308                    .body(Body::empty())
309                    .unwrap(),
310            )
311            .await
312            .unwrap();
313
314        assert_eq!(response.status(), StatusCode::FORBIDDEN);
315    }
316
317    #[tokio::test]
318    async fn post_with_missing_cookie_returns_403() {
319        let app = test_app();
320
321        // Token in header but no cookie.
322        let response = app
323            .oneshot(
324                Request::builder()
325                    .method("POST")
326                    .uri("/")
327                    .header("x-csrf-token", "sometoken")
328                    .body(Body::empty())
329                    .unwrap(),
330            )
331            .await
332            .unwrap();
333
334        assert_eq!(response.status(), StatusCode::FORBIDDEN);
335    }
336
337    #[tokio::test]
338    async fn existing_cookie_not_overwritten_on_get() {
339        let app = test_app();
340
341        let response = app
342            .oneshot(
343                Request::builder()
344                    .uri("/")
345                    .header(header::COOKIE, "csrf_token=existing_token")
346                    .body(Body::empty())
347                    .unwrap(),
348            )
349            .await
350            .unwrap();
351
352        assert_eq!(response.status(), StatusCode::OK);
353        // No new Set-Cookie should be issued since the cookie already exists.
354        assert!(get_set_cookie(&response).is_none());
355    }
356}