elif_security/
integration.rs1use elif_http::middleware::MiddlewarePipeline;
7use crate::{
8 middleware::{cors::CorsMiddleware, csrf::CsrfMiddleware},
9 config::{CorsConfig, CsrfConfig},
10};
11
12#[derive(Debug, Default)]
15pub struct SecurityMiddlewareBuilder {
16 cors_config: Option<CorsConfig>,
17 csrf_config: Option<CsrfConfig>,
18}
19
20impl SecurityMiddlewareBuilder {
21 pub fn new() -> Self {
23 Self::default()
24 }
25
26 pub fn with_cors(mut self, config: CorsConfig) -> Self {
28 self.cors_config = Some(config);
29 self
30 }
31
32 pub fn with_cors_permissive(mut self) -> Self {
34 self.cors_config = Some(CorsConfig::default());
35 self
36 }
37
38 pub fn with_csrf(mut self, config: CsrfConfig) -> Self {
40 self.csrf_config = Some(config);
41 self
42 }
43
44 pub fn with_csrf_default(mut self) -> Self {
46 self.csrf_config = Some(CsrfConfig::default());
47 self
48 }
49
50 pub fn build(self) -> MiddlewarePipeline {
56 let mut pipeline = MiddlewarePipeline::new();
57
58 if let Some(cors_config) = self.cors_config {
60 let cors_middleware = CorsMiddleware::new(cors_config);
61 pipeline = pipeline.add(cors_middleware);
62 }
63
64 if let Some(csrf_config) = self.csrf_config {
66 let csrf_middleware = CsrfMiddleware::new(csrf_config);
67 pipeline = pipeline.add(csrf_middleware);
68 }
69
70 pipeline
71 }
72}
73
74pub fn basic_security_pipeline() -> MiddlewarePipeline {
78 SecurityMiddlewareBuilder::new()
79 .with_cors_permissive()
80 .with_csrf_default()
81 .build()
82}
83
84pub fn strict_security_pipeline(allowed_origins: Vec<String>) -> MiddlewarePipeline {
86 use std::collections::HashSet;
87
88 let cors_config = CorsConfig {
89 allowed_origins: Some(allowed_origins.into_iter().collect::<HashSet<_>>()),
90 allow_credentials: true,
91 max_age: Some(300), ..CorsConfig::default()
93 };
94
95 let csrf_config = CsrfConfig {
96 secure_cookie: true,
97 token_lifetime: 3600, ..CsrfConfig::default()
99 };
100
101 SecurityMiddlewareBuilder::new()
102 .with_cors(cors_config)
103 .with_csrf(csrf_config)
104 .build()
105}
106
107pub fn development_security_pipeline() -> MiddlewarePipeline {
109 let cors_config = CorsConfig {
110 allowed_origins: None, allow_credentials: false,
112 ..CorsConfig::default()
113 };
114
115 let csrf_config = CsrfConfig {
116 secure_cookie: false, token_lifetime: 7200, ..CsrfConfig::default()
119 };
120
121 SecurityMiddlewareBuilder::new()
122 .with_cors(cors_config)
123 .with_csrf(csrf_config)
124 .build()
125}
126
127#[cfg(test)]
128mod tests {
129 use super::*;
130 use axum::{extract::Request, http::Method, body::Body};
131
132 #[tokio::test]
133 async fn test_basic_security_pipeline() {
134 let pipeline = basic_security_pipeline();
135
136 assert_eq!(pipeline.len(), 2);
138 assert_eq!(pipeline.names(), vec!["CorsMiddleware", "CsrfMiddleware"]);
139 }
140
141 #[tokio::test]
142 async fn test_security_middleware_builder() {
143 let cors_config = CorsConfig::default();
144 let csrf_config = CsrfConfig::default();
145
146 let pipeline = SecurityMiddlewareBuilder::new()
147 .with_cors(cors_config)
148 .with_csrf(csrf_config)
149 .build();
150
151 assert_eq!(pipeline.len(), 2);
152 assert!(pipeline.names().contains(&"CorsMiddleware"));
153 assert!(pipeline.names().contains(&"CsrfMiddleware"));
154 }
155
156 #[tokio::test]
157 async fn test_cors_only_pipeline() {
158 let pipeline = SecurityMiddlewareBuilder::new()
159 .with_cors_permissive()
160 .build();
161
162 assert_eq!(pipeline.len(), 1);
163 assert_eq!(pipeline.names(), vec!["CorsMiddleware"]);
164 }
165
166 #[tokio::test]
167 async fn test_csrf_only_pipeline() {
168 let pipeline = SecurityMiddlewareBuilder::new()
169 .with_csrf_default()
170 .build();
171
172 assert_eq!(pipeline.len(), 1);
173 assert_eq!(pipeline.names(), vec!["CsrfMiddleware"]);
174 }
175
176 #[tokio::test]
177 async fn test_security_pipeline_processing() {
178 let pipeline = basic_security_pipeline();
179
180 let request = Request::builder()
182 .method(Method::GET)
183 .uri("/")
184 .header("Origin", "https://example.com")
185 .body(Body::empty())
186 .unwrap();
187
188 let result = pipeline.process_request(request).await;
189
190 assert!(result.is_ok());
192 }
193
194 #[tokio::test]
195 async fn test_strict_security_pipeline() {
196 let allowed_origins = vec!["https://trusted.com".to_string()];
197 let pipeline = strict_security_pipeline(allowed_origins);
198
199 assert_eq!(pipeline.len(), 2);
200
201 let request = Request::builder()
203 .method(Method::GET)
204 .uri("/")
205 .header("Origin", "https://trusted.com")
206 .body(Body::empty())
207 .unwrap();
208
209 let result = pipeline.process_request(request).await;
210 assert!(result.is_ok());
211
212 let request = Request::builder()
214 .method(Method::GET)
215 .uri("/")
216 .header("Origin", "https://evil.com")
217 .body(Body::empty())
218 .unwrap();
219
220 let result = pipeline.process_request(request).await;
221 assert!(result.is_err());
222 }
223
224 #[tokio::test]
225 async fn test_development_security_pipeline() {
226 let pipeline = development_security_pipeline();
227
228 assert_eq!(pipeline.len(), 2);
229
230 let request = Request::builder()
232 .method(Method::GET)
233 .uri("/")
234 .header("Origin", "http://localhost:3000")
235 .body(Body::empty())
236 .unwrap();
237
238 let result = pipeline.process_request(request).await;
239 assert!(result.is_ok());
240 }
241}