1use std::future::Future;
46use std::pin::Pin;
47use std::sync::Arc;
48use std::task::{Context, Poll};
49
50use axum::extract::{FromRequestParts, OptionalFromRequestParts};
51use axum::http::{Request, Response, StatusCode};
52use http::header::HeaderName;
53
54use tower::{Layer, Service};
55use uuid::Uuid;
56
57use super::config::CsrfConfig;
58
59const CSRF_FORBIDDEN_MESSAGE: &str = "CSRF token missing or invalid";
61
62#[derive(Clone, Debug)]
68pub struct CsrfFormField(pub String);
69
70#[derive(Clone, Debug)]
93pub struct CsrfToken(String);
94
95impl CsrfToken {
96 #[must_use]
98 pub fn token(&self) -> &str {
99 &self.0
100 }
101
102 #[cfg(test)]
103 pub(crate) const fn new(token: String) -> Self {
104 Self(token)
105 }
106}
107
108impl std::fmt::Display for CsrfToken {
109 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
110 f.write_str(&self.0)
111 }
112}
113
114impl<S> FromRequestParts<S> for CsrfToken
115where
116 S: Send + Sync,
117{
118 type Rejection = (StatusCode, &'static str);
119
120 async fn from_request_parts(
121 parts: &mut axum::http::request::Parts,
122 _state: &S,
123 ) -> Result<Self, Self::Rejection> {
124 parts.extensions.get::<Self>().cloned().ok_or((
125 StatusCode::INTERNAL_SERVER_ERROR,
126 "CSRF token not found in request extensions. Is CsrfLayer enabled?",
127 ))
128 }
129}
130
131impl<S> OptionalFromRequestParts<S> for CsrfToken
132where
133 S: Send + Sync,
134{
135 type Rejection = std::convert::Infallible;
136
137 async fn from_request_parts(
138 parts: &mut axum::http::request::Parts,
139 _state: &S,
140 ) -> Result<Option<Self>, Self::Rejection> {
141 Ok(parts.extensions.get::<Self>().cloned())
142 }
143}
144
145impl<S> FromRequestParts<S> for CsrfFormField
146where
147 S: Send + Sync,
148{
149 type Rejection = (StatusCode, &'static str);
150
151 async fn from_request_parts(
152 parts: &mut axum::http::request::Parts,
153 _state: &S,
154 ) -> Result<Self, Self::Rejection> {
155 parts.extensions.get::<Self>().cloned().ok_or((
156 StatusCode::INTERNAL_SERVER_ERROR,
157 "CSRF form field not found in request extensions. Is CsrfLayer enabled?",
158 ))
159 }
160}
161
162impl<S> OptionalFromRequestParts<S> for CsrfFormField
163where
164 S: Send + Sync,
165{
166 type Rejection = std::convert::Infallible;
167
168 async fn from_request_parts(
169 parts: &mut axum::http::request::Parts,
170 _state: &S,
171 ) -> Result<Option<Self>, Self::Rejection> {
172 Ok(parts.extensions.get::<Self>().cloned())
173 }
174}
175
176#[derive(Debug, Clone)]
178struct CsrfSettings {
179 cookie_name: String,
180 token_header: HeaderName,
181 form_field: String,
182 safe_methods: Vec<http::Method>,
183 exempt_paths: Vec<String>,
184 signing_keys: Option<Arc<crate::security::config::ResolvedSigningKeys>>,
185}
186
187#[derive(Clone, Debug)]
191pub struct CsrfLayer {
192 settings: Arc<CsrfSettings>,
193}
194
195impl CsrfLayer {
196 #[must_use]
198 pub fn from_config(config: &CsrfConfig) -> Self {
199 let safe_methods = config
200 .safe_methods
201 .iter()
202 .filter_map(|m| m.parse::<http::Method>().ok())
203 .collect();
204
205 let token_header = config
206 .token_header
207 .parse::<HeaderName>()
208 .unwrap_or_else(|_| HeaderName::from_static("x-csrf-token"));
209
210 Self {
211 settings: Arc::new(CsrfSettings {
212 cookie_name: config.cookie_name.clone(),
213 token_header,
214 form_field: config.form_field.clone(),
215 safe_methods,
216 exempt_paths: config.exempt_paths.clone(),
217 signing_keys: None,
218 }),
219 }
220 }
221
222 #[must_use]
228 pub fn with_signing_keys(
229 mut self,
230 keys: Arc<crate::security::config::ResolvedSigningKeys>,
231 ) -> Self {
232 Arc::make_mut(&mut self.settings).signing_keys = Some(keys);
233 self
234 }
235}
236
237impl<S> Layer<S> for CsrfLayer {
238 type Service = CsrfService<S>;
239
240 fn layer(&self, inner: S) -> Self::Service {
241 CsrfService {
242 inner,
243 settings: Arc::clone(&self.settings),
244 }
245 }
246}
247
248#[derive(Clone, Debug)]
250pub struct CsrfService<S> {
251 inner: S,
252 settings: Arc<CsrfSettings>,
253}
254
255use subtle::{Choice, ConstantTimeEq};
256
257#[inline(never)]
263fn constant_time_eq(a: &str, b: &str) -> bool {
264 let a = a.as_bytes();
265 let b = b.as_bytes();
266
267 let len_eq = a.len().ct_eq(&b.len());
269
270 let mut bytes_eq = Choice::from(1u8);
277 for (i, &a_byte) in a.iter().enumerate() {
278 let b_byte = *b.get(i).unwrap_or(&0xFF);
279 bytes_eq &= a_byte.ct_eq(&b_byte);
280 }
281
282 (len_eq & bytes_eq).into()
283}
284
285fn extract_cookie_token(req_headers: &http::HeaderMap, cookie_name: &str) -> Option<String> {
287 let mut found_token = None;
288
289 for cookie_header in &req_headers.get_all(http::header::COOKIE) {
290 let Ok(cookie_str) = cookie_header.to_str() else {
291 continue;
292 };
293
294 for pair in cookie_str.split(';') {
295 let pair = pair.trim();
296 let Some((name, value)) = pair.split_once('=') else {
297 continue;
298 };
299
300 if name.trim() != cookie_name {
301 continue;
302 }
303
304 if found_token.is_some() {
305 return None;
309 }
310
311 found_token = Some(value.trim().to_owned());
312 }
313 }
314
315 found_token
316}
317
318impl<S, ResBody> Service<Request<axum::body::Body>> for CsrfService<S>
319where
320 S: Service<Request<axum::body::Body>, Response = Response<ResBody>> + Clone + Send + 'static,
321 S::Future: Send + 'static,
322 S::Error: Send + 'static,
323 ResBody: From<&'static str> + From<String> + Default + Send + 'static,
324{
325 type Response = S::Response;
326 type Error = S::Error;
327 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
328
329 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
330 self.inner.poll_ready(cx)
331 }
332
333 fn call(&mut self, mut req: Request<axum::body::Body>) -> Self::Future {
334 let path = req.uri().path();
335 let is_exempt = self
336 .settings
337 .exempt_paths
338 .iter()
339 .any(|prefix| path.starts_with(prefix.as_str()));
340 let is_safe = is_exempt || self.settings.safe_methods.contains(req.method());
341 let raw_cookie_token = extract_cookie_token(req.headers(), &self.settings.cookie_name);
342
343 let cookie_token = match (&raw_cookie_token, &self.settings.signing_keys) {
347 (Some(tok), Some(_)) if !validate_cookie_token_hmac(tok, &self.settings) => None,
348 _ => raw_cookie_token.clone(),
349 };
350
351 let token = cookie_token.clone().unwrap_or_else(|| {
354 let raw = Uuid::new_v4().to_string();
355 if let Some(keys) = &self.settings.signing_keys {
356 let sig = keys.sign(raw.as_bytes());
357 format!("{raw}.{sig}")
358 } else {
359 raw
360 }
361 });
362
363 req.extensions_mut().insert(CsrfToken(token.clone()));
365 req.extensions_mut()
366 .insert(CsrfFormField(self.settings.form_field.clone()));
367
368 let set_cookie = if cookie_token.is_none() {
370 Some(format!(
371 "{}={}; Path=/; SameSite=Lax; HttpOnly",
372 self.settings.cookie_name, token
373 ))
374 } else {
375 None
376 };
377
378 let settings = Arc::clone(&self.settings);
379 let mut inner = self.inner.clone();
380
381 std::mem::swap(&mut self.inner, &mut inner);
383
384 Box::pin(async move {
385 if !is_safe && !verify_csrf_token(&mut req, &settings, cookie_token.as_deref()).await {
386 let request_id = req
387 .extensions()
388 .get::<crate::middleware::RequestId>()
389 .map(std::string::ToString::to_string);
390 let instance = Some(req.uri().path().to_owned());
391 if wants_problem_details(req.headers()) {
392 return Ok(csrf_problem_response(request_id, instance));
393 }
394
395 let mut response = Response::new(ResBody::from(CSRF_FORBIDDEN_MESSAGE));
396 *response.status_mut() = StatusCode::FORBIDDEN;
397 response.headers_mut().insert(
398 http::header::CONTENT_TYPE,
399 http::HeaderValue::from_static("text/plain; charset=utf-8"),
400 );
401 return Ok(response);
402 }
403
404 let mut response = inner.call(req).await?;
406
407 if let Some(cookie) = set_cookie
408 && let Ok(val) = http::header::HeaderValue::from_str(&cookie)
409 {
410 response.headers_mut().append(http::header::SET_COOKIE, val);
411 }
412
413 Ok(response)
414 })
415 }
416}
417
418fn wants_problem_details(headers: &http::HeaderMap) -> bool {
419 !crate::middleware::error_page_filter::accept_prefers_html(headers)
420}
421
422fn csrf_problem_response<ResBody: From<String> + Default>(
423 request_id: Option<String>,
424 instance: Option<String>,
425) -> Response<ResBody> {
426 let mut problem = crate::error::problem_details(
427 StatusCode::FORBIDDEN,
428 CSRF_FORBIDDEN_MESSAGE.to_owned(),
429 None,
430 Some("https://autumn.dev/problems/csrf"),
431 request_id,
432 instance,
433 true,
434 );
435 "autumn.csrf".clone_into(&mut problem.code);
436 let body = crate::error::problem_details_to_json_string(&problem);
437
438 Response::builder()
439 .status(StatusCode::FORBIDDEN)
440 .header(http::header::CONTENT_TYPE, "application/problem+json")
441 .body(ResBody::from(body))
442 .unwrap_or_default()
443}
444
445fn validate_cookie_token_hmac(cookie_token: &str, settings: &CsrfSettings) -> bool {
450 let Some(keys) = &settings.signing_keys else {
451 return true; };
453 let Some((uuid_part, sig)) = cookie_token.split_once('.') else {
455 return false; };
457 keys.verify(uuid_part.as_bytes(), sig)
458}
459
460async fn verify_csrf_token(
461 req: &mut Request<axum::body::Body>,
462 settings: &CsrfSettings,
463 cookie_token: Option<&str>,
464) -> bool {
465 let mut token_found = false;
466
467 let header_token = req
469 .headers()
470 .get(&settings.token_header)
471 .and_then(|v| v.to_str().ok());
472
473 if let (Some(c), Some(h)) = (cookie_token, header_token)
474 && !c.is_empty()
475 && !h.is_empty()
476 && validate_cookie_token_hmac(c, settings)
477 && constant_time_eq(c, h)
478 {
479 token_found = true;
480 }
481
482 if token_found {
483 return true;
484 }
485
486 let content_type = req
488 .headers()
489 .get(http::header::CONTENT_TYPE)
490 .and_then(|v| v.to_str().ok())
491 .unwrap_or_default();
492
493 if !content_type.starts_with("application/x-www-form-urlencoded") {
494 return false;
495 }
496
497 let body = std::mem::replace(req.body_mut(), axum::body::Body::empty());
499
500 let bytes = axum::body::to_bytes(body, 2 * 1024 * 1024)
502 .await
503 .unwrap_or_else(|_| axum::body::Bytes::new());
504
505 for (key, value) in url::form_urlencoded::parse(&bytes) {
506 if key == settings.form_field {
507 if let Some(c) = cookie_token
508 && !c.is_empty()
509 && !value.is_empty()
510 && validate_cookie_token_hmac(c, settings)
511 && constant_time_eq(c, value.as_ref())
512 {
513 token_found = true;
514 }
515 break;
516 }
517 }
518
519 *req.body_mut() = axum::body::Body::from(bytes);
521
522 token_found
523}
524
525#[cfg(test)]
526mod tests {
527 #[tokio::test]
528 async fn post_with_url_encoded_token_passes() {
529 let raw_token = "abc+123/xyz=456";
530 let encoded_token = "abc%2B123%2Fxyz%3D456";
531 let app = Router::new()
532 .route("/submit", post(|| async { "created" }))
533 .layer(CsrfLayer::from_config(&default_csrf_config()));
534
535 let response = app
536 .oneshot(
537 Request::builder()
538 .method("POST")
539 .uri("/submit")
540 .header("Cookie", format!("autumn-csrf={raw_token}"))
541 .header("Content-Type", "application/x-www-form-urlencoded")
542 .body(Body::from(format!("_csrf={encoded_token}")))
543 .unwrap(),
544 )
545 .await
546 .unwrap();
547
548 assert_eq!(response.status(), StatusCode::OK);
549 }
550
551 use super::*;
552 use axum::Router;
553 use axum::body::Body;
554 use axum::routing::{get, post};
555 use tower::ServiceExt;
556
557 fn default_csrf_config() -> CsrfConfig {
558 CsrfConfig {
559 enabled: true,
560 ..Default::default()
561 }
562 }
563
564 #[tokio::test]
565 async fn safe_method_passes_without_token() {
566 let app = Router::new()
567 .route("/", get(|| async { "ok" }))
568 .layer(CsrfLayer::from_config(&default_csrf_config()));
569
570 let response = app
571 .oneshot(
572 Request::builder()
573 .method("GET")
574 .uri("/")
575 .body(Body::empty())
576 .unwrap(),
577 )
578 .await
579 .unwrap();
580
581 assert_eq!(response.status(), StatusCode::OK);
582 }
583
584 #[tokio::test]
585 async fn safe_method_sets_csrf_cookie() {
586 let app = Router::new()
587 .route("/", get(|| async { "ok" }))
588 .layer(CsrfLayer::from_config(&default_csrf_config()));
589
590 let response = app
591 .oneshot(
592 Request::builder()
593 .method("GET")
594 .uri("/")
595 .body(Body::empty())
596 .unwrap(),
597 )
598 .await
599 .unwrap();
600
601 let set_cookie = response
602 .headers()
603 .get("set-cookie")
604 .unwrap()
605 .to_str()
606 .unwrap();
607 assert!(set_cookie.starts_with("autumn-csrf="));
608 assert!(set_cookie.contains("HttpOnly"));
609 }
610
611 #[tokio::test]
612 async fn post_without_token_returns_403() {
613 let app = Router::new()
614 .route("/submit", post(|| async { "created" }))
615 .layer(CsrfLayer::from_config(&default_csrf_config()));
616
617 let response = app
618 .oneshot(
619 Request::builder()
620 .method("POST")
621 .uri("/submit")
622 .header(http::header::ACCEPT, "text/html")
623 .body(Body::empty())
624 .unwrap(),
625 )
626 .await
627 .unwrap();
628
629 assert_eq!(response.status(), StatusCode::FORBIDDEN);
630 }
631
632 #[tokio::test]
633 async fn forbidden_response_has_clear_error_body() {
634 let app = Router::new()
635 .route("/submit", post(|| async { "created" }))
636 .layer(CsrfLayer::from_config(&default_csrf_config()));
637
638 let response = app
639 .oneshot(
640 Request::builder()
641 .method("POST")
642 .uri("/submit")
643 .header(http::header::ACCEPT, "text/html")
644 .body(Body::empty())
645 .unwrap(),
646 )
647 .await
648 .unwrap();
649
650 assert_eq!(response.status(), StatusCode::FORBIDDEN);
651 assert_eq!(
652 response
653 .headers()
654 .get(http::header::CONTENT_TYPE)
655 .map(|v| v.to_str().unwrap_or_default()),
656 Some("text/plain; charset=utf-8")
657 );
658 let body = axum::body::to_bytes(response.into_body(), usize::MAX)
659 .await
660 .unwrap();
661 let text = std::str::from_utf8(&body).unwrap();
662 assert!(
663 text.contains("CSRF"),
664 "expected CSRF error message, got: {text:?}"
665 );
666 }
667
668 #[tokio::test]
669 async fn exempt_path_skips_csrf_validation() {
670 let config = CsrfConfig {
671 enabled: true,
672 exempt_paths: vec!["/api/".to_string()],
673 ..Default::default()
674 };
675 let app = Router::new()
676 .route("/api/items", post(|| async { "created" }))
677 .route("/form/submit", post(|| async { "created" }))
678 .layer(CsrfLayer::from_config(&config));
679
680 let response = app
682 .clone()
683 .oneshot(
684 Request::builder()
685 .method("POST")
686 .uri("/api/items")
687 .body(Body::empty())
688 .unwrap(),
689 )
690 .await
691 .unwrap();
692 assert_eq!(response.status(), StatusCode::OK);
693
694 let response = app
696 .oneshot(
697 Request::builder()
698 .method("POST")
699 .uri("/form/submit")
700 .body(Body::empty())
701 .unwrap(),
702 )
703 .await
704 .unwrap();
705 assert_eq!(response.status(), StatusCode::FORBIDDEN);
706 }
707
708 #[tokio::test]
709 async fn post_with_valid_token_passes() {
710 let token = Uuid::new_v4().to_string();
711 let app = Router::new()
712 .route("/submit", post(|| async { "created" }))
713 .layer(CsrfLayer::from_config(&default_csrf_config()));
714
715 let response = app
716 .oneshot(
717 Request::builder()
718 .method("POST")
719 .uri("/submit")
720 .header("Cookie", format!("autumn-csrf={token}"))
721 .header("X-CSRF-Token", &token)
722 .body(Body::empty())
723 .unwrap(),
724 )
725 .await
726 .unwrap();
727
728 assert_eq!(response.status(), StatusCode::OK);
729 }
730
731 #[tokio::test]
732 async fn post_with_mismatched_token_returns_403() {
733 let cookie_token = Uuid::new_v4().to_string();
734 let header_token = Uuid::new_v4().to_string();
735 let app = Router::new()
736 .route("/submit", post(|| async { "created" }))
737 .layer(CsrfLayer::from_config(&default_csrf_config()));
738
739 let response = app
740 .oneshot(
741 Request::builder()
742 .method("POST")
743 .uri("/submit")
744 .header("Cookie", format!("autumn-csrf={cookie_token}"))
745 .header("X-CSRF-Token", &header_token)
746 .body(Body::empty())
747 .unwrap(),
748 )
749 .await
750 .unwrap();
751
752 assert_eq!(response.status(), StatusCode::FORBIDDEN);
753 }
754
755 #[tokio::test]
756 async fn csrf_token_extractor_works() {
757 async fn handler(csrf: CsrfToken) -> String {
758 csrf.token().to_owned()
759 }
760
761 let app = Router::new()
762 .route("/", get(handler))
763 .layer(CsrfLayer::from_config(&default_csrf_config()));
764
765 let response = app
766 .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
767 .await
768 .unwrap();
769
770 assert_eq!(response.status(), StatusCode::OK);
771 let body = axum::body::to_bytes(response.into_body(), usize::MAX)
772 .await
773 .unwrap();
774 let token_str = String::from_utf8(body.to_vec()).unwrap();
775 assert!(Uuid::parse_str(&token_str).is_ok());
776 }
777
778 #[test]
779 fn extract_cookie_from_header() {
780 let mut headers = http::HeaderMap::new();
781 headers.insert(
782 http::header::COOKIE,
783 "autumn-csrf=abc123; other=xyz".parse().unwrap(),
784 );
785 assert_eq!(
786 extract_cookie_token(&headers, "autumn-csrf"),
787 Some("abc123".to_owned())
788 );
789 }
790
791 #[test]
792 fn missing_cookie_returns_none() {
793 let headers = http::HeaderMap::new();
794 assert_eq!(extract_cookie_token(&headers, "autumn-csrf"), None);
795 }
796
797 #[test]
798 fn extract_cookie_rejects_multiple_cookies() {
799 let mut headers = http::HeaderMap::new();
801 headers.insert(
802 http::header::COOKIE,
803 "autumn-csrf=abc123; autumn-csrf=xyz456".parse().unwrap(),
804 );
805 assert_eq!(extract_cookie_token(&headers, "autumn-csrf"), None);
806
807 let mut headers2 = http::HeaderMap::new();
809 headers2.append(http::header::COOKIE, "autumn-csrf=abc123".parse().unwrap());
810 headers2.append(http::header::COOKIE, "autumn-csrf=xyz456".parse().unwrap());
811 assert_eq!(extract_cookie_token(&headers2, "autumn-csrf"), None);
812 }
813
814 #[test]
815 fn extract_cookie_ignores_malformed_cookies() {
816 let mut headers = http::HeaderMap::new();
817 headers.insert(http::header::COOKIE, "autumn-csrf abc123".parse().unwrap());
819 assert_eq!(extract_cookie_token(&headers, "autumn-csrf"), None);
820
821 headers.insert(
823 http::header::COOKIE,
824 " autumn-csrf = abc123 ; other=xyz".parse().unwrap(),
825 );
826 assert_eq!(
827 extract_cookie_token(&headers, "autumn-csrf"),
828 Some("abc123".to_owned())
829 );
830 }
831
832 #[test]
833 fn test_constant_time_eq() {
834 assert!(super::constant_time_eq("abc", "abc"));
835 assert!(!super::constant_time_eq("abc", "ab"));
836 assert!(!super::constant_time_eq("abc", "abd"));
837 assert!(super::constant_time_eq("", ""));
838 assert!(!super::constant_time_eq("a", "b"));
839 assert!(!super::constant_time_eq("a", "A"));
840 }
841
842 #[tokio::test]
843 async fn post_with_empty_cookie_but_valid_header() {
844 let token = Uuid::new_v4().to_string();
845 let app = Router::new()
846 .route("/submit", post(|| async { "created" }))
847 .layer(CsrfLayer::from_config(&default_csrf_config()));
848
849 let response = app
850 .oneshot(
851 Request::builder()
852 .method("POST")
853 .uri("/submit")
854 .header("Cookie", "autumn-csrf=")
855 .header("X-CSRF-Token", &token)
856 .body(Body::empty())
857 .unwrap(),
858 )
859 .await
860 .unwrap();
861
862 assert_eq!(response.status(), StatusCode::FORBIDDEN);
863 }
864
865 #[tokio::test]
866 async fn post_with_valid_cookie_but_empty_header() {
867 let token = Uuid::new_v4().to_string();
868 let app = Router::new()
869 .route("/submit", post(|| async { "created" }))
870 .layer(CsrfLayer::from_config(&default_csrf_config()));
871
872 let response = app
873 .oneshot(
874 Request::builder()
875 .method("POST")
876 .uri("/submit")
877 .header("Cookie", format!("autumn-csrf={token}"))
878 .header("X-CSRF-Token", "")
879 .body(Body::empty())
880 .unwrap(),
881 )
882 .await
883 .unwrap();
884
885 assert_eq!(response.status(), StatusCode::FORBIDDEN);
886 }
887
888 #[tokio::test]
889 async fn post_with_empty_cookie_but_valid_form_field() {
890 let token = Uuid::new_v4().to_string();
891 let app = Router::new()
892 .route("/submit", post(|| async { "created" }))
893 .layer(CsrfLayer::from_config(&default_csrf_config()));
894
895 let response = app
896 .oneshot(
897 Request::builder()
898 .method("POST")
899 .uri("/submit")
900 .header("Cookie", "autumn-csrf=")
901 .header("Content-Type", "application/x-www-form-urlencoded")
902 .body(Body::from(format!("_csrf={token}")))
903 .unwrap(),
904 )
905 .await
906 .unwrap();
907
908 assert_eq!(response.status(), StatusCode::FORBIDDEN);
909 }
910
911 #[tokio::test]
912 async fn post_with_valid_cookie_but_empty_form_field() {
913 let token = Uuid::new_v4().to_string();
914 let app = Router::new()
915 .route("/submit", post(|| async { "created" }))
916 .layer(CsrfLayer::from_config(&default_csrf_config()));
917
918 let response = app
919 .oneshot(
920 Request::builder()
921 .method("POST")
922 .uri("/submit")
923 .header("Cookie", format!("autumn-csrf={token}"))
924 .header("Content-Type", "application/x-www-form-urlencoded")
925 .body(Body::from("_csrf="))
926 .unwrap(),
927 )
928 .await
929 .unwrap();
930
931 assert_eq!(response.status(), StatusCode::FORBIDDEN);
932 }
933
934 #[tokio::test]
935 async fn post_with_large_body_fails_csrf() {
936 let token = Uuid::new_v4().to_string();
937 let app = Router::new()
938 .route("/submit", post(|| async { "created" }))
939 .layer(CsrfLayer::from_config(&default_csrf_config()));
940
941 let large_padding = "a".repeat(2 * 1024 * 1024 + 10);
943 let body_content = format!("_csrf={token}&pad={large_padding}");
944
945 let response = app
946 .oneshot(
947 Request::builder()
948 .method("POST")
949 .uri("/submit")
950 .header("Cookie", format!("autumn-csrf={token}"))
951 .header("Content-Type", "application/x-www-form-urlencoded")
952 .body(Body::from(body_content))
953 .unwrap(),
954 )
955 .await
956 .unwrap();
957
958 assert_eq!(response.status(), StatusCode::FORBIDDEN);
959 }
960
961 #[tokio::test]
962 async fn post_with_empty_tokens_returns_403() {
963 let app = Router::new()
964 .route("/submit", post(|| async { "created" }))
965 .layer(CsrfLayer::from_config(&CsrfConfig {
966 enabled: true,
967 ..Default::default()
968 }));
969
970 let response = app
971 .oneshot(
972 Request::builder()
973 .method("POST")
974 .uri("/submit")
975 .header("Cookie", "autumn-csrf=")
976 .header("X-CSRF-Token", "")
977 .body(Body::empty())
978 .unwrap(),
979 )
980 .await
981 .unwrap();
982
983 assert_eq!(response.status(), StatusCode::FORBIDDEN);
984 }
985
986 #[tokio::test]
987 async fn post_with_empty_form_tokens_returns_403() {
988 let app = Router::new()
989 .route("/submit", post(|| async { "created" }))
990 .layer(CsrfLayer::from_config(&CsrfConfig {
991 enabled: true,
992 ..Default::default()
993 }));
994
995 let response = app
996 .oneshot(
997 Request::builder()
998 .method("POST")
999 .uri("/submit")
1000 .header("Cookie", "autumn-csrf=")
1001 .header("Content-Type", "application/x-www-form-urlencoded")
1002 .body(Body::from("_csrf="))
1003 .unwrap(),
1004 )
1005 .await
1006 .unwrap();
1007
1008 assert_eq!(response.status(), StatusCode::FORBIDDEN);
1009 }
1010
1011 #[test]
1012 fn from_config_filters_invalid_methods() {
1013 let config = CsrfConfig {
1014 safe_methods: vec![
1015 "GET".to_string(),
1016 "INVALID METHOD".to_string(),
1017 "POST".to_string(),
1018 ],
1019 ..Default::default()
1020 };
1021 let layer = CsrfLayer::from_config(&config);
1022 assert_eq!(layer.settings.safe_methods.len(), 2);
1023 assert!(layer.settings.safe_methods.contains(&http::Method::GET));
1024 assert!(layer.settings.safe_methods.contains(&http::Method::POST));
1025 }
1026
1027 #[test]
1028 fn from_config_handles_invalid_header_name() {
1029 let config = CsrfConfig {
1030 token_header: "Invalid Header Name\n".to_string(),
1031 ..Default::default()
1032 };
1033 let layer = CsrfLayer::from_config(&config);
1034 assert_eq!(layer.settings.token_header.as_str(), "x-csrf-token");
1035 }
1036
1037 #[tokio::test]
1040 async fn csrf_token_is_hmac_signed_when_keys_set() {
1041 use crate::security::config::{SigningSecretConfig, resolve_signing_keys};
1042 use std::sync::Arc;
1043
1044 let keys = Arc::new(resolve_signing_keys(&SigningSecretConfig {
1045 secret: Some("k".repeat(32)),
1046 previous_secrets: vec![],
1047 }));
1048 let layer = CsrfLayer::from_config(&default_csrf_config()).with_signing_keys(keys);
1049
1050 let app = Router::new()
1051 .route("/", get(|| async { "ok" }))
1052 .layer(layer);
1053
1054 let resp = app
1055 .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
1056 .await
1057 .unwrap();
1058
1059 let set_cookie = resp
1060 .headers()
1061 .get("set-cookie")
1062 .expect("should set CSRF cookie")
1063 .to_str()
1064 .unwrap();
1065 let cookie_value = set_cookie
1066 .split('=')
1067 .nth(1)
1068 .unwrap()
1069 .split(';')
1070 .next()
1071 .unwrap()
1072 .trim();
1073
1074 assert!(
1075 cookie_value.contains('.'),
1076 "signed CSRF cookie must be {{uuid}}.{{hmac}}, got: {cookie_value}"
1077 );
1078 let (_uuid_part, sig_part) = cookie_value.split_once('.').unwrap();
1079 assert_eq!(sig_part.len(), 64, "HMAC hex must be 64 chars");
1080 }
1081
1082 #[tokio::test]
1083 async fn csrf_signed_token_validates_on_post() {
1084 use crate::security::config::{SigningSecretConfig, resolve_signing_keys};
1085 use std::sync::Arc;
1086
1087 let keys = Arc::new(resolve_signing_keys(&SigningSecretConfig {
1088 secret: Some("k".repeat(32)),
1089 previous_secrets: vec![],
1090 }));
1091 let layer = CsrfLayer::from_config(&default_csrf_config()).with_signing_keys(keys);
1092
1093 let app = Router::new()
1094 .route("/", post(|| async { "created" }))
1095 .layer(layer);
1096
1097 let config = SigningSecretConfig {
1099 secret: Some("k".repeat(32)),
1100 previous_secrets: vec![],
1101 };
1102 let signing_keys = resolve_signing_keys(&config);
1103 let uuid = uuid::Uuid::new_v4().to_string();
1104 let sig = signing_keys.sign(uuid.as_bytes());
1105 let signed_token = format!("{uuid}.{sig}");
1106
1107 let resp = app
1108 .oneshot(
1109 Request::builder()
1110 .method("POST")
1111 .uri("/")
1112 .header("Cookie", format!("autumn-csrf={signed_token}"))
1113 .header("X-CSRF-Token", &signed_token)
1114 .body(Body::empty())
1115 .unwrap(),
1116 )
1117 .await
1118 .unwrap();
1119
1120 assert_eq!(resp.status(), StatusCode::OK);
1121 }
1122
1123 #[tokio::test]
1124 async fn csrf_unsigned_token_rejected_when_signing_active() {
1125 use crate::security::config::{SigningSecretConfig, resolve_signing_keys};
1126 use std::sync::Arc;
1127
1128 let keys = Arc::new(resolve_signing_keys(&SigningSecretConfig {
1129 secret: Some("k".repeat(32)),
1130 previous_secrets: vec![],
1131 }));
1132 let layer = CsrfLayer::from_config(&default_csrf_config()).with_signing_keys(keys);
1133
1134 let app = Router::new()
1135 .route("/", post(|| async { "created" }))
1136 .layer(layer);
1137
1138 let raw_token = uuid::Uuid::new_v4().to_string();
1140 let resp = app
1141 .oneshot(
1142 Request::builder()
1143 .method("POST")
1144 .uri("/")
1145 .header("Cookie", format!("autumn-csrf={raw_token}"))
1146 .header("X-CSRF-Token", &raw_token)
1147 .body(Body::empty())
1148 .unwrap(),
1149 )
1150 .await
1151 .unwrap();
1152
1153 assert_eq!(
1154 resp.status(),
1155 StatusCode::FORBIDDEN,
1156 "unsigned CSRF token must be rejected when signing is active"
1157 );
1158 }
1159
1160 #[tokio::test]
1161 async fn csrf_previous_key_signed_token_accepted() {
1162 use crate::security::config::{
1163 ResolvedSigningKeys, SigningSecretConfig, resolve_signing_keys,
1164 };
1165 use std::sync::Arc;
1166
1167 let old_secret = "old-key".repeat(5); let old_keys = resolve_signing_keys(&SigningSecretConfig {
1169 secret: Some(old_secret.clone()),
1170 previous_secrets: vec![],
1171 });
1172
1173 let uuid = uuid::Uuid::new_v4().to_string();
1174 let old_sig = old_keys.sign(uuid.as_bytes());
1175 let old_signed_token = format!("{uuid}.{old_sig}");
1176
1177 let new_keys = Arc::new(ResolvedSigningKeys::new(
1178 "new-key".repeat(5).into_bytes(),
1179 vec![old_secret.into_bytes()],
1180 ));
1181 let layer = CsrfLayer::from_config(&default_csrf_config()).with_signing_keys(new_keys);
1182
1183 let app = Router::new()
1184 .route("/", post(|| async { "created" }))
1185 .layer(layer);
1186
1187 let resp = app
1188 .oneshot(
1189 Request::builder()
1190 .method("POST")
1191 .uri("/")
1192 .header("Cookie", format!("autumn-csrf={old_signed_token}"))
1193 .header("X-CSRF-Token", &old_signed_token)
1194 .body(Body::empty())
1195 .unwrap(),
1196 )
1197 .await
1198 .unwrap();
1199
1200 assert_eq!(
1201 resp.status(),
1202 StatusCode::OK,
1203 "previous-key-signed CSRF token must pass during grace window"
1204 );
1205 }
1206}