Skip to main content

allowthem_server/
csrf.rs

1use axum::{
2    body::Body,
3    extract::FromRequestParts,
4    http::{Request, StatusCode, header, request::Parts},
5    middleware::Next,
6    response::Response,
7};
8use subtle::ConstantTimeEq;
9use uuid::Uuid;
10
11use allowthem_core::{AllowThem, derive_csrf_token, verify_csrf_token};
12
13const PRE_AUTH_CSRF_COOKIE: &str = "csrf_pre";
14
15/// A CSRF token for the current request.
16///
17/// Available to handlers via extractor after `csrf_middleware` has run.
18/// Embed in forms as a hidden field named `csrf_token`, or send as
19/// `X-CSRF-Token` header for AJAX/HTMX requests.
20#[derive(Clone)]
21pub struct CsrfToken(pub String);
22
23impl CsrfToken {
24    pub fn as_str(&self) -> &str {
25        &self.0
26    }
27}
28
29impl<S: Send + Sync> FromRequestParts<S> for CsrfToken {
30    type Rejection = StatusCode;
31
32    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
33        parts
34            .extensions
35            .get::<CsrfToken>()
36            .cloned()
37            .ok_or(StatusCode::INTERNAL_SERVER_ERROR)
38    }
39}
40
41/// CSRF protection middleware using session-bound HMAC derivation.
42///
43/// **Authenticated requests (session cookie present):**
44/// The CSRF token is `HMAC-SHA256(csrf_key, session_token_bytes)`. No DB read
45/// needed — the token is derived from the cookie value already in memory.
46/// The derived token is stable for the session lifetime (SPA/HTMX friendly).
47///
48/// **Pre-auth requests (no session cookie, e.g. login/register forms):**
49/// Falls back to a double-submit cookie pattern using a `csrf_pre` cookie.
50/// A random UUID is generated on GET and stored in `csrf_pre`; POST must
51/// echo it back via `X-CSRF-Token` header or `csrf_token` form field.
52///
53/// Returns 403 on CSRF mismatch and 500 if `csrf_key` is not configured.
54pub async fn csrf_middleware(
55    mut request: Request<Body>,
56    next: Next,
57) -> Result<Response, StatusCode> {
58    let ath = request
59        .extensions()
60        .get::<AllowThem>()
61        .cloned()
62        .ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
63
64    let csrf_key = ath
65        .csrf_key()
66        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
67
68    let method = request.method().clone();
69    let is_safe = matches!(
70        method,
71        axum::http::Method::GET | axum::http::Method::HEAD | axum::http::Method::OPTIONS
72    );
73
74    let session_token = ath.parse_session_cookie(
75        request
76            .headers()
77            .get(header::COOKIE)
78            .and_then(|v| v.to_str().ok())
79            .unwrap_or(""),
80    );
81
82    if is_safe {
83        let csrf_token = match &session_token {
84            Some(tok) => derive_csrf_token(tok, csrf_key),
85            None => extract_pre_auth_csrf_cookie(request.headers())
86                .unwrap_or_else(|| Uuid::new_v4().to_string()),
87        };
88
89        let is_new_pre_auth =
90            session_token.is_none() && extract_pre_auth_csrf_cookie(request.headers()).is_none();
91
92        request
93            .extensions_mut()
94            .insert(CsrfToken(csrf_token.clone()));
95
96        let mut response = next.run(request).await;
97
98        if is_new_pre_auth {
99            let secure = ath.session_config().secure;
100            set_pre_auth_csrf_cookie(&mut response, &csrf_token, secure);
101        }
102
103        Ok(response)
104    } else {
105        let submitted = extract_submitted_token(&mut request).await?;
106
107        match &session_token {
108            Some(tok) => {
109                if !verify_csrf_token(tok, csrf_key, &submitted) {
110                    return Err(StatusCode::FORBIDDEN);
111                }
112                request.extensions_mut().insert(CsrfToken(submitted));
113            }
114            None => {
115                let cookie_val =
116                    extract_pre_auth_csrf_cookie(request.headers()).ok_or(StatusCode::FORBIDDEN)?;
117                if cookie_val.len() != submitted.len() {
118                    return Err(StatusCode::FORBIDDEN);
119                }
120                let matches: bool = cookie_val.as_bytes().ct_eq(submitted.as_bytes()).into();
121                if !matches {
122                    return Err(StatusCode::FORBIDDEN);
123                }
124                request.extensions_mut().insert(CsrfToken(submitted));
125            }
126        }
127
128        Ok(next.run(request).await)
129    }
130}
131
132fn extract_pre_auth_csrf_cookie(headers: &header::HeaderMap) -> Option<String> {
133    let cookie_header = headers.get(header::COOKIE)?.to_str().ok()?;
134    for pair in cookie_header.split("; ") {
135        if let Some((name, value)) = pair.split_once('=')
136            && name.trim() == PRE_AUTH_CSRF_COOKIE
137        {
138            return Some(value.trim().to_string());
139        }
140    }
141    None
142}
143
144fn set_pre_auth_csrf_cookie(response: &mut Response, token: &str, secure: bool) {
145    let mut cookie = format!(
146        "{}={}; SameSite=Lax; Path=/; Max-Age=1800",
147        PRE_AUTH_CSRF_COOKIE, token
148    );
149    if secure {
150        cookie.push_str("; Secure");
151    }
152    if let Ok(value) = cookie.parse() {
153        response.headers_mut().append(header::SET_COOKIE, value);
154    }
155}
156
157/// Extract the submitted CSRF token from `X-CSRF-Token` header or
158/// `csrf_token` field in an `application/x-www-form-urlencoded` body.
159///
160/// Consumes and replaces the request body so the handler still receives it.
161async fn extract_submitted_token(request: &mut Request<Body>) -> Result<String, StatusCode> {
162    if let Some(header_val) = request.headers().get("x-csrf-token")
163        && let Ok(token) = header_val.to_str()
164    {
165        return Ok(token.to_string());
166    }
167
168    let is_form = request
169        .headers()
170        .get(header::CONTENT_TYPE)
171        .and_then(|v| v.to_str().ok())
172        .map(|ct| ct.starts_with("application/x-www-form-urlencoded"))
173        .unwrap_or(false);
174
175    if !is_form {
176        return Err(StatusCode::FORBIDDEN);
177    }
178
179    let body = std::mem::replace(request.body_mut(), Body::empty());
180    let bytes = axum::body::to_bytes(body, 64 * 1024)
181        .await
182        .map_err(|_| StatusCode::BAD_REQUEST)?;
183
184    *request.body_mut() = Body::from(bytes.clone());
185
186    let body_str = std::str::from_utf8(&bytes).map_err(|_| StatusCode::BAD_REQUEST)?;
187    for pair in body_str.split('&') {
188        if let Some((key, value)) = pair.split_once('=')
189            && key == "csrf_token"
190        {
191            return Ok(value.to_string());
192        }
193    }
194
195    Err(StatusCode::FORBIDDEN)
196}
197
198#[cfg(test)]
199mod tests {
200    use super::*;
201    use allowthem_core::{AllowThemBuilder, Email, generate_token, hash_token};
202    use axum::{Router, middleware, routing::get};
203    use chrono::{Duration, Utc};
204    use tower::ServiceExt;
205
206    const TEST_CSRF_KEY: [u8; 32] = *b"test-csrf-key-32bytes-padding!!!";
207
208    async fn ok_handler() -> StatusCode {
209        StatusCode::OK
210    }
211
212    async fn build_ath() -> AllowThem {
213        AllowThemBuilder::new("sqlite::memory:")
214            .cookie_secure(false)
215            .csrf_key(TEST_CSRF_KEY)
216            .build()
217            .await
218            .unwrap()
219    }
220
221    fn test_app(ath: AllowThem) -> Router {
222        Router::new()
223            .route("/", get(ok_handler).post(ok_handler))
224            .layer(middleware::from_fn(csrf_middleware))
225            .layer(middleware::from_fn_with_state(
226                ath.clone(),
227                crate::cors::inject_ath_into_extensions,
228            ))
229    }
230
231    fn get_set_cookie(response: &Response) -> Option<String> {
232        response
233            .headers()
234            .get(header::SET_COOKIE)
235            .and_then(|v| v.to_str().ok())
236            .map(|s| s.to_string())
237    }
238
239    fn extract_token_from_set_cookie(set_cookie: &str) -> String {
240        set_cookie
241            .split(';')
242            .next()
243            .and_then(|pair| pair.split_once('='))
244            .map(|(_, v)| v.trim().to_string())
245            .expect("csrf token not found in Set-Cookie")
246    }
247
248    // --- Pre-auth path (no session cookie) ---
249
250    #[tokio::test]
251    async fn pre_auth_get_sets_csrf_pre_cookie() {
252        let app = test_app(build_ath().await);
253        let response = app
254            .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
255            .await
256            .unwrap();
257        assert_eq!(response.status(), StatusCode::OK);
258        let set_cookie = get_set_cookie(&response).expect("Set-Cookie header missing");
259        assert!(set_cookie.starts_with("csrf_pre="));
260        assert!(set_cookie.contains("SameSite=Lax"));
261        assert!(set_cookie.contains("Max-Age=1800"));
262        assert!(!set_cookie.contains("Secure"));
263    }
264
265    #[tokio::test]
266    async fn pre_auth_get_does_not_reset_existing_csrf_pre_cookie() {
267        let app = test_app(build_ath().await);
268        let response = app
269            .oneshot(
270                Request::builder()
271                    .uri("/")
272                    .header(header::COOKIE, "csrf_pre=existing_value")
273                    .body(Body::empty())
274                    .unwrap(),
275            )
276            .await
277            .unwrap();
278        assert_eq!(response.status(), StatusCode::OK);
279        assert!(get_set_cookie(&response).is_none());
280    }
281
282    #[tokio::test]
283    async fn pre_auth_post_accepts_matching_cookie_and_header() {
284        let app = test_app(build_ath().await);
285        let get_resp = app
286            .clone()
287            .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
288            .await
289            .unwrap();
290        let set_cookie = get_set_cookie(&get_resp).expect("Set-Cookie missing");
291        let token = extract_token_from_set_cookie(&set_cookie);
292        let post_resp = app
293            .oneshot(
294                Request::builder()
295                    .method("POST")
296                    .uri("/")
297                    .header(header::COOKIE, format!("csrf_pre={token}"))
298                    .header("x-csrf-token", &token)
299                    .body(Body::empty())
300                    .unwrap(),
301            )
302            .await
303            .unwrap();
304        assert_eq!(post_resp.status(), StatusCode::OK);
305    }
306
307    #[tokio::test]
308    async fn pre_auth_post_rejects_mismatched_token() {
309        let app = test_app(build_ath().await);
310        let response = app
311            .oneshot(
312                Request::builder()
313                    .method("POST")
314                    .uri("/")
315                    .header(header::COOKIE, "csrf_pre=correct")
316                    .header("x-csrf-token", "wrong")
317                    .body(Body::empty())
318                    .unwrap(),
319            )
320            .await
321            .unwrap();
322        assert_eq!(response.status(), StatusCode::FORBIDDEN);
323    }
324
325    #[tokio::test]
326    async fn pre_auth_post_rejects_missing_cookie() {
327        let app = test_app(build_ath().await);
328        let response = app
329            .oneshot(
330                Request::builder()
331                    .method("POST")
332                    .uri("/")
333                    .header("x-csrf-token", "sometoken")
334                    .body(Body::empty())
335                    .unwrap(),
336            )
337            .await
338            .unwrap();
339        assert_eq!(response.status(), StatusCode::FORBIDDEN);
340    }
341
342    #[tokio::test]
343    async fn pre_auth_post_accepts_form_token() {
344        let app = test_app(build_ath().await);
345        let get_resp = app
346            .clone()
347            .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
348            .await
349            .unwrap();
350        let set_cookie = get_set_cookie(&get_resp).expect("Set-Cookie missing");
351        let token = extract_token_from_set_cookie(&set_cookie);
352        let body = format!("username=alice&csrf_token={token}");
353        let post_resp = app
354            .oneshot(
355                Request::builder()
356                    .method("POST")
357                    .uri("/")
358                    .header(header::COOKIE, format!("csrf_pre={token}"))
359                    .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
360                    .body(Body::from(body))
361                    .unwrap(),
362            )
363            .await
364            .unwrap();
365        assert_eq!(post_resp.status(), StatusCode::OK);
366    }
367
368    // --- Session-bound path ---
369
370    async fn make_session_cookie(ath: &AllowThem) -> (String, String) {
371        let email = Email::new("user@example.com".into()).unwrap();
372        let user = ath
373            .db()
374            .create_user(email, "password", None, None)
375            .await
376            .unwrap();
377        let token = generate_token();
378        let hash = hash_token(&token);
379        let expires = Utc::now() + Duration::hours(24);
380        ath.db()
381            .create_session(user.id, hash, None, None, expires)
382            .await
383            .unwrap();
384        let cookie_header = ath.session_cookie(&token);
385        let cookie_value = cookie_header.split(';').next().unwrap().to_string();
386        let csrf = derive_csrf_token(&token, &TEST_CSRF_KEY);
387        (cookie_value, csrf)
388    }
389
390    #[tokio::test]
391    async fn session_bound_get_does_not_set_csrf_pre_cookie() {
392        let ath = build_ath().await;
393        let (session_cookie, _) = make_session_cookie(&ath).await;
394        let app = test_app(ath);
395        let response = app
396            .oneshot(
397                Request::builder()
398                    .uri("/")
399                    .header(header::COOKIE, &session_cookie)
400                    .body(Body::empty())
401                    .unwrap(),
402            )
403            .await
404            .unwrap();
405        assert_eq!(response.status(), StatusCode::OK);
406        assert!(get_set_cookie(&response).is_none());
407    }
408
409    #[tokio::test]
410    async fn session_bound_post_accepts_derived_token_in_header() {
411        let ath = build_ath().await;
412        let (session_cookie, csrf) = make_session_cookie(&ath).await;
413        let app = test_app(ath);
414        let response = app
415            .oneshot(
416                Request::builder()
417                    .method("POST")
418                    .uri("/")
419                    .header(header::COOKIE, &session_cookie)
420                    .header("x-csrf-token", &csrf)
421                    .body(Body::empty())
422                    .unwrap(),
423            )
424            .await
425            .unwrap();
426        assert_eq!(response.status(), StatusCode::OK);
427    }
428
429    #[tokio::test]
430    async fn session_bound_post_rejects_wrong_token() {
431        let ath = build_ath().await;
432        let (session_cookie, _) = make_session_cookie(&ath).await;
433        let app = test_app(ath);
434        let response = app
435            .oneshot(
436                Request::builder()
437                    .method("POST")
438                    .uri("/")
439                    .header(header::COOKIE, &session_cookie)
440                    .header(
441                        "x-csrf-token",
442                        "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
443                    )
444                    .body(Body::empty())
445                    .unwrap(),
446            )
447            .await
448            .unwrap();
449        assert_eq!(response.status(), StatusCode::FORBIDDEN);
450    }
451
452    #[tokio::test]
453    async fn session_bound_post_accepts_form_token() {
454        let ath = build_ath().await;
455        let (session_cookie, csrf) = make_session_cookie(&ath).await;
456        let app = test_app(ath);
457        let body = format!("field=value&csrf_token={csrf}");
458        let response = app
459            .oneshot(
460                Request::builder()
461                    .method("POST")
462                    .uri("/")
463                    .header(header::COOKIE, &session_cookie)
464                    .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
465                    .body(Body::from(body))
466                    .unwrap(),
467            )
468            .await
469            .unwrap();
470        assert_eq!(response.status(), StatusCode::OK);
471    }
472
473    #[tokio::test]
474    async fn returns_500_when_csrf_key_not_configured() {
475        let ath = AllowThemBuilder::new("sqlite::memory:")
476            .cookie_secure(false)
477            .build()
478            .await
479            .unwrap();
480        let app = Router::new()
481            .route("/", get(ok_handler).post(ok_handler))
482            .layer(middleware::from_fn(csrf_middleware))
483            .layer(middleware::from_fn_with_state(
484                ath.clone(),
485                crate::cors::inject_ath_into_extensions,
486            ))
487            .with_state(ath);
488
489        let response = app
490            .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
491            .await
492            .unwrap();
493        assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
494    }
495
496    #[tokio::test]
497    async fn head_does_not_require_csrf() {
498        let app = test_app(build_ath().await);
499        let response = app
500            .oneshot(
501                Request::builder()
502                    .method("HEAD")
503                    .uri("/")
504                    .body(Body::empty())
505                    .unwrap(),
506            )
507            .await
508            .unwrap();
509        assert_eq!(response.status(), StatusCode::OK);
510    }
511}