rocket_community/shield/
shield.rs

1use std::collections::HashMap;
2use std::sync::atomic::{AtomicBool, Ordering};
3
4use crate::fairing::{Fairing, Info, Kind};
5use crate::http::{uncased::UncasedStr, Header};
6use crate::shield::{Frame, Hsts, NoSniff, Permission, Policy};
7use crate::trace::{Trace, TraceAll};
8use crate::{Config, Orbit, Request, Response, Rocket};
9
10/// A [`Fairing`] that injects browser security and privacy headers into all
11/// outgoing responses.
12///
13/// # Usage
14///
15/// To use `Shield`, first construct an instance of it. To use the default
16/// set of headers, construct with [`Shield::default()`](#method.default).
17/// For an instance with no preset headers, use [`Shield::new()`]. To
18/// enable an additional header, use [`enable()`](Shield::enable()), and to
19/// disable a header, use [`disable()`](Shield::disable()):
20///
21/// ```rust
22/// # extern crate rocket_community as rocket;
23///
24/// use rocket::shield::Shield;
25/// use rocket::shield::{XssFilter, ExpectCt};
26///
27/// // A `Shield` with the default headers:
28/// let shield = Shield::default();
29///
30/// // A `Shield` with the default headers minus `XssFilter`:
31/// let shield = Shield::default().disable::<XssFilter>();
32///
33/// // A `Shield` with the default headers plus `ExpectCt`.
34/// let shield = Shield::default().enable(ExpectCt::default());
35///
36/// // A `Shield` with only `XssFilter` and `ExpectCt`.
37/// let shield = Shield::default()
38///     .enable(XssFilter::default())
39///     .enable(ExpectCt::default());
40/// ```
41///
42/// Then, attach the instance of `Shield` to your application's instance of
43/// `Rocket`:
44///
45/// ```rust
46/// # extern crate rocket_community as rocket;
47/// # use rocket::shield::Shield;
48/// # let shield = Shield::default();
49/// rocket::build()
50///     // ...
51///     .attach(shield)
52/// # ;
53/// ```
54///
55/// The fairing will inject all enabled headers into all outgoing responses
56/// _unless_ the response already contains a header with the same name. If it
57/// does contain the header, a warning is emitted, and the header is not
58/// overwritten.
59///
60/// # TLS and HSTS
61///
62/// If TLS is configured and enabled when the application is launched in a
63/// non-debug profile, HSTS is automatically enabled with its default policy and
64/// a warning is logged. To get rid of this warning, explicitly
65/// [`Shield::enable()`] an [`Hsts`] policy.
66pub struct Shield {
67    /// Enabled policies where the key is the header name.
68    policies: HashMap<&'static UncasedStr, Header<'static>>,
69    /// Whether to enforce HSTS even though the user didn't enable it.
70    force_hsts: AtomicBool,
71}
72
73impl Clone for Shield {
74    fn clone(&self) -> Self {
75        Self {
76            policies: self.policies.clone(),
77            force_hsts: AtomicBool::from(self.force_hsts.load(Ordering::Acquire)),
78        }
79    }
80}
81
82impl Default for Shield {
83    /// Returns a new `Shield` instance. See the [table] for a description
84    /// of the policies used by default.
85    ///
86    /// [table]: ./#supported-headers
87    ///
88    /// # Example
89    ///
90    /// ```rust
91    /// # extern crate rocket_community as rocket;
92    ///
93    /// use rocket::shield::Shield;
94    ///
95    /// let shield = Shield::default();
96    /// ```
97    fn default() -> Self {
98        Shield::new()
99            .enable(NoSniff::default())
100            .enable(Frame::default())
101            .enable(Permission::default())
102    }
103}
104
105impl Shield {
106    /// Returns an instance of `Shield` with no headers enabled.
107    ///
108    /// # Example
109    ///
110    /// ```rust
111    /// # extern crate rocket_community as rocket;
112    ///
113    /// use rocket::shield::Shield;
114    ///
115    /// let shield = Shield::new();
116    /// ```
117    pub fn new() -> Self {
118        Shield {
119            policies: HashMap::new(),
120            force_hsts: AtomicBool::new(false),
121        }
122    }
123
124    /// Enables the policy header `policy`.
125    ///
126    /// If the policy was previously enabled, the configuration is replaced
127    /// with that of `policy`.
128    ///
129    /// # Example
130    ///
131    /// ```rust
132    /// # extern crate rocket_community as rocket;
133    ///
134    /// use rocket::shield::Shield;
135    /// use rocket::shield::NoSniff;
136    ///
137    /// let shield = Shield::new().enable(NoSniff::default());
138    /// ```
139    pub fn enable<P: Policy>(mut self, policy: P) -> Self {
140        self.policies.insert(P::NAME.into(), policy.header());
141        self
142    }
143
144    /// Disables the policy header `policy`.
145    ///
146    /// # Example
147    ///
148    /// ```rust
149    /// # extern crate rocket_community as rocket;
150    ///
151    /// use rocket::shield::Shield;
152    /// use rocket::shield::NoSniff;
153    ///
154    /// let shield = Shield::default().disable::<NoSniff>();
155    /// ```
156    pub fn disable<P: Policy>(mut self) -> Self {
157        self.policies.remove(UncasedStr::new(P::NAME));
158        self
159    }
160
161    /// Returns `true` if the policy `P` is enabled.
162    ///
163    /// # Example
164    ///
165    /// ```rust
166    /// # extern crate rocket_community as rocket;
167    ///
168    /// use rocket::shield::Shield;
169    /// use rocket::shield::{Permission, NoSniff, Frame};
170    /// use rocket::shield::{Prefetch, ExpectCt, Referrer};
171    ///
172    /// let shield = Shield::default();
173    ///
174    /// assert!(shield.is_enabled::<NoSniff>());
175    /// assert!(shield.is_enabled::<Frame>());
176    /// assert!(shield.is_enabled::<Permission>());
177    ///
178    /// assert!(!shield.is_enabled::<Prefetch>());
179    /// assert!(!shield.is_enabled::<ExpectCt>());
180    /// assert!(!shield.is_enabled::<Referrer>());
181    /// ```
182    pub fn is_enabled<P: Policy>(&self) -> bool {
183        self.policies.contains_key(UncasedStr::new(P::NAME))
184    }
185}
186
187#[crate::async_trait]
188impl Fairing for Shield {
189    fn info(&self) -> Info {
190        Info {
191            name: "Shield",
192            kind: Kind::Liftoff | Kind::Response | Kind::Singleton,
193        }
194    }
195
196    async fn on_liftoff(&self, rocket: &Rocket<Orbit>) {
197        if self.policies.is_empty() {
198            return;
199        }
200
201        let force_hsts = rocket.endpoints().all(|v| v.is_tls())
202            && rocket.figment().profile() != Config::DEBUG_PROFILE
203            && !self.is_enabled::<Hsts>();
204
205        if force_hsts {
206            self.force_hsts.store(true, Ordering::Release);
207        }
208
209        span_info!("shield", policies = self.policies.len() => {
210            self.policies.values().trace_all_info();
211
212            if force_hsts {
213                warn!("Detected TLS-enabled liftoff without enabling HSTS.\n\
214                    Shield has enabled a default HSTS policy.\n\
215                    To remove this warning, configure an HSTS policy.");
216            }
217        })
218    }
219
220    async fn on_response<'r>(&self, _: &'r Request<'_>, response: &mut Response<'r>) {
221        // Set all of the headers in `self.policies` in `response` as long as
222        // the header is not already in the response.
223        for header in self.policies.values() {
224            if response.headers().contains(header.name()) {
225                span_warn!("shield", "shield refusing to overwrite existing response header" => {
226                    header.trace_warn();
227                });
228
229                continue;
230            }
231
232            response.set_header(header.clone());
233        }
234    }
235}