Skip to main content

elif_security/
integration.rs

1//! Security Middleware Integration
2//!
3//! Provides a unified way to integrate all security middleware with the framework's 
4//! MiddlewarePipeline, ensuring consistent usage and proper ordering.
5
6use elif_http::middleware::MiddlewarePipeline;
7use crate::{
8    middleware::{cors::CorsMiddleware, csrf::CsrfMiddleware},
9    config::{CorsConfig, CsrfConfig},
10};
11
12/// Security middleware suite builder that helps configure and integrate
13/// all security middleware with the framework's MiddlewarePipeline
14#[derive(Debug, Default)]
15pub struct SecurityMiddlewareBuilder {
16    cors_config: Option<CorsConfig>,
17    csrf_config: Option<CsrfConfig>,
18}
19
20impl SecurityMiddlewareBuilder {
21    /// Create a new security middleware builder
22    pub fn new() -> Self {
23        Self::default()
24    }
25    
26    /// Add CORS middleware with configuration
27    pub fn with_cors(mut self, config: CorsConfig) -> Self {
28        self.cors_config = Some(config);
29        self
30    }
31    
32    /// Add CORS middleware with permissive settings (not recommended for production)
33    pub fn with_cors_permissive(mut self) -> Self {
34        self.cors_config = Some(CorsConfig::default());
35        self
36    }
37    
38    /// Add CSRF middleware with configuration
39    pub fn with_csrf(mut self, config: CsrfConfig) -> Self {
40        self.csrf_config = Some(config);
41        self
42    }
43    
44    /// Add CSRF middleware with default configuration
45    pub fn with_csrf_default(mut self) -> Self {
46        self.csrf_config = Some(CsrfConfig::default());
47        self
48    }
49    
50    /// Build the security middleware pipeline
51    /// 
52    /// The middleware are added in the following order for optimal security:
53    /// 1. CORS middleware (handles preflight requests early)
54    /// 2. CSRF middleware (validates tokens after CORS)
55    pub fn build(self) -> MiddlewarePipeline {
56        let mut pipeline = MiddlewarePipeline::new();
57        
58        // Add CORS middleware first (handles preflight requests)
59        if let Some(cors_config) = self.cors_config {
60            let cors_middleware = CorsMiddleware::new(cors_config);
61            pipeline = pipeline.add(cors_middleware);
62        }
63        
64        // Add CSRF middleware second (validates after CORS)
65        if let Some(csrf_config) = self.csrf_config {
66            let csrf_middleware = CsrfMiddleware::new(csrf_config);
67            pipeline = pipeline.add(csrf_middleware);
68        }
69        
70        pipeline
71    }
72}
73
74/// Quick setup functions for common security configurations
75
76/// Create a basic security pipeline with permissive CORS and default CSRF
77pub fn basic_security_pipeline() -> MiddlewarePipeline {
78    SecurityMiddlewareBuilder::new()
79        .with_cors_permissive()
80        .with_csrf_default()
81        .build()
82}
83
84/// Create a strict security pipeline with restrictive CORS and secure CSRF
85pub fn strict_security_pipeline(allowed_origins: Vec<String>) -> MiddlewarePipeline {
86    use std::collections::HashSet;
87    
88    let cors_config = CorsConfig {
89        allowed_origins: Some(allowed_origins.into_iter().collect::<HashSet<_>>()),
90        allow_credentials: true,
91        max_age: Some(300), // 5 minutes
92        ..CorsConfig::default()
93    };
94    
95    let csrf_config = CsrfConfig {
96        secure_cookie: true,
97        token_lifetime: 3600, // 1 hour
98        ..CsrfConfig::default()
99    };
100    
101    SecurityMiddlewareBuilder::new()
102        .with_cors(cors_config)
103        .with_csrf(csrf_config)
104        .build()
105}
106
107/// Create a development security pipeline with relaxed settings
108pub fn development_security_pipeline() -> MiddlewarePipeline {
109    let cors_config = CorsConfig {
110        allowed_origins: None, // Allow all origins in development
111        allow_credentials: false,
112        ..CorsConfig::default()
113    };
114    
115    let csrf_config = CsrfConfig {
116        secure_cookie: false, // Allow non-HTTPS in development
117        token_lifetime: 7200, // 2 hours for convenience
118        ..CsrfConfig::default()
119    };
120    
121    SecurityMiddlewareBuilder::new()
122        .with_cors(cors_config)
123        .with_csrf(csrf_config)
124        .build()
125}
126
127#[cfg(test)]
128mod tests {
129    use super::*;
130    use axum::{extract::Request, http::Method, body::Body};
131    
132    #[tokio::test]
133    async fn test_basic_security_pipeline() {
134        let pipeline = basic_security_pipeline();
135        
136        // Should have both CORS and CSRF middleware
137        assert_eq!(pipeline.len(), 2);
138        assert_eq!(pipeline.names(), vec!["CorsMiddleware", "CsrfMiddleware"]);
139    }
140    
141    #[tokio::test]
142    async fn test_security_middleware_builder() {
143        let cors_config = CorsConfig::default();
144        let csrf_config = CsrfConfig::default();
145        
146        let pipeline = SecurityMiddlewareBuilder::new()
147            .with_cors(cors_config)
148            .with_csrf(csrf_config)
149            .build();
150        
151        assert_eq!(pipeline.len(), 2);
152        assert!(pipeline.names().contains(&"CorsMiddleware"));
153        assert!(pipeline.names().contains(&"CsrfMiddleware"));
154    }
155    
156    #[tokio::test]
157    async fn test_cors_only_pipeline() {
158        let pipeline = SecurityMiddlewareBuilder::new()
159            .with_cors_permissive()
160            .build();
161        
162        assert_eq!(pipeline.len(), 1);
163        assert_eq!(pipeline.names(), vec!["CorsMiddleware"]);
164    }
165    
166    #[tokio::test]
167    async fn test_csrf_only_pipeline() {
168        let pipeline = SecurityMiddlewareBuilder::new()
169            .with_csrf_default()
170            .build();
171        
172        assert_eq!(pipeline.len(), 1);
173        assert_eq!(pipeline.names(), vec!["CsrfMiddleware"]);
174    }
175    
176    #[tokio::test]
177    async fn test_security_pipeline_processing() {
178        let pipeline = basic_security_pipeline();
179        
180        // Test normal GET request (should pass CORS and be exempt from CSRF)
181        let request = Request::builder()
182            .method(Method::GET)
183            .uri("/")
184            .header("Origin", "https://example.com")
185            .body(Body::empty())
186            .unwrap();
187        
188        let result = pipeline.process_request(request).await;
189        
190        // Should pass through successfully
191        assert!(result.is_ok());
192    }
193    
194    #[tokio::test]
195    async fn test_strict_security_pipeline() {
196        let allowed_origins = vec!["https://trusted.com".to_string()];
197        let pipeline = strict_security_pipeline(allowed_origins);
198        
199        assert_eq!(pipeline.len(), 2);
200        
201        // Test request from allowed origin
202        let request = Request::builder()
203            .method(Method::GET)
204            .uri("/")
205            .header("Origin", "https://trusted.com")
206            .body(Body::empty())
207            .unwrap();
208        
209        let result = pipeline.process_request(request).await;
210        assert!(result.is_ok());
211        
212        // Test request from disallowed origin
213        let request = Request::builder()
214            .method(Method::GET)
215            .uri("/")
216            .header("Origin", "https://evil.com")
217            .body(Body::empty())
218            .unwrap();
219        
220        let result = pipeline.process_request(request).await;
221        assert!(result.is_err());
222    }
223    
224    #[tokio::test]
225    async fn test_development_security_pipeline() {
226        let pipeline = development_security_pipeline();
227        
228        assert_eq!(pipeline.len(), 2);
229        
230        // Should allow any origin in development mode
231        let request = Request::builder()
232            .method(Method::GET)
233            .uri("/")
234            .header("Origin", "http://localhost:3000")
235            .body(Body::empty())
236            .unwrap();
237        
238        let result = pipeline.process_request(request).await;
239        assert!(result.is_ok());
240    }
241}