1use crate::{Error, Result};
2use reqwest::header::{HeaderMap, HeaderValue};
3use std::collections::HashMap;
4use std::time::{SystemTime, UNIX_EPOCH};
5use tracing::{debug, warn, instrument};
6
7#[derive(Debug, Clone)]
9pub struct SecurityConfig {
10 pub mask_api_key: bool,
12 pub security_headers: HashMap<String, String>,
14 pub enable_cert_pinning: bool,
16 pub max_request_size: usize,
18 pub enable_request_signing: bool,
20 pub security_timeout: std::time::Duration,
22}
23
24impl Default for SecurityConfig {
25 fn default() -> Self {
26 let mut security_headers = HashMap::new();
27 security_headers.insert("X-Request-ID".to_string(), "generated".to_string());
28 security_headers.insert("User-Agent".to_string(), "goldrush-sdk-rs".to_string());
29
30 Self {
31 mask_api_key: true,
32 security_headers,
33 enable_cert_pinning: false, max_request_size: 1024 * 1024, enable_request_signing: false, security_timeout: std::time::Duration::from_secs(10),
37 }
38 }
39}
40
41pub struct SecurityManager {
43 config: SecurityConfig,
44}
45
46impl SecurityManager {
47 pub fn new(config: SecurityConfig) -> Self {
48 Self { config }
49 }
50
51 #[instrument(skip(self))]
53 pub fn mask_api_key(&self, api_key: &str) -> String {
54 if !self.config.mask_api_key {
55 return api_key.to_string();
56 }
57
58 if api_key.len() < 8 {
59 return "*".repeat(api_key.len());
60 }
61
62 let prefix = &api_key[..4];
63 let suffix = &api_key[api_key.len()-4..];
64 format!("{}***{}", prefix, suffix)
65 }
66
67 #[instrument(skip(self))]
69 pub fn generate_security_headers(&self, request_id: &str) -> Result<HeaderMap> {
70 let mut headers = HeaderMap::new();
71
72 for (key, value) in &self.config.security_headers {
74 let header_name: reqwest::header::HeaderName = key.parse()
75 .map_err(|e| Error::Config(format!("Invalid header name '{}': {}", key, e)))?;
76
77 let header_value = if value == "generated" {
78 match key.as_str() {
79 "X-Request-ID" => HeaderValue::from_str(request_id)
80 .map_err(|e| Error::Config(format!("Invalid request ID: {}", e)))?,
81 "User-Agent" => HeaderValue::from_str(&format!("goldrush-sdk-rs/{}", env!("CARGO_PKG_VERSION")))
82 .map_err(|e| Error::Config(format!("Invalid user agent: {}", e)))?,
83 _ => HeaderValue::from_str(value)
84 .map_err(|e| Error::Config(format!("Invalid header value '{}': {}", value, e)))?,
85 }
86 } else {
87 HeaderValue::from_str(value)
88 .map_err(|e| Error::Config(format!("Invalid header value '{}': {}", value, e)))?
89 };
90
91 headers.insert(header_name, header_value);
92 }
93
94 headers.insert("X-Content-Type-Options", HeaderValue::from_static("nosniff"));
96 headers.insert("X-Frame-Options", HeaderValue::from_static("DENY"));
97 headers.insert("X-XSS-Protection", HeaderValue::from_static("1; mode=block"));
98
99 debug!("Generated {} security headers", headers.len());
100 Ok(headers)
101 }
102
103 #[instrument(skip(self), fields(size = %content_length))]
105 pub fn validate_request_size(&self, content_length: usize) -> Result<()> {
106 if content_length > self.config.max_request_size {
107 warn!(
108 size = %content_length,
109 max_size = %self.config.max_request_size,
110 "Request size exceeds maximum allowed"
111 );
112 return Err(Error::Config(format!(
113 "Request size {} exceeds maximum allowed size {}",
114 content_length, self.config.max_request_size
115 )));
116 }
117
118 debug!("Request size validation passed");
119 Ok(())
120 }
121
122 pub fn generate_timestamp(&self) -> u64 {
124 SystemTime::now()
125 .duration_since(UNIX_EPOCH)
126 .expect("Time went backwards")
127 .as_secs()
128 }
129
130 #[instrument(skip(self), fields(timestamp = %timestamp))]
132 pub fn validate_timestamp(&self, timestamp: u64, tolerance_secs: u64) -> Result<()> {
133 let current_time = self.generate_timestamp();
134 let time_diff = if current_time > timestamp {
135 current_time - timestamp
136 } else {
137 timestamp - current_time
138 };
139
140 if time_diff > tolerance_secs {
141 warn!(
142 timestamp = %timestamp,
143 current_time = %current_time,
144 diff = %time_diff,
145 tolerance = %tolerance_secs,
146 "Timestamp validation failed"
147 );
148 return Err(Error::Config("Request timestamp is outside acceptable range".to_string()));
149 }
150
151 debug!("Timestamp validation passed");
152 Ok(())
153 }
154
155 #[instrument(skip(self), fields(url = %url))]
157 pub fn sanitize_url(&self, url: &str) -> Result<String> {
158 let sanitized = url
160 .replace("../", "") .replace("..\\", "") .replace("<", "<") .replace(">", ">")
164 .replace("\"", """)
165 .replace("'", "'");
166
167 let suspicious_patterns = [
169 "javascript:", "data:", "vbscript:", "file:", "ftp:",
170 "mailto:", "news:", "gopher:", "ldap:", "telnet:",
171 ];
172
173 for pattern in &suspicious_patterns {
174 if sanitized.to_lowercase().contains(pattern) {
175 warn!(url = %url, pattern = %pattern, "Suspicious URL pattern detected");
176 return Err(Error::Config(format!("URL contains suspicious pattern: {}", pattern)));
177 }
178 }
179
180 debug!("URL sanitization completed");
181 Ok(sanitized)
182 }
183
184 #[instrument(skip(self, response_body), fields(size = %response_body.len()))]
186 pub fn validate_response(&self, response_body: &str) -> Result<()> {
187 if response_body.len() > self.config.max_request_size * 10 {
189 warn!(
190 size = %response_body.len(),
191 max_size = %self.config.max_request_size,
192 "Response size is unusually large"
193 );
194 return Err(Error::Config("Response size exceeds safety limits".to_string()));
195 }
196
197 let suspicious_scripts = [
199 "<script", "javascript:", "onclick=", "onerror=", "onload=",
200 "eval(", "setTimeout(", "setInterval(",
201 ];
202
203 for script in &suspicious_scripts {
204 if response_body.to_lowercase().contains(script) {
205 warn!(pattern = %script, "Potential script injection detected in response");
206 return Err(Error::Config("Response contains potentially malicious content".to_string()));
207 }
208 }
209
210 debug!("Response validation passed");
211 Ok(())
212 }
213
214 pub fn generate_nonce(&self) -> String {
216 use std::collections::hash_map::DefaultHasher;
217 use std::hash::{Hash, Hasher};
218
219 let mut hasher = DefaultHasher::new();
220 self.generate_timestamp().hash(&mut hasher);
221 std::thread::current().id().hash(&mut hasher);
222
223 format!("{:x}", hasher.finish())
224 }
225
226 #[instrument(skip(self, api_key, body), fields(method = %method, url = %url))]
228 pub fn create_request_signature(
229 &self,
230 method: &str,
231 url: &str,
232 api_key: &str,
233 body: &str,
234 timestamp: u64,
235 nonce: &str,
236 ) -> String {
237 use std::collections::hash_map::DefaultHasher;
238 use std::hash::{Hash, Hasher};
239
240 let signature_string = format!(
242 "{}|{}|{}|{}|{}|{}",
243 method.to_uppercase(),
244 url,
245 body,
246 api_key,
247 timestamp,
248 nonce
249 );
250
251 let mut hasher = DefaultHasher::new();
253 signature_string.hash(&mut hasher);
254
255 format!("{:x}", hasher.finish())
256 }
257
258 pub fn verify_tls_config(&self) -> Result<()> {
260 if self.config.enable_cert_pinning {
264 debug!("Certificate pinning is enabled");
265 }
267
268 debug!("TLS configuration verified");
269 Ok(())
270 }
271}
272
273#[derive(Debug, Clone)]
275pub struct SecurityContext {
276 pub request_id: String,
277 pub timestamp: u64,
278 pub nonce: String,
279 pub signature: Option<String>,
280 pub headers: HeaderMap,
281}
282
283impl SecurityContext {
284 pub fn new(request_id: String, security_manager: &SecurityManager) -> Result<Self> {
285 let timestamp = security_manager.generate_timestamp();
286 let nonce = security_manager.generate_nonce();
287 let headers = security_manager.generate_security_headers(&request_id)?;
288
289 Ok(Self {
290 request_id,
291 timestamp,
292 nonce,
293 signature: None,
294 headers,
295 })
296 }
297
298 pub fn with_signature(mut self, signature: String) -> Self {
300 self.signature = Some(signature);
301 self
302 }
303
304 pub fn is_signed(&self) -> bool {
306 self.signature.is_some()
307 }
308}
309
310#[cfg(test)]
311mod tests {
312 use super::*;
313
314 #[test]
315 fn test_api_key_masking() {
316 let config = SecurityConfig::default();
317 let security_manager = SecurityManager::new(config);
318
319 let api_key = "sk_test_1234567890abcdef";
320 let masked = security_manager.mask_api_key(api_key);
321
322 assert_eq!(masked, "sk_t***cdef");
323 assert!(!masked.contains("1234567890ab"));
324 }
325
326 #[test]
327 fn test_url_sanitization() {
328 let config = SecurityConfig::default();
329 let security_manager = SecurityManager::new(config);
330
331 let malicious_url = "https://api.example.com/../admin/users";
332 let sanitized = security_manager.sanitize_url(malicious_url).unwrap();
333
334 assert!(!sanitized.contains("../"));
335
336 let script_url = "javascript:alert('xss')";
337 let result = security_manager.sanitize_url(script_url);
338 assert!(result.is_err());
339 }
340
341 #[test]
342 fn test_request_size_validation() {
343 let config = SecurityConfig {
344 max_request_size: 1000,
345 ..Default::default()
346 };
347 let security_manager = SecurityManager::new(config);
348
349 assert!(security_manager.validate_request_size(500).is_ok());
350 assert!(security_manager.validate_request_size(1500).is_err());
351 }
352
353 #[test]
354 fn test_timestamp_validation() {
355 let config = SecurityConfig::default();
356 let security_manager = SecurityManager::new(config);
357
358 let current_time = security_manager.generate_timestamp();
359
360 assert!(security_manager.validate_timestamp(current_time, 60).is_ok());
362
363 assert!(security_manager.validate_timestamp(current_time - 3600, 60).is_err());
365
366 assert!(security_manager.validate_timestamp(current_time + 3600, 60).is_err());
368 }
369}