Skip to main content

allowthem_server/
csrf.rs

1use axum::{
2    body::Body,
3    extract::{FromRequestParts, State},
4    http::{Request, StatusCode, header, request::Parts},
5    middleware::Next,
6    response::Response,
7};
8use subtle::ConstantTimeEq;
9use uuid::Uuid;
10
11use allowthem_core::{AllowThem, derive_csrf_token, verify_csrf_token};
12
13const PRE_AUTH_CSRF_COOKIE: &str = "csrf_pre";
14
15/// A CSRF token for the current request.
16///
17/// Available to handlers via extractor after `csrf_middleware` has run.
18/// Embed in forms as a hidden field named `csrf_token`, or send as
19/// `X-CSRF-Token` header for AJAX/HTMX requests.
20#[derive(Clone)]
21pub struct CsrfToken(pub String);
22
23impl CsrfToken {
24    pub fn as_str(&self) -> &str {
25        &self.0
26    }
27}
28
29impl<S: Send + Sync> FromRequestParts<S> for CsrfToken {
30    type Rejection = StatusCode;
31
32    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
33        parts
34            .extensions
35            .get::<CsrfToken>()
36            .cloned()
37            .ok_or(StatusCode::INTERNAL_SERVER_ERROR)
38    }
39}
40
41/// CSRF protection middleware using session-bound HMAC derivation.
42///
43/// **Authenticated requests (session cookie present):**
44/// The CSRF token is `HMAC-SHA256(csrf_key, session_token_bytes)`. No DB read
45/// needed — the token is derived from the cookie value already in memory.
46/// The derived token is stable for the session lifetime (SPA/HTMX friendly).
47///
48/// **Pre-auth requests (no session cookie, e.g. login/register forms):**
49/// Falls back to a double-submit cookie pattern using a `csrf_pre` cookie.
50/// A random UUID is generated on GET and stored in `csrf_pre`; POST must
51/// echo it back via `X-CSRF-Token` header or `csrf_token` form field.
52///
53/// Returns 403 on CSRF mismatch and 500 if `csrf_key` is not configured.
54pub async fn csrf_middleware(
55    State(ath): State<AllowThem>,
56    mut request: Request<Body>,
57    next: Next,
58) -> Result<Response, StatusCode> {
59    let csrf_key = ath
60        .csrf_key()
61        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
62
63    let method = request.method().clone();
64    let is_safe = matches!(
65        method,
66        axum::http::Method::GET | axum::http::Method::HEAD | axum::http::Method::OPTIONS
67    );
68
69    let session_token = ath.parse_session_cookie(
70        request
71            .headers()
72            .get(header::COOKIE)
73            .and_then(|v| v.to_str().ok())
74            .unwrap_or(""),
75    );
76
77    if is_safe {
78        let csrf_token = match &session_token {
79            Some(tok) => derive_csrf_token(tok, csrf_key),
80            None => extract_pre_auth_csrf_cookie(request.headers())
81                .unwrap_or_else(|| Uuid::new_v4().to_string()),
82        };
83
84        let is_new_pre_auth =
85            session_token.is_none() && extract_pre_auth_csrf_cookie(request.headers()).is_none();
86
87        request
88            .extensions_mut()
89            .insert(CsrfToken(csrf_token.clone()));
90
91        let mut response = next.run(request).await;
92
93        if is_new_pre_auth {
94            let secure = ath.session_config().secure;
95            set_pre_auth_csrf_cookie(&mut response, &csrf_token, secure);
96        }
97
98        Ok(response)
99    } else {
100        let submitted = extract_submitted_token(&mut request).await?;
101
102        match &session_token {
103            Some(tok) => {
104                if !verify_csrf_token(tok, csrf_key, &submitted) {
105                    return Err(StatusCode::FORBIDDEN);
106                }
107                request.extensions_mut().insert(CsrfToken(submitted));
108            }
109            None => {
110                let cookie_val =
111                    extract_pre_auth_csrf_cookie(request.headers()).ok_or(StatusCode::FORBIDDEN)?;
112                if cookie_val.len() != submitted.len() {
113                    return Err(StatusCode::FORBIDDEN);
114                }
115                let matches: bool = cookie_val.as_bytes().ct_eq(submitted.as_bytes()).into();
116                if !matches {
117                    return Err(StatusCode::FORBIDDEN);
118                }
119                request.extensions_mut().insert(CsrfToken(submitted));
120            }
121        }
122
123        Ok(next.run(request).await)
124    }
125}
126
127fn extract_pre_auth_csrf_cookie(headers: &header::HeaderMap) -> Option<String> {
128    let cookie_header = headers.get(header::COOKIE)?.to_str().ok()?;
129    for pair in cookie_header.split("; ") {
130        if let Some((name, value)) = pair.split_once('=')
131            && name.trim() == PRE_AUTH_CSRF_COOKIE
132        {
133            return Some(value.trim().to_string());
134        }
135    }
136    None
137}
138
139fn set_pre_auth_csrf_cookie(response: &mut Response, token: &str, secure: bool) {
140    let mut cookie = format!(
141        "{}={}; SameSite=Lax; Path=/; Max-Age=1800",
142        PRE_AUTH_CSRF_COOKIE, token
143    );
144    if secure {
145        cookie.push_str("; Secure");
146    }
147    if let Ok(value) = cookie.parse() {
148        response.headers_mut().append(header::SET_COOKIE, value);
149    }
150}
151
152/// Extract the submitted CSRF token from `X-CSRF-Token` header or
153/// `csrf_token` field in an `application/x-www-form-urlencoded` body.
154///
155/// Consumes and replaces the request body so the handler still receives it.
156async fn extract_submitted_token(request: &mut Request<Body>) -> Result<String, StatusCode> {
157    if let Some(header_val) = request.headers().get("x-csrf-token")
158        && let Ok(token) = header_val.to_str()
159    {
160        return Ok(token.to_string());
161    }
162
163    let is_form = request
164        .headers()
165        .get(header::CONTENT_TYPE)
166        .and_then(|v| v.to_str().ok())
167        .map(|ct| ct.starts_with("application/x-www-form-urlencoded"))
168        .unwrap_or(false);
169
170    if !is_form {
171        return Err(StatusCode::FORBIDDEN);
172    }
173
174    let body = std::mem::replace(request.body_mut(), Body::empty());
175    let bytes = axum::body::to_bytes(body, 64 * 1024)
176        .await
177        .map_err(|_| StatusCode::BAD_REQUEST)?;
178
179    *request.body_mut() = Body::from(bytes.clone());
180
181    let body_str = std::str::from_utf8(&bytes).map_err(|_| StatusCode::BAD_REQUEST)?;
182    for pair in body_str.split('&') {
183        if let Some((key, value)) = pair.split_once('=')
184            && key == "csrf_token"
185        {
186            return Ok(value.to_string());
187        }
188    }
189
190    Err(StatusCode::FORBIDDEN)
191}
192
193#[cfg(test)]
194mod tests {
195    use super::*;
196    use allowthem_core::{AllowThemBuilder, Email, generate_token, hash_token};
197    use axum::{Router, middleware, routing::get};
198    use chrono::{Duration, Utc};
199    use tower::ServiceExt;
200
201    const TEST_CSRF_KEY: [u8; 32] = *b"test-csrf-key-32bytes-padding!!!";
202
203    async fn ok_handler() -> StatusCode {
204        StatusCode::OK
205    }
206
207    async fn build_ath() -> AllowThem {
208        AllowThemBuilder::new("sqlite::memory:")
209            .cookie_secure(false)
210            .csrf_key(TEST_CSRF_KEY)
211            .build()
212            .await
213            .unwrap()
214    }
215
216    fn test_app(ath: AllowThem) -> Router {
217        Router::new()
218            .route("/", get(ok_handler).post(ok_handler))
219            .layer(middleware::from_fn_with_state(ath.clone(), csrf_middleware))
220            .with_state(ath)
221    }
222
223    fn get_set_cookie(response: &Response) -> Option<String> {
224        response
225            .headers()
226            .get(header::SET_COOKIE)
227            .and_then(|v| v.to_str().ok())
228            .map(|s| s.to_string())
229    }
230
231    fn extract_token_from_set_cookie(set_cookie: &str) -> String {
232        set_cookie
233            .split(';')
234            .next()
235            .and_then(|pair| pair.split_once('='))
236            .map(|(_, v)| v.trim().to_string())
237            .expect("csrf token not found in Set-Cookie")
238    }
239
240    // --- Pre-auth path (no session cookie) ---
241
242    #[tokio::test]
243    async fn pre_auth_get_sets_csrf_pre_cookie() {
244        let app = test_app(build_ath().await);
245        let response = app
246            .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
247            .await
248            .unwrap();
249        assert_eq!(response.status(), StatusCode::OK);
250        let set_cookie = get_set_cookie(&response).expect("Set-Cookie header missing");
251        assert!(set_cookie.starts_with("csrf_pre="));
252        assert!(set_cookie.contains("SameSite=Lax"));
253        assert!(set_cookie.contains("Max-Age=1800"));
254        assert!(!set_cookie.contains("Secure"));
255    }
256
257    #[tokio::test]
258    async fn pre_auth_get_does_not_reset_existing_csrf_pre_cookie() {
259        let app = test_app(build_ath().await);
260        let response = app
261            .oneshot(
262                Request::builder()
263                    .uri("/")
264                    .header(header::COOKIE, "csrf_pre=existing_value")
265                    .body(Body::empty())
266                    .unwrap(),
267            )
268            .await
269            .unwrap();
270        assert_eq!(response.status(), StatusCode::OK);
271        assert!(get_set_cookie(&response).is_none());
272    }
273
274    #[tokio::test]
275    async fn pre_auth_post_accepts_matching_cookie_and_header() {
276        let app = test_app(build_ath().await);
277        let get_resp = app
278            .clone()
279            .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
280            .await
281            .unwrap();
282        let set_cookie = get_set_cookie(&get_resp).expect("Set-Cookie missing");
283        let token = extract_token_from_set_cookie(&set_cookie);
284        let post_resp = app
285            .oneshot(
286                Request::builder()
287                    .method("POST")
288                    .uri("/")
289                    .header(header::COOKIE, format!("csrf_pre={token}"))
290                    .header("x-csrf-token", &token)
291                    .body(Body::empty())
292                    .unwrap(),
293            )
294            .await
295            .unwrap();
296        assert_eq!(post_resp.status(), StatusCode::OK);
297    }
298
299    #[tokio::test]
300    async fn pre_auth_post_rejects_mismatched_token() {
301        let app = test_app(build_ath().await);
302        let response = app
303            .oneshot(
304                Request::builder()
305                    .method("POST")
306                    .uri("/")
307                    .header(header::COOKIE, "csrf_pre=correct")
308                    .header("x-csrf-token", "wrong")
309                    .body(Body::empty())
310                    .unwrap(),
311            )
312            .await
313            .unwrap();
314        assert_eq!(response.status(), StatusCode::FORBIDDEN);
315    }
316
317    #[tokio::test]
318    async fn pre_auth_post_rejects_missing_cookie() {
319        let app = test_app(build_ath().await);
320        let response = app
321            .oneshot(
322                Request::builder()
323                    .method("POST")
324                    .uri("/")
325                    .header("x-csrf-token", "sometoken")
326                    .body(Body::empty())
327                    .unwrap(),
328            )
329            .await
330            .unwrap();
331        assert_eq!(response.status(), StatusCode::FORBIDDEN);
332    }
333
334    #[tokio::test]
335    async fn pre_auth_post_accepts_form_token() {
336        let app = test_app(build_ath().await);
337        let get_resp = app
338            .clone()
339            .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
340            .await
341            .unwrap();
342        let set_cookie = get_set_cookie(&get_resp).expect("Set-Cookie missing");
343        let token = extract_token_from_set_cookie(&set_cookie);
344        let body = format!("username=alice&csrf_token={token}");
345        let post_resp = app
346            .oneshot(
347                Request::builder()
348                    .method("POST")
349                    .uri("/")
350                    .header(header::COOKIE, format!("csrf_pre={token}"))
351                    .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
352                    .body(Body::from(body))
353                    .unwrap(),
354            )
355            .await
356            .unwrap();
357        assert_eq!(post_resp.status(), StatusCode::OK);
358    }
359
360    // --- Session-bound path ---
361
362    async fn make_session_cookie(ath: &AllowThem) -> (String, String) {
363        let email = Email::new("user@example.com".into()).unwrap();
364        let user = ath
365            .db()
366            .create_user(email, "password", None, None)
367            .await
368            .unwrap();
369        let token = generate_token();
370        let hash = hash_token(&token);
371        let expires = Utc::now() + Duration::hours(24);
372        ath.db()
373            .create_session(user.id, hash, None, None, expires)
374            .await
375            .unwrap();
376        let cookie_header = ath.session_cookie(&token);
377        let cookie_value = cookie_header.split(';').next().unwrap().to_string();
378        let csrf = derive_csrf_token(&token, &TEST_CSRF_KEY);
379        (cookie_value, csrf)
380    }
381
382    #[tokio::test]
383    async fn session_bound_get_does_not_set_csrf_pre_cookie() {
384        let ath = build_ath().await;
385        let (session_cookie, _) = make_session_cookie(&ath).await;
386        let app = test_app(ath);
387        let response = app
388            .oneshot(
389                Request::builder()
390                    .uri("/")
391                    .header(header::COOKIE, &session_cookie)
392                    .body(Body::empty())
393                    .unwrap(),
394            )
395            .await
396            .unwrap();
397        assert_eq!(response.status(), StatusCode::OK);
398        assert!(get_set_cookie(&response).is_none());
399    }
400
401    #[tokio::test]
402    async fn session_bound_post_accepts_derived_token_in_header() {
403        let ath = build_ath().await;
404        let (session_cookie, csrf) = make_session_cookie(&ath).await;
405        let app = test_app(ath);
406        let response = app
407            .oneshot(
408                Request::builder()
409                    .method("POST")
410                    .uri("/")
411                    .header(header::COOKIE, &session_cookie)
412                    .header("x-csrf-token", &csrf)
413                    .body(Body::empty())
414                    .unwrap(),
415            )
416            .await
417            .unwrap();
418        assert_eq!(response.status(), StatusCode::OK);
419    }
420
421    #[tokio::test]
422    async fn session_bound_post_rejects_wrong_token() {
423        let ath = build_ath().await;
424        let (session_cookie, _) = make_session_cookie(&ath).await;
425        let app = test_app(ath);
426        let response = app
427            .oneshot(
428                Request::builder()
429                    .method("POST")
430                    .uri("/")
431                    .header(header::COOKIE, &session_cookie)
432                    .header(
433                        "x-csrf-token",
434                        "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
435                    )
436                    .body(Body::empty())
437                    .unwrap(),
438            )
439            .await
440            .unwrap();
441        assert_eq!(response.status(), StatusCode::FORBIDDEN);
442    }
443
444    #[tokio::test]
445    async fn session_bound_post_accepts_form_token() {
446        let ath = build_ath().await;
447        let (session_cookie, csrf) = make_session_cookie(&ath).await;
448        let app = test_app(ath);
449        let body = format!("field=value&csrf_token={csrf}");
450        let response = app
451            .oneshot(
452                Request::builder()
453                    .method("POST")
454                    .uri("/")
455                    .header(header::COOKIE, &session_cookie)
456                    .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
457                    .body(Body::from(body))
458                    .unwrap(),
459            )
460            .await
461            .unwrap();
462        assert_eq!(response.status(), StatusCode::OK);
463    }
464
465    #[tokio::test]
466    async fn returns_500_when_csrf_key_not_configured() {
467        let ath = AllowThemBuilder::new("sqlite::memory:")
468            .cookie_secure(false)
469            .build()
470            .await
471            .unwrap();
472        let app = Router::new()
473            .route("/", get(ok_handler).post(ok_handler))
474            .layer(middleware::from_fn_with_state(ath.clone(), csrf_middleware))
475            .with_state(ath);
476        let response = app
477            .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
478            .await
479            .unwrap();
480        assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
481    }
482
483    #[tokio::test]
484    async fn head_does_not_require_csrf() {
485        let app = test_app(build_ath().await);
486        let response = app
487            .oneshot(
488                Request::builder()
489                    .method("HEAD")
490                    .uri("/")
491                    .body(Body::empty())
492                    .unwrap(),
493            )
494            .await
495            .unwrap();
496        assert_eq!(response.status(), StatusCode::OK);
497    }
498}