elif_security/middleware/
csrf.rs

1//! CSRF (Cross-Site Request Forgery) protection middleware
2//!
3//! Provides comprehensive CSRF protection including token generation,
4//! validation, and secure cookie handling.
5
6use std::sync::Arc;
7use std::collections::HashMap;
8use axum::{
9    extract::Request,
10    http::{HeaderMap, Method, StatusCode, header},
11    response::{IntoResponse, Response},
12};
13use elif_http::middleware::{Middleware, BoxFuture};
14use sha2::{Sha256, Digest};
15use rand::{thread_rng, Rng};
16use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
17
18pub use crate::config::CsrfConfig;
19use crate::SecurityError;
20
21/// CSRF token store - in production this would be backed by Redis/database
22type TokenStore = Arc<tokio::sync::RwLock<HashMap<String, CsrfTokenData>>>;
23
24/// CSRF token data with expiration
25#[derive(Debug, Clone)]
26pub struct CsrfTokenData {
27    pub token: String,
28    pub expires_at: time::OffsetDateTime,
29    pub user_agent_hash: Option<String>,
30}
31
32/// CSRF protection middleware
33#[derive(Debug, Clone)]
34pub struct CsrfMiddleware {
35    config: CsrfConfig,
36    token_store: TokenStore,
37}
38
39impl CsrfMiddleware {
40    /// Create new CSRF middleware with configuration
41    pub fn new(config: CsrfConfig) -> Self {
42        Self { 
43            config,
44            token_store: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
45        }
46    }
47    
48    /// Create middleware with builder pattern
49    pub fn builder() -> CsrfMiddlewareBuilder {
50        CsrfMiddlewareBuilder::new()
51    }
52    
53    /// Generate a new CSRF token
54    pub async fn generate_token(&self, user_agent: Option<&str>) -> String {
55        let mut rng = thread_rng();
56        let token_bytes: [u8; 32] = rng.gen();
57        let token = URL_SAFE_NO_PAD.encode(token_bytes);
58        
59        let user_agent_hash = user_agent.map(|ua| {
60            let mut hasher = Sha256::new();
61            hasher.update(ua.as_bytes());
62            format!("{:x}", hasher.finalize())
63        });
64        
65        let token_data = CsrfTokenData {
66            token: token.clone(),
67            expires_at: time::OffsetDateTime::now_utc() + 
68                time::Duration::seconds(self.config.token_lifetime as i64),
69            user_agent_hash,
70        };
71        
72        // Store token
73        let mut store = self.token_store.write().await;
74        store.insert(token.clone(), token_data);
75        
76        // Clean up expired tokens periodically
77        self.cleanup_expired_tokens(&mut store).await;
78        
79        token
80    }
81    
82    /// Validate a CSRF token
83    pub async fn validate_token(&self, token: &str, user_agent: Option<&str>) -> bool {
84        let store = self.token_store.read().await;
85        
86        if let Some(token_data) = store.get(token) {
87            // Check expiration
88            if time::OffsetDateTime::now_utc() > token_data.expires_at {
89                return false;
90            }
91            
92            // Check user agent if configured
93            if let Some(stored_hash) = &token_data.user_agent_hash {
94                if let Some(ua) = user_agent {
95                    let mut hasher = Sha256::new();
96                    hasher.update(ua.as_bytes());
97                    let ua_hash = format!("{:x}", hasher.finalize());
98                    if stored_hash != &ua_hash {
99                        return false;
100                    }
101                } else {
102                    return false;
103                }
104            }
105            
106            true
107        } else {
108            false
109        }
110    }
111    
112    /// Remove a token after successful validation (single-use)
113    pub async fn consume_token(&self, token: &str) {
114        let mut store = self.token_store.write().await;
115        store.remove(token);
116    }
117    
118    /// Clean up expired tokens
119    async fn cleanup_expired_tokens(&self, store: &mut HashMap<String, CsrfTokenData>) {
120        let now = time::OffsetDateTime::now_utc();
121        store.retain(|_, data| data.expires_at > now);
122    }
123    
124    /// Check if path is exempt from CSRF protection
125    fn is_exempt_path(&self, path: &str) -> bool {
126        self.config.exempt_paths.contains(path) ||
127        self.config.exempt_paths.iter().any(|exempt| {
128            // Simple glob pattern matching
129            if exempt.ends_with('*') {
130                path.starts_with(&exempt[..exempt.len()-1])
131            } else {
132                path == exempt
133            }
134        })
135    }
136    
137    /// Extract CSRF token from request
138    fn extract_token(&self, headers: &HeaderMap) -> Option<String> {
139        // Try header first
140        if let Some(header_value) = headers.get(&self.config.token_header) {
141            if let Ok(token) = header_value.to_str() {
142                return Some(token.to_string());
143            }
144        }
145        
146        // Try cookie (would need cookie parsing here - simplified for now)
147        if let Some(cookie_header) = headers.get(header::COOKIE) {
148            if let Ok(cookies) = cookie_header.to_str() {
149                for cookie in cookies.split(';') {
150                    let cookie = cookie.trim();
151                    if let Some((name, value)) = cookie.split_once('=') {
152                        if name == self.config.cookie_name {
153                            return Some(value.to_string());
154                        }
155                    }
156                }
157            }
158        }
159        
160        None
161    }
162}
163
164/// Implementation of our Middleware trait for CSRF protection
165impl Middleware for CsrfMiddleware {
166    fn process_request<'a>(
167        &'a self, 
168        request: Request
169    ) -> BoxFuture<'a, Result<Request, Response>> {
170        Box::pin(async move {
171            let method = request.method();
172            let uri = request.uri();
173            let headers = request.headers();
174            
175            // Skip CSRF protection for safe methods (GET, HEAD, OPTIONS)
176            if matches!(method, &Method::GET | &Method::HEAD | &Method::OPTIONS) {
177                return Ok(request);
178            }
179            
180            // Skip exempt paths
181            if self.is_exempt_path(uri.path()) {
182                return Ok(request);
183            }
184            
185            // Extract and validate token
186            let user_agent = headers.get(header::USER_AGENT)
187                .and_then(|h| h.to_str().ok());
188                
189            if let Some(token) = self.extract_token(headers) {
190                if self.validate_token(&token, user_agent).await {
191                    // Consume token for single-use (optional - can be configured)
192                    // self.consume_token(&token).await;
193                    return Ok(request);
194                }
195            }
196            
197            // CSRF validation failed - return 403 Forbidden
198            let error_response = Response::builder()
199                .status(StatusCode::FORBIDDEN)
200                .header("Content-Type", "application/json")
201                .body(r#"{"error":{"code":"CSRF_VALIDATION_FAILED","message":"CSRF token validation failed"}}"#.into())
202                .unwrap();
203                
204            Err(error_response)
205        })
206    }
207    
208    fn name(&self) -> &'static str {
209        "CSRF Protection"
210    }
211}
212
213impl IntoResponse for SecurityError {
214    fn into_response(self) -> Response {
215        let (status, message) = match self {
216            SecurityError::CsrfValidationFailed => {
217                (StatusCode::FORBIDDEN, "CSRF token validation failed")
218            }
219            _ => (StatusCode::INTERNAL_SERVER_ERROR, "Security error"),
220        };
221        
222        (status, message).into_response()
223    }
224}
225
226/// Builder for CSRF middleware configuration
227#[derive(Debug)]
228pub struct CsrfMiddlewareBuilder {
229    config: CsrfConfig,
230}
231
232impl CsrfMiddlewareBuilder {
233    pub fn new() -> Self {
234        Self {
235            config: CsrfConfig::default(),
236        }
237    }
238    
239    pub fn token_header<S: Into<String>>(mut self, header: S) -> Self {
240        self.config.token_header = header.into();
241        self
242    }
243    
244    pub fn cookie_name<S: Into<String>>(mut self, name: S) -> Self {
245        self.config.cookie_name = name.into();
246        self
247    }
248    
249    pub fn token_lifetime(mut self, seconds: u64) -> Self {
250        self.config.token_lifetime = seconds;
251        self
252    }
253    
254    pub fn secure_cookie(mut self, secure: bool) -> Self {
255        self.config.secure_cookie = secure;
256        self
257    }
258    
259    pub fn exempt_path<S: Into<String>>(mut self, path: S) -> Self {
260        self.config.exempt_paths.insert(path.into());
261        self
262    }
263    
264    pub fn exempt_paths<I, S>(mut self, paths: I) -> Self 
265    where 
266        I: IntoIterator<Item = S>,
267        S: Into<String>,
268    {
269        for path in paths {
270            self.config.exempt_paths.insert(path.into());
271        }
272        self
273    }
274    
275    pub fn build(self) -> CsrfMiddleware {
276        CsrfMiddleware::new(self.config)
277    }
278}
279
280impl Default for CsrfMiddlewareBuilder {
281    fn default() -> Self {
282        Self::new()
283    }
284}
285
286#[cfg(test)]
287mod tests {
288    use super::*;
289    use axum::http::{HeaderValue, Method};
290    use elif_http::middleware::MiddlewarePipeline;
291    use std::collections::HashSet;
292
293
294    fn create_test_middleware() -> CsrfMiddleware {
295        let mut exempt_paths = HashSet::new();
296        exempt_paths.insert("/api/webhook".to_string());
297        exempt_paths.insert("/public/*".to_string());
298
299        let config = CsrfConfig {
300            token_header: "X-CSRF-Token".to_string(),
301            cookie_name: "_csrf_token".to_string(),
302            token_lifetime: 3600,
303            secure_cookie: false, // For testing
304            exempt_paths,
305        };
306
307        CsrfMiddleware::new(config)
308    }
309
310    #[tokio::test]
311    async fn test_csrf_token_generation() {
312        let middleware = create_test_middleware();
313        
314        let token1 = middleware.generate_token(Some("Mozilla/5.0")).await;
315        let token2 = middleware.generate_token(Some("Mozilla/5.0")).await;
316        
317        // Tokens should be different
318        assert_ne!(token1, token2);
319        assert!(token1.len() > 20); // Should be base64 encoded
320        assert!(token2.len() > 20);
321    }
322
323    #[tokio::test]
324    async fn test_csrf_token_validation() {
325        let middleware = create_test_middleware();
326        let user_agent = Some("Mozilla/5.0");
327        
328        let token = middleware.generate_token(user_agent).await;
329        
330        // Valid token should pass
331        assert!(middleware.validate_token(&token, user_agent).await);
332        
333        // Invalid token should fail
334        assert!(!middleware.validate_token("invalid_token", user_agent).await);
335        
336        // Different user agent should fail if token was generated with one
337        assert!(!middleware.validate_token(&token, Some("Different Agent")).await);
338    }
339
340    #[tokio::test]
341    async fn test_csrf_token_expiration() {
342        let config = CsrfConfig {
343            token_lifetime: 1, // 1 second
344            ..Default::default()
345        };
346        let middleware = CsrfMiddleware::new(config);
347        
348        let token = middleware.generate_token(None).await;
349        
350        // Should be valid immediately
351        assert!(middleware.validate_token(&token, None).await);
352        
353        // Wait for expiration
354        tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
355        
356        // Should be expired now
357        assert!(!middleware.validate_token(&token, None).await);
358    }
359
360    #[tokio::test]
361    async fn test_csrf_exempt_paths() {
362        let middleware = create_test_middleware();
363        
364        // Exact match
365        assert!(middleware.is_exempt_path("/api/webhook"));
366        
367        // Glob pattern match
368        assert!(middleware.is_exempt_path("/public/assets/style.css"));
369        assert!(middleware.is_exempt_path("/public/images/logo.png"));
370        
371        // Non-exempt paths
372        assert!(!middleware.is_exempt_path("/api/users"));
373        assert!(!middleware.is_exempt_path("/admin/dashboard"));
374    }
375
376    #[tokio::test]
377    async fn test_csrf_builder_pattern() {
378        let middleware = CsrfMiddleware::builder()
379            .token_header("X-Custom-CSRF-Token")
380            .cookie_name("_custom_csrf")
381            .token_lifetime(7200)
382            .secure_cookie(true)
383            .exempt_path("/api/public")
384            .exempt_paths(vec!["/webhook", "/status"])
385            .build();
386            
387        assert_eq!(middleware.config.token_header, "X-Custom-CSRF-Token");
388        assert_eq!(middleware.config.cookie_name, "_custom_csrf");
389        assert_eq!(middleware.config.token_lifetime, 7200);
390        assert!(middleware.config.secure_cookie);
391        assert!(middleware.config.exempt_paths.contains("/api/public"));
392        assert!(middleware.config.exempt_paths.contains("/webhook"));
393        assert!(middleware.config.exempt_paths.contains("/status"));
394    }
395
396    #[tokio::test]
397    async fn test_csrf_middleware_get_requests() {
398        let middleware = create_test_middleware();
399        let pipeline = MiddlewarePipeline::new().add(middleware);
400        
401        // Create GET request
402        let request = Request::builder()
403            .method(Method::GET)
404            .uri("/test")
405            .body(axum::body::Body::empty())
406            .unwrap();
407        
408        // GET requests should pass without CSRF token
409        let result = pipeline.process_request(request).await;
410        assert!(result.is_ok());
411    }
412
413    #[tokio::test]
414    async fn test_csrf_middleware_post_without_token() {
415        let middleware = create_test_middleware();
416        let pipeline = MiddlewarePipeline::new().add(middleware);
417        
418        // Create POST request without CSRF token
419        let request = Request::builder()
420            .method(Method::POST)
421            .uri("/test")
422            .body(axum::body::Body::empty())
423            .unwrap();
424        
425        // POST without CSRF token should fail
426        let result = pipeline.process_request(request).await;
427        assert!(result.is_err());
428        
429        // Check that it returns 403 Forbidden
430        if let Err(response) = result {
431            assert_eq!(response.status(), StatusCode::FORBIDDEN);
432        }
433    }
434
435    #[tokio::test]
436    async fn test_csrf_middleware_post_with_valid_token() {
437        let middleware = create_test_middleware();
438        let token = middleware.generate_token(Some("TestAgent")).await;
439        let pipeline = MiddlewarePipeline::new().add(middleware);
440        
441        // Create POST request with valid CSRF token
442        let request = Request::builder()
443            .method(Method::POST)
444            .uri("/test")
445            .header("X-CSRF-Token", &token)
446            .header("User-Agent", "TestAgent")
447            .body(axum::body::Body::empty())
448            .unwrap();
449        
450        // POST with valid CSRF token should pass
451        let result = pipeline.process_request(request).await;
452        assert!(result.is_ok());
453    }
454
455    #[tokio::test]
456    async fn test_csrf_middleware_exempt_paths() {
457        let middleware = create_test_middleware();
458        let pipeline = MiddlewarePipeline::new().add(middleware);
459        
460        // Test exempt exact path
461        let request1 = Request::builder()
462            .method(Method::POST)
463            .uri("/api/webhook")
464            .body(axum::body::Body::empty())
465            .unwrap();
466        
467        let result1 = pipeline.process_request(request1).await;
468        assert!(result1.is_ok());
469        
470        // Test exempt glob path
471        let request2 = Request::builder()
472            .method(Method::POST)
473            .uri("/public/upload")
474            .body(axum::body::Body::empty())
475            .unwrap();
476        
477        let result2 = pipeline.process_request(request2).await;
478        assert!(result2.is_ok());
479    }
480
481    #[tokio::test]
482    async fn test_csrf_token_cleanup() {
483        let config = CsrfConfig {
484            token_lifetime: 1, // 1 second
485            ..Default::default()
486        };
487        let middleware = CsrfMiddleware::new(config);
488        
489        // Generate several tokens
490        let _token1 = middleware.generate_token(None).await;
491        let _token2 = middleware.generate_token(None).await;
492        let _token3 = middleware.generate_token(None).await;
493        
494        // Check initial count
495        {
496            let store = middleware.token_store.read().await;
497            assert_eq!(store.len(), 3);
498        }
499        
500        // Wait for expiration
501        tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
502        
503        // Generate a new token to trigger cleanup
504        let _new_token = middleware.generate_token(None).await;
505        
506        // Check that expired tokens were cleaned up
507        {
508            let store = middleware.token_store.read().await;
509            assert_eq!(store.len(), 1); // Only the new token should remain
510        }
511    }
512
513    #[tokio::test]
514    async fn test_csrf_cookie_extraction() {
515        let middleware = create_test_middleware();
516        let mut headers = HeaderMap::new();
517        
518        // Test cookie extraction
519        headers.insert(
520            header::COOKIE,
521            HeaderValue::from_str("_csrf_token=test_token_123; other_cookie=value").unwrap()
522        );
523        
524        let token = middleware.extract_token(&headers);
525        assert_eq!(token, Some("test_token_123".to_string()));
526        
527        // Test header extraction (should take precedence)
528        headers.insert(
529            "X-CSRF-Token",
530            HeaderValue::from_str("header_token_456").unwrap()
531        );
532        
533        let token = middleware.extract_token(&headers);
534        assert_eq!(token, Some("header_token_456".to_string()));
535    }
536
537    #[tokio::test]
538    async fn test_csrf_user_agent_binding() {
539        let middleware = create_test_middleware();
540        
541        let token = middleware.generate_token(Some("SpecificAgent")).await;
542        
543        // Same user agent should work
544        assert!(middleware.validate_token(&token, Some("SpecificAgent")).await);
545        
546        // Different user agent should fail
547        assert!(!middleware.validate_token(&token, Some("DifferentAgent")).await);
548        
549        // No user agent should fail when token was created with one
550        assert!(!middleware.validate_token(&token, None).await);
551    }
552}