1use hmac::{Hmac, Mac};
2use sha2::Sha256;
3use subtle::ConstantTimeEq;
4
5use crate::types::SessionToken;
6
7type HmacSha256 = Hmac<Sha256>;
8
9pub fn derive_csrf_token(session_token: &SessionToken, secret: &[u8]) -> String {
14 let mut mac = HmacSha256::new_from_slice(secret).expect("HMAC accepts any key length");
15 mac.update(session_token.as_str().as_bytes());
16 let result = mac.finalize();
17 format!("{:x}", result.into_bytes())
18}
19
20pub fn verify_csrf_token(session_token: &SessionToken, secret: &[u8], submitted: &str) -> bool {
25 let expected = derive_csrf_token(session_token, secret);
26 if expected.len() != submitted.len() {
27 return false;
28 }
29 expected.as_bytes().ct_eq(submitted.as_bytes()).into()
30}
31
32#[cfg(test)]
33mod tests {
34 use super::*;
35
36 fn token(s: &str) -> SessionToken {
37 SessionToken::from_encoded(s.to_string())
38 }
39
40 const SECRET: &[u8] = b"test-secret-key-32bytes-padding!";
41
42 #[test]
43 fn derive_is_deterministic() {
44 let t = token("abc123");
45 let a = derive_csrf_token(&t, SECRET);
46 let b = derive_csrf_token(&t, SECRET);
47 assert_eq!(a, b);
48 }
49
50 #[test]
51 fn derive_differs_for_different_tokens() {
52 let a = derive_csrf_token(&token("token_a"), SECRET);
53 let b = derive_csrf_token(&token("token_b"), SECRET);
54 assert_ne!(a, b);
55 }
56
57 #[test]
58 fn derive_differs_for_different_secrets() {
59 let t = token("same_token");
60 let a = derive_csrf_token(&t, b"secret_one_32bytes_padding_here!");
61 let b = derive_csrf_token(&t, b"secret_two_32bytes_padding_here!");
62 assert_ne!(a, b);
63 }
64
65 #[test]
66 fn verify_accepts_correct_token() {
67 let t = token("abc123");
68 let csrf = derive_csrf_token(&t, SECRET);
69 assert!(verify_csrf_token(&t, SECRET, &csrf));
70 }
71
72 #[test]
73 fn verify_rejects_wrong_token() {
74 let t = token("abc123");
75 assert!(!verify_csrf_token(&t, SECRET, "wrong_token_value"));
76 }
77
78 #[test]
79 fn verify_rejects_different_length() {
80 let t = token("abc123");
81 assert!(!verify_csrf_token(&t, SECRET, "short"));
82 }
83
84 #[test]
85 fn output_is_64_hex_chars() {
86 let t = token("any_token_value");
87 let csrf = derive_csrf_token(&t, SECRET);
88 assert_eq!(csrf.len(), 64);
89 assert!(csrf.chars().all(|c| c.is_ascii_hexdigit()));
90 }
91}