use rustapi_core::{
middleware::{BoxedNext, MiddlewareLayer},
Request, Response,
};
use std::future::Future;
use std::pin::Pin;
#[derive(Clone)]
pub struct SecurityHeadersConfig {
pub x_content_type_options: bool,
pub x_frame_options: Option<XFrameOptions>,
pub x_xss_protection: bool,
pub hsts: Option<HstsConfig>,
pub csp: Option<String>,
pub referrer_policy: Option<ReferrerPolicy>,
pub permissions_policy: Option<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum XFrameOptions {
Deny,
SameOrigin,
}
impl XFrameOptions {
fn as_str(&self) -> &'static str {
match self {
Self::Deny => "DENY",
Self::SameOrigin => "SAMEORIGIN",
}
}
}
#[derive(Debug, Clone)]
pub struct HstsConfig {
pub max_age: u32,
pub include_subdomains: bool,
pub preload: bool,
}
impl HstsConfig {
fn to_header_value(&self) -> String {
let mut value = format!("max-age={}", self.max_age);
if self.include_subdomains {
value.push_str("; includeSubDomains");
}
if self.preload {
value.push_str("; preload");
}
value
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ReferrerPolicy {
NoReferrer,
NoReferrerWhenDowngrade,
Origin,
OriginWhenCrossOrigin,
SameOrigin,
StrictOrigin,
StrictOriginWhenCrossOrigin,
UnsafeUrl,
}
impl ReferrerPolicy {
fn as_str(&self) -> &'static str {
match self {
Self::NoReferrer => "no-referrer",
Self::NoReferrerWhenDowngrade => "no-referrer-when-downgrade",
Self::Origin => "origin",
Self::OriginWhenCrossOrigin => "origin-when-cross-origin",
Self::SameOrigin => "same-origin",
Self::StrictOrigin => "strict-origin",
Self::StrictOriginWhenCrossOrigin => "strict-origin-when-cross-origin",
Self::UnsafeUrl => "unsafe-url",
}
}
}
impl Default for SecurityHeadersConfig {
fn default() -> Self {
Self {
x_content_type_options: true,
x_frame_options: Some(XFrameOptions::Deny),
x_xss_protection: true,
hsts: Some(HstsConfig {
max_age: 31536000, include_subdomains: true,
preload: false,
}),
csp: Some("default-src 'self'".to_string()),
referrer_policy: Some(ReferrerPolicy::StrictOriginWhenCrossOrigin),
permissions_policy: Some("geolocation=(), microphone=(), camera=()".to_string()),
}
}
}
#[derive(Clone)]
pub struct SecurityHeadersLayer {
config: SecurityHeadersConfig,
}
impl SecurityHeadersLayer {
pub fn new() -> Self {
Self {
config: SecurityHeadersConfig::default(),
}
}
pub fn strict() -> Self {
Self {
config: SecurityHeadersConfig {
x_content_type_options: true,
x_frame_options: Some(XFrameOptions::Deny),
x_xss_protection: true,
hsts: Some(HstsConfig {
max_age: 63072000, include_subdomains: true,
preload: true,
}),
csp: Some(
"default-src 'self'; script-src 'self'; style-src 'self' 'unsafe-inline'; img-src 'self' data:; font-src 'self'; connect-src 'self'; frame-ancestors 'none'"
.to_string(),
),
referrer_policy: Some(ReferrerPolicy::NoReferrer),
permissions_policy: Some(
"geolocation=(), microphone=(), camera=(), payment=(), usb=()".to_string(),
),
},
}
}
pub fn without_content_type_options(mut self) -> Self {
self.config.x_content_type_options = false;
self
}
pub fn x_frame_options(mut self, options: XFrameOptions) -> Self {
self.config.x_frame_options = Some(options);
self
}
pub fn without_x_frame_options(mut self) -> Self {
self.config.x_frame_options = None;
self
}
pub fn hsts(mut self, config: HstsConfig) -> Self {
self.config.hsts = Some(config);
self
}
pub fn without_hsts(mut self) -> Self {
self.config.hsts = None;
self
}
pub fn csp(mut self, policy: impl Into<String>) -> Self {
self.config.csp = Some(policy.into());
self
}
pub fn without_csp(mut self) -> Self {
self.config.csp = None;
self
}
pub fn referrer_policy(mut self, policy: ReferrerPolicy) -> Self {
self.config.referrer_policy = Some(policy);
self
}
pub fn permissions_policy(mut self, policy: impl Into<String>) -> Self {
self.config.permissions_policy = Some(policy.into());
self
}
}
impl Default for SecurityHeadersLayer {
fn default() -> Self {
Self::new()
}
}
impl MiddlewareLayer for SecurityHeadersLayer {
fn call(
&self,
req: Request,
next: BoxedNext,
) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> {
let config = self.config.clone();
Box::pin(async move {
let mut response = next(req).await;
let headers = response.headers_mut();
if config.x_content_type_options {
headers.insert(
http::header::HeaderName::from_static("x-content-type-options"),
http::header::HeaderValue::from_static("nosniff"),
);
}
if let Some(frame_options) = config.x_frame_options {
headers.insert(
http::header::HeaderName::from_static("x-frame-options"),
http::header::HeaderValue::from_static(frame_options.as_str()),
);
}
if config.x_xss_protection {
headers.insert(
http::header::HeaderName::from_static("x-xss-protection"),
http::header::HeaderValue::from_static("1; mode=block"),
);
}
if let Some(hsts) = config.hsts {
if let Ok(value) = http::header::HeaderValue::from_str(&hsts.to_header_value()) {
headers.insert(
http::header::HeaderName::from_static("strict-transport-security"),
value,
);
}
}
if let Some(csp) = config.csp {
if let Ok(value) = http::header::HeaderValue::from_str(&csp) {
headers.insert(
http::header::HeaderName::from_static("content-security-policy"),
value,
);
}
}
if let Some(referrer_policy) = config.referrer_policy {
headers.insert(
http::header::HeaderName::from_static("referrer-policy"),
http::header::HeaderValue::from_static(referrer_policy.as_str()),
);
}
if let Some(permissions) = config.permissions_policy {
if let Ok(value) = http::header::HeaderValue::from_str(&permissions) {
headers.insert(
http::header::HeaderName::from_static("permissions-policy"),
value,
);
}
}
response
})
}
fn clone_box(&self) -> Box<dyn MiddlewareLayer> {
Box::new(self.clone())
}
}
#[cfg(test)]
mod tests {
use super::*;
use bytes::Bytes;
use rustapi_core::ResponseBody;
use std::sync::Arc;
#[tokio::test]
async fn security_headers_added_to_response() {
let layer = SecurityHeadersLayer::new();
let next: BoxedNext = Arc::new(|_req: Request| {
Box::pin(async {
http::Response::builder()
.status(200)
.body(ResponseBody::new(Bytes::from("OK")))
.unwrap()
}) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
});
let req = Request::from_http_request(
http::Request::builder()
.method("GET")
.uri("/")
.body(())
.unwrap(),
Bytes::new(),
);
let response = layer.call(req, next).await;
assert!(response.headers().contains_key("x-content-type-options"));
assert!(response.headers().contains_key("x-frame-options"));
assert!(response.headers().contains_key("x-xss-protection"));
assert!(response.headers().contains_key("strict-transport-security"));
assert!(response.headers().contains_key("content-security-policy"));
assert!(response.headers().contains_key("referrer-policy"));
}
#[tokio::test]
async fn strict_mode_adds_all_headers() {
let layer = SecurityHeadersLayer::strict();
let next: BoxedNext = Arc::new(|_req: Request| {
Box::pin(async {
http::Response::builder()
.status(200)
.body(ResponseBody::new(Bytes::from("OK")))
.unwrap()
}) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
});
let req = Request::from_http_request(
http::Request::builder()
.method("GET")
.uri("/")
.body(())
.unwrap(),
Bytes::new(),
);
let response = layer.call(req, next).await;
let hsts = response
.headers()
.get("strict-transport-security")
.unwrap()
.to_str()
.unwrap();
assert!(hsts.contains("preload"));
assert!(hsts.contains("includeSubDomains"));
}
#[test]
fn hsts_config_formats_correctly() {
let hsts = HstsConfig {
max_age: 31536000,
include_subdomains: true,
preload: true,
};
assert_eq!(
hsts.to_header_value(),
"max-age=31536000; includeSubDomains; preload"
);
}
}