axum_csrf_sync_pattern/lib.rs
1//! # Axum Synchronizer Token Pattern CSRF prevention
2//!
3//! This crate provides a CSRF protection layer and middleware for use with the [axum](https://docs.rs/axum/) web framework.
4//!
5//! The middleware implements the [CSRF Synchronizer Token Pattern](https://cheatsheetseries.owasp.org/cheatsheets/Cross-Site_Request_Forgery_Prevention_Cheat_Sheet.html#synchronizer-token-pattern)
6//! for AJAX backends and API endpoints as described in the OWASP CSRF prevention cheat sheet.
7//!
8//! ## Scope
9//!
10//! This middleware implements token transfer via [custom request headers](https://cheatsheetseries.owasp.org/cheatsheets/Cross-Site_Request_Forgery_Prevention_Cheat_Sheet.html#use-of-custom-request-headers).
11//!
12//! The middleware requires and is built upon [`axum_sessions`](https://docs.rs/axum-sessions/), which in turn uses [`async_session`](https://docs.rs/async-session/).
13//!
14//! The [Same Origin Policy](https://developer.mozilla.org/en-US/docs/Web/Security/Same-origin_policy) prevents the custom request header to be set by foreign scripts.
15//!
16//! ### In which contexts should I use this middleware?
17//!
18//! The goal of this middleware is to prevent cross-site request forgery attacks specifically in applications communicating with their backend by means of the JavaScript
19//! [`fetch()` API](https://developer.mozilla.org/en-US/docs/Web/API/fetch) or classic [`XmlHttpRequest`](https://developer.mozilla.org/en-US/docs/Web/API/XMLHttpRequest),
20//! traditionally called "AJAX".
21//!
22//! The Synchronizer Token Pattern is especially useful in [CORS](https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS) contexts,
23//! as the underlying session cookie is obligatorily secured and inaccessible by JavaScript, while the custom HTTP response header carrying the CSRF token can be exposed
24//! using the CORS [`Access-Control-Expose-Headers`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Expose-Headers) HTTP response header.
25//!
26//! While the [Same Origin Policy](https://developer.mozilla.org/en-US/docs/Web/Security/Same-origin_policy) commonly prevents custom request headers to be set on cross-origin requests,
27//! use of the use of the [Access-Control-Allow-Headers](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers) CORS HTTP response header
28//! can be used to specifically allow CORS requests to be equipped with a required custom HTTP request header.
29//!
30//! This approach ensures that requests forged by auto-submitted forms or other data-submitting scripts from foreign origins are unable to add the required header.
31//!
32//! ### When should I use other CSRF protection patterns or libraries?
33//!
34//! Use other available middleware libraries if you plan on submitting classical HTML forms without the use of JavaScript, and if you do not send the form data across origins.
35//!
36//! ## Security
37//! ### Token randomness
38//!
39//! The CSRF tokens are generated using [`rand::ThreadRng`](https://rust-random.github.io/rand/rand/rngs/struct.ThreadRng.html) which is considered cryptographically secure (CSPRNG).
40//! See ["Our RNGs"](https://rust-random.github.io/book/guide-rngs.html#cryptographically-secure-pseudo-random-number-generators-csprngs) for more.
41//!
42//! ### Underlying session security
43//!
44//! The security of the underlying session is paramount - the CSRF prevention methods applied can only be as secure as the session carrying the server-side token.
45//!
46//! - When creating your [SessionLayer](https://docs.rs/axum-sessions/latest/axum_sessions/struct.SessionLayer.html), make sure to use at least 64 bytes of cryptographically secure randomness.
47//! - Do not lower the secure defaults: Keep the session cookie's `secure` flag **on**.
48//! - Use the strictest possible same-site policy.
49//!
50//! ### CORS security
51//!
52//! If you need to provide and secure cross-site requests:
53//!
54//! - Allow only your backend origin when configuring the [`CorsLayer`](https://docs.rs/tower-http/latest/tower_http/cors/struct.CorsLayer.html)
55//! - Allow only the headers you need. (At least the CSRF request token header.)
56//! - Only expose the headers you need. (At least the CSRF response token header.)
57//!
58//! ### No leaks of error details
59//!
60//! Errors are logged using [`tracing::error!`]. Error responses do not contain error details.
61//!
62//! Use [`tower_http::TraceLayer`](https://docs.rs/tower-http/latest/tower_http/trace/struct.TraceLayer.html) to capture and view traces.
63//!
64//! ## Safety
65//!
66//! This crate uses no `unsafe` code.
67//!
68//! The layer and middleware functionality is tested. View the the module source code to learn more.
69//!
70//! ## Usage
71//!
72//! See the [example projects](https://github.com/LeoniePhiline/axum-csrf-sync-pattern/tree/main/examples/) for same-site and cross-site usage.
73//!
74//! ### Same-site usage
75//!
76//! **Note:** The crate repository contains example projects for same-site and cross-site usage!
77//!
78//! Configure your session and CSRF protection layer in your backend application:
79//!
80//! ```rust
81//! use axum::{
82//! body::Body,
83//! http::StatusCode,
84//! routing::{get, Router},
85//! };
86//! use axum_csrf_sync_pattern::{CsrfLayer, RegenerateToken};
87//! use axum_sessions::{async_session::MemoryStore, SessionLayer};
88//! use rand::RngCore;
89//!
90//! let mut secret = [0; 64];
91//! rand::thread_rng().try_fill_bytes(&mut secret).unwrap();
92//!
93//! async fn handler() -> StatusCode {
94//! StatusCode::OK
95//! }
96//!
97//! let app = Router::new()
98//! .route("/", get(handler).post(handler))
99//! .layer(
100//! CsrfLayer::new()
101//!
102//! // Optionally, configure the layer with the following options:
103//!
104//! // Default: RegenerateToken::PerSession
105//! .regenerate(RegenerateToken::PerUse)
106//! // Default: "X-CSRF-TOKEN"
107//! .request_header("X-Custom-Request-Header")
108//! // Default: "X-CSRF-TOKEN"
109//! .response_header("X-Custom-Response-Header")
110//! // Default: "_csrf_token"
111//! .session_key("_custom_session_key")
112//! )
113//! .layer(SessionLayer::new(MemoryStore::new(), &secret));
114//!
115//! // Use hyper to run `app` as service and expose on a local port or socket.
116//!
117//! # use tower::util::ServiceExt;
118//! # tokio_test::block_on(async {
119//! # app.oneshot(
120//! # axum::http::Request::builder().body(axum::body::Body::empty()).unwrap()
121//! # ).await.unwrap();
122//! # })
123//! ```
124//!
125//! Receive the token and send same-site requests, using your custom header:
126//!
127//! ```javascript
128//! const test = async () => {
129//! // Receive CSRF token (Default response header name: 'X-CSRF-TOKEN')
130//! const token = (await fetch('/')).headers.get('X-Custom-Response-Header');
131//!
132//! // Submit data using the token
133//! await fetch('/', {
134//! method: 'POST',
135//! headers: {
136//! 'Content-Type': 'application/json',
137//! // Default request header name: 'X-CSRF-TOKEN'
138//! 'X-Custom-Request-Header': token,
139//! },
140//! body: JSON.stringify({ /* ... */ }),
141//! });
142//! };
143//! ```
144//!
145//! ### CORS-enabled usage
146//!
147//! **Note:** The crate repository contains example projects for same-site and cross-site usage!
148//!
149//! Configure your CORS layer, session and CSRF protection layer in your backend application:
150//!
151//! ```rust
152//! use axum::{
153//! body::Body,
154//! http::{header, Method, StatusCode},
155//! routing::{get, Router},
156//! };
157//! use axum_csrf_sync_pattern::{CsrfLayer, RegenerateToken};
158//! use axum_sessions::{async_session::MemoryStore, SessionLayer};
159//! use rand::RngCore;
160//! use tower_http::cors::{AllowOrigin, CorsLayer};
161//!
162//! let mut secret = [0; 64];
163//! rand::thread_rng().try_fill_bytes(&mut secret).unwrap();
164//!
165//! async fn handler() -> StatusCode {
166//! StatusCode::OK
167//! }
168//!
169//! let app = Router::new()
170//! .route("/", get(handler).post(handler))
171//! .layer(
172//! // See example above for custom layer configuration.
173//! CsrfLayer::new()
174//! )
175//! .layer(SessionLayer::new(MemoryStore::new(), &secret))
176//! .layer(
177//! CorsLayer::new()
178//! .allow_origin(AllowOrigin::list(["https://www.example.com".parse().unwrap()]))
179//! .allow_methods([Method::GET, Method::POST])
180//! .allow_headers([header::CONTENT_TYPE, "X-CSRF-TOKEN".parse().unwrap()])
181//! .allow_credentials(true)
182//! .expose_headers(["X-CSRF-TOKEN".parse().unwrap()]),
183//! );
184//!
185//! // Use hyper to run `app` as service and expose on a local port or socket.
186//!
187//! # use tower::util::ServiceExt;
188//! # tokio_test::block_on(async {
189//! # app.oneshot(
190//! # axum::http::Request::builder().body(axum::body::Body::empty()).unwrap()
191//! # ).await.unwrap();
192//! # })
193//! ```
194//!
195//! Receive the token and send cross-site requests, using your custom header:
196//!
197//! ```javascript
198//! const test = async () => {
199//! // Receive CSRF token
200//! const token = (await fetch('https://backend.example.com/', {
201//! credentials: 'include',
202//! })).headers.get('X-CSRF-TOKEN');
203//!
204//! // Submit data using the token
205//! await fetch('https://backend.example.com/', {
206//! method: 'POST',
207//! headers: {
208//! 'Content-Type': 'application/json',
209//! 'X-CSRF-TOKEN': token,
210//! },
211//! credentials: 'include',
212//! body: JSON.stringify({ /* ... */ }),
213//! });
214//! };
215//! ```
216//!
217//! ## Contributing
218//!
219//! Pull requests are welcome!
220//!
221
222#![forbid(unsafe_code, future_incompatible)]
223#![deny(
224 missing_debug_implementations,
225 nonstandard_style,
226 missing_docs,
227 unreachable_pub,
228 missing_copy_implementations,
229 unused_qualifications
230)]
231
232use std::{
233 convert::Infallible,
234 future::Future,
235 pin::Pin,
236 task::{Context, Poll},
237};
238
239use axum::http::{self, HeaderValue, Request, StatusCode};
240use axum_core::response::{IntoResponse, Response};
241use axum_sessions::{async_session::Session, SessionHandle};
242use base64::prelude::*;
243use rand::RngCore;
244use tokio::sync::RwLockWriteGuard;
245use tower::Layer;
246
247/// Use `CsrfLayer::new()` to provide the middleware and configuration to axum's service stack.
248///
249/// Use the provided methods to configure details, such as when tokens are regenerated, what request and response
250/// headers should be named, and under which key the token should be stored in the session.
251#[derive(Clone, Copy, Debug)]
252pub struct CsrfLayer {
253 /// Configures when tokens are regenerated: Per session, per use or per request. See [`RegenerateToken`] for details.
254 pub regenerate_token: RegenerateToken,
255
256 /// Configures the request header name accepted by the middleware. Defaults to `"X-CSRF-TOKEN"`.
257 /// This header is set on your JavaScript or WASM requests originating from the browser.
258 pub request_header: &'static str,
259
260 /// Configures the response header name sent by the middleware. Defaults to `"X-CSRF-TOKEN"`.
261 /// This header is received by your JavaScript or WASM code and its name must be used to extract the token from the HTTP response.
262 pub response_header: &'static str,
263
264 /// Configures the key under which the middleware stores the server-side token in the session. Defaults to `"_csrf_token"`.
265 pub session_key: &'static str,
266}
267
268impl Default for CsrfLayer {
269 fn default() -> Self {
270 Self {
271 regenerate_token: Default::default(),
272 request_header: "X-CSRF-TOKEN",
273 response_header: "X-CSRF-TOKEN",
274 session_key: "_csrf_token",
275 }
276 }
277}
278
279impl CsrfLayer {
280 /// Create a new CSRF synchronizer token layer to inject into your middleware stack using
281 /// [`axum::Router::layer()`].
282 pub fn new() -> Self {
283 Self::default()
284 }
285
286 /// Configure when tokens are regenerated: Per session, per use or per request. See [`RegenerateToken`] for details.
287 pub fn regenerate(mut self, regenerate_token: RegenerateToken) -> Self {
288 self.regenerate_token = regenerate_token;
289
290 self
291 }
292
293 /// Configure a custom request header name accepted by the middleware. Defaults to `"X-CSRF-TOKEN"`.
294 ///
295 /// This header is set on your JavaScript or WASM requests originating from the browser.
296 pub fn request_header(mut self, request_header: &'static str) -> Self {
297 self.request_header = request_header;
298
299 self
300 }
301
302 /// Configure a custom response header name sent by the middleware. Defaults to `"X-CSRF-TOKEN"`.
303 ///
304 /// This header is received by your JavaScript or WASM code and its name must be used to extract the token from the HTTP response.
305 pub fn response_header(mut self, response_header: &'static str) -> Self {
306 self.response_header = response_header;
307
308 self
309 }
310
311 /// Configure a custom key under which the middleware stores the server-side token in the session. Defaults to `"_csrf_token"`.
312 pub fn session_key(mut self, session_key: &'static str) -> Self {
313 self.session_key = session_key;
314
315 self
316 }
317
318 fn regenerate_token(
319 &self,
320 session_write: &mut RwLockWriteGuard<Session>,
321 ) -> Result<String, Error> {
322 let mut buf = [0; 32];
323 rand::thread_rng().try_fill_bytes(&mut buf)?;
324 let token = BASE64_STANDARD.encode(buf);
325 session_write.insert(self.session_key, &token)?;
326
327 Ok(token)
328 }
329
330 fn response_with_token(&self, mut response: Response, server_token: &str) -> Response {
331 response.headers_mut().insert(
332 self.response_header,
333 match HeaderValue::from_str(server_token).map_err(Error::from) {
334 Ok(token_header) => token_header,
335 Err(error) => return error.into_response(),
336 },
337 );
338 response
339 }
340}
341
342impl<S> Layer<S> for CsrfLayer {
343 type Service = CsrfMiddleware<S>;
344
345 fn layer(&self, inner: S) -> Self::Service {
346 CsrfMiddleware::new(inner, *self)
347 }
348}
349
350/// This enum is used with [`CsrfLayer::regenerate`] to determine
351/// at which occurences the CSRF token should be regenerated.
352///
353/// You could understand these options as modes to choose a level of paranoia, depending on your application's requirements.
354///
355/// This paranoia level is a trade-off between ergonomics and security; as more frequent
356/// token invalidation requires more overhead for handling and renewing tokens on the client side,
357/// as well as retrying requests with a fresh token, should they fail.
358#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
359#[allow(clippy::enum_variant_names)]
360pub enum RegenerateToken {
361 /// Generate one CSRF token per session and use this token until the session ends.
362 ///
363 /// This is the default behavior and should work for most applications.
364 #[default]
365 PerSession,
366 /// Regenerate the CSRF token after each use. A "use" describes an unsafe HTTP method
367 /// (`POST`, `PUT`, `PATCH`, `DELETE`).
368 ///
369 /// CSRF tokens are not required for, and thus not invalidated by handling requests
370 /// using safe HTTP methods (`HEAD`, `GET`, `OPTIONS`, `TRACE`, `CONNECT`).
371 PerUse,
372 /// Regenerate the CSRF token at each request, including safe HTTP methods (`HEAD`, `GET`, `OPTIONS`, `TRACE`, `CONNECT`).
373 ///
374 /// This behavior might require elaborate token handling on the client side,
375 /// as any concurrent requests mean race conditions from the client's perspective,
376 /// and each request's response yields a new token to be used on the consecutive request.
377 PerRequest,
378}
379
380#[derive(thiserror::Error, Debug)]
381enum Error {
382 #[error("Random number generator error")]
383 Rng(#[from] rand::Error),
384
385 #[error("Serde JSON error")]
386 Serde(#[from] axum_sessions::async_session::serde_json::Error),
387
388 #[error("Session extension missing. Is `axum_sessions::SessionLayer` installed and layered around the `axum_csrf_sync_pattern::CsrfLayer`?")]
389 SessionLayerMissing,
390
391 #[error("Incoming CSRF token header was not valid ASCII")]
392 InvalidClientTokenHeader(#[from] http::header::ToStrError),
393
394 #[error("Invalid CSRF token when preparing response header")]
395 InvalidServerTokenHeader(#[from] http::header::InvalidHeaderValue),
396}
397
398impl IntoResponse for Error {
399 fn into_response(self) -> Response {
400 tracing::error!(?self);
401 StatusCode::INTERNAL_SERVER_ERROR.into_response()
402 }
403}
404
405/// This middleware is created by axum by applying the `CsrfLayer`.
406/// It verifies the CSRF token header on incoming requests, regenerates tokens as configured,
407/// and attaches the current token to the outgoing response.
408///
409/// In detail, this middleware receives a CSRF token as `X-CSRF-TOKEN` (if not custom configured
410/// with a different name) HTTP request header value
411/// and compares it to the token stored in the session.
412///
413/// Upon response from the inner service, the session token is returned to the
414/// client via the `X-CSRF-TOKEN` response header.
415///
416/// Make sure to expose this header in your CORS configuration if necessary!
417///
418/// Requires and uses `axum_sessions`.
419///
420/// Optionally regenerates the token from the session after successful verification,
421/// to ensure a new token is used for each writing (`POST`, `PUT`, `DELETE`) request.
422/// Enable with [`RegenerateToken::PerUse`].
423///
424/// For maximum security, but severely reduced ergonomics, optionally regenerates the
425/// token from the session after each request, to keep the token validity as short as
426/// possible. Enable with [`RegenerateToken::PerRequest`].
427#[derive(Debug, Clone)]
428pub struct CsrfMiddleware<S> {
429 inner: S,
430 layer: CsrfLayer,
431}
432
433impl<S> CsrfMiddleware<S> {
434 /// Create a new middleware from an inner [`tower::Service`] (axum-specific bounds, such as `Infallible` errors apply!) and a [`CsrfLayer`].
435 /// Commonly, the middleware is created by the [`tower::Layer`] - and never manually.
436 pub fn new(inner: S, layer: CsrfLayer) -> Self {
437 CsrfMiddleware { inner, layer }
438 }
439
440 /// Create a new CSRF synchronizer token layer.
441 /// Equivalent to calling [`CsrfLayer::new()`].
442 pub fn layer() -> CsrfLayer {
443 CsrfLayer::default()
444 }
445}
446
447impl<S, B: Send + 'static> tower::Service<Request<B>> for CsrfMiddleware<S>
448where
449 S: tower::Service<Request<B>, Response = Response, Error = Infallible> + Send + Clone + 'static,
450 S::Future: Send,
451{
452 type Response = S::Response;
453 type Error = Infallible;
454 type Future =
455 Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
456
457 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
458 self.inner.poll_ready(cx)
459 }
460
461 fn call(&mut self, req: Request<B>) -> Self::Future {
462 let clone = self.inner.clone();
463 let mut inner = std::mem::replace(&mut self.inner, clone);
464 let layer = self.layer;
465 Box::pin(async move {
466 let session_handle = match req
467 .extensions()
468 .get::<SessionHandle>()
469 .ok_or(Error::SessionLayerMissing)
470 {
471 Ok(session_handle) => session_handle,
472 Err(error) => return Ok(error.into_response()),
473 };
474
475 // Extract the CSRF server side token from the session; create a new one if none has been set yet.
476 // If the regeneration option is set to "per request", then regenerate the token even if present in the session.
477 let mut session_write = session_handle.write().await;
478 let mut server_token = match session_write.get::<String>(layer.session_key) {
479 Some(token) => token,
480 None => match layer.regenerate_token(&mut session_write) {
481 Ok(token) => token,
482 Err(error) => return Ok(error.into_response()),
483 },
484 };
485
486 if !req.method().is_safe() {
487 // Verify incoming CSRF token for unsafe request methods.
488 let client_token = {
489 match req.headers().get(layer.request_header) {
490 Some(token) => token,
491 None => {
492 tracing::warn!("{} header missing!", layer.request_header);
493 return Ok(layer.response_with_token(
494 StatusCode::FORBIDDEN.into_response(),
495 &server_token,
496 ));
497 }
498 }
499 };
500
501 let client_token = match client_token.to_str().map_err(Error::from) {
502 Ok(token) => token,
503 Err(error) => {
504 return Ok(layer.response_with_token(error.into_response(), &server_token))
505 }
506 };
507 if client_token != server_token {
508 tracing::warn!("{} header mismatch!", layer.request_header);
509 return Ok(layer.response_with_token(
510 (StatusCode::FORBIDDEN).into_response(),
511 &server_token,
512 ));
513 }
514 }
515
516 // Create new token if configured to regenerate per each request,
517 // or if configured to regenerate per use and just used.
518 if layer.regenerate_token == RegenerateToken::PerRequest
519 || (!req.method().is_safe() && layer.regenerate_token == RegenerateToken::PerUse)
520 {
521 server_token = match layer.regenerate_token(&mut session_write) {
522 Ok(token) => token,
523 Err(error) => {
524 return Ok(layer.response_with_token(error.into_response(), &server_token))
525 }
526 };
527 }
528
529 drop(session_write);
530
531 let response = inner.call(req).await.into_response();
532
533 // Add X-CSRF-TOKEN response header.
534 Ok(layer.response_with_token(response, &server_token))
535 })
536 }
537}
538
539#[cfg(test)]
540mod tests {
541 use std::convert::Infallible;
542
543 use axum::{body::Body, routing::get, Router};
544 use axum_core::response::{IntoResponse, Response};
545 use axum_sessions::{async_session::MemoryStore, extractors::ReadableSession, SessionLayer};
546 use http::{
547 header::{COOKIE, SET_COOKIE},
548 Method, Request, StatusCode,
549 };
550 use tower::{Service, ServiceExt};
551
552 use super::*;
553
554 async fn handler() -> Result<Response, Infallible> {
555 Ok((
556 StatusCode::OK,
557 "The default test success response has a body",
558 )
559 .into_response())
560 }
561
562 fn session_layer() -> SessionLayer<MemoryStore> {
563 let mut secret = [0; 64];
564 rand::thread_rng().try_fill_bytes(&mut secret).unwrap();
565 SessionLayer::new(MemoryStore::new(), &secret)
566 }
567
568 fn app(csrf_layer: CsrfLayer) -> Router {
569 Router::new()
570 .route("/", get(handler).post(handler))
571 .layer(csrf_layer)
572 .layer(session_layer())
573 }
574
575 #[tokio::test]
576 async fn get_without_token_succeeds() {
577 let request = Request::builder()
578 .method(Method::GET)
579 .body(Body::empty())
580 .unwrap();
581
582 let response = app(CsrfLayer::new()).oneshot(request).await.unwrap();
583
584 assert_eq!(response.status(), StatusCode::OK);
585
586 let client_token = response.headers().get("X-CSRF-TOKEN").unwrap();
587 assert_eq!(BASE64_STANDARD.decode(client_token).unwrap().len(), 32);
588 }
589
590 #[tokio::test]
591 async fn post_without_token_fails() {
592 let request = Request::builder()
593 .method(Method::POST)
594 .body(Body::empty())
595 .unwrap();
596 let response = app(CsrfLayer::new()).oneshot(request).await.unwrap();
597
598 assert_eq!(response.status(), StatusCode::FORBIDDEN);
599
600 // Assert: Response must contain token even on request token failure.
601 let client_token = response.headers().get("X-CSRF-TOKEN").unwrap();
602 assert_eq!(BASE64_STANDARD.decode(client_token).unwrap().len(), 32);
603 }
604
605 #[tokio::test]
606 async fn session_token_remains_valid() {
607 let mut app = app(CsrfLayer::new().regenerate(RegenerateToken::PerSession));
608
609 // Get CSRF token
610 let response = app
611 .ready()
612 .await
613 .unwrap()
614 .call(Request::builder().body(Body::empty()).unwrap())
615 .await
616 .unwrap();
617
618 assert_eq!(response.status(), StatusCode::OK);
619
620 // Tokens are bound to the session - must re-use on each consecutive request.
621 let session_cookie = response.headers().get(SET_COOKIE).unwrap().clone();
622
623 let initial_client_token = response.headers().get("X-CSRF-TOKEN").unwrap();
624 assert_eq!(
625 BASE64_STANDARD.decode(initial_client_token).unwrap().len(),
626 32
627 );
628
629 // Use CSRF token for POST request
630 let response = app
631 .ready()
632 .await
633 .unwrap()
634 .call(
635 Request::builder()
636 .method(Method::POST)
637 .header("X-CSRF-TOKEN", initial_client_token)
638 .header(COOKIE, session_cookie.clone())
639 .body(Body::empty())
640 .unwrap(),
641 )
642 .await
643 .unwrap();
644
645 assert_eq!(response.status(), StatusCode::OK);
646
647 // Assert token has not been changed after POST request
648 let client_token = response.headers().get("X-CSRF-TOKEN").unwrap();
649 assert_eq!(client_token, initial_client_token);
650
651 // Attempt token re-use for a second POST request
652 let response = app
653 .ready()
654 .await
655 .unwrap()
656 .call(
657 Request::builder()
658 .method(Method::POST)
659 .header("X-CSRF-TOKEN", initial_client_token)
660 .header(COOKIE, session_cookie)
661 .body(Body::empty())
662 .unwrap(),
663 )
664 .await
665 .unwrap();
666
667 assert_eq!(response.status(), StatusCode::OK);
668
669 // Assert token has not been changed after POST request
670 let client_token = response.headers().get("X-CSRF-TOKEN").unwrap();
671 assert_eq!(client_token, initial_client_token);
672 }
673
674 #[tokio::test]
675 async fn single_use_token_is_regenerated() {
676 let mut app = app(CsrfLayer::new().regenerate(RegenerateToken::PerUse));
677
678 // Get single-use CSRF token
679 let response = app
680 .ready()
681 .await
682 .unwrap()
683 .call(Request::builder().body(Body::empty()).unwrap())
684 .await
685 .unwrap();
686
687 assert_eq!(response.status(), StatusCode::OK);
688
689 // Tokens are bound to the session - must re-use on each consecutive request.
690 let session_cookie = response.headers().get(SET_COOKIE).unwrap().clone();
691
692 let initial_client_token = response.headers().get("X-CSRF-TOKEN").unwrap();
693 assert_eq!(
694 BASE64_STANDARD.decode(initial_client_token).unwrap().len(),
695 32
696 );
697
698 // Use CSRF token for POST request
699 let response = app
700 .ready()
701 .await
702 .unwrap()
703 .call(
704 Request::builder()
705 .method(Method::POST)
706 .header("X-CSRF-TOKEN", initial_client_token)
707 .header(COOKIE, session_cookie.clone())
708 .body(Body::empty())
709 .unwrap(),
710 )
711 .await
712 .unwrap();
713
714 assert_eq!(response.status(), StatusCode::OK);
715
716 // Assert token has been regenerated after POST request
717 let client_token = response.headers().get("X-CSRF-TOKEN").unwrap();
718 assert_ne!(client_token, initial_client_token);
719
720 // Attempt token re-use for a second POST request
721 let response = app
722 .ready()
723 .await
724 .unwrap()
725 .call(
726 Request::builder()
727 .method(Method::POST)
728 .header("X-CSRF-TOKEN", initial_client_token)
729 .header(COOKIE, session_cookie)
730 .body(Body::empty())
731 .unwrap(),
732 )
733 .await
734 .unwrap();
735
736 assert_eq!(response.status(), StatusCode::FORBIDDEN);
737
738 // Assert token has been regenerated after POST request
739 let client_token = response.headers().get("X-CSRF-TOKEN").unwrap();
740 assert_ne!(client_token, initial_client_token);
741 }
742
743 #[tokio::test]
744 async fn single_request_token_is_regenerated() {
745 let mut app = app(CsrfLayer::new().regenerate(RegenerateToken::PerRequest));
746
747 // Get single-use CSRF token
748 let response = app
749 .ready()
750 .await
751 .unwrap()
752 .call(Request::builder().body(Body::empty()).unwrap())
753 .await
754 .unwrap();
755
756 assert_eq!(response.status(), StatusCode::OK);
757
758 // Tokens are bound to the session - must re-use on each consecutive request.
759 let session_cookie = response.headers().get(SET_COOKIE).unwrap().clone();
760
761 let initial_client_token = response.headers().get("X-CSRF-TOKEN").unwrap();
762 assert_eq!(
763 BASE64_STANDARD.decode(initial_client_token).unwrap().len(),
764 32
765 );
766
767 // Perform another GET request
768 let response = app
769 .ready()
770 .await
771 .unwrap()
772 .call(
773 Request::builder()
774 .method(Method::GET)
775 .header(COOKIE, session_cookie.clone())
776 .body(Body::empty())
777 .unwrap(),
778 )
779 .await
780 .unwrap();
781
782 assert_eq!(response.status(), StatusCode::OK);
783
784 // Assert token has been regenerated after GET request
785 let client_token = response.headers().get("X-CSRF-TOKEN").unwrap();
786 assert_ne!(client_token, initial_client_token);
787
788 // Attempt using single-request token for POST request
789 let response = app
790 .ready()
791 .await
792 .unwrap()
793 .call(
794 Request::builder()
795 .method(Method::POST)
796 .header("X-CSRF-TOKEN", client_token)
797 .header(COOKIE, session_cookie)
798 .body(Body::empty())
799 .unwrap(),
800 )
801 .await
802 .unwrap();
803
804 assert_eq!(response.status(), StatusCode::OK);
805
806 // Assert token has been regenerated after POST request
807 let client_token = response.headers().get("X-CSRF-TOKEN").unwrap();
808 assert_ne!(client_token, initial_client_token);
809 }
810
811 #[tokio::test]
812 async fn accepts_custom_request_header() {
813 let mut app = app(CsrfLayer::new().request_header("X-Custom-Token-Request-Header"));
814
815 // Get CSRF token
816 let response = app
817 .ready()
818 .await
819 .unwrap()
820 .call(Request::builder().body(Body::empty()).unwrap())
821 .await
822 .unwrap();
823
824 assert_eq!(response.status(), StatusCode::OK);
825
826 // Tokens are bound to the session - must re-use on each consecutive request.
827 let session_cookie = response.headers().get(SET_COOKIE).unwrap().clone();
828
829 let client_token = response.headers().get("X-CSRF-TOKEN").unwrap();
830 assert_eq!(BASE64_STANDARD.decode(client_token).unwrap().len(), 32);
831
832 // Use CSRF token for POST request
833 let response = app
834 .ready()
835 .await
836 .unwrap()
837 .call(
838 Request::builder()
839 .method(Method::POST)
840 .header("X-Custom-Token-Request-Header", client_token)
841 .header(COOKIE, session_cookie.clone())
842 .body(Body::empty())
843 .unwrap(),
844 )
845 .await
846 .unwrap();
847
848 assert_eq!(response.status(), StatusCode::OK);
849 }
850
851 #[tokio::test]
852 async fn sends_custom_response_header() {
853 // Get CSRF token
854 let response = app(CsrfLayer::new().response_header("X-Custom-Token-Response-Header"))
855 .oneshot(Request::builder().body(Body::empty()).unwrap())
856 .await
857 .unwrap();
858
859 assert_eq!(response.status(), StatusCode::OK);
860
861 let client_token = response
862 .headers()
863 .get("X-Custom-Token-Response-Header")
864 .unwrap();
865 assert_eq!(BASE64_STANDARD.decode(client_token).unwrap().len(), 32);
866 }
867
868 #[tokio::test]
869 async fn uses_custom_session_key() {
870 // Custom handler asserting the layer's configured session key is set,
871 // and its value looks like a CSRF token.
872 async fn extract_session(session: ReadableSession) -> StatusCode {
873 let session_csrf_token: String = session.get("custom_session_key").unwrap();
874
875 assert_eq!(
876 BASE64_STANDARD.decode(session_csrf_token).unwrap().len(),
877 32
878 );
879 StatusCode::OK
880 }
881
882 let app = Router::new()
883 .route("/", get(extract_session))
884 .layer(CsrfLayer::new().session_key("custom_session_key"))
885 .layer(session_layer());
886
887 let response = app
888 .oneshot(Request::builder().body(Body::empty()).unwrap())
889 .await
890 .unwrap();
891
892 assert_eq!(response.status(), StatusCode::OK);
893 }
894
895 #[tokio::test]
896 async fn missing_session_layer_error_response() {
897 let app = Router::new()
898 .route("/", get(handler))
899 .layer(CsrfLayer::new());
900
901 let response = app
902 .oneshot(Request::builder().body(Body::empty()).unwrap())
903 .await
904 .unwrap();
905
906 assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
907 }
908
909 #[tokio::test]
910 async fn invalid_token_str_error_response() {
911 let layer = CsrfLayer::new();
912 let response = Response::builder()
913 .status(StatusCode::OK)
914 .body(axum::body::boxed(Body::empty()))
915 .unwrap();
916 let response = layer.response_with_token(response, "\n");
917
918 assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
919 }
920}