Skip to main content

netray_common/
security_headers.rs

1use axum::extract::Request;
2use axum::http::HeaderValue;
3use axum::middleware::Next;
4use axum::response::Response;
5
6/// Configuration for the security headers middleware.
7#[derive(Debug, Clone)]
8pub struct SecurityHeadersConfig {
9    /// Additional `script-src` origins to include in CSP (e.g. `"https://cdn.jsdelivr.net"`).
10    /// Applied to paths matching `relaxed_csp_path_prefix`.
11    pub extra_script_src: Vec<String>,
12
13    /// Path prefix that triggers the relaxed CSP with `extra_script_src`.
14    /// Defaults to `"/docs"` if empty.
15    pub relaxed_csp_path_prefix: String,
16
17    /// Whether to include the `Permissions-Policy` header.
18    pub include_permissions_policy: bool,
19}
20
21impl Default for SecurityHeadersConfig {
22    fn default() -> Self {
23        Self {
24            extra_script_src: Vec::new(),
25            relaxed_csp_path_prefix: "/docs".to_string(),
26            include_permissions_policy: false,
27        }
28    }
29}
30
31/// Build a security headers middleware function from the given config.
32///
33/// Returns an async closure suitable for `axum::middleware::from_fn`.
34///
35/// Headers applied:
36/// - `Content-Security-Policy`: Restricts resource loading to same origin.
37///   `style-src 'unsafe-inline'` is included for inline styles.
38///   Paths matching `relaxed_csp_path_prefix` get additional `script-src` origins.
39/// - `X-Content-Type-Options: nosniff`
40/// - `X-Frame-Options: DENY`
41/// - `Referrer-Policy: strict-origin-when-cross-origin`
42/// - `Strict-Transport-Security: max-age=31536000; includeSubDomains`
43/// - `Permissions-Policy` (optional, when `include_permissions_policy` is true)
44pub fn security_headers_layer(
45    config: SecurityHeadersConfig,
46) -> impl Fn(Request, Next) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send>>
47       + Clone
48       + Send
49       + 'static {
50    let valid_extra: Vec<String> = config
51        .extra_script_src
52        .into_iter()
53        .filter(|src| {
54            if src.contains(';') || src.contains('\n') || src.contains('\r') || src.is_empty() {
55                tracing::warn!(value = %src, "invalid extra_script_src entry skipped");
56                false
57            } else {
58                true
59            }
60        })
61        .collect();
62
63    let strict_csp =
64        "default-src 'self'; script-src 'self'; style-src 'self' 'unsafe-inline'; connect-src 'self'; img-src 'self' data:; frame-ancestors 'none'".to_string();
65
66    let relaxed_csp = if valid_extra.is_empty() {
67        strict_csp.clone()
68    } else {
69        let extra = valid_extra.join(" ");
70        format!(
71            "default-src 'self'; script-src 'self' {extra}; style-src 'self' 'unsafe-inline'; connect-src 'self'; img-src 'self' data:; frame-ancestors 'none'"
72        )
73    };
74
75    let strict_csp_val: HeaderValue = strict_csp.parse().expect("valid CSP header value");
76    let relaxed_csp_val: HeaderValue = relaxed_csp.parse().expect("valid CSP header value");
77    let nosniff: HeaderValue = "nosniff".parse().expect("valid header value");
78    let deny: HeaderValue = "DENY".parse().expect("valid header value");
79    let referrer: HeaderValue = "strict-origin-when-cross-origin"
80        .parse()
81        .expect("valid header value");
82    let hsts: HeaderValue = "max-age=31536000; includeSubDomains"
83        .parse()
84        .expect("valid header value");
85    let pp_val: Option<HeaderValue> = if config.include_permissions_policy {
86        Some(
87            "geolocation=(), microphone=(), camera=(), payment=()"
88                .parse()
89                .expect("valid header value"),
90        )
91    } else {
92        None
93    };
94
95    let prefix = config.relaxed_csp_path_prefix;
96    let prefix_with_slash = format!("{prefix}/");
97
98    move |request: Request, next: Next| {
99        let strict_csp_val = strict_csp_val.clone();
100        let relaxed_csp_val = relaxed_csp_val.clone();
101        let nosniff = nosniff.clone();
102        let deny = deny.clone();
103        let referrer = referrer.clone();
104        let hsts = hsts.clone();
105        let pp_val = pp_val.clone();
106        let prefix = prefix.clone();
107        let prefix_with_slash = prefix_with_slash.clone();
108
109        Box::pin(async move {
110            let path = request.uri().path();
111            let is_relaxed_path = path == prefix || path.starts_with(&prefix_with_slash);
112
113            let mut response = next.run(request).await;
114            let headers = response.headers_mut();
115
116            let csp = if is_relaxed_path {
117                relaxed_csp_val
118            } else {
119                strict_csp_val
120            };
121            headers.insert(axum::http::header::CONTENT_SECURITY_POLICY, csp);
122            headers.insert(axum::http::header::X_CONTENT_TYPE_OPTIONS, nosniff);
123            headers.insert(axum::http::header::X_FRAME_OPTIONS, deny);
124            headers.insert(axum::http::header::REFERRER_POLICY, referrer);
125            headers.insert(axum::http::header::STRICT_TRANSPORT_SECURITY, hsts);
126
127            if let Some(pp) = pp_val {
128                headers.insert(
129                    axum::http::HeaderName::from_static("permissions-policy"),
130                    pp,
131                );
132            }
133
134            response
135        })
136    }
137}
138
139#[cfg(test)]
140mod tests {
141    use super::*;
142    use axum::body::Body;
143    use axum::http::{Request as HttpRequest, StatusCode};
144    use axum::middleware;
145    use axum::routing::get;
146    use axum::Router;
147    use tower::ServiceExt;
148
149    async fn ok_handler() -> &'static str {
150        "ok"
151    }
152
153    async fn make_response(config: SecurityHeadersConfig, path: &str) -> Response {
154        let layer_fn = security_headers_layer(config);
155        let app = Router::new()
156            .route("/test", get(ok_handler))
157            .route("/docs/test", get(ok_handler))
158            .layer(middleware::from_fn(move |req, next| {
159                let f = layer_fn.clone();
160                async move { f(req, next).await }
161            }));
162
163        let request = HttpRequest::builder()
164            .uri(path)
165            .body(Body::empty())
166            .unwrap();
167
168        app.oneshot(request).await.unwrap()
169    }
170
171    #[tokio::test]
172    async fn sets_all_base_headers() {
173        let response = make_response(SecurityHeadersConfig::default(), "/test").await;
174
175        assert_eq!(response.status(), StatusCode::OK);
176
177        let csp = response
178            .headers()
179            .get("content-security-policy")
180            .unwrap()
181            .to_str()
182            .unwrap();
183        assert!(csp.contains("default-src 'self'"));
184        assert!(csp.contains("script-src 'self'"));
185        assert!(csp.contains("style-src 'self' 'unsafe-inline'"));
186        assert!(csp.contains("frame-ancestors 'none'"));
187
188        assert_eq!(
189            response.headers().get("x-content-type-options").unwrap(),
190            "nosniff"
191        );
192        assert_eq!(response.headers().get("x-frame-options").unwrap(), "DENY");
193        assert_eq!(
194            response.headers().get("referrer-policy").unwrap(),
195            "strict-origin-when-cross-origin"
196        );
197        assert_eq!(
198            response.headers().get("strict-transport-security").unwrap(),
199            "max-age=31536000; includeSubDomains"
200        );
201    }
202
203    #[tokio::test]
204    async fn no_permissions_policy_by_default() {
205        let response = make_response(SecurityHeadersConfig::default(), "/test").await;
206        assert!(response.headers().get("permissions-policy").is_none());
207    }
208
209    #[tokio::test]
210    async fn includes_permissions_policy_when_configured() {
211        let config = SecurityHeadersConfig {
212            include_permissions_policy: true,
213            ..Default::default()
214        };
215        let response = make_response(config, "/test").await;
216        let pp = response
217            .headers()
218            .get("permissions-policy")
219            .expect("Permissions-Policy header present")
220            .to_str()
221            .unwrap();
222        assert!(pp.contains("geolocation=()"));
223        assert!(pp.contains("camera=()"));
224    }
225
226    #[tokio::test]
227    async fn relaxed_csp_on_docs_path() {
228        let config = SecurityHeadersConfig {
229            extra_script_src: vec!["https://cdn.jsdelivr.net".to_string()],
230            ..Default::default()
231        };
232        let response = make_response(config, "/docs/test").await;
233        let csp = response
234            .headers()
235            .get("content-security-policy")
236            .unwrap()
237            .to_str()
238            .unwrap();
239        assert!(csp.contains("https://cdn.jsdelivr.net"));
240    }
241
242    #[tokio::test]
243    async fn relaxed_csp_on_custom_prefix() {
244        let config = SecurityHeadersConfig {
245            extra_script_src: vec!["https://cdn.example.com".to_string()],
246            relaxed_csp_path_prefix: "/api-docs".to_string(),
247            ..Default::default()
248        };
249        let layer_fn = security_headers_layer(config);
250        let app = Router::new()
251            .route("/api-docs/test", get(ok_handler))
252            .route("/test", get(ok_handler))
253            .layer(middleware::from_fn(move |req, next| {
254                let f = layer_fn.clone();
255                async move { f(req, next).await }
256            }));
257
258        let req = HttpRequest::builder()
259            .uri("/api-docs/test")
260            .body(Body::empty())
261            .unwrap();
262        let response = app.clone().oneshot(req).await.unwrap();
263        let csp = response
264            .headers()
265            .get("content-security-policy")
266            .unwrap()
267            .to_str()
268            .unwrap();
269        assert!(csp.contains("https://cdn.example.com"));
270
271        let req = HttpRequest::builder()
272            .uri("/test")
273            .body(Body::empty())
274            .unwrap();
275        let response = app.oneshot(req).await.unwrap();
276        let csp = response
277            .headers()
278            .get("content-security-policy")
279            .unwrap()
280            .to_str()
281            .unwrap();
282        assert!(!csp.contains("cdn.example.com"));
283    }
284
285    #[tokio::test]
286    async fn strict_csp_on_non_docs_path() {
287        let config = SecurityHeadersConfig {
288            extra_script_src: vec!["https://cdn.jsdelivr.net".to_string()],
289            ..Default::default()
290        };
291        let response = make_response(config, "/test").await;
292        let csp = response
293            .headers()
294            .get("content-security-policy")
295            .unwrap()
296            .to_str()
297            .unwrap();
298        assert!(!csp.contains("cdn.jsdelivr.net"));
299    }
300}