modo/middleware/
security_headers.rs1use std::future::Future;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4
5use http::{HeaderName, HeaderValue, Response};
6use serde::Deserialize;
7use tower::{Layer, Service};
8
9use crate::Error;
10
11#[non_exhaustive]
17#[derive(Debug, Clone, Deserialize)]
18#[serde(default)]
19pub struct SecurityHeadersConfig {
20 pub x_content_type_options: bool,
22 pub x_frame_options: String,
24 pub referrer_policy: String,
26 pub hsts_max_age: Option<u64>,
28 pub content_security_policy: Option<String>,
30 pub permissions_policy: Option<String>,
32}
33
34impl Default for SecurityHeadersConfig {
35 fn default() -> Self {
36 Self {
37 x_content_type_options: true,
38 x_frame_options: "DENY".to_string(),
39 referrer_policy: "strict-origin-when-cross-origin".to_string(),
40 hsts_max_age: None,
41 content_security_policy: None,
42 permissions_policy: None,
43 }
44 }
45}
46
47#[derive(Clone)]
49pub struct SecurityHeadersLayer {
50 headers: Vec<(HeaderName, HeaderValue)>,
51}
52
53impl SecurityHeadersLayer {
54 fn from_config(config: &SecurityHeadersConfig) -> crate::Result<Self> {
55 let mut headers = Vec::new();
56
57 if config.x_content_type_options {
58 headers.push((
59 http::header::X_CONTENT_TYPE_OPTIONS,
60 HeaderValue::from_static("nosniff"),
61 ));
62 }
63
64 headers.push((
65 HeaderName::from_static("x-frame-options"),
66 HeaderValue::from_str(&config.x_frame_options)
67 .map_err(|_| Error::unprocessable_entity("invalid x-frame-options value"))?,
68 ));
69
70 headers.push((
71 HeaderName::from_static("referrer-policy"),
72 HeaderValue::from_str(&config.referrer_policy)
73 .map_err(|_| Error::unprocessable_entity("invalid referrer-policy value"))?,
74 ));
75
76 if let Some(max_age) = config.hsts_max_age {
77 headers.push((
78 http::header::STRICT_TRANSPORT_SECURITY,
79 HeaderValue::from_str(&format!("max-age={max_age}"))
80 .map_err(|_| Error::unprocessable_entity("invalid hsts max-age value"))?,
81 ));
82 }
83
84 if let Some(ref csp) = config.content_security_policy {
85 headers.push((
86 HeaderName::from_static("content-security-policy"),
87 HeaderValue::from_str(csp).map_err(|_| {
88 Error::unprocessable_entity("invalid content-security-policy value")
89 })?,
90 ));
91 }
92
93 if let Some(ref pp) = config.permissions_policy {
94 headers.push((
95 HeaderName::from_static("permissions-policy"),
96 HeaderValue::from_str(pp)
97 .map_err(|_| Error::unprocessable_entity("invalid permissions-policy value"))?,
98 ));
99 }
100
101 Ok(Self { headers })
102 }
103}
104
105impl<S> Layer<S> for SecurityHeadersLayer {
106 type Service = SecurityHeadersService<S>;
107
108 fn layer(&self, inner: S) -> Self::Service {
109 SecurityHeadersService {
110 inner,
111 headers: self.headers.clone(),
112 }
113 }
114}
115
116#[derive(Clone)]
120pub struct SecurityHeadersService<S> {
121 inner: S,
122 headers: Vec<(HeaderName, HeaderValue)>,
123}
124
125impl<S, ReqBody, ResBody> Service<http::Request<ReqBody>> for SecurityHeadersService<S>
126where
127 S: Service<http::Request<ReqBody>, Response = Response<ResBody>>,
128 S::Future: Send + 'static,
129{
130 type Response = S::Response;
131 type Error = S::Error;
132 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
133
134 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
135 self.inner.poll_ready(cx)
136 }
137
138 fn call(&mut self, request: http::Request<ReqBody>) -> Self::Future {
139 let headers = self.headers.clone();
140 let future = self.inner.call(request);
141
142 Box::pin(async move {
143 let mut response = future.await?;
144 let resp_headers = response.headers_mut();
145 for (name, value) in headers {
146 if !resp_headers.contains_key(&name) {
147 resp_headers.insert(name, value);
148 }
149 }
150 Ok(response)
151 })
152 }
153}
154
155pub fn security_headers(config: &SecurityHeadersConfig) -> crate::Result<SecurityHeadersLayer> {
172 SecurityHeadersLayer::from_config(config)
173}