api_tools/server/axum/layers/
security_headers.rs

1//! Security layer (standard security headers: CSP, HSTS, etc.)
2
3use axum::{
4    body::Body,
5    extract::Request,
6    http::{HeaderName, HeaderValue, header},
7    response::Response,
8};
9use futures::future::BoxFuture;
10use std::task::{Context, Poll};
11use tower::{Layer, Service};
12
13/// Configuration for security headers
14#[derive(Clone, Debug)]
15pub struct SecurityHeadersConfig {
16    pub content_security_policy: HeaderValue,
17    pub strict_transport_security: HeaderValue,
18    pub x_content_type_options: HeaderValue,
19    pub x_frame_options: HeaderValue,
20    pub x_xss_protection: HeaderValue,
21    pub referrer_policy: HeaderValue,
22    pub permissions_policy: HeaderValue,
23}
24
25impl Default for SecurityHeadersConfig {
26    fn default() -> Self {
27        SecurityHeadersConfig {
28            content_security_policy: HeaderValue::from_static("default-src 'self';"),
29            strict_transport_security: HeaderValue::from_static("max-age=31536000; includeSubDomains; preload"),
30            x_content_type_options: HeaderValue::from_static("nosniff"),
31            x_frame_options: HeaderValue::from_static("DENY"),
32            x_xss_protection: HeaderValue::from_static("1; mode=block"),
33            referrer_policy: HeaderValue::from_static("no-referrer"),
34            permissions_policy: HeaderValue::from_static("geolocation=(self), microphone=(), camera=()"),
35        }
36    }
37}
38
39#[derive(Clone)]
40pub struct SecurityHeadersLayer {
41    pub config: SecurityHeadersConfig,
42}
43
44impl SecurityHeadersLayer {
45    /// Create a new `SecurityLayer`
46    pub fn new(config: SecurityHeadersConfig) -> Self {
47        Self { config }
48    }
49}
50
51impl<S> Layer<S> for SecurityHeadersLayer {
52    type Service = SecurityHeadersMiddleware<S>;
53
54    fn layer(&self, inner: S) -> Self::Service {
55        SecurityHeadersMiddleware {
56            inner,
57            config: self.config.clone(),
58        }
59    }
60}
61
62#[derive(Clone)]
63pub struct SecurityHeadersMiddleware<S> {
64    inner: S,
65    config: SecurityHeadersConfig,
66}
67
68impl<S> Service<Request<Body>> for SecurityHeadersMiddleware<S>
69where
70    S: Service<Request<Body>, Response = Response> + Send + Clone + 'static,
71    S::Future: Send + 'static,
72{
73    type Response = S::Response;
74    type Error = S::Error;
75    // `BoxFuture` is a type alias for `Pin<Box<dyn Future + Send + 'a>>`
76    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
77
78    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
79        self.inner.poll_ready(cx)
80    }
81
82    fn call(&mut self, request: Request<Body>) -> Self::Future {
83        let config = self.config.clone();
84        let future = self.inner.call(request);
85
86        Box::pin(async move {
87            let mut response: Response = future.await?;
88
89            let headers = response.headers_mut();
90            headers.insert(header::CONTENT_SECURITY_POLICY, config.content_security_policy);
91            headers.insert(header::STRICT_TRANSPORT_SECURITY, config.strict_transport_security);
92            headers.insert(header::X_CONTENT_TYPE_OPTIONS, config.x_content_type_options);
93            headers.insert(header::X_FRAME_OPTIONS, config.x_frame_options);
94            headers.insert(header::X_XSS_PROTECTION, config.x_xss_protection);
95            headers.insert(header::REFERRER_POLICY, config.referrer_policy);
96            headers.insert(HeaderName::from_static("permissions-policy"), config.permissions_policy);
97
98            Ok(response)
99        })
100    }
101}