1use crate::security::{Result, SecurityError, ValidationConfig};
2use axum::{
3 extract::{Request, State},
4 http::{header, HeaderMap, StatusCode},
5 middleware::Next,
6 response::Response,
7};
8use regex::Regex;
9use serde::{Deserialize, Serialize};
10use std::sync::Arc;
11use tracing::{debug, warn};
12use validator::{Validate, ValidationError};
13
14pub struct ValidationManager {
16 config: ValidationConfig,
17 sql_injection_patterns: Vec<Regex>,
18 xss_patterns: Vec<Regex>,
19 malicious_patterns: Vec<Regex>,
20}
21
22impl ValidationManager {
23 pub fn new(config: ValidationConfig) -> Result<Self> {
24 let mut manager = Self {
25 config,
26 sql_injection_patterns: Vec::new(),
27 xss_patterns: Vec::new(),
28 malicious_patterns: Vec::new(),
29 };
30
31 if manager.config.enabled {
32 manager.initialize_patterns()?;
33 }
34
35 Ok(manager)
36 }
37
38 fn initialize_patterns(&mut self) -> Result<()> {
39 if self.config.sql_injection_protection {
41 let sql_patterns = vec![
42 r"(?i)(union\s+select)",
43 r"(?i)(drop\s+table)",
44 r"(?i)(delete\s+from)",
45 r"(?i)(insert\s+into)",
46 r"(?i)(update\s+set)",
47 r"(?i)(alter\s+table)",
48 r"(?i)(create\s+table)",
49 r"(?i)(exec\s*\()",
50 r"(?i)(execute\s*\()",
51 r"(?i)(\'\s*or\s*\'\s*=\s*\')",
52 r"(?i)(\'\s*or\s*1\s*=\s*1)",
53 r"(?i)(\'\s*;\s*drop)",
54 r"(?i)(--\s*)",
55 r"(?i)(/\*.*\*/)",
56 r"(?i)(xp_cmdshell)",
57 r"(?i)(sp_executesql)",
58 ];
59
60 for pattern in sql_patterns {
61 self.sql_injection_patterns
62 .push(
63 Regex::new(pattern).map_err(|e| SecurityError::ValidationError {
64 message: format!("Failed to compile SQL injection pattern: {e}"),
65 })?,
66 );
67 }
68 }
69
70 if self.config.xss_protection {
72 let xss_patterns = vec![
73 r"(?i)<script[^>]*>",
74 r"(?i)</script>",
75 r"(?i)<iframe[^>]*>",
76 r"(?i)<object[^>]*>",
77 r"(?i)<embed[^>]*>",
78 r"(?i)<link[^>]*>",
79 r"(?i)<meta[^>]*>",
80 r"(?i)javascript:",
81 r"(?i)vbscript:",
82 r"(?i)onload\s*=",
83 r"(?i)onerror\s*=",
84 r"(?i)onclick\s*=",
85 r"(?i)onmouseover\s*=",
86 r"(?i)onfocus\s*=",
87 r"(?i)onblur\s*=",
88 r"(?i)onchange\s*=",
89 r"(?i)onsubmit\s*=",
90 r"(?i)expression\s*\(",
91 r"(?i)url\s*\(",
92 r"(?i)@import",
93 ];
94
95 for pattern in xss_patterns {
96 self.xss_patterns.push(Regex::new(pattern).map_err(|e| {
97 SecurityError::ValidationError {
98 message: format!("Failed to compile XSS pattern: {e}"),
99 }
100 })?);
101 }
102 }
103
104 let malicious_patterns = vec![
106 r"(?i)(\.\.\/){2,}", r"(?i)\.\.\\", r"(?i)\/etc\/passwd",
109 r"(?i)\/etc\/shadow",
110 r"(?i)\/proc\/",
111 r"(?i)c:\\windows\\",
112 r"(?i)cmd\.exe",
113 r"(?i)powershell\.exe",
114 r"(?i)bash\s*-c",
115 r"(?i)sh\s*-c",
116 r"(?i)\$\([^)]*\)", r"(?i)`[^`]*`", ];
119
120 for pattern in malicious_patterns {
121 self.malicious_patterns
122 .push(
123 Regex::new(pattern).map_err(|e| SecurityError::ValidationError {
124 message: format!("Failed to compile malicious pattern: {e}"),
125 })?,
126 );
127 }
128
129 debug!(
130 "Initialized validation patterns: {} SQL, {} XSS, {} malicious",
131 self.sql_injection_patterns.len(),
132 self.xss_patterns.len(),
133 self.malicious_patterns.len()
134 );
135
136 Ok(())
137 }
138
139 pub fn validate_input(&self, input: &str) -> Result<String> {
141 if !self.config.enabled {
142 return Ok(input.to_string());
143 }
144
145 if self.config.sql_injection_protection {
147 for pattern in &self.sql_injection_patterns {
148 if pattern.is_match(input) {
149 warn!("SQL injection attempt detected: {}", pattern.as_str());
150 return Err(SecurityError::ValidationError {
151 message: "Potential SQL injection detected".to_string(),
152 });
153 }
154 }
155 }
156
157 let skip_xss_check =
159 std::env::var("SKIP_XSS_CHECK").unwrap_or_else(|_| "false".to_string()) == "true";
160
161 if self.config.xss_protection && !skip_xss_check {
162 for pattern in &self.xss_patterns {
163 if pattern.is_match(input) {
164 warn!("XSS attempt detected: {}", pattern.as_str());
165 return Err(SecurityError::ValidationError {
166 message: "Potential XSS detected".to_string(),
167 });
168 }
169 }
170 }
171
172 for pattern in &self.malicious_patterns {
174 if pattern.is_match(input) {
175 warn!("Malicious pattern detected: {}", pattern.as_str());
176 return Err(SecurityError::ValidationError {
177 message: "Malicious content detected".to_string(),
178 });
179 }
180 }
181
182 if self.config.sanitize_input {
184 Ok(self.sanitize_input(input))
185 } else {
186 Ok(input.to_string())
187 }
188 }
189
190 fn sanitize_input(&self, input: &str) -> String {
192 let mut sanitized = input.to_string();
193
194 sanitized = sanitized.replace('\0', "");
196
197 sanitized = sanitized.replace('\r', "");
199 sanitized = sanitized.replace('\n', " ");
200 sanitized = sanitized.replace('\t', " ");
201
202 if self.config.xss_protection {
204 sanitized = sanitized.replace('&', "&");
206 sanitized = sanitized.replace('<', "<");
207 sanitized = sanitized.replace('>', ">");
208 sanitized = sanitized.replace('"', """);
209 sanitized = sanitized.replace('\'', "'");
210 }
211
212 if sanitized.len() > 10000 {
214 sanitized.truncate(10000);
215 sanitized.push_str("...");
216 }
217
218 sanitized
219 }
220
221 pub fn validate_json(&self, json_str: &str) -> Result<serde_json::Value> {
223 if !self.config.enabled {
224 return serde_json::from_str(json_str).map_err(|e| SecurityError::ValidationError {
225 message: format!("Invalid JSON: {e}"),
226 });
227 }
228
229 if json_str.len() > self.config.max_request_size as usize {
231 return Err(SecurityError::ValidationError {
232 message: "Request size exceeds maximum allowed".to_string(),
233 });
234 }
235
236 let json_value: serde_json::Value =
238 serde_json::from_str(json_str).map_err(|e| SecurityError::ValidationError {
239 message: format!("Invalid JSON: {e}"),
240 })?;
241
242 self.validate_json_value(&json_value)?;
244
245 Ok(json_value)
246 }
247
248 fn validate_json_value(&self, value: &serde_json::Value) -> Result<()> {
249 match value {
250 serde_json::Value::String(s) => {
251 self.validate_input(s)?;
252 }
253 serde_json::Value::Array(arr) => {
254 for item in arr {
255 self.validate_json_value(item)?;
256 }
257 }
258 serde_json::Value::Object(obj) => {
259 for (key, val) in obj {
260 self.validate_input(key)?;
261 self.validate_json_value(val)?;
262 }
263 }
264 _ => {} }
266 Ok(())
267 }
268
269 pub fn validate_headers(&self, headers: &HeaderMap) -> Result<()> {
271 if !self.config.enabled {
272 return Ok(());
273 }
274
275 if let Some(user_agent) = headers.get(header::USER_AGENT) {
277 if let Ok(ua_str) = user_agent.to_str() {
278 self.validate_input(ua_str)?;
279
280 let suspicious_patterns = vec![
282 r"(?i)(sqlmap|nmap|nikto|dirb|gobuster)",
283 r"(?i)(masscan|zap|burp|wget|curl)",
284 r"(?i)(python-requests|libwww-perl)",
285 ];
286
287 for pattern_str in suspicious_patterns {
288 let pattern = Regex::new(pattern_str).unwrap();
289 if pattern.is_match(ua_str) {
290 warn!("Suspicious user agent detected: {}", ua_str);
291 return Err(SecurityError::ValidationError {
292 message: "Suspicious user agent".to_string(),
293 });
294 }
295 }
296 }
297 }
298
299 if let Some(referer) = headers.get(header::REFERER) {
301 if let Ok(referer_str) = referer.to_str() {
302 self.validate_input(referer_str)?;
303 }
304 }
305
306 for (_name, value) in headers {
308 if let Ok(value_str) = value.to_str() {
309 self.validate_input(value_str)?;
310 }
311 }
312
313 Ok(())
314 }
315
316 pub fn validate_content_type(&self, content_type: Option<&str>) -> Result<()> {
318 if !self.config.enabled {
319 return Ok(());
320 }
321
322 let allowed_types = [
323 "application/json",
324 "application/x-www-form-urlencoded",
325 "text/plain",
326 "multipart/form-data",
327 ];
328
329 if let Some(ct) = content_type {
330 let ct_main = ct.split(';').next().unwrap_or(ct).trim();
331
332 if !allowed_types.contains(&ct_main) {
333 return Err(SecurityError::ValidationError {
334 message: format!("Content type not allowed: {ct_main}"),
335 });
336 }
337 }
338
339 Ok(())
340 }
341
342 pub fn is_enabled(&self) -> bool {
343 self.config.enabled
344 }
345
346 pub fn get_max_request_size(&self) -> u64 {
347 self.config.max_request_size
348 }
349}
350
351#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
353pub struct ValidatedRequest {
354 #[validate(length(
355 min = 1,
356 max = 1000,
357 message = "Content must be between 1 and 1000 characters"
358 ))]
359 pub content: Option<String>,
360
361 #[validate(email(message = "Invalid email format"))]
362 pub email: Option<String>,
363
364 #[validate(url(message = "Invalid URL format"))]
365 pub url: Option<String>,
366
367 #[validate(range(min = 1, max = 1000, message = "Limit must be between 1 and 1000"))]
368 pub limit: Option<i32>,
369
370 #[validate(range(min = 0, message = "Offset must be non-negative"))]
371 pub offset: Option<i32>,
372
373 #[validate(length(min = 1, max = 1000))]
374 pub query: Option<String>,
375}
376
377#[allow(dead_code)]
379fn validate_safe_string(value: &str) -> std::result::Result<(), ValidationError> {
380 if value.contains("<script") || value.contains("javascript:") || value.contains("../../") {
382 return Err(ValidationError::new("unsafe_content"));
383 }
384
385 Ok(())
386}
387
388pub async fn validation_middleware(
390 State(validator): State<Arc<ValidationManager>>,
391 headers: HeaderMap,
392 request: Request,
393 next: Next,
394) -> std::result::Result<Response, StatusCode> {
395 if !validator.is_enabled() {
396 return Ok(next.run(request).await);
397 }
398
399 if validator.validate_headers(&headers).is_err() {
401 warn!("Request validation failed: invalid headers");
402 return Err(StatusCode::BAD_REQUEST);
403 }
404
405 let content_type = headers
407 .get(header::CONTENT_TYPE)
408 .and_then(|ct| ct.to_str().ok());
409
410 if validator.validate_content_type(content_type).is_err() {
411 warn!("Request validation failed: invalid content type");
412 return Err(StatusCode::UNSUPPORTED_MEDIA_TYPE);
413 }
414
415 if let Some(content_length) = headers.get(header::CONTENT_LENGTH) {
417 if let Ok(length_str) = content_length.to_str() {
418 if let Ok(length) = length_str.parse::<u64>() {
419 if length > validator.get_max_request_size() {
420 warn!(
421 "Request validation failed: request too large ({} bytes)",
422 length
423 );
424 return Err(StatusCode::PAYLOAD_TOO_LARGE);
425 }
426 }
427 }
428 }
429
430 debug!("Request validation passed");
431 Ok(next.run(request).await)
432}
433
434#[cfg(test)]
435mod tests {
436 use super::*;
437
438 #[test]
439 fn test_validation_manager_creation() {
440 let config = ValidationConfig::default();
441 let manager = ValidationManager::new(config).unwrap();
442 assert!(manager.is_enabled());
443 }
444
445 #[test]
446 fn test_sql_injection_detection() {
447 let mut config = ValidationConfig::default();
448 config.sql_injection_protection = true;
449
450 let manager = ValidationManager::new(config).unwrap();
451
452 let result = manager.validate_input("SELECT name FROM users WHERE id = 1");
454 assert!(result.is_ok());
455
456 let result = manager.validate_input("'; DROP TABLE users; --");
458 assert!(result.is_err());
459
460 if let Err(SecurityError::ValidationError { message }) = result {
461 assert!(message.contains("SQL injection"));
462 }
463 }
464
465 #[test]
466 fn test_xss_detection() {
467 let mut config = ValidationConfig::default();
468 config.xss_protection = true;
469
470 let manager = ValidationManager::new(config).unwrap();
471
472 let result = manager.validate_input("Hello world!");
474 assert!(result.is_ok());
475
476 let result = manager.validate_input("<script>alert('xss')</script>");
478 assert!(result.is_err());
479
480 if let Err(SecurityError::ValidationError { message }) = result {
481 assert!(message.contains("XSS"));
482 }
483 }
484
485 #[test]
486 fn test_input_sanitization() {
487 let mut config = ValidationConfig::default();
488 config.sanitize_input = true;
489 config.xss_protection = true;
490
491 let manager = ValidationManager::new(config).unwrap();
492
493 let result = manager
494 .validate_input("Hello <world> & 'test' \"quote\"")
495 .unwrap();
496 assert_eq!(
497 result,
498 "Hello <world> & 'test' "quote""
499 );
500 }
501
502 #[test]
503 fn test_malicious_pattern_detection() {
504 let config = ValidationConfig::default();
505 let manager = ValidationManager::new(config).unwrap();
506
507 let result = manager.validate_input("../../../etc/passwd");
509 assert!(result.is_err());
510
511 let result = manager.validate_input("test; rm -rf /");
513 assert!(result.is_ok()); let result = manager.validate_input("../../etc/shadow");
517 assert!(result.is_err());
518 }
519
520 #[test]
521 fn test_json_validation() {
522 let config = ValidationConfig::default();
523 let manager = ValidationManager::new(config).unwrap();
524
525 let json = r#"{"name": "test", "value": 123}"#;
527 let result = manager.validate_json(json);
528 assert!(result.is_ok());
529
530 let invalid_json = r#"{"name": "test", "value":}"#;
532 let result = manager.validate_json(invalid_json);
533 assert!(result.is_err());
534 }
535
536 #[test]
537 fn test_json_content_validation() {
538 let mut config = ValidationConfig::default();
539 config.xss_protection = true;
540
541 let manager = ValidationManager::new(config).unwrap();
542
543 let json = r#"{"comment": "<script>alert('xss')</script>"}"#;
545 let result = manager.validate_json(json);
546 assert!(result.is_err());
547 }
548
549 #[test]
550 fn test_content_type_validation() {
551 let config = ValidationConfig::default();
552 let manager = ValidationManager::new(config).unwrap();
553
554 let result = manager.validate_content_type(Some("application/json"));
556 assert!(result.is_ok());
557
558 let result = manager.validate_content_type(Some("application/x-executable"));
560 assert!(result.is_err());
561
562 let result = manager.validate_content_type(Some("application/json; charset=utf-8"));
564 assert!(result.is_ok());
565 }
566
567 #[test]
568 fn test_validation_disabled() {
569 let mut config = ValidationConfig::default();
570 config.enabled = false;
571
572 let manager = ValidationManager::new(config).unwrap();
573 assert!(!manager.is_enabled());
574
575 let result = manager.validate_input("<script>alert('xss')</script>");
577 assert!(result.is_ok());
578 }
579
580 #[test]
581 fn test_validated_request_struct() {
582 let request = ValidatedRequest {
583 content: Some("Hello world".to_string()),
584 email: Some("test@example.com".to_string()),
585 url: Some("https://example.com".to_string()),
586 limit: Some(100),
587 offset: Some(0),
588 query: Some("safe query".to_string()),
589 };
590
591 let validation_result = request.validate();
592 assert!(validation_result.is_ok());
593 }
594
595 #[test]
596 fn test_validated_request_invalid() {
597 let request = ValidatedRequest {
598 content: Some("".to_string()), email: Some("invalid-email".to_string()), url: Some("not-a-url".to_string()), limit: Some(2000), offset: Some(-1), query: Some("<script>alert('xss')</script>".to_string()), };
605
606 let validation_result = request.validate();
607 assert!(validation_result.is_err());
608
609 let errors = validation_result.unwrap_err();
610 assert!(!errors.field_errors().is_empty());
611 }
612
613 #[test]
614 fn test_custom_validator() {
615 let valid_result = validate_safe_string("This is a safe string");
616 assert!(valid_result.is_ok());
617
618 let invalid_result = validate_safe_string("<script>alert('test')</script>");
619 assert!(invalid_result.is_err());
620
621 let traversal_result = validate_safe_string("../../etc/passwd");
622 assert!(traversal_result.is_err());
623 }
624
625 #[test]
626 fn test_request_size_limits() {
627 let config = ValidationConfig {
628 enabled: true,
629 max_request_size: 1024, sanitize_input: true,
631 xss_protection: true,
632 sql_injection_protection: true,
633 };
634
635 let manager = ValidationManager::new(config).unwrap();
636 assert_eq!(manager.get_max_request_size(), 1024);
637
638 let large_json = "x".repeat(2000);
640 let json = format!(r#"{{"data": "{large_json}"}}"#);
641 let result = manager.validate_json(&json);
642 assert!(result.is_err());
643
644 if let Err(SecurityError::ValidationError { message }) = result {
645 assert!(message.contains("exceeds maximum"));
646 }
647 }
648}