use serde::Deserialize;
#[derive(Debug, Default, Deserialize)]
pub struct SecurityConfig {
#[serde(default)]
pub headers: HeadersConfig,
#[serde(default)]
pub csrf: CsrfConfig,
}
#[derive(Debug, Deserialize)]
#[allow(clippy::struct_excessive_bools)]
pub struct HeadersConfig {
#[serde(default = "default_x_frame_options")]
pub x_frame_options: String,
#[serde(default = "default_true")]
pub x_content_type_options: bool,
#[serde(default = "default_true")]
pub xss_protection: bool,
#[serde(default)]
pub strict_transport_security: bool,
#[serde(default = "default_hsts_max_age")]
pub hsts_max_age_secs: u64,
#[serde(default = "default_true")]
pub hsts_include_subdomains: bool,
#[serde(default)]
pub content_security_policy: String,
#[serde(default = "default_referrer_policy")]
pub referrer_policy: String,
#[serde(default)]
pub permissions_policy: String,
}
impl Default for HeadersConfig {
fn default() -> Self {
Self {
x_frame_options: default_x_frame_options(),
x_content_type_options: true,
xss_protection: true,
strict_transport_security: false,
hsts_max_age_secs: default_hsts_max_age(),
hsts_include_subdomains: true,
content_security_policy: String::new(),
referrer_policy: default_referrer_policy(),
permissions_policy: String::new(),
}
}
}
#[derive(Debug, Deserialize)]
pub struct CsrfConfig {
#[serde(default)]
pub enabled: bool,
#[serde(default = "default_csrf_header")]
pub token_header: String,
#[serde(default = "default_csrf_field")]
pub form_field: String,
#[serde(default = "default_csrf_cookie")]
pub cookie_name: String,
#[serde(default = "default_safe_methods")]
pub safe_methods: Vec<String>,
}
impl Default for CsrfConfig {
fn default() -> Self {
Self {
enabled: false,
token_header: default_csrf_header(),
form_field: default_csrf_field(),
cookie_name: default_csrf_cookie(),
safe_methods: default_safe_methods(),
}
}
}
const fn default_true() -> bool {
true
}
fn default_x_frame_options() -> String {
"DENY".to_owned()
}
const fn default_hsts_max_age() -> u64 {
31_536_000 }
fn default_referrer_policy() -> String {
"strict-origin-when-cross-origin".to_owned()
}
fn default_csrf_header() -> String {
"X-CSRF-Token".to_owned()
}
fn default_csrf_field() -> String {
"_csrf".to_owned()
}
fn default_csrf_cookie() -> String {
"autumn-csrf".to_owned()
}
fn default_safe_methods() -> Vec<String> {
vec![
"GET".to_owned(),
"HEAD".to_owned(),
"OPTIONS".to_owned(),
"TRACE".to_owned(),
]
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn security_config_defaults() {
let config = SecurityConfig::default();
assert_eq!(config.headers.x_frame_options, "DENY");
assert!(config.headers.x_content_type_options);
assert!(config.headers.xss_protection);
assert!(!config.headers.strict_transport_security);
assert_eq!(config.headers.hsts_max_age_secs, 31_536_000);
assert!(config.headers.content_security_policy.is_empty());
assert_eq!(
config.headers.referrer_policy,
"strict-origin-when-cross-origin"
);
}
#[test]
fn csrf_config_defaults() {
let config = CsrfConfig::default();
assert!(!config.enabled);
assert_eq!(config.token_header, "X-CSRF-Token");
assert_eq!(config.form_field, "_csrf");
assert_eq!(config.cookie_name, "autumn-csrf");
assert_eq!(config.safe_methods.len(), 4);
}
#[test]
fn headers_config_deserialize() {
let toml_str = r#"
x_frame_options = "SAMEORIGIN"
strict_transport_security = true
content_security_policy = "default-src 'self'"
"#;
let config: HeadersConfig = toml::from_str(toml_str).unwrap();
assert_eq!(config.x_frame_options, "SAMEORIGIN");
assert!(config.strict_transport_security);
assert_eq!(config.content_security_policy, "default-src 'self'");
assert!(config.x_content_type_options);
assert!(config.xss_protection);
}
#[test]
fn csrf_config_deserialize() {
let toml_str = r#"
enabled = true
token_header = "X-XSRF-Token"
"#;
let config: CsrfConfig = toml::from_str(toml_str).unwrap();
assert!(config.enabled);
assert_eq!(config.token_header, "X-XSRF-Token");
assert_eq!(config.form_field, "_csrf"); }
#[test]
fn full_security_config_deserialize() {
let toml_str = r#"
[headers]
x_frame_options = "DENY"
strict_transport_security = true
[csrf]
enabled = true
"#;
let config: SecurityConfig = toml::from_str(toml_str).unwrap();
assert_eq!(config.headers.x_frame_options, "DENY");
assert!(config.headers.strict_transport_security);
assert!(config.csrf.enabled);
}
}