use std::convert::Infallible;
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use crate::body::BoxBody;
#[derive(Clone, Debug)]
pub struct SecureHeadersLayer {
headers: Vec<(String, String)>,
}
impl SecureHeadersLayer {
pub fn new() -> Self {
SecureHeadersLayer {
headers: vec![
("x-content-type-options".to_string(), "nosniff".to_string()),
("x-frame-options".to_string(), "DENY".to_string()),
("x-xss-protection".to_string(), "0".to_string()),
(
"referrer-policy".to_string(),
"strict-origin-when-cross-origin".to_string(),
),
(
"content-security-policy".to_string(),
"default-src 'self'".to_string(),
),
(
"permissions-policy".to_string(),
"camera=(), microphone=(), geolocation=()".to_string(),
),
],
}
}
pub fn hsts(mut self, max_age_secs: u64) -> Self {
let value = format!("max-age={max_age_secs}; includeSubDomains; preload");
self.set_header("strict-transport-security", value);
self
}
pub fn frame_options(mut self, value: impl Into<String>) -> Self {
self.set_header("x-frame-options", value.into());
self
}
pub fn content_security_policy(mut self, value: impl Into<String>) -> Self {
self.set_header("content-security-policy", value.into());
self
}
pub fn disable_csp(mut self) -> Self {
self.headers
.retain(|(name, _)| name != "content-security-policy");
self
}
pub fn custom(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
let name = name.into().to_ascii_lowercase();
let value = value.into();
self.set_header(&name, value);
self
}
fn set_header(&mut self, name: &str, value: String) {
if let Some(entry) = self.headers.iter_mut().find(|(n, _)| n == name) {
entry.1 = value;
} else {
self.headers.push((name.to_string(), value));
}
}
}
impl Default for SecureHeadersLayer {
fn default() -> Self {
Self::new()
}
}
impl<S> tower_layer::Layer<S> for SecureHeadersLayer {
type Service = SecureHeadersService<S>;
fn layer(&self, inner: S) -> Self::Service {
let parsed: Vec<(http::HeaderName, http::HeaderValue)> = self
.headers
.iter()
.filter_map(|(name, value)| {
let header_name = http::HeaderName::from_bytes(name.as_bytes()).ok()?;
let header_value = http::HeaderValue::from_str(value).ok()?;
Some((header_name, header_value))
})
.collect();
SecureHeadersService {
inner,
headers: std::sync::Arc::new(parsed),
}
}
}
#[derive(Clone)]
pub struct SecureHeadersService<S> {
inner: S,
headers: std::sync::Arc<Vec<(http::HeaderName, http::HeaderValue)>>,
}
impl<S, B> tower_service::Service<http::Request<B>> for SecureHeadersService<S>
where
S: tower_service::Service<
http::Request<B>,
Response = http::Response<BoxBody>,
Error = Infallible,
> + Clone
+ Send
+ 'static,
S::Future: Send + 'static,
B: Send + 'static,
{
type Response = http::Response<BoxBody>;
type Error = Infallible;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: http::Request<B>) -> Self::Future {
let mut inner = self.inner.clone();
let headers = self.headers.clone();
Box::pin(async move {
let mut resp = inner.call(req).await?;
for (name, value) in headers.iter() {
resp.headers_mut().insert(name.clone(), value.clone());
}
Ok(resp)
})
}
}