elif_security/middleware/
csrf.rs1use std::sync::Arc;
7use std::collections::HashMap;
8use axum::{
9 extract::Request,
10 http::{HeaderMap, Method, header},
11 response::{Response, IntoResponse},
12};
13use elif_http::{
14 middleware::{Middleware, BoxFuture},
15 ElifStatusCode, };
17use sha2::{Sha256, Digest};
18use rand::{thread_rng, Rng};
19use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
20
21pub use crate::config::CsrfConfig;
22use crate::SecurityError;
23
24type TokenStore = Arc<tokio::sync::RwLock<HashMap<String, CsrfTokenData>>>;
26
27#[derive(Debug, Clone)]
29pub struct CsrfTokenData {
30 pub token: String,
31 pub expires_at: time::OffsetDateTime,
32 pub user_agent_hash: Option<String>,
33}
34
35#[derive(Debug, Clone)]
37pub struct CsrfMiddleware {
38 config: CsrfConfig,
39 token_store: TokenStore,
40}
41
42impl CsrfMiddleware {
43 pub fn new(config: CsrfConfig) -> Self {
45 Self {
46 config,
47 token_store: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
48 }
49 }
50
51 pub fn builder() -> CsrfMiddlewareBuilder {
53 CsrfMiddlewareBuilder::new()
54 }
55
56 pub async fn generate_token(&self, user_agent: Option<&str>) -> String {
58 let mut rng = thread_rng();
59 let token_bytes: [u8; 32] = rng.gen();
60 let token = URL_SAFE_NO_PAD.encode(token_bytes);
61
62 let user_agent_hash = user_agent.map(|ua| {
63 let mut hasher = Sha256::new();
64 hasher.update(ua.as_bytes());
65 format!("{:x}", hasher.finalize())
66 });
67
68 let token_data = CsrfTokenData {
69 token: token.clone(),
70 expires_at: time::OffsetDateTime::now_utc() +
71 time::Duration::seconds(self.config.token_lifetime as i64),
72 user_agent_hash,
73 };
74
75 let mut store = self.token_store.write().await;
77 store.insert(token.clone(), token_data);
78
79 self.cleanup_expired_tokens(&mut store).await;
81
82 token
83 }
84
85 pub async fn validate_token(&self, token: &str, user_agent: Option<&str>) -> bool {
87 let store = self.token_store.read().await;
88
89 if let Some(token_data) = store.get(token) {
90 if time::OffsetDateTime::now_utc() > token_data.expires_at {
92 return false;
93 }
94
95 if let Some(stored_hash) = &token_data.user_agent_hash {
97 if let Some(ua) = user_agent {
98 let mut hasher = Sha256::new();
99 hasher.update(ua.as_bytes());
100 let ua_hash = format!("{:x}", hasher.finalize());
101 if stored_hash != &ua_hash {
102 return false;
103 }
104 } else {
105 return false;
106 }
107 }
108
109 true
110 } else {
111 false
112 }
113 }
114
115 pub async fn consume_token(&self, token: &str) {
117 let mut store = self.token_store.write().await;
118 store.remove(token);
119 }
120
121 async fn cleanup_expired_tokens(&self, store: &mut HashMap<String, CsrfTokenData>) {
123 let now = time::OffsetDateTime::now_utc();
124 store.retain(|_, data| data.expires_at > now);
125 }
126
127 fn is_exempt_path(&self, path: &str) -> bool {
129 self.config.exempt_paths.contains(path) ||
130 self.config.exempt_paths.iter().any(|exempt| {
131 if exempt.ends_with('*') {
133 path.starts_with(&exempt[..exempt.len()-1])
134 } else {
135 path == exempt
136 }
137 })
138 }
139
140 fn extract_token(&self, headers: &HeaderMap) -> Option<String> {
142 if let Some(header_value) = headers.get(&self.config.token_header) {
144 if let Ok(token) = header_value.to_str() {
145 return Some(token.to_string());
146 }
147 }
148
149 if let Some(cookie_header) = headers.get(header::COOKIE) {
151 if let Ok(cookies) = cookie_header.to_str() {
152 for cookie in cookies.split(';') {
153 let cookie = cookie.trim();
154 if let Some((name, value)) = cookie.split_once('=') {
155 if name == self.config.cookie_name {
156 return Some(value.to_string());
157 }
158 }
159 }
160 }
161 }
162
163 None
164 }
165}
166
167impl Middleware for CsrfMiddleware {
169 fn process_request<'a>(
170 &'a self,
171 request: Request
172 ) -> BoxFuture<'a, Result<Request, Response>> {
173 Box::pin(async move {
174 let method = request.method();
175 let uri = request.uri();
176 let headers = request.headers();
177
178 if matches!(method, &Method::GET | &Method::HEAD | &Method::OPTIONS) {
180 return Ok(request);
181 }
182
183 if self.is_exempt_path(uri.path()) {
185 return Ok(request);
186 }
187
188 let user_agent = headers.get(header::USER_AGENT)
190 .and_then(|h| h.to_str().ok());
191
192 if let Some(token) = self.extract_token(headers) {
193 if self.validate_token(&token, user_agent).await {
194 return Ok(request);
197 }
198 }
199
200 let error_response = Response::builder()
202 .status(ElifStatusCode::FORBIDDEN)
203 .header("Content-Type", "application/json")
204 .body(r#"{"error":{"code":"CSRF_VALIDATION_FAILED","message":"CSRF token validation failed"}}"#.into())
205 .unwrap();
206
207 Err(error_response)
208 })
209 }
210
211 fn name(&self) -> &'static str {
212 "CsrfMiddleware"
213 }
214}
215
216impl IntoResponse for SecurityError {
217 fn into_response(self) -> Response {
218 let (status, message) = match self {
219 SecurityError::CsrfValidationFailed => {
220 (ElifStatusCode::FORBIDDEN, "CSRF token validation failed")
221 }
222 _ => (ElifStatusCode::INTERNAL_SERVER_ERROR, "Security error"),
223 };
224
225 (status, message).into_response()
226 }
227}
228
229#[derive(Debug)]
231pub struct CsrfMiddlewareBuilder {
232 config: CsrfConfig,
233}
234
235impl CsrfMiddlewareBuilder {
236 pub fn new() -> Self {
237 Self {
238 config: CsrfConfig::default(),
239 }
240 }
241
242 pub fn token_header<S: Into<String>>(mut self, header: S) -> Self {
243 self.config.token_header = header.into();
244 self
245 }
246
247 pub fn cookie_name<S: Into<String>>(mut self, name: S) -> Self {
248 self.config.cookie_name = name.into();
249 self
250 }
251
252 pub fn token_lifetime(mut self, seconds: u64) -> Self {
253 self.config.token_lifetime = seconds;
254 self
255 }
256
257 pub fn secure_cookie(mut self, secure: bool) -> Self {
258 self.config.secure_cookie = secure;
259 self
260 }
261
262 pub fn exempt_path<S: Into<String>>(mut self, path: S) -> Self {
263 self.config.exempt_paths.insert(path.into());
264 self
265 }
266
267 pub fn exempt_paths<I, S>(mut self, paths: I) -> Self
268 where
269 I: IntoIterator<Item = S>,
270 S: Into<String>,
271 {
272 for path in paths {
273 self.config.exempt_paths.insert(path.into());
274 }
275 self
276 }
277
278 pub fn build(self) -> CsrfMiddleware {
279 CsrfMiddleware::new(self.config)
280 }
281}
282
283impl Default for CsrfMiddlewareBuilder {
284 fn default() -> Self {
285 Self::new()
286 }
287}
288
289#[cfg(test)]
290mod tests {
291 use super::*;
292 use axum::http::{HeaderValue, Method};
293 use elif_http::middleware::MiddlewarePipeline;
294 use std::collections::HashSet;
295
296
297 fn create_test_middleware() -> CsrfMiddleware {
298 let mut exempt_paths = HashSet::new();
299 exempt_paths.insert("/api/webhook".to_string());
300 exempt_paths.insert("/public/*".to_string());
301
302 let config = CsrfConfig {
303 token_header: "X-CSRF-Token".to_string(),
304 cookie_name: "_csrf_token".to_string(),
305 token_lifetime: 3600,
306 secure_cookie: false, exempt_paths,
308 };
309
310 CsrfMiddleware::new(config)
311 }
312
313 #[tokio::test]
314 async fn test_csrf_token_generation() {
315 let middleware = create_test_middleware();
316
317 let token1 = middleware.generate_token(Some("Mozilla/5.0")).await;
318 let token2 = middleware.generate_token(Some("Mozilla/5.0")).await;
319
320 assert_ne!(token1, token2);
322 assert!(token1.len() > 20); assert!(token2.len() > 20);
324 }
325
326 #[tokio::test]
327 async fn test_csrf_token_validation() {
328 let middleware = create_test_middleware();
329 let user_agent = Some("Mozilla/5.0");
330
331 let token = middleware.generate_token(user_agent).await;
332
333 assert!(middleware.validate_token(&token, user_agent).await);
335
336 assert!(!middleware.validate_token("invalid_token", user_agent).await);
338
339 assert!(!middleware.validate_token(&token, Some("Different Agent")).await);
341 }
342
343 #[tokio::test]
344 async fn test_csrf_token_expiration() {
345 let config = CsrfConfig {
346 token_lifetime: 1, ..Default::default()
348 };
349 let middleware = CsrfMiddleware::new(config);
350
351 let token = middleware.generate_token(None).await;
352
353 assert!(middleware.validate_token(&token, None).await);
355
356 tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
358
359 assert!(!middleware.validate_token(&token, None).await);
361 }
362
363 #[tokio::test]
364 async fn test_csrf_exempt_paths() {
365 let middleware = create_test_middleware();
366
367 assert!(middleware.is_exempt_path("/api/webhook"));
369
370 assert!(middleware.is_exempt_path("/public/assets/style.css"));
372 assert!(middleware.is_exempt_path("/public/images/logo.png"));
373
374 assert!(!middleware.is_exempt_path("/api/users"));
376 assert!(!middleware.is_exempt_path("/admin/dashboard"));
377 }
378
379 #[tokio::test]
380 async fn test_csrf_builder_pattern() {
381 let middleware = CsrfMiddleware::builder()
382 .token_header("X-Custom-CSRF-Token")
383 .cookie_name("_custom_csrf")
384 .token_lifetime(7200)
385 .secure_cookie(true)
386 .exempt_path("/api/public")
387 .exempt_paths(vec!["/webhook", "/status"])
388 .build();
389
390 assert_eq!(middleware.config.token_header, "X-Custom-CSRF-Token");
391 assert_eq!(middleware.config.cookie_name, "_custom_csrf");
392 assert_eq!(middleware.config.token_lifetime, 7200);
393 assert!(middleware.config.secure_cookie);
394 assert!(middleware.config.exempt_paths.contains("/api/public"));
395 assert!(middleware.config.exempt_paths.contains("/webhook"));
396 assert!(middleware.config.exempt_paths.contains("/status"));
397 }
398
399 #[tokio::test]
400 async fn test_csrf_middleware_get_requests() {
401 let middleware = create_test_middleware();
402 let pipeline = MiddlewarePipeline::new().add(middleware);
403
404 let request = Request::builder()
406 .method(Method::GET)
407 .uri("/test")
408 .body(axum::body::Body::empty())
409 .unwrap();
410
411 let result = pipeline.process_request(request).await;
413 assert!(result.is_ok());
414 }
415
416 #[tokio::test]
417 async fn test_csrf_middleware_post_without_token() {
418 let middleware = create_test_middleware();
419 let pipeline = MiddlewarePipeline::new().add(middleware);
420
421 let request = Request::builder()
423 .method(Method::POST)
424 .uri("/test")
425 .body(axum::body::Body::empty())
426 .unwrap();
427
428 let result = pipeline.process_request(request).await;
430 assert!(result.is_err());
431
432 if let Err(response) = result {
434 assert_eq!(response.status(), ElifStatusCode::FORBIDDEN);
435 }
436 }
437
438 #[tokio::test]
439 async fn test_csrf_middleware_post_with_valid_token() {
440 let middleware = create_test_middleware();
441 let token = middleware.generate_token(Some("TestAgent")).await;
442 let pipeline = MiddlewarePipeline::new().add(middleware);
443
444 let request = Request::builder()
446 .method(Method::POST)
447 .uri("/test")
448 .header("X-CSRF-Token", &token)
449 .header("User-Agent", "TestAgent")
450 .body(axum::body::Body::empty())
451 .unwrap();
452
453 let result = pipeline.process_request(request).await;
455 assert!(result.is_ok());
456 }
457
458 #[tokio::test]
459 async fn test_csrf_middleware_exempt_paths() {
460 let middleware = create_test_middleware();
461 let pipeline = MiddlewarePipeline::new().add(middleware);
462
463 let request1 = Request::builder()
465 .method(Method::POST)
466 .uri("/api/webhook")
467 .body(axum::body::Body::empty())
468 .unwrap();
469
470 let result1 = pipeline.process_request(request1).await;
471 assert!(result1.is_ok());
472
473 let request2 = Request::builder()
475 .method(Method::POST)
476 .uri("/public/upload")
477 .body(axum::body::Body::empty())
478 .unwrap();
479
480 let result2 = pipeline.process_request(request2).await;
481 assert!(result2.is_ok());
482 }
483
484 #[tokio::test]
485 async fn test_csrf_token_cleanup() {
486 let config = CsrfConfig {
487 token_lifetime: 1, ..Default::default()
489 };
490 let middleware = CsrfMiddleware::new(config);
491
492 let _token1 = middleware.generate_token(None).await;
494 let _token2 = middleware.generate_token(None).await;
495 let _token3 = middleware.generate_token(None).await;
496
497 {
499 let store = middleware.token_store.read().await;
500 assert_eq!(store.len(), 3);
501 }
502
503 tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
505
506 let _new_token = middleware.generate_token(None).await;
508
509 {
511 let store = middleware.token_store.read().await;
512 assert_eq!(store.len(), 1); }
514 }
515
516 #[tokio::test]
517 async fn test_csrf_cookie_extraction() {
518 let middleware = create_test_middleware();
519 let mut headers = HeaderMap::new();
520
521 headers.insert(
523 header::COOKIE,
524 HeaderValue::from_str("_csrf_token=test_token_123; other_cookie=value").unwrap()
525 );
526
527 let token = middleware.extract_token(&headers);
528 assert_eq!(token, Some("test_token_123".to_string()));
529
530 headers.insert(
532 "X-CSRF-Token",
533 HeaderValue::from_str("header_token_456").unwrap()
534 );
535
536 let token = middleware.extract_token(&headers);
537 assert_eq!(token, Some("header_token_456".to_string()));
538 }
539
540 #[tokio::test]
541 async fn test_csrf_user_agent_binding() {
542 let middleware = create_test_middleware();
543
544 let token = middleware.generate_token(Some("SpecificAgent")).await;
545
546 assert!(middleware.validate_token(&token, Some("SpecificAgent")).await);
548
549 assert!(!middleware.validate_token(&token, Some("DifferentAgent")).await);
551
552 assert!(!middleware.validate_token(&token, None).await);
554 }
555}