use hmac::{Hmac, Mac};
use rand::Rng;
use sha2::Sha256;
pub const CSRF_TOKEN_LENGTH: usize = 64;
pub const CSRF_SECRET_LENGTH: usize = 32;
pub const CSRF_ALLOWED_CHARS: &str =
"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789";
pub const CSRF_SESSION_KEY: &str = "_csrf_token";
pub const REASON_BAD_ORIGIN: &str = "Origin checking failed - does not match any trusted origins.";
pub const REASON_BAD_REFERER: &str =
"Referer checking failed - does not match any trusted origins.";
pub const REASON_CSRF_TOKEN_MISSING: &str = "CSRF token missing.";
pub const REASON_INCORRECT_LENGTH: &str = "CSRF token has incorrect length.";
pub const REASON_INSECURE_REFERER: &str =
"Referer checking failed - Referer is insecure while host is secure.";
pub const REASON_INVALID_CHARACTERS: &str = "CSRF token has invalid characters.";
pub const REASON_MALFORMED_REFERER: &str = "Referer checking failed - Referer is malformed.";
pub const REASON_NO_CSRF_COOKIE: &str = "CSRF cookie not set.";
pub const REASON_NO_REFERER: &str = "Referer checking failed - no Referer.";
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct RejectRequest {
pub reason: String,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct InvalidTokenFormat {
pub reason: String,
}
#[derive(Debug, Clone)]
pub struct CsrfMeta {
pub token: String,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum SameSite {
Strict,
#[default]
Lax,
None,
}
#[derive(Debug, Clone)]
pub struct CsrfConfig {
pub cookie_name: String,
pub header_name: String,
pub cookie_httponly: bool,
pub cookie_secure: bool,
pub cookie_samesite: SameSite,
pub cookie_domain: Option<String>,
pub cookie_path: String,
pub cookie_max_age: Option<i64>,
pub enable_token_rotation: bool,
pub token_rotation_interval: Option<u64>,
}
impl Default for CsrfConfig {
fn default() -> Self {
Self {
cookie_name: "csrftoken".to_string(),
header_name: "X-CSRFToken".to_string(),
cookie_httponly: false, cookie_secure: false, cookie_samesite: SameSite::Lax,
cookie_domain: None,
cookie_path: "/".to_string(),
cookie_max_age: None, enable_token_rotation: false, token_rotation_interval: None, }
}
}
impl CsrfConfig {
pub fn production() -> Self {
Self {
cookie_name: "csrftoken".to_string(),
header_name: "X-CSRFToken".to_string(),
cookie_httponly: false, cookie_secure: true, cookie_samesite: SameSite::Strict,
cookie_domain: None,
cookie_path: "/".to_string(),
cookie_max_age: Some(31449600), enable_token_rotation: true, token_rotation_interval: Some(3600), }
}
pub fn with_token_rotation(mut self, interval: Option<u64>) -> Self {
self.enable_token_rotation = true;
self.token_rotation_interval = interval;
self
}
}
pub struct CsrfMiddleware {
#[allow(dead_code)]
config: CsrfConfig,
}
impl CsrfMiddleware {
pub fn new() -> Self {
Self {
config: CsrfConfig::default(),
}
}
pub fn with_config(config: CsrfConfig) -> Self {
Self { config }
}
}
impl Default for CsrfMiddleware {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct CsrfToken(pub String);
impl CsrfToken {
pub fn new(token: String) -> Self {
Self(token)
}
pub fn as_str(&self) -> &str {
&self.0
}
}
type HmacSha256 = Hmac<Sha256>;
pub fn generate_token_hmac(secret: &[u8], message: &str) -> String {
let mut mac = HmacSha256::new_from_slice(secret).expect("HMAC can take key of any size");
mac.update(message.as_bytes());
let result = mac.finalize();
hex::encode(result.into_bytes())
}
pub fn verify_token_hmac(token: &str, secret: &[u8], message: &str) -> bool {
let Ok(token_bytes) = hex::decode(token) else {
return false;
};
let mut mac = HmacSha256::new_from_slice(secret).expect("HMAC can take key of any size");
mac.update(message.as_bytes());
mac.verify_slice(&token_bytes).is_ok()
}
pub fn get_secret_bytes() -> Vec<u8> {
let mut rng = rand::rng();
let mut secret = vec![0u8; 32];
rng.fill(&mut secret[..]);
secret
}
pub fn get_token_hmac(secret_bytes: &[u8], session_id: &str) -> String {
generate_token_hmac(secret_bytes, session_id)
}
pub fn check_token_hmac(
request_token: &str,
secret_bytes: &[u8],
session_id: &str,
) -> Result<(), RejectRequest> {
if !verify_token_hmac(request_token, secret_bytes, session_id) {
return Err(RejectRequest {
reason: "CSRF token mismatch (HMAC verification failed)".to_string(),
});
}
Ok(())
}
pub fn check_origin(origin: &str, allowed_origins: &[String]) -> Result<(), RejectRequest> {
if !allowed_origins.iter().any(|o| o == origin) {
return Err(RejectRequest {
reason: REASON_BAD_ORIGIN.to_string(),
});
}
Ok(())
}
pub fn check_referer(
referer: Option<&str>,
allowed_origins: &[String],
is_secure: bool,
) -> Result<(), RejectRequest> {
let referer = referer.ok_or_else(|| RejectRequest {
reason: REASON_NO_REFERER.to_string(),
})?;
if referer.is_empty() {
return Err(RejectRequest {
reason: REASON_MALFORMED_REFERER.to_string(),
});
}
if is_secure && referer.starts_with("http://") {
return Err(RejectRequest {
reason: REASON_INSECURE_REFERER.to_string(),
});
}
if !allowed_origins.iter().any(|o| referer.starts_with(o)) {
return Err(RejectRequest {
reason: REASON_BAD_REFERER.to_string(),
});
}
Ok(())
}
pub fn is_same_domain(domain1: &str, domain2: &str) -> bool {
domain1 == domain2
}
pub fn get_token_timestamp() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
}
pub fn should_rotate_token(
token_timestamp: u64,
current_timestamp: u64,
rotation_interval: Option<u64>,
) -> bool {
match rotation_interval {
Some(interval) => current_timestamp.saturating_sub(token_timestamp) >= interval,
None => false, }
}
pub fn generate_token_with_timestamp(secret_bytes: &[u8], session_id: &str) -> String {
let timestamp = get_token_timestamp();
let message = format!("{}:{}", session_id, timestamp);
let token = generate_token_hmac(secret_bytes, &message);
format!("{}:{}", token, timestamp)
}
pub fn verify_token_with_timestamp(
token_data: &str,
secret_bytes: &[u8],
session_id: &str,
) -> Result<u64, RejectRequest> {
if token_data.is_empty() {
return Err(RejectRequest {
reason: "Invalid token format (empty token)".to_string(),
});
}
let mut parts = token_data.rsplitn(2, ':');
let timestamp_str = parts.next().ok_or_else(|| RejectRequest {
reason: "Invalid token format (missing timestamp)".to_string(),
})?;
let token = parts.next().ok_or_else(|| RejectRequest {
reason: "Invalid token format (missing delimiter)".to_string(),
})?;
if token.is_empty() {
return Err(RejectRequest {
reason: "Invalid token format (empty token value)".to_string(),
});
}
if timestamp_str.is_empty() {
return Err(RejectRequest {
reason: "Invalid token format (empty timestamp)".to_string(),
});
}
if token.len() != CSRF_TOKEN_LENGTH {
return Err(RejectRequest {
reason: format!(
"Invalid token format (expected {} hex characters, got {})",
CSRF_TOKEN_LENGTH,
token.len()
),
});
}
if !token.chars().all(|c| c.is_ascii_hexdigit()) {
return Err(RejectRequest {
reason: "Invalid token format (token contains non-hex characters)".to_string(),
});
}
let timestamp: u64 = timestamp_str.parse().map_err(|_| RejectRequest {
reason: "Invalid token format (timestamp is not a valid number)".to_string(),
})?;
let message = format!("{}:{}", session_id, timestamp);
if !verify_token_hmac(token, secret_bytes, &message) {
return Err(RejectRequest {
reason: "CSRF token mismatch (HMAC verification failed)".to_string(),
});
}
Ok(timestamp)
}
#[cfg(test)]
mod tests {
use super::*;
use rstest::rstest;
fn test_secret() -> Vec<u8> {
b"test-secret-key-at-least-32-bytes".to_vec()
}
#[rstest]
fn test_verify_token_with_timestamp_valid_token() {
let secret = test_secret();
let session_id = "user-session-12345";
let token_data = generate_token_with_timestamp(&secret, session_id);
let result = verify_token_with_timestamp(&token_data, &secret, session_id);
assert!(result.is_ok(), "Expected valid token to pass verification");
assert!(result.unwrap() > 0, "Expected positive timestamp");
}
#[rstest]
fn test_verify_token_with_timestamp_rejects_empty_input() {
let secret = test_secret();
let result = verify_token_with_timestamp("", &secret, "session");
assert!(result.is_err());
assert_eq!(
result.unwrap_err().reason,
"Invalid token format (empty token)"
);
}
#[rstest]
#[case("no-delimiter-at-all")]
#[case("abcdef")]
fn test_verify_token_with_timestamp_rejects_missing_delimiter(#[case] input: &str) {
let secret = test_secret();
let result = verify_token_with_timestamp(input, &secret, "session");
assert!(result.is_err());
assert_eq!(
result.unwrap_err().reason,
"Invalid token format (missing delimiter)"
);
}
#[rstest]
fn test_verify_token_with_timestamp_rejects_empty_token_value() {
let secret = test_secret();
let result = verify_token_with_timestamp(":12345", &secret, "session");
assert!(result.is_err());
assert_eq!(
result.unwrap_err().reason,
"Invalid token format (empty token value)"
);
}
#[rstest]
fn test_verify_token_with_timestamp_rejects_empty_timestamp() {
let secret = test_secret();
let result = verify_token_with_timestamp(
"a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2:",
&secret,
"session",
);
assert!(result.is_err());
assert_eq!(
result.unwrap_err().reason,
"Invalid token format (empty timestamp)"
);
}
#[rstest]
#[case("short:12345")]
#[case("ab:12345")]
fn test_verify_token_with_timestamp_rejects_wrong_token_length(#[case] input: &str) {
let secret = test_secret();
let result = verify_token_with_timestamp(input, &secret, "session");
assert!(result.is_err());
assert!(
result
.unwrap_err()
.reason
.contains("expected 64 hex characters"),
"Expected token length error"
);
}
#[rstest]
fn test_verify_token_with_timestamp_rejects_non_hex_token() {
let secret = test_secret();
let bad_token = "g1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6z1b2";
let input = format!("{}:12345", bad_token);
let result = verify_token_with_timestamp(&input, &secret, "session");
assert!(result.is_err());
assert_eq!(
result.unwrap_err().reason,
"Invalid token format (token contains non-hex characters)"
);
}
#[rstest]
#[case("a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2:not_a_number")]
#[case("a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2:-1")]
#[case("a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2:12.34")]
fn test_verify_token_with_timestamp_rejects_invalid_timestamp(#[case] input: &str) {
let secret = test_secret();
let result = verify_token_with_timestamp(input, &secret, "session");
assert!(result.is_err());
assert_eq!(
result.unwrap_err().reason,
"Invalid token format (timestamp is not a valid number)"
);
}
#[rstest]
fn test_verify_token_with_timestamp_rejects_tampered_token() {
let secret = test_secret();
let session_id = "user-session-12345";
let token_data = generate_token_with_timestamp(&secret, session_id);
let result = verify_token_with_timestamp(&token_data, &secret, "different-session");
assert!(result.is_err());
assert_eq!(
result.unwrap_err().reason,
"CSRF token mismatch (HMAC verification failed)"
);
}
#[rstest]
fn test_verify_token_with_timestamp_rejects_wrong_secret() {
let secret = test_secret();
let wrong_secret = b"wrong-secret-key-at-least-32-byte".to_vec();
let session_id = "user-session-12345";
let token_data = generate_token_with_timestamp(&secret, session_id);
let result = verify_token_with_timestamp(&token_data, &wrong_secret, session_id);
assert!(result.is_err());
assert_eq!(
result.unwrap_err().reason,
"CSRF token mismatch (HMAC verification failed)"
);
}
#[rstest]
fn test_verify_token_with_timestamp_handles_extra_colons_in_crafted_input() {
let secret = test_secret();
let input = "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2:extra:12345";
let result = verify_token_with_timestamp(input, &secret, "session");
assert!(result.is_err());
}
#[rstest]
fn test_should_rotate_token_normal_case() {
let token_timestamp = 1000u64;
let current_timestamp = 4700u64; let interval = 3600u64;
let result = should_rotate_token(token_timestamp, current_timestamp, Some(interval));
assert!(result, "Token older than interval should trigger rotation");
}
#[rstest]
fn test_should_rotate_token_future_timestamp_no_panic() {
let token_timestamp = 5000u64; let current_timestamp = 1000u64;
let interval = 3600u64;
let result = should_rotate_token(token_timestamp, current_timestamp, Some(interval));
assert!(!result, "Future-dated token should not trigger rotation");
}
#[rstest]
fn test_should_rotate_token_equal_timestamps() {
let timestamp = 1000u64;
let interval = 3600u64;
let result = should_rotate_token(timestamp, timestamp, Some(interval));
assert!(
!result,
"Equal timestamps (0 elapsed) should not trigger rotation"
);
}
}