use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, Ordering};
use crate::fairing::{Fairing, Info, Kind};
use crate::http::{uncased::UncasedStr, Header};
use crate::shield::{Frame, Hsts, NoSniff, Permission, Policy};
use crate::trace::{Trace, TraceAll};
use crate::{Config, Orbit, Request, Response, Rocket};
pub struct Shield {
policies: HashMap<&'static UncasedStr, Header<'static>>,
force_hsts: AtomicBool,
}
impl Clone for Shield {
fn clone(&self) -> Self {
Self {
policies: self.policies.clone(),
force_hsts: AtomicBool::from(self.force_hsts.load(Ordering::Acquire)),
}
}
}
impl Default for Shield {
fn default() -> Self {
Shield::new()
.enable(NoSniff::default())
.enable(Frame::default())
.enable(Permission::default())
}
}
impl Shield {
pub fn new() -> Self {
Shield {
policies: HashMap::new(),
force_hsts: AtomicBool::new(false),
}
}
pub fn enable<P: Policy>(mut self, policy: P) -> Self {
self.policies.insert(P::NAME.into(), policy.header());
self
}
pub fn disable<P: Policy>(mut self) -> Self {
self.policies.remove(UncasedStr::new(P::NAME));
self
}
pub fn is_enabled<P: Policy>(&self) -> bool {
self.policies.contains_key(UncasedStr::new(P::NAME))
}
}
#[crate::async_trait]
impl Fairing for Shield {
fn info(&self) -> Info {
Info {
name: "Shield",
kind: Kind::Liftoff | Kind::Response | Kind::Singleton,
}
}
async fn on_liftoff(&self, rocket: &Rocket<Orbit>) {
if self.policies.is_empty() {
return;
}
let force_hsts = rocket.endpoints().all(|v| v.is_tls())
&& rocket.figment().profile() != Config::DEBUG_PROFILE
&& !self.is_enabled::<Hsts>();
if force_hsts {
self.force_hsts.store(true, Ordering::Release);
}
span_info!("shield", policies = self.policies.len() => {
self.policies.values().trace_all_info();
if force_hsts {
warn!("Detected TLS-enabled liftoff without enabling HSTS.\n\
Shield has enabled a default HSTS policy.\n\
To remove this warning, configure an HSTS policy.");
}
})
}
async fn on_response<'r>(&self, _: &'r Request<'_>, response: &mut Response<'r>) {
for header in self.policies.values() {
if response.headers().contains(header.name()) {
span_warn!("shield", "shield refusing to overwrite existing response header" => {
header.trace_warn();
});
continue;
}
response.set_header(header.clone());
}
}
}