#![deny(unsafe_code)]
use axum::http::HeaderValue;
use regex::RegexSet;
use std::collections::HashMap;
use std::fmt::{Debug, Display, Formatter};
#[derive(Hash, Eq, PartialEq, Debug, Clone, Copy, Ord, PartialOrd)]
pub enum CspDirectiveType {
ChildSrc,
ConnectSrc,
DefaultSrc,
FontSrc,
FrameSrc,
ImgSrc,
ManifestSrc,
MediaSrc,
ObjectSrc,
PrefetchSrc,
ScriptSource,
ScriptSourceElem,
StyleSource,
StyleSourceElem,
WorkerSource,
BaseUri,
Sandbox,
FormAction,
FrameAncestors,
NavigateTo,
ReportUri,
ReportTo,
RequireTrustedTypesFor,
TrustedTypes,
UpgradeInsecureRequests,
}
impl AsRef<str> for CspDirectiveType {
fn as_ref(&self) -> &str {
match self {
CspDirectiveType::ChildSrc => "child-src",
CspDirectiveType::ConnectSrc => "connect-src",
CspDirectiveType::DefaultSrc => "default-src",
CspDirectiveType::FrameSrc => "frame-src",
CspDirectiveType::FontSrc => "font-src",
CspDirectiveType::ImgSrc => "img-src",
CspDirectiveType::ManifestSrc => "manifest-src",
CspDirectiveType::MediaSrc => "media-src",
CspDirectiveType::ObjectSrc => "object-src",
CspDirectiveType::PrefetchSrc => "prefetch-src",
CspDirectiveType::ScriptSource => "script-src",
CspDirectiveType::ScriptSourceElem => "script-src-elem",
CspDirectiveType::StyleSource => "style-src",
CspDirectiveType::StyleSourceElem => "style-src-elem",
CspDirectiveType::WorkerSource => "worker-src",
CspDirectiveType::BaseUri => "base-uri",
CspDirectiveType::Sandbox => "sandbox",
CspDirectiveType::FormAction => "form-action",
CspDirectiveType::FrameAncestors => "frame-ancestors",
CspDirectiveType::NavigateTo => "navigate-to",
CspDirectiveType::ReportUri => "report-uri",
CspDirectiveType::ReportTo => "report-to",
CspDirectiveType::RequireTrustedTypesFor => "require-trusted-types-for",
CspDirectiveType::TrustedTypes => "trusted-types",
CspDirectiveType::UpgradeInsecureRequests => "upgrade-insecure-requests",
}
}
}
impl Display for CspDirectiveType {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.as_ref())
}
}
impl From<CspDirectiveType> for String {
fn from(input: CspDirectiveType) -> String {
input.as_ref().to_string()
}
}
#[derive(Debug, Clone)]
pub struct CspDirective {
pub directive_type: CspDirectiveType,
pub values: Vec<CspValue>,
}
impl CspDirective {
#[must_use]
pub fn from(directive_type: CspDirectiveType, values: Vec<CspValue>) -> Self {
Self {
directive_type,
values,
}
}
pub fn default_self() -> Self {
Self {
directive_type: CspDirectiveType::DefaultSrc,
values: vec![CspValue::SelfSite],
}
}
}
impl Display for CspDirective {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
let space = if self.values.is_empty() { "" } else { " " };
f.write_fmt(format_args!(
"{}{}{}",
self.directive_type.as_ref(),
space,
self.values
.iter()
.map(|v| String::from(v.to_owned()))
.collect::<Vec<String>>()
.join(" ")
))
}
}
impl From<CspDirective> for HeaderValue {
fn from(input: CspDirective) -> HeaderValue {
match HeaderValue::from_str(&input.to_string()) {
Ok(val) => val,
Err(e) => panic!("Failed to build HeaderValue from CspDirective: {}", e),
}
}
}
#[derive(Clone, Debug)]
pub struct CspUrlMatcher {
pub matcher: RegexSet,
pub directives: Vec<CspDirective>,
}
impl CspUrlMatcher {
#[must_use]
pub fn new(matcher: RegexSet) -> Self {
Self {
matcher,
directives: vec![],
}
}
pub fn with_directive(&mut self, directive: CspDirective) -> &mut Self {
self.directives.push(directive);
self
}
pub fn is_match(&self, text: &str) -> bool {
self.matcher.is_match(text)
}
pub fn default_all_self() -> Self {
Self {
matcher: RegexSet::new([r#".*"#]).unwrap(),
directives: vec![CspDirective {
directive_type: CspDirectiveType::DefaultSrc,
values: vec![CspValue::SelfSite],
}],
}
}
pub fn default_self(matcher: RegexSet) -> Self {
Self {
matcher,
directives: vec![CspDirective {
directive_type: CspDirectiveType::DefaultSrc,
values: vec![CspValue::SelfSite],
}],
}
}
}
impl From<CspUrlMatcher> for HeaderValue {
fn from(input: CspUrlMatcher) -> HeaderValue {
let mut res = String::new();
for directive in input.directives {
res.push_str(directive.directive_type.as_ref());
for val in directive.values {
res.push_str(&format!(" {}", String::from(val)));
}
res.push_str("; ");
}
HeaderValue::from_str(res.trim()).unwrap()
}
}
#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd)]
pub enum CspValue {
None,
SelfSite,
StrictDynamic,
ReportSample,
UnsafeInline,
UnsafeEval,
UnsafeHashes,
UnsafeAllowRedirects,
Host {
value: String,
},
SchemeHttps,
SchemeHttp,
SchemeData,
SchemeOther {
value: String,
},
Nonce {
value: String,
},
Sha256 {
value: String,
},
Sha384 {
value: String,
},
Sha512 {
value: String,
},
}
impl From<CspValue> for String {
fn from(input: CspValue) -> String {
match input {
CspValue::None => "'none'".to_string(),
CspValue::SelfSite => "'self'".to_string(),
CspValue::StrictDynamic => "'strict-dynamic'".to_string(),
CspValue::ReportSample => "'report-sample'".to_string(),
CspValue::UnsafeInline => "'unsafe-inline'".to_string(),
CspValue::UnsafeEval => "'unsafe-eval'".to_string(),
CspValue::UnsafeHashes => "'unsafe-hashes'".to_string(),
CspValue::UnsafeAllowRedirects => "'unsafe-allow-redirects'".to_string(),
CspValue::SchemeHttps => "https:".to_string(),
CspValue::SchemeHttp => "http:".to_string(),
CspValue::SchemeData => "data:".to_string(),
CspValue::Host { value } | CspValue::SchemeOther { value } => value.to_string(),
CspValue::Nonce { value } => format!("nonce-{value}"),
CspValue::Sha256 { value } => format!("sha256-{value}"),
CspValue::Sha384 { value } => format!("sha384-{value}"),
CspValue::Sha512 { value } => format!("sha512-{value}"),
}
}
}
#[derive(Clone, Debug, Default)]
pub struct CspHeaderBuilder {
pub directive_map: HashMap<CspDirectiveType, Vec<CspValue>>,
}
impl CspHeaderBuilder {
pub fn new() -> Self {
Self {
directive_map: HashMap::new(),
}
}
pub fn add(mut self, directive: CspDirectiveType, values: Vec<CspValue>) -> Self {
self.directive_map.entry(directive).or_default();
values.into_iter().for_each(|val| {
if !self.directive_map.get(&directive).unwrap().contains(&val) {
self.directive_map.get_mut(&directive).unwrap().push(val);
}
});
self
}
pub fn finish(self) -> HeaderValue {
let mut keys = self
.directive_map
.keys()
.collect::<Vec<&CspDirectiveType>>();
keys.sort();
let directive_strings: Vec<String> = keys
.iter()
.map(|directive| {
let mut directive_string = String::new();
directive_string.push_str(&format!(" {}", directive));
let mut values = match self.directive_map.get(directive) {
Some(val) => val.to_owned(),
None => vec![],
};
values.sort();
values.into_iter().for_each(|val| {
directive_string.push_str(&format!(" {}", String::from(val)));
});
directive_string.trim().to_string()
})
.collect();
HeaderValue::from_str(&directive_strings.join("; "))
.expect("Failed to build header value from directive strings")
}
}