elif_security/middleware/
csrf.rs1use std::sync::Arc;
7use std::collections::HashMap;
8use axum::{
9 extract::Request,
10 http::{HeaderMap, Method, StatusCode, header},
11 response::{IntoResponse, Response},
12};
13use elif_http::middleware::{Middleware, BoxFuture};
14use sha2::{Sha256, Digest};
15use rand::{thread_rng, Rng};
16use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
17
18pub use crate::config::CsrfConfig;
19use crate::SecurityError;
20
21type TokenStore = Arc<tokio::sync::RwLock<HashMap<String, CsrfTokenData>>>;
23
24#[derive(Debug, Clone)]
26pub struct CsrfTokenData {
27 pub token: String,
28 pub expires_at: time::OffsetDateTime,
29 pub user_agent_hash: Option<String>,
30}
31
32#[derive(Debug, Clone)]
34pub struct CsrfMiddleware {
35 config: CsrfConfig,
36 token_store: TokenStore,
37}
38
39impl CsrfMiddleware {
40 pub fn new(config: CsrfConfig) -> Self {
42 Self {
43 config,
44 token_store: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
45 }
46 }
47
48 pub fn builder() -> CsrfMiddlewareBuilder {
50 CsrfMiddlewareBuilder::new()
51 }
52
53 pub async fn generate_token(&self, user_agent: Option<&str>) -> String {
55 let mut rng = thread_rng();
56 let token_bytes: [u8; 32] = rng.gen();
57 let token = URL_SAFE_NO_PAD.encode(token_bytes);
58
59 let user_agent_hash = user_agent.map(|ua| {
60 let mut hasher = Sha256::new();
61 hasher.update(ua.as_bytes());
62 format!("{:x}", hasher.finalize())
63 });
64
65 let token_data = CsrfTokenData {
66 token: token.clone(),
67 expires_at: time::OffsetDateTime::now_utc() +
68 time::Duration::seconds(self.config.token_lifetime as i64),
69 user_agent_hash,
70 };
71
72 let mut store = self.token_store.write().await;
74 store.insert(token.clone(), token_data);
75
76 self.cleanup_expired_tokens(&mut store).await;
78
79 token
80 }
81
82 pub async fn validate_token(&self, token: &str, user_agent: Option<&str>) -> bool {
84 let store = self.token_store.read().await;
85
86 if let Some(token_data) = store.get(token) {
87 if time::OffsetDateTime::now_utc() > token_data.expires_at {
89 return false;
90 }
91
92 if let Some(stored_hash) = &token_data.user_agent_hash {
94 if let Some(ua) = user_agent {
95 let mut hasher = Sha256::new();
96 hasher.update(ua.as_bytes());
97 let ua_hash = format!("{:x}", hasher.finalize());
98 if stored_hash != &ua_hash {
99 return false;
100 }
101 } else {
102 return false;
103 }
104 }
105
106 true
107 } else {
108 false
109 }
110 }
111
112 pub async fn consume_token(&self, token: &str) {
114 let mut store = self.token_store.write().await;
115 store.remove(token);
116 }
117
118 async fn cleanup_expired_tokens(&self, store: &mut HashMap<String, CsrfTokenData>) {
120 let now = time::OffsetDateTime::now_utc();
121 store.retain(|_, data| data.expires_at > now);
122 }
123
124 fn is_exempt_path(&self, path: &str) -> bool {
126 self.config.exempt_paths.contains(path) ||
127 self.config.exempt_paths.iter().any(|exempt| {
128 if exempt.ends_with('*') {
130 path.starts_with(&exempt[..exempt.len()-1])
131 } else {
132 path == exempt
133 }
134 })
135 }
136
137 fn extract_token(&self, headers: &HeaderMap) -> Option<String> {
139 if let Some(header_value) = headers.get(&self.config.token_header) {
141 if let Ok(token) = header_value.to_str() {
142 return Some(token.to_string());
143 }
144 }
145
146 if let Some(cookie_header) = headers.get(header::COOKIE) {
148 if let Ok(cookies) = cookie_header.to_str() {
149 for cookie in cookies.split(';') {
150 let cookie = cookie.trim();
151 if let Some((name, value)) = cookie.split_once('=') {
152 if name == self.config.cookie_name {
153 return Some(value.to_string());
154 }
155 }
156 }
157 }
158 }
159
160 None
161 }
162}
163
164impl Middleware for CsrfMiddleware {
166 fn process_request<'a>(
167 &'a self,
168 request: Request
169 ) -> BoxFuture<'a, Result<Request, Response>> {
170 Box::pin(async move {
171 let method = request.method();
172 let uri = request.uri();
173 let headers = request.headers();
174
175 if matches!(method, &Method::GET | &Method::HEAD | &Method::OPTIONS) {
177 return Ok(request);
178 }
179
180 if self.is_exempt_path(uri.path()) {
182 return Ok(request);
183 }
184
185 let user_agent = headers.get(header::USER_AGENT)
187 .and_then(|h| h.to_str().ok());
188
189 if let Some(token) = self.extract_token(headers) {
190 if self.validate_token(&token, user_agent).await {
191 return Ok(request);
194 }
195 }
196
197 let error_response = Response::builder()
199 .status(StatusCode::FORBIDDEN)
200 .header("Content-Type", "application/json")
201 .body(r#"{"error":{"code":"CSRF_VALIDATION_FAILED","message":"CSRF token validation failed"}}"#.into())
202 .unwrap();
203
204 Err(error_response)
205 })
206 }
207
208 fn name(&self) -> &'static str {
209 "CSRF Protection"
210 }
211}
212
213impl IntoResponse for SecurityError {
214 fn into_response(self) -> Response {
215 let (status, message) = match self {
216 SecurityError::CsrfValidationFailed => {
217 (StatusCode::FORBIDDEN, "CSRF token validation failed")
218 }
219 _ => (StatusCode::INTERNAL_SERVER_ERROR, "Security error"),
220 };
221
222 (status, message).into_response()
223 }
224}
225
226#[derive(Debug)]
228pub struct CsrfMiddlewareBuilder {
229 config: CsrfConfig,
230}
231
232impl CsrfMiddlewareBuilder {
233 pub fn new() -> Self {
234 Self {
235 config: CsrfConfig::default(),
236 }
237 }
238
239 pub fn token_header<S: Into<String>>(mut self, header: S) -> Self {
240 self.config.token_header = header.into();
241 self
242 }
243
244 pub fn cookie_name<S: Into<String>>(mut self, name: S) -> Self {
245 self.config.cookie_name = name.into();
246 self
247 }
248
249 pub fn token_lifetime(mut self, seconds: u64) -> Self {
250 self.config.token_lifetime = seconds;
251 self
252 }
253
254 pub fn secure_cookie(mut self, secure: bool) -> Self {
255 self.config.secure_cookie = secure;
256 self
257 }
258
259 pub fn exempt_path<S: Into<String>>(mut self, path: S) -> Self {
260 self.config.exempt_paths.insert(path.into());
261 self
262 }
263
264 pub fn exempt_paths<I, S>(mut self, paths: I) -> Self
265 where
266 I: IntoIterator<Item = S>,
267 S: Into<String>,
268 {
269 for path in paths {
270 self.config.exempt_paths.insert(path.into());
271 }
272 self
273 }
274
275 pub fn build(self) -> CsrfMiddleware {
276 CsrfMiddleware::new(self.config)
277 }
278}
279
280impl Default for CsrfMiddlewareBuilder {
281 fn default() -> Self {
282 Self::new()
283 }
284}
285
286#[cfg(test)]
287mod tests {
288 use super::*;
289 use axum::http::{HeaderValue, Method};
290 use elif_http::middleware::MiddlewarePipeline;
291 use std::collections::HashSet;
292
293
294 fn create_test_middleware() -> CsrfMiddleware {
295 let mut exempt_paths = HashSet::new();
296 exempt_paths.insert("/api/webhook".to_string());
297 exempt_paths.insert("/public/*".to_string());
298
299 let config = CsrfConfig {
300 token_header: "X-CSRF-Token".to_string(),
301 cookie_name: "_csrf_token".to_string(),
302 token_lifetime: 3600,
303 secure_cookie: false, exempt_paths,
305 };
306
307 CsrfMiddleware::new(config)
308 }
309
310 #[tokio::test]
311 async fn test_csrf_token_generation() {
312 let middleware = create_test_middleware();
313
314 let token1 = middleware.generate_token(Some("Mozilla/5.0")).await;
315 let token2 = middleware.generate_token(Some("Mozilla/5.0")).await;
316
317 assert_ne!(token1, token2);
319 assert!(token1.len() > 20); assert!(token2.len() > 20);
321 }
322
323 #[tokio::test]
324 async fn test_csrf_token_validation() {
325 let middleware = create_test_middleware();
326 let user_agent = Some("Mozilla/5.0");
327
328 let token = middleware.generate_token(user_agent).await;
329
330 assert!(middleware.validate_token(&token, user_agent).await);
332
333 assert!(!middleware.validate_token("invalid_token", user_agent).await);
335
336 assert!(!middleware.validate_token(&token, Some("Different Agent")).await);
338 }
339
340 #[tokio::test]
341 async fn test_csrf_token_expiration() {
342 let config = CsrfConfig {
343 token_lifetime: 1, ..Default::default()
345 };
346 let middleware = CsrfMiddleware::new(config);
347
348 let token = middleware.generate_token(None).await;
349
350 assert!(middleware.validate_token(&token, None).await);
352
353 tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
355
356 assert!(!middleware.validate_token(&token, None).await);
358 }
359
360 #[tokio::test]
361 async fn test_csrf_exempt_paths() {
362 let middleware = create_test_middleware();
363
364 assert!(middleware.is_exempt_path("/api/webhook"));
366
367 assert!(middleware.is_exempt_path("/public/assets/style.css"));
369 assert!(middleware.is_exempt_path("/public/images/logo.png"));
370
371 assert!(!middleware.is_exempt_path("/api/users"));
373 assert!(!middleware.is_exempt_path("/admin/dashboard"));
374 }
375
376 #[tokio::test]
377 async fn test_csrf_builder_pattern() {
378 let middleware = CsrfMiddleware::builder()
379 .token_header("X-Custom-CSRF-Token")
380 .cookie_name("_custom_csrf")
381 .token_lifetime(7200)
382 .secure_cookie(true)
383 .exempt_path("/api/public")
384 .exempt_paths(vec!["/webhook", "/status"])
385 .build();
386
387 assert_eq!(middleware.config.token_header, "X-Custom-CSRF-Token");
388 assert_eq!(middleware.config.cookie_name, "_custom_csrf");
389 assert_eq!(middleware.config.token_lifetime, 7200);
390 assert!(middleware.config.secure_cookie);
391 assert!(middleware.config.exempt_paths.contains("/api/public"));
392 assert!(middleware.config.exempt_paths.contains("/webhook"));
393 assert!(middleware.config.exempt_paths.contains("/status"));
394 }
395
396 #[tokio::test]
397 async fn test_csrf_middleware_get_requests() {
398 let middleware = create_test_middleware();
399 let pipeline = MiddlewarePipeline::new().add(middleware);
400
401 let request = Request::builder()
403 .method(Method::GET)
404 .uri("/test")
405 .body(axum::body::Body::empty())
406 .unwrap();
407
408 let result = pipeline.process_request(request).await;
410 assert!(result.is_ok());
411 }
412
413 #[tokio::test]
414 async fn test_csrf_middleware_post_without_token() {
415 let middleware = create_test_middleware();
416 let pipeline = MiddlewarePipeline::new().add(middleware);
417
418 let request = Request::builder()
420 .method(Method::POST)
421 .uri("/test")
422 .body(axum::body::Body::empty())
423 .unwrap();
424
425 let result = pipeline.process_request(request).await;
427 assert!(result.is_err());
428
429 if let Err(response) = result {
431 assert_eq!(response.status(), StatusCode::FORBIDDEN);
432 }
433 }
434
435 #[tokio::test]
436 async fn test_csrf_middleware_post_with_valid_token() {
437 let middleware = create_test_middleware();
438 let token = middleware.generate_token(Some("TestAgent")).await;
439 let pipeline = MiddlewarePipeline::new().add(middleware);
440
441 let request = Request::builder()
443 .method(Method::POST)
444 .uri("/test")
445 .header("X-CSRF-Token", &token)
446 .header("User-Agent", "TestAgent")
447 .body(axum::body::Body::empty())
448 .unwrap();
449
450 let result = pipeline.process_request(request).await;
452 assert!(result.is_ok());
453 }
454
455 #[tokio::test]
456 async fn test_csrf_middleware_exempt_paths() {
457 let middleware = create_test_middleware();
458 let pipeline = MiddlewarePipeline::new().add(middleware);
459
460 let request1 = Request::builder()
462 .method(Method::POST)
463 .uri("/api/webhook")
464 .body(axum::body::Body::empty())
465 .unwrap();
466
467 let result1 = pipeline.process_request(request1).await;
468 assert!(result1.is_ok());
469
470 let request2 = Request::builder()
472 .method(Method::POST)
473 .uri("/public/upload")
474 .body(axum::body::Body::empty())
475 .unwrap();
476
477 let result2 = pipeline.process_request(request2).await;
478 assert!(result2.is_ok());
479 }
480
481 #[tokio::test]
482 async fn test_csrf_token_cleanup() {
483 let config = CsrfConfig {
484 token_lifetime: 1, ..Default::default()
486 };
487 let middleware = CsrfMiddleware::new(config);
488
489 let _token1 = middleware.generate_token(None).await;
491 let _token2 = middleware.generate_token(None).await;
492 let _token3 = middleware.generate_token(None).await;
493
494 {
496 let store = middleware.token_store.read().await;
497 assert_eq!(store.len(), 3);
498 }
499
500 tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
502
503 let _new_token = middleware.generate_token(None).await;
505
506 {
508 let store = middleware.token_store.read().await;
509 assert_eq!(store.len(), 1); }
511 }
512
513 #[tokio::test]
514 async fn test_csrf_cookie_extraction() {
515 let middleware = create_test_middleware();
516 let mut headers = HeaderMap::new();
517
518 headers.insert(
520 header::COOKIE,
521 HeaderValue::from_str("_csrf_token=test_token_123; other_cookie=value").unwrap()
522 );
523
524 let token = middleware.extract_token(&headers);
525 assert_eq!(token, Some("test_token_123".to_string()));
526
527 headers.insert(
529 "X-CSRF-Token",
530 HeaderValue::from_str("header_token_456").unwrap()
531 );
532
533 let token = middleware.extract_token(&headers);
534 assert_eq!(token, Some("header_token_456".to_string()));
535 }
536
537 #[tokio::test]
538 async fn test_csrf_user_agent_binding() {
539 let middleware = create_test_middleware();
540
541 let token = middleware.generate_token(Some("SpecificAgent")).await;
542
543 assert!(middleware.validate_token(&token, Some("SpecificAgent")).await);
545
546 assert!(!middleware.validate_token(&token, Some("DifferentAgent")).await);
548
549 assert!(!middleware.validate_token(&token, None).await);
551 }
552}