Skip to main content

modo/middleware/
security_headers.rs

1use std::future::Future;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4
5use http::{HeaderName, HeaderValue, Response};
6use serde::Deserialize;
7use tower::{Layer, Service};
8
9use crate::Error;
10
11/// Configuration for security response headers.
12///
13/// All fields have sensible defaults. Optional fields (`hsts_max_age`,
14/// `content_security_policy`, `permissions_policy`) are `None` by default
15/// and their corresponding headers are only added when set.
16#[non_exhaustive]
17#[derive(Debug, Clone, Deserialize)]
18#[serde(default)]
19pub struct SecurityHeadersConfig {
20    /// When `true`, adds `X-Content-Type-Options: nosniff`.
21    pub x_content_type_options: bool,
22    /// Value for the `X-Frame-Options` header (e.g. `"DENY"`, `"SAMEORIGIN"`).
23    pub x_frame_options: String,
24    /// Value for the `Referrer-Policy` header.
25    pub referrer_policy: String,
26    /// When set, adds `Strict-Transport-Security: max-age=<value>`.
27    pub hsts_max_age: Option<u64>,
28    /// When set, adds the `Content-Security-Policy` header with this value.
29    pub content_security_policy: Option<String>,
30    /// When set, adds the `Permissions-Policy` header with this value.
31    pub permissions_policy: Option<String>,
32}
33
34impl Default for SecurityHeadersConfig {
35    fn default() -> Self {
36        Self {
37            x_content_type_options: true,
38            x_frame_options: "DENY".to_string(),
39            referrer_policy: "strict-origin-when-cross-origin".to_string(),
40            hsts_max_age: None,
41            content_security_policy: None,
42            permissions_policy: None,
43        }
44    }
45}
46
47/// A [`Layer`] that adds configurable security headers to every response.
48#[derive(Clone)]
49pub struct SecurityHeadersLayer {
50    headers: Vec<(HeaderName, HeaderValue)>,
51}
52
53impl SecurityHeadersLayer {
54    fn from_config(config: &SecurityHeadersConfig) -> crate::Result<Self> {
55        let mut headers = Vec::new();
56
57        if config.x_content_type_options {
58            headers.push((
59                http::header::X_CONTENT_TYPE_OPTIONS,
60                HeaderValue::from_static("nosniff"),
61            ));
62        }
63
64        headers.push((
65            HeaderName::from_static("x-frame-options"),
66            HeaderValue::from_str(&config.x_frame_options)
67                .map_err(|_| Error::unprocessable_entity("invalid x-frame-options value"))?,
68        ));
69
70        headers.push((
71            HeaderName::from_static("referrer-policy"),
72            HeaderValue::from_str(&config.referrer_policy)
73                .map_err(|_| Error::unprocessable_entity("invalid referrer-policy value"))?,
74        ));
75
76        if let Some(max_age) = config.hsts_max_age {
77            headers.push((
78                http::header::STRICT_TRANSPORT_SECURITY,
79                HeaderValue::from_str(&format!("max-age={max_age}"))
80                    .map_err(|_| Error::unprocessable_entity("invalid hsts max-age value"))?,
81            ));
82        }
83
84        if let Some(ref csp) = config.content_security_policy {
85            headers.push((
86                HeaderName::from_static("content-security-policy"),
87                HeaderValue::from_str(csp).map_err(|_| {
88                    Error::unprocessable_entity("invalid content-security-policy value")
89                })?,
90            ));
91        }
92
93        if let Some(ref pp) = config.permissions_policy {
94            headers.push((
95                HeaderName::from_static("permissions-policy"),
96                HeaderValue::from_str(pp)
97                    .map_err(|_| Error::unprocessable_entity("invalid permissions-policy value"))?,
98            ));
99        }
100
101        Ok(Self { headers })
102    }
103}
104
105impl<S> Layer<S> for SecurityHeadersLayer {
106    type Service = SecurityHeadersService<S>;
107
108    fn layer(&self, inner: S) -> Self::Service {
109        SecurityHeadersService {
110            inner,
111            headers: self.headers.clone(),
112        }
113    }
114}
115
116/// The [`Service`] produced by `SecurityHeadersLayer`.
117///
118/// Wraps an inner service and appends security headers to every response.
119#[derive(Clone)]
120pub struct SecurityHeadersService<S> {
121    inner: S,
122    headers: Vec<(HeaderName, HeaderValue)>,
123}
124
125impl<S, ReqBody, ResBody> Service<http::Request<ReqBody>> for SecurityHeadersService<S>
126where
127    S: Service<http::Request<ReqBody>, Response = Response<ResBody>>,
128    S::Future: Send + 'static,
129{
130    type Response = S::Response;
131    type Error = S::Error;
132    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
133
134    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
135        self.inner.poll_ready(cx)
136    }
137
138    fn call(&mut self, request: http::Request<ReqBody>) -> Self::Future {
139        let headers = self.headers.clone();
140        let future = self.inner.call(request);
141
142        Box::pin(async move {
143            let mut response = future.await?;
144            let resp_headers = response.headers_mut();
145            for (name, value) in headers {
146                if !resp_headers.contains_key(&name) {
147                    resp_headers.insert(name, value);
148                }
149            }
150            Ok(response)
151        })
152    }
153}
154
155/// Returns a Tower layer that adds security headers to every response
156/// based on the provided configuration.
157///
158/// # Errors
159///
160/// Returns [`crate::Error`] if any configured header value contains invalid
161/// HTTP header characters.
162///
163/// # Example
164///
165/// ```rust,no_run
166/// use modo::middleware::{security_headers, SecurityHeadersConfig};
167///
168/// let config = SecurityHeadersConfig::default();
169/// let layer = security_headers(&config).unwrap();
170/// ```
171pub fn security_headers(config: &SecurityHeadersConfig) -> crate::Result<SecurityHeadersLayer> {
172    SecurityHeadersLayer::from_config(config)
173}