by_loco/controller/middleware/
secure_headers.rs

1//! Sets secure headers for your backend to promote security-by-default.
2//!
3//! This middleware applies secure HTTP headers, providing pre-defined presets
4//! (e.g., "github") and the ability to override or define custom headers.
5
6use std::{
7    collections::{BTreeMap, HashMap},
8    sync::OnceLock,
9    task::{Context, Poll},
10};
11
12use axum::{
13    body::Body,
14    http::{HeaderName, HeaderValue, Request},
15    response::Response,
16    Router as AXRouter,
17};
18use futures_util::future::BoxFuture;
19use serde::{Deserialize, Serialize};
20use serde_json::{self, json};
21use tower::{Layer, Service};
22
23use crate::{app::AppContext, controller::middleware::MiddlewareLayer, Error, Result};
24
25static PRESETS: OnceLock<HashMap<String, BTreeMap<String, String>>> = OnceLock::new();
26fn get_presets() -> &'static HashMap<String, BTreeMap<String, String>> {
27    PRESETS.get_or_init(|| {
28        let json_data = include_str!("secure_headers.json");
29        serde_json::from_str(json_data).unwrap()
30    })
31}
32/// Sets a predefined or custom set of secure headers.
33///
34/// We recommend our `github` preset. Presets values are derived
35/// from the [secure_headers](https://github.com/github/secure_headers) Ruby
36/// library which Github (and originally Twitter) use.
37///
38/// To use a preset, in your `config/development.yaml`:
39///
40/// ```yaml
41/// middlewares:
42///   secure_headers:
43///     preset: github
44/// ```
45///
46/// You can also override individual headers on a given preset:
47///
48/// ```yaml
49/// middlewares:
50///   secure_headers:
51///     preset: github
52///     overrides:
53///       foo: bar
54/// ```
55///
56/// Or start from scratch:
57///
58///```yaml
59/// middlewares:
60///   secure_headers:
61///     preset: empty
62///     overrides:
63///       one: two
64/// ```
65///
66/// To support `htmx`, You can add the following override, to allow some inline
67/// running of scripts:
68///
69/// ```yaml
70/// secure_headers:
71///     preset: github
72///     overrides:
73///         # this allows you to use HTMX, and has unsafe-inline. Remove or consider in production
74///         "Content-Security-Policy": "default-src 'self' https:; font-src 'self' https: data:; img-src 'self' https: data:; object-src 'none'; script-src 'unsafe-inline' 'self' https:; style-src 'self' https: 'unsafe-inline'"
75/// ```
76///
77/// For the list of presets and their content look at [secure_headers.json](https://github.com/loco-rs/loco/blob/master/src/controller/middleware/secure_headers.rs)
78#[derive(Serialize, Deserialize, Debug, Clone)]
79pub struct SecureHeader {
80    #[serde(default)]
81    pub enable: bool,
82    #[serde(default = "default_preset")]
83    pub preset: String,
84    #[serde(default)]
85    pub overrides: Option<BTreeMap<String, String>>,
86}
87
88impl Default for SecureHeader {
89    fn default() -> Self {
90        serde_json::from_value(json!({})).unwrap()
91    }
92}
93
94fn default_preset() -> String {
95    "github".to_string()
96}
97
98impl MiddlewareLayer for SecureHeader {
99    /// Returns the name of the middleware
100    fn name(&self) -> &'static str {
101        "secure_headers"
102    }
103
104    /// Returns whether the middleware is enabled or not
105    fn is_enabled(&self) -> bool {
106        self.enable
107    }
108
109    fn config(&self) -> serde_json::Result<serde_json::Value> {
110        serde_json::to_value(self)
111    }
112
113    /// Applies the secure headers layer to the application router
114    fn apply(&self, app: AXRouter<AppContext>) -> Result<AXRouter<AppContext>> {
115        Ok(app.layer(SecureHeaders::new(self)?))
116    }
117}
118
119impl SecureHeader {
120    /// Converts the configuration into a list of headers.
121    ///
122    /// Applies the preset headers and any custom overrides.
123    fn as_headers(&self) -> Result<Vec<(HeaderName, HeaderValue)>> {
124        let mut headers = vec![];
125
126        let preset = &self.preset;
127        let p = get_presets().get(preset).ok_or_else(|| {
128            Error::Message(format!(
129                "secure_headers: a preset named `{preset}` does not exist"
130            ))
131        })?;
132
133        Self::push_headers(&mut headers, p)?;
134        if let Some(overrides) = &self.overrides {
135            Self::push_headers(&mut headers, overrides)?;
136        }
137        Ok(headers)
138    }
139
140    /// Helper function to push headers into a mutable vector.
141    ///
142    /// This function takes a map of header names and values, converting them
143    /// into valid HTTP headers and adding them to the provided `headers`
144    /// vector.
145    fn push_headers(
146        headers: &mut Vec<(HeaderName, HeaderValue)>,
147        hm: &BTreeMap<String, String>,
148    ) -> Result<()> {
149        for (k, v) in hm {
150            headers.push((
151                HeaderName::from_bytes(k.clone().as_bytes()).map_err(Box::from)?,
152                HeaderValue::from_str(v.clone().as_str()).map_err(Box::from)?,
153            ));
154        }
155        Ok(())
156    }
157}
158
159/// The [`SecureHeaders`] layer which wraps around the service and injects
160/// security headers
161#[derive(Clone, Debug)]
162pub struct SecureHeaders {
163    headers: Vec<(HeaderName, HeaderValue)>,
164}
165
166impl SecureHeaders {
167    /// Creates a new [`SecureHeaders`] instance with the provided
168    /// configuration.
169    ///
170    /// # Errors
171    /// Returns an error if any header values are invalid.
172    pub fn new(config: &SecureHeader) -> Result<Self> {
173        Ok(Self {
174            headers: config.as_headers()?,
175        })
176    }
177}
178
179impl<S> Layer<S> for SecureHeaders {
180    type Service = SecureHeadersMiddleware<S>;
181
182    /// Wraps the provided service with the secure headers middleware.
183    fn layer(&self, inner: S) -> Self::Service {
184        SecureHeadersMiddleware {
185            inner,
186            layer: self.clone(),
187        }
188    }
189}
190
191/// The secure headers middleware
192#[derive(Clone, Debug)]
193#[must_use]
194pub struct SecureHeadersMiddleware<S> {
195    inner: S,
196    layer: SecureHeaders,
197}
198
199impl<S> Service<Request<Body>> for SecureHeadersMiddleware<S>
200where
201    S: Service<Request<Body>, Response = Response> + Send + 'static,
202    S::Future: Send + 'static,
203{
204    type Response = S::Response;
205    type Error = S::Error;
206    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
207
208    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
209        self.inner.poll_ready(cx)
210    }
211
212    fn call(&mut self, request: Request<Body>) -> Self::Future {
213        let layer = self.layer.clone();
214        let future = self.inner.call(request);
215        Box::pin(async move {
216            let mut response: Response = future.await?;
217            let headers = response.headers_mut();
218            for (k, v) in &layer.headers {
219                headers.insert(k, v.clone());
220            }
221            Ok(response)
222        })
223    }
224}
225
226#[cfg(test)]
227mod tests {
228
229    use axum::{
230        http::{HeaderMap, Method},
231        routing::get,
232        Router,
233    };
234    use insta::assert_debug_snapshot;
235    use tower::ServiceExt;
236
237    use super::*;
238    fn normalize_headers(headers: &HeaderMap) -> BTreeMap<String, String> {
239        headers
240            .iter()
241            .map(|(k, v)| {
242                let key = k.to_string();
243                let value = v.to_str().unwrap_or("").to_string();
244                (key, value)
245            })
246            .collect()
247    }
248    #[tokio::test]
249    async fn can_set_headers() {
250        let config = SecureHeader {
251            enable: true,
252            preset: "github".to_string(),
253            overrides: None,
254        };
255        let app = Router::new()
256            .route("/", get(|| async {}))
257            .layer(SecureHeaders::new(&config).unwrap());
258
259        let req = Request::builder()
260            .uri("/")
261            .method(Method::GET)
262            .body(Body::empty())
263            .unwrap();
264        let response = app.oneshot(req).await.unwrap();
265        assert_debug_snapshot!(normalize_headers(response.headers()));
266    }
267
268    #[tokio::test]
269    async fn can_override_headers() {
270        let mut overrides = BTreeMap::new();
271        overrides.insert("X-Download-Options".to_string(), "foobar".to_string());
272        overrides.insert("New-Header".to_string(), "baz".to_string());
273
274        let config = SecureHeader {
275            enable: true,
276            preset: "github".to_string(),
277            overrides: Some(overrides),
278        };
279        let app = Router::new()
280            .route("/", get(|| async {}))
281            .layer(SecureHeaders::new(&config).unwrap());
282
283        let req = Request::builder()
284            .uri("/")
285            .method(Method::GET)
286            .body(Body::empty())
287            .unwrap();
288        let response = app.oneshot(req).await.unwrap();
289        assert_debug_snapshot!(normalize_headers(response.headers()));
290    }
291
292    #[tokio::test]
293    async fn default_is_github_preset() {
294        let config = SecureHeader::default();
295        let app = Router::new()
296            .route("/", get(|| async {}))
297            .layer(SecureHeaders::new(&config).unwrap());
298
299        let req = Request::builder()
300            .uri("/")
301            .method(Method::GET)
302            .body(Body::empty())
303            .unwrap();
304        let response = app.oneshot(req).await.unwrap();
305        assert_debug_snapshot!(normalize_headers(response.headers()));
306    }
307}