rocket_community/shield/shield.rs
1use std::collections::HashMap;
2use std::sync::atomic::{AtomicBool, Ordering};
3
4use crate::fairing::{Fairing, Info, Kind};
5use crate::http::{uncased::UncasedStr, Header};
6use crate::shield::{Frame, Hsts, NoSniff, Permission, Policy};
7use crate::trace::{Trace, TraceAll};
8use crate::{Config, Orbit, Request, Response, Rocket};
9
10/// A [`Fairing`] that injects browser security and privacy headers into all
11/// outgoing responses.
12///
13/// # Usage
14///
15/// To use `Shield`, first construct an instance of it. To use the default
16/// set of headers, construct with [`Shield::default()`](#method.default).
17/// For an instance with no preset headers, use [`Shield::new()`]. To
18/// enable an additional header, use [`enable()`](Shield::enable()), and to
19/// disable a header, use [`disable()`](Shield::disable()):
20///
21/// ```rust
22/// # extern crate rocket_community as rocket;
23///
24/// use rocket::shield::Shield;
25/// use rocket::shield::{XssFilter, ExpectCt};
26///
27/// // A `Shield` with the default headers:
28/// let shield = Shield::default();
29///
30/// // A `Shield` with the default headers minus `XssFilter`:
31/// let shield = Shield::default().disable::<XssFilter>();
32///
33/// // A `Shield` with the default headers plus `ExpectCt`.
34/// let shield = Shield::default().enable(ExpectCt::default());
35///
36/// // A `Shield` with only `XssFilter` and `ExpectCt`.
37/// let shield = Shield::default()
38/// .enable(XssFilter::default())
39/// .enable(ExpectCt::default());
40/// ```
41///
42/// Then, attach the instance of `Shield` to your application's instance of
43/// `Rocket`:
44///
45/// ```rust
46/// # extern crate rocket_community as rocket;
47/// # use rocket::shield::Shield;
48/// # let shield = Shield::default();
49/// rocket::build()
50/// // ...
51/// .attach(shield)
52/// # ;
53/// ```
54///
55/// The fairing will inject all enabled headers into all outgoing responses
56/// _unless_ the response already contains a header with the same name. If it
57/// does contain the header, a warning is emitted, and the header is not
58/// overwritten.
59///
60/// # TLS and HSTS
61///
62/// If TLS is configured and enabled when the application is launched in a
63/// non-debug profile, HSTS is automatically enabled with its default policy and
64/// a warning is logged. To get rid of this warning, explicitly
65/// [`Shield::enable()`] an [`Hsts`] policy.
66pub struct Shield {
67 /// Enabled policies where the key is the header name.
68 policies: HashMap<&'static UncasedStr, Header<'static>>,
69 /// Whether to enforce HSTS even though the user didn't enable it.
70 force_hsts: AtomicBool,
71}
72
73impl Clone for Shield {
74 fn clone(&self) -> Self {
75 Self {
76 policies: self.policies.clone(),
77 force_hsts: AtomicBool::from(self.force_hsts.load(Ordering::Acquire)),
78 }
79 }
80}
81
82impl Default for Shield {
83 /// Returns a new `Shield` instance. See the [table] for a description
84 /// of the policies used by default.
85 ///
86 /// [table]: ./#supported-headers
87 ///
88 /// # Example
89 ///
90 /// ```rust
91 /// # extern crate rocket_community as rocket;
92 ///
93 /// use rocket::shield::Shield;
94 ///
95 /// let shield = Shield::default();
96 /// ```
97 fn default() -> Self {
98 Shield::new()
99 .enable(NoSniff::default())
100 .enable(Frame::default())
101 .enable(Permission::default())
102 }
103}
104
105impl Shield {
106 /// Returns an instance of `Shield` with no headers enabled.
107 ///
108 /// # Example
109 ///
110 /// ```rust
111 /// # extern crate rocket_community as rocket;
112 ///
113 /// use rocket::shield::Shield;
114 ///
115 /// let shield = Shield::new();
116 /// ```
117 pub fn new() -> Self {
118 Shield {
119 policies: HashMap::new(),
120 force_hsts: AtomicBool::new(false),
121 }
122 }
123
124 /// Enables the policy header `policy`.
125 ///
126 /// If the policy was previously enabled, the configuration is replaced
127 /// with that of `policy`.
128 ///
129 /// # Example
130 ///
131 /// ```rust
132 /// # extern crate rocket_community as rocket;
133 ///
134 /// use rocket::shield::Shield;
135 /// use rocket::shield::NoSniff;
136 ///
137 /// let shield = Shield::new().enable(NoSniff::default());
138 /// ```
139 pub fn enable<P: Policy>(mut self, policy: P) -> Self {
140 self.policies.insert(P::NAME.into(), policy.header());
141 self
142 }
143
144 /// Disables the policy header `policy`.
145 ///
146 /// # Example
147 ///
148 /// ```rust
149 /// # extern crate rocket_community as rocket;
150 ///
151 /// use rocket::shield::Shield;
152 /// use rocket::shield::NoSniff;
153 ///
154 /// let shield = Shield::default().disable::<NoSniff>();
155 /// ```
156 pub fn disable<P: Policy>(mut self) -> Self {
157 self.policies.remove(UncasedStr::new(P::NAME));
158 self
159 }
160
161 /// Returns `true` if the policy `P` is enabled.
162 ///
163 /// # Example
164 ///
165 /// ```rust
166 /// # extern crate rocket_community as rocket;
167 ///
168 /// use rocket::shield::Shield;
169 /// use rocket::shield::{Permission, NoSniff, Frame};
170 /// use rocket::shield::{Prefetch, ExpectCt, Referrer};
171 ///
172 /// let shield = Shield::default();
173 ///
174 /// assert!(shield.is_enabled::<NoSniff>());
175 /// assert!(shield.is_enabled::<Frame>());
176 /// assert!(shield.is_enabled::<Permission>());
177 ///
178 /// assert!(!shield.is_enabled::<Prefetch>());
179 /// assert!(!shield.is_enabled::<ExpectCt>());
180 /// assert!(!shield.is_enabled::<Referrer>());
181 /// ```
182 pub fn is_enabled<P: Policy>(&self) -> bool {
183 self.policies.contains_key(UncasedStr::new(P::NAME))
184 }
185}
186
187#[crate::async_trait]
188impl Fairing for Shield {
189 fn info(&self) -> Info {
190 Info {
191 name: "Shield",
192 kind: Kind::Liftoff | Kind::Response | Kind::Singleton,
193 }
194 }
195
196 async fn on_liftoff(&self, rocket: &Rocket<Orbit>) {
197 if self.policies.is_empty() {
198 return;
199 }
200
201 let force_hsts = rocket.endpoints().all(|v| v.is_tls())
202 && rocket.figment().profile() != Config::DEBUG_PROFILE
203 && !self.is_enabled::<Hsts>();
204
205 if force_hsts {
206 self.force_hsts.store(true, Ordering::Release);
207 }
208
209 span_info!("shield", policies = self.policies.len() => {
210 self.policies.values().trace_all_info();
211
212 if force_hsts {
213 warn!("Detected TLS-enabled liftoff without enabling HSTS.\n\
214 Shield has enabled a default HSTS policy.\n\
215 To remove this warning, configure an HSTS policy.");
216 }
217 })
218 }
219
220 async fn on_response<'r>(&self, _: &'r Request<'_>, response: &mut Response<'r>) {
221 // Set all of the headers in `self.policies` in `response` as long as
222 // the header is not already in the response.
223 for header in self.policies.values() {
224 if response.headers().contains(header.name()) {
225 span_warn!("shield", "shield refusing to overwrite existing response header" => {
226 header.trace_warn();
227 });
228
229 continue;
230 }
231
232 response.set_header(header.clone());
233 }
234 }
235}