use crate::http::{HttpResponse, Response};
use crate::middleware::{Middleware, Next};
use crate::session::get_csrf_token;
use crate::Request;
use async_trait::async_trait;
pub struct CsrfMiddleware {
protected_methods: Vec<&'static str>,
except: Vec<String>,
}
impl CsrfMiddleware {
pub fn new() -> Self {
Self {
protected_methods: vec!["POST", "PUT", "PATCH", "DELETE"],
except: Vec::new(),
}
}
pub fn except(mut self, paths: Vec<impl Into<String>>) -> Self {
self.except = paths.into_iter().map(|p| p.into()).collect();
self
}
fn is_excluded(&self, path: &str) -> bool {
for pattern in &self.except {
if pattern.ends_with('*') {
let prefix = &pattern[..pattern.len() - 1];
if path.starts_with(prefix) {
return true;
}
} else if pattern == path {
return true;
}
}
false
}
}
impl Default for CsrfMiddleware {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Middleware for CsrfMiddleware {
async fn handle(&self, request: Request, next: Next) -> Response {
let method = request.method().as_str();
if !self.protected_methods.contains(&method) {
return next(request).await;
}
if self.is_excluded(request.path()) {
return next(request).await;
}
let expected_token = match get_csrf_token() {
Some(token) => token,
None => {
return Err(HttpResponse::json(serde_json::json!({
"message": "Session not found. CSRF validation failed."
}))
.status(500));
}
};
let provided_token = request
.header("X-CSRF-TOKEN")
.or_else(|| request.header("X-XSRF-TOKEN"))
.map(|s| s.to_string());
match provided_token {
Some(token) if constant_time_compare(&token, &expected_token) => {
next(request).await
}
_ => {
Err(HttpResponse::json(serde_json::json!({
"message": "CSRF token mismatch."
}))
.status(419))
}
}
}
}
fn constant_time_compare(a: &str, b: &str) -> bool {
if a.len() != b.len() {
return false;
}
a.bytes()
.zip(b.bytes())
.fold(0, |acc, (x, y)| acc | (x ^ y))
== 0
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_constant_time_compare() {
assert!(constant_time_compare("abc123", "abc123"));
assert!(!constant_time_compare("abc123", "abc124"));
assert!(!constant_time_compare("abc123", "abc12"));
assert!(!constant_time_compare("", "a"));
}
#[test]
fn test_is_excluded() {
let csrf = CsrfMiddleware::new().except(vec!["/webhooks/*", "/api/public"]);
assert!(csrf.is_excluded("/webhooks/stripe"));
assert!(csrf.is_excluded("/webhooks/github/events"));
assert!(csrf.is_excluded("/api/public"));
assert!(!csrf.is_excluded("/api/private"));
assert!(!csrf.is_excluded("/login"));
}
}