elif_security/middleware/
csrf.rs

1//! CSRF (Cross-Site Request Forgery) protection middleware
2//!
3//! Provides comprehensive CSRF protection including token generation,
4//! validation, and secure cookie handling.
5
6use std::sync::Arc;
7use std::collections::HashMap;
8use axum::{
9    extract::{Request, State},
10    http::{HeaderMap, Method, StatusCode, header},
11    middleware::Next,
12    response::{IntoResponse, Response},
13};
14use sha2::{Sha256, Digest};
15use rand::{thread_rng, Rng};
16use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
17
18pub use crate::config::CsrfConfig;
19use crate::SecurityError;
20
21/// CSRF token store - in production this would be backed by Redis/database
22type TokenStore = Arc<tokio::sync::RwLock<HashMap<String, CsrfTokenData>>>;
23
24/// CSRF token data with expiration
25#[derive(Debug, Clone)]
26pub struct CsrfTokenData {
27    pub token: String,
28    pub expires_at: time::OffsetDateTime,
29    pub user_agent_hash: Option<String>,
30}
31
32/// CSRF protection middleware
33#[derive(Debug, Clone)]
34pub struct CsrfMiddleware {
35    config: CsrfConfig,
36    token_store: TokenStore,
37}
38
39impl CsrfMiddleware {
40    /// Create new CSRF middleware with configuration
41    pub fn new(config: CsrfConfig) -> Self {
42        Self { 
43            config,
44            token_store: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
45        }
46    }
47    
48    /// Create middleware with builder pattern
49    pub fn builder() -> CsrfMiddlewareBuilder {
50        CsrfMiddlewareBuilder::new()
51    }
52    
53    /// Generate a new CSRF token
54    pub async fn generate_token(&self, user_agent: Option<&str>) -> String {
55        let mut rng = thread_rng();
56        let token_bytes: [u8; 32] = rng.gen();
57        let token = URL_SAFE_NO_PAD.encode(token_bytes);
58        
59        let user_agent_hash = user_agent.map(|ua| {
60            let mut hasher = Sha256::new();
61            hasher.update(ua.as_bytes());
62            format!("{:x}", hasher.finalize())
63        });
64        
65        let token_data = CsrfTokenData {
66            token: token.clone(),
67            expires_at: time::OffsetDateTime::now_utc() + 
68                time::Duration::seconds(self.config.token_lifetime as i64),
69            user_agent_hash,
70        };
71        
72        // Store token
73        let mut store = self.token_store.write().await;
74        store.insert(token.clone(), token_data);
75        
76        // Clean up expired tokens periodically
77        self.cleanup_expired_tokens(&mut store).await;
78        
79        token
80    }
81    
82    /// Validate a CSRF token
83    pub async fn validate_token(&self, token: &str, user_agent: Option<&str>) -> bool {
84        let store = self.token_store.read().await;
85        
86        if let Some(token_data) = store.get(token) {
87            // Check expiration
88            if time::OffsetDateTime::now_utc() > token_data.expires_at {
89                return false;
90            }
91            
92            // Check user agent if configured
93            if let Some(stored_hash) = &token_data.user_agent_hash {
94                if let Some(ua) = user_agent {
95                    let mut hasher = Sha256::new();
96                    hasher.update(ua.as_bytes());
97                    let ua_hash = format!("{:x}", hasher.finalize());
98                    if stored_hash != &ua_hash {
99                        return false;
100                    }
101                } else {
102                    return false;
103                }
104            }
105            
106            true
107        } else {
108            false
109        }
110    }
111    
112    /// Remove a token after successful validation (single-use)
113    pub async fn consume_token(&self, token: &str) {
114        let mut store = self.token_store.write().await;
115        store.remove(token);
116    }
117    
118    /// Clean up expired tokens
119    async fn cleanup_expired_tokens(&self, store: &mut HashMap<String, CsrfTokenData>) {
120        let now = time::OffsetDateTime::now_utc();
121        store.retain(|_, data| data.expires_at > now);
122    }
123    
124    /// Check if path is exempt from CSRF protection
125    fn is_exempt_path(&self, path: &str) -> bool {
126        self.config.exempt_paths.contains(path) ||
127        self.config.exempt_paths.iter().any(|exempt| {
128            // Simple glob pattern matching
129            if exempt.ends_with('*') {
130                path.starts_with(&exempt[..exempt.len()-1])
131            } else {
132                path == exempt
133            }
134        })
135    }
136    
137    /// Extract CSRF token from request
138    fn extract_token(&self, headers: &HeaderMap) -> Option<String> {
139        // Try header first
140        if let Some(header_value) = headers.get(&self.config.token_header) {
141            if let Ok(token) = header_value.to_str() {
142                return Some(token.to_string());
143            }
144        }
145        
146        // Try cookie (would need cookie parsing here - simplified for now)
147        if let Some(cookie_header) = headers.get(header::COOKIE) {
148            if let Ok(cookies) = cookie_header.to_str() {
149                for cookie in cookies.split(';') {
150                    let cookie = cookie.trim();
151                    if let Some((name, value)) = cookie.split_once('=') {
152                        if name == self.config.cookie_name {
153                            return Some(value.to_string());
154                        }
155                    }
156                }
157            }
158        }
159        
160        None
161    }
162}
163
164/// Axum middleware implementation
165pub async fn csrf_middleware(
166    State(middleware): State<CsrfMiddleware>,
167    request: Request,
168    next: Next,
169) -> Result<Response, SecurityError> {
170    let method = request.method();
171    let uri = request.uri();
172    let headers = request.headers();
173    
174    // Skip CSRF protection for safe methods (GET, HEAD, OPTIONS)
175    if matches!(method, &Method::GET | &Method::HEAD | &Method::OPTIONS) {
176        return Ok(next.run(request).await);
177    }
178    
179    // Skip exempt paths
180    if middleware.is_exempt_path(uri.path()) {
181        return Ok(next.run(request).await);
182    }
183    
184    // Extract and validate token
185    let user_agent = headers.get(header::USER_AGENT)
186        .and_then(|h| h.to_str().ok());
187        
188    if let Some(token) = middleware.extract_token(headers) {
189        if middleware.validate_token(&token, user_agent).await {
190            // Consume token for single-use (optional - can be configured)
191            // middleware.consume_token(&token).await;
192            return Ok(next.run(request).await);
193        }
194    }
195    
196    // CSRF validation failed
197    Err(SecurityError::CsrfValidationFailed)
198}
199
200impl IntoResponse for SecurityError {
201    fn into_response(self) -> Response {
202        let (status, message) = match self {
203            SecurityError::CsrfValidationFailed => {
204                (StatusCode::FORBIDDEN, "CSRF token validation failed")
205            }
206            _ => (StatusCode::INTERNAL_SERVER_ERROR, "Security error"),
207        };
208        
209        (status, message).into_response()
210    }
211}
212
213/// Builder for CSRF middleware configuration
214#[derive(Debug)]
215pub struct CsrfMiddlewareBuilder {
216    config: CsrfConfig,
217}
218
219impl CsrfMiddlewareBuilder {
220    pub fn new() -> Self {
221        Self {
222            config: CsrfConfig::default(),
223        }
224    }
225    
226    pub fn token_header<S: Into<String>>(mut self, header: S) -> Self {
227        self.config.token_header = header.into();
228        self
229    }
230    
231    pub fn cookie_name<S: Into<String>>(mut self, name: S) -> Self {
232        self.config.cookie_name = name.into();
233        self
234    }
235    
236    pub fn token_lifetime(mut self, seconds: u64) -> Self {
237        self.config.token_lifetime = seconds;
238        self
239    }
240    
241    pub fn secure_cookie(mut self, secure: bool) -> Self {
242        self.config.secure_cookie = secure;
243        self
244    }
245    
246    pub fn exempt_path<S: Into<String>>(mut self, path: S) -> Self {
247        self.config.exempt_paths.insert(path.into());
248        self
249    }
250    
251    pub fn exempt_paths<I, S>(mut self, paths: I) -> Self 
252    where 
253        I: IntoIterator<Item = S>,
254        S: Into<String>,
255    {
256        for path in paths {
257            self.config.exempt_paths.insert(path.into());
258        }
259        self
260    }
261    
262    pub fn build(self) -> CsrfMiddleware {
263        CsrfMiddleware::new(self.config)
264    }
265}
266
267impl Default for CsrfMiddlewareBuilder {
268    fn default() -> Self {
269        Self::new()
270    }
271}
272
273#[cfg(test)]
274mod tests {
275    use super::*;
276    use axum::{
277        http::HeaderValue,
278        middleware::from_fn_with_state,
279        routing::{get, post},
280        Router,
281    };
282    use axum_test::TestServer;
283    use std::collections::HashSet;
284
285    async fn test_handler() -> &'static str {
286        "OK"
287    }
288
289    fn create_test_middleware() -> CsrfMiddleware {
290        let mut exempt_paths = HashSet::new();
291        exempt_paths.insert("/api/webhook".to_string());
292        exempt_paths.insert("/public/*".to_string());
293
294        let config = CsrfConfig {
295            token_header: "X-CSRF-Token".to_string(),
296            cookie_name: "_csrf_token".to_string(),
297            token_lifetime: 3600,
298            secure_cookie: false, // For testing
299            exempt_paths,
300        };
301
302        CsrfMiddleware::new(config)
303    }
304
305    #[tokio::test]
306    async fn test_csrf_token_generation() {
307        let middleware = create_test_middleware();
308        
309        let token1 = middleware.generate_token(Some("Mozilla/5.0")).await;
310        let token2 = middleware.generate_token(Some("Mozilla/5.0")).await;
311        
312        // Tokens should be different
313        assert_ne!(token1, token2);
314        assert!(token1.len() > 20); // Should be base64 encoded
315        assert!(token2.len() > 20);
316    }
317
318    #[tokio::test]
319    async fn test_csrf_token_validation() {
320        let middleware = create_test_middleware();
321        let user_agent = Some("Mozilla/5.0");
322        
323        let token = middleware.generate_token(user_agent).await;
324        
325        // Valid token should pass
326        assert!(middleware.validate_token(&token, user_agent).await);
327        
328        // Invalid token should fail
329        assert!(!middleware.validate_token("invalid_token", user_agent).await);
330        
331        // Different user agent should fail if token was generated with one
332        assert!(!middleware.validate_token(&token, Some("Different Agent")).await);
333    }
334
335    #[tokio::test]
336    async fn test_csrf_token_expiration() {
337        let config = CsrfConfig {
338            token_lifetime: 1, // 1 second
339            ..Default::default()
340        };
341        let middleware = CsrfMiddleware::new(config);
342        
343        let token = middleware.generate_token(None).await;
344        
345        // Should be valid immediately
346        assert!(middleware.validate_token(&token, None).await);
347        
348        // Wait for expiration
349        tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
350        
351        // Should be expired now
352        assert!(!middleware.validate_token(&token, None).await);
353    }
354
355    #[tokio::test]
356    async fn test_csrf_exempt_paths() {
357        let middleware = create_test_middleware();
358        
359        // Exact match
360        assert!(middleware.is_exempt_path("/api/webhook"));
361        
362        // Glob pattern match
363        assert!(middleware.is_exempt_path("/public/assets/style.css"));
364        assert!(middleware.is_exempt_path("/public/images/logo.png"));
365        
366        // Non-exempt paths
367        assert!(!middleware.is_exempt_path("/api/users"));
368        assert!(!middleware.is_exempt_path("/admin/dashboard"));
369    }
370
371    #[tokio::test]
372    async fn test_csrf_builder_pattern() {
373        let middleware = CsrfMiddleware::builder()
374            .token_header("X-Custom-CSRF-Token")
375            .cookie_name("_custom_csrf")
376            .token_lifetime(7200)
377            .secure_cookie(true)
378            .exempt_path("/api/public")
379            .exempt_paths(vec!["/webhook", "/status"])
380            .build();
381            
382        assert_eq!(middleware.config.token_header, "X-Custom-CSRF-Token");
383        assert_eq!(middleware.config.cookie_name, "_custom_csrf");
384        assert_eq!(middleware.config.token_lifetime, 7200);
385        assert!(middleware.config.secure_cookie);
386        assert!(middleware.config.exempt_paths.contains("/api/public"));
387        assert!(middleware.config.exempt_paths.contains("/webhook"));
388        assert!(middleware.config.exempt_paths.contains("/status"));
389    }
390
391    #[tokio::test]
392    async fn test_csrf_middleware_get_requests() {
393        let middleware = create_test_middleware();
394        
395        let app = Router::new()
396            .route("/test", get(test_handler))
397            .layer(from_fn_with_state(middleware, csrf_middleware));
398            
399        let server = TestServer::new(app).unwrap();
400        
401        // GET requests should pass without CSRF token
402        let response = server.get("/test").await;
403        response.assert_status_ok();
404        response.assert_text("OK");
405    }
406
407    #[tokio::test]
408    async fn test_csrf_middleware_post_without_token() {
409        let middleware = create_test_middleware();
410        
411        let app = Router::new()
412            .route("/test", post(test_handler))
413            .layer(from_fn_with_state(middleware, csrf_middleware));
414            
415        let server = TestServer::new(app).unwrap();
416        
417        // POST without CSRF token should fail
418        let response = server.post("/test").await;
419        response.assert_status_forbidden();
420    }
421
422    #[tokio::test]
423    async fn test_csrf_middleware_post_with_valid_token() {
424        let middleware = create_test_middleware();
425        let token = middleware.generate_token(Some("TestAgent")).await;
426        
427        let app = Router::new()
428            .route("/test", post(test_handler))
429            .layer(from_fn_with_state(middleware, csrf_middleware));
430            
431        let server = TestServer::new(app).unwrap();
432        
433        // POST with valid CSRF token should pass
434        let response = server
435            .post("/test")
436            .add_header("X-CSRF-Token", &token)
437            .add_header("User-Agent", "TestAgent")
438            .await;
439            
440        response.assert_status_ok();
441        response.assert_text("OK");
442    }
443
444    #[tokio::test]
445    async fn test_csrf_middleware_exempt_paths() {
446        let middleware = create_test_middleware();
447        
448        let app = Router::new()
449            .route("/api/webhook", post(test_handler))
450            .route("/public/upload", post(test_handler))
451            .layer(from_fn_with_state(middleware, csrf_middleware));
452            
453        let server = TestServer::new(app).unwrap();
454        
455        // Exempt paths should pass without CSRF token
456        let response1 = server.post("/api/webhook").await;
457        response1.assert_status_ok();
458        
459        let response2 = server.post("/public/upload").await;
460        response2.assert_status_ok();
461    }
462
463    #[tokio::test]
464    async fn test_csrf_token_cleanup() {
465        let config = CsrfConfig {
466            token_lifetime: 1, // 1 second
467            ..Default::default()
468        };
469        let middleware = CsrfMiddleware::new(config);
470        
471        // Generate several tokens
472        let _token1 = middleware.generate_token(None).await;
473        let _token2 = middleware.generate_token(None).await;
474        let _token3 = middleware.generate_token(None).await;
475        
476        // Check initial count
477        {
478            let store = middleware.token_store.read().await;
479            assert_eq!(store.len(), 3);
480        }
481        
482        // Wait for expiration
483        tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
484        
485        // Generate a new token to trigger cleanup
486        let _new_token = middleware.generate_token(None).await;
487        
488        // Check that expired tokens were cleaned up
489        {
490            let store = middleware.token_store.read().await;
491            assert_eq!(store.len(), 1); // Only the new token should remain
492        }
493    }
494
495    #[tokio::test]
496    async fn test_csrf_cookie_extraction() {
497        let middleware = create_test_middleware();
498        let mut headers = HeaderMap::new();
499        
500        // Test cookie extraction
501        headers.insert(
502            header::COOKIE,
503            HeaderValue::from_str("_csrf_token=test_token_123; other_cookie=value").unwrap()
504        );
505        
506        let token = middleware.extract_token(&headers);
507        assert_eq!(token, Some("test_token_123".to_string()));
508        
509        // Test header extraction (should take precedence)
510        headers.insert(
511            "X-CSRF-Token",
512            HeaderValue::from_str("header_token_456").unwrap()
513        );
514        
515        let token = middleware.extract_token(&headers);
516        assert_eq!(token, Some("header_token_456".to_string()));
517    }
518
519    #[tokio::test]
520    async fn test_csrf_user_agent_binding() {
521        let middleware = create_test_middleware();
522        
523        let token = middleware.generate_token(Some("SpecificAgent")).await;
524        
525        // Same user agent should work
526        assert!(middleware.validate_token(&token, Some("SpecificAgent")).await);
527        
528        // Different user agent should fail
529        assert!(!middleware.validate_token(&token, Some("DifferentAgent")).await);
530        
531        // No user agent should fail when token was created with one
532        assert!(!middleware.validate_token(&token, None).await);
533    }
534}