use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, Ordering};
use state::InitCell;
use yansi::Paint;
use crate::{Rocket, Request, Response, Orbit, Config};
use crate::fairing::{Fairing, Info, Kind};
use crate::http::{Header, uncased::UncasedStr};
use crate::log::PaintExt;
use crate::shield::*;
pub struct Shield {
policies: HashMap<&'static UncasedStr, Box<dyn SubPolicy>>,
force_hsts: AtomicBool,
rendered: InitCell<Vec<Header<'static>>>,
}
impl Default for Shield {
fn default() -> Self {
Shield::new()
.enable(NoSniff::default())
.enable(Frame::default())
.enable(Permission::default())
}
}
impl Shield {
pub fn new() -> Self {
Shield {
policies: HashMap::new(),
force_hsts: AtomicBool::new(false),
rendered: InitCell::new(),
}
}
pub fn enable<P: Policy>(mut self, policy: P) -> Self {
self.rendered = InitCell::new();
self.policies.insert(P::NAME.into(), Box::new(policy));
self
}
pub fn disable<P: Policy>(mut self) -> Self {
self.rendered = InitCell::new();
self.policies.remove(UncasedStr::new(P::NAME));
self
}
pub fn is_enabled<P: Policy>(&self) -> bool {
self.policies.contains_key(UncasedStr::new(P::NAME))
}
fn headers(&self) -> &[Header<'static>] {
self.rendered.get_or_init(|| {
let mut headers: Vec<_> = self.policies.values()
.map(|p| p.header())
.collect();
if self.force_hsts.load(Ordering::Acquire) {
headers.push(Policy::header(&Hsts::default()));
}
headers
})
}
}
#[crate::async_trait]
impl Fairing for Shield {
fn info(&self) -> Info {
Info {
name: "Shield",
kind: Kind::Liftoff | Kind::Response | Kind::Singleton,
}
}
async fn on_liftoff(&self, rocket: &Rocket<Orbit>) {
let force_hsts = rocket.config().tls_enabled()
&& rocket.figment().profile() != Config::DEBUG_PROFILE
&& !self.is_enabled::<Hsts>();
if force_hsts {
self.force_hsts.store(true, Ordering::Release);
}
if !self.headers().is_empty() {
info!("{}{}:", "🛡️ ".emoji(), "Shield".magenta());
for header in self.headers() {
info_!("{}: {}", header.name(), header.value().primary());
}
if force_hsts {
warn_!("Detected TLS-enabled liftoff without enabling HSTS.");
warn_!("Shield has enabled a default HSTS policy.");
info_!("To remove this warning, configure an HSTS policy.");
}
}
}
async fn on_response<'r>(&self, _: &'r Request<'_>, response: &mut Response<'r>) {
for header in self.headers() {
if response.headers().contains(header.name()) {
warn!("Shield: response contains a '{}' header.", header.name());
warn_!("Refusing to overwrite existing header.");
continue
}
response.set_header(header.clone());
}
}
}