codex_memory/security/
validation.rs

1use crate::security::{Result, SecurityError, ValidationConfig};
2use axum::{
3    extract::{Request, State},
4    http::{header, HeaderMap, StatusCode},
5    middleware::Next,
6    response::Response,
7};
8use regex::Regex;
9use serde::{Deserialize, Serialize};
10use std::sync::Arc;
11use tracing::{debug, warn};
12use validator::{Validate, ValidationError};
13
14/// Input validation and sanitization manager
15pub struct ValidationManager {
16    config: ValidationConfig,
17    sql_injection_patterns: Vec<Regex>,
18    xss_patterns: Vec<Regex>,
19    malicious_patterns: Vec<Regex>,
20}
21
22impl ValidationManager {
23    pub fn new(config: ValidationConfig) -> Result<Self> {
24        let mut manager = Self {
25            config,
26            sql_injection_patterns: Vec::new(),
27            xss_patterns: Vec::new(),
28            malicious_patterns: Vec::new(),
29        };
30
31        if manager.config.enabled {
32            manager.initialize_patterns()?;
33        }
34
35        Ok(manager)
36    }
37
38    fn initialize_patterns(&mut self) -> Result<()> {
39        // SQL injection patterns
40        if self.config.sql_injection_protection {
41            let sql_patterns = vec![
42                r"(?i)(union\s+select)",
43                r"(?i)(drop\s+table)",
44                r"(?i)(delete\s+from)",
45                r"(?i)(insert\s+into)",
46                r"(?i)(update\s+set)",
47                r"(?i)(alter\s+table)",
48                r"(?i)(create\s+table)",
49                r"(?i)(exec\s*\()",
50                r"(?i)(execute\s*\()",
51                r"(?i)(\'\s*or\s*\'\s*=\s*\')",
52                r"(?i)(\'\s*or\s*1\s*=\s*1)",
53                r"(?i)(\'\s*;\s*drop)",
54                r"(?i)(--\s*)",
55                r"(?i)(/\*.*\*/)",
56                r"(?i)(xp_cmdshell)",
57                r"(?i)(sp_executesql)",
58            ];
59
60            for pattern in sql_patterns {
61                self.sql_injection_patterns
62                    .push(
63                        Regex::new(pattern).map_err(|e| SecurityError::ValidationError {
64                            message: format!("Failed to compile SQL injection pattern: {e}"),
65                        })?,
66                    );
67            }
68        }
69
70        // XSS patterns
71        if self.config.xss_protection {
72            let xss_patterns = vec![
73                r"(?i)<script[^>]*>",
74                r"(?i)</script>",
75                r"(?i)<iframe[^>]*>",
76                r"(?i)<object[^>]*>",
77                r"(?i)<embed[^>]*>",
78                r"(?i)<link[^>]*>",
79                r"(?i)<meta[^>]*>",
80                r"(?i)javascript:",
81                r"(?i)vbscript:",
82                r"(?i)onload\s*=",
83                r"(?i)onerror\s*=",
84                r"(?i)onclick\s*=",
85                r"(?i)onmouseover\s*=",
86                r"(?i)onfocus\s*=",
87                r"(?i)onblur\s*=",
88                r"(?i)onchange\s*=",
89                r"(?i)onsubmit\s*=",
90                r"(?i)expression\s*\(",
91                r"(?i)url\s*\(",
92                r"(?i)@import",
93            ];
94
95            for pattern in xss_patterns {
96                self.xss_patterns.push(Regex::new(pattern).map_err(|e| {
97                    SecurityError::ValidationError {
98                        message: format!("Failed to compile XSS pattern: {e}"),
99                    }
100                })?);
101            }
102        }
103
104        // General malicious patterns
105        let malicious_patterns = vec![
106            r"(?i)(\.\.\/){2,}", // Path traversal
107            r"(?i)\.\.\\",       // Windows path traversal
108            r"(?i)\/etc\/passwd",
109            r"(?i)\/etc\/shadow",
110            r"(?i)\/proc\/",
111            r"(?i)c:\\windows\\",
112            r"(?i)cmd\.exe",
113            r"(?i)powershell\.exe",
114            r"(?i)bash\s*-c",
115            r"(?i)sh\s*-c",
116            r"(?i)\$\([^)]*\)", // Command substitution
117            r"(?i)`[^`]*`",     // Backtick command execution
118        ];
119
120        for pattern in malicious_patterns {
121            self.malicious_patterns
122                .push(
123                    Regex::new(pattern).map_err(|e| SecurityError::ValidationError {
124                        message: format!("Failed to compile malicious pattern: {e}"),
125                    })?,
126                );
127        }
128
129        debug!(
130            "Initialized validation patterns: {} SQL, {} XSS, {} malicious",
131            self.sql_injection_patterns.len(),
132            self.xss_patterns.len(),
133            self.malicious_patterns.len()
134        );
135
136        Ok(())
137    }
138
139    /// Validate and sanitize input string
140    pub fn validate_input(&self, input: &str) -> Result<String> {
141        if !self.config.enabled {
142            return Ok(input.to_string());
143        }
144
145        // Check for SQL injection
146        if self.config.sql_injection_protection {
147            for pattern in &self.sql_injection_patterns {
148                if pattern.is_match(input) {
149                    warn!("SQL injection attempt detected: {}", pattern.as_str());
150                    return Err(SecurityError::ValidationError {
151                        message: "Potential SQL injection detected".to_string(),
152                    });
153                }
154            }
155        }
156
157        // Check for XSS (can be disabled for testing)
158        let skip_xss_check =
159            std::env::var("SKIP_XSS_CHECK").unwrap_or_else(|_| "false".to_string()) == "true";
160
161        if self.config.xss_protection && !skip_xss_check {
162            for pattern in &self.xss_patterns {
163                if pattern.is_match(input) {
164                    warn!("XSS attempt detected: {}", pattern.as_str());
165                    return Err(SecurityError::ValidationError {
166                        message: "Potential XSS detected".to_string(),
167                    });
168                }
169            }
170        }
171
172        // Check for general malicious patterns
173        for pattern in &self.malicious_patterns {
174            if pattern.is_match(input) {
175                warn!("Malicious pattern detected: {}", pattern.as_str());
176                return Err(SecurityError::ValidationError {
177                    message: "Malicious content detected".to_string(),
178                });
179            }
180        }
181
182        // Sanitize if enabled
183        if self.config.sanitize_input {
184            Ok(self.sanitize_input(input))
185        } else {
186            Ok(input.to_string())
187        }
188    }
189
190    /// Sanitize input by removing or escaping dangerous characters
191    fn sanitize_input(&self, input: &str) -> String {
192        let mut sanitized = input.to_string();
193
194        // Remove null bytes
195        sanitized = sanitized.replace('\0', "");
196
197        // Remove or escape common dangerous characters
198        sanitized = sanitized.replace('\r', "");
199        sanitized = sanitized.replace('\n', " ");
200        sanitized = sanitized.replace('\t', " ");
201
202        // Escape HTML entities if XSS protection is enabled
203        if self.config.xss_protection {
204            // Replace & first to avoid double-escaping
205            sanitized = sanitized.replace('&', "&amp;");
206            sanitized = sanitized.replace('<', "&lt;");
207            sanitized = sanitized.replace('>', "&gt;");
208            sanitized = sanitized.replace('"', "&quot;");
209            sanitized = sanitized.replace('\'', "&#x27;");
210        }
211
212        // Limit length to prevent buffer overflow attacks
213        if sanitized.len() > 10000 {
214            sanitized.truncate(10000);
215            sanitized.push_str("...");
216        }
217
218        sanitized
219    }
220
221    /// Validate JSON payload
222    pub fn validate_json(&self, json_str: &str) -> Result<serde_json::Value> {
223        if !self.config.enabled {
224            return serde_json::from_str(json_str).map_err(|e| SecurityError::ValidationError {
225                message: format!("Invalid JSON: {e}"),
226            });
227        }
228
229        // Check JSON size
230        if json_str.len() > self.config.max_request_size as usize {
231            return Err(SecurityError::ValidationError {
232                message: "Request size exceeds maximum allowed".to_string(),
233            });
234        }
235
236        // Parse JSON
237        let json_value: serde_json::Value =
238            serde_json::from_str(json_str).map_err(|e| SecurityError::ValidationError {
239                message: format!("Invalid JSON: {e}"),
240            })?;
241
242        // Validate JSON content recursively
243        self.validate_json_value(&json_value)?;
244
245        Ok(json_value)
246    }
247
248    fn validate_json_value(&self, value: &serde_json::Value) -> Result<()> {
249        match value {
250            serde_json::Value::String(s) => {
251                self.validate_input(s)?;
252            }
253            serde_json::Value::Array(arr) => {
254                for item in arr {
255                    self.validate_json_value(item)?;
256                }
257            }
258            serde_json::Value::Object(obj) => {
259                for (key, val) in obj {
260                    self.validate_input(key)?;
261                    self.validate_json_value(val)?;
262                }
263            }
264            _ => {} // Numbers, booleans, null are safe
265        }
266        Ok(())
267    }
268
269    /// Validate HTTP headers
270    pub fn validate_headers(&self, headers: &HeaderMap) -> Result<()> {
271        if !self.config.enabled {
272            return Ok(());
273        }
274
275        // Check User-Agent header
276        if let Some(user_agent) = headers.get(header::USER_AGENT) {
277            if let Ok(ua_str) = user_agent.to_str() {
278                self.validate_input(ua_str)?;
279
280                // Check for suspicious user agents
281                let suspicious_patterns = vec![
282                    r"(?i)(sqlmap|nmap|nikto|dirb|gobuster)",
283                    r"(?i)(masscan|zap|burp|wget|curl)",
284                    r"(?i)(python-requests|libwww-perl)",
285                ];
286
287                for pattern_str in suspicious_patterns {
288                    let pattern = Regex::new(pattern_str).unwrap();
289                    if pattern.is_match(ua_str) {
290                        warn!("Suspicious user agent detected: {}", ua_str);
291                        return Err(SecurityError::ValidationError {
292                            message: "Suspicious user agent".to_string(),
293                        });
294                    }
295                }
296            }
297        }
298
299        // Check Referer header for common attacks
300        if let Some(referer) = headers.get(header::REFERER) {
301            if let Ok(referer_str) = referer.to_str() {
302                self.validate_input(referer_str)?;
303            }
304        }
305
306        // Check custom headers
307        for (_name, value) in headers {
308            if let Ok(value_str) = value.to_str() {
309                self.validate_input(value_str)?;
310            }
311        }
312
313        Ok(())
314    }
315
316    /// Check if content type is allowed
317    pub fn validate_content_type(&self, content_type: Option<&str>) -> Result<()> {
318        if !self.config.enabled {
319            return Ok(());
320        }
321
322        let allowed_types = [
323            "application/json",
324            "application/x-www-form-urlencoded",
325            "text/plain",
326            "multipart/form-data",
327        ];
328
329        if let Some(ct) = content_type {
330            let ct_main = ct.split(';').next().unwrap_or(ct).trim();
331
332            if !allowed_types.contains(&ct_main) {
333                return Err(SecurityError::ValidationError {
334                    message: format!("Content type not allowed: {ct_main}"),
335                });
336            }
337        }
338
339        Ok(())
340    }
341
342    pub fn is_enabled(&self) -> bool {
343        self.config.enabled
344    }
345
346    pub fn get_max_request_size(&self) -> u64 {
347        self.config.max_request_size
348    }
349}
350
351/// Request validation data for structured validation
352#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
353pub struct ValidatedRequest {
354    #[validate(length(
355        min = 1,
356        max = 1000,
357        message = "Content must be between 1 and 1000 characters"
358    ))]
359    pub content: Option<String>,
360
361    #[validate(email(message = "Invalid email format"))]
362    pub email: Option<String>,
363
364    #[validate(url(message = "Invalid URL format"))]
365    pub url: Option<String>,
366
367    #[validate(range(min = 1, max = 1000, message = "Limit must be between 1 and 1000"))]
368    pub limit: Option<i32>,
369
370    #[validate(range(min = 0, message = "Offset must be non-negative"))]
371    pub offset: Option<i32>,
372
373    #[validate(length(min = 1, max = 1000))]
374    pub query: Option<String>,
375}
376
377/// Custom validator for safe strings
378#[allow(dead_code)]
379fn validate_safe_string(value: &str) -> std::result::Result<(), ValidationError> {
380    // Check for dangerous characters
381    if value.contains("<script") || value.contains("javascript:") || value.contains("../../") {
382        return Err(ValidationError::new("unsafe_content"));
383    }
384
385    Ok(())
386}
387
388/// Validation middleware for Axum
389pub async fn validation_middleware(
390    State(validator): State<Arc<ValidationManager>>,
391    headers: HeaderMap,
392    request: Request,
393    next: Next,
394) -> std::result::Result<Response, StatusCode> {
395    if !validator.is_enabled() {
396        return Ok(next.run(request).await);
397    }
398
399    // Validate headers
400    if validator.validate_headers(&headers).is_err() {
401        warn!("Request validation failed: invalid headers");
402        return Err(StatusCode::BAD_REQUEST);
403    }
404
405    // Validate content type
406    let content_type = headers
407        .get(header::CONTENT_TYPE)
408        .and_then(|ct| ct.to_str().ok());
409
410    if validator.validate_content_type(content_type).is_err() {
411        warn!("Request validation failed: invalid content type");
412        return Err(StatusCode::UNSUPPORTED_MEDIA_TYPE);
413    }
414
415    // Check request size
416    if let Some(content_length) = headers.get(header::CONTENT_LENGTH) {
417        if let Ok(length_str) = content_length.to_str() {
418            if let Ok(length) = length_str.parse::<u64>() {
419                if length > validator.get_max_request_size() {
420                    warn!(
421                        "Request validation failed: request too large ({} bytes)",
422                        length
423                    );
424                    return Err(StatusCode::PAYLOAD_TOO_LARGE);
425                }
426            }
427        }
428    }
429
430    debug!("Request validation passed");
431    Ok(next.run(request).await)
432}
433
434#[cfg(test)]
435mod tests {
436    use super::*;
437
438    #[test]
439    fn test_validation_manager_creation() {
440        let config = ValidationConfig::default();
441        let manager = ValidationManager::new(config).unwrap();
442        assert!(manager.is_enabled());
443    }
444
445    #[test]
446    fn test_sql_injection_detection() {
447        let mut config = ValidationConfig::default();
448        config.sql_injection_protection = true;
449
450        let manager = ValidationManager::new(config).unwrap();
451
452        // Test valid input
453        let result = manager.validate_input("SELECT name FROM users WHERE id = 1");
454        assert!(result.is_ok());
455
456        // Test SQL injection attempt
457        let result = manager.validate_input("'; DROP TABLE users; --");
458        assert!(result.is_err());
459
460        if let Err(SecurityError::ValidationError { message }) = result {
461            assert!(message.contains("SQL injection"));
462        }
463    }
464
465    #[test]
466    fn test_xss_detection() {
467        let mut config = ValidationConfig::default();
468        config.xss_protection = true;
469
470        let manager = ValidationManager::new(config).unwrap();
471
472        // Test valid input
473        let result = manager.validate_input("Hello world!");
474        assert!(result.is_ok());
475
476        // Test XSS attempt
477        let result = manager.validate_input("<script>alert('xss')</script>");
478        assert!(result.is_err());
479
480        if let Err(SecurityError::ValidationError { message }) = result {
481            assert!(message.contains("XSS"));
482        }
483    }
484
485    #[test]
486    fn test_input_sanitization() {
487        let mut config = ValidationConfig::default();
488        config.sanitize_input = true;
489        config.xss_protection = true;
490
491        let manager = ValidationManager::new(config).unwrap();
492
493        let result = manager
494            .validate_input("Hello <world> & 'test' \"quote\"")
495            .unwrap();
496        assert_eq!(
497            result,
498            "Hello &lt;world&gt; &amp; &#x27;test&#x27; &quot;quote&quot;"
499        );
500    }
501
502    #[test]
503    fn test_malicious_pattern_detection() {
504        let config = ValidationConfig::default();
505        let manager = ValidationManager::new(config).unwrap();
506
507        // Test path traversal
508        let result = manager.validate_input("../../../etc/passwd");
509        assert!(result.is_err());
510
511        // Test command injection
512        let result = manager.validate_input("test; rm -rf /");
513        assert!(result.is_ok()); // This specific pattern isn't in our malicious patterns
514
515        // Test directory traversal
516        let result = manager.validate_input("../../etc/shadow");
517        assert!(result.is_err());
518    }
519
520    #[test]
521    fn test_json_validation() {
522        let config = ValidationConfig::default();
523        let manager = ValidationManager::new(config).unwrap();
524
525        // Valid JSON
526        let json = r#"{"name": "test", "value": 123}"#;
527        let result = manager.validate_json(json);
528        assert!(result.is_ok());
529
530        // Invalid JSON
531        let invalid_json = r#"{"name": "test", "value":}"#;
532        let result = manager.validate_json(invalid_json);
533        assert!(result.is_err());
534    }
535
536    #[test]
537    fn test_json_content_validation() {
538        let mut config = ValidationConfig::default();
539        config.xss_protection = true;
540
541        let manager = ValidationManager::new(config).unwrap();
542
543        // JSON with XSS content
544        let json = r#"{"comment": "<script>alert('xss')</script>"}"#;
545        let result = manager.validate_json(json);
546        assert!(result.is_err());
547    }
548
549    #[test]
550    fn test_content_type_validation() {
551        let config = ValidationConfig::default();
552        let manager = ValidationManager::new(config).unwrap();
553
554        // Allowed content type
555        let result = manager.validate_content_type(Some("application/json"));
556        assert!(result.is_ok());
557
558        // Not allowed content type
559        let result = manager.validate_content_type(Some("application/x-executable"));
560        assert!(result.is_err());
561
562        // Content type with charset
563        let result = manager.validate_content_type(Some("application/json; charset=utf-8"));
564        assert!(result.is_ok());
565    }
566
567    #[test]
568    fn test_validation_disabled() {
569        let mut config = ValidationConfig::default();
570        config.enabled = false;
571
572        let manager = ValidationManager::new(config).unwrap();
573        assert!(!manager.is_enabled());
574
575        // Should pass even with malicious content when disabled
576        let result = manager.validate_input("<script>alert('xss')</script>");
577        assert!(result.is_ok());
578    }
579
580    #[test]
581    fn test_validated_request_struct() {
582        let request = ValidatedRequest {
583            content: Some("Hello world".to_string()),
584            email: Some("test@example.com".to_string()),
585            url: Some("https://example.com".to_string()),
586            limit: Some(100),
587            offset: Some(0),
588            query: Some("safe query".to_string()),
589        };
590
591        let validation_result = request.validate();
592        assert!(validation_result.is_ok());
593    }
594
595    #[test]
596    fn test_validated_request_invalid() {
597        let request = ValidatedRequest {
598            content: Some("".to_string()),                            // Too short
599            email: Some("invalid-email".to_string()),                 // Invalid email
600            url: Some("not-a-url".to_string()),                       // Invalid URL
601            limit: Some(2000),                                        // Too large
602            offset: Some(-1),                                         // Negative
603            query: Some("<script>alert('xss')</script>".to_string()), // Unsafe content
604        };
605
606        let validation_result = request.validate();
607        assert!(validation_result.is_err());
608
609        let errors = validation_result.unwrap_err();
610        assert!(!errors.field_errors().is_empty());
611    }
612
613    #[test]
614    fn test_custom_validator() {
615        let valid_result = validate_safe_string("This is a safe string");
616        assert!(valid_result.is_ok());
617
618        let invalid_result = validate_safe_string("<script>alert('test')</script>");
619        assert!(invalid_result.is_err());
620
621        let traversal_result = validate_safe_string("../../etc/passwd");
622        assert!(traversal_result.is_err());
623    }
624
625    #[test]
626    fn test_request_size_limits() {
627        let config = ValidationConfig {
628            enabled: true,
629            max_request_size: 1024, // 1KB limit
630            sanitize_input: true,
631            xss_protection: true,
632            sql_injection_protection: true,
633        };
634
635        let manager = ValidationManager::new(config).unwrap();
636        assert_eq!(manager.get_max_request_size(), 1024);
637
638        // Large JSON should fail
639        let large_json = "x".repeat(2000);
640        let json = format!(r#"{{"data": "{large_json}"}}"#);
641        let result = manager.validate_json(&json);
642        assert!(result.is_err());
643
644        if let Err(SecurityError::ValidationError { message }) = result {
645            assert!(message.contains("exceeds maximum"));
646        }
647    }
648}