llm_config_security/
input.rs

1//! Input validation and sanitization
2
3use crate::errors::{SecurityError, SecurityResult};
4use regex::Regex;
5use std::sync::OnceLock;
6
7/// SQL injection patterns
8static SQL_INJECTION_PATTERNS: OnceLock<Vec<Regex>> = OnceLock::new();
9
10/// XSS patterns
11static XSS_PATTERNS: OnceLock<Vec<Regex>> = OnceLock::new();
12
13/// Path traversal patterns
14static PATH_TRAVERSAL_PATTERNS: OnceLock<Vec<Regex>> = OnceLock::new();
15
16/// Command injection patterns
17static COMMAND_INJECTION_PATTERNS: OnceLock<Vec<Regex>> = OnceLock::new();
18
19/// LDAP injection patterns
20static LDAP_INJECTION_PATTERNS: OnceLock<Vec<Regex>> = OnceLock::new();
21
22/// Initialize security patterns
23fn init_patterns() {
24    SQL_INJECTION_PATTERNS.get_or_init(|| {
25        vec![
26            Regex::new(r"(?i)(\bunion\b.*\bselect\b)").unwrap(),
27            Regex::new(r"(?i)(\bdrop\b.*\btable\b)").unwrap(),
28            Regex::new(r"(?i)(\binsert\b.*\binto\b)").unwrap(),
29            Regex::new(r"(?i)(\bdelete\b.*\bfrom\b)").unwrap(),
30            Regex::new(r"(?i)(\bupdate\b.*\bset\b)").unwrap(),
31            Regex::new(r"(?i)(;.*(--)|(#))").unwrap(),
32            Regex::new(r"(?i)('|(--)|;|/\*|\*/|@@|@)").unwrap(),
33            Regex::new(r"(?i)\bexec(\s|\+)+(s|x)p\w+").unwrap(),
34        ]
35    });
36
37    XSS_PATTERNS.get_or_init(|| {
38        vec![
39            Regex::new(r"(?i)<script[^>]*>.*?</script>").unwrap(),
40            Regex::new(r"(?i)javascript:").unwrap(),
41            Regex::new(r"(?i)on\w+\s*=").unwrap(),
42            Regex::new(r"(?i)<iframe").unwrap(),
43            Regex::new(r"(?i)<embed").unwrap(),
44            Regex::new(r"(?i)<object").unwrap(),
45            Regex::new(r"(?i)eval\(").unwrap(),
46            Regex::new(r"(?i)expression\(").unwrap(),
47        ]
48    });
49
50    PATH_TRAVERSAL_PATTERNS.get_or_init(|| {
51        vec![
52            Regex::new(r"\.\./").unwrap(),
53            Regex::new(r"\.\./").unwrap(),
54            Regex::new(r"%2e%2e/").unwrap(),
55            Regex::new(r"%2e%2e\\").unwrap(),
56            Regex::new(r"\.\.\\").unwrap(),
57        ]
58    });
59
60    COMMAND_INJECTION_PATTERNS.get_or_init(|| {
61        vec![
62            Regex::new(r"[;&|`$\n]").unwrap(),
63            Regex::new(r"\$\(.*\)").unwrap(),
64            Regex::new(r"`.*`").unwrap(),
65        ]
66    });
67
68    LDAP_INJECTION_PATTERNS.get_or_init(|| {
69        vec![
70            Regex::new(r"[*()\\]").unwrap(),
71            Regex::new(r"\x00").unwrap(),
72        ]
73    });
74}
75
76/// Configuration for input sanitization
77#[derive(Debug, Clone)]
78pub struct SanitizationConfig {
79    /// Maximum input length
80    pub max_length: usize,
81    /// Allow special characters
82    pub allow_special_chars: bool,
83    /// Allow HTML
84    pub allow_html: bool,
85    /// Trim whitespace
86    pub trim_whitespace: bool,
87}
88
89impl Default for SanitizationConfig {
90    fn default() -> Self {
91        Self {
92            max_length: 1000,
93            allow_special_chars: false,
94            allow_html: false,
95            trim_whitespace: true,
96        }
97    }
98}
99
100/// Input validator for security
101pub struct InputValidator {
102    config: SanitizationConfig,
103}
104
105impl InputValidator {
106    /// Create a new input validator
107    pub fn new(config: SanitizationConfig) -> Self {
108        init_patterns();
109        Self { config }
110    }
111
112    /// Create with default configuration
113    pub fn default() -> Self {
114        Self::new(SanitizationConfig::default())
115    }
116
117    /// Validate and sanitize input
118    pub fn validate(&self, input: &str) -> SecurityResult<String> {
119        // Check length
120        if input.len() > self.config.max_length {
121            return Err(SecurityError::ValidationError(format!(
122                "Input exceeds maximum length of {} characters",
123                self.config.max_length
124            )));
125        }
126
127        // Detect SQL injection
128        if self.detect_sql_injection(input) {
129            return Err(SecurityError::SqlInjectionAttempt);
130        }
131
132        // Detect XSS
133        if self.detect_xss(input) {
134            return Err(SecurityError::XssAttempt);
135        }
136
137        // Detect path traversal
138        if self.detect_path_traversal(input) {
139            return Err(SecurityError::PathTraversalAttempt);
140        }
141
142        // Detect command injection
143        if self.detect_command_injection(input) {
144            return Err(SecurityError::CommandInjectionAttempt);
145        }
146
147        // Sanitize
148        let sanitized = self.sanitize(input);
149
150        Ok(sanitized)
151    }
152
153    /// Detect SQL injection attempts
154    fn detect_sql_injection(&self, input: &str) -> bool {
155        SQL_INJECTION_PATTERNS
156            .get()
157            .unwrap()
158            .iter()
159            .any(|pattern| pattern.is_match(input))
160    }
161
162    /// Detect XSS attempts
163    fn detect_xss(&self, input: &str) -> bool {
164        if !self.config.allow_html {
165            XSS_PATTERNS
166                .get()
167                .unwrap()
168                .iter()
169                .any(|pattern| pattern.is_match(input))
170        } else {
171            false
172        }
173    }
174
175    /// Detect path traversal attempts
176    fn detect_path_traversal(&self, input: &str) -> bool {
177        PATH_TRAVERSAL_PATTERNS
178            .get()
179            .unwrap()
180            .iter()
181            .any(|pattern| pattern.is_match(input))
182    }
183
184    /// Detect command injection attempts
185    fn detect_command_injection(&self, input: &str) -> bool {
186        COMMAND_INJECTION_PATTERNS
187            .get()
188            .unwrap()
189            .iter()
190            .any(|pattern| pattern.is_match(input))
191    }
192
193    /// Detect LDAP injection attempts
194    fn detect_ldap_injection(&self, input: &str) -> bool {
195        LDAP_INJECTION_PATTERNS
196            .get()
197            .unwrap()
198            .iter()
199            .any(|pattern| pattern.is_match(input))
200    }
201
202    /// Sanitize input
203    fn sanitize(&self, input: &str) -> String {
204        let mut sanitized = input.to_string();
205
206        // Trim whitespace
207        if self.config.trim_whitespace {
208            sanitized = sanitized.trim().to_string();
209        }
210
211        // Remove null bytes
212        sanitized = sanitized.replace('\0', "");
213
214        // HTML encode if not allowing HTML
215        if !self.config.allow_html {
216            sanitized = html_escape(&sanitized);
217        }
218
219        // Remove control characters
220        sanitized = sanitized
221            .chars()
222            .filter(|c| !c.is_control() || *c == '\n' || *c == '\r' || *c == '\t')
223            .collect();
224
225        sanitized
226    }
227
228    /// Validate an email address
229    pub fn validate_email(&self, email: &str) -> SecurityResult<String> {
230        let email_regex = Regex::new(
231            r"^[a-zA-Z0-9.!#$%&'*+/=?^_`{|}~-]+@[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?(?:\.[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)*$"
232        ).unwrap();
233
234        if !email_regex.is_match(email) {
235            return Err(SecurityError::ValidationError(
236                "Invalid email format".to_string(),
237            ));
238        }
239
240        Ok(email.to_lowercase())
241    }
242
243    /// Validate a username
244    pub fn validate_username(&self, username: &str) -> SecurityResult<String> {
245        // Username: alphanumeric, underscore, hyphen, 3-30 characters
246        let username_regex = Regex::new(r"^[a-zA-Z0-9_-]{3,30}$").unwrap();
247
248        if !username_regex.is_match(username) {
249            return Err(SecurityError::ValidationError(
250                "Username must be 3-30 alphanumeric characters, underscore, or hyphen".to_string(),
251            ));
252        }
253
254        Ok(username.to_string())
255    }
256
257    /// Validate a configuration key
258    pub fn validate_config_key(&self, key: &str) -> SecurityResult<String> {
259        // Config key: alphanumeric, underscore, hyphen, dot, slash
260        let key_regex = Regex::new(r"^[a-zA-Z0-9_\-./]{1,200}$").unwrap();
261
262        if !key_regex.is_match(key) {
263            return Err(SecurityError::ValidationError(
264                "Invalid configuration key format".to_string(),
265            ));
266        }
267
268        // Additional checks
269        if self.detect_path_traversal(key) {
270            return Err(SecurityError::PathTraversalAttempt);
271        }
272
273        Ok(key.to_string())
274    }
275
276    /// Validate a URL
277    pub fn validate_url(&self, url: &str) -> SecurityResult<String> {
278        // Basic URL validation
279        let url_regex = Regex::new(r"^https?://[^\s/$.?#].[^\s]*$").unwrap();
280
281        if !url_regex.is_match(url) {
282            return Err(SecurityError::ValidationError(
283                "Invalid URL format".to_string(),
284            ));
285        }
286
287        // Check for suspicious patterns
288        if self.detect_xss(url) {
289            return Err(SecurityError::XssAttempt);
290        }
291
292        Ok(url.to_string())
293    }
294
295    /// Validate JSON input
296    pub fn validate_json(&self, json: &str) -> SecurityResult<serde_json::Value> {
297        // Check for suspicious patterns before parsing
298        if self.detect_xss(json) {
299            return Err(SecurityError::XssAttempt);
300        }
301
302        // Parse JSON
303        serde_json::from_str(json)
304            .map_err(|e| SecurityError::ValidationError(format!("Invalid JSON: {}", e)))
305    }
306}
307
308/// HTML escape special characters
309fn html_escape(input: &str) -> String {
310    input
311        .replace('&', "&amp;")
312        .replace('<', "&lt;")
313        .replace('>', "&gt;")
314        .replace('"', "&quot;")
315        .replace('\'', "&#x27;")
316        .replace('/', "&#x2F;")
317}
318
319#[cfg(test)]
320mod tests {
321    use super::*;
322
323    #[test]
324    fn test_sql_injection_detection() {
325        let validator = InputValidator::default();
326
327        // Should detect SQL injection
328        assert!(validator
329            .validate("' OR '1'='1")
330            .is_err_and(|e| matches!(e, SecurityError::SqlInjectionAttempt)));
331
332        assert!(validator
333            .validate("'; DROP TABLE users; --")
334            .is_err_and(|e| matches!(e, SecurityError::SqlInjectionAttempt)));
335
336        // Should allow normal input
337        assert!(validator.validate("normal text").is_ok());
338    }
339
340    #[test]
341    fn test_xss_detection() {
342        let validator = InputValidator::default();
343
344        // Should detect XSS
345        assert!(validator
346            .validate("<script>alert('XSS')</script>")
347            .is_err_and(|e| matches!(e, SecurityError::XssAttempt)));
348
349        assert!(validator
350            .validate("javascript:alert('XSS')")
351            .is_err_and(|e| matches!(e, SecurityError::XssAttempt)));
352
353        // Should allow normal input
354        assert!(validator.validate("normal text").is_ok());
355    }
356
357    #[test]
358    fn test_path_traversal_detection() {
359        let validator = InputValidator::default();
360
361        // Should detect path traversal
362        assert!(validator
363            .validate("../../etc/passwd")
364            .is_err_and(|e| matches!(e, SecurityError::PathTraversalAttempt)));
365
366        // Should allow normal paths
367        assert!(validator.validate("normal/path").is_ok());
368    }
369
370    #[test]
371    fn test_command_injection_detection() {
372        let validator = InputValidator::default();
373
374        // Should detect command injection
375        assert!(validator
376            .validate("test; rm -rf /")
377            .is_err_and(|e| matches!(e, SecurityError::CommandInjectionAttempt)));
378
379        assert!(validator
380            .validate("$(malicious)")
381            .is_err_and(|e| matches!(e, SecurityError::CommandInjectionAttempt)));
382
383        // Should allow normal input
384        assert!(validator.validate("normal text").is_ok());
385    }
386
387    #[test]
388    fn test_email_validation() {
389        let validator = InputValidator::default();
390
391        assert!(validator.validate_email("user@example.com").is_ok());
392        assert!(validator.validate_email("invalid.email").is_err());
393        assert!(validator.validate_email("@example.com").is_err());
394    }
395
396    #[test]
397    fn test_username_validation() {
398        let validator = InputValidator::default();
399
400        assert!(validator.validate_username("user123").is_ok());
401        assert!(validator.validate_username("user-name_123").is_ok());
402        assert!(validator.validate_username("ab").is_err()); // Too short
403        assert!(validator.validate_username("user@name").is_err()); // Invalid char
404    }
405
406    #[test]
407    fn test_config_key_validation() {
408        let validator = InputValidator::default();
409
410        assert!(validator.validate_config_key("app/config/key").is_ok());
411        assert!(validator.validate_config_key("app.config.key").is_ok());
412        assert!(validator.validate_config_key("../etc/passwd").is_err());
413    }
414
415    #[test]
416    fn test_url_validation() {
417        let validator = InputValidator::default();
418
419        assert!(validator.validate_url("https://example.com").is_ok());
420        assert!(validator.validate_url("http://example.com/path").is_ok());
421        assert!(validator.validate_url("invalid-url").is_err());
422        assert!(validator
423            .validate_url("javascript:alert('XSS')")
424            .is_err());
425    }
426
427    #[test]
428    fn test_json_validation() {
429        let validator = InputValidator::default();
430
431        assert!(validator.validate_json(r#"{"key": "value"}"#).is_ok());
432        assert!(validator.validate_json("invalid json").is_err());
433        assert!(validator
434            .validate_json(r#"{"xss": "<script>alert('XSS')</script>"}"#)
435            .is_err());
436    }
437
438    #[test]
439    fn test_sanitization() {
440        let validator = InputValidator::default();
441
442        let result = validator.validate("  test input  ").unwrap();
443        assert_eq!(result, "test input");
444
445        let result = validator.validate("test<script>").unwrap();
446        assert!(result.contains("&lt;script&gt;"));
447    }
448
449    #[test]
450    fn test_length_validation() {
451        let config = SanitizationConfig {
452            max_length: 10,
453            ..Default::default()
454        };
455        let validator = InputValidator::new(config);
456
457        assert!(validator.validate("short").is_ok());
458        assert!(validator.validate("this is way too long").is_err());
459    }
460}