Skip to main content

fraiseql_core/security/
headers.rs

1//! Security header enforcement.
2
3use std::collections::HashMap;
4
5/// Security headers configuration
6#[derive(Debug)]
7pub struct SecurityHeaders {
8    headers: HashMap<String, String>,
9}
10
11impl Default for SecurityHeaders {
12    /// Create default security headers
13    fn default() -> Self {
14        let mut headers = HashMap::new();
15
16        // Prevent XSS
17        headers.insert("X-XSS-Protection".to_string(), "1; mode=block".to_string());
18
19        // Prevent MIME sniffing
20        headers.insert("X-Content-Type-Options".to_string(), "nosniff".to_string());
21
22        // Prevent clickjacking
23        headers.insert("X-Frame-Options".to_string(), "DENY".to_string());
24
25        // Referrer policy
26        headers
27            .insert("Referrer-Policy".to_string(), "strict-origin-when-cross-origin".to_string());
28
29        // Permissions policy
30        headers.insert(
31            "Permissions-Policy".to_string(),
32            "geolocation=(), microphone=(), camera=()".to_string(),
33        );
34
35        Self { headers }
36    }
37}
38
39impl SecurityHeaders {
40    /// Create production-grade security headers
41    #[must_use]
42    pub fn production() -> Self {
43        let mut headers = Self::default().headers;
44
45        // Stricter CSP for production
46        headers.insert(
47            "Content-Security-Policy".to_string(),
48            "default-src 'self'; script-src 'self'; style-src 'self'; img-src 'self' data: https:; font-src 'self'; connect-src 'self'; frame-ancestors 'none'".to_string(),
49        );
50
51        // HSTS with preload
52        headers.insert(
53            "Strict-Transport-Security".to_string(),
54            "max-age=63072000; includeSubDomains; preload".to_string(),
55        );
56
57        Self { headers }
58    }
59
60    /// Get headers as Vec for HTTP response
61    #[must_use]
62    pub fn to_vec(&self) -> Vec<(String, String)> {
63        self.headers.iter().map(|(k, v)| (k.clone(), v.clone())).collect()
64    }
65
66    /// Add custom header
67    pub fn add(&mut self, name: String, value: String) {
68        self.headers.insert(name, value);
69    }
70
71    /// Remove header
72    pub fn remove(&mut self, name: &str) {
73        self.headers.remove(name);
74    }
75
76    /// Get header value
77    #[must_use]
78    pub fn get(&self, name: &str) -> Option<&String> {
79        self.headers.get(name)
80    }
81
82    /// Check if header exists
83    #[must_use]
84    pub fn has(&self, name: &str) -> bool {
85        self.headers.contains_key(name)
86    }
87
88    /// Get all header names
89    #[must_use]
90    pub fn names(&self) -> Vec<String> {
91        self.headers.keys().cloned().collect()
92    }
93
94    /// Merge with another `SecurityHeaders` instance
95    pub fn merge(&mut self, other: &Self) {
96        for (key, value) in &other.headers {
97            self.headers.insert(key.clone(), value.clone());
98        }
99    }
100
101    /// Create headers for development environment
102    #[must_use]
103    pub fn development() -> Self {
104        let mut headers = Self::default().headers;
105
106        // More permissive CSP for development
107        headers.insert(
108            "Content-Security-Policy".to_string(),
109            "default-src 'self' 'unsafe-inline' 'unsafe-eval'; script-src 'self' 'unsafe-inline' 'unsafe-eval'; style-src 'self' 'unsafe-inline'; img-src 'self' data: https: http:; font-src 'self' data:; connect-src 'self' ws: wss: http: https:".to_string(),
110        );
111
112        Self { headers }
113    }
114}
115
116#[cfg(test)]
117mod tests {
118    use super::*;
119
120    #[test]
121    fn test_default_security_headers() {
122        let headers = SecurityHeaders::default();
123        assert!(headers.has("X-XSS-Protection"));
124        assert!(headers.has("X-Content-Type-Options"));
125        assert!(headers.has("X-Frame-Options"));
126        assert!(headers.has("Referrer-Policy"));
127        assert!(headers.has("Permissions-Policy"));
128    }
129
130    #[test]
131    fn test_production_security_headers() {
132        let headers = SecurityHeaders::production();
133        assert!(headers.has("Content-Security-Policy"));
134        assert!(headers.has("Strict-Transport-Security"));
135        assert!(headers.has("X-XSS-Protection")); // Should inherit from default
136    }
137
138    #[test]
139    fn test_custom_header_operations() {
140        let mut headers = SecurityHeaders::default();
141
142        // Add custom header
143        headers.add("X-Custom-Header".to_string(), "custom-value".to_string());
144        assert_eq!(headers.get("X-Custom-Header"), Some(&"custom-value".to_string()));
145
146        // Remove header
147        headers.remove("X-Custom-Header");
148        assert!(!headers.has("X-Custom-Header"));
149    }
150
151    #[test]
152    fn test_header_merge() {
153        let mut headers1 = SecurityHeaders::default();
154        let mut headers2 = SecurityHeaders::default();
155
156        headers2.add("X-Custom".to_string(), "value".to_string());
157        headers1.merge(&headers2);
158
159        assert!(headers1.has("X-Custom"));
160        assert_eq!(headers1.get("X-Custom"), Some(&"value".to_string()));
161    }
162}