1pub mod header;
42
43use std::future::Future;
44use std::pin::Pin;
45use std::task::{Context, Poll};
46
47use futures::ready;
48use http::header::{AsHeaderName, HeaderName, InvalidHeaderValue};
49use http::{HeaderMap, HeaderValue, Request, Response};
50use pin_project_lite::pin_project;
51use tower_layer::Layer;
52use tower_service::Service;
53
54use crate::header::{
55 ContentSecurityPolicy, CrossOriginEmbedderPolicy, CrossOriginOpenerPolicy,
56 CrossOriginResourcePolicy, ExpectCt, OriginAgentCluster, ReferrerPolicy,
57 StrictTransportSecurity, XContentTypeOptions, XDnsPrefetchControl, XDownloadOptions,
58 XFrameOptions, XPermittedCrossDomainPolicies, XXSSProtection,
59};
60
61pub trait IntoHeader {
62 fn header_name(&self) -> HeaderName;
63 fn header_value(&self) -> Result<HeaderValue, InvalidHeaderValue>;
64}
65
66#[derive(Debug, Clone)]
68pub struct HelmetLayer {
69 headers: HeaderMap,
70}
71
72impl HelmetLayer {
73 pub fn blank() -> Self {
75 Self {
76 headers: HeaderMap::new(),
77 }
78 }
79
80 pub fn with_defaults() -> Self {
82 let mut layer = Self::blank();
83 layer
84 .enable(ContentSecurityPolicy::default())
85 .enable(CrossOriginEmbedderPolicy::default())
86 .enable(CrossOriginOpenerPolicy::default())
87 .enable(CrossOriginResourcePolicy::default())
88 .enable(ExpectCt::default())
89 .enable(OriginAgentCluster::default())
90 .enable(ReferrerPolicy::default())
91 .enable(StrictTransportSecurity::default())
92 .enable(XContentTypeOptions::default())
93 .enable(XDnsPrefetchControl::default())
94 .enable(XDownloadOptions::default())
95 .enable(XFrameOptions::default())
96 .enable(XPermittedCrossDomainPolicies::default())
97 .enable(XXSSProtection::default());
98
99 layer
100 }
101
102 pub fn enable(&mut self, h: impl IntoHeader) -> &mut Self {
103 self.headers
104 .insert(h.header_name(), h.header_value().unwrap());
105 self
106 }
107
108 pub fn remove<K>(&mut self, key: K) -> &mut Self
109 where
110 K: AsHeaderName,
111 {
112 self.headers.remove(key);
113 self
114 }
115}
116
117impl<S> Layer<S> for HelmetLayer {
118 type Service = HelmetService<S>;
119
120 fn layer(&self, service: S) -> Self::Service {
121 HelmetService {
122 inner: service,
123 headers: self.headers.clone(),
124 }
125 }
126}
127
128#[derive(Debug, Clone)]
129pub struct HelmetService<S> {
130 inner: S,
131 headers: HeaderMap,
132}
133
134impl<ReqBody, ResBody, S> Service<Request<ReqBody>> for HelmetService<S>
135where
136 S: Service<Request<ReqBody>, Response = Response<ResBody>>,
137{
138 type Response = S::Response;
139 type Error = S::Error;
140 type Future = ResponseFuture<S::Future>;
141
142 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
143 self.inner.poll_ready(cx)
144 }
145
146 fn call(&mut self, request: Request<ReqBody>) -> Self::Future {
147 ResponseFuture {
148 future: self.inner.call(request),
149 headers: self.headers.clone(),
150 }
151 }
152}
153
154pin_project! {
155 #[derive(Debug)]
157 pub struct ResponseFuture<F> {
158 #[pin]
159 future: F,
160
161 headers: HeaderMap,
162 }
163}
164
165impl<F, ResBody, E> Future for ResponseFuture<F>
166where
167 F: Future<Output = Result<Response<ResBody>, E>>,
168{
169 type Output = F::Output;
170
171 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
172 let this = self.project();
173 let mut res: Response<ResBody> = ready!(this.future.poll(cx)?);
174 let headers = res.headers_mut();
175
176 for (name, value) in this.headers {
177 headers.insert(name, value.clone());
178 }
179
180 Poll::Ready(Ok(res))
181 }
182}