netray_common/
security_headers.rs1use axum::extract::Request;
2use axum::http::HeaderValue;
3use axum::middleware::Next;
4use axum::response::Response;
5
6#[derive(Debug, Clone)]
8pub struct SecurityHeadersConfig {
9 pub extra_script_src: Vec<String>,
12
13 pub relaxed_csp_path_prefix: String,
16
17 pub include_permissions_policy: bool,
19}
20
21impl Default for SecurityHeadersConfig {
22 fn default() -> Self {
23 Self {
24 extra_script_src: Vec::new(),
25 relaxed_csp_path_prefix: "/docs".to_string(),
26 include_permissions_policy: false,
27 }
28 }
29}
30
31pub fn security_headers_layer(
45 config: SecurityHeadersConfig,
46) -> impl Fn(Request, Next) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send>>
47 + Clone
48 + Send
49 + 'static {
50 let valid_extra: Vec<String> = config
51 .extra_script_src
52 .into_iter()
53 .filter(|src| {
54 if src.contains(';') || src.contains('\n') || src.contains('\r') || src.is_empty() {
55 tracing::warn!(value = %src, "invalid extra_script_src entry skipped");
56 false
57 } else {
58 true
59 }
60 })
61 .collect();
62
63 let strict_csp =
64 "default-src 'self'; script-src 'self'; style-src 'self' 'unsafe-inline'; connect-src 'self'; img-src 'self' data:; frame-ancestors 'none'".to_string();
65
66 let relaxed_csp = if valid_extra.is_empty() {
67 strict_csp.clone()
68 } else {
69 let extra = valid_extra.join(" ");
70 format!(
71 "default-src 'self'; script-src 'self' {extra}; style-src 'self' 'unsafe-inline'; connect-src 'self'; img-src 'self' data:; frame-ancestors 'none'"
72 )
73 };
74
75 let strict_csp_val: HeaderValue = strict_csp.parse().expect("valid CSP header value");
76 let relaxed_csp_val: HeaderValue = relaxed_csp.parse().expect("valid CSP header value");
77 let nosniff: HeaderValue = "nosniff".parse().expect("valid header value");
78 let deny: HeaderValue = "DENY".parse().expect("valid header value");
79 let referrer: HeaderValue = "strict-origin-when-cross-origin"
80 .parse()
81 .expect("valid header value");
82 let hsts: HeaderValue = "max-age=31536000; includeSubDomains"
83 .parse()
84 .expect("valid header value");
85 let pp_val: Option<HeaderValue> = if config.include_permissions_policy {
86 Some(
87 "geolocation=(), microphone=(), camera=(), payment=()"
88 .parse()
89 .expect("valid header value"),
90 )
91 } else {
92 None
93 };
94
95 let prefix = config.relaxed_csp_path_prefix;
96 let prefix_with_slash = format!("{prefix}/");
97
98 move |request: Request, next: Next| {
99 let strict_csp_val = strict_csp_val.clone();
100 let relaxed_csp_val = relaxed_csp_val.clone();
101 let nosniff = nosniff.clone();
102 let deny = deny.clone();
103 let referrer = referrer.clone();
104 let hsts = hsts.clone();
105 let pp_val = pp_val.clone();
106 let prefix = prefix.clone();
107 let prefix_with_slash = prefix_with_slash.clone();
108
109 Box::pin(async move {
110 let path = request.uri().path();
111 let is_relaxed_path = path == prefix || path.starts_with(&prefix_with_slash);
112
113 let mut response = next.run(request).await;
114 let headers = response.headers_mut();
115
116 let csp = if is_relaxed_path {
117 relaxed_csp_val
118 } else {
119 strict_csp_val
120 };
121 headers.insert(axum::http::header::CONTENT_SECURITY_POLICY, csp);
122 headers.insert(axum::http::header::X_CONTENT_TYPE_OPTIONS, nosniff);
123 headers.insert(axum::http::header::X_FRAME_OPTIONS, deny);
124 headers.insert(axum::http::header::REFERRER_POLICY, referrer);
125 headers.insert(axum::http::header::STRICT_TRANSPORT_SECURITY, hsts);
126
127 if let Some(pp) = pp_val {
128 headers.insert(
129 axum::http::HeaderName::from_static("permissions-policy"),
130 pp,
131 );
132 }
133
134 response
135 })
136 }
137}
138
139#[cfg(test)]
140mod tests {
141 use super::*;
142 use axum::body::Body;
143 use axum::http::{Request as HttpRequest, StatusCode};
144 use axum::middleware;
145 use axum::routing::get;
146 use axum::Router;
147 use tower::ServiceExt;
148
149 async fn ok_handler() -> &'static str {
150 "ok"
151 }
152
153 async fn make_response(config: SecurityHeadersConfig, path: &str) -> Response {
154 let layer_fn = security_headers_layer(config);
155 let app = Router::new()
156 .route("/test", get(ok_handler))
157 .route("/docs/test", get(ok_handler))
158 .layer(middleware::from_fn(move |req, next| {
159 let f = layer_fn.clone();
160 async move { f(req, next).await }
161 }));
162
163 let request = HttpRequest::builder()
164 .uri(path)
165 .body(Body::empty())
166 .unwrap();
167
168 app.oneshot(request).await.unwrap()
169 }
170
171 #[tokio::test]
172 async fn sets_all_base_headers() {
173 let response = make_response(SecurityHeadersConfig::default(), "/test").await;
174
175 assert_eq!(response.status(), StatusCode::OK);
176
177 let csp = response
178 .headers()
179 .get("content-security-policy")
180 .unwrap()
181 .to_str()
182 .unwrap();
183 assert!(csp.contains("default-src 'self'"));
184 assert!(csp.contains("script-src 'self'"));
185 assert!(csp.contains("style-src 'self' 'unsafe-inline'"));
186 assert!(csp.contains("frame-ancestors 'none'"));
187
188 assert_eq!(
189 response.headers().get("x-content-type-options").unwrap(),
190 "nosniff"
191 );
192 assert_eq!(response.headers().get("x-frame-options").unwrap(), "DENY");
193 assert_eq!(
194 response.headers().get("referrer-policy").unwrap(),
195 "strict-origin-when-cross-origin"
196 );
197 assert_eq!(
198 response.headers().get("strict-transport-security").unwrap(),
199 "max-age=31536000; includeSubDomains"
200 );
201 }
202
203 #[tokio::test]
204 async fn no_permissions_policy_by_default() {
205 let response = make_response(SecurityHeadersConfig::default(), "/test").await;
206 assert!(response.headers().get("permissions-policy").is_none());
207 }
208
209 #[tokio::test]
210 async fn includes_permissions_policy_when_configured() {
211 let config = SecurityHeadersConfig {
212 include_permissions_policy: true,
213 ..Default::default()
214 };
215 let response = make_response(config, "/test").await;
216 let pp = response
217 .headers()
218 .get("permissions-policy")
219 .expect("Permissions-Policy header present")
220 .to_str()
221 .unwrap();
222 assert!(pp.contains("geolocation=()"));
223 assert!(pp.contains("camera=()"));
224 }
225
226 #[tokio::test]
227 async fn relaxed_csp_on_docs_path() {
228 let config = SecurityHeadersConfig {
229 extra_script_src: vec!["https://cdn.jsdelivr.net".to_string()],
230 ..Default::default()
231 };
232 let response = make_response(config, "/docs/test").await;
233 let csp = response
234 .headers()
235 .get("content-security-policy")
236 .unwrap()
237 .to_str()
238 .unwrap();
239 assert!(csp.contains("https://cdn.jsdelivr.net"));
240 }
241
242 #[tokio::test]
243 async fn relaxed_csp_on_custom_prefix() {
244 let config = SecurityHeadersConfig {
245 extra_script_src: vec!["https://cdn.example.com".to_string()],
246 relaxed_csp_path_prefix: "/api-docs".to_string(),
247 ..Default::default()
248 };
249 let layer_fn = security_headers_layer(config);
250 let app = Router::new()
251 .route("/api-docs/test", get(ok_handler))
252 .route("/test", get(ok_handler))
253 .layer(middleware::from_fn(move |req, next| {
254 let f = layer_fn.clone();
255 async move { f(req, next).await }
256 }));
257
258 let req = HttpRequest::builder()
259 .uri("/api-docs/test")
260 .body(Body::empty())
261 .unwrap();
262 let response = app.clone().oneshot(req).await.unwrap();
263 let csp = response
264 .headers()
265 .get("content-security-policy")
266 .unwrap()
267 .to_str()
268 .unwrap();
269 assert!(csp.contains("https://cdn.example.com"));
270
271 let req = HttpRequest::builder()
272 .uri("/test")
273 .body(Body::empty())
274 .unwrap();
275 let response = app.oneshot(req).await.unwrap();
276 let csp = response
277 .headers()
278 .get("content-security-policy")
279 .unwrap()
280 .to_str()
281 .unwrap();
282 assert!(!csp.contains("cdn.example.com"));
283 }
284
285 #[tokio::test]
286 async fn strict_csp_on_non_docs_path() {
287 let config = SecurityHeadersConfig {
288 extra_script_src: vec!["https://cdn.jsdelivr.net".to_string()],
289 ..Default::default()
290 };
291 let response = make_response(config, "/test").await;
292 let csp = response
293 .headers()
294 .get("content-security-policy")
295 .unwrap()
296 .to_str()
297 .unwrap();
298 assert!(!csp.contains("cdn.jsdelivr.net"));
299 }
300}