use std::collections::HashMap;
#[derive(Debug)]
pub struct SecurityHeaders {
headers: HashMap<String, String>,
}
impl Default for SecurityHeaders {
fn default() -> Self {
let mut headers = HashMap::new();
headers.insert("X-XSS-Protection".to_string(), "1; mode=block".to_string());
headers.insert("X-Content-Type-Options".to_string(), "nosniff".to_string());
headers.insert("X-Frame-Options".to_string(), "DENY".to_string());
headers
.insert("Referrer-Policy".to_string(), "strict-origin-when-cross-origin".to_string());
headers.insert(
"Permissions-Policy".to_string(),
"geolocation=(), microphone=(), camera=()".to_string(),
);
Self { headers }
}
}
impl SecurityHeaders {
#[must_use]
pub fn production() -> Self {
let mut headers = Self::default().headers;
headers.insert(
"Content-Security-Policy".to_string(),
"default-src 'self'; script-src 'self'; style-src 'self'; img-src 'self' data: https:; font-src 'self'; connect-src 'self'; frame-ancestors 'none'".to_string(),
);
headers.insert(
"Strict-Transport-Security".to_string(),
"max-age=63072000; includeSubDomains; preload".to_string(),
);
Self { headers }
}
#[must_use]
pub fn to_vec(&self) -> Vec<(String, String)> {
self.headers.iter().map(|(k, v)| (k.clone(), v.clone())).collect()
}
pub fn add(&mut self, name: String, value: String) {
self.headers.insert(name, value);
}
pub fn remove(&mut self, name: &str) {
self.headers.remove(name);
}
#[must_use]
pub fn get(&self, name: &str) -> Option<&String> {
self.headers.get(name)
}
#[must_use]
pub fn has(&self, name: &str) -> bool {
self.headers.contains_key(name)
}
#[must_use]
pub fn names(&self) -> Vec<String> {
self.headers.keys().cloned().collect()
}
pub fn merge(&mut self, other: &Self) {
for (key, value) in &other.headers {
self.headers.insert(key.clone(), value.clone());
}
}
#[must_use]
pub fn development() -> Self {
let mut headers = Self::default().headers;
headers.insert(
"Content-Security-Policy".to_string(),
"default-src 'self' 'unsafe-inline' 'unsafe-eval'; script-src 'self' 'unsafe-inline' 'unsafe-eval'; style-src 'self' 'unsafe-inline'; img-src 'self' data: https: http:; font-src 'self' data:; connect-src 'self' ws: wss: http: https:".to_string(),
);
Self { headers }
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_security_headers() {
let headers = SecurityHeaders::default();
assert!(headers.has("X-XSS-Protection"));
assert!(headers.has("X-Content-Type-Options"));
assert!(headers.has("X-Frame-Options"));
assert!(headers.has("Referrer-Policy"));
assert!(headers.has("Permissions-Policy"));
}
#[test]
fn test_production_security_headers() {
let headers = SecurityHeaders::production();
assert!(headers.has("Content-Security-Policy"));
assert!(headers.has("Strict-Transport-Security"));
assert!(headers.has("X-XSS-Protection")); }
#[test]
fn test_custom_header_operations() {
let mut headers = SecurityHeaders::default();
headers.add("X-Custom-Header".to_string(), "custom-value".to_string());
assert_eq!(headers.get("X-Custom-Header"), Some(&"custom-value".to_string()));
headers.remove("X-Custom-Header");
assert!(!headers.has("X-Custom-Header"));
}
#[test]
fn test_header_merge() {
let mut headers1 = SecurityHeaders::default();
let mut headers2 = SecurityHeaders::default();
headers2.add("X-Custom".to_string(), "value".to_string());
headers1.merge(&headers2);
assert!(headers1.has("X-Custom"));
assert_eq!(headers1.get("X-Custom"), Some(&"value".to_string()));
}
}