use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{debug, info, warn};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HstsPolicy {
pub honor_hsts: bool,
pub strip_hsts_headers: bool,
pub max_cache_age: u64,
pub use_preload_list: bool,
}
impl Default for HstsPolicy {
fn default() -> Self {
Self {
honor_hsts: true, strip_hsts_headers: false, max_cache_age: 31536000, use_preload_list: true, }
}
}
impl HstsPolicy {
pub fn honor() -> Self {
Self::default()
}
pub fn strip_for_testing() -> Self {
Self {
honor_hsts: false,
strip_hsts_headers: true,
..Default::default()
}
}
pub fn from_env() -> Self {
let mut policy = Self::default();
if let Ok(val) = std::env::var("DERUSTED_HSTS_HONOR") {
policy.honor_hsts = val.parse().unwrap_or(true);
}
if let Ok(val) = std::env::var("DERUSTED_HSTS_STRIP") {
policy.strip_hsts_headers = val.parse().unwrap_or(false);
}
if let Ok(val) = std::env::var("DERUSTED_HSTS_MAX_AGE") {
policy.max_cache_age = val.parse().unwrap_or(31536000);
}
if let Ok(val) = std::env::var("DERUSTED_HSTS_PRELOAD") {
policy.use_preload_list = val.parse().unwrap_or(true);
}
policy
}
}
#[derive(Debug, Clone)]
struct HstsEntry {
expires_at: i64,
include_subdomains: bool,
preloaded: bool,
}
pub struct HstsManager {
policy: HstsPolicy,
cache: Arc<RwLock<HashMap<String, HstsEntry>>>,
}
impl HstsManager {
pub fn new() -> Self {
Self::with_policy(HstsPolicy::default())
}
pub fn with_policy(policy: HstsPolicy) -> Self {
let mut cache = HashMap::new();
if policy.use_preload_list {
Self::load_preload_list(&mut cache);
}
Self {
policy,
cache: Arc::new(RwLock::new(cache)),
}
}
pub fn from_env() -> Self {
Self::with_policy(HstsPolicy::from_env())
}
pub async fn is_hsts_domain(&self, domain: &str) -> bool {
if !self.policy.honor_hsts {
return false; }
let cache = self.cache.read().await;
if let Some(entry) = cache.get(domain) {
if entry.preloaded || entry.expires_at > chrono::Utc::now().timestamp() {
debug!(domain = %domain, "Domain is HSTS-protected");
return true;
}
}
let mut current_domain = domain;
while let Some(parent) = Self::parent_domain(current_domain) {
if let Some(entry) = cache.get(parent) {
if entry.include_subdomains {
if entry.preloaded || entry.expires_at > chrono::Utc::now().timestamp() {
debug!(
domain = %domain,
parent = %parent,
"Domain is HSTS-protected via parent"
);
return true;
}
}
}
current_domain = parent;
}
false
}
pub async fn add_from_header(&self, domain: &str, header_value: &str) {
if !self.policy.honor_hsts {
return; }
let max_age = Self::parse_max_age(header_value);
if max_age == 0 {
warn!(
domain = %domain,
header = %header_value,
"HSTS header with max-age=0, removing from cache"
);
self.cache.write().await.remove(domain);
return;
}
let include_subdomains = header_value.contains("includeSubDomains");
let expires_at = chrono::Utc::now().timestamp() + max_age as i64;
let entry = HstsEntry {
expires_at,
include_subdomains,
preloaded: false,
};
info!(
domain = %domain,
max_age = max_age,
include_subdomains = include_subdomains,
"Added HSTS entry"
);
self.cache.write().await.insert(domain.to_string(), entry);
}
pub fn process_response_headers(&self, headers: &mut HashMap<String, String>) {
if self.policy.strip_hsts_headers {
if headers.remove("strict-transport-security").is_some() {
warn!("HSTS header stripped (testing mode)");
}
}
}
fn parent_domain(domain: &str) -> Option<&str> {
let parts: Vec<&str> = domain.split('.').collect();
if parts.len() > 2 {
Some(&domain[domain.find('.').unwrap() + 1..])
} else {
None
}
}
fn parse_max_age(header_value: &str) -> u64 {
for directive in header_value.split(';') {
let directive = directive.trim();
if let Some(value) = directive.strip_prefix("max-age=") {
if let Ok(age) = value.trim().parse::<u64>() {
return age;
}
}
}
0
}
fn load_preload_list(cache: &mut HashMap<String, HstsEntry>) {
let preload_domains = vec![
"google.com",
"gmail.com",
"youtube.com",
"facebook.com",
"twitter.com",
"github.com",
"wikipedia.org",
"cloudflare.com",
"amazon.com",
"apple.com",
"microsoft.com",
"netflix.com",
"linkedin.com",
"reddit.com",
"instagram.com",
"paypal.com",
"dropbox.com",
"stackoverflow.com",
"zoom.us",
"slack.com",
];
let count = preload_domains.len();
for domain in &preload_domains {
cache.insert(
domain.to_string(),
HstsEntry {
expires_at: i64::MAX, include_subdomains: true,
preloaded: true,
},
);
}
info!(count = count, "Loaded HSTS preload list");
}
pub fn policy(&self) -> &HstsPolicy {
&self.policy
}
pub async fn cache_size(&self) -> usize {
self.cache.read().await.len()
}
pub async fn cleanup_expired(&self) {
let mut cache = self.cache.write().await;
let now = chrono::Utc::now().timestamp();
cache.retain(|domain, entry| {
if entry.preloaded {
true } else if entry.expires_at > now {
true } else {
debug!(domain = %domain, "HSTS entry expired");
false
}
});
}
}
impl Default for HstsManager {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_hsts_policy_defaults() {
let policy = HstsPolicy::default();
assert!(policy.honor_hsts);
assert!(!policy.strip_hsts_headers);
assert_eq!(policy.max_cache_age, 31536000);
assert!(policy.use_preload_list);
}
#[tokio::test]
async fn test_hsts_policy_strip_for_testing() {
let policy = HstsPolicy::strip_for_testing();
assert!(!policy.honor_hsts);
assert!(policy.strip_hsts_headers);
}
#[tokio::test]
async fn test_hsts_manager_preload_list() {
let manager = HstsManager::new();
assert!(manager.is_hsts_domain("google.com").await);
assert!(manager.is_hsts_domain("github.com").await);
assert!(manager.is_hsts_domain("facebook.com").await);
assert!(!manager.is_hsts_domain("example.com").await);
}
#[tokio::test]
async fn test_hsts_manager_add_from_header() {
let manager = HstsManager::new();
manager
.add_from_header("example.com", "max-age=31536000; includeSubDomains")
.await;
assert!(manager.is_hsts_domain("example.com").await);
assert!(manager.is_hsts_domain("sub.example.com").await);
}
#[tokio::test]
async fn test_hsts_manager_max_age_zero() {
let manager = HstsManager::new();
manager
.add_from_header("example.com", "max-age=31536000")
.await;
assert!(manager.is_hsts_domain("example.com").await);
manager.add_from_header("example.com", "max-age=0").await;
assert!(!manager.is_hsts_domain("example.com").await);
}
#[tokio::test]
async fn test_hsts_manager_disabled() {
let policy = HstsPolicy {
honor_hsts: false,
..Default::default()
};
let manager = HstsManager::with_policy(policy);
assert!(!manager.is_hsts_domain("google.com").await);
}
#[tokio::test]
async fn test_hsts_strip_headers() {
let policy = HstsPolicy::strip_for_testing();
let manager = HstsManager::with_policy(policy);
let mut headers = HashMap::new();
headers.insert(
"strict-transport-security".to_string(),
"max-age=31536000".to_string(),
);
headers.insert("content-type".to_string(), "text/html".to_string());
manager.process_response_headers(&mut headers);
assert!(!headers.contains_key("strict-transport-security"));
assert!(headers.contains_key("content-type"));
}
#[test]
fn test_parse_max_age() {
assert_eq!(HstsManager::parse_max_age("max-age=31536000"), 31536000);
assert_eq!(HstsManager::parse_max_age("max-age=0"), 0);
assert_eq!(
HstsManager::parse_max_age("max-age=31536000; includeSubDomains"),
31536000
);
assert_eq!(HstsManager::parse_max_age("invalid"), 0);
}
#[test]
fn test_parent_domain() {
assert_eq!(
HstsManager::parent_domain("sub.example.com"),
Some("example.com")
);
assert_eq!(
HstsManager::parent_domain("deep.sub.example.com"),
Some("sub.example.com")
);
assert_eq!(HstsManager::parent_domain("example.com"), None);
}
#[tokio::test]
async fn test_cleanup_expired() {
let manager = HstsManager::new();
manager.cache.write().await.insert(
"expired.com".to_string(),
HstsEntry {
expires_at: 0,
include_subdomains: false,
preloaded: false,
},
);
manager.cache.write().await.insert(
"valid.com".to_string(),
HstsEntry {
expires_at: chrono::Utc::now().timestamp() + 3600,
include_subdomains: false,
preloaded: false,
},
);
manager.cleanup_expired().await;
assert!(!manager.cache.read().await.contains_key("expired.com"));
assert!(manager.cache.read().await.contains_key("valid.com"));
}
#[tokio::test]
async fn test_hsts_includesubdomains_multi_level() {
let manager = HstsManager::new();
manager
.add_from_header("example.com", "max-age=31536000; includeSubDomains")
.await;
assert!(
manager.is_hsts_domain("foo.bar.example.com").await,
"foo.bar.example.com should be protected via example.com includeSubDomains"
);
assert!(manager.is_hsts_domain("example.com").await);
assert!(manager.is_hsts_domain("bar.example.com").await);
assert!(manager.is_hsts_domain("foo.bar.example.com").await);
assert!(manager.is_hsts_domain("baz.foo.bar.example.com").await);
}
}