Skip to main content

elif_security/middleware/
csrf.rs

1//! CSRF (Cross-Site Request Forgery) protection middleware
2//!
3//! Provides comprehensive CSRF protection including token generation,
4//! validation, and secure cookie handling.
5
6use std::sync::Arc;
7use std::collections::HashMap;
8use axum::{
9    extract::Request,
10    http::{HeaderMap, Method, header},
11    response::{Response, IntoResponse},
12};
13use elif_http::{
14    middleware::{Middleware, BoxFuture},
15    ElifStatusCode,  // Use framework-native status codes
16};
17use sha2::{Sha256, Digest};
18use rand::{thread_rng, Rng};
19use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
20
21pub use crate::config::CsrfConfig;
22use crate::SecurityError;
23
24/// CSRF token store - in production this would be backed by Redis/database
25type TokenStore = Arc<tokio::sync::RwLock<HashMap<String, CsrfTokenData>>>;
26
27/// CSRF token data with expiration
28#[derive(Debug, Clone)]
29pub struct CsrfTokenData {
30    pub token: String,
31    pub expires_at: time::OffsetDateTime,
32    pub user_agent_hash: Option<String>,
33}
34
35/// CSRF protection middleware
36#[derive(Debug, Clone)]
37pub struct CsrfMiddleware {
38    config: CsrfConfig,
39    token_store: TokenStore,
40}
41
42impl CsrfMiddleware {
43    /// Create new CSRF middleware with configuration
44    pub fn new(config: CsrfConfig) -> Self {
45        Self { 
46            config,
47            token_store: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
48        }
49    }
50    
51    /// Create middleware with builder pattern
52    pub fn builder() -> CsrfMiddlewareBuilder {
53        CsrfMiddlewareBuilder::new()
54    }
55    
56    /// Generate a new CSRF token
57    pub async fn generate_token(&self, user_agent: Option<&str>) -> String {
58        let mut rng = thread_rng();
59        let token_bytes: [u8; 32] = rng.gen();
60        let token = URL_SAFE_NO_PAD.encode(token_bytes);
61        
62        let user_agent_hash = user_agent.map(|ua| {
63            let mut hasher = Sha256::new();
64            hasher.update(ua.as_bytes());
65            format!("{:x}", hasher.finalize())
66        });
67        
68        let token_data = CsrfTokenData {
69            token: token.clone(),
70            expires_at: time::OffsetDateTime::now_utc() + 
71                time::Duration::seconds(self.config.token_lifetime as i64),
72            user_agent_hash,
73        };
74        
75        // Store token
76        let mut store = self.token_store.write().await;
77        store.insert(token.clone(), token_data);
78        
79        // Clean up expired tokens periodically
80        self.cleanup_expired_tokens(&mut store).await;
81        
82        token
83    }
84    
85    /// Validate a CSRF token
86    pub async fn validate_token(&self, token: &str, user_agent: Option<&str>) -> bool {
87        let store = self.token_store.read().await;
88        
89        if let Some(token_data) = store.get(token) {
90            // Check expiration
91            if time::OffsetDateTime::now_utc() > token_data.expires_at {
92                return false;
93            }
94            
95            // Check user agent if configured
96            if let Some(stored_hash) = &token_data.user_agent_hash {
97                if let Some(ua) = user_agent {
98                    let mut hasher = Sha256::new();
99                    hasher.update(ua.as_bytes());
100                    let ua_hash = format!("{:x}", hasher.finalize());
101                    if stored_hash != &ua_hash {
102                        return false;
103                    }
104                } else {
105                    return false;
106                }
107            }
108            
109            true
110        } else {
111            false
112        }
113    }
114    
115    /// Remove a token after successful validation (single-use)
116    pub async fn consume_token(&self, token: &str) {
117        let mut store = self.token_store.write().await;
118        store.remove(token);
119    }
120    
121    /// Clean up expired tokens
122    async fn cleanup_expired_tokens(&self, store: &mut HashMap<String, CsrfTokenData>) {
123        let now = time::OffsetDateTime::now_utc();
124        store.retain(|_, data| data.expires_at > now);
125    }
126    
127    /// Check if path is exempt from CSRF protection
128    fn is_exempt_path(&self, path: &str) -> bool {
129        self.config.exempt_paths.contains(path) ||
130        self.config.exempt_paths.iter().any(|exempt| {
131            // Simple glob pattern matching
132            if exempt.ends_with('*') {
133                path.starts_with(&exempt[..exempt.len()-1])
134            } else {
135                path == exempt
136            }
137        })
138    }
139    
140    /// Extract CSRF token from request
141    fn extract_token(&self, headers: &HeaderMap) -> Option<String> {
142        // Try header first
143        if let Some(header_value) = headers.get(&self.config.token_header) {
144            if let Ok(token) = header_value.to_str() {
145                return Some(token.to_string());
146            }
147        }
148        
149        // Try cookie (would need cookie parsing here - simplified for now)
150        if let Some(cookie_header) = headers.get(header::COOKIE) {
151            if let Ok(cookies) = cookie_header.to_str() {
152                for cookie in cookies.split(';') {
153                    let cookie = cookie.trim();
154                    if let Some((name, value)) = cookie.split_once('=') {
155                        if name == self.config.cookie_name {
156                            return Some(value.to_string());
157                        }
158                    }
159                }
160            }
161        }
162        
163        None
164    }
165}
166
167/// Implementation of our Middleware trait for CSRF protection
168impl Middleware for CsrfMiddleware {
169    fn process_request<'a>(
170        &'a self, 
171        request: Request
172    ) -> BoxFuture<'a, Result<Request, Response>> {
173        Box::pin(async move {
174            let method = request.method();
175            let uri = request.uri();
176            let headers = request.headers();
177            
178            // Skip CSRF protection for safe methods (GET, HEAD, OPTIONS)
179            if matches!(method, &Method::GET | &Method::HEAD | &Method::OPTIONS) {
180                return Ok(request);
181            }
182            
183            // Skip exempt paths
184            if self.is_exempt_path(uri.path()) {
185                return Ok(request);
186            }
187            
188            // Extract and validate token
189            let user_agent = headers.get(header::USER_AGENT)
190                .and_then(|h| h.to_str().ok());
191                
192            if let Some(token) = self.extract_token(headers) {
193                if self.validate_token(&token, user_agent).await {
194                    // Consume token for single-use (optional - can be configured)
195                    // self.consume_token(&token).await;
196                    return Ok(request);
197                }
198            }
199            
200            // CSRF validation failed - return 403 Forbidden
201            let error_response = Response::builder()
202                .status(ElifStatusCode::FORBIDDEN)
203                .header("Content-Type", "application/json")
204                .body(r#"{"error":{"code":"CSRF_VALIDATION_FAILED","message":"CSRF token validation failed"}}"#.into())
205                .unwrap();
206                
207            Err(error_response)
208        })
209    }
210    
211    fn name(&self) -> &'static str {
212        "CsrfMiddleware"
213    }
214}
215
216impl IntoResponse for SecurityError {
217    fn into_response(self) -> Response {
218        let (status, message) = match self {
219            SecurityError::CsrfValidationFailed => {
220                (ElifStatusCode::FORBIDDEN, "CSRF token validation failed")
221            }
222            _ => (ElifStatusCode::INTERNAL_SERVER_ERROR, "Security error"),
223        };
224        
225        (status, message).into_response()
226    }
227}
228
229/// Builder for CSRF middleware configuration
230#[derive(Debug)]
231pub struct CsrfMiddlewareBuilder {
232    config: CsrfConfig,
233}
234
235impl CsrfMiddlewareBuilder {
236    pub fn new() -> Self {
237        Self {
238            config: CsrfConfig::default(),
239        }
240    }
241    
242    pub fn token_header<S: Into<String>>(mut self, header: S) -> Self {
243        self.config.token_header = header.into();
244        self
245    }
246    
247    pub fn cookie_name<S: Into<String>>(mut self, name: S) -> Self {
248        self.config.cookie_name = name.into();
249        self
250    }
251    
252    pub fn token_lifetime(mut self, seconds: u64) -> Self {
253        self.config.token_lifetime = seconds;
254        self
255    }
256    
257    pub fn secure_cookie(mut self, secure: bool) -> Self {
258        self.config.secure_cookie = secure;
259        self
260    }
261    
262    pub fn exempt_path<S: Into<String>>(mut self, path: S) -> Self {
263        self.config.exempt_paths.insert(path.into());
264        self
265    }
266    
267    pub fn exempt_paths<I, S>(mut self, paths: I) -> Self 
268    where 
269        I: IntoIterator<Item = S>,
270        S: Into<String>,
271    {
272        for path in paths {
273            self.config.exempt_paths.insert(path.into());
274        }
275        self
276    }
277    
278    pub fn build(self) -> CsrfMiddleware {
279        CsrfMiddleware::new(self.config)
280    }
281}
282
283impl Default for CsrfMiddlewareBuilder {
284    fn default() -> Self {
285        Self::new()
286    }
287}
288
289#[cfg(test)]
290mod tests {
291    use super::*;
292    use axum::http::{HeaderValue, Method};
293    use elif_http::middleware::MiddlewarePipeline;
294    use std::collections::HashSet;
295
296
297    fn create_test_middleware() -> CsrfMiddleware {
298        let mut exempt_paths = HashSet::new();
299        exempt_paths.insert("/api/webhook".to_string());
300        exempt_paths.insert("/public/*".to_string());
301
302        let config = CsrfConfig {
303            token_header: "X-CSRF-Token".to_string(),
304            cookie_name: "_csrf_token".to_string(),
305            token_lifetime: 3600,
306            secure_cookie: false, // For testing
307            exempt_paths,
308        };
309
310        CsrfMiddleware::new(config)
311    }
312
313    #[tokio::test]
314    async fn test_csrf_token_generation() {
315        let middleware = create_test_middleware();
316        
317        let token1 = middleware.generate_token(Some("Mozilla/5.0")).await;
318        let token2 = middleware.generate_token(Some("Mozilla/5.0")).await;
319        
320        // Tokens should be different
321        assert_ne!(token1, token2);
322        assert!(token1.len() > 20); // Should be base64 encoded
323        assert!(token2.len() > 20);
324    }
325
326    #[tokio::test]
327    async fn test_csrf_token_validation() {
328        let middleware = create_test_middleware();
329        let user_agent = Some("Mozilla/5.0");
330        
331        let token = middleware.generate_token(user_agent).await;
332        
333        // Valid token should pass
334        assert!(middleware.validate_token(&token, user_agent).await);
335        
336        // Invalid token should fail
337        assert!(!middleware.validate_token("invalid_token", user_agent).await);
338        
339        // Different user agent should fail if token was generated with one
340        assert!(!middleware.validate_token(&token, Some("Different Agent")).await);
341    }
342
343    #[tokio::test]
344    async fn test_csrf_token_expiration() {
345        let config = CsrfConfig {
346            token_lifetime: 1, // 1 second
347            ..Default::default()
348        };
349        let middleware = CsrfMiddleware::new(config);
350        
351        let token = middleware.generate_token(None).await;
352        
353        // Should be valid immediately
354        assert!(middleware.validate_token(&token, None).await);
355        
356        // Wait for expiration
357        tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
358        
359        // Should be expired now
360        assert!(!middleware.validate_token(&token, None).await);
361    }
362
363    #[tokio::test]
364    async fn test_csrf_exempt_paths() {
365        let middleware = create_test_middleware();
366        
367        // Exact match
368        assert!(middleware.is_exempt_path("/api/webhook"));
369        
370        // Glob pattern match
371        assert!(middleware.is_exempt_path("/public/assets/style.css"));
372        assert!(middleware.is_exempt_path("/public/images/logo.png"));
373        
374        // Non-exempt paths
375        assert!(!middleware.is_exempt_path("/api/users"));
376        assert!(!middleware.is_exempt_path("/admin/dashboard"));
377    }
378
379    #[tokio::test]
380    async fn test_csrf_builder_pattern() {
381        let middleware = CsrfMiddleware::builder()
382            .token_header("X-Custom-CSRF-Token")
383            .cookie_name("_custom_csrf")
384            .token_lifetime(7200)
385            .secure_cookie(true)
386            .exempt_path("/api/public")
387            .exempt_paths(vec!["/webhook", "/status"])
388            .build();
389            
390        assert_eq!(middleware.config.token_header, "X-Custom-CSRF-Token");
391        assert_eq!(middleware.config.cookie_name, "_custom_csrf");
392        assert_eq!(middleware.config.token_lifetime, 7200);
393        assert!(middleware.config.secure_cookie);
394        assert!(middleware.config.exempt_paths.contains("/api/public"));
395        assert!(middleware.config.exempt_paths.contains("/webhook"));
396        assert!(middleware.config.exempt_paths.contains("/status"));
397    }
398
399    #[tokio::test]
400    async fn test_csrf_middleware_get_requests() {
401        let middleware = create_test_middleware();
402        let pipeline = MiddlewarePipeline::new().add(middleware);
403        
404        // Create GET request
405        let request = Request::builder()
406            .method(Method::GET)
407            .uri("/test")
408            .body(axum::body::Body::empty())
409            .unwrap();
410        
411        // GET requests should pass without CSRF token
412        let result = pipeline.process_request(request).await;
413        assert!(result.is_ok());
414    }
415
416    #[tokio::test]
417    async fn test_csrf_middleware_post_without_token() {
418        let middleware = create_test_middleware();
419        let pipeline = MiddlewarePipeline::new().add(middleware);
420        
421        // Create POST request without CSRF token
422        let request = Request::builder()
423            .method(Method::POST)
424            .uri("/test")
425            .body(axum::body::Body::empty())
426            .unwrap();
427        
428        // POST without CSRF token should fail
429        let result = pipeline.process_request(request).await;
430        assert!(result.is_err());
431        
432        // Check that it returns 403 Forbidden
433        if let Err(response) = result {
434            assert_eq!(response.status(), ElifStatusCode::FORBIDDEN);
435        }
436    }
437
438    #[tokio::test]
439    async fn test_csrf_middleware_post_with_valid_token() {
440        let middleware = create_test_middleware();
441        let token = middleware.generate_token(Some("TestAgent")).await;
442        let pipeline = MiddlewarePipeline::new().add(middleware);
443        
444        // Create POST request with valid CSRF token
445        let request = Request::builder()
446            .method(Method::POST)
447            .uri("/test")
448            .header("X-CSRF-Token", &token)
449            .header("User-Agent", "TestAgent")
450            .body(axum::body::Body::empty())
451            .unwrap();
452        
453        // POST with valid CSRF token should pass
454        let result = pipeline.process_request(request).await;
455        assert!(result.is_ok());
456    }
457
458    #[tokio::test]
459    async fn test_csrf_middleware_exempt_paths() {
460        let middleware = create_test_middleware();
461        let pipeline = MiddlewarePipeline::new().add(middleware);
462        
463        // Test exempt exact path
464        let request1 = Request::builder()
465            .method(Method::POST)
466            .uri("/api/webhook")
467            .body(axum::body::Body::empty())
468            .unwrap();
469        
470        let result1 = pipeline.process_request(request1).await;
471        assert!(result1.is_ok());
472        
473        // Test exempt glob path
474        let request2 = Request::builder()
475            .method(Method::POST)
476            .uri("/public/upload")
477            .body(axum::body::Body::empty())
478            .unwrap();
479        
480        let result2 = pipeline.process_request(request2).await;
481        assert!(result2.is_ok());
482    }
483
484    #[tokio::test]
485    async fn test_csrf_token_cleanup() {
486        let config = CsrfConfig {
487            token_lifetime: 1, // 1 second
488            ..Default::default()
489        };
490        let middleware = CsrfMiddleware::new(config);
491        
492        // Generate several tokens
493        let _token1 = middleware.generate_token(None).await;
494        let _token2 = middleware.generate_token(None).await;
495        let _token3 = middleware.generate_token(None).await;
496        
497        // Check initial count
498        {
499            let store = middleware.token_store.read().await;
500            assert_eq!(store.len(), 3);
501        }
502        
503        // Wait for expiration
504        tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
505        
506        // Generate a new token to trigger cleanup
507        let _new_token = middleware.generate_token(None).await;
508        
509        // Check that expired tokens were cleaned up
510        {
511            let store = middleware.token_store.read().await;
512            assert_eq!(store.len(), 1); // Only the new token should remain
513        }
514    }
515
516    #[tokio::test]
517    async fn test_csrf_cookie_extraction() {
518        let middleware = create_test_middleware();
519        let mut headers = HeaderMap::new();
520        
521        // Test cookie extraction
522        headers.insert(
523            header::COOKIE,
524            HeaderValue::from_str("_csrf_token=test_token_123; other_cookie=value").unwrap()
525        );
526        
527        let token = middleware.extract_token(&headers);
528        assert_eq!(token, Some("test_token_123".to_string()));
529        
530        // Test header extraction (should take precedence)
531        headers.insert(
532            "X-CSRF-Token",
533            HeaderValue::from_str("header_token_456").unwrap()
534        );
535        
536        let token = middleware.extract_token(&headers);
537        assert_eq!(token, Some("header_token_456".to_string()));
538    }
539
540    #[tokio::test]
541    async fn test_csrf_user_agent_binding() {
542        let middleware = create_test_middleware();
543        
544        let token = middleware.generate_token(Some("SpecificAgent")).await;
545        
546        // Same user agent should work
547        assert!(middleware.validate_token(&token, Some("SpecificAgent")).await);
548        
549        // Different user agent should fail
550        assert!(!middleware.validate_token(&token, Some("DifferentAgent")).await);
551        
552        // No user agent should fail when token was created with one
553        assert!(!middleware.validate_token(&token, None).await);
554    }
555}