use crate::SecurityError;
use rand::RngCore;
use std::collections::HashMap;
use tokio::sync::RwLock;
#[derive(Debug, Clone)]
pub struct CsrfToken {
pub token: String,
pub header_name: String,
pub parameter_name: String,
}
impl CsrfToken {
pub fn new(
token: impl Into<String>,
header_name: impl Into<String>,
parameter_name: impl Into<String>,
) -> Self {
Self {
token: token.into(),
header_name: header_name.into(),
parameter_name: parameter_name.into(),
}
}
}
#[async_trait::async_trait]
pub trait CsrfTokenRepository: Send + Sync {
async fn generate_token(&self, identifier: &str) -> CsrfToken;
async fn save_token(&self, identifier: &str, token: &CsrfToken);
async fn load_token(&self, identifier: &str) -> Option<CsrfToken>;
async fn remove_token(&self, identifier: &str);
}
#[derive(Debug, Clone)]
pub struct InMemoryCsrfTokenRepository {
header_name: String,
parameter_name: String,
token_length: usize,
store: std::sync::Arc<RwLock<HashMap<String, CsrfToken>>>,
}
impl InMemoryCsrfTokenRepository {
pub fn new() -> Self {
Self {
header_name: "X-CSRF-TOKEN".to_string(),
parameter_name: "_csrf".to_string(),
token_length: 32,
store: std::sync::Arc::new(RwLock::new(HashMap::new())),
}
}
pub fn header_name(mut self, name: impl Into<String>) -> Self {
self.header_name = name.into();
self
}
pub fn parameter_name(mut self, name: impl Into<String>) -> Self {
self.parameter_name = name.into();
self
}
pub fn token_length(mut self, length: usize) -> Self {
self.token_length = length;
self
}
fn random_token(&self) -> String {
let mut buf = vec![0u8; self.token_length];
rand::rng().fill_bytes(&mut buf);
hex::encode(&buf)
}
}
impl Default for InMemoryCsrfTokenRepository {
fn default() -> Self {
Self::new()
}
}
#[async_trait::async_trait]
impl CsrfTokenRepository for InMemoryCsrfTokenRepository {
async fn generate_token(&self, identifier: &str) -> CsrfToken {
let token_value = self.random_token();
let csrf_token =
CsrfToken::new(token_value, self.header_name.clone(), self.parameter_name.clone());
self.save_token(identifier, &csrf_token).await;
csrf_token
}
async fn save_token(&self, identifier: &str, token: &CsrfToken) {
let mut store = self.store.write().await;
store.insert(identifier.to_string(), token.clone());
}
async fn load_token(&self, identifier: &str) -> Option<CsrfToken> {
let store = self.store.read().await;
store.get(identifier).cloned()
}
async fn remove_token(&self, identifier: &str) {
let mut store = self.store.write().await;
store.remove(identifier);
}
}
#[derive(Debug, Clone)]
pub struct CookieCsrfTokenRepository {
pub cookie_name: String,
pub header_name: String,
pub parameter_name: String,
pub cookie_http_only: bool,
pub cookie_secure: bool,
pub cookie_path: String,
pub cookie_max_age: Option<u64>,
pub cookie_same_site: String,
store: std::sync::Arc<RwLock<HashMap<String, CsrfToken>>>,
}
impl CookieCsrfTokenRepository {
pub fn new() -> Self {
Self {
cookie_name: "XSRF-TOKEN".to_string(),
header_name: "X-XSRF-TOKEN".to_string(),
parameter_name: "_csrf".to_string(),
cookie_http_only: false,
cookie_secure: false,
cookie_path: "/".to_string(),
cookie_max_age: None,
cookie_same_site: "Lax".to_string(),
store: std::sync::Arc::new(RwLock::new(HashMap::new())),
}
}
pub fn with_http_only_false() -> Self {
Self::new()
}
pub fn cookie_name(mut self, name: impl Into<String>) -> Self {
self.cookie_name = name.into();
self
}
pub fn header_name(mut self, name: impl Into<String>) -> Self {
self.header_name = name.into();
self
}
pub fn parameter_name(mut self, name: impl Into<String>) -> Self {
self.parameter_name = name.into();
self
}
pub fn cookie_http_only(mut self, http_only: bool) -> Self {
self.cookie_http_only = http_only;
self
}
pub fn cookie_secure(mut self, secure: bool) -> Self {
self.cookie_secure = secure;
self
}
pub fn cookie_path(mut self, path: impl Into<String>) -> Self {
self.cookie_path = path.into();
self
}
pub fn cookie_max_age(mut self, max_age: Option<u64>) -> Self {
self.cookie_max_age = max_age;
self
}
pub fn cookie_same_site(mut self, same_site: impl Into<String>) -> Self {
self.cookie_same_site = same_site.into();
self
}
fn random_token() -> String {
let mut buf = vec![0u8; 32];
rand::rng().fill_bytes(&mut buf);
hex::encode(&buf)
}
}
impl Default for CookieCsrfTokenRepository {
fn default() -> Self {
Self::new()
}
}
#[async_trait::async_trait]
impl CsrfTokenRepository for CookieCsrfTokenRepository {
async fn generate_token(&self, identifier: &str) -> CsrfToken {
let token_value = Self::random_token();
let csrf_token =
CsrfToken::new(token_value, self.header_name.clone(), self.parameter_name.clone());
self.save_token(identifier, &csrf_token).await;
csrf_token
}
async fn save_token(&self, identifier: &str, token: &CsrfToken) {
let mut store = self.store.write().await;
store.insert(identifier.to_string(), token.clone());
}
async fn load_token(&self, identifier: &str) -> Option<CsrfToken> {
let store = self.store.read().await;
store.get(identifier).cloned()
}
async fn remove_token(&self, identifier: &str) {
let mut store = self.store.write().await;
store.remove(identifier);
}
}
#[derive(Debug, Clone)]
pub struct CsrfProtectionConfig {
pub enabled: bool,
pub ignored_methods: Vec<http::Method>,
pub token_header_name: String,
pub token_param_name: String,
}
impl Default for CsrfProtectionConfig {
fn default() -> Self {
Self {
enabled: true,
ignored_methods: vec![
http::Method::GET,
http::Method::HEAD,
http::Method::OPTIONS,
http::Method::TRACE,
],
token_header_name: "X-CSRF-TOKEN".to_string(),
token_param_name: "_csrf".to_string(),
}
}
}
impl CsrfProtectionConfig {
pub fn disabled() -> Self {
Self {
enabled: false,
..Self::default()
}
}
pub fn enabled(mut self, enabled: bool) -> Self {
self.enabled = enabled;
self
}
pub fn ignore_method(mut self, method: http::Method) -> Self {
if !self.ignored_methods.contains(&method) {
self.ignored_methods.push(method);
}
self
}
pub fn token_header_name(mut self, name: impl Into<String>) -> Self {
self.token_header_name = name.into();
self
}
pub fn token_param_name(mut self, name: impl Into<String>) -> Self {
self.token_param_name = name.into();
self
}
pub fn is_method_ignored(&self, method: &http::Method) -> bool {
self.ignored_methods.contains(method)
}
}
pub struct CsrfValidator<R: CsrfTokenRepository> {
config: CsrfProtectionConfig,
repository: R,
}
impl<R: CsrfTokenRepository> CsrfValidator<R> {
pub fn new(config: CsrfProtectionConfig, repository: R) -> Self {
Self { config, repository }
}
pub async fn generate_token(&self, identifier: &str) -> crate::SecurityResult<CsrfToken> {
if !self.config.enabled {
return Err(SecurityError::CsrfValidationFailed(
"CSRF protection is disabled".to_string(),
));
}
Ok(self.repository.generate_token(identifier).await)
}
pub async fn validate_token(
&self,
identifier: &str,
submitted_token: &str,
) -> crate::SecurityResult<()> {
if !self.config.enabled {
return Ok(());
}
let stored = self
.repository
.load_token(identifier)
.await
.ok_or_else(|| {
SecurityError::CsrfValidationFailed("No CSRF token found for session".to_string())
})?;
if subtle::ConstantTimeEq::ct_eq(stored.token.as_bytes(), submitted_token.as_bytes()).into()
{
Ok(())
} else {
Err(SecurityError::CsrfValidationFailed("CSRF token mismatch".to_string()))
}
}
pub fn is_method_ignored(&self, method: &http::Method) -> bool {
self.config.is_method_ignored(method)
}
pub async fn validate(
&self,
method: &http::Method,
identifier: &str,
submitted_token: &str,
) -> crate::SecurityResult<()> {
if self.is_method_ignored(method) {
return Ok(());
}
self.validate_token(identifier, submitted_token).await
}
pub async fn remove_token(&self, identifier: &str) {
self.repository.remove_token(identifier).await;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_csrf_token_new() {
let token = CsrfToken::new("abc123", "X-CSRF-TOKEN", "_csrf");
assert_eq!(token.token, "abc123");
assert_eq!(token.header_name, "X-CSRF-TOKEN");
assert_eq!(token.parameter_name, "_csrf");
}
#[test]
fn test_default_config() {
let config = CsrfProtectionConfig::default();
assert!(config.enabled);
assert!(config.is_method_ignored(&http::Method::GET));
assert!(config.is_method_ignored(&http::Method::HEAD));
assert!(config.is_method_ignored(&http::Method::OPTIONS));
assert!(config.is_method_ignored(&http::Method::TRACE));
assert!(!config.is_method_ignored(&http::Method::POST));
assert!(!config.is_method_ignored(&http::Method::PUT));
assert!(!config.is_method_ignored(&http::Method::DELETE));
}
#[test]
fn test_disabled_config() {
let config = CsrfProtectionConfig::disabled();
assert!(!config.enabled);
assert!(config.is_method_ignored(&http::Method::GET));
}
#[test]
fn test_config_builder() {
let config = CsrfProtectionConfig::default()
.enabled(false)
.token_header_name("X-MY-CSRF")
.token_param_name("csrf_field");
assert!(!config.enabled);
assert_eq!(config.token_header_name, "X-MY-CSRF");
assert_eq!(config.token_param_name, "csrf_field");
}
#[test]
fn test_config_ignore_method() {
let config = CsrfProtectionConfig::default().ignore_method(http::Method::POST);
assert!(config.is_method_ignored(&http::Method::POST));
}
#[test]
fn test_config_ignore_method_no_duplicates() {
let config = CsrfProtectionConfig::default()
.ignore_method(http::Method::GET)
.ignore_method(http::Method::GET);
assert_eq!(
config
.ignored_methods
.iter()
.filter(|m| **m == http::Method::GET)
.count(),
1
);
}
#[tokio::test]
async fn test_in_memory_generate_and_load() {
let repo = InMemoryCsrfTokenRepository::new();
let token = repo.generate_token("session-1").await;
assert!(!token.token.is_empty());
assert_eq!(token.header_name, "X-CSRF-TOKEN");
assert_eq!(token.parameter_name, "_csrf");
let loaded = repo.load_token("session-1").await;
assert!(loaded.is_some());
assert_eq!(loaded.unwrap().token, token.token);
}
#[tokio::test]
async fn test_in_memory_save_and_load() {
let repo = InMemoryCsrfTokenRepository::new();
let token = CsrfToken::new("my-token", "X-CSRF", "csrf");
repo.save_token("session-2", &token).await;
let loaded = repo.load_token("session-2").await;
assert!(loaded.is_some());
assert_eq!(loaded.unwrap().token, "my-token");
}
#[tokio::test]
async fn test_in_memory_load_missing() {
let repo = InMemoryCsrfTokenRepository::new();
let loaded = repo.load_token("nonexistent").await;
assert!(loaded.is_none());
}
#[tokio::test]
async fn test_in_memory_remove() {
let repo = InMemoryCsrfTokenRepository::new();
repo.generate_token("session-3").await;
repo.remove_token("session-3").await;
let loaded = repo.load_token("session-3").await;
assert!(loaded.is_none());
}
#[tokio::test]
async fn test_in_memory_custom_settings() {
let repo = InMemoryCsrfTokenRepository::new()
.header_name("X-MY-CSRF")
.parameter_name("my_csrf")
.token_length(16);
let token = repo.generate_token("s").await;
assert_eq!(token.token.len(), 32);
assert_eq!(token.header_name, "X-MY-CSRF");
assert_eq!(token.parameter_name, "my_csrf");
}
#[test]
fn test_cookie_repo_default_settings() {
let repo = CookieCsrfTokenRepository::new();
assert_eq!(repo.cookie_name, "XSRF-TOKEN");
assert_eq!(repo.header_name, "X-XSRF-TOKEN");
assert_eq!(repo.parameter_name, "_csrf");
assert!(!repo.cookie_http_only);
assert!(!repo.cookie_secure);
assert_eq!(repo.cookie_path, "/");
assert_eq!(repo.cookie_same_site, "Lax");
assert!(repo.cookie_max_age.is_none());
}
#[test]
fn test_cookie_repo_with_http_only_false() {
let repo = CookieCsrfTokenRepository::with_http_only_false();
assert!(!repo.cookie_http_only);
}
#[test]
fn test_cookie_repo_builder() {
let repo = CookieCsrfTokenRepository::new()
.cookie_name("MY-CSRF-COOKIE")
.header_name("X-MY-CSRF")
.parameter_name("my_csrf")
.cookie_http_only(true)
.cookie_secure(true)
.cookie_path("/api")
.cookie_max_age(Some(3600))
.cookie_same_site("Strict");
assert_eq!(repo.cookie_name, "MY-CSRF-COOKIE");
assert_eq!(repo.header_name, "X-MY-CSRF");
assert!(repo.cookie_http_only);
assert!(repo.cookie_secure);
assert_eq!(repo.cookie_path, "/api");
assert_eq!(repo.cookie_max_age, Some(3600));
assert_eq!(repo.cookie_same_site, "Strict");
}
#[tokio::test]
async fn test_cookie_repo_generate_and_load() {
let repo = CookieCsrfTokenRepository::new();
let token = repo.generate_token("cookie-session-1").await;
assert!(!token.token.is_empty());
let loaded = repo.load_token("cookie-session-1").await;
assert!(loaded.is_some());
assert_eq!(loaded.unwrap().token, token.token);
}
#[tokio::test]
async fn test_validator_generate_and_validate() {
let config = CsrfProtectionConfig::default();
let repo = InMemoryCsrfTokenRepository::new();
let validator = CsrfValidator::new(config, repo);
let token = validator.generate_token("sess-1").await.unwrap();
let result = validator.validate_token("sess-1", &token.token).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_validator_invalid_token() {
let config = CsrfProtectionConfig::default();
let repo = InMemoryCsrfTokenRepository::new();
let validator = CsrfValidator::new(config, repo);
validator.generate_token("sess-2").await.unwrap();
let result = validator.validate_token("sess-2", "wrong-token").await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_validator_missing_session() {
let config = CsrfProtectionConfig::default();
let repo = InMemoryCsrfTokenRepository::new();
let validator = CsrfValidator::new(config, repo);
let result = validator
.validate_token("no-such-session", "any-token")
.await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_validator_disabled() {
let config = CsrfProtectionConfig::disabled();
let repo = InMemoryCsrfTokenRepository::new();
let validator = CsrfValidator::new(config, repo);
let result = validator.validate_token("any", "any").await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_validator_generate_when_disabled() {
let config = CsrfProtectionConfig::disabled();
let repo = InMemoryCsrfTokenRepository::new();
let validator = CsrfValidator::new(config, repo);
let result = validator.generate_token("sess").await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_validator_validate_safe_methods() {
let config = CsrfProtectionConfig::default();
let repo = InMemoryCsrfTokenRepository::new();
let validator = CsrfValidator::new(config, repo);
for method in &[
http::Method::GET,
http::Method::HEAD,
http::Method::OPTIONS,
http::Method::TRACE,
] {
let result = validator.validate(method, "no-token-needed", "").await;
assert!(result.is_ok(), "Method {} should be ignored", method);
}
}
#[tokio::test]
async fn test_validator_validate_unsafe_methods() {
let config = CsrfProtectionConfig::default();
let repo = InMemoryCsrfTokenRepository::new();
let validator = CsrfValidator::new(config, repo);
let token = validator.generate_token("sess-unsafe").await.unwrap();
let result = validator
.validate(&http::Method::POST, "sess-unsafe", &token.token)
.await;
assert!(result.is_ok());
let result = validator
.validate(&http::Method::PUT, "sess-unsafe", "bad")
.await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_validator_remove_token() {
let config = CsrfProtectionConfig::default();
let repo = InMemoryCsrfTokenRepository::new();
let validator = CsrfValidator::new(config, repo);
validator.generate_token("sess-rm").await.unwrap();
validator.remove_token("sess-rm").await;
let result = validator.validate_token("sess-rm", "any").await;
assert!(result.is_err());
}
}