api_tools/server/axum/layers/
security_headers.rs1use axum::{
4 body::Body,
5 extract::Request,
6 http::{HeaderName, HeaderValue, header},
7 response::Response,
8};
9use futures::future::BoxFuture;
10use std::task::{Context, Poll};
11use tower::{Layer, Service};
12
13#[derive(Clone, Debug)]
15pub struct SecurityHeadersConfig {
16 pub content_security_policy: HeaderValue,
17 pub strict_transport_security: HeaderValue,
18 pub x_content_type_options: HeaderValue,
19 pub x_frame_options: HeaderValue,
20 pub x_xss_protection: HeaderValue,
21 pub referrer_policy: HeaderValue,
22 pub permissions_policy: HeaderValue,
23}
24
25impl Default for SecurityHeadersConfig {
26 fn default() -> Self {
27 SecurityHeadersConfig {
28 content_security_policy: HeaderValue::from_static("default-src 'self';"),
29 strict_transport_security: HeaderValue::from_static("max-age=31536000; includeSubDomains; preload"),
30 x_content_type_options: HeaderValue::from_static("nosniff"),
31 x_frame_options: HeaderValue::from_static("DENY"),
32 x_xss_protection: HeaderValue::from_static("1; mode=block"),
33 referrer_policy: HeaderValue::from_static("no-referrer"),
34 permissions_policy: HeaderValue::from_static("geolocation=(self), microphone=(), camera=()"),
35 }
36 }
37}
38
39#[derive(Clone)]
40pub struct SecurityHeadersLayer {
41 pub config: SecurityHeadersConfig,
42}
43
44impl SecurityHeadersLayer {
45 pub fn new(config: SecurityHeadersConfig) -> Self {
47 Self { config }
48 }
49}
50
51impl<S> Layer<S> for SecurityHeadersLayer {
52 type Service = SecurityHeadersMiddleware<S>;
53
54 fn layer(&self, inner: S) -> Self::Service {
55 SecurityHeadersMiddleware {
56 inner,
57 config: self.config.clone(),
58 }
59 }
60}
61
62#[derive(Clone)]
63pub struct SecurityHeadersMiddleware<S> {
64 inner: S,
65 config: SecurityHeadersConfig,
66}
67
68impl<S> Service<Request<Body>> for SecurityHeadersMiddleware<S>
69where
70 S: Service<Request<Body>, Response = Response> + Send + Clone + 'static,
71 S::Future: Send + 'static,
72{
73 type Response = S::Response;
74 type Error = S::Error;
75 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
77
78 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
79 self.inner.poll_ready(cx)
80 }
81
82 fn call(&mut self, request: Request<Body>) -> Self::Future {
83 let config = self.config.clone();
84 let future = self.inner.call(request);
85
86 Box::pin(async move {
87 let mut response: Response = future.await?;
88
89 let headers = response.headers_mut();
90 headers.insert(header::CONTENT_SECURITY_POLICY, config.content_security_policy);
91 headers.insert(header::STRICT_TRANSPORT_SECURITY, config.strict_transport_security);
92 headers.insert(header::X_CONTENT_TYPE_OPTIONS, config.x_content_type_options);
93 headers.insert(header::X_FRAME_OPTIONS, config.x_frame_options);
94 headers.insert(header::X_XSS_PROTECTION, config.x_xss_protection);
95 headers.insert(header::REFERRER_POLICY, config.referrer_policy);
96 headers.insert(HeaderName::from_static("permissions-policy"), config.permissions_policy);
97
98 Ok(response)
99 })
100 }
101}