use std::borrow::Cow;
use rocket::http::{Header, uri::Uri, uncased::UncasedStr};
use helmet::time::Duration;
pub trait Policy: Default + Send + Sync + 'static {
const NAME: &'static str;
fn header(&self) -> Header<'static>;
}
crate trait SubPolicy: Send + Sync {
fn name(&self) -> &'static UncasedStr;
fn header(&self) -> Header<'static>;
}
impl<P: Policy> SubPolicy for P {
fn name(&self) -> &'static UncasedStr {
P::NAME.into()
}
fn header(&self) -> Header<'static> {
Policy::header(self)
}
}
macro_rules! impl_policy {
($T:ty, $name:expr) => (
impl Policy for $T {
const NAME: &'static str = $name;
fn header(&self) -> Header<'static> {
self.into()
}
}
)
}
impl_policy!(XssFilter, "X-XSS-Protection");
impl_policy!(NoSniff, "X-Content-Type-Options");
impl_policy!(Frame, "X-Frame-Options");
impl_policy!(Hsts, "Strict-Transport-Security");
impl_policy!(ExpectCt, "Expect-CT");
impl_policy!(Referrer, "Referrer-Policy");
pub enum Referrer {
NoReferrer,
NoReferrerWhenDowngrade,
Origin,
OriginWhenCrossOrigin,
SameOrigin,
StrictOrigin,
StrictOriginWhenCrossOrigin,
UnsafeUrl,
}
impl Default for Referrer {
fn default() -> Referrer {
Referrer::NoReferrer
}
}
impl<'h, 'a> Into<Header<'h>> for &'a Referrer {
fn into(self) -> Header<'h> {
let policy_string = match self {
Referrer::NoReferrer => "no-referrer",
Referrer::NoReferrerWhenDowngrade => "no-referrer-when-downgrade",
Referrer::Origin => "origin",
Referrer::OriginWhenCrossOrigin => "origin-when-cross-origin",
Referrer::SameOrigin => "same-origin",
Referrer::StrictOrigin => "strict-origin",
Referrer::StrictOriginWhenCrossOrigin => "strict-origin-when-cross-origin",
Referrer::UnsafeUrl => "unsafe-url",
};
Header::new(Referrer::NAME, policy_string)
}
}
pub enum ExpectCt {
Enforce(Duration),
Report(Duration, Uri<'static>),
ReportAndEnforce(Duration, Uri<'static>),
}
impl Default for ExpectCt {
fn default() -> ExpectCt {
ExpectCt::Enforce(Duration::days(30))
}
}
impl<'a> Into<Header<'static>> for &'a ExpectCt {
fn into(self) -> Header<'static> {
let policy_string = match self {
ExpectCt::Enforce(age) => format!("max-age={}, enforce", age.num_seconds()),
ExpectCt::Report(age, uri) => {
format!(r#"max-age={}, report-uri="{}""#, age.num_seconds(), uri)
}
ExpectCt::ReportAndEnforce(age, uri) => {
format!("max-age={}, enforce, report-uri=\"{}\"", age.num_seconds(), uri)
}
};
Header::new(ExpectCt::NAME, policy_string)
}
}
pub enum NoSniff {
Enable,
}
impl Default for NoSniff {
fn default() -> NoSniff {
NoSniff::Enable
}
}
impl<'h, 'a> Into<Header<'h>> for &'a NoSniff {
fn into(self) -> Header<'h> {
Header::new(NoSniff::NAME, "nosniff")
}
}
pub enum Hsts {
Enable(Duration),
IncludeSubDomains(Duration),
Preload(Duration),
}
impl Default for Hsts {
fn default() -> Hsts {
Hsts::Enable(Duration::weeks(52))
}
}
impl<'a> Into<Header<'static>> for &'a Hsts {
fn into(self) -> Header<'static> {
let policy_string = match self {
Hsts::Enable(age) => format!("max-age={}", age.num_seconds()),
Hsts::IncludeSubDomains(age) => {
format!("max-age={}; includeSubDomains", age.num_seconds())
}
Hsts::Preload(age) => format!("max-age={}; preload", age.num_seconds()),
};
Header::new(Hsts::NAME, policy_string)
}
}
pub enum Frame {
Deny,
SameOrigin,
AllowFrom(Uri<'static>),
}
impl Default for Frame {
fn default() -> Frame {
Frame::SameOrigin
}
}
impl<'a> Into<Header<'static>> for &'a Frame {
fn into(self) -> Header<'static> {
let policy_string: Cow<'static, str> = match self {
Frame::Deny => "DENY".into(),
Frame::SameOrigin => "SAMEORIGIN".into(),
Frame::AllowFrom(uri) => format!("ALLOW-FROM {}", uri).into(),
};
Header::new(Frame::NAME, policy_string)
}
}
pub enum XssFilter {
Disable,
Enable,
EnableBlock,
EnableReport(Uri<'static>),
}
impl Default for XssFilter {
fn default() -> XssFilter {
XssFilter::Enable
}
}
impl<'a> Into<Header<'static>> for &'a XssFilter {
fn into(self) -> Header<'static> {
let policy_string: Cow<'static, str> = match self {
XssFilter::Disable => "0".into(),
XssFilter::Enable => "1".into(),
XssFilter::EnableBlock => "1; mode=block".into(),
XssFilter::EnableReport(u) => format!("{}{}", "1; report=", u).into(),
};
Header::new(XssFilter::NAME, policy_string)
}
}