use std::sync::Arc;
use std::collections::HashMap;
use axum::{
extract::Request,
http::{HeaderMap, Method, header},
response::{Response, IntoResponse},
};
use elif_http::{
middleware::{Middleware, BoxFuture},
ElifStatusCode, };
use sha2::{Sha256, Digest};
use rand::{thread_rng, Rng};
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
pub use crate::config::CsrfConfig;
use crate::SecurityError;
type TokenStore = Arc<tokio::sync::RwLock<HashMap<String, CsrfTokenData>>>;
#[derive(Debug, Clone)]
pub struct CsrfTokenData {
pub token: String,
pub expires_at: time::OffsetDateTime,
pub user_agent_hash: Option<String>,
}
#[derive(Debug, Clone)]
pub struct CsrfMiddleware {
config: CsrfConfig,
token_store: TokenStore,
}
impl CsrfMiddleware {
pub fn new(config: CsrfConfig) -> Self {
Self {
config,
token_store: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
}
}
pub fn builder() -> CsrfMiddlewareBuilder {
CsrfMiddlewareBuilder::new()
}
pub async fn generate_token(&self, user_agent: Option<&str>) -> String {
let mut rng = thread_rng();
let token_bytes: [u8; 32] = rng.gen();
let token = URL_SAFE_NO_PAD.encode(token_bytes);
let user_agent_hash = user_agent.map(|ua| {
let mut hasher = Sha256::new();
hasher.update(ua.as_bytes());
format!("{:x}", hasher.finalize())
});
let token_data = CsrfTokenData {
token: token.clone(),
expires_at: time::OffsetDateTime::now_utc() +
time::Duration::seconds(self.config.token_lifetime as i64),
user_agent_hash,
};
let mut store = self.token_store.write().await;
store.insert(token.clone(), token_data);
self.cleanup_expired_tokens(&mut store).await;
token
}
pub async fn validate_token(&self, token: &str, user_agent: Option<&str>) -> bool {
let store = self.token_store.read().await;
if let Some(token_data) = store.get(token) {
if time::OffsetDateTime::now_utc() > token_data.expires_at {
return false;
}
if let Some(stored_hash) = &token_data.user_agent_hash {
if let Some(ua) = user_agent {
let mut hasher = Sha256::new();
hasher.update(ua.as_bytes());
let ua_hash = format!("{:x}", hasher.finalize());
if stored_hash != &ua_hash {
return false;
}
} else {
return false;
}
}
true
} else {
false
}
}
pub async fn consume_token(&self, token: &str) {
let mut store = self.token_store.write().await;
store.remove(token);
}
async fn cleanup_expired_tokens(&self, store: &mut HashMap<String, CsrfTokenData>) {
let now = time::OffsetDateTime::now_utc();
store.retain(|_, data| data.expires_at > now);
}
fn is_exempt_path(&self, path: &str) -> bool {
self.config.exempt_paths.contains(path) ||
self.config.exempt_paths.iter().any(|exempt| {
if exempt.ends_with('*') {
path.starts_with(&exempt[..exempt.len()-1])
} else {
path == exempt
}
})
}
fn extract_token(&self, headers: &HeaderMap) -> Option<String> {
if let Some(header_value) = headers.get(&self.config.token_header) {
if let Ok(token) = header_value.to_str() {
return Some(token.to_string());
}
}
if let Some(cookie_header) = headers.get(header::COOKIE) {
if let Ok(cookies) = cookie_header.to_str() {
for cookie in cookies.split(';') {
let cookie = cookie.trim();
if let Some((name, value)) = cookie.split_once('=') {
if name == self.config.cookie_name {
return Some(value.to_string());
}
}
}
}
}
None
}
}
impl Middleware for CsrfMiddleware {
fn process_request<'a>(
&'a self,
request: Request
) -> BoxFuture<'a, Result<Request, Response>> {
Box::pin(async move {
let method = request.method();
let uri = request.uri();
let headers = request.headers();
if matches!(method, &Method::GET | &Method::HEAD | &Method::OPTIONS) {
return Ok(request);
}
if self.is_exempt_path(uri.path()) {
return Ok(request);
}
let user_agent = headers.get(header::USER_AGENT)
.and_then(|h| h.to_str().ok());
if let Some(token) = self.extract_token(headers) {
if self.validate_token(&token, user_agent).await {
return Ok(request);
}
}
let error_response = Response::builder()
.status(ElifStatusCode::FORBIDDEN)
.header("Content-Type", "application/json")
.body(r#"{"error":{"code":"CSRF_VALIDATION_FAILED","message":"CSRF token validation failed"}}"#.into())
.unwrap();
Err(error_response)
})
}
fn name(&self) -> &'static str {
"CsrfMiddleware"
}
}
impl IntoResponse for SecurityError {
fn into_response(self) -> Response {
let (status, message) = match self {
SecurityError::CsrfValidationFailed => {
(ElifStatusCode::FORBIDDEN, "CSRF token validation failed")
}
_ => (ElifStatusCode::INTERNAL_SERVER_ERROR, "Security error"),
};
(status, message).into_response()
}
}
#[derive(Debug)]
pub struct CsrfMiddlewareBuilder {
config: CsrfConfig,
}
impl CsrfMiddlewareBuilder {
pub fn new() -> Self {
Self {
config: CsrfConfig::default(),
}
}
pub fn token_header<S: Into<String>>(mut self, header: S) -> Self {
self.config.token_header = header.into();
self
}
pub fn cookie_name<S: Into<String>>(mut self, name: S) -> Self {
self.config.cookie_name = name.into();
self
}
pub fn token_lifetime(mut self, seconds: u64) -> Self {
self.config.token_lifetime = seconds;
self
}
pub fn secure_cookie(mut self, secure: bool) -> Self {
self.config.secure_cookie = secure;
self
}
pub fn exempt_path<S: Into<String>>(mut self, path: S) -> Self {
self.config.exempt_paths.insert(path.into());
self
}
pub fn exempt_paths<I, S>(mut self, paths: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
for path in paths {
self.config.exempt_paths.insert(path.into());
}
self
}
pub fn build(self) -> CsrfMiddleware {
CsrfMiddleware::new(self.config)
}
}
impl Default for CsrfMiddlewareBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use axum::http::{HeaderValue, Method};
use elif_http::middleware::MiddlewarePipeline;
use std::collections::HashSet;
fn create_test_middleware() -> CsrfMiddleware {
let mut exempt_paths = HashSet::new();
exempt_paths.insert("/api/webhook".to_string());
exempt_paths.insert("/public/*".to_string());
let config = CsrfConfig {
token_header: "X-CSRF-Token".to_string(),
cookie_name: "_csrf_token".to_string(),
token_lifetime: 3600,
secure_cookie: false, exempt_paths,
};
CsrfMiddleware::new(config)
}
#[tokio::test]
async fn test_csrf_token_generation() {
let middleware = create_test_middleware();
let token1 = middleware.generate_token(Some("Mozilla/5.0")).await;
let token2 = middleware.generate_token(Some("Mozilla/5.0")).await;
assert_ne!(token1, token2);
assert!(token1.len() > 20); assert!(token2.len() > 20);
}
#[tokio::test]
async fn test_csrf_token_validation() {
let middleware = create_test_middleware();
let user_agent = Some("Mozilla/5.0");
let token = middleware.generate_token(user_agent).await;
assert!(middleware.validate_token(&token, user_agent).await);
assert!(!middleware.validate_token("invalid_token", user_agent).await);
assert!(!middleware.validate_token(&token, Some("Different Agent")).await);
}
#[tokio::test]
async fn test_csrf_token_expiration() {
let config = CsrfConfig {
token_lifetime: 1, ..Default::default()
};
let middleware = CsrfMiddleware::new(config);
let token = middleware.generate_token(None).await;
assert!(middleware.validate_token(&token, None).await);
tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
assert!(!middleware.validate_token(&token, None).await);
}
#[tokio::test]
async fn test_csrf_exempt_paths() {
let middleware = create_test_middleware();
assert!(middleware.is_exempt_path("/api/webhook"));
assert!(middleware.is_exempt_path("/public/assets/style.css"));
assert!(middleware.is_exempt_path("/public/images/logo.png"));
assert!(!middleware.is_exempt_path("/api/users"));
assert!(!middleware.is_exempt_path("/admin/dashboard"));
}
#[tokio::test]
async fn test_csrf_builder_pattern() {
let middleware = CsrfMiddleware::builder()
.token_header("X-Custom-CSRF-Token")
.cookie_name("_custom_csrf")
.token_lifetime(7200)
.secure_cookie(true)
.exempt_path("/api/public")
.exempt_paths(vec!["/webhook", "/status"])
.build();
assert_eq!(middleware.config.token_header, "X-Custom-CSRF-Token");
assert_eq!(middleware.config.cookie_name, "_custom_csrf");
assert_eq!(middleware.config.token_lifetime, 7200);
assert!(middleware.config.secure_cookie);
assert!(middleware.config.exempt_paths.contains("/api/public"));
assert!(middleware.config.exempt_paths.contains("/webhook"));
assert!(middleware.config.exempt_paths.contains("/status"));
}
#[tokio::test]
async fn test_csrf_middleware_get_requests() {
let middleware = create_test_middleware();
let pipeline = MiddlewarePipeline::new().add(middleware);
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.body(axum::body::Body::empty())
.unwrap();
let result = pipeline.process_request(request).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_csrf_middleware_post_without_token() {
let middleware = create_test_middleware();
let pipeline = MiddlewarePipeline::new().add(middleware);
let request = Request::builder()
.method(Method::POST)
.uri("/test")
.body(axum::body::Body::empty())
.unwrap();
let result = pipeline.process_request(request).await;
assert!(result.is_err());
if let Err(response) = result {
assert_eq!(response.status(), ElifStatusCode::FORBIDDEN);
}
}
#[tokio::test]
async fn test_csrf_middleware_post_with_valid_token() {
let middleware = create_test_middleware();
let token = middleware.generate_token(Some("TestAgent")).await;
let pipeline = MiddlewarePipeline::new().add(middleware);
let request = Request::builder()
.method(Method::POST)
.uri("/test")
.header("X-CSRF-Token", &token)
.header("User-Agent", "TestAgent")
.body(axum::body::Body::empty())
.unwrap();
let result = pipeline.process_request(request).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_csrf_middleware_exempt_paths() {
let middleware = create_test_middleware();
let pipeline = MiddlewarePipeline::new().add(middleware);
let request1 = Request::builder()
.method(Method::POST)
.uri("/api/webhook")
.body(axum::body::Body::empty())
.unwrap();
let result1 = pipeline.process_request(request1).await;
assert!(result1.is_ok());
let request2 = Request::builder()
.method(Method::POST)
.uri("/public/upload")
.body(axum::body::Body::empty())
.unwrap();
let result2 = pipeline.process_request(request2).await;
assert!(result2.is_ok());
}
#[tokio::test]
async fn test_csrf_token_cleanup() {
let config = CsrfConfig {
token_lifetime: 1, ..Default::default()
};
let middleware = CsrfMiddleware::new(config);
let _token1 = middleware.generate_token(None).await;
let _token2 = middleware.generate_token(None).await;
let _token3 = middleware.generate_token(None).await;
{
let store = middleware.token_store.read().await;
assert_eq!(store.len(), 3);
}
tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
let _new_token = middleware.generate_token(None).await;
{
let store = middleware.token_store.read().await;
assert_eq!(store.len(), 1); }
}
#[tokio::test]
async fn test_csrf_cookie_extraction() {
let middleware = create_test_middleware();
let mut headers = HeaderMap::new();
headers.insert(
header::COOKIE,
HeaderValue::from_str("_csrf_token=test_token_123; other_cookie=value").unwrap()
);
let token = middleware.extract_token(&headers);
assert_eq!(token, Some("test_token_123".to_string()));
headers.insert(
"X-CSRF-Token",
HeaderValue::from_str("header_token_456").unwrap()
);
let token = middleware.extract_token(&headers);
assert_eq!(token, Some("header_token_456".to_string()));
}
#[tokio::test]
async fn test_csrf_user_agent_binding() {
let middleware = create_test_middleware();
let token = middleware.generate_token(Some("SpecificAgent")).await;
assert!(middleware.validate_token(&token, Some("SpecificAgent")).await);
assert!(!middleware.validate_token(&token, Some("DifferentAgent")).await);
assert!(!middleware.validate_token(&token, None).await);
}
}