axum_csrf_sync_pattern/
lib.rs

1//! # Axum Synchronizer Token Pattern CSRF prevention
2//!
3//! This crate provides a CSRF protection layer and middleware for use with the [axum](https://docs.rs/axum/) web framework.
4//!
5//! The middleware implements the [CSRF Synchronizer Token Pattern](https://cheatsheetseries.owasp.org/cheatsheets/Cross-Site_Request_Forgery_Prevention_Cheat_Sheet.html#synchronizer-token-pattern)
6//! for AJAX backends and API endpoints as described in the OWASP CSRF prevention cheat sheet.
7//!
8//! ## Scope
9//!
10//! This middleware implements token transfer via [custom request headers](https://cheatsheetseries.owasp.org/cheatsheets/Cross-Site_Request_Forgery_Prevention_Cheat_Sheet.html#use-of-custom-request-headers).
11//!
12//! The middleware requires and is built upon [`axum_sessions`](https://docs.rs/axum-sessions/), which in turn uses [`async_session`](https://docs.rs/async-session/).
13//!
14//! The [Same Origin Policy](https://developer.mozilla.org/en-US/docs/Web/Security/Same-origin_policy) prevents the custom request header to be set by foreign scripts.
15//!
16//! ### In which contexts should I use this middleware?
17//!
18//! The goal of this middleware is to prevent cross-site request forgery attacks specifically in applications communicating with their backend by means of the JavaScript
19//! [`fetch()` API](https://developer.mozilla.org/en-US/docs/Web/API/fetch) or classic [`XmlHttpRequest`](https://developer.mozilla.org/en-US/docs/Web/API/XMLHttpRequest),
20//! traditionally called "AJAX".
21//!
22//! The Synchronizer Token Pattern is especially useful in [CORS](https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS) contexts,
23//! as the underlying session cookie is obligatorily secured and inaccessible by JavaScript, while the custom HTTP response header carrying the CSRF token can be exposed
24//! using the CORS [`Access-Control-Expose-Headers`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Expose-Headers) HTTP response header.
25//!
26//! While the [Same Origin Policy](https://developer.mozilla.org/en-US/docs/Web/Security/Same-origin_policy) commonly prevents custom request headers to be set on cross-origin requests,
27//! use of the use of the [Access-Control-Allow-Headers](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers) CORS HTTP response header
28//! can be used to specifically allow CORS requests to be equipped with a required custom HTTP request header.
29//!
30//! This approach ensures that requests forged by auto-submitted forms or other data-submitting scripts from foreign origins are unable to add the required header.
31//!
32//! ### When should I use other CSRF protection patterns or libraries?
33//!
34//! Use other available middleware libraries if you plan on submitting classical HTML forms without the use of JavaScript, and if you do not send the form data across origins.
35//!
36//! ## Security
37//! ### Token randomness
38//!
39//! The CSRF tokens are generated using [`rand::ThreadRng`](https://rust-random.github.io/rand/rand/rngs/struct.ThreadRng.html) which is considered cryptographically secure (CSPRNG).
40//! See ["Our RNGs"](https://rust-random.github.io/book/guide-rngs.html#cryptographically-secure-pseudo-random-number-generators-csprngs) for more.
41//!
42//! ### Underlying session security
43//!
44//! The security of the underlying session is paramount - the CSRF prevention methods applied can only be as secure as the session carrying the server-side token.
45//!
46//! - When creating your [SessionLayer](https://docs.rs/axum-sessions/latest/axum_sessions/struct.SessionLayer.html), make sure to use at least 64 bytes of cryptographically secure randomness.
47//! - Do not lower the secure defaults: Keep the session cookie's `secure` flag **on**.
48//! - Use the strictest possible same-site policy.
49//!
50//! ### CORS security
51//!
52//! If you need to provide and secure cross-site requests:
53//!
54//! - Allow only your backend origin when configuring the [`CorsLayer`](https://docs.rs/tower-http/latest/tower_http/cors/struct.CorsLayer.html)
55//! - Allow only the headers you need. (At least the CSRF request token header.)
56//! - Only expose the headers you need. (At least the CSRF response token header.)
57//!
58//! ### No leaks of error details
59//!
60//! Errors are logged using [`tracing::error!`]. Error responses do not contain error details.
61//!
62//! Use [`tower_http::TraceLayer`](https://docs.rs/tower-http/latest/tower_http/trace/struct.TraceLayer.html) to capture and view traces.
63//!
64//! ## Safety
65//!
66//! This crate uses no `unsafe` code.
67//!
68//! The layer and middleware functionality is tested. View the the module source code to learn more.
69//!
70//! ## Usage
71//!
72//! See the [example projects](https://github.com/LeoniePhiline/axum-csrf-sync-pattern/tree/main/examples/) for same-site and cross-site usage.
73//!
74//! ### Same-site usage
75//!
76//! **Note:** The crate repository contains example projects for same-site and cross-site usage!
77//!
78//! Configure your session and CSRF protection layer in your backend application:
79//!
80//! ```rust
81//! use axum::{
82//!     body::Body,
83//!     http::StatusCode,
84//!     routing::{get, Router},
85//! };
86//! use axum_csrf_sync_pattern::{CsrfLayer, RegenerateToken};
87//! use axum_sessions::{async_session::MemoryStore, SessionLayer};
88//! use rand::RngCore;
89//!
90//! let mut secret = [0; 64];
91//! rand::thread_rng().try_fill_bytes(&mut secret).unwrap();
92//!
93//! async fn handler() -> StatusCode {
94//!     StatusCode::OK
95//! }
96//!
97//! let app = Router::new()
98//!     .route("/", get(handler).post(handler))
99//!     .layer(
100//!         CsrfLayer::new()
101//!
102//!         // Optionally, configure the layer with the following options:
103//!
104//!         // Default: RegenerateToken::PerSession
105//!         .regenerate(RegenerateToken::PerUse)
106//!         // Default: "X-CSRF-TOKEN"
107//!         .request_header("X-Custom-Request-Header")
108//!         // Default: "X-CSRF-TOKEN"
109//!         .response_header("X-Custom-Response-Header")
110//!         // Default: "_csrf_token"
111//!         .session_key("_custom_session_key")
112//!     )
113//!     .layer(SessionLayer::new(MemoryStore::new(), &secret));
114//!
115//! // Use hyper to run `app` as service and expose on a local port or socket.
116//!
117//! # use tower::util::ServiceExt;
118//! # tokio_test::block_on(async {
119//! #     app.oneshot(
120//! #         axum::http::Request::builder().body(axum::body::Body::empty()).unwrap()
121//! #     ).await.unwrap();
122//! # })
123//! ```
124//!
125//! Receive the token and send same-site requests, using your custom header:
126//!
127//! ```javascript
128//! const test = async () => {
129//!     // Receive CSRF token (Default response header name: 'X-CSRF-TOKEN')
130//!     const token = (await fetch('/')).headers.get('X-Custom-Response-Header');
131//!
132//!     // Submit data using the token
133//!     await fetch('/', {
134//!         method: 'POST',
135//!         headers: {
136//!            'Content-Type': 'application/json',
137//!            // Default request header name: 'X-CSRF-TOKEN'
138//!            'X-Custom-Request-Header': token,
139//!         },
140//!         body: JSON.stringify({ /* ... */ }),
141//!     });
142//! };
143//! ```
144//!
145//! ### CORS-enabled usage
146//!
147//! **Note:** The crate repository contains example projects for same-site and cross-site usage!
148//!
149//! Configure your CORS layer, session and CSRF protection layer in your backend application:
150//!
151//! ```rust
152//! use axum::{
153//!     body::Body,
154//!     http::{header, Method, StatusCode},
155//!     routing::{get, Router},
156//! };
157//! use axum_csrf_sync_pattern::{CsrfLayer, RegenerateToken};
158//! use axum_sessions::{async_session::MemoryStore, SessionLayer};
159//! use rand::RngCore;
160//! use tower_http::cors::{AllowOrigin, CorsLayer};
161//!
162//! let mut secret = [0; 64];
163//! rand::thread_rng().try_fill_bytes(&mut secret).unwrap();
164//!
165//! async fn handler() -> StatusCode {
166//!     StatusCode::OK
167//! }
168//!
169//! let app = Router::new()
170//!     .route("/", get(handler).post(handler))
171//!     .layer(
172//!         // See example above for custom layer configuration.
173//!         CsrfLayer::new()
174//!     )
175//!     .layer(SessionLayer::new(MemoryStore::new(), &secret))
176//!     .layer(
177//!         CorsLayer::new()
178//!             .allow_origin(AllowOrigin::list(["https://www.example.com".parse().unwrap()]))
179//!             .allow_methods([Method::GET, Method::POST])
180//!             .allow_headers([header::CONTENT_TYPE, "X-CSRF-TOKEN".parse().unwrap()])
181//!             .allow_credentials(true)
182//!             .expose_headers(["X-CSRF-TOKEN".parse().unwrap()]),
183//!    );
184//!
185//! // Use hyper to run `app` as service and expose on a local port or socket.
186//!
187//! # use tower::util::ServiceExt;
188//! # tokio_test::block_on(async {
189//! #     app.oneshot(
190//! #         axum::http::Request::builder().body(axum::body::Body::empty()).unwrap()
191//! #     ).await.unwrap();
192//! # })
193//! ```
194//!
195//! Receive the token and send cross-site requests, using your custom header:
196//!
197//! ```javascript
198//! const test = async () => {
199//!     // Receive CSRF token
200//!     const token = (await fetch('https://backend.example.com/', {
201//!         credentials: 'include',
202//!     })).headers.get('X-CSRF-TOKEN');
203//!
204//!     // Submit data using the token
205//!     await fetch('https://backend.example.com/', {
206//!         method: 'POST',
207//!         headers: {
208//!            'Content-Type': 'application/json',
209//!            'X-CSRF-TOKEN': token,
210//!         },
211//!         credentials: 'include',
212//!         body: JSON.stringify({ /* ... */ }),
213//!     });
214//! };
215//! ```
216//!
217//! ## Contributing
218//!
219//! Pull requests are welcome!
220//!
221
222#![forbid(unsafe_code, future_incompatible)]
223#![deny(
224    missing_debug_implementations,
225    nonstandard_style,
226    missing_docs,
227    unreachable_pub,
228    missing_copy_implementations,
229    unused_qualifications
230)]
231
232use std::{
233    convert::Infallible,
234    future::Future,
235    pin::Pin,
236    task::{Context, Poll},
237};
238
239use axum::http::{self, HeaderValue, Request, StatusCode};
240use axum_core::response::{IntoResponse, Response};
241use axum_sessions::{async_session::Session, SessionHandle};
242use base64::prelude::*;
243use rand::RngCore;
244use tokio::sync::RwLockWriteGuard;
245use tower::Layer;
246
247/// Use `CsrfLayer::new()` to provide the middleware and configuration to axum's service stack.
248///
249/// Use the provided methods to configure details, such as when tokens are regenerated, what request and response
250/// headers should be named, and under which key the token should be stored in the session.
251#[derive(Clone, Copy, Debug)]
252pub struct CsrfLayer {
253    /// Configures when tokens are regenerated: Per session, per use or per request. See [`RegenerateToken`] for details.
254    pub regenerate_token: RegenerateToken,
255
256    /// Configures the request header name accepted by the middleware. Defaults to `"X-CSRF-TOKEN"`.
257    /// This header is set on your JavaScript or WASM requests originating from the browser.
258    pub request_header: &'static str,
259
260    /// Configures the response header name sent by the middleware. Defaults to `"X-CSRF-TOKEN"`.
261    /// This header is received by your JavaScript or WASM code and its name must be used to extract the token from the HTTP response.
262    pub response_header: &'static str,
263
264    /// Configures the key under which the middleware stores the server-side token in the session. Defaults to `"_csrf_token"`.
265    pub session_key: &'static str,
266}
267
268impl Default for CsrfLayer {
269    fn default() -> Self {
270        Self {
271            regenerate_token: Default::default(),
272            request_header: "X-CSRF-TOKEN",
273            response_header: "X-CSRF-TOKEN",
274            session_key: "_csrf_token",
275        }
276    }
277}
278
279impl CsrfLayer {
280    /// Create a new CSRF synchronizer token layer to inject into your middleware stack using
281    /// [`axum::Router::layer()`].
282    pub fn new() -> Self {
283        Self::default()
284    }
285
286    /// Configure when tokens are regenerated: Per session, per use or per request. See [`RegenerateToken`] for details.
287    pub fn regenerate(mut self, regenerate_token: RegenerateToken) -> Self {
288        self.regenerate_token = regenerate_token;
289
290        self
291    }
292
293    /// Configure a custom request header name accepted by the middleware. Defaults to `"X-CSRF-TOKEN"`.
294    ///
295    /// This header is set on your JavaScript or WASM requests originating from the browser.
296    pub fn request_header(mut self, request_header: &'static str) -> Self {
297        self.request_header = request_header;
298
299        self
300    }
301
302    /// Configure a custom response header name sent by the middleware. Defaults to `"X-CSRF-TOKEN"`.
303    ///
304    /// This header is received by your JavaScript or WASM code and its name must be used to extract the token from the HTTP response.
305    pub fn response_header(mut self, response_header: &'static str) -> Self {
306        self.response_header = response_header;
307
308        self
309    }
310
311    /// Configure a custom key under which the middleware stores the server-side token in the session. Defaults to `"_csrf_token"`.
312    pub fn session_key(mut self, session_key: &'static str) -> Self {
313        self.session_key = session_key;
314
315        self
316    }
317
318    fn regenerate_token(
319        &self,
320        session_write: &mut RwLockWriteGuard<Session>,
321    ) -> Result<String, Error> {
322        let mut buf = [0; 32];
323        rand::thread_rng().try_fill_bytes(&mut buf)?;
324        let token = BASE64_STANDARD.encode(buf);
325        session_write.insert(self.session_key, &token)?;
326
327        Ok(token)
328    }
329
330    fn response_with_token(&self, mut response: Response, server_token: &str) -> Response {
331        response.headers_mut().insert(
332            self.response_header,
333            match HeaderValue::from_str(server_token).map_err(Error::from) {
334                Ok(token_header) => token_header,
335                Err(error) => return error.into_response(),
336            },
337        );
338        response
339    }
340}
341
342impl<S> Layer<S> for CsrfLayer {
343    type Service = CsrfMiddleware<S>;
344
345    fn layer(&self, inner: S) -> Self::Service {
346        CsrfMiddleware::new(inner, *self)
347    }
348}
349
350/// This enum is used with [`CsrfLayer::regenerate`] to determine
351/// at which occurences the CSRF token should be regenerated.
352///
353/// You could understand these options as modes to choose a level of paranoia, depending on your application's requirements.
354///
355/// This paranoia level is a trade-off between ergonomics and security; as more frequent
356/// token invalidation requires more overhead for handling and renewing tokens on the client side,
357/// as well as retrying requests with a fresh token, should they fail.
358#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
359#[allow(clippy::enum_variant_names)]
360pub enum RegenerateToken {
361    /// Generate one CSRF token per session and use this token until the session ends.
362    ///
363    /// This is the default behavior and should work for most applications.
364    #[default]
365    PerSession,
366    /// Regenerate the CSRF token after each use. A "use" describes an unsafe HTTP method
367    /// (`POST`, `PUT`, `PATCH`, `DELETE`).
368    ///
369    /// CSRF tokens are not required for, and thus not invalidated by handling requests
370    /// using safe HTTP methods (`HEAD`, `GET`, `OPTIONS`, `TRACE`, `CONNECT`).
371    PerUse,
372    /// Regenerate the CSRF token at each request, including safe HTTP methods (`HEAD`, `GET`, `OPTIONS`, `TRACE`, `CONNECT`).
373    ///
374    /// This behavior might require elaborate token handling on the client side,
375    /// as any concurrent requests mean race conditions from the client's perspective,
376    /// and each request's response yields a new token to be used on the consecutive request.
377    PerRequest,
378}
379
380#[derive(thiserror::Error, Debug)]
381enum Error {
382    #[error("Random number generator error")]
383    Rng(#[from] rand::Error),
384
385    #[error("Serde JSON error")]
386    Serde(#[from] axum_sessions::async_session::serde_json::Error),
387
388    #[error("Session extension missing. Is `axum_sessions::SessionLayer` installed and layered around the `axum_csrf_sync_pattern::CsrfLayer`?")]
389    SessionLayerMissing,
390
391    #[error("Incoming CSRF token header was not valid ASCII")]
392    InvalidClientTokenHeader(#[from] http::header::ToStrError),
393
394    #[error("Invalid CSRF token when preparing response header")]
395    InvalidServerTokenHeader(#[from] http::header::InvalidHeaderValue),
396}
397
398impl IntoResponse for Error {
399    fn into_response(self) -> Response {
400        tracing::error!(?self);
401        StatusCode::INTERNAL_SERVER_ERROR.into_response()
402    }
403}
404
405/// This middleware is created by axum by applying the `CsrfLayer`.
406/// It verifies the CSRF token header on incoming requests, regenerates tokens as configured,
407/// and attaches the current token to the outgoing response.
408///
409/// In detail, this middleware receives a CSRF token as `X-CSRF-TOKEN` (if not custom configured
410/// with a different name) HTTP request header value
411/// and compares it to the token stored in the session.
412///
413/// Upon response from the inner service, the session token is returned to the
414/// client via the `X-CSRF-TOKEN` response header.
415///
416/// Make sure to expose this header in your CORS configuration if necessary!
417///
418/// Requires and uses `axum_sessions`.
419///
420/// Optionally regenerates the token from the session after successful verification,
421/// to ensure a new token is used for each writing (`POST`, `PUT`, `DELETE`) request.
422/// Enable with [`RegenerateToken::PerUse`].
423///
424/// For maximum security, but severely reduced ergonomics, optionally regenerates the
425/// token from the session after each request, to keep the token validity as short as
426/// possible. Enable with [`RegenerateToken::PerRequest`].
427#[derive(Debug, Clone)]
428pub struct CsrfMiddleware<S> {
429    inner: S,
430    layer: CsrfLayer,
431}
432
433impl<S> CsrfMiddleware<S> {
434    /// Create a new middleware from an inner [`tower::Service`] (axum-specific bounds, such as `Infallible` errors apply!) and a [`CsrfLayer`].
435    /// Commonly, the middleware is created by the [`tower::Layer`] - and never manually.
436    pub fn new(inner: S, layer: CsrfLayer) -> Self {
437        CsrfMiddleware { inner, layer }
438    }
439
440    /// Create a new CSRF synchronizer token layer.
441    /// Equivalent to calling [`CsrfLayer::new()`].
442    pub fn layer() -> CsrfLayer {
443        CsrfLayer::default()
444    }
445}
446
447impl<S, B: Send + 'static> tower::Service<Request<B>> for CsrfMiddleware<S>
448where
449    S: tower::Service<Request<B>, Response = Response, Error = Infallible> + Send + Clone + 'static,
450    S::Future: Send,
451{
452    type Response = S::Response;
453    type Error = Infallible;
454    type Future =
455        Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
456
457    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
458        self.inner.poll_ready(cx)
459    }
460
461    fn call(&mut self, req: Request<B>) -> Self::Future {
462        let clone = self.inner.clone();
463        let mut inner = std::mem::replace(&mut self.inner, clone);
464        let layer = self.layer;
465        Box::pin(async move {
466            let session_handle = match req
467                .extensions()
468                .get::<SessionHandle>()
469                .ok_or(Error::SessionLayerMissing)
470            {
471                Ok(session_handle) => session_handle,
472                Err(error) => return Ok(error.into_response()),
473            };
474
475            // Extract the CSRF server side token from the session; create a new one if none has been set yet.
476            // If the regeneration option is set to "per request", then regenerate the token even if present in the session.
477            let mut session_write = session_handle.write().await;
478            let mut server_token = match session_write.get::<String>(layer.session_key) {
479                Some(token) => token,
480                None => match layer.regenerate_token(&mut session_write) {
481                    Ok(token) => token,
482                    Err(error) => return Ok(error.into_response()),
483                },
484            };
485
486            if !req.method().is_safe() {
487                // Verify incoming CSRF token for unsafe request methods.
488                let client_token = {
489                    match req.headers().get(layer.request_header) {
490                        Some(token) => token,
491                        None => {
492                            tracing::warn!("{} header missing!", layer.request_header);
493                            return Ok(layer.response_with_token(
494                                StatusCode::FORBIDDEN.into_response(),
495                                &server_token,
496                            ));
497                        }
498                    }
499                };
500
501                let client_token = match client_token.to_str().map_err(Error::from) {
502                    Ok(token) => token,
503                    Err(error) => {
504                        return Ok(layer.response_with_token(error.into_response(), &server_token))
505                    }
506                };
507                if client_token != server_token {
508                    tracing::warn!("{} header mismatch!", layer.request_header);
509                    return Ok(layer.response_with_token(
510                        (StatusCode::FORBIDDEN).into_response(),
511                        &server_token,
512                    ));
513                }
514            }
515
516            // Create new token if configured to regenerate per each request,
517            // or if configured to regenerate per use and just used.
518            if layer.regenerate_token == RegenerateToken::PerRequest
519                || (!req.method().is_safe() && layer.regenerate_token == RegenerateToken::PerUse)
520            {
521                server_token = match layer.regenerate_token(&mut session_write) {
522                    Ok(token) => token,
523                    Err(error) => {
524                        return Ok(layer.response_with_token(error.into_response(), &server_token))
525                    }
526                };
527            }
528
529            drop(session_write);
530
531            let response = inner.call(req).await.into_response();
532
533            // Add X-CSRF-TOKEN response header.
534            Ok(layer.response_with_token(response, &server_token))
535        })
536    }
537}
538
539#[cfg(test)]
540mod tests {
541    use std::convert::Infallible;
542
543    use axum::{body::Body, routing::get, Router};
544    use axum_core::response::{IntoResponse, Response};
545    use axum_sessions::{async_session::MemoryStore, extractors::ReadableSession, SessionLayer};
546    use http::{
547        header::{COOKIE, SET_COOKIE},
548        Method, Request, StatusCode,
549    };
550    use tower::{Service, ServiceExt};
551
552    use super::*;
553
554    async fn handler() -> Result<Response, Infallible> {
555        Ok((
556            StatusCode::OK,
557            "The default test success response has a body",
558        )
559            .into_response())
560    }
561
562    fn session_layer() -> SessionLayer<MemoryStore> {
563        let mut secret = [0; 64];
564        rand::thread_rng().try_fill_bytes(&mut secret).unwrap();
565        SessionLayer::new(MemoryStore::new(), &secret)
566    }
567
568    fn app(csrf_layer: CsrfLayer) -> Router {
569        Router::new()
570            .route("/", get(handler).post(handler))
571            .layer(csrf_layer)
572            .layer(session_layer())
573    }
574
575    #[tokio::test]
576    async fn get_without_token_succeeds() {
577        let request = Request::builder()
578            .method(Method::GET)
579            .body(Body::empty())
580            .unwrap();
581
582        let response = app(CsrfLayer::new()).oneshot(request).await.unwrap();
583
584        assert_eq!(response.status(), StatusCode::OK);
585
586        let client_token = response.headers().get("X-CSRF-TOKEN").unwrap();
587        assert_eq!(BASE64_STANDARD.decode(client_token).unwrap().len(), 32);
588    }
589
590    #[tokio::test]
591    async fn post_without_token_fails() {
592        let request = Request::builder()
593            .method(Method::POST)
594            .body(Body::empty())
595            .unwrap();
596        let response = app(CsrfLayer::new()).oneshot(request).await.unwrap();
597
598        assert_eq!(response.status(), StatusCode::FORBIDDEN);
599
600        // Assert: Response must contain token even on request token failure.
601        let client_token = response.headers().get("X-CSRF-TOKEN").unwrap();
602        assert_eq!(BASE64_STANDARD.decode(client_token).unwrap().len(), 32);
603    }
604
605    #[tokio::test]
606    async fn session_token_remains_valid() {
607        let mut app = app(CsrfLayer::new().regenerate(RegenerateToken::PerSession));
608
609        // Get CSRF token
610        let response = app
611            .ready()
612            .await
613            .unwrap()
614            .call(Request::builder().body(Body::empty()).unwrap())
615            .await
616            .unwrap();
617
618        assert_eq!(response.status(), StatusCode::OK);
619
620        // Tokens are bound to the session - must re-use on each consecutive request.
621        let session_cookie = response.headers().get(SET_COOKIE).unwrap().clone();
622
623        let initial_client_token = response.headers().get("X-CSRF-TOKEN").unwrap();
624        assert_eq!(
625            BASE64_STANDARD.decode(initial_client_token).unwrap().len(),
626            32
627        );
628
629        // Use CSRF token for POST request
630        let response = app
631            .ready()
632            .await
633            .unwrap()
634            .call(
635                Request::builder()
636                    .method(Method::POST)
637                    .header("X-CSRF-TOKEN", initial_client_token)
638                    .header(COOKIE, session_cookie.clone())
639                    .body(Body::empty())
640                    .unwrap(),
641            )
642            .await
643            .unwrap();
644
645        assert_eq!(response.status(), StatusCode::OK);
646
647        // Assert token has not been changed after POST request
648        let client_token = response.headers().get("X-CSRF-TOKEN").unwrap();
649        assert_eq!(client_token, initial_client_token);
650
651        // Attempt token re-use for a second POST request
652        let response = app
653            .ready()
654            .await
655            .unwrap()
656            .call(
657                Request::builder()
658                    .method(Method::POST)
659                    .header("X-CSRF-TOKEN", initial_client_token)
660                    .header(COOKIE, session_cookie)
661                    .body(Body::empty())
662                    .unwrap(),
663            )
664            .await
665            .unwrap();
666
667        assert_eq!(response.status(), StatusCode::OK);
668
669        // Assert token has not been changed after POST request
670        let client_token = response.headers().get("X-CSRF-TOKEN").unwrap();
671        assert_eq!(client_token, initial_client_token);
672    }
673
674    #[tokio::test]
675    async fn single_use_token_is_regenerated() {
676        let mut app = app(CsrfLayer::new().regenerate(RegenerateToken::PerUse));
677
678        // Get single-use CSRF token
679        let response = app
680            .ready()
681            .await
682            .unwrap()
683            .call(Request::builder().body(Body::empty()).unwrap())
684            .await
685            .unwrap();
686
687        assert_eq!(response.status(), StatusCode::OK);
688
689        // Tokens are bound to the session - must re-use on each consecutive request.
690        let session_cookie = response.headers().get(SET_COOKIE).unwrap().clone();
691
692        let initial_client_token = response.headers().get("X-CSRF-TOKEN").unwrap();
693        assert_eq!(
694            BASE64_STANDARD.decode(initial_client_token).unwrap().len(),
695            32
696        );
697
698        // Use CSRF token for POST request
699        let response = app
700            .ready()
701            .await
702            .unwrap()
703            .call(
704                Request::builder()
705                    .method(Method::POST)
706                    .header("X-CSRF-TOKEN", initial_client_token)
707                    .header(COOKIE, session_cookie.clone())
708                    .body(Body::empty())
709                    .unwrap(),
710            )
711            .await
712            .unwrap();
713
714        assert_eq!(response.status(), StatusCode::OK);
715
716        // Assert token has been regenerated after POST request
717        let client_token = response.headers().get("X-CSRF-TOKEN").unwrap();
718        assert_ne!(client_token, initial_client_token);
719
720        // Attempt token re-use for a second POST request
721        let response = app
722            .ready()
723            .await
724            .unwrap()
725            .call(
726                Request::builder()
727                    .method(Method::POST)
728                    .header("X-CSRF-TOKEN", initial_client_token)
729                    .header(COOKIE, session_cookie)
730                    .body(Body::empty())
731                    .unwrap(),
732            )
733            .await
734            .unwrap();
735
736        assert_eq!(response.status(), StatusCode::FORBIDDEN);
737
738        // Assert token has been regenerated after POST request
739        let client_token = response.headers().get("X-CSRF-TOKEN").unwrap();
740        assert_ne!(client_token, initial_client_token);
741    }
742
743    #[tokio::test]
744    async fn single_request_token_is_regenerated() {
745        let mut app = app(CsrfLayer::new().regenerate(RegenerateToken::PerRequest));
746
747        // Get single-use CSRF token
748        let response = app
749            .ready()
750            .await
751            .unwrap()
752            .call(Request::builder().body(Body::empty()).unwrap())
753            .await
754            .unwrap();
755
756        assert_eq!(response.status(), StatusCode::OK);
757
758        // Tokens are bound to the session - must re-use on each consecutive request.
759        let session_cookie = response.headers().get(SET_COOKIE).unwrap().clone();
760
761        let initial_client_token = response.headers().get("X-CSRF-TOKEN").unwrap();
762        assert_eq!(
763            BASE64_STANDARD.decode(initial_client_token).unwrap().len(),
764            32
765        );
766
767        // Perform another GET request
768        let response = app
769            .ready()
770            .await
771            .unwrap()
772            .call(
773                Request::builder()
774                    .method(Method::GET)
775                    .header(COOKIE, session_cookie.clone())
776                    .body(Body::empty())
777                    .unwrap(),
778            )
779            .await
780            .unwrap();
781
782        assert_eq!(response.status(), StatusCode::OK);
783
784        // Assert token has been regenerated after GET request
785        let client_token = response.headers().get("X-CSRF-TOKEN").unwrap();
786        assert_ne!(client_token, initial_client_token);
787
788        // Attempt using single-request token for POST request
789        let response = app
790            .ready()
791            .await
792            .unwrap()
793            .call(
794                Request::builder()
795                    .method(Method::POST)
796                    .header("X-CSRF-TOKEN", client_token)
797                    .header(COOKIE, session_cookie)
798                    .body(Body::empty())
799                    .unwrap(),
800            )
801            .await
802            .unwrap();
803
804        assert_eq!(response.status(), StatusCode::OK);
805
806        // Assert token has been regenerated after POST request
807        let client_token = response.headers().get("X-CSRF-TOKEN").unwrap();
808        assert_ne!(client_token, initial_client_token);
809    }
810
811    #[tokio::test]
812    async fn accepts_custom_request_header() {
813        let mut app = app(CsrfLayer::new().request_header("X-Custom-Token-Request-Header"));
814
815        // Get CSRF token
816        let response = app
817            .ready()
818            .await
819            .unwrap()
820            .call(Request::builder().body(Body::empty()).unwrap())
821            .await
822            .unwrap();
823
824        assert_eq!(response.status(), StatusCode::OK);
825
826        // Tokens are bound to the session - must re-use on each consecutive request.
827        let session_cookie = response.headers().get(SET_COOKIE).unwrap().clone();
828
829        let client_token = response.headers().get("X-CSRF-TOKEN").unwrap();
830        assert_eq!(BASE64_STANDARD.decode(client_token).unwrap().len(), 32);
831
832        // Use CSRF token for POST request
833        let response = app
834            .ready()
835            .await
836            .unwrap()
837            .call(
838                Request::builder()
839                    .method(Method::POST)
840                    .header("X-Custom-Token-Request-Header", client_token)
841                    .header(COOKIE, session_cookie.clone())
842                    .body(Body::empty())
843                    .unwrap(),
844            )
845            .await
846            .unwrap();
847
848        assert_eq!(response.status(), StatusCode::OK);
849    }
850
851    #[tokio::test]
852    async fn sends_custom_response_header() {
853        // Get CSRF token
854        let response = app(CsrfLayer::new().response_header("X-Custom-Token-Response-Header"))
855            .oneshot(Request::builder().body(Body::empty()).unwrap())
856            .await
857            .unwrap();
858
859        assert_eq!(response.status(), StatusCode::OK);
860
861        let client_token = response
862            .headers()
863            .get("X-Custom-Token-Response-Header")
864            .unwrap();
865        assert_eq!(BASE64_STANDARD.decode(client_token).unwrap().len(), 32);
866    }
867
868    #[tokio::test]
869    async fn uses_custom_session_key() {
870        // Custom handler asserting the layer's configured session key is set,
871        // and its value looks like a CSRF token.
872        async fn extract_session(session: ReadableSession) -> StatusCode {
873            let session_csrf_token: String = session.get("custom_session_key").unwrap();
874
875            assert_eq!(
876                BASE64_STANDARD.decode(session_csrf_token).unwrap().len(),
877                32
878            );
879            StatusCode::OK
880        }
881
882        let app = Router::new()
883            .route("/", get(extract_session))
884            .layer(CsrfLayer::new().session_key("custom_session_key"))
885            .layer(session_layer());
886
887        let response = app
888            .oneshot(Request::builder().body(Body::empty()).unwrap())
889            .await
890            .unwrap();
891
892        assert_eq!(response.status(), StatusCode::OK);
893    }
894
895    #[tokio::test]
896    async fn missing_session_layer_error_response() {
897        let app = Router::new()
898            .route("/", get(handler))
899            .layer(CsrfLayer::new());
900
901        let response = app
902            .oneshot(Request::builder().body(Body::empty()).unwrap())
903            .await
904            .unwrap();
905
906        assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
907    }
908
909    #[tokio::test]
910    async fn invalid_token_str_error_response() {
911        let layer = CsrfLayer::new();
912        let response = Response::builder()
913            .status(StatusCode::OK)
914            .body(axum::body::boxed(Body::empty()))
915            .unwrap();
916        let response = layer.response_with_token(response, "\n");
917
918        assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
919    }
920}