1use crate::security::{Result, SecurityError, ValidationConfig};
2use axum::{
3 extract::{Request, State},
4 http::{header, HeaderMap, StatusCode},
5 middleware::Next,
6 response::Response,
7};
8use regex::Regex;
9use serde::{Deserialize, Serialize};
10use std::sync::Arc;
11use tracing::{debug, warn};
12use validator::{Validate, ValidationError};
13
14pub struct ValidationManager {
16 config: ValidationConfig,
17 sql_injection_patterns: Vec<Regex>,
18 xss_patterns: Vec<Regex>,
19 malicious_patterns: Vec<Regex>,
20}
21
22impl ValidationManager {
23 pub fn new(config: ValidationConfig) -> Result<Self> {
24 let mut manager = Self {
25 config,
26 sql_injection_patterns: Vec::new(),
27 xss_patterns: Vec::new(),
28 malicious_patterns: Vec::new(),
29 };
30
31 if manager.config.enabled {
32 manager.initialize_patterns()?;
33 }
34
35 Ok(manager)
36 }
37
38 fn initialize_patterns(&mut self) -> Result<()> {
39 if self.config.sql_injection_protection {
41 let sql_patterns = vec![
42 r"(?i)(union\s+select)",
43 r"(?i)(drop\s+table)",
44 r"(?i)(delete\s+from)",
45 r"(?i)(insert\s+into)",
46 r"(?i)(update\s+set)",
47 r"(?i)(alter\s+table)",
48 r"(?i)(create\s+table)",
49 r"(?i)(exec\s*\()",
50 r"(?i)(execute\s*\()",
51 r"(?i)(\'\s*or\s*\'\s*=\s*\')",
52 r"(?i)(\'\s*or\s*1\s*=\s*1)",
53 r"(?i)(\'\s*;\s*drop)",
54 r"(?i)(--\s*)",
55 r"(?i)(/\*.*\*/)",
56 r"(?i)(xp_cmdshell)",
57 r"(?i)(sp_executesql)",
58 ];
59
60 for pattern in sql_patterns {
61 self.sql_injection_patterns
62 .push(
63 Regex::new(pattern).map_err(|e| SecurityError::ValidationError {
64 message: format!("Failed to compile SQL injection pattern: {e}"),
65 })?,
66 );
67 }
68 }
69
70 if self.config.xss_protection {
72 let xss_patterns = vec![
73 r"(?i)<script[^>]*>",
74 r"(?i)</script>",
75 r"(?i)<iframe[^>]*>",
76 r"(?i)<object[^>]*>",
77 r"(?i)<embed[^>]*>",
78 r"(?i)<link[^>]*>",
79 r"(?i)<meta[^>]*>",
80 r"(?i)javascript:",
81 r"(?i)vbscript:",
82 r"(?i)onload\s*=",
83 r"(?i)onerror\s*=",
84 r"(?i)onclick\s*=",
85 r"(?i)onmouseover\s*=",
86 r"(?i)onfocus\s*=",
87 r"(?i)onblur\s*=",
88 r"(?i)onchange\s*=",
89 r"(?i)onsubmit\s*=",
90 r"(?i)expression\s*\(",
91 r"(?i)url\s*\(",
92 r"(?i)@import",
93 ];
94
95 for pattern in xss_patterns {
96 self.xss_patterns.push(Regex::new(pattern).map_err(|e| {
97 SecurityError::ValidationError {
98 message: format!("Failed to compile XSS pattern: {e}"),
99 }
100 })?);
101 }
102 }
103
104 let malicious_patterns = vec![
106 r"(?i)(\.\.\/){2,}", r"(?i)\.\.\\", r"(?i)\/etc\/passwd",
109 r"(?i)\/etc\/shadow",
110 r"(?i)\/proc\/",
111 r"(?i)c:\\windows\\",
112 r"(?i)cmd\.exe",
113 r"(?i)powershell\.exe",
114 r"(?i)bash\s*-c",
115 r"(?i)sh\s*-c",
116 r"(?i)\$\([^)]*\)", r"(?i)`[^`]*`", ];
119
120 for pattern in malicious_patterns {
121 self.malicious_patterns
122 .push(
123 Regex::new(pattern).map_err(|e| SecurityError::ValidationError {
124 message: format!("Failed to compile malicious pattern: {e}"),
125 })?,
126 );
127 }
128
129 debug!(
130 "Initialized validation patterns: {} SQL, {} XSS, {} malicious",
131 self.sql_injection_patterns.len(),
132 self.xss_patterns.len(),
133 self.malicious_patterns.len()
134 );
135
136 Ok(())
137 }
138
139 pub fn validate_input(&self, input: &str) -> Result<String> {
141 if !self.config.enabled {
142 return Ok(input.to_string());
143 }
144
145 if self.config.sql_injection_protection {
147 for pattern in &self.sql_injection_patterns {
148 if pattern.is_match(input) {
149 warn!("SQL injection attempt detected: {}", pattern.as_str());
150 return Err(SecurityError::ValidationError {
151 message: "Potential SQL injection detected".to_string(),
152 });
153 }
154 }
155 }
156
157 if self.config.xss_protection {
159 for pattern in &self.xss_patterns {
160 if pattern.is_match(input) {
161 warn!("XSS attempt detected: {}", pattern.as_str());
162 return Err(SecurityError::ValidationError {
163 message: "Potential XSS detected".to_string(),
164 });
165 }
166 }
167 }
168
169 for pattern in &self.malicious_patterns {
171 if pattern.is_match(input) {
172 warn!("Malicious pattern detected: {}", pattern.as_str());
173 return Err(SecurityError::ValidationError {
174 message: "Malicious content detected".to_string(),
175 });
176 }
177 }
178
179 if self.config.sanitize_input {
181 Ok(self.sanitize_input(input))
182 } else {
183 Ok(input.to_string())
184 }
185 }
186
187 fn sanitize_input(&self, input: &str) -> String {
189 let mut sanitized = input.to_string();
190
191 sanitized = sanitized.replace('\0', "");
193
194 sanitized = sanitized.replace('\r', "");
196 sanitized = sanitized.replace('\n', " ");
197 sanitized = sanitized.replace('\t', " ");
198
199 if self.config.xss_protection {
201 sanitized = sanitized.replace('<', "<");
202 sanitized = sanitized.replace('>', ">");
203 sanitized = sanitized.replace('"', """);
204 sanitized = sanitized.replace('\'', "'");
205 sanitized = sanitized.replace('&', "&");
206 }
207
208 if sanitized.len() > 10000 {
210 sanitized.truncate(10000);
211 sanitized.push_str("...");
212 }
213
214 sanitized
215 }
216
217 pub fn validate_json(&self, json_str: &str) -> Result<serde_json::Value> {
219 if !self.config.enabled {
220 return serde_json::from_str(json_str).map_err(|e| SecurityError::ValidationError {
221 message: format!("Invalid JSON: {e}"),
222 });
223 }
224
225 if json_str.len() > self.config.max_request_size as usize {
227 return Err(SecurityError::ValidationError {
228 message: "Request size exceeds maximum allowed".to_string(),
229 });
230 }
231
232 let json_value: serde_json::Value =
234 serde_json::from_str(json_str).map_err(|e| SecurityError::ValidationError {
235 message: format!("Invalid JSON: {e}"),
236 })?;
237
238 self.validate_json_value(&json_value)?;
240
241 Ok(json_value)
242 }
243
244 fn validate_json_value(&self, value: &serde_json::Value) -> Result<()> {
245 match value {
246 serde_json::Value::String(s) => {
247 self.validate_input(s)?;
248 }
249 serde_json::Value::Array(arr) => {
250 for item in arr {
251 self.validate_json_value(item)?;
252 }
253 }
254 serde_json::Value::Object(obj) => {
255 for (key, val) in obj {
256 self.validate_input(key)?;
257 self.validate_json_value(val)?;
258 }
259 }
260 _ => {} }
262 Ok(())
263 }
264
265 pub fn validate_headers(&self, headers: &HeaderMap) -> Result<()> {
267 if !self.config.enabled {
268 return Ok(());
269 }
270
271 if let Some(user_agent) = headers.get(header::USER_AGENT) {
273 if let Ok(ua_str) = user_agent.to_str() {
274 self.validate_input(ua_str)?;
275
276 let suspicious_patterns = vec![
278 r"(?i)(sqlmap|nmap|nikto|dirb|gobuster)",
279 r"(?i)(masscan|zap|burp|wget|curl)",
280 r"(?i)(python-requests|libwww-perl)",
281 ];
282
283 for pattern_str in suspicious_patterns {
284 let pattern = Regex::new(pattern_str).unwrap();
285 if pattern.is_match(ua_str) {
286 warn!("Suspicious user agent detected: {}", ua_str);
287 return Err(SecurityError::ValidationError {
288 message: "Suspicious user agent".to_string(),
289 });
290 }
291 }
292 }
293 }
294
295 if let Some(referer) = headers.get(header::REFERER) {
297 if let Ok(referer_str) = referer.to_str() {
298 self.validate_input(referer_str)?;
299 }
300 }
301
302 for (_name, value) in headers {
304 if let Ok(value_str) = value.to_str() {
305 self.validate_input(value_str)?;
306 }
307 }
308
309 Ok(())
310 }
311
312 pub fn validate_content_type(&self, content_type: Option<&str>) -> Result<()> {
314 if !self.config.enabled {
315 return Ok(());
316 }
317
318 let allowed_types = [
319 "application/json",
320 "application/x-www-form-urlencoded",
321 "text/plain",
322 "multipart/form-data",
323 ];
324
325 if let Some(ct) = content_type {
326 let ct_main = ct.split(';').next().unwrap_or(ct).trim();
327
328 if !allowed_types.contains(&ct_main) {
329 return Err(SecurityError::ValidationError {
330 message: format!("Content type not allowed: {ct_main}"),
331 });
332 }
333 }
334
335 Ok(())
336 }
337
338 pub fn is_enabled(&self) -> bool {
339 self.config.enabled
340 }
341
342 pub fn get_max_request_size(&self) -> u64 {
343 self.config.max_request_size
344 }
345}
346
347#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
349pub struct ValidatedRequest {
350 #[validate(length(
351 min = 1,
352 max = 1000,
353 message = "Content must be between 1 and 1000 characters"
354 ))]
355 pub content: Option<String>,
356
357 #[validate(email(message = "Invalid email format"))]
358 pub email: Option<String>,
359
360 #[validate(url(message = "Invalid URL format"))]
361 pub url: Option<String>,
362
363 #[validate(range(min = 1, max = 1000, message = "Limit must be between 1 and 1000"))]
364 pub limit: Option<i32>,
365
366 #[validate(range(min = 0, message = "Offset must be non-negative"))]
367 pub offset: Option<i32>,
368
369 #[validate(length(min = 1, max = 1000))]
370 pub query: Option<String>,
371}
372
373#[allow(dead_code)]
375fn validate_safe_string(value: &str) -> std::result::Result<(), ValidationError> {
376 if value.contains("<script") || value.contains("javascript:") || value.contains("../../") {
378 return Err(ValidationError::new("unsafe_content"));
379 }
380
381 Ok(())
382}
383
384pub async fn validation_middleware(
386 State(validator): State<Arc<ValidationManager>>,
387 headers: HeaderMap,
388 request: Request,
389 next: Next,
390) -> std::result::Result<Response, StatusCode> {
391 if !validator.is_enabled() {
392 return Ok(next.run(request).await);
393 }
394
395 if validator.validate_headers(&headers).is_err() {
397 warn!("Request validation failed: invalid headers");
398 return Err(StatusCode::BAD_REQUEST);
399 }
400
401 let content_type = headers
403 .get(header::CONTENT_TYPE)
404 .and_then(|ct| ct.to_str().ok());
405
406 if validator.validate_content_type(content_type).is_err() {
407 warn!("Request validation failed: invalid content type");
408 return Err(StatusCode::UNSUPPORTED_MEDIA_TYPE);
409 }
410
411 if let Some(content_length) = headers.get(header::CONTENT_LENGTH) {
413 if let Ok(length_str) = content_length.to_str() {
414 if let Ok(length) = length_str.parse::<u64>() {
415 if length > validator.get_max_request_size() {
416 warn!(
417 "Request validation failed: request too large ({} bytes)",
418 length
419 );
420 return Err(StatusCode::PAYLOAD_TOO_LARGE);
421 }
422 }
423 }
424 }
425
426 debug!("Request validation passed");
427 Ok(next.run(request).await)
428}
429
430#[cfg(test)]
431mod tests {
432 use super::*;
433
434 #[test]
435 fn test_validation_manager_creation() {
436 let config = ValidationConfig::default();
437 let manager = ValidationManager::new(config).unwrap();
438 assert!(manager.is_enabled());
439 }
440
441 #[test]
442 fn test_sql_injection_detection() {
443 let mut config = ValidationConfig::default();
444 config.sql_injection_protection = true;
445
446 let manager = ValidationManager::new(config).unwrap();
447
448 let result = manager.validate_input("SELECT name FROM users WHERE id = 1");
450 assert!(result.is_ok());
451
452 let result = manager.validate_input("'; DROP TABLE users; --");
454 assert!(result.is_err());
455
456 if let Err(SecurityError::ValidationError { message }) = result {
457 assert!(message.contains("SQL injection"));
458 }
459 }
460
461 #[test]
462 fn test_xss_detection() {
463 let mut config = ValidationConfig::default();
464 config.xss_protection = true;
465
466 let manager = ValidationManager::new(config).unwrap();
467
468 let result = manager.validate_input("Hello world!");
470 assert!(result.is_ok());
471
472 let result = manager.validate_input("<script>alert('xss')</script>");
474 assert!(result.is_err());
475
476 if let Err(SecurityError::ValidationError { message }) = result {
477 assert!(message.contains("XSS"));
478 }
479 }
480
481 #[test]
482 fn test_input_sanitization() {
483 let mut config = ValidationConfig::default();
484 config.sanitize_input = true;
485 config.xss_protection = true;
486
487 let manager = ValidationManager::new(config).unwrap();
488
489 let result = manager
490 .validate_input("Hello <world> & 'test' \"quote\"")
491 .unwrap();
492 assert_eq!(
493 result,
494 "Hello <world> & 'test' "quote""
495 );
496 }
497
498 #[test]
499 fn test_malicious_pattern_detection() {
500 let config = ValidationConfig::default();
501 let manager = ValidationManager::new(config).unwrap();
502
503 let result = manager.validate_input("../../../etc/passwd");
505 assert!(result.is_err());
506
507 let result = manager.validate_input("test; rm -rf /");
509 assert!(result.is_ok()); let result = manager.validate_input("../../etc/shadow");
513 assert!(result.is_err());
514 }
515
516 #[test]
517 fn test_json_validation() {
518 let config = ValidationConfig::default();
519 let manager = ValidationManager::new(config).unwrap();
520
521 let json = r#"{"name": "test", "value": 123}"#;
523 let result = manager.validate_json(json);
524 assert!(result.is_ok());
525
526 let invalid_json = r#"{"name": "test", "value":}"#;
528 let result = manager.validate_json(invalid_json);
529 assert!(result.is_err());
530 }
531
532 #[test]
533 fn test_json_content_validation() {
534 let mut config = ValidationConfig::default();
535 config.xss_protection = true;
536
537 let manager = ValidationManager::new(config).unwrap();
538
539 let json = r#"{"comment": "<script>alert('xss')</script>"}"#;
541 let result = manager.validate_json(json);
542 assert!(result.is_err());
543 }
544
545 #[test]
546 fn test_content_type_validation() {
547 let config = ValidationConfig::default();
548 let manager = ValidationManager::new(config).unwrap();
549
550 let result = manager.validate_content_type(Some("application/json"));
552 assert!(result.is_ok());
553
554 let result = manager.validate_content_type(Some("application/x-executable"));
556 assert!(result.is_err());
557
558 let result = manager.validate_content_type(Some("application/json; charset=utf-8"));
560 assert!(result.is_ok());
561 }
562
563 #[test]
564 fn test_validation_disabled() {
565 let mut config = ValidationConfig::default();
566 config.enabled = false;
567
568 let manager = ValidationManager::new(config).unwrap();
569 assert!(!manager.is_enabled());
570
571 let result = manager.validate_input("<script>alert('xss')</script>");
573 assert!(result.is_ok());
574 }
575
576 #[test]
577 fn test_validated_request_struct() {
578 let request = ValidatedRequest {
579 content: Some("Hello world".to_string()),
580 email: Some("test@example.com".to_string()),
581 url: Some("https://example.com".to_string()),
582 limit: Some(100),
583 offset: Some(0),
584 query: Some("safe query".to_string()),
585 };
586
587 let validation_result = request.validate();
588 assert!(validation_result.is_ok());
589 }
590
591 #[test]
592 fn test_validated_request_invalid() {
593 let request = ValidatedRequest {
594 content: Some("".to_string()), email: Some("invalid-email".to_string()), url: Some("not-a-url".to_string()), limit: Some(2000), offset: Some(-1), query: Some("<script>alert('xss')</script>".to_string()), };
601
602 let validation_result = request.validate();
603 assert!(validation_result.is_err());
604
605 let errors = validation_result.unwrap_err();
606 assert!(!errors.field_errors().is_empty());
607 }
608
609 #[test]
610 fn test_custom_validator() {
611 let valid_result = validate_safe_string("This is a safe string");
612 assert!(valid_result.is_ok());
613
614 let invalid_result = validate_safe_string("<script>alert('test')</script>");
615 assert!(invalid_result.is_err());
616
617 let traversal_result = validate_safe_string("../../etc/passwd");
618 assert!(traversal_result.is_err());
619 }
620
621 #[test]
622 fn test_request_size_limits() {
623 let config = ValidationConfig {
624 enabled: true,
625 max_request_size: 1024, sanitize_input: true,
627 xss_protection: true,
628 sql_injection_protection: true,
629 };
630
631 let manager = ValidationManager::new(config).unwrap();
632 assert_eq!(manager.get_max_request_size(), 1024);
633
634 let large_json = "x".repeat(2000);
636 let json = format!(r#"{{"data": "{large_json}"}}"#);
637 let result = manager.validate_json(&json);
638 assert!(result.is_err());
639
640 if let Err(SecurityError::ValidationError { message }) = result {
641 assert!(message.contains("exceeds maximum"));
642 }
643 }
644}