Skip to main content

netray_common/
security_headers.rs

1use axum::extract::Request;
2use axum::middleware::Next;
3use axum::response::Response;
4
5/// Configuration for the security headers middleware.
6#[derive(Debug, Clone)]
7pub struct SecurityHeadersConfig {
8    /// Additional `script-src` origins to include in CSP (e.g. `"https://cdn.jsdelivr.net"`).
9    /// Applied to paths matching `relaxed_csp_path_prefix`.
10    pub extra_script_src: Vec<String>,
11
12    /// Path prefix that triggers the relaxed CSP with `extra_script_src`.
13    /// Defaults to `"/docs"` if empty.
14    pub relaxed_csp_path_prefix: String,
15
16    /// Whether to include the `Permissions-Policy` header.
17    pub include_permissions_policy: bool,
18}
19
20impl Default for SecurityHeadersConfig {
21    fn default() -> Self {
22        Self {
23            extra_script_src: Vec::new(),
24            relaxed_csp_path_prefix: "/docs".to_string(),
25            include_permissions_policy: false,
26        }
27    }
28}
29
30/// Build a security headers middleware function from the given config.
31///
32/// Returns an async closure suitable for `axum::middleware::from_fn`.
33///
34/// Headers applied:
35/// - `Content-Security-Policy`: Restricts resource loading to same origin.
36///   `style-src 'unsafe-inline'` is included for inline styles.
37///   Paths matching `relaxed_csp_path_prefix` get additional `script-src` origins.
38/// - `X-Content-Type-Options: nosniff`
39/// - `X-Frame-Options: DENY`
40/// - `Referrer-Policy: strict-origin-when-cross-origin`
41/// - `Strict-Transport-Security: max-age=31536000; includeSubDomains`
42/// - `Permissions-Policy` (optional, when `include_permissions_policy` is true)
43pub fn security_headers_layer(
44    config: SecurityHeadersConfig,
45) -> impl Fn(Request, Next) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send>>
46       + Clone
47       + Send
48       + 'static {
49    let strict_csp =
50        "default-src 'self'; script-src 'self'; style-src 'self' 'unsafe-inline'; connect-src 'self'; img-src 'self' data:; frame-ancestors 'none'".to_string();
51
52    let relaxed_csp = if config.extra_script_src.is_empty() {
53        strict_csp.clone()
54    } else {
55        let extra = config.extra_script_src.join(" ");
56        format!(
57            "default-src 'self'; script-src 'self' {extra}; style-src 'self' 'unsafe-inline'; connect-src 'self'; img-src 'self' data:; frame-ancestors 'none'"
58        )
59    };
60
61    let prefix = config.relaxed_csp_path_prefix;
62    let include_pp = config.include_permissions_policy;
63
64    move |request: Request, next: Next| {
65        let strict_csp = strict_csp.clone();
66        let relaxed_csp = relaxed_csp.clone();
67        let prefix = prefix.clone();
68
69        Box::pin(async move {
70            let is_relaxed_path = request.uri().path().starts_with(&prefix);
71
72            let mut response = next.run(request).await;
73            let headers = response.headers_mut();
74
75            let csp = if is_relaxed_path {
76                &relaxed_csp
77            } else {
78                &strict_csp
79            };
80            headers.insert(
81                axum::http::header::CONTENT_SECURITY_POLICY,
82                csp.parse().expect("valid CSP header value"),
83            );
84
85            headers.insert(
86                axum::http::header::X_CONTENT_TYPE_OPTIONS,
87                "nosniff".parse().expect("valid header value"),
88            );
89
90            headers.insert(
91                axum::http::header::X_FRAME_OPTIONS,
92                "DENY".parse().expect("valid header value"),
93            );
94
95            headers.insert(
96                axum::http::header::REFERRER_POLICY,
97                "strict-origin-when-cross-origin"
98                    .parse()
99                    .expect("valid header value"),
100            );
101
102            headers.insert(
103                axum::http::header::STRICT_TRANSPORT_SECURITY,
104                "max-age=31536000; includeSubDomains"
105                    .parse()
106                    .expect("valid header value"),
107            );
108
109            if include_pp {
110                headers.insert(
111                    axum::http::HeaderName::from_static("permissions-policy"),
112                    "geolocation=(), microphone=(), camera=(), payment=()"
113                        .parse()
114                        .expect("valid header value"),
115                );
116            }
117
118            response
119        })
120    }
121}
122
123#[cfg(test)]
124mod tests {
125    use super::*;
126    use axum::body::Body;
127    use axum::http::{Request as HttpRequest, StatusCode};
128    use axum::middleware;
129    use axum::routing::get;
130    use axum::Router;
131    use tower::ServiceExt;
132
133    async fn ok_handler() -> &'static str {
134        "ok"
135    }
136
137    async fn make_response(config: SecurityHeadersConfig, path: &str) -> Response {
138        let layer_fn = security_headers_layer(config);
139        let app = Router::new()
140            .route("/test", get(ok_handler))
141            .route("/docs/test", get(ok_handler))
142            .layer(middleware::from_fn(move |req, next| {
143                let f = layer_fn.clone();
144                async move { f(req, next).await }
145            }));
146
147        let request = HttpRequest::builder()
148            .uri(path)
149            .body(Body::empty())
150            .unwrap();
151
152        app.oneshot(request).await.unwrap()
153    }
154
155    #[tokio::test]
156    async fn sets_all_base_headers() {
157        let response = make_response(SecurityHeadersConfig::default(), "/test").await;
158
159        assert_eq!(response.status(), StatusCode::OK);
160
161        let csp = response
162            .headers()
163            .get("content-security-policy")
164            .unwrap()
165            .to_str()
166            .unwrap();
167        assert!(csp.contains("default-src 'self'"));
168        assert!(csp.contains("script-src 'self'"));
169        assert!(csp.contains("style-src 'self' 'unsafe-inline'"));
170        assert!(csp.contains("frame-ancestors 'none'"));
171
172        assert_eq!(
173            response.headers().get("x-content-type-options").unwrap(),
174            "nosniff"
175        );
176        assert_eq!(response.headers().get("x-frame-options").unwrap(), "DENY");
177        assert_eq!(
178            response.headers().get("referrer-policy").unwrap(),
179            "strict-origin-when-cross-origin"
180        );
181        assert_eq!(
182            response.headers().get("strict-transport-security").unwrap(),
183            "max-age=31536000; includeSubDomains"
184        );
185    }
186
187    #[tokio::test]
188    async fn no_permissions_policy_by_default() {
189        let response = make_response(SecurityHeadersConfig::default(), "/test").await;
190        assert!(response.headers().get("permissions-policy").is_none());
191    }
192
193    #[tokio::test]
194    async fn includes_permissions_policy_when_configured() {
195        let config = SecurityHeadersConfig {
196            include_permissions_policy: true,
197            ..Default::default()
198        };
199        let response = make_response(config, "/test").await;
200        let pp = response
201            .headers()
202            .get("permissions-policy")
203            .expect("Permissions-Policy header present")
204            .to_str()
205            .unwrap();
206        assert!(pp.contains("geolocation=()"));
207        assert!(pp.contains("camera=()"));
208    }
209
210    #[tokio::test]
211    async fn relaxed_csp_on_docs_path() {
212        let config = SecurityHeadersConfig {
213            extra_script_src: vec!["https://cdn.jsdelivr.net".to_string()],
214            ..Default::default()
215        };
216        let response = make_response(config, "/docs/test").await;
217        let csp = response
218            .headers()
219            .get("content-security-policy")
220            .unwrap()
221            .to_str()
222            .unwrap();
223        assert!(csp.contains("https://cdn.jsdelivr.net"));
224    }
225
226    #[tokio::test]
227    async fn strict_csp_on_non_docs_path() {
228        let config = SecurityHeadersConfig {
229            extra_script_src: vec!["https://cdn.jsdelivr.net".to_string()],
230            ..Default::default()
231        };
232        let response = make_response(config, "/test").await;
233        let csp = response
234            .headers()
235            .get("content-security-policy")
236            .unwrap()
237            .to_str()
238            .unwrap();
239        assert!(!csp.contains("cdn.jsdelivr.net"));
240    }
241}