use std::{collections::HashMap, time::Duration};
use ic_http_certification::HeaderField;
pub struct CacheControl {
pub static_assets: String,
pub dynamic_assets: String,
}
impl Default for CacheControl {
fn default() -> Self {
Self {
static_assets: "public, max-age=31536000, immutable".into(),
dynamic_assets: "public, no-cache, no-store".into(),
}
}
}
#[derive(Default)]
pub struct CacheConfig {
pub default_ttl: Option<Duration>,
pub per_route_ttl: HashMap<String, Duration>,
}
impl CacheConfig {
pub fn effective_ttl(&self, path: &str) -> Option<Duration> {
self.per_route_ttl.get(path).copied().or(self.default_ttl)
}
}
pub struct SecurityHeaders {
pub hsts: Option<String>,
pub csp: Option<String>,
pub content_type_options: Option<String>,
pub frame_options: Option<String>,
pub referrer_policy: Option<String>,
pub permissions_policy: Option<String>,
pub coep: Option<String>,
pub coop: Option<String>,
pub corp: Option<String>,
pub dns_prefetch_control: Option<String>,
pub permitted_cross_domain_policies: Option<String>,
}
impl SecurityHeaders {
pub fn strict() -> Self {
Self {
hsts: Some("max-age=31536000; includeSubDomains".into()),
csp: None,
content_type_options: Some("nosniff".into()),
frame_options: Some("DENY".into()),
referrer_policy: Some("no-referrer".into()),
permissions_policy: Some(
"accelerometer=(), camera=(), geolocation=(), gyroscope=(), \
magnetometer=(), microphone=(), payment=(), usb=(), interest-cohort=()"
.into(),
),
coep: Some("require-corp".into()),
coop: Some("same-origin".into()),
corp: Some("same-origin".into()),
dns_prefetch_control: Some("off".into()),
permitted_cross_domain_policies: Some("none".into()),
}
}
pub fn permissive() -> Self {
Self {
hsts: Some("max-age=31536000; includeSubDomains".into()),
csp: None,
content_type_options: Some("nosniff".into()),
frame_options: Some("SAMEORIGIN".into()),
referrer_policy: Some("strict-origin-when-cross-origin".into()),
permissions_policy: None,
coep: None,
coop: Some("same-origin-allow-popups".into()),
corp: Some("cross-origin".into()),
dns_prefetch_control: None,
permitted_cross_domain_policies: Some("none".into()),
}
}
pub fn none() -> Self {
Self {
hsts: None,
csp: None,
content_type_options: None,
frame_options: None,
referrer_policy: None,
permissions_policy: None,
coep: None,
coop: None,
corp: None,
dns_prefetch_control: None,
permitted_cross_domain_policies: None,
}
}
pub fn to_headers(&self) -> Vec<HeaderField> {
let mut headers = Vec::new();
if let Some(ref v) = self.hsts {
headers.push(("strict-transport-security".to_string(), v.clone()));
}
if let Some(ref v) = self.csp {
headers.push(("content-security-policy".to_string(), v.clone()));
}
if let Some(ref v) = self.content_type_options {
headers.push(("x-content-type-options".to_string(), v.clone()));
}
if let Some(ref v) = self.frame_options {
headers.push(("x-frame-options".to_string(), v.clone()));
}
if let Some(ref v) = self.referrer_policy {
headers.push(("referrer-policy".to_string(), v.clone()));
}
if let Some(ref v) = self.permissions_policy {
headers.push(("permissions-policy".to_string(), v.clone()));
}
if let Some(ref v) = self.coep {
headers.push(("cross-origin-embedder-policy".to_string(), v.clone()));
}
if let Some(ref v) = self.coop {
headers.push(("cross-origin-opener-policy".to_string(), v.clone()));
}
if let Some(ref v) = self.corp {
headers.push(("cross-origin-resource-policy".to_string(), v.clone()));
}
if let Some(ref v) = self.dns_prefetch_control {
headers.push(("x-dns-prefetch-control".to_string(), v.clone()));
}
if let Some(ref v) = self.permitted_cross_domain_policies {
headers.push(("x-permitted-cross-domain-policies".to_string(), v.clone()));
}
headers
}
}
impl Default for SecurityHeaders {
fn default() -> Self {
Self::permissive()
}
}
#[derive(Default)]
pub struct AssetConfig {
pub security_headers: SecurityHeaders,
pub cache_control: CacheControl,
pub cache_config: CacheConfig,
pub custom_headers: Vec<HeaderField>,
}
impl AssetConfig {
pub fn merged_headers(&self, additional_headers: Vec<HeaderField>) -> Vec<HeaderField> {
let mut merged: Vec<HeaderField> = Vec::new();
for h in self.security_headers.to_headers() {
merged.push(h);
}
for h in &self.custom_headers {
let name_lower = h.0.to_lowercase();
merged.retain(|(k, _)| k.to_lowercase() != name_lower);
merged.push(h.clone());
}
for h in &additional_headers {
let name_lower = h.0.to_lowercase();
merged.retain(|(k, _)| k.to_lowercase() != name_lower);
merged.push(h.clone());
}
merged
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn strict_produces_expected_headers() {
let headers = SecurityHeaders::strict().to_headers();
let names: Vec<&str> = headers.iter().map(|(k, _)| k.as_str()).collect();
assert!(names.contains(&"strict-transport-security"));
assert!(names.contains(&"x-content-type-options"));
assert!(names.contains(&"x-frame-options"));
assert!(names.contains(&"referrer-policy"));
assert!(names.contains(&"permissions-policy"));
assert!(names.contains(&"cross-origin-embedder-policy"));
assert!(names.contains(&"cross-origin-opener-policy"));
assert!(names.contains(&"cross-origin-resource-policy"));
assert!(names.contains(&"x-dns-prefetch-control"));
assert!(names.contains(&"x-permitted-cross-domain-policies"));
assert!(!names.contains(&"content-security-policy"));
let find =
|name: &str| -> String { headers.iter().find(|(k, _)| k == name).unwrap().1.clone() };
assert_eq!(find("x-frame-options"), "DENY");
assert_eq!(find("referrer-policy"), "no-referrer");
assert_eq!(find("cross-origin-embedder-policy"), "require-corp");
assert_eq!(find("cross-origin-opener-policy"), "same-origin");
assert_eq!(find("cross-origin-resource-policy"), "same-origin");
assert_eq!(find("x-dns-prefetch-control"), "off");
assert_eq!(find("x-permitted-cross-domain-policies"), "none");
}
#[test]
fn permissive_produces_expected_headers() {
let headers = SecurityHeaders::permissive().to_headers();
let names: Vec<&str> = headers.iter().map(|(k, _)| k.as_str()).collect();
assert!(names.contains(&"strict-transport-security"));
assert!(names.contains(&"x-content-type-options"));
assert!(names.contains(&"x-frame-options"));
assert!(names.contains(&"referrer-policy"));
assert!(names.contains(&"cross-origin-opener-policy"));
assert!(names.contains(&"cross-origin-resource-policy"));
assert!(names.contains(&"x-permitted-cross-domain-policies"));
assert!(!names.contains(&"permissions-policy"));
assert!(!names.contains(&"cross-origin-embedder-policy"));
assert!(!names.contains(&"content-security-policy"));
assert!(!names.contains(&"x-dns-prefetch-control"));
let find =
|name: &str| -> String { headers.iter().find(|(k, _)| k == name).unwrap().1.clone() };
assert_eq!(find("x-frame-options"), "SAMEORIGIN");
assert_eq!(find("referrer-policy"), "strict-origin-when-cross-origin");
assert_eq!(
find("cross-origin-opener-policy"),
"same-origin-allow-popups"
);
assert_eq!(find("cross-origin-resource-policy"), "cross-origin");
}
#[test]
fn none_produces_zero_headers() {
let headers = SecurityHeaders::none().to_headers();
assert!(headers.is_empty());
}
#[test]
fn custom_headers_override_security_headers() {
let config = AssetConfig {
security_headers: SecurityHeaders::strict(),
cache_control: CacheControl::default(),
cache_config: CacheConfig::default(),
custom_headers: vec![("x-frame-options".to_string(), "SAMEORIGIN".to_string())],
};
let merged = config.merged_headers(vec![]);
let frame_opts: Vec<_> = merged
.iter()
.filter(|(k, _)| k == "x-frame-options")
.collect();
assert_eq!(frame_opts.len(), 1);
assert_eq!(frame_opts[0].1, "SAMEORIGIN");
}
#[test]
fn additional_headers_override_custom_and_security() {
let config = AssetConfig {
security_headers: SecurityHeaders::strict(),
cache_control: CacheControl::default(),
cache_config: CacheConfig::default(),
custom_headers: vec![("x-frame-options".to_string(), "SAMEORIGIN".to_string())],
};
let merged = config.merged_headers(vec![(
"X-Frame-Options".to_string(),
"ALLOW-FROM https://example.com".to_string(),
)]);
let frame_opts: Vec<_> = merged
.iter()
.filter(|(k, _)| k.to_lowercase() == "x-frame-options")
.collect();
assert_eq!(frame_opts.len(), 1);
assert_eq!(frame_opts[0].1, "ALLOW-FROM https://example.com");
}
#[test]
fn xss_protection_never_set() {
for headers in [
SecurityHeaders::strict().to_headers(),
SecurityHeaders::permissive().to_headers(),
SecurityHeaders::none().to_headers(),
] {
assert!(
headers
.iter()
.all(|(k, _)| k.to_lowercase() != "x-xss-protection"),
"X-XSS-Protection should never be set by any preset"
);
}
}
#[test]
fn default_is_permissive() {
let default_headers = SecurityHeaders::default().to_headers();
let permissive_headers = SecurityHeaders::permissive().to_headers();
assert_eq!(default_headers, permissive_headers);
}
#[test]
fn default_cache_control_reproduces_current_behavior() {
let cc = CacheControl::default();
assert_eq!(cc.static_assets, "public, max-age=31536000, immutable");
assert_eq!(cc.dynamic_assets, "public, no-cache, no-store");
}
#[test]
fn custom_static_cache_control() {
let cc = CacheControl {
static_assets: "public, max-age=3600".into(),
..CacheControl::default()
};
assert_eq!(cc.static_assets, "public, max-age=3600");
assert_eq!(cc.dynamic_assets, "public, no-cache, no-store");
}
#[test]
fn custom_dynamic_cache_control() {
let cc = CacheControl {
dynamic_assets: "public, max-age=600".into(),
..CacheControl::default()
};
assert_eq!(cc.dynamic_assets, "public, max-age=600");
assert_eq!(cc.static_assets, "public, max-age=31536000, immutable");
}
#[test]
fn asset_config_default_includes_default_cache_control() {
let config = AssetConfig::default();
assert_eq!(
config.cache_control.static_assets,
"public, max-age=31536000, immutable"
);
assert_eq!(
config.cache_control.dynamic_assets,
"public, no-cache, no-store"
);
}
#[test]
fn cache_config_default_has_none_and_empty() {
let cc = CacheConfig::default();
assert!(cc.default_ttl.is_none());
assert!(cc.per_route_ttl.is_empty());
}
#[test]
fn per_route_ttl_overrides_default_ttl() {
let cc = CacheConfig {
default_ttl: Some(Duration::from_secs(3600)),
per_route_ttl: HashMap::from([("/posts/1".to_string(), Duration::from_secs(60))]),
};
assert_eq!(cc.effective_ttl("/posts/1"), Some(Duration::from_secs(60)));
assert_eq!(cc.effective_ttl("/about"), Some(Duration::from_secs(3600)));
let cc_no_default = CacheConfig {
default_ttl: None,
per_route_ttl: HashMap::from([("/posts/1".to_string(), Duration::from_secs(60))]),
};
assert_eq!(
cc_no_default.effective_ttl("/posts/1"),
Some(Duration::from_secs(60))
);
assert_eq!(cc_no_default.effective_ttl("/about"), None);
}
}