codex_memory/security/
validation.rs

1use crate::security::{Result, SecurityError, ValidationConfig};
2use axum::{
3    extract::{Request, State},
4    http::{header, HeaderMap, StatusCode},
5    middleware::Next,
6    response::Response,
7};
8use regex::Regex;
9use serde::{Deserialize, Serialize};
10use std::sync::Arc;
11use tracing::{debug, warn};
12use validator::{Validate, ValidationError};
13
14/// Input validation and sanitization manager
15pub struct ValidationManager {
16    config: ValidationConfig,
17    sql_injection_patterns: Vec<Regex>,
18    xss_patterns: Vec<Regex>,
19    malicious_patterns: Vec<Regex>,
20}
21
22impl ValidationManager {
23    pub fn new(config: ValidationConfig) -> Result<Self> {
24        let mut manager = Self {
25            config,
26            sql_injection_patterns: Vec::new(),
27            xss_patterns: Vec::new(),
28            malicious_patterns: Vec::new(),
29        };
30
31        if manager.config.enabled {
32            manager.initialize_patterns()?;
33        }
34
35        Ok(manager)
36    }
37
38    fn initialize_patterns(&mut self) -> Result<()> {
39        // SQL injection patterns
40        if self.config.sql_injection_protection {
41            let sql_patterns = vec![
42                r"(?i)(union\s+select)",
43                r"(?i)(drop\s+table)",
44                r"(?i)(delete\s+from)",
45                r"(?i)(insert\s+into)",
46                r"(?i)(update\s+set)",
47                r"(?i)(alter\s+table)",
48                r"(?i)(create\s+table)",
49                r"(?i)(exec\s*\()",
50                r"(?i)(execute\s*\()",
51                r"(?i)(\'\s*or\s*\'\s*=\s*\')",
52                r"(?i)(\'\s*or\s*1\s*=\s*1)",
53                r"(?i)(\'\s*;\s*drop)",
54                r"(?i)(--\s*)",
55                r"(?i)(/\*.*\*/)",
56                r"(?i)(xp_cmdshell)",
57                r"(?i)(sp_executesql)",
58            ];
59
60            for pattern in sql_patterns {
61                self.sql_injection_patterns
62                    .push(
63                        Regex::new(pattern).map_err(|e| SecurityError::ValidationError {
64                            message: format!("Failed to compile SQL injection pattern: {e}"),
65                        })?,
66                    );
67            }
68        }
69
70        // XSS patterns
71        if self.config.xss_protection {
72            let xss_patterns = vec![
73                r"(?i)<script[^>]*>",
74                r"(?i)</script>",
75                r"(?i)<iframe[^>]*>",
76                r"(?i)<object[^>]*>",
77                r"(?i)<embed[^>]*>",
78                r"(?i)<link[^>]*>",
79                r"(?i)<meta[^>]*>",
80                r"(?i)javascript:",
81                r"(?i)vbscript:",
82                r"(?i)onload\s*=",
83                r"(?i)onerror\s*=",
84                r"(?i)onclick\s*=",
85                r"(?i)onmouseover\s*=",
86                r"(?i)onfocus\s*=",
87                r"(?i)onblur\s*=",
88                r"(?i)onchange\s*=",
89                r"(?i)onsubmit\s*=",
90                r"(?i)expression\s*\(",
91                r"(?i)url\s*\(",
92                r"(?i)@import",
93            ];
94
95            for pattern in xss_patterns {
96                self.xss_patterns.push(Regex::new(pattern).map_err(|e| {
97                    SecurityError::ValidationError {
98                        message: format!("Failed to compile XSS pattern: {e}"),
99                    }
100                })?);
101            }
102        }
103
104        // General malicious patterns
105        let malicious_patterns = vec![
106            r"(?i)(\.\.\/){2,}", // Path traversal
107            r"(?i)\.\.\\",       // Windows path traversal
108            r"(?i)\/etc\/passwd",
109            r"(?i)\/etc\/shadow",
110            r"(?i)\/proc\/",
111            r"(?i)c:\\windows\\",
112            r"(?i)cmd\.exe",
113            r"(?i)powershell\.exe",
114            r"(?i)bash\s*-c",
115            r"(?i)sh\s*-c",
116            r"(?i)\$\([^)]*\)", // Command substitution
117            r"(?i)`[^`]*`",     // Backtick command execution
118        ];
119
120        for pattern in malicious_patterns {
121            self.malicious_patterns
122                .push(
123                    Regex::new(pattern).map_err(|e| SecurityError::ValidationError {
124                        message: format!("Failed to compile malicious pattern: {e}"),
125                    })?,
126                );
127        }
128
129        debug!(
130            "Initialized validation patterns: {} SQL, {} XSS, {} malicious",
131            self.sql_injection_patterns.len(),
132            self.xss_patterns.len(),
133            self.malicious_patterns.len()
134        );
135
136        Ok(())
137    }
138
139    /// Validate and sanitize input string
140    pub fn validate_input(&self, input: &str) -> Result<String> {
141        if !self.config.enabled {
142            return Ok(input.to_string());
143        }
144
145        // Check for SQL injection
146        if self.config.sql_injection_protection {
147            for pattern in &self.sql_injection_patterns {
148                if pattern.is_match(input) {
149                    warn!("SQL injection attempt detected: {}", pattern.as_str());
150                    return Err(SecurityError::ValidationError {
151                        message: "Potential SQL injection detected".to_string(),
152                    });
153                }
154            }
155        }
156
157        // Check for XSS
158        if self.config.xss_protection {
159            for pattern in &self.xss_patterns {
160                if pattern.is_match(input) {
161                    warn!("XSS attempt detected: {}", pattern.as_str());
162                    return Err(SecurityError::ValidationError {
163                        message: "Potential XSS detected".to_string(),
164                    });
165                }
166            }
167        }
168
169        // Check for general malicious patterns
170        for pattern in &self.malicious_patterns {
171            if pattern.is_match(input) {
172                warn!("Malicious pattern detected: {}", pattern.as_str());
173                return Err(SecurityError::ValidationError {
174                    message: "Malicious content detected".to_string(),
175                });
176            }
177        }
178
179        // Sanitize if enabled
180        if self.config.sanitize_input {
181            Ok(self.sanitize_input(input))
182        } else {
183            Ok(input.to_string())
184        }
185    }
186
187    /// Sanitize input by removing or escaping dangerous characters
188    fn sanitize_input(&self, input: &str) -> String {
189        let mut sanitized = input.to_string();
190
191        // Remove null bytes
192        sanitized = sanitized.replace('\0', "");
193
194        // Remove or escape common dangerous characters
195        sanitized = sanitized.replace('\r', "");
196        sanitized = sanitized.replace('\n', " ");
197        sanitized = sanitized.replace('\t', " ");
198
199        // Escape HTML entities if XSS protection is enabled
200        if self.config.xss_protection {
201            // Replace & first to avoid double-escaping
202            sanitized = sanitized.replace('&', "&amp;");
203            sanitized = sanitized.replace('<', "&lt;");
204            sanitized = sanitized.replace('>', "&gt;");
205            sanitized = sanitized.replace('"', "&quot;");
206            sanitized = sanitized.replace('\'', "&#x27;");
207        }
208
209        // Limit length to prevent buffer overflow attacks
210        if sanitized.len() > 10000 {
211            sanitized.truncate(10000);
212            sanitized.push_str("...");
213        }
214
215        sanitized
216    }
217
218    /// Validate JSON payload
219    pub fn validate_json(&self, json_str: &str) -> Result<serde_json::Value> {
220        if !self.config.enabled {
221            return serde_json::from_str(json_str).map_err(|e| SecurityError::ValidationError {
222                message: format!("Invalid JSON: {e}"),
223            });
224        }
225
226        // Check JSON size
227        if json_str.len() > self.config.max_request_size as usize {
228            return Err(SecurityError::ValidationError {
229                message: "Request size exceeds maximum allowed".to_string(),
230            });
231        }
232
233        // Parse JSON
234        let json_value: serde_json::Value =
235            serde_json::from_str(json_str).map_err(|e| SecurityError::ValidationError {
236                message: format!("Invalid JSON: {e}"),
237            })?;
238
239        // Validate JSON content recursively
240        self.validate_json_value(&json_value)?;
241
242        Ok(json_value)
243    }
244
245    fn validate_json_value(&self, value: &serde_json::Value) -> Result<()> {
246        match value {
247            serde_json::Value::String(s) => {
248                self.validate_input(s)?;
249            }
250            serde_json::Value::Array(arr) => {
251                for item in arr {
252                    self.validate_json_value(item)?;
253                }
254            }
255            serde_json::Value::Object(obj) => {
256                for (key, val) in obj {
257                    self.validate_input(key)?;
258                    self.validate_json_value(val)?;
259                }
260            }
261            _ => {} // Numbers, booleans, null are safe
262        }
263        Ok(())
264    }
265
266    /// Validate HTTP headers
267    pub fn validate_headers(&self, headers: &HeaderMap) -> Result<()> {
268        if !self.config.enabled {
269            return Ok(());
270        }
271
272        // Check User-Agent header
273        if let Some(user_agent) = headers.get(header::USER_AGENT) {
274            if let Ok(ua_str) = user_agent.to_str() {
275                self.validate_input(ua_str)?;
276
277                // Check for suspicious user agents
278                let suspicious_patterns = vec![
279                    r"(?i)(sqlmap|nmap|nikto|dirb|gobuster)",
280                    r"(?i)(masscan|zap|burp|wget|curl)",
281                    r"(?i)(python-requests|libwww-perl)",
282                ];
283
284                for pattern_str in suspicious_patterns {
285                    let pattern = Regex::new(pattern_str).unwrap();
286                    if pattern.is_match(ua_str) {
287                        warn!("Suspicious user agent detected: {}", ua_str);
288                        return Err(SecurityError::ValidationError {
289                            message: "Suspicious user agent".to_string(),
290                        });
291                    }
292                }
293            }
294        }
295
296        // Check Referer header for common attacks
297        if let Some(referer) = headers.get(header::REFERER) {
298            if let Ok(referer_str) = referer.to_str() {
299                self.validate_input(referer_str)?;
300            }
301        }
302
303        // Check custom headers
304        for (_name, value) in headers {
305            if let Ok(value_str) = value.to_str() {
306                self.validate_input(value_str)?;
307            }
308        }
309
310        Ok(())
311    }
312
313    /// Check if content type is allowed
314    pub fn validate_content_type(&self, content_type: Option<&str>) -> Result<()> {
315        if !self.config.enabled {
316            return Ok(());
317        }
318
319        let allowed_types = [
320            "application/json",
321            "application/x-www-form-urlencoded",
322            "text/plain",
323            "multipart/form-data",
324        ];
325
326        if let Some(ct) = content_type {
327            let ct_main = ct.split(';').next().unwrap_or(ct).trim();
328
329            if !allowed_types.contains(&ct_main) {
330                return Err(SecurityError::ValidationError {
331                    message: format!("Content type not allowed: {ct_main}"),
332                });
333            }
334        }
335
336        Ok(())
337    }
338
339    pub fn is_enabled(&self) -> bool {
340        self.config.enabled
341    }
342
343    pub fn get_max_request_size(&self) -> u64 {
344        self.config.max_request_size
345    }
346}
347
348/// Request validation data for structured validation
349#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
350pub struct ValidatedRequest {
351    #[validate(length(
352        min = 1,
353        max = 1000,
354        message = "Content must be between 1 and 1000 characters"
355    ))]
356    pub content: Option<String>,
357
358    #[validate(email(message = "Invalid email format"))]
359    pub email: Option<String>,
360
361    #[validate(url(message = "Invalid URL format"))]
362    pub url: Option<String>,
363
364    #[validate(range(min = 1, max = 1000, message = "Limit must be between 1 and 1000"))]
365    pub limit: Option<i32>,
366
367    #[validate(range(min = 0, message = "Offset must be non-negative"))]
368    pub offset: Option<i32>,
369
370    #[validate(length(min = 1, max = 1000))]
371    pub query: Option<String>,
372}
373
374/// Custom validator for safe strings
375#[allow(dead_code)]
376fn validate_safe_string(value: &str) -> std::result::Result<(), ValidationError> {
377    // Check for dangerous characters
378    if value.contains("<script") || value.contains("javascript:") || value.contains("../../") {
379        return Err(ValidationError::new("unsafe_content"));
380    }
381
382    Ok(())
383}
384
385/// Validation middleware for Axum
386pub async fn validation_middleware(
387    State(validator): State<Arc<ValidationManager>>,
388    headers: HeaderMap,
389    request: Request,
390    next: Next,
391) -> std::result::Result<Response, StatusCode> {
392    if !validator.is_enabled() {
393        return Ok(next.run(request).await);
394    }
395
396    // Validate headers
397    if validator.validate_headers(&headers).is_err() {
398        warn!("Request validation failed: invalid headers");
399        return Err(StatusCode::BAD_REQUEST);
400    }
401
402    // Validate content type
403    let content_type = headers
404        .get(header::CONTENT_TYPE)
405        .and_then(|ct| ct.to_str().ok());
406
407    if validator.validate_content_type(content_type).is_err() {
408        warn!("Request validation failed: invalid content type");
409        return Err(StatusCode::UNSUPPORTED_MEDIA_TYPE);
410    }
411
412    // Check request size
413    if let Some(content_length) = headers.get(header::CONTENT_LENGTH) {
414        if let Ok(length_str) = content_length.to_str() {
415            if let Ok(length) = length_str.parse::<u64>() {
416                if length > validator.get_max_request_size() {
417                    warn!(
418                        "Request validation failed: request too large ({} bytes)",
419                        length
420                    );
421                    return Err(StatusCode::PAYLOAD_TOO_LARGE);
422                }
423            }
424        }
425    }
426
427    debug!("Request validation passed");
428    Ok(next.run(request).await)
429}
430
431#[cfg(test)]
432mod tests {
433    use super::*;
434
435    #[test]
436    fn test_validation_manager_creation() {
437        let config = ValidationConfig::default();
438        let manager = ValidationManager::new(config).unwrap();
439        assert!(manager.is_enabled());
440    }
441
442    #[test]
443    fn test_sql_injection_detection() {
444        let mut config = ValidationConfig::default();
445        config.sql_injection_protection = true;
446
447        let manager = ValidationManager::new(config).unwrap();
448
449        // Test valid input
450        let result = manager.validate_input("SELECT name FROM users WHERE id = 1");
451        assert!(result.is_ok());
452
453        // Test SQL injection attempt
454        let result = manager.validate_input("'; DROP TABLE users; --");
455        assert!(result.is_err());
456
457        if let Err(SecurityError::ValidationError { message }) = result {
458            assert!(message.contains("SQL injection"));
459        }
460    }
461
462    #[test]
463    fn test_xss_detection() {
464        let mut config = ValidationConfig::default();
465        config.xss_protection = true;
466
467        let manager = ValidationManager::new(config).unwrap();
468
469        // Test valid input
470        let result = manager.validate_input("Hello world!");
471        assert!(result.is_ok());
472
473        // Test XSS attempt
474        let result = manager.validate_input("<script>alert('xss')</script>");
475        assert!(result.is_err());
476
477        if let Err(SecurityError::ValidationError { message }) = result {
478            assert!(message.contains("XSS"));
479        }
480    }
481
482    #[test]
483    fn test_input_sanitization() {
484        let mut config = ValidationConfig::default();
485        config.sanitize_input = true;
486        config.xss_protection = true;
487
488        let manager = ValidationManager::new(config).unwrap();
489
490        let result = manager
491            .validate_input("Hello <world> & 'test' \"quote\"")
492            .unwrap();
493        assert_eq!(
494            result,
495            "Hello &lt;world&gt; &amp; &#x27;test&#x27; &quot;quote&quot;"
496        );
497    }
498
499    #[test]
500    fn test_malicious_pattern_detection() {
501        let config = ValidationConfig::default();
502        let manager = ValidationManager::new(config).unwrap();
503
504        // Test path traversal
505        let result = manager.validate_input("../../../etc/passwd");
506        assert!(result.is_err());
507
508        // Test command injection
509        let result = manager.validate_input("test; rm -rf /");
510        assert!(result.is_ok()); // This specific pattern isn't in our malicious patterns
511
512        // Test directory traversal
513        let result = manager.validate_input("../../etc/shadow");
514        assert!(result.is_err());
515    }
516
517    #[test]
518    fn test_json_validation() {
519        let config = ValidationConfig::default();
520        let manager = ValidationManager::new(config).unwrap();
521
522        // Valid JSON
523        let json = r#"{"name": "test", "value": 123}"#;
524        let result = manager.validate_json(json);
525        assert!(result.is_ok());
526
527        // Invalid JSON
528        let invalid_json = r#"{"name": "test", "value":}"#;
529        let result = manager.validate_json(invalid_json);
530        assert!(result.is_err());
531    }
532
533    #[test]
534    fn test_json_content_validation() {
535        let mut config = ValidationConfig::default();
536        config.xss_protection = true;
537
538        let manager = ValidationManager::new(config).unwrap();
539
540        // JSON with XSS content
541        let json = r#"{"comment": "<script>alert('xss')</script>"}"#;
542        let result = manager.validate_json(json);
543        assert!(result.is_err());
544    }
545
546    #[test]
547    fn test_content_type_validation() {
548        let config = ValidationConfig::default();
549        let manager = ValidationManager::new(config).unwrap();
550
551        // Allowed content type
552        let result = manager.validate_content_type(Some("application/json"));
553        assert!(result.is_ok());
554
555        // Not allowed content type
556        let result = manager.validate_content_type(Some("application/x-executable"));
557        assert!(result.is_err());
558
559        // Content type with charset
560        let result = manager.validate_content_type(Some("application/json; charset=utf-8"));
561        assert!(result.is_ok());
562    }
563
564    #[test]
565    fn test_validation_disabled() {
566        let mut config = ValidationConfig::default();
567        config.enabled = false;
568
569        let manager = ValidationManager::new(config).unwrap();
570        assert!(!manager.is_enabled());
571
572        // Should pass even with malicious content when disabled
573        let result = manager.validate_input("<script>alert('xss')</script>");
574        assert!(result.is_ok());
575    }
576
577    #[test]
578    fn test_validated_request_struct() {
579        let request = ValidatedRequest {
580            content: Some("Hello world".to_string()),
581            email: Some("test@example.com".to_string()),
582            url: Some("https://example.com".to_string()),
583            limit: Some(100),
584            offset: Some(0),
585            query: Some("safe query".to_string()),
586        };
587
588        let validation_result = request.validate();
589        assert!(validation_result.is_ok());
590    }
591
592    #[test]
593    fn test_validated_request_invalid() {
594        let request = ValidatedRequest {
595            content: Some("".to_string()),                            // Too short
596            email: Some("invalid-email".to_string()),                 // Invalid email
597            url: Some("not-a-url".to_string()),                       // Invalid URL
598            limit: Some(2000),                                        // Too large
599            offset: Some(-1),                                         // Negative
600            query: Some("<script>alert('xss')</script>".to_string()), // Unsafe content
601        };
602
603        let validation_result = request.validate();
604        assert!(validation_result.is_err());
605
606        let errors = validation_result.unwrap_err();
607        assert!(!errors.field_errors().is_empty());
608    }
609
610    #[test]
611    fn test_custom_validator() {
612        let valid_result = validate_safe_string("This is a safe string");
613        assert!(valid_result.is_ok());
614
615        let invalid_result = validate_safe_string("<script>alert('test')</script>");
616        assert!(invalid_result.is_err());
617
618        let traversal_result = validate_safe_string("../../etc/passwd");
619        assert!(traversal_result.is_err());
620    }
621
622    #[test]
623    fn test_request_size_limits() {
624        let config = ValidationConfig {
625            enabled: true,
626            max_request_size: 1024, // 1KB limit
627            sanitize_input: true,
628            xss_protection: true,
629            sql_injection_protection: true,
630        };
631
632        let manager = ValidationManager::new(config).unwrap();
633        assert_eq!(manager.get_max_request_size(), 1024);
634
635        // Large JSON should fail
636        let large_json = "x".repeat(2000);
637        let json = format!(r#"{{"data": "{large_json}"}}"#);
638        let result = manager.validate_json(&json);
639        assert!(result.is_err());
640
641        if let Err(SecurityError::ValidationError { message }) = result {
642            assert!(message.contains("exceeds maximum"));
643        }
644    }
645}