netray_common/
security_headers.rs1use axum::extract::Request;
2use axum::middleware::Next;
3use axum::response::Response;
4
5#[derive(Debug, Clone)]
7pub struct SecurityHeadersConfig {
8 pub extra_script_src: Vec<String>,
11
12 pub relaxed_csp_path_prefix: String,
15
16 pub include_permissions_policy: bool,
18}
19
20impl Default for SecurityHeadersConfig {
21 fn default() -> Self {
22 Self {
23 extra_script_src: Vec::new(),
24 relaxed_csp_path_prefix: "/docs".to_string(),
25 include_permissions_policy: false,
26 }
27 }
28}
29
30pub fn security_headers_layer(
44 config: SecurityHeadersConfig,
45) -> impl Fn(Request, Next) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send>>
46 + Clone
47 + Send
48 + 'static {
49 let strict_csp =
50 "default-src 'self'; script-src 'self'; style-src 'self' 'unsafe-inline'; connect-src 'self'; img-src 'self' data:; frame-ancestors 'none'".to_string();
51
52 let relaxed_csp = if config.extra_script_src.is_empty() {
53 strict_csp.clone()
54 } else {
55 let extra = config.extra_script_src.join(" ");
56 format!(
57 "default-src 'self'; script-src 'self' {extra}; style-src 'self' 'unsafe-inline'; connect-src 'self'; img-src 'self' data:; frame-ancestors 'none'"
58 )
59 };
60
61 let prefix = config.relaxed_csp_path_prefix;
62 let include_pp = config.include_permissions_policy;
63
64 move |request: Request, next: Next| {
65 let strict_csp = strict_csp.clone();
66 let relaxed_csp = relaxed_csp.clone();
67 let prefix = prefix.clone();
68
69 Box::pin(async move {
70 let is_relaxed_path = request.uri().path().starts_with(&prefix);
71
72 let mut response = next.run(request).await;
73 let headers = response.headers_mut();
74
75 let csp = if is_relaxed_path {
76 &relaxed_csp
77 } else {
78 &strict_csp
79 };
80 headers.insert(
81 axum::http::header::CONTENT_SECURITY_POLICY,
82 csp.parse().expect("valid CSP header value"),
83 );
84
85 headers.insert(
86 axum::http::header::X_CONTENT_TYPE_OPTIONS,
87 "nosniff".parse().expect("valid header value"),
88 );
89
90 headers.insert(
91 axum::http::header::X_FRAME_OPTIONS,
92 "DENY".parse().expect("valid header value"),
93 );
94
95 headers.insert(
96 axum::http::header::REFERRER_POLICY,
97 "strict-origin-when-cross-origin"
98 .parse()
99 .expect("valid header value"),
100 );
101
102 headers.insert(
103 axum::http::header::STRICT_TRANSPORT_SECURITY,
104 "max-age=31536000; includeSubDomains"
105 .parse()
106 .expect("valid header value"),
107 );
108
109 if include_pp {
110 headers.insert(
111 axum::http::HeaderName::from_static("permissions-policy"),
112 "geolocation=(), microphone=(), camera=(), payment=()"
113 .parse()
114 .expect("valid header value"),
115 );
116 }
117
118 response
119 })
120 }
121}
122
123#[cfg(test)]
124mod tests {
125 use super::*;
126 use axum::body::Body;
127 use axum::http::{Request as HttpRequest, StatusCode};
128 use axum::middleware;
129 use axum::routing::get;
130 use axum::Router;
131 use tower::ServiceExt;
132
133 async fn ok_handler() -> &'static str {
134 "ok"
135 }
136
137 async fn make_response(config: SecurityHeadersConfig, path: &str) -> Response {
138 let layer_fn = security_headers_layer(config);
139 let app = Router::new()
140 .route("/test", get(ok_handler))
141 .route("/docs/test", get(ok_handler))
142 .layer(middleware::from_fn(move |req, next| {
143 let f = layer_fn.clone();
144 async move { f(req, next).await }
145 }));
146
147 let request = HttpRequest::builder()
148 .uri(path)
149 .body(Body::empty())
150 .unwrap();
151
152 app.oneshot(request).await.unwrap()
153 }
154
155 #[tokio::test]
156 async fn sets_all_base_headers() {
157 let response = make_response(SecurityHeadersConfig::default(), "/test").await;
158
159 assert_eq!(response.status(), StatusCode::OK);
160
161 let csp = response
162 .headers()
163 .get("content-security-policy")
164 .unwrap()
165 .to_str()
166 .unwrap();
167 assert!(csp.contains("default-src 'self'"));
168 assert!(csp.contains("script-src 'self'"));
169 assert!(csp.contains("style-src 'self' 'unsafe-inline'"));
170 assert!(csp.contains("frame-ancestors 'none'"));
171
172 assert_eq!(
173 response.headers().get("x-content-type-options").unwrap(),
174 "nosniff"
175 );
176 assert_eq!(response.headers().get("x-frame-options").unwrap(), "DENY");
177 assert_eq!(
178 response.headers().get("referrer-policy").unwrap(),
179 "strict-origin-when-cross-origin"
180 );
181 assert_eq!(
182 response.headers().get("strict-transport-security").unwrap(),
183 "max-age=31536000; includeSubDomains"
184 );
185 }
186
187 #[tokio::test]
188 async fn no_permissions_policy_by_default() {
189 let response = make_response(SecurityHeadersConfig::default(), "/test").await;
190 assert!(response.headers().get("permissions-policy").is_none());
191 }
192
193 #[tokio::test]
194 async fn includes_permissions_policy_when_configured() {
195 let config = SecurityHeadersConfig {
196 include_permissions_policy: true,
197 ..Default::default()
198 };
199 let response = make_response(config, "/test").await;
200 let pp = response
201 .headers()
202 .get("permissions-policy")
203 .expect("Permissions-Policy header present")
204 .to_str()
205 .unwrap();
206 assert!(pp.contains("geolocation=()"));
207 assert!(pp.contains("camera=()"));
208 }
209
210 #[tokio::test]
211 async fn relaxed_csp_on_docs_path() {
212 let config = SecurityHeadersConfig {
213 extra_script_src: vec!["https://cdn.jsdelivr.net".to_string()],
214 ..Default::default()
215 };
216 let response = make_response(config, "/docs/test").await;
217 let csp = response
218 .headers()
219 .get("content-security-policy")
220 .unwrap()
221 .to_str()
222 .unwrap();
223 assert!(csp.contains("https://cdn.jsdelivr.net"));
224 }
225
226 #[tokio::test]
227 async fn strict_csp_on_non_docs_path() {
228 let config = SecurityHeadersConfig {
229 extra_script_src: vec!["https://cdn.jsdelivr.net".to_string()],
230 ..Default::default()
231 };
232 let response = make_response(config, "/test").await;
233 let csp = response
234 .headers()
235 .get("content-security-policy")
236 .unwrap()
237 .to_str()
238 .unwrap();
239 assert!(!csp.contains("cdn.jsdelivr.net"));
240 }
241}