oxidite_middleware/
security_headers.rs1use oxidite_core::{OxiditeRequest, OxiditeResponse, Error as CoreError};
2use tower::{Service, Layer};
3use std::task::{Context, Poll};
4use std::future::Future;
5use std::pin::Pin;
6
7#[derive(Clone)]
9pub struct SecurityHeadersMiddleware<S> {
10 inner: S,
11 config: SecurityHeadersConfig,
12}
13
14#[derive(Clone, Debug)]
15pub struct SecurityHeadersConfig {
16 pub csp: Option<String>,
17 pub hsts_max_age: Option<u64>,
18 pub frame_options: FrameOptions,
19 pub content_type_options: bool,
20 pub xss_protection: bool,
21 pub referrer_policy: Option<String>,
22}
23
24#[derive(Clone, Debug)]
25pub enum FrameOptions {
26 Deny,
27 SameOrigin,
28 Allow,
29}
30
31impl Default for SecurityHeadersConfig {
32 fn default() -> Self {
33 Self {
34 csp: Some("default-src 'self'".to_string()),
35 hsts_max_age: Some(31536000), frame_options: FrameOptions::SameOrigin,
37 content_type_options: true,
38 xss_protection: true,
39 referrer_policy: Some("strict-origin-when-cross-origin".to_string()),
40 }
41 }
42}
43
44impl<S> SecurityHeadersMiddleware<S> {
45 pub fn new(inner: S, config: SecurityHeadersConfig) -> Self {
46 Self { inner, config }
47 }
48}
49
50impl<S> Service<OxiditeRequest> for SecurityHeadersMiddleware<S>
51where
52 S: Service<OxiditeRequest, Response = OxiditeResponse, Error = CoreError> + Clone + Send + 'static,
53 S::Future: Send + 'static,
54{
55 type Response = S::Response;
56 type Error = S::Error;
57 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
58
59 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
60 self.inner.poll_ready(cx)
61 }
62
63 fn call(&mut self, req: OxiditeRequest) -> Self::Future {
64 let config = self.config.clone();
65 let mut inner = self.inner.clone();
66
67 Box::pin(async move {
68 let mut response = inner.call(req).await?;
69 let headers = response.headers_mut();
70
71 if let Some(csp) = &config.csp {
73 if let Ok(value) = csp.parse() {
74 headers.insert("content-security-policy", value);
75 }
76 }
77
78 if let Some(max_age) = config.hsts_max_age {
80 let hsts = format!("max-age={}; includeSubDomains", max_age);
81 if let Ok(value) = hsts.parse() {
82 headers.insert("strict-transport-security", value);
83 }
84 }
85
86 let frame_option = match config.frame_options {
88 FrameOptions::Deny => "DENY",
89 FrameOptions::SameOrigin => "SAMEORIGIN",
90 FrameOptions::Allow => return Ok(response),
91 };
92 if let Ok(value) = frame_option.parse() {
93 headers.insert("x-frame-options", value);
94 }
95
96 if config.content_type_options {
98 if let Ok(value) = "nosniff".parse() {
99 headers.insert("x-content-type-options", value);
100 }
101 }
102
103 if config.xss_protection {
105 if let Ok(value) = "1; mode=block".parse() {
106 headers.insert("x-xss-protection", value);
107 }
108 }
109
110 if let Some(policy) = &config.referrer_policy {
112 if let Ok(value) = policy.parse() {
113 headers.insert("referrer-policy", value);
114 }
115 }
116
117 Ok(response)
118 })
119 }
120}
121
122#[derive(Clone)]
124pub struct SecurityHeadersLayer {
125 config: SecurityHeadersConfig,
126}
127
128impl SecurityHeadersLayer {
129 pub fn new(config: SecurityHeadersConfig) -> Self {
130 Self { config }
131 }
132
133 pub fn with_defaults() -> Self {
134 Self {
135 config: SecurityHeadersConfig::default(),
136 }
137 }
138}
139
140impl<S> Layer<S> for SecurityHeadersLayer {
141 type Service = SecurityHeadersMiddleware<S>;
142
143 fn layer(&self, inner: S) -> Self::Service {
144 SecurityHeadersMiddleware::new(inner, self.config.clone())
145 }
146}