axum_helmet/
lib.rs

1//! Helmet middleware for axum.
2//!
3//! # Example
4//!
5//! ```no_run
6//! use axum::{routing::get, Router};
7//! use axum_helmet::{Helmet, HelmetLayer};
8//! use helmet_core::Helmet as HelmetCore;
9//!
10//! #[tokio::main]
11//! async fn main() {
12//!     let app = Router::new()
13//!         .route("/", get(|| async { "Hello, world!" }))
14//!         .layer(HelmetLayer::new(
15//!             Helmet::new()
16//!                 .add(helmet_core::XContentTypeOptions::nosniff())
17//!                 .add(helmet_core::XFrameOptions::same_origin())
18//!                 .add(helmet_core::XXSSProtection::on().mode_block()),
19//!         ));
20//!
21//!     let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
22//!     axum::serve(listener, app).await.unwrap();
23//! }
24//! ```
25use std::{
26    future::Future,
27    pin::Pin,
28    task::{ready, Context, Poll},
29};
30
31use http::{header::HeaderName, HeaderMap, HeaderValue, Request, Response};
32use pin_project_lite::pin_project;
33use tower_service::Service;
34
35use helmet_core::Helmet as HelmetCore;
36
37// re-export helmet_core::* for convenience
38pub use helmet_core::*;
39
40/// Create a [`tower::layer::Layer`] that adds helmet headers to responses.
41/// See [`helmet_core::Helmet`] for more details.
42///
43/// # Example
44///
45/// ```no_run
46/// use axum::{routing::get, Router};
47/// use axum_helmet::{Helmet, HelmetLayer};
48///
49/// #[tokio::main]
50/// async fn main() {
51///     let app = Router::new()
52///         .route("/", get(|| async { "Hello, world!" }))
53///         .layer(HelmetLayer::new(
54///             Helmet::new()
55///                 .add(helmet_core::XContentTypeOptions::nosniff())
56///                 .add(helmet_core::XFrameOptions::same_origin())
57///                 .add(helmet_core::XXSSProtection::on().mode_block()),
58///         ));
59///
60///     let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
61///     axum::serve(listener, app).await.unwrap();
62/// }
63/// ```
64#[derive(Clone)]
65pub struct HelmetLayer {
66    headers: HeaderMap,
67}
68
69impl HelmetLayer {
70    pub fn new(core: HelmetCore) -> Self {
71        let headers = core
72            .headers
73            .iter()
74            .map(|header| {
75                (
76                    HeaderName::try_from(header.0).expect("invalid header name"),
77                    HeaderValue::try_from(&header.1).expect("invalid header value"),
78                )
79            })
80            .collect();
81        Self { headers }
82    }
83}
84
85impl<S> tower::layer::Layer<S> for HelmetLayer {
86    type Service = HelmetInner<S>;
87
88    fn layer(&self, inner: S) -> Self::Service {
89        let header_map = self.headers.clone();
90
91        HelmetInner { header_map, inner }
92    }
93}
94
95#[derive(Clone)]
96pub struct HelmetInner<S> {
97    header_map: HeaderMap,
98    inner: S,
99}
100
101impl<S> HelmetInner<S> {
102    pub fn new(inner: S) -> Self {
103        let header_map = HeaderMap::new();
104
105        Self { header_map, inner }
106    }
107}
108
109impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for HelmetInner<S>
110where
111    S: Service<Request<ReqBody>, Response = Response<ResBody>>,
112    ResBody: Default,
113{
114    type Response = S::Response;
115    type Error = S::Error;
116    type Future = ResponseFuture<S::Future>;
117
118    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
119        self.inner.poll_ready(cx)
120    }
121
122    fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
123        ResponseFuture {
124            future: self.inner.call(req),
125            headers: self.header_map.clone(),
126        }
127    }
128}
129
130pin_project! {
131    /// Response future for [`SetResponseHeader`].
132    #[derive(Debug)]
133    pub struct ResponseFuture<F> {
134        #[pin]
135        future: F,
136        headers: HeaderMap,
137    }
138}
139
140impl<F, ResBody, E> Future for ResponseFuture<F>
141where
142    F: Future<Output = Result<Response<ResBody>, E>>,
143{
144    type Output = F::Output;
145
146    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
147        let this = self.project();
148        let mut res = ready!(this.future.poll(cx)?);
149
150        res.headers_mut()
151            .extend(this.headers.iter().map(|(k, v)| (k.clone(), v.clone())));
152
153        Poll::Ready(Ok(res))
154    }
155}
156
157#[cfg(test)]
158mod tests {
159    use super::*;
160
161    use axum::{routing::get, Router};
162    use axum_test::TestServer;
163    use http::{header, HeaderValue};
164
165    #[tokio::test]
166    async fn test_helmet() {
167        let test_app = Router::new()
168            .route("/", get(|| async { "Hello, world!" }))
169            .layer(HelmetLayer::new(
170                Helmet::new()
171                    .add(helmet_core::XContentTypeOptions::nosniff())
172                    .add(helmet_core::XFrameOptions::same_origin())
173                    .add(helmet_core::XXSSProtection::on().mode_block()),
174            ));
175
176        let server = TestServer::new(test_app).expect("failed to create test server");
177
178        let res = server.get("/").await;
179
180        assert_eq!(res.status_code(), 200);
181
182        assert_eq!(
183            res.headers().get(header::X_CONTENT_TYPE_OPTIONS),
184            Some(&HeaderValue::from_static("nosniff"))
185        );
186        assert_eq!(
187            res.headers().get(header::X_FRAME_OPTIONS),
188            Some(&HeaderValue::from_static("SAMEORIGIN"))
189        );
190        assert_eq!(
191            res.headers().get(header::X_XSS_PROTECTION),
192            Some(&HeaderValue::from_static("1; mode=block"))
193        );
194    }
195}