1use hmac::{Hmac, Mac};
4use rand::Rng;
5use sha2::Sha256;
6
7pub const CSRF_TOKEN_LENGTH: usize = 64;
9
10pub const CSRF_SECRET_LENGTH: usize = 32;
12
13pub const CSRF_ALLOWED_CHARS: &str =
15 "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789";
16
17pub const CSRF_SESSION_KEY: &str = "_csrf_token";
19
20pub const REASON_BAD_ORIGIN: &str = "Origin checking failed - does not match any trusted origins.";
22pub const REASON_BAD_REFERER: &str =
24 "Referer checking failed - does not match any trusted origins.";
25pub const REASON_CSRF_TOKEN_MISSING: &str = "CSRF token missing.";
27pub const REASON_INCORRECT_LENGTH: &str = "CSRF token has incorrect length.";
29pub const REASON_INSECURE_REFERER: &str =
31 "Referer checking failed - Referer is insecure while host is secure.";
32pub const REASON_INVALID_CHARACTERS: &str = "CSRF token has invalid characters.";
34pub const REASON_MALFORMED_REFERER: &str = "Referer checking failed - Referer is malformed.";
36pub const REASON_NO_CSRF_COOKIE: &str = "CSRF cookie not set.";
38pub const REASON_NO_REFERER: &str = "Referer checking failed - no Referer.";
40
41#[derive(Debug, Clone, PartialEq, Eq)]
43pub struct RejectRequest {
44 pub reason: String,
46}
47
48#[derive(Debug, Clone, PartialEq, Eq)]
50pub struct InvalidTokenFormat {
51 pub reason: String,
53}
54
55#[derive(Debug, Clone)]
57pub struct CsrfMeta {
58 pub token: String,
60}
61
62#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
64pub enum SameSite {
65 Strict,
67 #[default]
69 Lax,
70 None,
72}
73
74#[derive(Debug, Clone)]
78pub struct CsrfConfig {
79 pub cookie_name: String,
81 pub header_name: String,
83 pub cookie_httponly: bool,
85 pub cookie_secure: bool,
87 pub cookie_samesite: SameSite,
89 pub cookie_domain: Option<String>,
91 pub cookie_path: String,
93 pub cookie_max_age: Option<i64>,
95 pub enable_token_rotation: bool,
97 pub token_rotation_interval: Option<u64>,
99}
100
101impl Default for CsrfConfig {
102 fn default() -> Self {
103 Self {
104 cookie_name: "csrftoken".to_string(),
105 header_name: "X-CSRFToken".to_string(),
106 cookie_httponly: false, cookie_secure: false, cookie_samesite: SameSite::Lax,
109 cookie_domain: None,
110 cookie_path: "/".to_string(),
111 cookie_max_age: None, enable_token_rotation: false, token_rotation_interval: None, }
115 }
116}
117
118impl CsrfConfig {
119 pub fn production() -> Self {
132 Self {
133 cookie_name: "csrftoken".to_string(),
134 header_name: "X-CSRFToken".to_string(),
135 cookie_httponly: false, cookie_secure: true, cookie_samesite: SameSite::Strict,
138 cookie_domain: None,
139 cookie_path: "/".to_string(),
140 cookie_max_age: Some(31449600), enable_token_rotation: true, token_rotation_interval: Some(3600), }
144 }
145
146 pub fn with_token_rotation(mut self, interval: Option<u64>) -> Self {
158 self.enable_token_rotation = true;
159 self.token_rotation_interval = interval;
160 self
161 }
162}
163
164pub struct CsrfMiddleware {
166 #[allow(dead_code)]
168 config: CsrfConfig,
169}
170
171impl CsrfMiddleware {
172 pub fn new() -> Self {
174 Self {
175 config: CsrfConfig::default(),
176 }
177 }
178
179 pub fn with_config(config: CsrfConfig) -> Self {
181 Self { config }
182 }
183}
184
185impl Default for CsrfMiddleware {
186 fn default() -> Self {
187 Self::new()
188 }
189}
190
191#[derive(Debug, Clone)]
193pub struct CsrfToken(pub String);
194
195impl CsrfToken {
196 pub fn new(token: String) -> Self {
198 Self(token)
199 }
200
201 pub fn as_str(&self) -> &str {
203 &self.0
204 }
205}
206
207type HmacSha256 = Hmac<Sha256>;
209
210pub fn generate_token_hmac(secret: &[u8], message: &str) -> String {
235 let mut mac = HmacSha256::new_from_slice(secret).expect("HMAC can take key of any size");
236 mac.update(message.as_bytes());
237 let result = mac.finalize();
238 hex::encode(result.into_bytes())
239}
240
241pub fn verify_token_hmac(token: &str, secret: &[u8], message: &str) -> bool {
270 let Ok(token_bytes) = hex::decode(token) else {
272 return false;
273 };
274
275 let mut mac = HmacSha256::new_from_slice(secret).expect("HMAC can take key of any size");
277 mac.update(message.as_bytes());
278
279 mac.verify_slice(&token_bytes).is_ok()
281}
282
283pub fn get_secret_bytes() -> Vec<u8> {
296 let mut rng = rand::rng();
297 let mut secret = vec![0u8; 32];
298 rng.fill(&mut secret[..]);
299 secret
300}
301
302pub fn get_token_hmac(secret_bytes: &[u8], session_id: &str) -> String {
327 generate_token_hmac(secret_bytes, session_id)
328}
329
330pub fn check_token_hmac(
357 request_token: &str,
358 secret_bytes: &[u8],
359 session_id: &str,
360) -> Result<(), RejectRequest> {
361 if !verify_token_hmac(request_token, secret_bytes, session_id) {
362 return Err(RejectRequest {
363 reason: "CSRF token mismatch (HMAC verification failed)".to_string(),
364 });
365 }
366 Ok(())
367}
368
369pub fn check_origin(origin: &str, allowed_origins: &[String]) -> Result<(), RejectRequest> {
371 if !allowed_origins.iter().any(|o| o == origin) {
372 return Err(RejectRequest {
373 reason: REASON_BAD_ORIGIN.to_string(),
374 });
375 }
376 Ok(())
377}
378
379pub fn check_referer(
381 referer: Option<&str>,
382 allowed_origins: &[String],
383 is_secure: bool,
384) -> Result<(), RejectRequest> {
385 let referer = referer.ok_or_else(|| RejectRequest {
386 reason: REASON_NO_REFERER.to_string(),
387 })?;
388
389 if referer.is_empty() {
390 return Err(RejectRequest {
391 reason: REASON_MALFORMED_REFERER.to_string(),
392 });
393 }
394
395 if is_secure && referer.starts_with("http://") {
396 return Err(RejectRequest {
397 reason: REASON_INSECURE_REFERER.to_string(),
398 });
399 }
400
401 if !allowed_origins.iter().any(|o| referer.starts_with(o)) {
402 return Err(RejectRequest {
403 reason: REASON_BAD_REFERER.to_string(),
404 });
405 }
406
407 Ok(())
408}
409
410pub fn is_same_domain(domain1: &str, domain2: &str) -> bool {
412 domain1 == domain2
413}
414
415pub fn get_token_timestamp() -> u64 {
426 std::time::SystemTime::now()
427 .duration_since(std::time::UNIX_EPOCH)
428 .unwrap_or_default()
429 .as_secs()
430}
431
432pub fn should_rotate_token(
445 token_timestamp: u64,
446 current_timestamp: u64,
447 rotation_interval: Option<u64>,
448) -> bool {
449 match rotation_interval {
450 Some(interval) => current_timestamp.saturating_sub(token_timestamp) >= interval,
451 None => false, }
453}
454
455pub fn generate_token_with_timestamp(secret_bytes: &[u8], session_id: &str) -> String {
468 let timestamp = get_token_timestamp();
469 let message = format!("{}:{}", session_id, timestamp);
470 let token = generate_token_hmac(secret_bytes, &message);
471 format!("{}:{}", token, timestamp)
472}
473
474pub fn verify_token_with_timestamp(
488 token_data: &str,
489 secret_bytes: &[u8],
490 session_id: &str,
491) -> Result<u64, RejectRequest> {
492 if token_data.is_empty() {
493 return Err(RejectRequest {
494 reason: "Invalid token format (empty token)".to_string(),
495 });
496 }
497
498 let mut parts = token_data.rsplitn(2, ':');
501 let timestamp_str = parts.next().ok_or_else(|| RejectRequest {
502 reason: "Invalid token format (missing timestamp)".to_string(),
503 })?;
504 let token = parts.next().ok_or_else(|| RejectRequest {
505 reason: "Invalid token format (missing delimiter)".to_string(),
506 })?;
507
508 if token.is_empty() {
509 return Err(RejectRequest {
510 reason: "Invalid token format (empty token value)".to_string(),
511 });
512 }
513
514 if timestamp_str.is_empty() {
515 return Err(RejectRequest {
516 reason: "Invalid token format (empty timestamp)".to_string(),
517 });
518 }
519
520 if token.len() != CSRF_TOKEN_LENGTH {
522 return Err(RejectRequest {
523 reason: format!(
524 "Invalid token format (expected {} hex characters, got {})",
525 CSRF_TOKEN_LENGTH,
526 token.len()
527 ),
528 });
529 }
530
531 if !token.chars().all(|c| c.is_ascii_hexdigit()) {
532 return Err(RejectRequest {
533 reason: "Invalid token format (token contains non-hex characters)".to_string(),
534 });
535 }
536
537 let timestamp: u64 = timestamp_str.parse().map_err(|_| RejectRequest {
538 reason: "Invalid token format (timestamp is not a valid number)".to_string(),
539 })?;
540
541 let message = format!("{}:{}", session_id, timestamp);
542 if !verify_token_hmac(token, secret_bytes, &message) {
543 return Err(RejectRequest {
544 reason: "CSRF token mismatch (HMAC verification failed)".to_string(),
545 });
546 }
547
548 Ok(timestamp)
549}
550
551#[cfg(test)]
552mod tests {
553 use super::*;
554 use rstest::rstest;
555
556 fn test_secret() -> Vec<u8> {
557 b"test-secret-key-at-least-32-bytes".to_vec()
558 }
559
560 #[rstest]
561 fn test_verify_token_with_timestamp_valid_token() {
562 let secret = test_secret();
564 let session_id = "user-session-12345";
565 let token_data = generate_token_with_timestamp(&secret, session_id);
566
567 let result = verify_token_with_timestamp(&token_data, &secret, session_id);
569
570 assert!(result.is_ok(), "Expected valid token to pass verification");
572 assert!(result.unwrap() > 0, "Expected positive timestamp");
573 }
574
575 #[rstest]
576 fn test_verify_token_with_timestamp_rejects_empty_input() {
577 let secret = test_secret();
579
580 let result = verify_token_with_timestamp("", &secret, "session");
582
583 assert!(result.is_err());
585 assert_eq!(
586 result.unwrap_err().reason,
587 "Invalid token format (empty token)"
588 );
589 }
590
591 #[rstest]
592 #[case("no-delimiter-at-all")]
593 #[case("abcdef")]
594 fn test_verify_token_with_timestamp_rejects_missing_delimiter(#[case] input: &str) {
595 let secret = test_secret();
597
598 let result = verify_token_with_timestamp(input, &secret, "session");
600
601 assert!(result.is_err());
603 assert_eq!(
604 result.unwrap_err().reason,
605 "Invalid token format (missing delimiter)"
606 );
607 }
608
609 #[rstest]
610 fn test_verify_token_with_timestamp_rejects_empty_token_value() {
611 let secret = test_secret();
613
614 let result = verify_token_with_timestamp(":12345", &secret, "session");
616
617 assert!(result.is_err());
619 assert_eq!(
620 result.unwrap_err().reason,
621 "Invalid token format (empty token value)"
622 );
623 }
624
625 #[rstest]
626 fn test_verify_token_with_timestamp_rejects_empty_timestamp() {
627 let secret = test_secret();
629
630 let result = verify_token_with_timestamp(
632 "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2:",
633 &secret,
634 "session",
635 );
636
637 assert!(result.is_err());
639 assert_eq!(
640 result.unwrap_err().reason,
641 "Invalid token format (empty timestamp)"
642 );
643 }
644
645 #[rstest]
646 #[case("short:12345")]
647 #[case("ab:12345")]
648 fn test_verify_token_with_timestamp_rejects_wrong_token_length(#[case] input: &str) {
649 let secret = test_secret();
651
652 let result = verify_token_with_timestamp(input, &secret, "session");
654
655 assert!(result.is_err());
657 assert!(
658 result
659 .unwrap_err()
660 .reason
661 .contains("expected 64 hex characters"),
662 "Expected token length error"
663 );
664 }
665
666 #[rstest]
667 fn test_verify_token_with_timestamp_rejects_non_hex_token() {
668 let secret = test_secret();
670 let bad_token = "g1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6z1b2";
672 let input = format!("{}:12345", bad_token);
673
674 let result = verify_token_with_timestamp(&input, &secret, "session");
676
677 assert!(result.is_err());
679 assert_eq!(
680 result.unwrap_err().reason,
681 "Invalid token format (token contains non-hex characters)"
682 );
683 }
684
685 #[rstest]
686 #[case("a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2:not_a_number")]
687 #[case("a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2:-1")]
688 #[case("a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2:12.34")]
689 fn test_verify_token_with_timestamp_rejects_invalid_timestamp(#[case] input: &str) {
690 let secret = test_secret();
692
693 let result = verify_token_with_timestamp(input, &secret, "session");
695
696 assert!(result.is_err());
698 assert_eq!(
699 result.unwrap_err().reason,
700 "Invalid token format (timestamp is not a valid number)"
701 );
702 }
703
704 #[rstest]
705 fn test_verify_token_with_timestamp_rejects_tampered_token() {
706 let secret = test_secret();
708 let session_id = "user-session-12345";
709 let token_data = generate_token_with_timestamp(&secret, session_id);
710
711 let result = verify_token_with_timestamp(&token_data, &secret, "different-session");
713
714 assert!(result.is_err());
716 assert_eq!(
717 result.unwrap_err().reason,
718 "CSRF token mismatch (HMAC verification failed)"
719 );
720 }
721
722 #[rstest]
723 fn test_verify_token_with_timestamp_rejects_wrong_secret() {
724 let secret = test_secret();
726 let wrong_secret = b"wrong-secret-key-at-least-32-byte".to_vec();
727 let session_id = "user-session-12345";
728 let token_data = generate_token_with_timestamp(&secret, session_id);
729
730 let result = verify_token_with_timestamp(&token_data, &wrong_secret, session_id);
732
733 assert!(result.is_err());
735 assert_eq!(
736 result.unwrap_err().reason,
737 "CSRF token mismatch (HMAC verification failed)"
738 );
739 }
740
741 #[rstest]
742 fn test_verify_token_with_timestamp_handles_extra_colons_in_crafted_input() {
743 let secret = test_secret();
745 let input = "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2:extra:12345";
748
749 let result = verify_token_with_timestamp(input, &secret, "session");
751
752 assert!(result.is_err());
755 }
756
757 #[rstest]
758 fn test_should_rotate_token_normal_case() {
759 let token_timestamp = 1000u64;
761 let current_timestamp = 4700u64; let interval = 3600u64; let result = should_rotate_token(token_timestamp, current_timestamp, Some(interval));
766
767 assert_eq!(
769 result, true,
770 "Token older than interval should trigger rotation"
771 );
772 }
773
774 #[rstest]
775 fn test_should_rotate_token_future_timestamp_no_panic() {
776 let token_timestamp = 5000u64; let current_timestamp = 1000u64;
779 let interval = 3600u64;
780
781 let result = should_rotate_token(token_timestamp, current_timestamp, Some(interval));
783
784 assert_eq!(
786 result, false,
787 "Future-dated token should not trigger rotation"
788 );
789 }
790
791 #[rstest]
792 fn test_should_rotate_token_equal_timestamps() {
793 let timestamp = 1000u64;
795 let interval = 3600u64;
796
797 let result = should_rotate_token(timestamp, timestamp, Some(interval));
799
800 assert_eq!(
802 result, false,
803 "Equal timestamps (0 elapsed) should not trigger rotation"
804 );
805 }
806}