Skip to main content

oxidite_middleware/
security_headers.rs

1use oxidite_core::{OxiditeRequest, OxiditeResponse, Error as CoreError};
2use tower::{Service, Layer};
3use std::task::{Context, Poll};
4use std::future::Future;
5use std::pin::Pin;
6
7/// Security headers middleware
8#[derive(Clone)]
9pub struct SecurityHeadersMiddleware<S> {
10    inner: S,
11    config: SecurityHeadersConfig,
12}
13
14#[derive(Clone, Debug)]
15pub struct SecurityHeadersConfig {
16    pub csp: Option<String>,
17    pub hsts_max_age: Option<u64>,
18    pub frame_options: FrameOptions,
19    pub content_type_options: bool,
20    pub xss_protection: bool,
21    pub referrer_policy: Option<String>,
22}
23
24#[derive(Clone, Debug)]
25pub enum FrameOptions {
26    Deny,
27    SameOrigin,
28    Allow,
29}
30
31impl Default for SecurityHeadersConfig {
32    fn default() -> Self {
33        Self {
34            csp: Some("default-src 'self'".to_string()),
35            hsts_max_age: Some(31536000), // 1 year
36            frame_options: FrameOptions::SameOrigin,
37            content_type_options: true,
38            xss_protection: true,
39            referrer_policy: Some("strict-origin-when-cross-origin".to_string()),
40        }
41    }
42}
43
44impl<S> SecurityHeadersMiddleware<S> {
45    pub fn new(inner: S, config: SecurityHeadersConfig) -> Self {
46        Self { inner, config }
47    }
48}
49
50impl<S> Service<OxiditeRequest> for SecurityHeadersMiddleware<S>
51where
52    S: Service<OxiditeRequest, Response = OxiditeResponse, Error = CoreError> + Clone + Send + 'static,
53    S::Future: Send + 'static,
54{
55    type Response = S::Response;
56    type Error = S::Error;
57    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
58
59    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
60        self.inner.poll_ready(cx)
61    }
62
63    fn call(&mut self, req: OxiditeRequest) -> Self::Future {
64        let config = self.config.clone();
65        let mut inner = self.inner.clone();
66
67        Box::pin(async move {
68            let mut response = inner.call(req).await?;
69            let headers = response.headers_mut();
70
71            // Content-Security-Policy
72            if let Some(csp) = &config.csp {
73                if let Ok(value) = csp.parse() {
74                    headers.insert("content-security-policy", value);
75                }
76            }
77
78            // Strict-Transport-Security
79            if let Some(max_age) = config.hsts_max_age {
80                let hsts = format!("max-age={}; includeSubDomains", max_age);
81                if let Ok(value) = hsts.parse() {
82                    headers.insert("strict-transport-security", value);
83                }
84            }
85
86            // X-Frame-Options
87            let frame_option = match config.frame_options {
88                FrameOptions::Deny => "DENY",
89                FrameOptions::SameOrigin => "SAMEORIGIN",
90                FrameOptions::Allow => return Ok(response),
91            };
92            if let Ok(value) = frame_option.parse() {
93                headers.insert("x-frame-options", value);
94            }
95
96            // X-Content-Type-Options
97            if config.content_type_options {
98                if let Ok(value) = "nosniff".parse() {
99                    headers.insert("x-content-type-options", value);
100                }
101            }
102
103            // X-XSS-Protection
104            if config.xss_protection {
105                if let Ok(value) = "1; mode=block".parse() {
106                    headers.insert("x-xss-protection", value);
107                }
108            }
109
110            // Referrer-Policy
111            if let Some(policy) = &config.referrer_policy {
112                if let Ok(value) = policy.parse() {
113                    headers.insert("referrer-policy", value);
114                }
115            }
116
117            Ok(response)
118        })
119    }
120}
121
122/// Layer for security headers middleware
123#[derive(Clone)]
124pub struct SecurityHeadersLayer {
125    config: SecurityHeadersConfig,
126}
127
128impl SecurityHeadersLayer {
129    pub fn new(config: SecurityHeadersConfig) -> Self {
130        Self { config }
131    }
132
133    pub fn with_defaults() -> Self {
134        Self {
135            config: SecurityHeadersConfig::default(),
136        }
137    }
138}
139
140impl<S> Layer<S> for SecurityHeadersLayer {
141    type Service = SecurityHeadersMiddleware<S>;
142
143    fn layer(&self, inner: S) -> Self::Service {
144        SecurityHeadersMiddleware::new(inner, self.config.clone())
145    }
146}