use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use http::{HeaderName, HeaderValue, Response};
use serde::Deserialize;
use tower::{Layer, Service};
use crate::Error;
#[non_exhaustive]
#[derive(Debug, Clone, Deserialize)]
#[serde(default)]
pub struct SecurityHeadersConfig {
pub x_content_type_options: bool,
pub x_frame_options: String,
pub referrer_policy: String,
pub hsts_max_age: Option<u64>,
pub content_security_policy: Option<String>,
pub permissions_policy: Option<String>,
}
impl Default for SecurityHeadersConfig {
fn default() -> Self {
Self {
x_content_type_options: true,
x_frame_options: "DENY".to_string(),
referrer_policy: "strict-origin-when-cross-origin".to_string(),
hsts_max_age: None,
content_security_policy: None,
permissions_policy: None,
}
}
}
#[derive(Clone)]
pub struct SecurityHeadersLayer {
headers: Vec<(HeaderName, HeaderValue)>,
}
impl SecurityHeadersLayer {
fn from_config(config: &SecurityHeadersConfig) -> crate::Result<Self> {
let mut headers = Vec::new();
if config.x_content_type_options {
headers.push((
http::header::X_CONTENT_TYPE_OPTIONS,
HeaderValue::from_static("nosniff"),
));
}
headers.push((
HeaderName::from_static("x-frame-options"),
HeaderValue::from_str(&config.x_frame_options)
.map_err(|_| Error::unprocessable_entity("invalid x-frame-options value"))?,
));
headers.push((
HeaderName::from_static("referrer-policy"),
HeaderValue::from_str(&config.referrer_policy)
.map_err(|_| Error::unprocessable_entity("invalid referrer-policy value"))?,
));
if let Some(max_age) = config.hsts_max_age {
headers.push((
http::header::STRICT_TRANSPORT_SECURITY,
HeaderValue::from_str(&format!("max-age={max_age}"))
.map_err(|_| Error::unprocessable_entity("invalid hsts max-age value"))?,
));
}
if let Some(ref csp) = config.content_security_policy {
headers.push((
HeaderName::from_static("content-security-policy"),
HeaderValue::from_str(csp).map_err(|_| {
Error::unprocessable_entity("invalid content-security-policy value")
})?,
));
}
if let Some(ref pp) = config.permissions_policy {
headers.push((
HeaderName::from_static("permissions-policy"),
HeaderValue::from_str(pp)
.map_err(|_| Error::unprocessable_entity("invalid permissions-policy value"))?,
));
}
Ok(Self { headers })
}
}
impl<S> Layer<S> for SecurityHeadersLayer {
type Service = SecurityHeadersService<S>;
fn layer(&self, inner: S) -> Self::Service {
SecurityHeadersService {
inner,
headers: self.headers.clone(),
}
}
}
#[derive(Clone)]
pub struct SecurityHeadersService<S> {
inner: S,
headers: Vec<(HeaderName, HeaderValue)>,
}
impl<S, ReqBody, ResBody> Service<http::Request<ReqBody>> for SecurityHeadersService<S>
where
S: Service<http::Request<ReqBody>, Response = Response<ResBody>>,
S::Future: Send + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, request: http::Request<ReqBody>) -> Self::Future {
let headers = self.headers.clone();
let future = self.inner.call(request);
Box::pin(async move {
let mut response = future.await?;
let resp_headers = response.headers_mut();
for (name, value) in headers {
if !resp_headers.contains_key(&name) {
resp_headers.insert(name, value);
}
}
Ok(response)
})
}
}
pub fn security_headers(config: &SecurityHeadersConfig) -> crate::Result<SecurityHeadersLayer> {
SecurityHeadersLayer::from_config(config)
}