use std::borrow::Cow;
use std::fmt;
use indexmap::IndexMap;
use time::Duration;
use crate::http::{uncased::Uncased, uri::Absolute, Header};
pub trait Policy: Default + Send + Sync + 'static {
const NAME: &'static str;
fn header(&self) -> Header<'static>;
}
macro_rules! impl_policy {
($T:ty, $name:expr) => {
impl Policy for $T {
const NAME: &'static str = $name;
fn header(&self) -> Header<'static> {
self.into()
}
}
};
}
impl_policy!(XssFilter, "X-XSS-Protection");
impl_policy!(NoSniff, "X-Content-Type-Options");
impl_policy!(Frame, "X-Frame-Options");
impl_policy!(Hsts, "Strict-Transport-Security");
impl_policy!(ExpectCt, "Expect-CT");
impl_policy!(Referrer, "Referrer-Policy");
impl_policy!(Prefetch, "X-DNS-Prefetch-Control");
impl_policy!(Permission, "Permissions-Policy");
pub enum Referrer {
NoReferrer,
NoReferrerWhenDowngrade,
Origin,
OriginWhenCrossOrigin,
SameOrigin,
StrictOrigin,
StrictOriginWhenCrossOrigin,
UnsafeUrl,
}
impl Default for Referrer {
fn default() -> Referrer {
Referrer::NoReferrer
}
}
impl From<&Referrer> for Header<'static> {
fn from(referrer: &Referrer) -> Self {
let policy_string = match referrer {
Referrer::NoReferrer => "no-referrer",
Referrer::NoReferrerWhenDowngrade => "no-referrer-when-downgrade",
Referrer::Origin => "origin",
Referrer::OriginWhenCrossOrigin => "origin-when-cross-origin",
Referrer::SameOrigin => "same-origin",
Referrer::StrictOrigin => "strict-origin",
Referrer::StrictOriginWhenCrossOrigin => "strict-origin-when-cross-origin",
Referrer::UnsafeUrl => "unsafe-url",
};
Header::new(Referrer::NAME, policy_string)
}
}
pub enum ExpectCt {
Enforce(Duration),
Report(Duration, Absolute<'static>),
ReportAndEnforce(Duration, Absolute<'static>),
}
impl Default for ExpectCt {
fn default() -> ExpectCt {
ExpectCt::Enforce(Duration::days(30))
}
}
impl From<&ExpectCt> for Header<'static> {
fn from(expect: &ExpectCt) -> Self {
let policy_string = match expect {
ExpectCt::Enforce(age) => format!("max-age={}, enforce", age.whole_seconds()),
ExpectCt::Report(age, uri) => {
format!(r#"max-age={}, report-uri="{}""#, age.whole_seconds(), uri)
}
ExpectCt::ReportAndEnforce(age, uri) => {
format!(
"max-age={}, enforce, report-uri=\"{}\"",
age.whole_seconds(),
uri
)
}
};
Header::new(ExpectCt::NAME, policy_string)
}
}
pub enum NoSniff {
Enable,
}
impl Default for NoSniff {
fn default() -> NoSniff {
NoSniff::Enable
}
}
impl From<&NoSniff> for Header<'static> {
fn from(_: &NoSniff) -> Self {
Header::new(NoSniff::NAME, "nosniff")
}
}
#[derive(PartialEq, Copy, Clone)]
pub enum Hsts {
Enable(Duration),
IncludeSubDomains(Duration),
Preload(Duration),
}
impl Default for Hsts {
fn default() -> Hsts {
Hsts::Enable(Duration::days(365))
}
}
impl From<&Hsts> for Header<'static> {
fn from(hsts: &Hsts) -> Self {
if hsts == &Hsts::default() {
static DEFAULT: Header<'static> = Header {
name: Uncased::from_borrowed(Hsts::NAME),
value: Cow::Borrowed("max-age=31536000"),
};
return DEFAULT.clone();
}
let policy_string = match hsts {
Hsts::Enable(age) => format!("max-age={}", age.whole_seconds()),
Hsts::IncludeSubDomains(age) => {
format!("max-age={}; includeSubDomains", age.whole_seconds())
}
Hsts::Preload(age) => {
static YEAR: Duration = Duration::seconds(31536000);
format!(
"max-age={}; includeSubDomains; preload",
age.max(&YEAR).whole_seconds()
)
}
};
Header::new(Hsts::NAME, policy_string)
}
}
pub enum Frame {
Deny,
SameOrigin,
}
impl Default for Frame {
fn default() -> Frame {
Frame::SameOrigin
}
}
impl From<&Frame> for Header<'static> {
fn from(frame: &Frame) -> Self {
let policy_string: &'static str = match frame {
Frame::Deny => "DENY",
Frame::SameOrigin => "SAMEORIGIN",
};
Header::new(Frame::NAME, policy_string)
}
}
pub enum XssFilter {
Disable,
Enable,
EnableBlock,
}
impl Default for XssFilter {
fn default() -> XssFilter {
XssFilter::Enable
}
}
impl From<&XssFilter> for Header<'static> {
fn from(filter: &XssFilter) -> Self {
let policy_string: &'static str = match filter {
XssFilter::Disable => "0",
XssFilter::Enable => "1",
XssFilter::EnableBlock => "1; mode=block",
};
Header::new(XssFilter::NAME, policy_string)
}
}
#[derive(Default)]
pub enum Prefetch {
On,
#[default]
Off,
}
impl From<&Prefetch> for Header<'static> {
fn from(prefetch: &Prefetch) -> Self {
let policy_string = match prefetch {
Prefetch::On => "on",
Prefetch::Off => "off",
};
Header::new(Prefetch::NAME, policy_string)
}
}
#[derive(PartialEq, Clone)]
pub struct Permission(IndexMap<Feature, Vec<Allow>>);
impl Default for Permission {
fn default() -> Self {
Permission::blocked(Feature::InterestCohort)
}
}
impl Permission {
pub fn allowed<L>(feature: Feature, allow: L) -> Self
where
L: IntoIterator<Item = Allow>,
{
Permission(IndexMap::new()).allow(feature, allow)
}
pub fn blocked(feature: Feature) -> Self {
Permission(IndexMap::new()).block(feature)
}
pub fn allow<L>(mut self, feature: Feature, allow: L) -> Self
where
L: IntoIterator<Item = Allow>,
{
let mut allow: Vec<_> = allow.into_iter().collect();
for allow in &allow {
if let Allow::Origin(absolute) = allow {
let auth = absolute.authority();
if auth.is_none() || matches!(auth, Some(a) if a.host().is_empty()) {
panic!("...")
}
}
}
if allow.contains(&Allow::Any) {
allow = vec![Allow::Any];
}
self.0.insert(feature, allow);
self
}
pub fn block(mut self, feature: Feature) -> Self {
self.0.insert(feature, vec![]);
self
}
pub fn get(&self, feature: Feature) -> Option<&[Allow]> {
Some(self.0.get(&feature)?)
}
pub fn iter(&self) -> impl Iterator<Item = (Feature, &[Allow])> {
self.0.iter().map(|(feature, list)| (*feature, &**list))
}
}
impl From<&Permission> for Header<'static> {
fn from(perm: &Permission) -> Self {
if perm == &Permission::default() {
static DEFAULT: Header<'static> = Header {
name: Uncased::from_borrowed(Permission::NAME),
value: Cow::Borrowed("interest-cohort=()"),
};
return DEFAULT.clone();
}
let value = perm
.0
.iter()
.map(|(feature, allow)| {
let list = allow
.iter()
.map(|origin| origin.rendered())
.collect::<Vec<_>>()
.join(" ");
format!("{}=({})", feature, list)
})
.collect::<Vec<_>>()
.join(", ");
Header::new(Permission::NAME, value)
}
}
#[allow(clippy::large_enum_variant)]
#[derive(Debug, PartialEq, Clone)]
pub enum Allow {
Origin(Absolute<'static>),
Any,
This,
}
impl Allow {
fn rendered(&self) -> Cow<'static, str> {
match self {
Allow::Origin(uri) => {
let mut string = String::with_capacity(32);
string.push('"');
string.push_str(uri.scheme());
if let Some(auth) = uri.authority() {
use std::fmt::Write;
let _ = write!(string, "://{}", auth.host());
if let Some(port) = auth.port() {
let _ = write!(string, ":{}", port);
}
}
string.push('"');
string.into()
}
Allow::Any => "*".into(),
Allow::This => "self".into(),
}
}
}
impl IntoIterator for Allow {
type Item = Self;
type IntoIter = std::iter::Once<Self>;
fn into_iter(self) -> Self::IntoIter {
std::iter::once(self)
}
}
#[derive(Debug, PartialEq, Eq, Hash, Copy, Clone)]
#[non_exhaustive]
pub enum Feature {
Accelerometer,
AmbientLightSensor,
Autoplay,
Battery,
Camera,
CrossOriginIsolated,
Displaycapture,
DocumentDomain,
EncryptedMedia,
ExecutionWhileNotRendered,
ExecutionWhileOutOfviewport,
Fullscreen,
Geolocation,
Gyroscope,
Magnetometer,
Microphone,
Midi,
NavigationOverride,
Payment,
PictureInPicture,
PublickeyCredentialsGet,
ScreenWakeLock,
SyncXhr,
Usb,
WebShare,
XrSpatialTracking,
ClipboardRead,
ClipboardWrite,
Gamepad,
SpeakerSelection,
InterestCohort,
ConversionMeasurement,
FocusWithoutUserActivation,
Hid,
IdleDetection,
Serial,
SyncScript,
TrustTokenRedemption,
VerticalScroll,
}
impl Feature {
pub const fn as_str(self) -> &'static str {
use Feature::*;
match self {
Accelerometer => "accelerometer",
AmbientLightSensor => "ambient-light-sensor",
Autoplay => "autoplay",
Battery => "battery",
Camera => "camera",
CrossOriginIsolated => "cross-origin-isolated",
Displaycapture => "display-capture",
DocumentDomain => "document-domain",
EncryptedMedia => "encrypted-media",
ExecutionWhileNotRendered => "execution-while-not-rendered",
ExecutionWhileOutOfviewport => "execution-while-out-of-viewport",
Fullscreen => "fullscreen",
Geolocation => "geolocation",
Gyroscope => "gyroscope",
Magnetometer => "magnetometer",
Microphone => "microphone",
Midi => "midi",
NavigationOverride => "navigation-override",
Payment => "payment",
PictureInPicture => "picture-in-picture",
PublickeyCredentialsGet => "publickey-credentials-get",
ScreenWakeLock => "screen-wake-lock",
SyncXhr => "sync-xhr",
Usb => "usb",
WebShare => "web-share",
XrSpatialTracking => "xr-spatial-tracking",
ClipboardRead => "clipboard-read",
ClipboardWrite => "clipboard-write",
Gamepad => "gamepad",
SpeakerSelection => "speaker-selection",
InterestCohort => "interest-cohort",
ConversionMeasurement => "conversion-measurement",
FocusWithoutUserActivation => "focus-without-user-activation",
Hid => "hid",
IdleDetection => "idle-detection",
Serial => "serial",
SyncScript => "sync-script",
TrustTokenRedemption => "trust-token-redemption",
VerticalScroll => "vertical-scroll",
}
}
}
impl fmt::Display for Feature {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.as_str().fmt(f)
}
}