1pub mod content_security_policy;
109pub mod content_type_options;
110pub mod cors;
111pub mod dns_prefetch_control;
112pub mod download_options;
113pub mod expect_ct;
114pub mod frame_guard;
115pub mod hsts;
116pub mod permitted_cross_domain_policies;
117pub mod powered_by;
118pub mod referrer_policy;
119pub mod request_signing;
120pub mod xss_filter;
121
122use armature_core::HttpResponse;
123use std::collections::HashMap;
124
125#[derive(Debug, Clone)]
127pub struct SecurityMiddleware {
128 pub csp: Option<content_security_policy::CspConfig>,
130
131 pub dns_prefetch_control: dns_prefetch_control::DnsPrefetchControl,
133
134 pub expect_ct: Option<expect_ct::ExpectCtConfig>,
136
137 pub frame_guard: frame_guard::FrameGuard,
139
140 pub hsts: Option<hsts::HstsConfig>,
142
143 pub hide_powered_by: bool,
145
146 pub referrer_policy: referrer_policy::ReferrerPolicy,
148
149 pub xss_filter: xss_filter::XssFilter,
151
152 pub content_type_options: content_type_options::ContentTypeOptions,
154
155 pub download_options: download_options::DownloadOptions,
157
158 pub permitted_cross_domain_policies:
160 permitted_cross_domain_policies::PermittedCrossDomainPolicies,
161}
162
163impl SecurityMiddleware {
164 pub fn new() -> Self {
166 Self {
167 csp: None,
168 dns_prefetch_control: dns_prefetch_control::DnsPrefetchControl::Off,
169 expect_ct: None,
170 frame_guard: frame_guard::FrameGuard::Deny,
171 hsts: None,
172 hide_powered_by: false,
173 referrer_policy: referrer_policy::ReferrerPolicy::NoReferrer,
174 xss_filter: xss_filter::XssFilter::Enabled,
175 content_type_options: content_type_options::ContentTypeOptions::NoSniff,
176 download_options: download_options::DownloadOptions::NoOpen,
177 permitted_cross_domain_policies:
178 permitted_cross_domain_policies::PermittedCrossDomainPolicies::None,
179 }
180 }
181
182 pub fn with_csp(mut self, config: content_security_policy::CspConfig) -> Self {
184 self.csp = Some(config);
185 self
186 }
187
188 pub fn with_dns_prefetch_control(
190 mut self,
191 control: dns_prefetch_control::DnsPrefetchControl,
192 ) -> Self {
193 self.dns_prefetch_control = control;
194 self
195 }
196
197 pub fn with_expect_ct(mut self, config: expect_ct::ExpectCtConfig) -> Self {
199 self.expect_ct = Some(config);
200 self
201 }
202
203 pub fn with_frame_guard(mut self, guard: frame_guard::FrameGuard) -> Self {
205 self.frame_guard = guard;
206 self
207 }
208
209 pub fn with_hsts(mut self, config: hsts::HstsConfig) -> Self {
211 self.hsts = Some(config);
212 self
213 }
214
215 pub fn hide_powered_by(mut self, hide: bool) -> Self {
217 self.hide_powered_by = hide;
218 self
219 }
220
221 pub fn with_referrer_policy(mut self, policy: referrer_policy::ReferrerPolicy) -> Self {
223 self.referrer_policy = policy;
224 self
225 }
226
227 pub fn with_xss_filter(mut self, filter: xss_filter::XssFilter) -> Self {
229 self.xss_filter = filter;
230 self
231 }
232
233 pub fn apply(&self, mut response: HttpResponse) -> HttpResponse {
235 let mut headers = HashMap::new();
236
237 if let Some(ref csp) = self.csp {
239 headers.insert("Content-Security-Policy".to_string(), csp.to_header_value());
240 }
241
242 headers.insert(
244 "X-DNS-Prefetch-Control".to_string(),
245 self.dns_prefetch_control.to_header_value(),
246 );
247
248 if let Some(ref expect_ct) = self.expect_ct {
250 headers.insert("Expect-CT".to_string(), expect_ct.to_header_value());
251 }
252
253 headers.insert(
255 "X-Frame-Options".to_string(),
256 self.frame_guard.to_header_value(),
257 );
258
259 if let Some(ref hsts) = self.hsts {
261 headers.insert(
262 "Strict-Transport-Security".to_string(),
263 hsts.to_header_value(),
264 );
265 }
266
267 headers.insert(
269 "Referrer-Policy".to_string(),
270 self.referrer_policy.to_header_value(),
271 );
272
273 headers.insert(
275 "X-XSS-Protection".to_string(),
276 self.xss_filter.to_header_value(),
277 );
278
279 headers.insert(
281 "X-Content-Type-Options".to_string(),
282 self.content_type_options.to_header_value(),
283 );
284
285 headers.insert(
287 "X-Download-Options".to_string(),
288 self.download_options.to_header_value(),
289 );
290
291 headers.insert(
293 "X-Permitted-Cross-Domain-Policies".to_string(),
294 self.permitted_cross_domain_policies.to_header_value(),
295 );
296
297 for (key, value) in headers {
299 response.headers.insert(key, value);
300 }
301
302 if self.hide_powered_by {
304 response.headers.remove("X-Powered-By");
305 }
306
307 response
308 }
309
310 pub fn enable_all(max_age_seconds: u64) -> Self {
312 Self {
313 csp: Some(content_security_policy::CspConfig::default()),
314 dns_prefetch_control: dns_prefetch_control::DnsPrefetchControl::Off,
315 expect_ct: Some(expect_ct::ExpectCtConfig::new(max_age_seconds)),
316 frame_guard: frame_guard::FrameGuard::Deny,
317 hsts: Some(hsts::HstsConfig::new(max_age_seconds)),
318 hide_powered_by: true,
319 referrer_policy: referrer_policy::ReferrerPolicy::NoReferrer,
320 xss_filter: xss_filter::XssFilter::Enabled,
321 content_type_options: content_type_options::ContentTypeOptions::NoSniff,
322 download_options: download_options::DownloadOptions::NoOpen,
323 permitted_cross_domain_policies:
324 permitted_cross_domain_policies::PermittedCrossDomainPolicies::None,
325 }
326 }
327}
328
329impl Default for SecurityMiddleware {
330 fn default() -> Self {
331 Self::enable_all(31536000) }
333}
334
335pub mod prelude {
341 pub use crate::SecurityMiddleware;
342 pub use crate::content_security_policy::CspConfig;
343 pub use crate::cors::CorsConfig;
344 pub use crate::frame_guard::FrameGuard;
345 pub use crate::hsts::HstsConfig;
346 pub use crate::referrer_policy::ReferrerPolicy;
347 pub use crate::request_signing::{RequestSigner, RequestSigningMiddleware, RequestVerifier};
348}
349
350#[async_trait::async_trait]
353impl armature_core::Middleware for SecurityMiddleware {
354 async fn handle(
355 &self,
356 req: armature_core::HttpRequest,
357 next: Box<
358 dyn FnOnce(
359 armature_core::HttpRequest,
360 ) -> std::pin::Pin<
361 Box<
362 dyn std::future::Future<
363 Output = Result<armature_core::HttpResponse, armature_core::Error>,
364 > + Send,
365 >,
366 > + Send,
367 >,
368 ) -> Result<armature_core::HttpResponse, armature_core::Error> {
369 let response = next(req).await?;
371 Ok(self.apply(response))
373 }
374}
375
376#[cfg(test)]
377mod tests {
378 use super::*;
379
380 #[test]
381 fn test_security_middleware_new() {
382 let middleware = SecurityMiddleware::new();
383 assert!(middleware.csp.is_none());
384 assert!(!middleware.hide_powered_by);
385 }
386
387 #[test]
388 fn test_security_middleware_default() {
389 let middleware = SecurityMiddleware::default();
390 assert!(middleware.csp.is_some());
391 assert!(middleware.hsts.is_some());
392 assert!(middleware.hide_powered_by);
393 }
394
395 #[test]
396 fn test_security_middleware_apply() {
397 let middleware = SecurityMiddleware::default();
398 let response = HttpResponse::ok();
399 let secured = middleware.apply(response);
400
401 assert!(secured.headers.contains_key("X-Frame-Options"));
402 assert!(secured.headers.contains_key("X-Content-Type-Options"));
403 assert!(secured.headers.contains_key("X-XSS-Protection"));
404 assert!(secured.headers.contains_key("Strict-Transport-Security"));
405 assert!(secured.headers.contains_key("Content-Security-Policy"));
406 }
407
408 #[test]
409 fn test_hide_powered_by() {
410 let middleware = SecurityMiddleware::new().hide_powered_by(true);
411 let mut response = HttpResponse::ok();
412 response
413 .headers
414 .insert("X-Powered-By".to_string(), "Armature".to_string());
415
416 let secured = middleware.apply(response);
417 assert!(!secured.headers.contains_key("X-Powered-By"));
418 }
419
420 #[test]
421 fn test_custom_configuration() {
422 let middleware = SecurityMiddleware::new()
423 .with_frame_guard(frame_guard::FrameGuard::SameOrigin)
424 .with_referrer_policy(referrer_policy::ReferrerPolicy::StrictOriginWhenCrossOrigin)
425 .hide_powered_by(true);
426
427 let response = HttpResponse::ok();
428 let secured = middleware.apply(response);
429
430 assert_eq!(
431 secured.headers.get("X-Frame-Options"),
432 Some(&"SAMEORIGIN".to_string())
433 );
434 assert_eq!(
435 secured.headers.get("Referrer-Policy"),
436 Some(&"strict-origin-when-cross-origin".to_string())
437 );
438 }
439}