codex_memory/security/
validation.rs

1use crate::security::{Result, SecurityError, ValidationConfig};
2use axum::{
3    extract::{Request, State},
4    http::{header, HeaderMap, StatusCode},
5    middleware::Next,
6    response::Response,
7};
8use regex::Regex;
9use serde::{Deserialize, Serialize};
10use std::sync::Arc;
11use tracing::{debug, warn};
12use validator::{Validate, ValidationError};
13
14/// Input validation and sanitization manager
15pub struct ValidationManager {
16    config: ValidationConfig,
17    sql_injection_patterns: Vec<Regex>,
18    xss_patterns: Vec<Regex>,
19    malicious_patterns: Vec<Regex>,
20}
21
22impl ValidationManager {
23    pub fn new(config: ValidationConfig) -> Result<Self> {
24        let mut manager = Self {
25            config,
26            sql_injection_patterns: Vec::new(),
27            xss_patterns: Vec::new(),
28            malicious_patterns: Vec::new(),
29        };
30
31        if manager.config.enabled {
32            manager.initialize_patterns()?;
33        }
34
35        Ok(manager)
36    }
37
38    fn initialize_patterns(&mut self) -> Result<()> {
39        // SQL injection patterns
40        if self.config.sql_injection_protection {
41            let sql_patterns = vec![
42                r"(?i)(union\s+select)",
43                r"(?i)(drop\s+table)",
44                r"(?i)(delete\s+from)",
45                r"(?i)(insert\s+into)",
46                r"(?i)(update\s+set)",
47                r"(?i)(alter\s+table)",
48                r"(?i)(create\s+table)",
49                r"(?i)(exec\s*\()",
50                r"(?i)(execute\s*\()",
51                r"(?i)(\'\s*or\s*\'\s*=\s*\')",
52                r"(?i)(\'\s*or\s*1\s*=\s*1)",
53                r"(?i)(\'\s*;\s*drop)",
54                r"(?i)(--\s*)",
55                r"(?i)(/\*.*\*/)",
56                r"(?i)(xp_cmdshell)",
57                r"(?i)(sp_executesql)",
58            ];
59
60            for pattern in sql_patterns {
61                self.sql_injection_patterns
62                    .push(
63                        Regex::new(pattern).map_err(|e| SecurityError::ValidationError {
64                            message: format!("Failed to compile SQL injection pattern: {e}"),
65                        })?,
66                    );
67            }
68        }
69
70        // XSS patterns
71        if self.config.xss_protection {
72            let xss_patterns = vec![
73                r"(?i)<script[^>]*>",
74                r"(?i)</script>",
75                r"(?i)<iframe[^>]*>",
76                r"(?i)<object[^>]*>",
77                r"(?i)<embed[^>]*>",
78                r"(?i)<link[^>]*>",
79                r"(?i)<meta[^>]*>",
80                r"(?i)javascript:",
81                r"(?i)vbscript:",
82                r"(?i)onload\s*=",
83                r"(?i)onerror\s*=",
84                r"(?i)onclick\s*=",
85                r"(?i)onmouseover\s*=",
86                r"(?i)onfocus\s*=",
87                r"(?i)onblur\s*=",
88                r"(?i)onchange\s*=",
89                r"(?i)onsubmit\s*=",
90                r"(?i)expression\s*\(",
91                r"(?i)url\s*\(",
92                r"(?i)@import",
93            ];
94
95            for pattern in xss_patterns {
96                self.xss_patterns.push(Regex::new(pattern).map_err(|e| {
97                    SecurityError::ValidationError {
98                        message: format!("Failed to compile XSS pattern: {e}"),
99                    }
100                })?);
101            }
102        }
103
104        // General malicious patterns
105        let malicious_patterns = vec![
106            r"(?i)(\.\.\/){2,}", // Path traversal
107            r"(?i)\.\.\\",       // Windows path traversal
108            r"(?i)\/etc\/passwd",
109            r"(?i)\/etc\/shadow",
110            r"(?i)\/proc\/",
111            r"(?i)c:\\windows\\",
112            r"(?i)cmd\.exe",
113            r"(?i)powershell\.exe",
114            r"(?i)bash\s*-c",
115            r"(?i)sh\s*-c",
116            r"(?i)\$\([^)]*\)", // Command substitution
117            r"(?i)`[^`]*`",     // Backtick command execution
118        ];
119
120        for pattern in malicious_patterns {
121            self.malicious_patterns
122                .push(
123                    Regex::new(pattern).map_err(|e| SecurityError::ValidationError {
124                        message: format!("Failed to compile malicious pattern: {e}"),
125                    })?,
126                );
127        }
128
129        debug!(
130            "Initialized validation patterns: {} SQL, {} XSS, {} malicious",
131            self.sql_injection_patterns.len(),
132            self.xss_patterns.len(),
133            self.malicious_patterns.len()
134        );
135
136        Ok(())
137    }
138
139    /// Validate and sanitize input string
140    pub fn validate_input(&self, input: &str) -> Result<String> {
141        if !self.config.enabled {
142            return Ok(input.to_string());
143        }
144
145        // Check for SQL injection
146        if self.config.sql_injection_protection {
147            for pattern in &self.sql_injection_patterns {
148                if pattern.is_match(input) {
149                    warn!("SQL injection attempt detected: {}", pattern.as_str());
150                    return Err(SecurityError::ValidationError {
151                        message: "Potential SQL injection detected".to_string(),
152                    });
153                }
154            }
155        }
156
157        // Check for XSS
158        if self.config.xss_protection {
159            for pattern in &self.xss_patterns {
160                if pattern.is_match(input) {
161                    warn!("XSS attempt detected: {}", pattern.as_str());
162                    return Err(SecurityError::ValidationError {
163                        message: "Potential XSS detected".to_string(),
164                    });
165                }
166            }
167        }
168
169        // Check for general malicious patterns
170        for pattern in &self.malicious_patterns {
171            if pattern.is_match(input) {
172                warn!("Malicious pattern detected: {}", pattern.as_str());
173                return Err(SecurityError::ValidationError {
174                    message: "Malicious content detected".to_string(),
175                });
176            }
177        }
178
179        // Sanitize if enabled
180        if self.config.sanitize_input {
181            Ok(self.sanitize_input(input))
182        } else {
183            Ok(input.to_string())
184        }
185    }
186
187    /// Sanitize input by removing or escaping dangerous characters
188    fn sanitize_input(&self, input: &str) -> String {
189        let mut sanitized = input.to_string();
190
191        // Remove null bytes
192        sanitized = sanitized.replace('\0', "");
193
194        // Remove or escape common dangerous characters
195        sanitized = sanitized.replace('\r', "");
196        sanitized = sanitized.replace('\n', " ");
197        sanitized = sanitized.replace('\t', " ");
198
199        // Escape HTML entities if XSS protection is enabled
200        if self.config.xss_protection {
201            sanitized = sanitized.replace('<', "&lt;");
202            sanitized = sanitized.replace('>', "&gt;");
203            sanitized = sanitized.replace('"', "&quot;");
204            sanitized = sanitized.replace('\'', "&#x27;");
205            sanitized = sanitized.replace('&', "&amp;");
206        }
207
208        // Limit length to prevent buffer overflow attacks
209        if sanitized.len() > 10000 {
210            sanitized.truncate(10000);
211            sanitized.push_str("...");
212        }
213
214        sanitized
215    }
216
217    /// Validate JSON payload
218    pub fn validate_json(&self, json_str: &str) -> Result<serde_json::Value> {
219        if !self.config.enabled {
220            return serde_json::from_str(json_str).map_err(|e| SecurityError::ValidationError {
221                message: format!("Invalid JSON: {e}"),
222            });
223        }
224
225        // Check JSON size
226        if json_str.len() > self.config.max_request_size as usize {
227            return Err(SecurityError::ValidationError {
228                message: "Request size exceeds maximum allowed".to_string(),
229            });
230        }
231
232        // Parse JSON
233        let json_value: serde_json::Value =
234            serde_json::from_str(json_str).map_err(|e| SecurityError::ValidationError {
235                message: format!("Invalid JSON: {e}"),
236            })?;
237
238        // Validate JSON content recursively
239        self.validate_json_value(&json_value)?;
240
241        Ok(json_value)
242    }
243
244    fn validate_json_value(&self, value: &serde_json::Value) -> Result<()> {
245        match value {
246            serde_json::Value::String(s) => {
247                self.validate_input(s)?;
248            }
249            serde_json::Value::Array(arr) => {
250                for item in arr {
251                    self.validate_json_value(item)?;
252                }
253            }
254            serde_json::Value::Object(obj) => {
255                for (key, val) in obj {
256                    self.validate_input(key)?;
257                    self.validate_json_value(val)?;
258                }
259            }
260            _ => {} // Numbers, booleans, null are safe
261        }
262        Ok(())
263    }
264
265    /// Validate HTTP headers
266    pub fn validate_headers(&self, headers: &HeaderMap) -> Result<()> {
267        if !self.config.enabled {
268            return Ok(());
269        }
270
271        // Check User-Agent header
272        if let Some(user_agent) = headers.get(header::USER_AGENT) {
273            if let Ok(ua_str) = user_agent.to_str() {
274                self.validate_input(ua_str)?;
275
276                // Check for suspicious user agents
277                let suspicious_patterns = vec![
278                    r"(?i)(sqlmap|nmap|nikto|dirb|gobuster)",
279                    r"(?i)(masscan|zap|burp|wget|curl)",
280                    r"(?i)(python-requests|libwww-perl)",
281                ];
282
283                for pattern_str in suspicious_patterns {
284                    let pattern = Regex::new(pattern_str).unwrap();
285                    if pattern.is_match(ua_str) {
286                        warn!("Suspicious user agent detected: {}", ua_str);
287                        return Err(SecurityError::ValidationError {
288                            message: "Suspicious user agent".to_string(),
289                        });
290                    }
291                }
292            }
293        }
294
295        // Check Referer header for common attacks
296        if let Some(referer) = headers.get(header::REFERER) {
297            if let Ok(referer_str) = referer.to_str() {
298                self.validate_input(referer_str)?;
299            }
300        }
301
302        // Check custom headers
303        for (_name, value) in headers {
304            if let Ok(value_str) = value.to_str() {
305                self.validate_input(value_str)?;
306            }
307        }
308
309        Ok(())
310    }
311
312    /// Check if content type is allowed
313    pub fn validate_content_type(&self, content_type: Option<&str>) -> Result<()> {
314        if !self.config.enabled {
315            return Ok(());
316        }
317
318        let allowed_types = [
319            "application/json",
320            "application/x-www-form-urlencoded",
321            "text/plain",
322            "multipart/form-data",
323        ];
324
325        if let Some(ct) = content_type {
326            let ct_main = ct.split(';').next().unwrap_or(ct).trim();
327
328            if !allowed_types.contains(&ct_main) {
329                return Err(SecurityError::ValidationError {
330                    message: format!("Content type not allowed: {ct_main}"),
331                });
332            }
333        }
334
335        Ok(())
336    }
337
338    pub fn is_enabled(&self) -> bool {
339        self.config.enabled
340    }
341
342    pub fn get_max_request_size(&self) -> u64 {
343        self.config.max_request_size
344    }
345}
346
347/// Request validation data for structured validation
348#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
349pub struct ValidatedRequest {
350    #[validate(length(
351        min = 1,
352        max = 1000,
353        message = "Content must be between 1 and 1000 characters"
354    ))]
355    pub content: Option<String>,
356
357    #[validate(email(message = "Invalid email format"))]
358    pub email: Option<String>,
359
360    #[validate(url(message = "Invalid URL format"))]
361    pub url: Option<String>,
362
363    #[validate(range(min = 1, max = 1000, message = "Limit must be between 1 and 1000"))]
364    pub limit: Option<i32>,
365
366    #[validate(range(min = 0, message = "Offset must be non-negative"))]
367    pub offset: Option<i32>,
368
369    #[validate(length(min = 1, max = 1000))]
370    pub query: Option<String>,
371}
372
373/// Custom validator for safe strings
374#[allow(dead_code)]
375fn validate_safe_string(value: &str) -> std::result::Result<(), ValidationError> {
376    // Check for dangerous characters
377    if value.contains("<script") || value.contains("javascript:") || value.contains("../../") {
378        return Err(ValidationError::new("unsafe_content"));
379    }
380
381    Ok(())
382}
383
384/// Validation middleware for Axum
385pub async fn validation_middleware(
386    State(validator): State<Arc<ValidationManager>>,
387    headers: HeaderMap,
388    request: Request,
389    next: Next,
390) -> std::result::Result<Response, StatusCode> {
391    if !validator.is_enabled() {
392        return Ok(next.run(request).await);
393    }
394
395    // Validate headers
396    if validator.validate_headers(&headers).is_err() {
397        warn!("Request validation failed: invalid headers");
398        return Err(StatusCode::BAD_REQUEST);
399    }
400
401    // Validate content type
402    let content_type = headers
403        .get(header::CONTENT_TYPE)
404        .and_then(|ct| ct.to_str().ok());
405
406    if validator.validate_content_type(content_type).is_err() {
407        warn!("Request validation failed: invalid content type");
408        return Err(StatusCode::UNSUPPORTED_MEDIA_TYPE);
409    }
410
411    // Check request size
412    if let Some(content_length) = headers.get(header::CONTENT_LENGTH) {
413        if let Ok(length_str) = content_length.to_str() {
414            if let Ok(length) = length_str.parse::<u64>() {
415                if length > validator.get_max_request_size() {
416                    warn!(
417                        "Request validation failed: request too large ({} bytes)",
418                        length
419                    );
420                    return Err(StatusCode::PAYLOAD_TOO_LARGE);
421                }
422            }
423        }
424    }
425
426    debug!("Request validation passed");
427    Ok(next.run(request).await)
428}
429
430#[cfg(test)]
431mod tests {
432    use super::*;
433
434    #[test]
435    fn test_validation_manager_creation() {
436        let config = ValidationConfig::default();
437        let manager = ValidationManager::new(config).unwrap();
438        assert!(manager.is_enabled());
439    }
440
441    #[test]
442    fn test_sql_injection_detection() {
443        let mut config = ValidationConfig::default();
444        config.sql_injection_protection = true;
445
446        let manager = ValidationManager::new(config).unwrap();
447
448        // Test valid input
449        let result = manager.validate_input("SELECT name FROM users WHERE id = 1");
450        assert!(result.is_ok());
451
452        // Test SQL injection attempt
453        let result = manager.validate_input("'; DROP TABLE users; --");
454        assert!(result.is_err());
455
456        if let Err(SecurityError::ValidationError { message }) = result {
457            assert!(message.contains("SQL injection"));
458        }
459    }
460
461    #[test]
462    fn test_xss_detection() {
463        let mut config = ValidationConfig::default();
464        config.xss_protection = true;
465
466        let manager = ValidationManager::new(config).unwrap();
467
468        // Test valid input
469        let result = manager.validate_input("Hello world!");
470        assert!(result.is_ok());
471
472        // Test XSS attempt
473        let result = manager.validate_input("<script>alert('xss')</script>");
474        assert!(result.is_err());
475
476        if let Err(SecurityError::ValidationError { message }) = result {
477            assert!(message.contains("XSS"));
478        }
479    }
480
481    #[test]
482    fn test_input_sanitization() {
483        let mut config = ValidationConfig::default();
484        config.sanitize_input = true;
485        config.xss_protection = true;
486
487        let manager = ValidationManager::new(config).unwrap();
488
489        let result = manager
490            .validate_input("Hello <world> & 'test' \"quote\"")
491            .unwrap();
492        assert_eq!(
493            result,
494            "Hello &lt;world&gt; &amp; &#x27;test&#x27; &quot;quote&quot;"
495        );
496    }
497
498    #[test]
499    fn test_malicious_pattern_detection() {
500        let config = ValidationConfig::default();
501        let manager = ValidationManager::new(config).unwrap();
502
503        // Test path traversal
504        let result = manager.validate_input("../../../etc/passwd");
505        assert!(result.is_err());
506
507        // Test command injection
508        let result = manager.validate_input("test; rm -rf /");
509        assert!(result.is_ok()); // This specific pattern isn't in our malicious patterns
510
511        // Test directory traversal
512        let result = manager.validate_input("../../etc/shadow");
513        assert!(result.is_err());
514    }
515
516    #[test]
517    fn test_json_validation() {
518        let config = ValidationConfig::default();
519        let manager = ValidationManager::new(config).unwrap();
520
521        // Valid JSON
522        let json = r#"{"name": "test", "value": 123}"#;
523        let result = manager.validate_json(json);
524        assert!(result.is_ok());
525
526        // Invalid JSON
527        let invalid_json = r#"{"name": "test", "value":}"#;
528        let result = manager.validate_json(invalid_json);
529        assert!(result.is_err());
530    }
531
532    #[test]
533    fn test_json_content_validation() {
534        let mut config = ValidationConfig::default();
535        config.xss_protection = true;
536
537        let manager = ValidationManager::new(config).unwrap();
538
539        // JSON with XSS content
540        let json = r#"{"comment": "<script>alert('xss')</script>"}"#;
541        let result = manager.validate_json(json);
542        assert!(result.is_err());
543    }
544
545    #[test]
546    fn test_content_type_validation() {
547        let config = ValidationConfig::default();
548        let manager = ValidationManager::new(config).unwrap();
549
550        // Allowed content type
551        let result = manager.validate_content_type(Some("application/json"));
552        assert!(result.is_ok());
553
554        // Not allowed content type
555        let result = manager.validate_content_type(Some("application/x-executable"));
556        assert!(result.is_err());
557
558        // Content type with charset
559        let result = manager.validate_content_type(Some("application/json; charset=utf-8"));
560        assert!(result.is_ok());
561    }
562
563    #[test]
564    fn test_validation_disabled() {
565        let mut config = ValidationConfig::default();
566        config.enabled = false;
567
568        let manager = ValidationManager::new(config).unwrap();
569        assert!(!manager.is_enabled());
570
571        // Should pass even with malicious content when disabled
572        let result = manager.validate_input("<script>alert('xss')</script>");
573        assert!(result.is_ok());
574    }
575
576    #[test]
577    fn test_validated_request_struct() {
578        let request = ValidatedRequest {
579            content: Some("Hello world".to_string()),
580            email: Some("test@example.com".to_string()),
581            url: Some("https://example.com".to_string()),
582            limit: Some(100),
583            offset: Some(0),
584            query: Some("safe query".to_string()),
585        };
586
587        let validation_result = request.validate();
588        assert!(validation_result.is_ok());
589    }
590
591    #[test]
592    fn test_validated_request_invalid() {
593        let request = ValidatedRequest {
594            content: Some("".to_string()),                            // Too short
595            email: Some("invalid-email".to_string()),                 // Invalid email
596            url: Some("not-a-url".to_string()),                       // Invalid URL
597            limit: Some(2000),                                        // Too large
598            offset: Some(-1),                                         // Negative
599            query: Some("<script>alert('xss')</script>".to_string()), // Unsafe content
600        };
601
602        let validation_result = request.validate();
603        assert!(validation_result.is_err());
604
605        let errors = validation_result.unwrap_err();
606        assert!(!errors.field_errors().is_empty());
607    }
608
609    #[test]
610    fn test_custom_validator() {
611        let valid_result = validate_safe_string("This is a safe string");
612        assert!(valid_result.is_ok());
613
614        let invalid_result = validate_safe_string("<script>alert('test')</script>");
615        assert!(invalid_result.is_err());
616
617        let traversal_result = validate_safe_string("../../etc/passwd");
618        assert!(traversal_result.is_err());
619    }
620
621    #[test]
622    fn test_request_size_limits() {
623        let config = ValidationConfig {
624            enabled: true,
625            max_request_size: 1024, // 1KB limit
626            sanitize_input: true,
627            xss_protection: true,
628            sql_injection_protection: true,
629        };
630
631        let manager = ValidationManager::new(config).unwrap();
632        assert_eq!(manager.get_max_request_size(), 1024);
633
634        // Large JSON should fail
635        let large_json = "x".repeat(2000);
636        let json = format!(r#"{{"data": "{large_json}"}}"#);
637        let result = manager.validate_json(&json);
638        assert!(result.is_err());
639
640        if let Err(SecurityError::ValidationError { message }) = result {
641            assert!(message.contains("exceeds maximum"));
642        }
643    }
644}