1use crate::security::{Result, SecurityError, ValidationConfig};
2use axum::{
3 extract::{Request, State},
4 http::{header, HeaderMap, StatusCode},
5 middleware::Next,
6 response::Response,
7};
8use regex::Regex;
9use serde::{Deserialize, Serialize};
10use std::sync::Arc;
11use tracing::{debug, warn};
12use validator::{Validate, ValidationError};
13
14pub struct ValidationManager {
16 config: ValidationConfig,
17 sql_injection_patterns: Vec<Regex>,
18 xss_patterns: Vec<Regex>,
19 malicious_patterns: Vec<Regex>,
20}
21
22impl ValidationManager {
23 pub fn new(config: ValidationConfig) -> Result<Self> {
24 let mut manager = Self {
25 config,
26 sql_injection_patterns: Vec::new(),
27 xss_patterns: Vec::new(),
28 malicious_patterns: Vec::new(),
29 };
30
31 if manager.config.enabled {
32 manager.initialize_patterns()?;
33 }
34
35 Ok(manager)
36 }
37
38 fn initialize_patterns(&mut self) -> Result<()> {
39 if self.config.sql_injection_protection {
41 let sql_patterns = vec![
42 r"(?i)(union\s+select)",
43 r"(?i)(drop\s+table)",
44 r"(?i)(delete\s+from)",
45 r"(?i)(insert\s+into)",
46 r"(?i)(update\s+set)",
47 r"(?i)(alter\s+table)",
48 r"(?i)(create\s+table)",
49 r"(?i)(exec\s*\()",
50 r"(?i)(execute\s*\()",
51 r"(?i)(\'\s*or\s*\'\s*=\s*\')",
52 r"(?i)(\'\s*or\s*1\s*=\s*1)",
53 r"(?i)(\'\s*;\s*drop)",
54 r"(?i)(--\s*)",
55 r"(?i)(/\*.*\*/)",
56 r"(?i)(xp_cmdshell)",
57 r"(?i)(sp_executesql)",
58 ];
59
60 for pattern in sql_patterns {
61 self.sql_injection_patterns
62 .push(
63 Regex::new(pattern).map_err(|e| SecurityError::ValidationError {
64 message: format!("Failed to compile SQL injection pattern: {e}"),
65 })?,
66 );
67 }
68 }
69
70 if self.config.xss_protection {
72 let xss_patterns = vec![
73 r"(?i)<script[^>]*>",
74 r"(?i)</script>",
75 r"(?i)<iframe[^>]*>",
76 r"(?i)<object[^>]*>",
77 r"(?i)<embed[^>]*>",
78 r"(?i)<link[^>]*>",
79 r"(?i)<meta[^>]*>",
80 r"(?i)javascript:",
81 r"(?i)vbscript:",
82 r"(?i)onload\s*=",
83 r"(?i)onerror\s*=",
84 r"(?i)onclick\s*=",
85 r"(?i)onmouseover\s*=",
86 r"(?i)onfocus\s*=",
87 r"(?i)onblur\s*=",
88 r"(?i)onchange\s*=",
89 r"(?i)onsubmit\s*=",
90 r"(?i)expression\s*\(",
91 r"(?i)url\s*\(",
92 r"(?i)@import",
93 ];
94
95 for pattern in xss_patterns {
96 self.xss_patterns.push(Regex::new(pattern).map_err(|e| {
97 SecurityError::ValidationError {
98 message: format!("Failed to compile XSS pattern: {e}"),
99 }
100 })?);
101 }
102 }
103
104 let malicious_patterns = vec![
106 r"(?i)(\.\.\/){2,}", r"(?i)\.\.\\", r"(?i)\/etc\/passwd",
109 r"(?i)\/etc\/shadow",
110 r"(?i)\/proc\/",
111 r"(?i)c:\\windows\\",
112 r"(?i)cmd\.exe",
113 r"(?i)powershell\.exe",
114 r"(?i)bash\s*-c",
115 r"(?i)sh\s*-c",
116 r"(?i)\$\([^)]*\)", r"(?i)`[^`]*`", ];
119
120 for pattern in malicious_patterns {
121 self.malicious_patterns
122 .push(
123 Regex::new(pattern).map_err(|e| SecurityError::ValidationError {
124 message: format!("Failed to compile malicious pattern: {e}"),
125 })?,
126 );
127 }
128
129 debug!(
130 "Initialized validation patterns: {} SQL, {} XSS, {} malicious",
131 self.sql_injection_patterns.len(),
132 self.xss_patterns.len(),
133 self.malicious_patterns.len()
134 );
135
136 Ok(())
137 }
138
139 pub fn validate_input(&self, input: &str) -> Result<String> {
141 if !self.config.enabled {
142 return Ok(input.to_string());
143 }
144
145 if self.config.sql_injection_protection {
147 for pattern in &self.sql_injection_patterns {
148 if pattern.is_match(input) {
149 warn!("SQL injection attempt detected: {}", pattern.as_str());
150 return Err(SecurityError::ValidationError {
151 message: "Potential SQL injection detected".to_string(),
152 });
153 }
154 }
155 }
156
157 if self.config.xss_protection {
159 for pattern in &self.xss_patterns {
160 if pattern.is_match(input) {
161 warn!("XSS attempt detected: {}", pattern.as_str());
162 return Err(SecurityError::ValidationError {
163 message: "Potential XSS detected".to_string(),
164 });
165 }
166 }
167 }
168
169 for pattern in &self.malicious_patterns {
171 if pattern.is_match(input) {
172 warn!("Malicious pattern detected: {}", pattern.as_str());
173 return Err(SecurityError::ValidationError {
174 message: "Malicious content detected".to_string(),
175 });
176 }
177 }
178
179 if self.config.sanitize_input {
181 Ok(self.sanitize_input(input))
182 } else {
183 Ok(input.to_string())
184 }
185 }
186
187 fn sanitize_input(&self, input: &str) -> String {
189 let mut sanitized = input.to_string();
190
191 sanitized = sanitized.replace('\0', "");
193
194 sanitized = sanitized.replace('\r', "");
196 sanitized = sanitized.replace('\n', " ");
197 sanitized = sanitized.replace('\t', " ");
198
199 if self.config.xss_protection {
201 sanitized = sanitized.replace('&', "&");
203 sanitized = sanitized.replace('<', "<");
204 sanitized = sanitized.replace('>', ">");
205 sanitized = sanitized.replace('"', """);
206 sanitized = sanitized.replace('\'', "'");
207 }
208
209 if sanitized.len() > 10000 {
211 sanitized.truncate(10000);
212 sanitized.push_str("...");
213 }
214
215 sanitized
216 }
217
218 pub fn validate_json(&self, json_str: &str) -> Result<serde_json::Value> {
220 if !self.config.enabled {
221 return serde_json::from_str(json_str).map_err(|e| SecurityError::ValidationError {
222 message: format!("Invalid JSON: {e}"),
223 });
224 }
225
226 if json_str.len() > self.config.max_request_size as usize {
228 return Err(SecurityError::ValidationError {
229 message: "Request size exceeds maximum allowed".to_string(),
230 });
231 }
232
233 let json_value: serde_json::Value =
235 serde_json::from_str(json_str).map_err(|e| SecurityError::ValidationError {
236 message: format!("Invalid JSON: {e}"),
237 })?;
238
239 self.validate_json_value(&json_value)?;
241
242 Ok(json_value)
243 }
244
245 fn validate_json_value(&self, value: &serde_json::Value) -> Result<()> {
246 match value {
247 serde_json::Value::String(s) => {
248 self.validate_input(s)?;
249 }
250 serde_json::Value::Array(arr) => {
251 for item in arr {
252 self.validate_json_value(item)?;
253 }
254 }
255 serde_json::Value::Object(obj) => {
256 for (key, val) in obj {
257 self.validate_input(key)?;
258 self.validate_json_value(val)?;
259 }
260 }
261 _ => {} }
263 Ok(())
264 }
265
266 pub fn validate_headers(&self, headers: &HeaderMap) -> Result<()> {
268 if !self.config.enabled {
269 return Ok(());
270 }
271
272 if let Some(user_agent) = headers.get(header::USER_AGENT) {
274 if let Ok(ua_str) = user_agent.to_str() {
275 self.validate_input(ua_str)?;
276
277 let suspicious_patterns = vec![
279 r"(?i)(sqlmap|nmap|nikto|dirb|gobuster)",
280 r"(?i)(masscan|zap|burp|wget|curl)",
281 r"(?i)(python-requests|libwww-perl)",
282 ];
283
284 for pattern_str in suspicious_patterns {
285 let pattern = Regex::new(pattern_str).unwrap();
286 if pattern.is_match(ua_str) {
287 warn!("Suspicious user agent detected: {}", ua_str);
288 return Err(SecurityError::ValidationError {
289 message: "Suspicious user agent".to_string(),
290 });
291 }
292 }
293 }
294 }
295
296 if let Some(referer) = headers.get(header::REFERER) {
298 if let Ok(referer_str) = referer.to_str() {
299 self.validate_input(referer_str)?;
300 }
301 }
302
303 for (_name, value) in headers {
305 if let Ok(value_str) = value.to_str() {
306 self.validate_input(value_str)?;
307 }
308 }
309
310 Ok(())
311 }
312
313 pub fn validate_content_type(&self, content_type: Option<&str>) -> Result<()> {
315 if !self.config.enabled {
316 return Ok(());
317 }
318
319 let allowed_types = [
320 "application/json",
321 "application/x-www-form-urlencoded",
322 "text/plain",
323 "multipart/form-data",
324 ];
325
326 if let Some(ct) = content_type {
327 let ct_main = ct.split(';').next().unwrap_or(ct).trim();
328
329 if !allowed_types.contains(&ct_main) {
330 return Err(SecurityError::ValidationError {
331 message: format!("Content type not allowed: {ct_main}"),
332 });
333 }
334 }
335
336 Ok(())
337 }
338
339 pub fn is_enabled(&self) -> bool {
340 self.config.enabled
341 }
342
343 pub fn get_max_request_size(&self) -> u64 {
344 self.config.max_request_size
345 }
346}
347
348#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
350pub struct ValidatedRequest {
351 #[validate(length(
352 min = 1,
353 max = 1000,
354 message = "Content must be between 1 and 1000 characters"
355 ))]
356 pub content: Option<String>,
357
358 #[validate(email(message = "Invalid email format"))]
359 pub email: Option<String>,
360
361 #[validate(url(message = "Invalid URL format"))]
362 pub url: Option<String>,
363
364 #[validate(range(min = 1, max = 1000, message = "Limit must be between 1 and 1000"))]
365 pub limit: Option<i32>,
366
367 #[validate(range(min = 0, message = "Offset must be non-negative"))]
368 pub offset: Option<i32>,
369
370 #[validate(length(min = 1, max = 1000))]
371 pub query: Option<String>,
372}
373
374#[allow(dead_code)]
376fn validate_safe_string(value: &str) -> std::result::Result<(), ValidationError> {
377 if value.contains("<script") || value.contains("javascript:") || value.contains("../../") {
379 return Err(ValidationError::new("unsafe_content"));
380 }
381
382 Ok(())
383}
384
385pub async fn validation_middleware(
387 State(validator): State<Arc<ValidationManager>>,
388 headers: HeaderMap,
389 request: Request,
390 next: Next,
391) -> std::result::Result<Response, StatusCode> {
392 if !validator.is_enabled() {
393 return Ok(next.run(request).await);
394 }
395
396 if validator.validate_headers(&headers).is_err() {
398 warn!("Request validation failed: invalid headers");
399 return Err(StatusCode::BAD_REQUEST);
400 }
401
402 let content_type = headers
404 .get(header::CONTENT_TYPE)
405 .and_then(|ct| ct.to_str().ok());
406
407 if validator.validate_content_type(content_type).is_err() {
408 warn!("Request validation failed: invalid content type");
409 return Err(StatusCode::UNSUPPORTED_MEDIA_TYPE);
410 }
411
412 if let Some(content_length) = headers.get(header::CONTENT_LENGTH) {
414 if let Ok(length_str) = content_length.to_str() {
415 if let Ok(length) = length_str.parse::<u64>() {
416 if length > validator.get_max_request_size() {
417 warn!(
418 "Request validation failed: request too large ({} bytes)",
419 length
420 );
421 return Err(StatusCode::PAYLOAD_TOO_LARGE);
422 }
423 }
424 }
425 }
426
427 debug!("Request validation passed");
428 Ok(next.run(request).await)
429}
430
431#[cfg(test)]
432mod tests {
433 use super::*;
434
435 #[test]
436 fn test_validation_manager_creation() {
437 let config = ValidationConfig::default();
438 let manager = ValidationManager::new(config).unwrap();
439 assert!(manager.is_enabled());
440 }
441
442 #[test]
443 fn test_sql_injection_detection() {
444 let mut config = ValidationConfig::default();
445 config.sql_injection_protection = true;
446
447 let manager = ValidationManager::new(config).unwrap();
448
449 let result = manager.validate_input("SELECT name FROM users WHERE id = 1");
451 assert!(result.is_ok());
452
453 let result = manager.validate_input("'; DROP TABLE users; --");
455 assert!(result.is_err());
456
457 if let Err(SecurityError::ValidationError { message }) = result {
458 assert!(message.contains("SQL injection"));
459 }
460 }
461
462 #[test]
463 fn test_xss_detection() {
464 let mut config = ValidationConfig::default();
465 config.xss_protection = true;
466
467 let manager = ValidationManager::new(config).unwrap();
468
469 let result = manager.validate_input("Hello world!");
471 assert!(result.is_ok());
472
473 let result = manager.validate_input("<script>alert('xss')</script>");
475 assert!(result.is_err());
476
477 if let Err(SecurityError::ValidationError { message }) = result {
478 assert!(message.contains("XSS"));
479 }
480 }
481
482 #[test]
483 fn test_input_sanitization() {
484 let mut config = ValidationConfig::default();
485 config.sanitize_input = true;
486 config.xss_protection = true;
487
488 let manager = ValidationManager::new(config).unwrap();
489
490 let result = manager
491 .validate_input("Hello <world> & 'test' \"quote\"")
492 .unwrap();
493 assert_eq!(
494 result,
495 "Hello <world> & 'test' "quote""
496 );
497 }
498
499 #[test]
500 fn test_malicious_pattern_detection() {
501 let config = ValidationConfig::default();
502 let manager = ValidationManager::new(config).unwrap();
503
504 let result = manager.validate_input("../../../etc/passwd");
506 assert!(result.is_err());
507
508 let result = manager.validate_input("test; rm -rf /");
510 assert!(result.is_ok()); let result = manager.validate_input("../../etc/shadow");
514 assert!(result.is_err());
515 }
516
517 #[test]
518 fn test_json_validation() {
519 let config = ValidationConfig::default();
520 let manager = ValidationManager::new(config).unwrap();
521
522 let json = r#"{"name": "test", "value": 123}"#;
524 let result = manager.validate_json(json);
525 assert!(result.is_ok());
526
527 let invalid_json = r#"{"name": "test", "value":}"#;
529 let result = manager.validate_json(invalid_json);
530 assert!(result.is_err());
531 }
532
533 #[test]
534 fn test_json_content_validation() {
535 let mut config = ValidationConfig::default();
536 config.xss_protection = true;
537
538 let manager = ValidationManager::new(config).unwrap();
539
540 let json = r#"{"comment": "<script>alert('xss')</script>"}"#;
542 let result = manager.validate_json(json);
543 assert!(result.is_err());
544 }
545
546 #[test]
547 fn test_content_type_validation() {
548 let config = ValidationConfig::default();
549 let manager = ValidationManager::new(config).unwrap();
550
551 let result = manager.validate_content_type(Some("application/json"));
553 assert!(result.is_ok());
554
555 let result = manager.validate_content_type(Some("application/x-executable"));
557 assert!(result.is_err());
558
559 let result = manager.validate_content_type(Some("application/json; charset=utf-8"));
561 assert!(result.is_ok());
562 }
563
564 #[test]
565 fn test_validation_disabled() {
566 let mut config = ValidationConfig::default();
567 config.enabled = false;
568
569 let manager = ValidationManager::new(config).unwrap();
570 assert!(!manager.is_enabled());
571
572 let result = manager.validate_input("<script>alert('xss')</script>");
574 assert!(result.is_ok());
575 }
576
577 #[test]
578 fn test_validated_request_struct() {
579 let request = ValidatedRequest {
580 content: Some("Hello world".to_string()),
581 email: Some("test@example.com".to_string()),
582 url: Some("https://example.com".to_string()),
583 limit: Some(100),
584 offset: Some(0),
585 query: Some("safe query".to_string()),
586 };
587
588 let validation_result = request.validate();
589 assert!(validation_result.is_ok());
590 }
591
592 #[test]
593 fn test_validated_request_invalid() {
594 let request = ValidatedRequest {
595 content: Some("".to_string()), email: Some("invalid-email".to_string()), url: Some("not-a-url".to_string()), limit: Some(2000), offset: Some(-1), query: Some("<script>alert('xss')</script>".to_string()), };
602
603 let validation_result = request.validate();
604 assert!(validation_result.is_err());
605
606 let errors = validation_result.unwrap_err();
607 assert!(!errors.field_errors().is_empty());
608 }
609
610 #[test]
611 fn test_custom_validator() {
612 let valid_result = validate_safe_string("This is a safe string");
613 assert!(valid_result.is_ok());
614
615 let invalid_result = validate_safe_string("<script>alert('test')</script>");
616 assert!(invalid_result.is_err());
617
618 let traversal_result = validate_safe_string("../../etc/passwd");
619 assert!(traversal_result.is_err());
620 }
621
622 #[test]
623 fn test_request_size_limits() {
624 let config = ValidationConfig {
625 enabled: true,
626 max_request_size: 1024, sanitize_input: true,
628 xss_protection: true,
629 sql_injection_protection: true,
630 };
631
632 let manager = ValidationManager::new(config).unwrap();
633 assert_eq!(manager.get_max_request_size(), 1024);
634
635 let large_json = "x".repeat(2000);
637 let json = format!(r#"{{"data": "{large_json}"}}"#);
638 let result = manager.validate_json(&json);
639 assert!(result.is_err());
640
641 if let Err(SecurityError::ValidationError { message }) = result {
642 assert!(message.contains("exceeds maximum"));
643 }
644 }
645}