ferro_rs/csrf/
middleware.rs1use crate::http::{HttpResponse, Response};
4use crate::middleware::{Middleware, Next};
5use crate::session::get_csrf_token;
6use crate::Request;
7use async_trait::async_trait;
8
9pub struct CsrfMiddleware {
28 protected_methods: Vec<&'static str>,
30 except: Vec<String>,
32}
33
34impl CsrfMiddleware {
35 pub fn new() -> Self {
39 Self {
40 protected_methods: vec!["POST", "PUT", "PATCH", "DELETE"],
41 except: Vec::new(),
42 }
43 }
44
45 pub fn except(mut self, paths: Vec<impl Into<String>>) -> Self {
56 self.except = paths.into_iter().map(|p| p.into()).collect();
57 self
58 }
59
60 fn is_excluded(&self, path: &str) -> bool {
62 for pattern in &self.except {
63 if pattern.ends_with('*') {
64 let prefix = &pattern[..pattern.len() - 1];
65 if path.starts_with(prefix) {
66 return true;
67 }
68 } else if pattern == path {
69 return true;
70 }
71 }
72 false
73 }
74}
75
76impl Default for CsrfMiddleware {
77 fn default() -> Self {
78 Self::new()
79 }
80}
81
82#[async_trait]
83impl Middleware for CsrfMiddleware {
84 async fn handle(&self, request: Request, next: Next) -> Response {
85 let method = request.method().as_str();
86
87 if !self.protected_methods.contains(&method) {
89 return next(request).await;
90 }
91
92 if self.is_excluded(request.path()) {
94 return next(request).await;
95 }
96
97 let expected_token = match get_csrf_token() {
99 Some(token) => token,
100 None => {
101 return Err(HttpResponse::json(serde_json::json!({
102 "message": "Session not found. CSRF validation failed."
103 }))
104 .status(500));
105 }
106 };
107
108 let provided_token = request
111 .header("X-CSRF-TOKEN")
112 .or_else(|| request.header("X-XSRF-TOKEN"))
113 .map(|s| s.to_string());
114
115 match provided_token {
116 Some(token) if constant_time_compare(&token, &expected_token) => {
117 next(request).await
119 }
120 _ => {
121 Err(HttpResponse::json(serde_json::json!({
124 "message": "CSRF token mismatch."
125 }))
126 .status(419))
127 }
128 }
129 }
130}
131
132fn constant_time_compare(a: &str, b: &str) -> bool {
137 if a.len() != b.len() {
138 return false;
139 }
140
141 a.bytes()
142 .zip(b.bytes())
143 .fold(0, |acc, (x, y)| acc | (x ^ y))
144 == 0
145}
146
147#[cfg(test)]
148mod tests {
149 use super::*;
150
151 #[test]
152 fn test_constant_time_compare() {
153 assert!(constant_time_compare("abc123", "abc123"));
154 assert!(!constant_time_compare("abc123", "abc124"));
155 assert!(!constant_time_compare("abc123", "abc12"));
156 assert!(!constant_time_compare("", "a"));
157 }
158
159 #[test]
160 fn test_is_excluded() {
161 let csrf = CsrfMiddleware::new().except(vec!["/webhooks/*", "/api/public"]);
162
163 assert!(csrf.is_excluded("/webhooks/stripe"));
164 assert!(csrf.is_excluded("/webhooks/github/events"));
165 assert!(csrf.is_excluded("/api/public"));
166 assert!(!csrf.is_excluded("/api/private"));
167 assert!(!csrf.is_excluded("/login"));
168 }
169}