use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, Ordering};
use rocket::http::uncased::UncasedStr;
use rocket::fairing::{Fairing, Info, Kind};
use rocket::{Request, Response, Rocket};
use helmet::*;
pub struct SpaceHelmet {
policies: HashMap<&'static UncasedStr, Box<dyn SubPolicy>>,
force_hsts: AtomicBool,
}
impl Default for SpaceHelmet {
fn default() -> Self {
SpaceHelmet::new()
.enable(NoSniff::default())
.enable(Frame::default())
.enable(XssFilter::default())
}
}
impl SpaceHelmet {
pub fn new() -> Self {
SpaceHelmet {
policies: HashMap::new(),
force_hsts: AtomicBool::new(false),
}
}
pub fn enable<P: Policy>(mut self, policy: P) -> Self {
self.policies.insert(P::NAME.into(), Box::new(policy));
self
}
pub fn disable<P: Policy>(mut self) -> Self {
self.policies.remove(UncasedStr::new(P::NAME));
self
}
pub fn is_enabled<P: Policy>(&self) -> bool {
self.policies.contains_key(UncasedStr::new(P::NAME))
}
fn apply(&self, response: &mut Response) {
for policy in self.policies.values() {
let name = policy.name();
if response.headers().contains(name.as_str()) {
warn!("Space Helmet: response contains a '{}' header.", name);
warn_!("Refusing to overwrite existing header.");
continue
}
response.set_header(policy.header());
}
if self.force_hsts.load(Ordering::Relaxed) {
if !response.headers().contains(Hsts::NAME) {
response.set_header(&Hsts::default());
}
}
}
}
impl Fairing for SpaceHelmet {
fn info(&self) -> Info {
Info {
name: "Space Helmet",
kind: Kind::Response | Kind::Launch,
}
}
fn on_response(&self, _request: &Request, response: &mut Response) {
self.apply(response);
}
fn on_launch(&self, rocket: &Rocket) {
if rocket.config().tls_enabled()
&& !rocket.config().environment.is_dev()
&& !self.is_enabled::<Hsts>()
{
warn_!("Space Helmet: deploying with TLS without enabling HSTS.");
warn_!("Enabling default HSTS policy.");
info_!("To disable this warning, configure an HSTS policy.");
self.force_hsts.store(true, Ordering::Relaxed);
}
}
}