elif_security/middleware/
csrf.rs1use std::sync::Arc;
7use std::collections::HashMap;
8use axum::{
9 extract::{Request, State},
10 http::{HeaderMap, Method, StatusCode, header},
11 middleware::Next,
12 response::{IntoResponse, Response},
13};
14use sha2::{Sha256, Digest};
15use rand::{thread_rng, Rng};
16use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
17
18pub use crate::config::CsrfConfig;
19use crate::SecurityError;
20
21type TokenStore = Arc<tokio::sync::RwLock<HashMap<String, CsrfTokenData>>>;
23
24#[derive(Debug, Clone)]
26pub struct CsrfTokenData {
27 pub token: String,
28 pub expires_at: time::OffsetDateTime,
29 pub user_agent_hash: Option<String>,
30}
31
32#[derive(Debug, Clone)]
34pub struct CsrfMiddleware {
35 config: CsrfConfig,
36 token_store: TokenStore,
37}
38
39impl CsrfMiddleware {
40 pub fn new(config: CsrfConfig) -> Self {
42 Self {
43 config,
44 token_store: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
45 }
46 }
47
48 pub fn builder() -> CsrfMiddlewareBuilder {
50 CsrfMiddlewareBuilder::new()
51 }
52
53 pub async fn generate_token(&self, user_agent: Option<&str>) -> String {
55 let mut rng = thread_rng();
56 let token_bytes: [u8; 32] = rng.gen();
57 let token = URL_SAFE_NO_PAD.encode(token_bytes);
58
59 let user_agent_hash = user_agent.map(|ua| {
60 let mut hasher = Sha256::new();
61 hasher.update(ua.as_bytes());
62 format!("{:x}", hasher.finalize())
63 });
64
65 let token_data = CsrfTokenData {
66 token: token.clone(),
67 expires_at: time::OffsetDateTime::now_utc() +
68 time::Duration::seconds(self.config.token_lifetime as i64),
69 user_agent_hash,
70 };
71
72 let mut store = self.token_store.write().await;
74 store.insert(token.clone(), token_data);
75
76 self.cleanup_expired_tokens(&mut store).await;
78
79 token
80 }
81
82 pub async fn validate_token(&self, token: &str, user_agent: Option<&str>) -> bool {
84 let store = self.token_store.read().await;
85
86 if let Some(token_data) = store.get(token) {
87 if time::OffsetDateTime::now_utc() > token_data.expires_at {
89 return false;
90 }
91
92 if let Some(stored_hash) = &token_data.user_agent_hash {
94 if let Some(ua) = user_agent {
95 let mut hasher = Sha256::new();
96 hasher.update(ua.as_bytes());
97 let ua_hash = format!("{:x}", hasher.finalize());
98 if stored_hash != &ua_hash {
99 return false;
100 }
101 } else {
102 return false;
103 }
104 }
105
106 true
107 } else {
108 false
109 }
110 }
111
112 pub async fn consume_token(&self, token: &str) {
114 let mut store = self.token_store.write().await;
115 store.remove(token);
116 }
117
118 async fn cleanup_expired_tokens(&self, store: &mut HashMap<String, CsrfTokenData>) {
120 let now = time::OffsetDateTime::now_utc();
121 store.retain(|_, data| data.expires_at > now);
122 }
123
124 fn is_exempt_path(&self, path: &str) -> bool {
126 self.config.exempt_paths.contains(path) ||
127 self.config.exempt_paths.iter().any(|exempt| {
128 if exempt.ends_with('*') {
130 path.starts_with(&exempt[..exempt.len()-1])
131 } else {
132 path == exempt
133 }
134 })
135 }
136
137 fn extract_token(&self, headers: &HeaderMap) -> Option<String> {
139 if let Some(header_value) = headers.get(&self.config.token_header) {
141 if let Ok(token) = header_value.to_str() {
142 return Some(token.to_string());
143 }
144 }
145
146 if let Some(cookie_header) = headers.get(header::COOKIE) {
148 if let Ok(cookies) = cookie_header.to_str() {
149 for cookie in cookies.split(';') {
150 let cookie = cookie.trim();
151 if let Some((name, value)) = cookie.split_once('=') {
152 if name == self.config.cookie_name {
153 return Some(value.to_string());
154 }
155 }
156 }
157 }
158 }
159
160 None
161 }
162}
163
164pub async fn csrf_middleware(
166 State(middleware): State<CsrfMiddleware>,
167 request: Request,
168 next: Next,
169) -> Result<Response, SecurityError> {
170 let method = request.method();
171 let uri = request.uri();
172 let headers = request.headers();
173
174 if matches!(method, &Method::GET | &Method::HEAD | &Method::OPTIONS) {
176 return Ok(next.run(request).await);
177 }
178
179 if middleware.is_exempt_path(uri.path()) {
181 return Ok(next.run(request).await);
182 }
183
184 let user_agent = headers.get(header::USER_AGENT)
186 .and_then(|h| h.to_str().ok());
187
188 if let Some(token) = middleware.extract_token(headers) {
189 if middleware.validate_token(&token, user_agent).await {
190 return Ok(next.run(request).await);
193 }
194 }
195
196 Err(SecurityError::CsrfValidationFailed)
198}
199
200impl IntoResponse for SecurityError {
201 fn into_response(self) -> Response {
202 let (status, message) = match self {
203 SecurityError::CsrfValidationFailed => {
204 (StatusCode::FORBIDDEN, "CSRF token validation failed")
205 }
206 _ => (StatusCode::INTERNAL_SERVER_ERROR, "Security error"),
207 };
208
209 (status, message).into_response()
210 }
211}
212
213#[derive(Debug)]
215pub struct CsrfMiddlewareBuilder {
216 config: CsrfConfig,
217}
218
219impl CsrfMiddlewareBuilder {
220 pub fn new() -> Self {
221 Self {
222 config: CsrfConfig::default(),
223 }
224 }
225
226 pub fn token_header<S: Into<String>>(mut self, header: S) -> Self {
227 self.config.token_header = header.into();
228 self
229 }
230
231 pub fn cookie_name<S: Into<String>>(mut self, name: S) -> Self {
232 self.config.cookie_name = name.into();
233 self
234 }
235
236 pub fn token_lifetime(mut self, seconds: u64) -> Self {
237 self.config.token_lifetime = seconds;
238 self
239 }
240
241 pub fn secure_cookie(mut self, secure: bool) -> Self {
242 self.config.secure_cookie = secure;
243 self
244 }
245
246 pub fn exempt_path<S: Into<String>>(mut self, path: S) -> Self {
247 self.config.exempt_paths.insert(path.into());
248 self
249 }
250
251 pub fn exempt_paths<I, S>(mut self, paths: I) -> Self
252 where
253 I: IntoIterator<Item = S>,
254 S: Into<String>,
255 {
256 for path in paths {
257 self.config.exempt_paths.insert(path.into());
258 }
259 self
260 }
261
262 pub fn build(self) -> CsrfMiddleware {
263 CsrfMiddleware::new(self.config)
264 }
265}
266
267impl Default for CsrfMiddlewareBuilder {
268 fn default() -> Self {
269 Self::new()
270 }
271}
272
273#[cfg(test)]
274mod tests {
275 use super::*;
276 use axum::{
277 http::HeaderValue,
278 middleware::from_fn_with_state,
279 routing::{get, post},
280 Router,
281 };
282 use axum_test::TestServer;
283 use std::collections::HashSet;
284
285 async fn test_handler() -> &'static str {
286 "OK"
287 }
288
289 fn create_test_middleware() -> CsrfMiddleware {
290 let mut exempt_paths = HashSet::new();
291 exempt_paths.insert("/api/webhook".to_string());
292 exempt_paths.insert("/public/*".to_string());
293
294 let config = CsrfConfig {
295 token_header: "X-CSRF-Token".to_string(),
296 cookie_name: "_csrf_token".to_string(),
297 token_lifetime: 3600,
298 secure_cookie: false, exempt_paths,
300 };
301
302 CsrfMiddleware::new(config)
303 }
304
305 #[tokio::test]
306 async fn test_csrf_token_generation() {
307 let middleware = create_test_middleware();
308
309 let token1 = middleware.generate_token(Some("Mozilla/5.0")).await;
310 let token2 = middleware.generate_token(Some("Mozilla/5.0")).await;
311
312 assert_ne!(token1, token2);
314 assert!(token1.len() > 20); assert!(token2.len() > 20);
316 }
317
318 #[tokio::test]
319 async fn test_csrf_token_validation() {
320 let middleware = create_test_middleware();
321 let user_agent = Some("Mozilla/5.0");
322
323 let token = middleware.generate_token(user_agent).await;
324
325 assert!(middleware.validate_token(&token, user_agent).await);
327
328 assert!(!middleware.validate_token("invalid_token", user_agent).await);
330
331 assert!(!middleware.validate_token(&token, Some("Different Agent")).await);
333 }
334
335 #[tokio::test]
336 async fn test_csrf_token_expiration() {
337 let config = CsrfConfig {
338 token_lifetime: 1, ..Default::default()
340 };
341 let middleware = CsrfMiddleware::new(config);
342
343 let token = middleware.generate_token(None).await;
344
345 assert!(middleware.validate_token(&token, None).await);
347
348 tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
350
351 assert!(!middleware.validate_token(&token, None).await);
353 }
354
355 #[tokio::test]
356 async fn test_csrf_exempt_paths() {
357 let middleware = create_test_middleware();
358
359 assert!(middleware.is_exempt_path("/api/webhook"));
361
362 assert!(middleware.is_exempt_path("/public/assets/style.css"));
364 assert!(middleware.is_exempt_path("/public/images/logo.png"));
365
366 assert!(!middleware.is_exempt_path("/api/users"));
368 assert!(!middleware.is_exempt_path("/admin/dashboard"));
369 }
370
371 #[tokio::test]
372 async fn test_csrf_builder_pattern() {
373 let middleware = CsrfMiddleware::builder()
374 .token_header("X-Custom-CSRF-Token")
375 .cookie_name("_custom_csrf")
376 .token_lifetime(7200)
377 .secure_cookie(true)
378 .exempt_path("/api/public")
379 .exempt_paths(vec!["/webhook", "/status"])
380 .build();
381
382 assert_eq!(middleware.config.token_header, "X-Custom-CSRF-Token");
383 assert_eq!(middleware.config.cookie_name, "_custom_csrf");
384 assert_eq!(middleware.config.token_lifetime, 7200);
385 assert!(middleware.config.secure_cookie);
386 assert!(middleware.config.exempt_paths.contains("/api/public"));
387 assert!(middleware.config.exempt_paths.contains("/webhook"));
388 assert!(middleware.config.exempt_paths.contains("/status"));
389 }
390
391 #[tokio::test]
392 async fn test_csrf_middleware_get_requests() {
393 let middleware = create_test_middleware();
394
395 let app = Router::new()
396 .route("/test", get(test_handler))
397 .layer(from_fn_with_state(middleware, csrf_middleware));
398
399 let server = TestServer::new(app).unwrap();
400
401 let response = server.get("/test").await;
403 response.assert_status_ok();
404 response.assert_text("OK");
405 }
406
407 #[tokio::test]
408 async fn test_csrf_middleware_post_without_token() {
409 let middleware = create_test_middleware();
410
411 let app = Router::new()
412 .route("/test", post(test_handler))
413 .layer(from_fn_with_state(middleware, csrf_middleware));
414
415 let server = TestServer::new(app).unwrap();
416
417 let response = server.post("/test").await;
419 response.assert_status_forbidden();
420 }
421
422 #[tokio::test]
423 async fn test_csrf_middleware_post_with_valid_token() {
424 let middleware = create_test_middleware();
425 let token = middleware.generate_token(Some("TestAgent")).await;
426
427 let app = Router::new()
428 .route("/test", post(test_handler))
429 .layer(from_fn_with_state(middleware, csrf_middleware));
430
431 let server = TestServer::new(app).unwrap();
432
433 let response = server
435 .post("/test")
436 .add_header("X-CSRF-Token", &token)
437 .add_header("User-Agent", "TestAgent")
438 .await;
439
440 response.assert_status_ok();
441 response.assert_text("OK");
442 }
443
444 #[tokio::test]
445 async fn test_csrf_middleware_exempt_paths() {
446 let middleware = create_test_middleware();
447
448 let app = Router::new()
449 .route("/api/webhook", post(test_handler))
450 .route("/public/upload", post(test_handler))
451 .layer(from_fn_with_state(middleware, csrf_middleware));
452
453 let server = TestServer::new(app).unwrap();
454
455 let response1 = server.post("/api/webhook").await;
457 response1.assert_status_ok();
458
459 let response2 = server.post("/public/upload").await;
460 response2.assert_status_ok();
461 }
462
463 #[tokio::test]
464 async fn test_csrf_token_cleanup() {
465 let config = CsrfConfig {
466 token_lifetime: 1, ..Default::default()
468 };
469 let middleware = CsrfMiddleware::new(config);
470
471 let _token1 = middleware.generate_token(None).await;
473 let _token2 = middleware.generate_token(None).await;
474 let _token3 = middleware.generate_token(None).await;
475
476 {
478 let store = middleware.token_store.read().await;
479 assert_eq!(store.len(), 3);
480 }
481
482 tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
484
485 let _new_token = middleware.generate_token(None).await;
487
488 {
490 let store = middleware.token_store.read().await;
491 assert_eq!(store.len(), 1); }
493 }
494
495 #[tokio::test]
496 async fn test_csrf_cookie_extraction() {
497 let middleware = create_test_middleware();
498 let mut headers = HeaderMap::new();
499
500 headers.insert(
502 header::COOKIE,
503 HeaderValue::from_str("_csrf_token=test_token_123; other_cookie=value").unwrap()
504 );
505
506 let token = middleware.extract_token(&headers);
507 assert_eq!(token, Some("test_token_123".to_string()));
508
509 headers.insert(
511 "X-CSRF-Token",
512 HeaderValue::from_str("header_token_456").unwrap()
513 );
514
515 let token = middleware.extract_token(&headers);
516 assert_eq!(token, Some("header_token_456".to_string()));
517 }
518
519 #[tokio::test]
520 async fn test_csrf_user_agent_binding() {
521 let middleware = create_test_middleware();
522
523 let token = middleware.generate_token(Some("SpecificAgent")).await;
524
525 assert!(middleware.validate_token(&token, Some("SpecificAgent")).await);
527
528 assert!(!middleware.validate_token(&token, Some("DifferentAgent")).await);
530
531 assert!(!middleware.validate_token(&token, None).await);
533 }
534}