by_loco/controller/middleware/
secure_headers.rs1use std::{
7 collections::{BTreeMap, HashMap},
8 sync::OnceLock,
9 task::{Context, Poll},
10};
11
12use axum::{
13 body::Body,
14 http::{HeaderName, HeaderValue, Request},
15 response::Response,
16 Router as AXRouter,
17};
18use futures_util::future::BoxFuture;
19use serde::{Deserialize, Serialize};
20use serde_json::{self, json};
21use tower::{Layer, Service};
22
23use crate::{app::AppContext, controller::middleware::MiddlewareLayer, Error, Result};
24
25static PRESETS: OnceLock<HashMap<String, BTreeMap<String, String>>> = OnceLock::new();
26fn get_presets() -> &'static HashMap<String, BTreeMap<String, String>> {
27 PRESETS.get_or_init(|| {
28 let json_data = include_str!("secure_headers.json");
29 serde_json::from_str(json_data).unwrap()
30 })
31}
32#[derive(Serialize, Deserialize, Debug, Clone)]
79pub struct SecureHeader {
80 #[serde(default)]
81 pub enable: bool,
82 #[serde(default = "default_preset")]
83 pub preset: String,
84 #[serde(default)]
85 pub overrides: Option<BTreeMap<String, String>>,
86}
87
88impl Default for SecureHeader {
89 fn default() -> Self {
90 serde_json::from_value(json!({})).unwrap()
91 }
92}
93
94fn default_preset() -> String {
95 "github".to_string()
96}
97
98impl MiddlewareLayer for SecureHeader {
99 fn name(&self) -> &'static str {
101 "secure_headers"
102 }
103
104 fn is_enabled(&self) -> bool {
106 self.enable
107 }
108
109 fn config(&self) -> serde_json::Result<serde_json::Value> {
110 serde_json::to_value(self)
111 }
112
113 fn apply(&self, app: AXRouter<AppContext>) -> Result<AXRouter<AppContext>> {
115 Ok(app.layer(SecureHeaders::new(self)?))
116 }
117}
118
119impl SecureHeader {
120 fn as_headers(&self) -> Result<Vec<(HeaderName, HeaderValue)>> {
124 let mut headers = vec![];
125
126 let preset = &self.preset;
127 let p = get_presets().get(preset).ok_or_else(|| {
128 Error::Message(format!(
129 "secure_headers: a preset named `{preset}` does not exist"
130 ))
131 })?;
132
133 Self::push_headers(&mut headers, p)?;
134 if let Some(overrides) = &self.overrides {
135 Self::push_headers(&mut headers, overrides)?;
136 }
137 Ok(headers)
138 }
139
140 fn push_headers(
146 headers: &mut Vec<(HeaderName, HeaderValue)>,
147 hm: &BTreeMap<String, String>,
148 ) -> Result<()> {
149 for (k, v) in hm {
150 headers.push((
151 HeaderName::from_bytes(k.clone().as_bytes()).map_err(Box::from)?,
152 HeaderValue::from_str(v.clone().as_str()).map_err(Box::from)?,
153 ));
154 }
155 Ok(())
156 }
157}
158
159#[derive(Clone, Debug)]
162pub struct SecureHeaders {
163 headers: Vec<(HeaderName, HeaderValue)>,
164}
165
166impl SecureHeaders {
167 pub fn new(config: &SecureHeader) -> Result<Self> {
173 Ok(Self {
174 headers: config.as_headers()?,
175 })
176 }
177}
178
179impl<S> Layer<S> for SecureHeaders {
180 type Service = SecureHeadersMiddleware<S>;
181
182 fn layer(&self, inner: S) -> Self::Service {
184 SecureHeadersMiddleware {
185 inner,
186 layer: self.clone(),
187 }
188 }
189}
190
191#[derive(Clone, Debug)]
193#[must_use]
194pub struct SecureHeadersMiddleware<S> {
195 inner: S,
196 layer: SecureHeaders,
197}
198
199impl<S> Service<Request<Body>> for SecureHeadersMiddleware<S>
200where
201 S: Service<Request<Body>, Response = Response> + Send + 'static,
202 S::Future: Send + 'static,
203{
204 type Response = S::Response;
205 type Error = S::Error;
206 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
207
208 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
209 self.inner.poll_ready(cx)
210 }
211
212 fn call(&mut self, request: Request<Body>) -> Self::Future {
213 let layer = self.layer.clone();
214 let future = self.inner.call(request);
215 Box::pin(async move {
216 let mut response: Response = future.await?;
217 let headers = response.headers_mut();
218 for (k, v) in &layer.headers {
219 headers.insert(k, v.clone());
220 }
221 Ok(response)
222 })
223 }
224}
225
226#[cfg(test)]
227mod tests {
228
229 use axum::{
230 http::{HeaderMap, Method},
231 routing::get,
232 Router,
233 };
234 use insta::assert_debug_snapshot;
235 use tower::ServiceExt;
236
237 use super::*;
238 fn normalize_headers(headers: &HeaderMap) -> BTreeMap<String, String> {
239 headers
240 .iter()
241 .map(|(k, v)| {
242 let key = k.to_string();
243 let value = v.to_str().unwrap_or("").to_string();
244 (key, value)
245 })
246 .collect()
247 }
248 #[tokio::test]
249 async fn can_set_headers() {
250 let config = SecureHeader {
251 enable: true,
252 preset: "github".to_string(),
253 overrides: None,
254 };
255 let app = Router::new()
256 .route("/", get(|| async {}))
257 .layer(SecureHeaders::new(&config).unwrap());
258
259 let req = Request::builder()
260 .uri("/")
261 .method(Method::GET)
262 .body(Body::empty())
263 .unwrap();
264 let response = app.oneshot(req).await.unwrap();
265 assert_debug_snapshot!(normalize_headers(response.headers()));
266 }
267
268 #[tokio::test]
269 async fn can_override_headers() {
270 let mut overrides = BTreeMap::new();
271 overrides.insert("X-Download-Options".to_string(), "foobar".to_string());
272 overrides.insert("New-Header".to_string(), "baz".to_string());
273
274 let config = SecureHeader {
275 enable: true,
276 preset: "github".to_string(),
277 overrides: Some(overrides),
278 };
279 let app = Router::new()
280 .route("/", get(|| async {}))
281 .layer(SecureHeaders::new(&config).unwrap());
282
283 let req = Request::builder()
284 .uri("/")
285 .method(Method::GET)
286 .body(Body::empty())
287 .unwrap();
288 let response = app.oneshot(req).await.unwrap();
289 assert_debug_snapshot!(normalize_headers(response.headers()));
290 }
291
292 #[tokio::test]
293 async fn default_is_github_preset() {
294 let config = SecureHeader::default();
295 let app = Router::new()
296 .route("/", get(|| async {}))
297 .layer(SecureHeaders::new(&config).unwrap());
298
299 let req = Request::builder()
300 .uri("/")
301 .method(Method::GET)
302 .body(Body::empty())
303 .unwrap();
304 let response = app.oneshot(req).await.unwrap();
305 assert_debug_snapshot!(normalize_headers(response.headers()));
306 }
307}