1use axum::{
25 body::Body,
26 http::{header, Request, Response},
27 middleware::Next,
28 response::IntoResponse,
29};
30use std::fmt;
31
32#[derive(Debug, Clone)]
39pub struct SecurityHeadersConfig {
40 pub frame_options: Option<FrameOptions>,
45
46 pub content_type_options: bool,
50
51 pub xss_protection: Option<bool>,
56
57 pub hsts: Option<HstsConfig>,
61
62 pub csp: Option<String>,
66
67 pub referrer_policy: Option<ReferrerPolicy>,
71}
72
73#[derive(Debug, Clone, Copy, PartialEq, Eq)]
75pub enum FrameOptions {
76 Deny,
78 SameOrigin,
80}
81
82impl fmt::Display for FrameOptions {
83 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
84 match self {
85 Self::Deny => write!(f, "DENY"),
86 Self::SameOrigin => write!(f, "SAMEORIGIN"),
87 }
88 }
89}
90
91#[derive(Debug, Clone, Copy, PartialEq, Eq)]
93pub struct HstsConfig {
94 pub max_age: u32,
96 pub include_subdomains: bool,
98 pub preload: bool,
100}
101
102impl HstsConfig {
103 #[must_use]
105 pub const fn strict() -> Self {
106 Self {
107 max_age: 31_536_000, include_subdomains: true,
109 preload: true,
110 }
111 }
112
113 #[must_use]
115 pub const fn moderate() -> Self {
116 Self {
117 max_age: 31_536_000, include_subdomains: false,
119 preload: false,
120 }
121 }
122}
123
124impl fmt::Display for HstsConfig {
125 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
126 write!(f, "max-age={}", self.max_age)?;
127 if self.include_subdomains {
128 write!(f, "; includeSubDomains")?;
129 }
130 if self.preload {
131 write!(f, "; preload")?;
132 }
133 Ok(())
134 }
135}
136
137#[derive(Debug, Clone, Copy, PartialEq, Eq)]
139pub enum ReferrerPolicy {
140 NoReferrer,
142 NoReferrerWhenDowngrade,
144 Origin,
146 OriginWhenCrossOrigin,
148 SameOrigin,
150 StrictOrigin,
152 StrictOriginWhenCrossOrigin,
154 UnsafeUrl,
156}
157
158impl fmt::Display for ReferrerPolicy {
159 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
160 match self {
161 Self::NoReferrer => write!(f, "no-referrer"),
162 Self::NoReferrerWhenDowngrade => write!(f, "no-referrer-when-downgrade"),
163 Self::Origin => write!(f, "origin"),
164 Self::OriginWhenCrossOrigin => write!(f, "origin-when-cross-origin"),
165 Self::SameOrigin => write!(f, "same-origin"),
166 Self::StrictOrigin => write!(f, "strict-origin"),
167 Self::StrictOriginWhenCrossOrigin => write!(f, "strict-origin-when-cross-origin"),
168 Self::UnsafeUrl => write!(f, "unsafe-url"),
169 }
170 }
171}
172
173impl SecurityHeadersConfig {
174 #[must_use]
183 pub fn strict() -> Self {
184 Self {
185 frame_options: Some(FrameOptions::Deny),
186 content_type_options: true,
187 xss_protection: Some(true),
188 hsts: Some(HstsConfig::strict()),
189 csp: Some("default-src 'self'".to_string()),
190 referrer_policy: Some(ReferrerPolicy::StrictOriginWhenCrossOrigin),
191 }
192 }
193
194 #[must_use]
203 pub fn development() -> Self {
204 Self {
205 frame_options: Some(FrameOptions::SameOrigin),
206 content_type_options: true,
207 xss_protection: None, hsts: None, csp: Some(
210 "default-src 'self' 'unsafe-inline' 'unsafe-eval'; img-src 'self' data:"
211 .to_string(),
212 ),
213 referrer_policy: Some(ReferrerPolicy::StrictOriginWhenCrossOrigin),
214 }
215 }
216
217 #[must_use]
221 pub const fn custom() -> Self {
222 Self {
223 frame_options: None,
224 content_type_options: false,
225 xss_protection: None,
226 hsts: None,
227 csp: None,
228 referrer_policy: None,
229 }
230 }
231
232 #[must_use]
234 pub const fn with_frame_options(mut self, options: FrameOptions) -> Self {
235 self.frame_options = Some(options);
236 self
237 }
238
239 #[must_use]
241 pub const fn with_content_type_options(mut self) -> Self {
242 self.content_type_options = true;
243 self
244 }
245
246 #[must_use]
248 pub const fn with_xss_protection(mut self, block_mode: bool) -> Self {
249 self.xss_protection = Some(block_mode);
250 self
251 }
252
253 #[must_use]
255 pub const fn with_hsts(mut self, config: HstsConfig) -> Self {
256 self.hsts = Some(config);
257 self
258 }
259
260 #[must_use]
262 pub fn with_csp(mut self, policy: String) -> Self {
263 self.csp = Some(policy);
264 self
265 }
266
267 #[must_use]
269 pub const fn with_referrer_policy(mut self, policy: ReferrerPolicy) -> Self {
270 self.referrer_policy = Some(policy);
271 self
272 }
273}
274
275#[derive(Clone)]
292pub struct SecurityHeadersLayer {
293 config: SecurityHeadersConfig,
294}
295
296impl SecurityHeadersLayer {
297 #[must_use]
299 pub const fn new(config: SecurityHeadersConfig) -> Self {
300 Self { config }
301 }
302}
303
304impl<S> tower::Layer<S> for SecurityHeadersLayer {
305 type Service = SecurityHeadersMiddleware<S>;
306
307 fn layer(&self, inner: S) -> Self::Service {
308 SecurityHeadersMiddleware {
309 inner,
310 config: self.config.clone(),
311 }
312 }
313}
314
315#[derive(Clone)]
317pub struct SecurityHeadersMiddleware<S> {
318 inner: S,
319 config: SecurityHeadersConfig,
320}
321
322impl<S> tower::Service<Request<Body>> for SecurityHeadersMiddleware<S>
323where
324 S: tower::Service<Request<Body>, Response = Response<Body>> + Clone + Send + 'static,
325 S::Future: Send + 'static,
326{
327 type Response = S::Response;
328 type Error = S::Error;
329 type Future = std::pin::Pin<
330 Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>,
331 >;
332
333 fn poll_ready(
334 &mut self,
335 cx: &mut std::task::Context<'_>,
336 ) -> std::task::Poll<Result<(), Self::Error>> {
337 self.inner.poll_ready(cx)
338 }
339
340 fn call(&mut self, request: Request<Body>) -> Self::Future {
341 let config = self.config.clone();
342 let future = self.inner.call(request);
343
344 Box::pin(async move {
345 let mut response = future.await?;
346 add_security_headers(&mut response, &config);
347 Ok(response)
348 })
349 }
350}
351
352fn add_security_headers(response: &mut Response<Body>, config: &SecurityHeadersConfig) {
354 let headers = response.headers_mut();
355
356 if let Some(frame_options) = &config.frame_options {
358 headers.insert("x-frame-options", frame_options.to_string().parse().unwrap());
359 }
360
361 if config.content_type_options {
363 headers.insert(
364 "x-content-type-options",
365 "nosniff".parse().unwrap(),
366 );
367 }
368
369 if let Some(block_mode) = config.xss_protection {
371 let value = if block_mode {
372 "1; mode=block"
373 } else {
374 "1"
375 };
376 headers.insert("x-xss-protection", value.parse().unwrap());
377 }
378
379 if let Some(hsts) = &config.hsts {
381 headers.insert(
382 header::STRICT_TRANSPORT_SECURITY,
383 hsts.to_string().parse().unwrap(),
384 );
385 }
386
387 if let Some(csp) = &config.csp {
389 headers.insert(
390 header::CONTENT_SECURITY_POLICY,
391 csp.parse().unwrap(),
392 );
393 }
394
395 if let Some(referrer_policy) = &config.referrer_policy {
397 headers.insert(
398 header::REFERER,
399 referrer_policy.to_string().parse().unwrap(),
400 );
401 }
402}
403
404pub async fn security_headers(
423 request: Request<Body>,
424 next: Next,
425 config: SecurityHeadersConfig,
426) -> impl IntoResponse {
427 let mut response = next.run(request).await;
428 add_security_headers(&mut response, &config);
429 response
430}
431
432#[cfg(test)]
433mod tests {
434 use super::*;
435 use axum::{
436 body::Body,
437 http::{Request, StatusCode},
438 response::IntoResponse,
439 routing::get,
440 Router,
441 };
442 use tower::ServiceExt;
443
444 async fn test_handler() -> impl IntoResponse {
445 (StatusCode::OK, "Hello, World!")
446 }
447
448 #[tokio::test]
449 async fn test_strict_config_headers() {
450 let config = SecurityHeadersConfig::strict();
451 let app = Router::new()
452 .route("/", get(test_handler))
453 .layer(SecurityHeadersLayer::new(config));
454
455 let request = Request::builder()
456 .uri("/")
457 .body(Body::empty())
458 .unwrap();
459
460 let response = app.oneshot(request).await.unwrap();
461
462 let headers = response.headers();
463 assert_eq!(headers.get("x-frame-options").unwrap(), "DENY");
464 assert_eq!(headers.get("x-content-type-options").unwrap(), "nosniff");
465 assert_eq!(headers.get("x-xss-protection").unwrap(), "1; mode=block");
466 assert!(headers.contains_key("strict-transport-security"));
467 assert!(headers.contains_key("content-security-policy"));
468 }
469
470 #[tokio::test]
471 async fn test_development_config_headers() {
472 let config = SecurityHeadersConfig::development();
473 let app = Router::new()
474 .route("/", get(test_handler))
475 .layer(SecurityHeadersLayer::new(config));
476
477 let request = Request::builder()
478 .uri("/")
479 .body(Body::empty())
480 .unwrap();
481
482 let response = app.oneshot(request).await.unwrap();
483
484 let headers = response.headers();
485 assert_eq!(headers.get("x-frame-options").unwrap(), "SAMEORIGIN");
486 assert_eq!(headers.get("x-content-type-options").unwrap(), "nosniff");
487 assert!(!headers.contains_key("x-xss-protection"));
488 assert!(!headers.contains_key("strict-transport-security"));
489 assert!(headers.contains_key("content-security-policy"));
490 }
491
492 #[tokio::test]
493 async fn test_custom_config() {
494 let config = SecurityHeadersConfig::custom()
495 .with_frame_options(FrameOptions::SameOrigin)
496 .with_content_type_options()
497 .with_referrer_policy(ReferrerPolicy::NoReferrer);
498
499 let app = Router::new()
500 .route("/", get(test_handler))
501 .layer(SecurityHeadersLayer::new(config));
502
503 let request = Request::builder()
504 .uri("/")
505 .body(Body::empty())
506 .unwrap();
507
508 let response = app.oneshot(request).await.unwrap();
509
510 let headers = response.headers();
511 assert_eq!(headers.get("x-frame-options").unwrap(), "SAMEORIGIN");
512 assert_eq!(headers.get("x-content-type-options").unwrap(), "nosniff");
513 assert!(!headers.contains_key("x-xss-protection"));
514 assert!(!headers.contains_key("strict-transport-security"));
515 assert!(!headers.contains_key("content-security-policy"));
516 }
517
518 #[test]
519 fn test_hsts_config_display() {
520 let hsts = HstsConfig::strict();
521 assert_eq!(
522 hsts.to_string(),
523 "max-age=31536000; includeSubDomains; preload"
524 );
525
526 let hsts = HstsConfig::moderate();
527 assert_eq!(hsts.to_string(), "max-age=31536000");
528 }
529
530 #[test]
531 fn test_frame_options_display() {
532 assert_eq!(FrameOptions::Deny.to_string(), "DENY");
533 assert_eq!(FrameOptions::SameOrigin.to_string(), "SAMEORIGIN");
534 }
535
536 #[test]
537 fn test_referrer_policy_display() {
538 assert_eq!(ReferrerPolicy::NoReferrer.to_string(), "no-referrer");
539 assert_eq!(
540 ReferrerPolicy::StrictOriginWhenCrossOrigin.to_string(),
541 "strict-origin-when-cross-origin"
542 );
543 }
544
545 #[test]
546 fn test_config_builder() {
547 let config = SecurityHeadersConfig::custom()
548 .with_frame_options(FrameOptions::Deny)
549 .with_content_type_options()
550 .with_xss_protection(true)
551 .with_hsts(HstsConfig::strict())
552 .with_csp("default-src 'self'".to_string())
553 .with_referrer_policy(ReferrerPolicy::StrictOriginWhenCrossOrigin);
554
555 assert_eq!(config.frame_options, Some(FrameOptions::Deny));
556 assert!(config.content_type_options);
557 assert_eq!(config.xss_protection, Some(true));
558 assert!(config.hsts.is_some());
559 assert!(config.csp.is_some());
560 assert_eq!(
561 config.referrer_policy,
562 Some(ReferrerPolicy::StrictOriginWhenCrossOrigin)
563 );
564 }
565}