Skip to main content

autumn_web/security/
csrf.rs

1//! CSRF (Cross-Site Request Forgery) protection middleware.
2//!
3//! Protects against CSRF attacks by requiring a token on mutating
4//! HTTP methods (POST, PUT, DELETE, PATCH). The token is stored in a
5//! cookie and must be echoed back via a request header or form field.
6//!
7//! # How it works
8//!
9//! 1. On every response, a CSRF cookie is set (if not already present)
10//!    containing a random UUID v4 token.
11//! 2. On mutating requests, the middleware checks that the token from
12//!    the cookie matches the token in the `X-CSRF-Token` header (or
13//!    `_csrf` form field).
14//! 3. Safe methods (GET, HEAD, OPTIONS, TRACE) are exempt.
15//!
16//! # Configuration
17//!
18//! See [`CsrfConfig`] for available settings.
19//!
20//! # Examples
21//!
22//! ## Template integration (Maud)
23//!
24//! ```rust,ignore
25//! use autumn_web::prelude::*;
26//! use autumn_web::security::CsrfToken;
27//!
28//! #[get("/form")]
29//! async fn form(csrf: CsrfToken) -> Markup {
30//!     html! {
31//!         form method="POST" action="/submit" {
32//!             input type="hidden" name="_csrf" value=(csrf.token());
33//!             input type="text" name="title";
34//!             button { "Submit" }
35//!         }
36//!     }
37//! }
38//! ```
39//!
40//! ## JavaScript / htmx
41//!
42//! Read the CSRF token from the `autumn-csrf` cookie and send it
43//! as an `X-CSRF-Token` header with every mutating request.
44
45use std::future::Future;
46use std::pin::Pin;
47use std::sync::Arc;
48use std::task::{Context, Poll};
49
50use axum::extract::{FromRequestParts, OptionalFromRequestParts};
51use axum::http::{Request, Response, StatusCode};
52use http::header::HeaderName;
53
54use tower::{Layer, Service};
55use uuid::Uuid;
56
57use super::config::CsrfConfig;
58
59/// Error body returned with a `403 Forbidden` when CSRF validation fails.
60const CSRF_FORBIDDEN_MESSAGE: &str = "CSRF token missing or invalid";
61
62/// The configured CSRF form field name, placed in request extensions by [`CsrfLayer`].
63///
64/// [`ChangesetForm`](crate::form::ChangesetForm) reads this so `form_tag` emits the
65/// hidden input under the correct field name even when `security.csrf.form_field` has
66/// been customised from its default `"_csrf"`.
67#[derive(Clone, Debug)]
68pub struct CsrfFormField(pub String);
69
70/// A CSRF token extracted from the request.
71///
72/// Use this as a handler parameter to access the CSRF token for embedding
73/// in HTML forms. The token is generated per-request and stored in
74/// request extensions by the [`CsrfLayer`].
75///
76/// ## Examples
77///
78/// ```rust,ignore
79/// use autumn_web::prelude::*;
80/// use autumn_web::security::CsrfToken;
81///
82/// #[get("/edit")]
83/// async fn edit_form(csrf: CsrfToken) -> Markup {
84///     html! {
85///         form method="POST" {
86///             input type="hidden" name="_csrf" value=(csrf.token());
87///             // ...
88///         }
89///     }
90/// }
91/// ```
92#[derive(Clone, Debug)]
93pub struct CsrfToken(String);
94
95impl CsrfToken {
96    /// Returns the CSRF token value for embedding in forms or headers.
97    #[must_use]
98    pub fn token(&self) -> &str {
99        &self.0
100    }
101
102    #[cfg(test)]
103    pub(crate) const fn new(token: String) -> Self {
104        Self(token)
105    }
106}
107
108impl std::fmt::Display for CsrfToken {
109    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
110        f.write_str(&self.0)
111    }
112}
113
114impl<S> FromRequestParts<S> for CsrfToken
115where
116    S: Send + Sync,
117{
118    type Rejection = (StatusCode, &'static str);
119
120    async fn from_request_parts(
121        parts: &mut axum::http::request::Parts,
122        _state: &S,
123    ) -> Result<Self, Self::Rejection> {
124        parts.extensions.get::<Self>().cloned().ok_or((
125            StatusCode::INTERNAL_SERVER_ERROR,
126            "CSRF token not found in request extensions. Is CsrfLayer enabled?",
127        ))
128    }
129}
130
131impl<S> OptionalFromRequestParts<S> for CsrfToken
132where
133    S: Send + Sync,
134{
135    type Rejection = std::convert::Infallible;
136
137    async fn from_request_parts(
138        parts: &mut axum::http::request::Parts,
139        _state: &S,
140    ) -> Result<Option<Self>, Self::Rejection> {
141        Ok(parts.extensions.get::<Self>().cloned())
142    }
143}
144
145impl<S> FromRequestParts<S> for CsrfFormField
146where
147    S: Send + Sync,
148{
149    type Rejection = (StatusCode, &'static str);
150
151    async fn from_request_parts(
152        parts: &mut axum::http::request::Parts,
153        _state: &S,
154    ) -> Result<Self, Self::Rejection> {
155        parts.extensions.get::<Self>().cloned().ok_or((
156            StatusCode::INTERNAL_SERVER_ERROR,
157            "CSRF form field not found in request extensions. Is CsrfLayer enabled?",
158        ))
159    }
160}
161
162impl<S> OptionalFromRequestParts<S> for CsrfFormField
163where
164    S: Send + Sync,
165{
166    type Rejection = std::convert::Infallible;
167
168    async fn from_request_parts(
169        parts: &mut axum::http::request::Parts,
170        _state: &S,
171    ) -> Result<Option<Self>, Self::Rejection> {
172        Ok(parts.extensions.get::<Self>().cloned())
173    }
174}
175
176/// Shared CSRF configuration.
177#[derive(Debug, Clone)]
178struct CsrfSettings {
179    cookie_name: String,
180    token_header: HeaderName,
181    form_field: String,
182    safe_methods: Vec<http::Method>,
183    exempt_paths: Vec<String>,
184    signing_keys: Option<Arc<crate::security::config::ResolvedSigningKeys>>,
185}
186
187/// Tower [`Layer`] that applies CSRF protection.
188///
189/// Applied automatically when `security.csrf.enabled = true` in config.
190#[derive(Clone, Debug)]
191pub struct CsrfLayer {
192    settings: Arc<CsrfSettings>,
193}
194
195impl CsrfLayer {
196    /// Create a new CSRF layer from configuration.
197    #[must_use]
198    pub fn from_config(config: &CsrfConfig) -> Self {
199        let safe_methods = config
200            .safe_methods
201            .iter()
202            .filter_map(|m| m.parse::<http::Method>().ok())
203            .collect();
204
205        let token_header = config
206            .token_header
207            .parse::<HeaderName>()
208            .unwrap_or_else(|_| HeaderName::from_static("x-csrf-token"));
209
210        Self {
211            settings: Arc::new(CsrfSettings {
212                cookie_name: config.cookie_name.clone(),
213                token_header,
214                form_field: config.form_field.clone(),
215                safe_methods,
216                exempt_paths: config.exempt_paths.clone(),
217                signing_keys: None,
218            }),
219        }
220    }
221
222    /// Attach signing keys so CSRF tokens are HMAC-signed.
223    ///
224    /// When set, tokens are in `{uuid}.{hmac_hex}` format. Unsigned tokens are
225    /// rejected. Previous keys (see `ResolvedSigningKeys`) allow tokens signed
226    /// with an old key to remain valid during a rotation grace window.
227    #[must_use]
228    pub fn with_signing_keys(
229        mut self,
230        keys: Arc<crate::security::config::ResolvedSigningKeys>,
231    ) -> Self {
232        Arc::make_mut(&mut self.settings).signing_keys = Some(keys);
233        self
234    }
235}
236
237impl<S> Layer<S> for CsrfLayer {
238    type Service = CsrfService<S>;
239
240    fn layer(&self, inner: S) -> Self::Service {
241        CsrfService {
242            inner,
243            settings: Arc::clone(&self.settings),
244        }
245    }
246}
247
248/// Tower [`Service`] produced by [`CsrfLayer`].
249#[derive(Clone, Debug)]
250pub struct CsrfService<S> {
251    inner: S,
252    settings: Arc<CsrfSettings>,
253}
254
255use subtle::{Choice, ConstantTimeEq};
256
257/// Constant-time string comparison to prevent timing attacks when verifying CSRF tokens.
258///
259/// The comparison always processes exactly `b.len()` bytes so that execution
260/// time is independent of the length of the submitted token `a`.  Neither a
261/// length mismatch nor a short input causes an early exit.
262#[inline(never)]
263fn constant_time_eq(a: &str, b: &str) -> bool {
264    let a = a.as_bytes();
265    let b = b.as_bytes();
266
267    // Constant-time length check — no early exit.
268    let len_eq = a.len().ct_eq(&b.len());
269
270    // Iterate over `a` (the trusted stored token) so the loop count is fixed
271    // at the server-side token length, regardless of what the caller submits
272    // as `b`.  Callers pass the attacker-controlled value as `b`, so iterating
273    // over `a` ensures every submission — short or long — executes the same
274    // amount of work.  Out-of-range positions in `b` use the sentinel 0xFF,
275    // which can never match a valid ASCII/UTF-8 token byte.
276    let mut bytes_eq = Choice::from(1u8);
277    for (i, &a_byte) in a.iter().enumerate() {
278        let b_byte = *b.get(i).unwrap_or(&0xFF);
279        bytes_eq &= a_byte.ct_eq(&b_byte);
280    }
281
282    (len_eq & bytes_eq).into()
283}
284
285/// Extract the CSRF cookie value from the Cookie header.
286fn extract_cookie_token(req_headers: &http::HeaderMap, cookie_name: &str) -> Option<String> {
287    let mut found_token = None;
288
289    for cookie_header in &req_headers.get_all(http::header::COOKIE) {
290        let Ok(cookie_str) = cookie_header.to_str() else {
291            continue;
292        };
293
294        for pair in cookie_str.split(';') {
295            let pair = pair.trim();
296            let Some((name, value)) = pair.split_once('=') else {
297                continue;
298            };
299
300            if name.trim() != cookie_name {
301                continue;
302            }
303
304            if found_token.is_some() {
305                // Multiple cookies with the same name found.
306                // This indicates a potential Cookie Tossing attack!
307                // Reject by returning None.
308                return None;
309            }
310
311            found_token = Some(value.trim().to_owned());
312        }
313    }
314
315    found_token
316}
317
318impl<S, ResBody> Service<Request<axum::body::Body>> for CsrfService<S>
319where
320    S: Service<Request<axum::body::Body>, Response = Response<ResBody>> + Clone + Send + 'static,
321    S::Future: Send + 'static,
322    S::Error: Send + 'static,
323    ResBody: From<&'static str> + From<String> + Default + Send + 'static,
324{
325    type Response = S::Response;
326    type Error = S::Error;
327    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
328
329    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
330        self.inner.poll_ready(cx)
331    }
332
333    fn call(&mut self, mut req: Request<axum::body::Body>) -> Self::Future {
334        let path = req.uri().path();
335        let is_exempt = self
336            .settings
337            .exempt_paths
338            .iter()
339            .any(|prefix| path.starts_with(prefix.as_str()));
340        let is_safe = is_exempt || self.settings.safe_methods.contains(req.method());
341        let raw_cookie_token = extract_cookie_token(req.headers(), &self.settings.cookie_name);
342
343        // When signing is active, discard any cookie that fails HMAC verification
344        // (unsigned pre-upgrade cookies, removed-key cookies, etc.) so a fresh signed
345        // token is minted and the Set-Cookie header refreshes the browser value.
346        let cookie_token = match (&raw_cookie_token, &self.settings.signing_keys) {
347            (Some(tok), Some(_)) if !validate_cookie_token_hmac(tok, &self.settings) => None,
348            _ => raw_cookie_token.clone(),
349        };
350
351        // Generate a new token if none exists in the cookie.
352        // When signing keys are active, the token is {uuid}.{hmac_hex}.
353        let token = cookie_token.clone().unwrap_or_else(|| {
354            let raw = Uuid::new_v4().to_string();
355            if let Some(keys) = &self.settings.signing_keys {
356                let sig = keys.sign(raw.as_bytes());
357                format!("{raw}.{sig}")
358            } else {
359                raw
360            }
361        });
362
363        // Insert CsrfToken and the configured form field name into request extensions.
364        req.extensions_mut().insert(CsrfToken(token.clone()));
365        req.extensions_mut()
366            .insert(CsrfFormField(self.settings.form_field.clone()));
367
368        // Check if we need to set a cookie
369        let set_cookie = if cookie_token.is_none() {
370            Some(format!(
371                "{}={}; Path=/; SameSite=Lax; HttpOnly",
372                self.settings.cookie_name, token
373            ))
374        } else {
375            None
376        };
377
378        let settings = Arc::clone(&self.settings);
379        let mut inner = self.inner.clone();
380
381        // Swap to ensure correct poll_ready semantics
382        std::mem::swap(&mut self.inner, &mut inner);
383
384        Box::pin(async move {
385            if !is_safe && !verify_csrf_token(&mut req, &settings, cookie_token.as_deref()).await {
386                let request_id = req
387                    .extensions()
388                    .get::<crate::middleware::RequestId>()
389                    .map(std::string::ToString::to_string);
390                let instance = Some(req.uri().path().to_owned());
391                if wants_problem_details(req.headers()) {
392                    return Ok(csrf_problem_response(request_id, instance));
393                }
394
395                let mut response = Response::new(ResBody::from(CSRF_FORBIDDEN_MESSAGE));
396                *response.status_mut() = StatusCode::FORBIDDEN;
397                response.headers_mut().insert(
398                    http::header::CONTENT_TYPE,
399                    http::HeaderValue::from_static("text/plain; charset=utf-8"),
400                );
401                return Ok(response);
402            }
403
404            // Validation passed (or method is safe)
405            let mut response = inner.call(req).await?;
406
407            if let Some(cookie) = set_cookie
408                && let Ok(val) = http::header::HeaderValue::from_str(&cookie)
409            {
410                response.headers_mut().append(http::header::SET_COOKIE, val);
411            }
412
413            Ok(response)
414        })
415    }
416}
417
418fn wants_problem_details(headers: &http::HeaderMap) -> bool {
419    !crate::middleware::error_page_filter::accept_prefers_html(headers)
420}
421
422fn csrf_problem_response<ResBody: From<String> + Default>(
423    request_id: Option<String>,
424    instance: Option<String>,
425) -> Response<ResBody> {
426    let mut problem = crate::error::problem_details(
427        StatusCode::FORBIDDEN,
428        CSRF_FORBIDDEN_MESSAGE.to_owned(),
429        None,
430        Some("https://autumn.dev/problems/csrf"),
431        request_id,
432        instance,
433        true,
434    );
435    "autumn.csrf".clone_into(&mut problem.code);
436    let body = crate::error::problem_details_to_json_string(&problem);
437
438    Response::builder()
439        .status(StatusCode::FORBIDDEN)
440        .header(http::header::CONTENT_TYPE, "application/problem+json")
441        .body(ResBody::from(body))
442        .unwrap_or_default()
443}
444
445/// Validate a CSRF cookie token's HMAC when signing is active.
446///
447/// Returns `false` when signing keys are set but the token is unsigned or carries
448/// an invalid HMAC (catches tampered or pre-rotation unsigned tokens).
449fn validate_cookie_token_hmac(cookie_token: &str, settings: &CsrfSettings) -> bool {
450    let Some(keys) = &settings.signing_keys else {
451        return true; // signing not active — accept raw token
452    };
453    // Signed format: "{uuid}.{hmac_hex}"
454    let Some((uuid_part, sig)) = cookie_token.split_once('.') else {
455        return false; // unsigned token rejected when signing is required
456    };
457    keys.verify(uuid_part.as_bytes(), sig)
458}
459
460async fn verify_csrf_token(
461    req: &mut Request<axum::body::Body>,
462    settings: &CsrfSettings,
463    cookie_token: Option<&str>,
464) -> bool {
465    let mut token_found = false;
466
467    // 1. Check header
468    let header_token = req
469        .headers()
470        .get(&settings.token_header)
471        .and_then(|v| v.to_str().ok());
472
473    if let (Some(c), Some(h)) = (cookie_token, header_token)
474        && !c.is_empty()
475        && !h.is_empty()
476        && validate_cookie_token_hmac(c, settings)
477        && constant_time_eq(c, h)
478    {
479        token_found = true;
480    }
481
482    if token_found {
483        return true;
484    }
485
486    // 2. Check form field (if not found in header)
487    let content_type = req
488        .headers()
489        .get(http::header::CONTENT_TYPE)
490        .and_then(|v| v.to_str().ok())
491        .unwrap_or_default();
492
493    if !content_type.starts_with("application/x-www-form-urlencoded") {
494        return false;
495    }
496
497    // Temporarily take ownership of the body
498    let body = std::mem::replace(req.body_mut(), axum::body::Body::empty());
499
500    // Limit body size to avoid DoS when extracting form field
501    let bytes = axum::body::to_bytes(body, 2 * 1024 * 1024)
502        .await
503        .unwrap_or_else(|_| axum::body::Bytes::new());
504
505    for (key, value) in url::form_urlencoded::parse(&bytes) {
506        if key == settings.form_field {
507            if let Some(c) = cookie_token
508                && !c.is_empty()
509                && !value.is_empty()
510                && validate_cookie_token_hmac(c, settings)
511                && constant_time_eq(c, value.as_ref())
512            {
513                token_found = true;
514            }
515            break;
516        }
517    }
518
519    // Restore request body
520    *req.body_mut() = axum::body::Body::from(bytes);
521
522    token_found
523}
524
525#[cfg(test)]
526mod tests {
527    #[tokio::test]
528    async fn post_with_url_encoded_token_passes() {
529        let raw_token = "abc+123/xyz=456";
530        let encoded_token = "abc%2B123%2Fxyz%3D456";
531        let app = Router::new()
532            .route("/submit", post(|| async { "created" }))
533            .layer(CsrfLayer::from_config(&default_csrf_config()));
534
535        let response = app
536            .oneshot(
537                Request::builder()
538                    .method("POST")
539                    .uri("/submit")
540                    .header("Cookie", format!("autumn-csrf={raw_token}"))
541                    .header("Content-Type", "application/x-www-form-urlencoded")
542                    .body(Body::from(format!("_csrf={encoded_token}")))
543                    .unwrap(),
544            )
545            .await
546            .unwrap();
547
548        assert_eq!(response.status(), StatusCode::OK);
549    }
550
551    use super::*;
552    use axum::Router;
553    use axum::body::Body;
554    use axum::routing::{get, post};
555    use tower::ServiceExt;
556
557    fn default_csrf_config() -> CsrfConfig {
558        CsrfConfig {
559            enabled: true,
560            ..Default::default()
561        }
562    }
563
564    #[tokio::test]
565    async fn safe_method_passes_without_token() {
566        let app = Router::new()
567            .route("/", get(|| async { "ok" }))
568            .layer(CsrfLayer::from_config(&default_csrf_config()));
569
570        let response = app
571            .oneshot(
572                Request::builder()
573                    .method("GET")
574                    .uri("/")
575                    .body(Body::empty())
576                    .unwrap(),
577            )
578            .await
579            .unwrap();
580
581        assert_eq!(response.status(), StatusCode::OK);
582    }
583
584    #[tokio::test]
585    async fn safe_method_sets_csrf_cookie() {
586        let app = Router::new()
587            .route("/", get(|| async { "ok" }))
588            .layer(CsrfLayer::from_config(&default_csrf_config()));
589
590        let response = app
591            .oneshot(
592                Request::builder()
593                    .method("GET")
594                    .uri("/")
595                    .body(Body::empty())
596                    .unwrap(),
597            )
598            .await
599            .unwrap();
600
601        let set_cookie = response
602            .headers()
603            .get("set-cookie")
604            .unwrap()
605            .to_str()
606            .unwrap();
607        assert!(set_cookie.starts_with("autumn-csrf="));
608        assert!(set_cookie.contains("HttpOnly"));
609    }
610
611    #[tokio::test]
612    async fn post_without_token_returns_403() {
613        let app = Router::new()
614            .route("/submit", post(|| async { "created" }))
615            .layer(CsrfLayer::from_config(&default_csrf_config()));
616
617        let response = app
618            .oneshot(
619                Request::builder()
620                    .method("POST")
621                    .uri("/submit")
622                    .header(http::header::ACCEPT, "text/html")
623                    .body(Body::empty())
624                    .unwrap(),
625            )
626            .await
627            .unwrap();
628
629        assert_eq!(response.status(), StatusCode::FORBIDDEN);
630    }
631
632    #[tokio::test]
633    async fn forbidden_response_has_clear_error_body() {
634        let app = Router::new()
635            .route("/submit", post(|| async { "created" }))
636            .layer(CsrfLayer::from_config(&default_csrf_config()));
637
638        let response = app
639            .oneshot(
640                Request::builder()
641                    .method("POST")
642                    .uri("/submit")
643                    .header(http::header::ACCEPT, "text/html")
644                    .body(Body::empty())
645                    .unwrap(),
646            )
647            .await
648            .unwrap();
649
650        assert_eq!(response.status(), StatusCode::FORBIDDEN);
651        assert_eq!(
652            response
653                .headers()
654                .get(http::header::CONTENT_TYPE)
655                .map(|v| v.to_str().unwrap_or_default()),
656            Some("text/plain; charset=utf-8")
657        );
658        let body = axum::body::to_bytes(response.into_body(), usize::MAX)
659            .await
660            .unwrap();
661        let text = std::str::from_utf8(&body).unwrap();
662        assert!(
663            text.contains("CSRF"),
664            "expected CSRF error message, got: {text:?}"
665        );
666    }
667
668    #[tokio::test]
669    async fn exempt_path_skips_csrf_validation() {
670        let config = CsrfConfig {
671            enabled: true,
672            exempt_paths: vec!["/api/".to_string()],
673            ..Default::default()
674        };
675        let app = Router::new()
676            .route("/api/items", post(|| async { "created" }))
677            .route("/form/submit", post(|| async { "created" }))
678            .layer(CsrfLayer::from_config(&config));
679
680        // Exempt API path: POST with no token should succeed.
681        let response = app
682            .clone()
683            .oneshot(
684                Request::builder()
685                    .method("POST")
686                    .uri("/api/items")
687                    .body(Body::empty())
688                    .unwrap(),
689            )
690            .await
691            .unwrap();
692        assert_eq!(response.status(), StatusCode::OK);
693
694        // Non-exempt form path: POST with no token should still be blocked.
695        let response = app
696            .oneshot(
697                Request::builder()
698                    .method("POST")
699                    .uri("/form/submit")
700                    .body(Body::empty())
701                    .unwrap(),
702            )
703            .await
704            .unwrap();
705        assert_eq!(response.status(), StatusCode::FORBIDDEN);
706    }
707
708    #[tokio::test]
709    async fn post_with_valid_token_passes() {
710        let token = Uuid::new_v4().to_string();
711        let app = Router::new()
712            .route("/submit", post(|| async { "created" }))
713            .layer(CsrfLayer::from_config(&default_csrf_config()));
714
715        let response = app
716            .oneshot(
717                Request::builder()
718                    .method("POST")
719                    .uri("/submit")
720                    .header("Cookie", format!("autumn-csrf={token}"))
721                    .header("X-CSRF-Token", &token)
722                    .body(Body::empty())
723                    .unwrap(),
724            )
725            .await
726            .unwrap();
727
728        assert_eq!(response.status(), StatusCode::OK);
729    }
730
731    #[tokio::test]
732    async fn post_with_mismatched_token_returns_403() {
733        let cookie_token = Uuid::new_v4().to_string();
734        let header_token = Uuid::new_v4().to_string();
735        let app = Router::new()
736            .route("/submit", post(|| async { "created" }))
737            .layer(CsrfLayer::from_config(&default_csrf_config()));
738
739        let response = app
740            .oneshot(
741                Request::builder()
742                    .method("POST")
743                    .uri("/submit")
744                    .header("Cookie", format!("autumn-csrf={cookie_token}"))
745                    .header("X-CSRF-Token", &header_token)
746                    .body(Body::empty())
747                    .unwrap(),
748            )
749            .await
750            .unwrap();
751
752        assert_eq!(response.status(), StatusCode::FORBIDDEN);
753    }
754
755    #[tokio::test]
756    async fn csrf_token_extractor_works() {
757        async fn handler(csrf: CsrfToken) -> String {
758            csrf.token().to_owned()
759        }
760
761        let app = Router::new()
762            .route("/", get(handler))
763            .layer(CsrfLayer::from_config(&default_csrf_config()));
764
765        let response = app
766            .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
767            .await
768            .unwrap();
769
770        assert_eq!(response.status(), StatusCode::OK);
771        let body = axum::body::to_bytes(response.into_body(), usize::MAX)
772            .await
773            .unwrap();
774        let token_str = String::from_utf8(body.to_vec()).unwrap();
775        assert!(Uuid::parse_str(&token_str).is_ok());
776    }
777
778    #[test]
779    fn extract_cookie_from_header() {
780        let mut headers = http::HeaderMap::new();
781        headers.insert(
782            http::header::COOKIE,
783            "autumn-csrf=abc123; other=xyz".parse().unwrap(),
784        );
785        assert_eq!(
786            extract_cookie_token(&headers, "autumn-csrf"),
787            Some("abc123".to_owned())
788        );
789    }
790
791    #[test]
792    fn missing_cookie_returns_none() {
793        let headers = http::HeaderMap::new();
794        assert_eq!(extract_cookie_token(&headers, "autumn-csrf"), None);
795    }
796
797    #[test]
798    fn extract_cookie_rejects_multiple_cookies() {
799        // Multiple cookies with the same name in a single header
800        let mut headers = http::HeaderMap::new();
801        headers.insert(
802            http::header::COOKIE,
803            "autumn-csrf=abc123; autumn-csrf=xyz456".parse().unwrap(),
804        );
805        assert_eq!(extract_cookie_token(&headers, "autumn-csrf"), None);
806
807        // Multiple headers with the same cookie
808        let mut headers2 = http::HeaderMap::new();
809        headers2.append(http::header::COOKIE, "autumn-csrf=abc123".parse().unwrap());
810        headers2.append(http::header::COOKIE, "autumn-csrf=xyz456".parse().unwrap());
811        assert_eq!(extract_cookie_token(&headers2, "autumn-csrf"), None);
812    }
813
814    #[test]
815    fn extract_cookie_ignores_malformed_cookies() {
816        let mut headers = http::HeaderMap::new();
817        // Missing '='
818        headers.insert(http::header::COOKIE, "autumn-csrf abc123".parse().unwrap());
819        assert_eq!(extract_cookie_token(&headers, "autumn-csrf"), None);
820
821        // Multiple spaces
822        headers.insert(
823            http::header::COOKIE,
824            "   autumn-csrf  =  abc123  ; other=xyz".parse().unwrap(),
825        );
826        assert_eq!(
827            extract_cookie_token(&headers, "autumn-csrf"),
828            Some("abc123".to_owned())
829        );
830    }
831
832    #[test]
833    fn test_constant_time_eq() {
834        assert!(super::constant_time_eq("abc", "abc"));
835        assert!(!super::constant_time_eq("abc", "ab"));
836        assert!(!super::constant_time_eq("abc", "abd"));
837        assert!(super::constant_time_eq("", ""));
838        assert!(!super::constant_time_eq("a", "b"));
839        assert!(!super::constant_time_eq("a", "A"));
840    }
841
842    #[tokio::test]
843    async fn post_with_empty_cookie_but_valid_header() {
844        let token = Uuid::new_v4().to_string();
845        let app = Router::new()
846            .route("/submit", post(|| async { "created" }))
847            .layer(CsrfLayer::from_config(&default_csrf_config()));
848
849        let response = app
850            .oneshot(
851                Request::builder()
852                    .method("POST")
853                    .uri("/submit")
854                    .header("Cookie", "autumn-csrf=")
855                    .header("X-CSRF-Token", &token)
856                    .body(Body::empty())
857                    .unwrap(),
858            )
859            .await
860            .unwrap();
861
862        assert_eq!(response.status(), StatusCode::FORBIDDEN);
863    }
864
865    #[tokio::test]
866    async fn post_with_valid_cookie_but_empty_header() {
867        let token = Uuid::new_v4().to_string();
868        let app = Router::new()
869            .route("/submit", post(|| async { "created" }))
870            .layer(CsrfLayer::from_config(&default_csrf_config()));
871
872        let response = app
873            .oneshot(
874                Request::builder()
875                    .method("POST")
876                    .uri("/submit")
877                    .header("Cookie", format!("autumn-csrf={token}"))
878                    .header("X-CSRF-Token", "")
879                    .body(Body::empty())
880                    .unwrap(),
881            )
882            .await
883            .unwrap();
884
885        assert_eq!(response.status(), StatusCode::FORBIDDEN);
886    }
887
888    #[tokio::test]
889    async fn post_with_empty_cookie_but_valid_form_field() {
890        let token = Uuid::new_v4().to_string();
891        let app = Router::new()
892            .route("/submit", post(|| async { "created" }))
893            .layer(CsrfLayer::from_config(&default_csrf_config()));
894
895        let response = app
896            .oneshot(
897                Request::builder()
898                    .method("POST")
899                    .uri("/submit")
900                    .header("Cookie", "autumn-csrf=")
901                    .header("Content-Type", "application/x-www-form-urlencoded")
902                    .body(Body::from(format!("_csrf={token}")))
903                    .unwrap(),
904            )
905            .await
906            .unwrap();
907
908        assert_eq!(response.status(), StatusCode::FORBIDDEN);
909    }
910
911    #[tokio::test]
912    async fn post_with_valid_cookie_but_empty_form_field() {
913        let token = Uuid::new_v4().to_string();
914        let app = Router::new()
915            .route("/submit", post(|| async { "created" }))
916            .layer(CsrfLayer::from_config(&default_csrf_config()));
917
918        let response = app
919            .oneshot(
920                Request::builder()
921                    .method("POST")
922                    .uri("/submit")
923                    .header("Cookie", format!("autumn-csrf={token}"))
924                    .header("Content-Type", "application/x-www-form-urlencoded")
925                    .body(Body::from("_csrf="))
926                    .unwrap(),
927            )
928            .await
929            .unwrap();
930
931        assert_eq!(response.status(), StatusCode::FORBIDDEN);
932    }
933
934    #[tokio::test]
935    async fn post_with_large_body_fails_csrf() {
936        let token = Uuid::new_v4().to_string();
937        let app = Router::new()
938            .route("/submit", post(|| async { "created" }))
939            .layer(CsrfLayer::from_config(&default_csrf_config()));
940
941        // Create a body just slightly over 2MB. The CSRF extractor limits to 2MB.
942        let large_padding = "a".repeat(2 * 1024 * 1024 + 10);
943        let body_content = format!("_csrf={token}&pad={large_padding}");
944
945        let response = app
946            .oneshot(
947                Request::builder()
948                    .method("POST")
949                    .uri("/submit")
950                    .header("Cookie", format!("autumn-csrf={token}"))
951                    .header("Content-Type", "application/x-www-form-urlencoded")
952                    .body(Body::from(body_content))
953                    .unwrap(),
954            )
955            .await
956            .unwrap();
957
958        assert_eq!(response.status(), StatusCode::FORBIDDEN);
959    }
960
961    #[tokio::test]
962    async fn post_with_empty_tokens_returns_403() {
963        let app = Router::new()
964            .route("/submit", post(|| async { "created" }))
965            .layer(CsrfLayer::from_config(&CsrfConfig {
966                enabled: true,
967                ..Default::default()
968            }));
969
970        let response = app
971            .oneshot(
972                Request::builder()
973                    .method("POST")
974                    .uri("/submit")
975                    .header("Cookie", "autumn-csrf=")
976                    .header("X-CSRF-Token", "")
977                    .body(Body::empty())
978                    .unwrap(),
979            )
980            .await
981            .unwrap();
982
983        assert_eq!(response.status(), StatusCode::FORBIDDEN);
984    }
985
986    #[tokio::test]
987    async fn post_with_empty_form_tokens_returns_403() {
988        let app = Router::new()
989            .route("/submit", post(|| async { "created" }))
990            .layer(CsrfLayer::from_config(&CsrfConfig {
991                enabled: true,
992                ..Default::default()
993            }));
994
995        let response = app
996            .oneshot(
997                Request::builder()
998                    .method("POST")
999                    .uri("/submit")
1000                    .header("Cookie", "autumn-csrf=")
1001                    .header("Content-Type", "application/x-www-form-urlencoded")
1002                    .body(Body::from("_csrf="))
1003                    .unwrap(),
1004            )
1005            .await
1006            .unwrap();
1007
1008        assert_eq!(response.status(), StatusCode::FORBIDDEN);
1009    }
1010
1011    #[test]
1012    fn from_config_filters_invalid_methods() {
1013        let config = CsrfConfig {
1014            safe_methods: vec![
1015                "GET".to_string(),
1016                "INVALID METHOD".to_string(),
1017                "POST".to_string(),
1018            ],
1019            ..Default::default()
1020        };
1021        let layer = CsrfLayer::from_config(&config);
1022        assert_eq!(layer.settings.safe_methods.len(), 2);
1023        assert!(layer.settings.safe_methods.contains(&http::Method::GET));
1024        assert!(layer.settings.safe_methods.contains(&http::Method::POST));
1025    }
1026
1027    #[test]
1028    fn from_config_handles_invalid_header_name() {
1029        let config = CsrfConfig {
1030            token_header: "Invalid Header Name\n".to_string(),
1031            ..Default::default()
1032        };
1033        let layer = CsrfLayer::from_config(&config);
1034        assert_eq!(layer.settings.token_header.as_str(), "x-csrf-token");
1035    }
1036
1037    // ── Signed CSRF tokens (RED phase) ────────────────────────────────────
1038
1039    #[tokio::test]
1040    async fn csrf_token_is_hmac_signed_when_keys_set() {
1041        use crate::security::config::{SigningSecretConfig, resolve_signing_keys};
1042        use std::sync::Arc;
1043
1044        let keys = Arc::new(resolve_signing_keys(&SigningSecretConfig {
1045            secret: Some("k".repeat(32)),
1046            previous_secrets: vec![],
1047        }));
1048        let layer = CsrfLayer::from_config(&default_csrf_config()).with_signing_keys(keys);
1049
1050        let app = Router::new()
1051            .route("/", get(|| async { "ok" }))
1052            .layer(layer);
1053
1054        let resp = app
1055            .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
1056            .await
1057            .unwrap();
1058
1059        let set_cookie = resp
1060            .headers()
1061            .get("set-cookie")
1062            .expect("should set CSRF cookie")
1063            .to_str()
1064            .unwrap();
1065        let cookie_value = set_cookie
1066            .split('=')
1067            .nth(1)
1068            .unwrap()
1069            .split(';')
1070            .next()
1071            .unwrap()
1072            .trim();
1073
1074        assert!(
1075            cookie_value.contains('.'),
1076            "signed CSRF cookie must be {{uuid}}.{{hmac}}, got: {cookie_value}"
1077        );
1078        let (_uuid_part, sig_part) = cookie_value.split_once('.').unwrap();
1079        assert_eq!(sig_part.len(), 64, "HMAC hex must be 64 chars");
1080    }
1081
1082    #[tokio::test]
1083    async fn csrf_signed_token_validates_on_post() {
1084        use crate::security::config::{SigningSecretConfig, resolve_signing_keys};
1085        use std::sync::Arc;
1086
1087        let keys = Arc::new(resolve_signing_keys(&SigningSecretConfig {
1088            secret: Some("k".repeat(32)),
1089            previous_secrets: vec![],
1090        }));
1091        let layer = CsrfLayer::from_config(&default_csrf_config()).with_signing_keys(keys);
1092
1093        let app = Router::new()
1094            .route("/", post(|| async { "created" }))
1095            .layer(layer);
1096
1097        // Mint a valid signed token
1098        let config = SigningSecretConfig {
1099            secret: Some("k".repeat(32)),
1100            previous_secrets: vec![],
1101        };
1102        let signing_keys = resolve_signing_keys(&config);
1103        let uuid = uuid::Uuid::new_v4().to_string();
1104        let sig = signing_keys.sign(uuid.as_bytes());
1105        let signed_token = format!("{uuid}.{sig}");
1106
1107        let resp = app
1108            .oneshot(
1109                Request::builder()
1110                    .method("POST")
1111                    .uri("/")
1112                    .header("Cookie", format!("autumn-csrf={signed_token}"))
1113                    .header("X-CSRF-Token", &signed_token)
1114                    .body(Body::empty())
1115                    .unwrap(),
1116            )
1117            .await
1118            .unwrap();
1119
1120        assert_eq!(resp.status(), StatusCode::OK);
1121    }
1122
1123    #[tokio::test]
1124    async fn csrf_unsigned_token_rejected_when_signing_active() {
1125        use crate::security::config::{SigningSecretConfig, resolve_signing_keys};
1126        use std::sync::Arc;
1127
1128        let keys = Arc::new(resolve_signing_keys(&SigningSecretConfig {
1129            secret: Some("k".repeat(32)),
1130            previous_secrets: vec![],
1131        }));
1132        let layer = CsrfLayer::from_config(&default_csrf_config()).with_signing_keys(keys);
1133
1134        let app = Router::new()
1135            .route("/", post(|| async { "created" }))
1136            .layer(layer);
1137
1138        // Raw UUID without HMAC — should be rejected when signing is active
1139        let raw_token = uuid::Uuid::new_v4().to_string();
1140        let resp = app
1141            .oneshot(
1142                Request::builder()
1143                    .method("POST")
1144                    .uri("/")
1145                    .header("Cookie", format!("autumn-csrf={raw_token}"))
1146                    .header("X-CSRF-Token", &raw_token)
1147                    .body(Body::empty())
1148                    .unwrap(),
1149            )
1150            .await
1151            .unwrap();
1152
1153        assert_eq!(
1154            resp.status(),
1155            StatusCode::FORBIDDEN,
1156            "unsigned CSRF token must be rejected when signing is active"
1157        );
1158    }
1159
1160    #[tokio::test]
1161    async fn csrf_previous_key_signed_token_accepted() {
1162        use crate::security::config::{
1163            ResolvedSigningKeys, SigningSecretConfig, resolve_signing_keys,
1164        };
1165        use std::sync::Arc;
1166
1167        let old_secret = "old-key".repeat(5); // 35 bytes
1168        let old_keys = resolve_signing_keys(&SigningSecretConfig {
1169            secret: Some(old_secret.clone()),
1170            previous_secrets: vec![],
1171        });
1172
1173        let uuid = uuid::Uuid::new_v4().to_string();
1174        let old_sig = old_keys.sign(uuid.as_bytes());
1175        let old_signed_token = format!("{uuid}.{old_sig}");
1176
1177        let new_keys = Arc::new(ResolvedSigningKeys::new(
1178            "new-key".repeat(5).into_bytes(),
1179            vec![old_secret.into_bytes()],
1180        ));
1181        let layer = CsrfLayer::from_config(&default_csrf_config()).with_signing_keys(new_keys);
1182
1183        let app = Router::new()
1184            .route("/", post(|| async { "created" }))
1185            .layer(layer);
1186
1187        let resp = app
1188            .oneshot(
1189                Request::builder()
1190                    .method("POST")
1191                    .uri("/")
1192                    .header("Cookie", format!("autumn-csrf={old_signed_token}"))
1193                    .header("X-CSRF-Token", &old_signed_token)
1194                    .body(Body::empty())
1195                    .unwrap(),
1196            )
1197            .await
1198            .unwrap();
1199
1200        assert_eq!(
1201            resp.status(),
1202            StatusCode::OK,
1203            "previous-key-signed CSRF token must pass during grace window"
1204        );
1205    }
1206}