1use std::{
26 future::Future,
27 pin::Pin,
28 task::{ready, Context, Poll},
29};
30
31use http::{header::HeaderName, HeaderMap, HeaderValue, Request, Response};
32use pin_project_lite::pin_project;
33use tower_service::Service;
34
35use helmet_core::Helmet as HelmetCore;
36
37pub use helmet_core::*;
39
40#[derive(Clone)]
65pub struct HelmetLayer {
66 headers: HeaderMap,
67}
68
69impl HelmetLayer {
70 pub fn new(core: HelmetCore) -> Self {
71 let headers = core
72 .headers
73 .iter()
74 .map(|header| {
75 (
76 HeaderName::try_from(header.0).expect("invalid header name"),
77 HeaderValue::try_from(&header.1).expect("invalid header value"),
78 )
79 })
80 .collect();
81 Self { headers }
82 }
83}
84
85impl<S> tower::layer::Layer<S> for HelmetLayer {
86 type Service = HelmetInner<S>;
87
88 fn layer(&self, inner: S) -> Self::Service {
89 let header_map = self.headers.clone();
90
91 HelmetInner { header_map, inner }
92 }
93}
94
95#[derive(Clone)]
96pub struct HelmetInner<S> {
97 header_map: HeaderMap,
98 inner: S,
99}
100
101impl<S> HelmetInner<S> {
102 pub fn new(inner: S) -> Self {
103 let header_map = HeaderMap::new();
104
105 Self { header_map, inner }
106 }
107}
108
109impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for HelmetInner<S>
110where
111 S: Service<Request<ReqBody>, Response = Response<ResBody>>,
112 ResBody: Default,
113{
114 type Response = S::Response;
115 type Error = S::Error;
116 type Future = ResponseFuture<S::Future>;
117
118 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
119 self.inner.poll_ready(cx)
120 }
121
122 fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
123 ResponseFuture {
124 future: self.inner.call(req),
125 headers: self.header_map.clone(),
126 }
127 }
128}
129
130pin_project! {
131 #[derive(Debug)]
133 pub struct ResponseFuture<F> {
134 #[pin]
135 future: F,
136 headers: HeaderMap,
137 }
138}
139
140impl<F, ResBody, E> Future for ResponseFuture<F>
141where
142 F: Future<Output = Result<Response<ResBody>, E>>,
143{
144 type Output = F::Output;
145
146 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
147 let this = self.project();
148 let mut res = ready!(this.future.poll(cx)?);
149
150 res.headers_mut()
151 .extend(this.headers.iter().map(|(k, v)| (k.clone(), v.clone())));
152
153 Poll::Ready(Ok(res))
154 }
155}
156
157#[cfg(test)]
158mod tests {
159 use super::*;
160
161 use axum::{routing::get, Router};
162 use axum_test::TestServer;
163 use http::{header, HeaderValue};
164
165 #[tokio::test]
166 async fn test_helmet() {
167 let test_app = Router::new()
168 .route("/", get(|| async { "Hello, world!" }))
169 .layer(HelmetLayer::new(
170 Helmet::new()
171 .add(helmet_core::XContentTypeOptions::nosniff())
172 .add(helmet_core::XFrameOptions::same_origin())
173 .add(helmet_core::XXSSProtection::on().mode_block()),
174 ));
175
176 let server = TestServer::new(test_app).expect("failed to create test server");
177
178 let res = server.get("/").await;
179
180 assert_eq!(res.status_code(), 200);
181
182 assert_eq!(
183 res.headers().get(header::X_CONTENT_TYPE_OPTIONS),
184 Some(&HeaderValue::from_static("nosniff"))
185 );
186 assert_eq!(
187 res.headers().get(header::X_FRAME_OPTIONS),
188 Some(&HeaderValue::from_static("SAMEORIGIN"))
189 );
190 assert_eq!(
191 res.headers().get(header::X_XSS_PROTECTION),
192 Some(&HeaderValue::from_static("1; mode=block"))
193 );
194 }
195}