tower_helmet/
lib.rs

1//! # Overview
2//!
3//! `tower-helmet` helps you secure your tower server by setting various HTTP headers. _It's not a
4//! silver bullet_, but it can help!
5//!
6//! You can find a list of all available headers under the [header] module. By default (with
7//! [HelmetLayer::with_defaults]) **all of them** are enabled. Please take a good look at
8//! [ContentSecurityPolicy]. Most of the time you will need to adapt this one to your needs.
9//!
10//! # Examples
11//!
12//! ```
13//! use std::collections::HashMap;
14//!
15//! use tower_helmet::header::{ContentSecurityPolicy, ExpectCt, XFrameOptions};
16//! use tower_helmet::HelmetLayer;
17//!
18//! // default layer with all security headers active
19//! let layer = HelmetLayer::with_defaults();
20//!
21//! // default layer with csp customizations applied
22//! let mut directives = HashMap::new();
23//! directives.insert("default-src", vec!["'self'", "https://example.com"]);
24//! directives.insert("img-src", vec!["'self'", "data:", "https://example.com"]);
25//! directives.insert(
26//!     "script-src",
27//!     vec!["'self'", "'unsafe-inline'", "https://example.com"],
28//! );
29//! let csp = ContentSecurityPolicy {
30//!     directives,
31//!     ..Default::default()
32//! };
33//!
34//! let layer = HelmetLayer::with_defaults().enable(csp);
35//!
36//! // completely blank layer, selectively enable and add headers
37//! let layer = HelmetLayer::blank()
38//!     .enable(XFrameOptions::SameOrigin)
39//!     .enable(ExpectCt::default());
40//! ```
41pub mod header;
42
43use std::future::Future;
44use std::pin::Pin;
45use std::task::{Context, Poll};
46
47use futures::ready;
48use http::header::{AsHeaderName, HeaderName, InvalidHeaderValue};
49use http::{HeaderMap, HeaderValue, Request, Response};
50use pin_project_lite::pin_project;
51use tower_layer::Layer;
52use tower_service::Service;
53
54use crate::header::{
55    ContentSecurityPolicy, CrossOriginEmbedderPolicy, CrossOriginOpenerPolicy,
56    CrossOriginResourcePolicy, ExpectCt, OriginAgentCluster, ReferrerPolicy,
57    StrictTransportSecurity, XContentTypeOptions, XDnsPrefetchControl, XDownloadOptions,
58    XFrameOptions, XPermittedCrossDomainPolicies, XXSSProtection,
59};
60
61pub trait IntoHeader {
62    fn header_name(&self) -> HeaderName;
63    fn header_value(&self) -> Result<HeaderValue, InvalidHeaderValue>;
64}
65
66/// HelmetLayer
67#[derive(Debug, Clone)]
68pub struct HelmetLayer {
69    headers: HeaderMap,
70}
71
72impl HelmetLayer {
73    /// Helmet without any headers added in by default. See [`enable`] for enabling headers.
74    pub fn blank() -> Self {
75        Self {
76            headers: HeaderMap::new(),
77        }
78    }
79
80    /// Helmet with most of the headers already added with the base configuration.
81    pub fn with_defaults() -> Self {
82        let mut layer = Self::blank();
83        layer
84            .enable(ContentSecurityPolicy::default())
85            .enable(CrossOriginEmbedderPolicy::default())
86            .enable(CrossOriginOpenerPolicy::default())
87            .enable(CrossOriginResourcePolicy::default())
88            .enable(ExpectCt::default())
89            .enable(OriginAgentCluster::default())
90            .enable(ReferrerPolicy::default())
91            .enable(StrictTransportSecurity::default())
92            .enable(XContentTypeOptions::default())
93            .enable(XDnsPrefetchControl::default())
94            .enable(XDownloadOptions::default())
95            .enable(XFrameOptions::default())
96            .enable(XPermittedCrossDomainPolicies::default())
97            .enable(XXSSProtection::default());
98
99        layer
100    }
101
102    pub fn enable(&mut self, h: impl IntoHeader) -> &mut Self {
103        self.headers
104            .insert(h.header_name(), h.header_value().unwrap());
105        self
106    }
107
108    pub fn remove<K>(&mut self, key: K) -> &mut Self
109    where
110        K: AsHeaderName,
111    {
112        self.headers.remove(key);
113        self
114    }
115}
116
117impl<S> Layer<S> for HelmetLayer {
118    type Service = HelmetService<S>;
119
120    fn layer(&self, service: S) -> Self::Service {
121        HelmetService {
122            inner: service,
123            headers: self.headers.clone(),
124        }
125    }
126}
127
128#[derive(Debug, Clone)]
129pub struct HelmetService<S> {
130    inner: S,
131    headers: HeaderMap,
132}
133
134impl<ReqBody, ResBody, S> Service<Request<ReqBody>> for HelmetService<S>
135where
136    S: Service<Request<ReqBody>, Response = Response<ResBody>>,
137{
138    type Response = S::Response;
139    type Error = S::Error;
140    type Future = ResponseFuture<S::Future>;
141
142    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
143        self.inner.poll_ready(cx)
144    }
145
146    fn call(&mut self, request: Request<ReqBody>) -> Self::Future {
147        ResponseFuture {
148            future: self.inner.call(request),
149            headers: self.headers.clone(),
150        }
151    }
152}
153
154pin_project! {
155    /// Response future for [`HelmetService`].
156    #[derive(Debug)]
157    pub struct ResponseFuture<F> {
158        #[pin]
159        future: F,
160
161        headers: HeaderMap,
162    }
163}
164
165impl<F, ResBody, E> Future for ResponseFuture<F>
166where
167    F: Future<Output = Result<Response<ResBody>, E>>,
168{
169    type Output = F::Output;
170
171    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
172        let this = self.project();
173        let mut res: Response<ResBody> = ready!(this.future.poll(cx)?);
174        let headers = res.headers_mut();
175
176        for (name, value) in this.headers {
177            headers.insert(name, value.clone());
178        }
179
180        Poll::Ready(Ok(res))
181    }
182}