1#![doc(html_favicon_url = "https://salvo.rs/favicon-32x32.png")]
11#![doc(html_logo_url = "https://salvo.rs/images/logo.svg")]
12#![cfg_attr(docsrs, feature(doc_cfg))]
13
14use std::error::Error as StdError;
15use std::fmt::{self, Debug, Formatter};
16
17mod finder;
18
19pub use finder::{CsrfTokenFinder, FormFinder, HeaderFinder, JsonFinder};
20use rand::Rng;
21use rand::distr::StandardUniform;
22use salvo_core::handler::Skipper;
23use salvo_core::http::{Method, StatusCode};
24use salvo_core::{Depot, FlowCtrl, Handler, Request, Response, async_trait};
25
26#[macro_use]
27mod cfg;
28
29cfg_feature! {
30 #![feature = "cookie-store"]
31
32 mod cookie_store;
33 pub use cookie_store::CookieStore;
34
35 #[must_use] pub fn cookie_store<>() -> CookieStore {
37 CookieStore::new()
38 }
39}
40cfg_feature! {
41 #![feature = "session-store"]
42
43 mod session_store;
44 pub use session_store::SessionStore;
45
46 #[must_use]
48 pub fn session_store() -> SessionStore {
49 SessionStore::new()
50 }
51}
52cfg_feature! {
53 #![feature = "bcrypt-cipher"]
54
55 mod bcrypt_cipher;
56 pub use bcrypt_cipher::BcryptCipher;
57
58 pub fn bcrypt_csrf<S>(store: S, finder: impl CsrfTokenFinder ) -> Csrf<BcryptCipher, S> where S: CsrfStore {
60 Csrf::new(BcryptCipher::new(), store, finder)
61 }
62}
63cfg_feature! {
64 #![all(feature = "bcrypt-cipher", feature = "cookie-store")]
65 pub fn bcrypt_cookie_csrf(finder: impl CsrfTokenFinder ) -> Csrf<BcryptCipher, CookieStore> {
67 Csrf::new(BcryptCipher::new(), CookieStore::new(), finder)
68 }
69}
70cfg_feature! {
71 #![all(feature = "bcrypt-cipher", feature = "session-store")]
72 pub fn bcrypt_session_csrf(finder: impl CsrfTokenFinder ) -> Csrf<BcryptCipher, SessionStore> {
74 Csrf::new(BcryptCipher::new(), SessionStore::new(), finder)
75 }
76}
77
78cfg_feature! {
79 #![feature = "hmac-cipher"]
80
81 mod hmac_cipher;
82 pub use hmac_cipher::HmacCipher;
83
84 pub fn hmac_csrf<S>(hmac_key: [u8; 32], store: S, finder: impl CsrfTokenFinder ) -> Csrf<HmacCipher, S> where S: CsrfStore {
86 Csrf::new(HmacCipher::new(hmac_key), store, finder)
87 }
88}
89cfg_feature! {
90 #![all(feature = "hmac-cipher", feature = "cookie-store")]
91 pub fn hmac_cookie_csrf(aead_key: [u8; 32], finder: impl CsrfTokenFinder ) -> Csrf<HmacCipher, CookieStore> {
93 Csrf::new(HmacCipher::new(aead_key), CookieStore::new(), finder)
94 }
95}
96cfg_feature! {
97 #![all(feature = "hmac-cipher", feature = "session-store")]
98 pub fn hmac_session_csrf(aead_key: [u8; 32], finder: impl CsrfTokenFinder ) -> Csrf<HmacCipher, SessionStore> {
100 Csrf::new(HmacCipher::new(aead_key), SessionStore::new(), finder)
101 }
102}
103
104cfg_feature! {
105 #![feature = "aes-gcm-cipher"]
106
107 mod aes_gcm_cipher;
108 pub use aes_gcm_cipher::AesGcmCipher;
109
110 pub fn aes_gcm_csrf<S>(aead_key: [u8; 32], store: S, finder: impl CsrfTokenFinder ) -> Csrf<AesGcmCipher, S> where S: CsrfStore {
112 Csrf::new(AesGcmCipher::new(aead_key), store, finder)
113 }
114}
115cfg_feature! {
116 #![all(feature = "aes-gcm-cipher", feature = "cookie-store")]
117 pub fn aes_gcm_cookie_csrf(aead_key: [u8; 32], finder: impl CsrfTokenFinder ) -> Csrf<AesGcmCipher, CookieStore> {
119 Csrf::new(AesGcmCipher::new(aead_key), CookieStore::new(), finder)
120 }
121}
122cfg_feature! {
123 #![all(feature = "aes-gcm-cipher", feature = "session-store")]
124 pub fn aes_gcm_session_csrf(aead_key: [u8; 32], finder: impl CsrfTokenFinder ) -> Csrf<AesGcmCipher, SessionStore> {
126 Csrf::new(AesGcmCipher::new(aead_key), SessionStore::new(), finder)
127 }
128}
129
130cfg_feature! {
131 #![feature = "ccp-cipher"]
132
133 mod ccp_cipher;
134 pub use ccp_cipher::CcpCipher;
135
136 pub fn ccp_csrf<S>(aead_key: [u8; 32], store: S, finder: impl CsrfTokenFinder ) -> Csrf<CcpCipher, S> where S: CsrfStore {
138 Csrf::new(CcpCipher::new(aead_key), store, finder)
139 }
140}
141cfg_feature! {
142 #![all(feature = "ccp-cipher", feature = "cookie-store")]
143 pub fn ccp_cookie_csrf(aead_key: [u8; 32], finder: impl CsrfTokenFinder ) -> Csrf<CcpCipher, CookieStore> {
145 Csrf::new(CcpCipher::new(aead_key), CookieStore::new(), finder)
146 }
147}
148cfg_feature! {
149 #![all(feature = "ccp-cipher", feature = "session-store")]
150 pub fn ccp_session_csrf(aead_key: [u8; 32], finder: impl CsrfTokenFinder ) -> Csrf<CcpCipher, SessionStore> {
152 Csrf::new(CcpCipher::new(aead_key), SessionStore::new(), finder)
153 }
154}
155
156pub const CSRF_TOKEN_KEY: &str = "salvo.csrf.token";
158
159fn default_skipper(req: &mut Request, _depot: &Depot) -> bool {
160 ![Method::POST, Method::PATCH, Method::DELETE, Method::PUT].contains(req.method())
161}
162
163pub trait CsrfStore: Send + Sync + 'static {
165 type Error: StdError + Send + Sync + 'static;
167 fn load<C: CsrfCipher>(
169 &self,
170 req: &mut Request,
171 depot: &mut Depot,
172 cipher: &C,
173 ) -> impl Future<Output = Option<(String, String)>> + Send;
174 fn save(
176 &self,
177 req: &mut Request,
178 depot: &mut Depot,
179 res: &mut Response,
180 token: &str,
181 proof: &str,
182 ) -> impl Future<Output = Result<(), Self::Error>> + Send;
183}
184
185pub trait CsrfCipher: Send + Sync + 'static {
187 fn verify(&self, token: &str, proof: &str) -> bool;
189 fn generate(&self) -> (String, String);
191
192 fn random_bytes(&self, len: usize) -> Vec<u8> {
194 rand::rng().sample_iter(StandardUniform).take(len).collect()
195 }
196}
197
198pub trait CsrfDepotExt {
200 fn csrf_token(&self) -> Option<&str>;
202}
203
204impl CsrfDepotExt for Depot {
205 #[inline]
206 fn csrf_token(&self) -> Option<&str> {
207 self.get::<String>(CSRF_TOKEN_KEY).map(|v| &**v).ok()
208 }
209}
210
211pub struct Csrf<C, S> {
213 cipher: C,
214 store: S,
215 skipper: Box<dyn Skipper>,
216 finders: Vec<Box<dyn CsrfTokenFinder>>,
217}
218
219impl<C, S> Debug for Csrf<C, S>
220where
221 C: Debug,
222 S: Debug,
223{
224 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
225 f.debug_struct("Csrf")
226 .field("cipher", &self.cipher)
227 .field("store", &self.store)
228 .finish()
229 }
230}
231
232impl<C: CsrfCipher, S: CsrfStore> Csrf<C, S> {
233 #[inline]
235 #[must_use]
236 pub fn new(cipher: C, store: S, finder: impl CsrfTokenFinder) -> Self {
237 Self {
238 cipher,
239 store,
240 skipper: Box::new(default_skipper),
241 finders: vec![Box::new(finder)],
242 }
243 }
244
245 #[inline]
247 #[must_use]
248 pub fn add_finder(mut self, finder: impl CsrfTokenFinder) -> Self {
249 self.finders.push(Box::new(finder));
250 self
251 }
252
253 async fn find_token(&self, req: &mut Request) -> Option<String> {
268 for finder in self.finders.iter() {
269 if let Some(token) = finder.find_token(req).await {
270 return Some(token);
271 }
272 }
273 None
274 }
275}
276
277#[async_trait]
278impl<C: CsrfCipher, S: CsrfStore> Handler for Csrf<C, S> {
279 async fn handle(
280 &self,
281 req: &mut Request,
282 depot: &mut Depot,
283 res: &mut Response,
284 ctrl: &mut FlowCtrl,
285 ) {
286 match self.store.load(req, depot, &self.cipher).await {
287 Some((token, proof)) => {
288 depot.insert(CSRF_TOKEN_KEY, token);
289
290 if !self.skipper.skipped(req, depot) {
291 if let Some(token) = &self.find_token(req).await {
292 tracing::debug!("csrf token: {token}");
293 if !self.cipher.verify(token, &proof) {
294 tracing::debug!(
295 "rejecting request due to invalid or expired CSRF token"
296 );
297 res.status_code(StatusCode::FORBIDDEN);
298 ctrl.skip_rest();
299 return;
300 } else {
301 tracing::debug!("cipher verify CSRF token success");
302 }
303 } else {
304 tracing::debug!("rejecting request due to missing CSRF token",);
305 res.status_code(StatusCode::FORBIDDEN);
306 ctrl.skip_rest();
307 return;
308 }
309 }
310 ctrl.call_next(req, depot, res).await;
311 }
312 None => {
313 if !self.skipper.skipped(req, depot) {
314 tracing::debug!("rejecting request due to missing CSRF token",);
315 res.status_code(StatusCode::FORBIDDEN);
316 ctrl.skip_rest();
317 } else {
318 let (token, proof) = self.cipher.generate();
319 if let Err(e) = self.store.save(req, depot, res, &token, &proof).await {
320 tracing::error!(error = ?e, "salvo csrf token failed");
321 }
322 tracing::debug!("new token: {:?}", token);
323 depot.insert(CSRF_TOKEN_KEY, token);
324 ctrl.call_next(req, depot, res).await;
325 }
326 }
327 }
328 }
329}
330
331#[cfg(test)]
332mod tests {
333 use salvo_core::prelude::*;
334 use salvo_core::test::{ResponseExt, TestClient};
335
336 use super::*;
337
338 #[handler]
339 async fn get_index(depot: &mut Depot) -> String {
340 depot.csrf_token().unwrap().to_owned()
341 }
342 #[handler]
343 async fn post_index() -> &'static str {
344 "POST"
345 }
346
347 #[tokio::test]
348 async fn test_exposes_csrf_request_extensions() {
349 let csrf = Csrf::new(
350 BcryptCipher::new(),
351 CookieStore::new(),
352 HeaderFinder::new("x-csrf-token"),
353 );
354 let router = Router::new().hoop(csrf).get(get_index);
355 let res = TestClient::get("http://127.0.0.1:5801").send(router).await;
356 assert_eq!(res.status_code.unwrap(), StatusCode::OK);
357 }
358
359 #[tokio::test]
360 async fn test_adds_csrf_cookie_sets_request_token() {
361 let csrf = Csrf::new(
362 BcryptCipher::new(),
363 CookieStore::new(),
364 HeaderFinder::new("x-csrf-token"),
365 );
366 let router = Router::new().hoop(csrf).get(get_index);
367
368 let mut res = TestClient::get("http://127.0.0.1:5801").send(router).await;
369
370 assert_eq!(res.status_code.unwrap(), StatusCode::OK);
371 assert_ne!(res.take_string().await.unwrap(), "");
372 assert_ne!(res.cookie("salvo.csrf"), None);
373 }
374
375 #[tokio::test]
376 async fn test_validates_token_in_header() {
377 let csrf = Csrf::new(
378 BcryptCipher::new(),
379 CookieStore::new(),
380 HeaderFinder::new("x-csrf-token"),
381 );
382 let router = Router::new().hoop(csrf).get(get_index).post(post_index);
383 let service = Service::new(router);
384
385 let mut res = TestClient::get("http://127.0.0.1:5801")
386 .send(&service)
387 .await;
388 assert_eq!(res.status_code.unwrap(), StatusCode::OK);
389
390 let csrf_token = res.take_string().await.unwrap();
391 let cookie = res.cookie("salvo.csrf").unwrap();
392
393 let res = TestClient::post("http://127.0.0.1:5801")
394 .send(&service)
395 .await;
396 assert_eq!(res.status_code.unwrap(), StatusCode::FORBIDDEN);
397
398 let mut res = TestClient::post("http://127.0.0.1:5801")
399 .add_header("x-csrf-token", csrf_token, true)
400 .add_header("cookie", cookie.to_string(), true)
401 .send(&service)
402 .await;
403 assert_eq!(res.status_code.unwrap(), StatusCode::OK);
404 assert_eq!(res.take_string().await.unwrap(), "POST");
405 }
406
407 #[tokio::test]
408 async fn test_validates_token_in_custom_header() {
409 let csrf = Csrf::new(
410 BcryptCipher::new(),
411 CookieStore::new(),
412 HeaderFinder::new("x-mycsrf-header"),
413 );
414 let router = Router::new().hoop(csrf).get(get_index).post(post_index);
415 let service = Service::new(router);
416
417 let mut res = TestClient::get("http://127.0.0.1:5801")
418 .send(&service)
419 .await;
420 assert_eq!(res.status_code.unwrap(), StatusCode::OK);
421
422 let csrf_token = res.take_string().await.unwrap();
423 let cookie = res.cookie("salvo.csrf").unwrap();
424
425 let res = TestClient::post("http://127.0.0.1:5801")
426 .send(&service)
427 .await;
428 assert_eq!(res.status_code.unwrap(), StatusCode::FORBIDDEN);
429
430 let mut res = TestClient::post("http://127.0.0.1:5801")
431 .add_header("x-mycsrf-header", csrf_token, true)
432 .add_header("cookie", cookie.to_string(), true)
433 .send(&service)
434 .await;
435 assert_eq!(res.status_code.unwrap(), StatusCode::OK);
436 assert_eq!(res.take_string().await.unwrap(), "POST");
437 }
438
439 #[tokio::test]
440 async fn test_validates_token_in_query() {
441 let csrf = Csrf::new(
442 BcryptCipher::new(),
443 CookieStore::new(),
444 HeaderFinder::new("csrf-token"),
445 );
446 let router = Router::new().hoop(csrf).get(get_index).post(post_index);
447 let service = Service::new(router);
448
449 let mut res = TestClient::get("http://127.0.0.1:5801")
450 .send(&service)
451 .await;
452 assert_eq!(res.status_code.unwrap(), StatusCode::OK);
453
454 let csrf_token = res.take_string().await.unwrap();
455 let cookie = res.cookie("salvo.csrf").unwrap();
456
457 let res = TestClient::post("http://127.0.0.1:5801")
458 .send(&service)
459 .await;
460 assert_eq!(res.status_code.unwrap(), StatusCode::FORBIDDEN);
461
462 let mut res = TestClient::post("http://127.0.0.1:5801?a=1&b=2")
463 .add_header("csrf-token", csrf_token, true)
464 .add_header("cookie", cookie.to_string(), true)
465 .send(&service)
466 .await;
467 assert_eq!(res.status_code.unwrap(), StatusCode::OK);
468 assert_eq!(res.take_string().await.unwrap(), "POST");
469 }
470 #[cfg(feature = "hmac-cipher")]
471 #[tokio::test]
472 async fn test_validates_token_in_alternate_query() {
473 let csrf = Csrf::new(
474 HmacCipher::new(*b"01234567012345670123456701234567"),
475 CookieStore::new(),
476 HeaderFinder::new("my-csrf-token"),
477 );
478 let router = Router::new().hoop(csrf).get(get_index).post(post_index);
479 let service = Service::new(router);
480
481 let mut res = TestClient::get("http://127.0.0.1:5801")
482 .send(&service)
483 .await;
484 assert_eq!(res.status_code.unwrap(), StatusCode::OK);
485
486 let csrf_token = res.take_string().await.unwrap();
487 let cookie = res.cookie("salvo.csrf").unwrap();
488
489 let res = TestClient::post("http://127.0.0.1:5801")
490 .send(&service)
491 .await;
492 assert_eq!(res.status_code.unwrap(), StatusCode::FORBIDDEN);
493
494 let mut res = TestClient::post("http://127.0.0.1:5801?a=1&b=2")
495 .add_header("my-csrf-token", csrf_token, true)
496 .add_header("cookie", cookie.to_string(), true)
497 .send(&service)
498 .await;
499 assert_eq!(res.status_code.unwrap(), StatusCode::OK);
500 assert_eq!(res.take_string().await.unwrap(), "POST");
501 }
502
503 #[cfg(feature = "hmac-cipher")]
504 #[tokio::test]
505 async fn test_validates_token_in_form() {
506 let csrf = Csrf::new(
507 HmacCipher::new(*b"01234567012345670123456701234567"),
508 CookieStore::new(),
509 FormFinder::new("csrf-token"),
510 );
511 let router = Router::new().hoop(csrf).get(get_index).post(post_index);
512 let service = Service::new(router);
513
514 let mut res = TestClient::get("http://127.0.0.1:5801")
515 .send(&service)
516 .await;
517 assert_eq!(res.status_code.unwrap(), StatusCode::OK);
518
519 let csrf_token = res.take_string().await.unwrap();
520 let cookie = res.cookie("salvo.csrf").unwrap();
521
522 let res = TestClient::post("http://127.0.0.1:5801")
523 .send(&service)
524 .await;
525 assert_eq!(res.status_code.unwrap(), StatusCode::FORBIDDEN);
526
527 let mut res = TestClient::post("http://127.0.0.1:5801")
528 .add_header("cookie", cookie.to_string(), true)
529 .form(&[("a", "1"), ("csrf-token", &*csrf_token), ("b", "2")])
530 .send(&service)
531 .await;
532 assert_eq!(res.status_code.unwrap(), StatusCode::OK);
533 assert_eq!(res.take_string().await.unwrap(), "POST");
534 }
535 #[tokio::test]
536 async fn test_validates_token_in_alternate_form() {
537 let csrf = Csrf::new(
538 BcryptCipher::new(),
539 CookieStore::new(),
540 FormFinder::new("my-csrf-token"),
541 );
542 let router = Router::new().hoop(csrf).get(get_index).post(post_index);
543 let service = Service::new(router);
544
545 let mut res = TestClient::get("http://127.0.0.1:5801")
546 .send(&service)
547 .await;
548 assert_eq!(res.status_code.unwrap(), StatusCode::OK);
549
550 let csrf_token = res.take_string().await.unwrap();
551 let cookie = res.cookie("salvo.csrf").unwrap();
552
553 let res = TestClient::post("http://127.0.0.1:5801")
554 .send(&service)
555 .await;
556 assert_eq!(res.status_code.unwrap(), StatusCode::FORBIDDEN);
557 let mut res = TestClient::post("http://127.0.0.1:5801")
558 .add_header("cookie", cookie.to_string(), true)
559 .form(&[("a", "1"), ("my-csrf-token", &*csrf_token), ("b", "2")])
560 .send(&service)
561 .await;
562 assert_eq!(res.status_code.unwrap(), StatusCode::OK);
563 assert_eq!(res.take_string().await.unwrap(), "POST");
564 }
565
566 #[tokio::test]
567 async fn test_rejects_short_token() {
568 let csrf = Csrf::new(
569 BcryptCipher::new(),
570 CookieStore::new(),
571 HeaderFinder::new("x-csrf-token"),
572 );
573 let router = Router::new().hoop(csrf).get(get_index).post(post_index);
574 let service = Service::new(router);
575
576 let res = TestClient::get("http://127.0.0.1:5801")
577 .send(&service)
578 .await;
579 assert_eq!(res.status_code.unwrap(), StatusCode::OK);
580
581 let cookie = res.cookie("salvo.csrf").unwrap();
582
583 let res = TestClient::post("http://127.0.0.1:5801")
584 .send(&service)
585 .await;
586 assert_eq!(res.status_code.unwrap(), StatusCode::FORBIDDEN);
587
588 let res = TestClient::post("http://127.0.0.1:5801")
589 .add_header("x-csrf-token", "aGVsbG8=", true)
590 .add_header(
591 "cookie",
592 cookie.to_string().split_once('.').unwrap().0,
593 true,
594 )
595 .send(&service)
596 .await;
597 assert_eq!(res.status_code.unwrap(), StatusCode::FORBIDDEN);
598 }
599
600 #[tokio::test]
601 async fn test_rejects_invalid_base64_token() {
602 let csrf = Csrf::new(
603 BcryptCipher::new(),
604 CookieStore::new(),
605 HeaderFinder::new("x-csrf-token"),
606 );
607 let router = Router::new().hoop(csrf).get(get_index).post(post_index);
608 let service = Service::new(router);
609
610 let res = TestClient::get("http://127.0.0.1:5801")
611 .send(&service)
612 .await;
613 assert_eq!(res.status_code.unwrap(), StatusCode::OK);
614
615 let cookie = res.cookie("salvo.csrf").unwrap();
616
617 let res = TestClient::post("http://127.0.0.1:5801")
618 .send(&service)
619 .await;
620 assert_eq!(res.status_code.unwrap(), StatusCode::FORBIDDEN);
621
622 let res = TestClient::post("http://127.0.0.1:5801")
623 .add_header("x-csrf-token", "aGVsbG8", true)
624 .add_header(
625 "cookie",
626 cookie.to_string().split_once('.').unwrap().0,
627 true,
628 )
629 .send(&service)
630 .await;
631 assert_eq!(res.status_code.unwrap(), StatusCode::FORBIDDEN);
632 }
633
634 #[tokio::test]
635 async fn test_rejects_mismatched_token() {
636 let csrf = Csrf::new(
637 BcryptCipher::new(),
638 CookieStore::new(),
639 HeaderFinder::new("x-csrf-token"),
640 );
641 let router = Router::new().hoop(csrf).get(get_index).post(post_index);
642 let service = Service::new(router);
643
644 let mut res = TestClient::get("http://127.0.0.1:5801")
645 .send(&service)
646 .await;
647 assert_eq!(res.status_code.unwrap(), StatusCode::OK);
648 let csrf_token = res.take_string().await.unwrap();
649
650 let res = TestClient::get("http://127.0.0.1:5801")
651 .send(&service)
652 .await;
653 assert_eq!(res.status_code.unwrap(), StatusCode::OK);
654 let cookie = res.cookie("salvo.csrf").unwrap();
655
656 let res = TestClient::post("http://127.0.0.1:5801")
657 .send(&service)
658 .await;
659 assert_eq!(res.status_code.unwrap(), StatusCode::FORBIDDEN);
660
661 let res = TestClient::post("http://127.0.0.1:5801")
662 .add_header("x-csrf-token", csrf_token, true)
663 .add_header(
664 "cookie",
665 cookie.to_string().split_once('.').unwrap().0,
666 true,
667 )
668 .send(&service)
669 .await;
670 assert_eq!(res.status_code.unwrap(), StatusCode::FORBIDDEN);
671 }
672}