llm_config_security/
input.rs1use crate::errors::{SecurityError, SecurityResult};
4use regex::Regex;
5use std::sync::OnceLock;
6
7static SQL_INJECTION_PATTERNS: OnceLock<Vec<Regex>> = OnceLock::new();
9
10static XSS_PATTERNS: OnceLock<Vec<Regex>> = OnceLock::new();
12
13static PATH_TRAVERSAL_PATTERNS: OnceLock<Vec<Regex>> = OnceLock::new();
15
16static COMMAND_INJECTION_PATTERNS: OnceLock<Vec<Regex>> = OnceLock::new();
18
19static LDAP_INJECTION_PATTERNS: OnceLock<Vec<Regex>> = OnceLock::new();
21
22fn init_patterns() {
24 SQL_INJECTION_PATTERNS.get_or_init(|| {
25 vec![
26 Regex::new(r"(?i)(\bunion\b.*\bselect\b)").unwrap(),
27 Regex::new(r"(?i)(\bdrop\b.*\btable\b)").unwrap(),
28 Regex::new(r"(?i)(\binsert\b.*\binto\b)").unwrap(),
29 Regex::new(r"(?i)(\bdelete\b.*\bfrom\b)").unwrap(),
30 Regex::new(r"(?i)(\bupdate\b.*\bset\b)").unwrap(),
31 Regex::new(r"(?i)(;.*(--)|(#))").unwrap(),
32 Regex::new(r"(?i)('|(--)|;|/\*|\*/|@@|@)").unwrap(),
33 Regex::new(r"(?i)\bexec(\s|\+)+(s|x)p\w+").unwrap(),
34 ]
35 });
36
37 XSS_PATTERNS.get_or_init(|| {
38 vec![
39 Regex::new(r"(?i)<script[^>]*>.*?</script>").unwrap(),
40 Regex::new(r"(?i)javascript:").unwrap(),
41 Regex::new(r"(?i)on\w+\s*=").unwrap(),
42 Regex::new(r"(?i)<iframe").unwrap(),
43 Regex::new(r"(?i)<embed").unwrap(),
44 Regex::new(r"(?i)<object").unwrap(),
45 Regex::new(r"(?i)eval\(").unwrap(),
46 Regex::new(r"(?i)expression\(").unwrap(),
47 ]
48 });
49
50 PATH_TRAVERSAL_PATTERNS.get_or_init(|| {
51 vec![
52 Regex::new(r"\.\./").unwrap(),
53 Regex::new(r"\.\./").unwrap(),
54 Regex::new(r"%2e%2e/").unwrap(),
55 Regex::new(r"%2e%2e\\").unwrap(),
56 Regex::new(r"\.\.\\").unwrap(),
57 ]
58 });
59
60 COMMAND_INJECTION_PATTERNS.get_or_init(|| {
61 vec![
62 Regex::new(r"[;&|`$\n]").unwrap(),
63 Regex::new(r"\$\(.*\)").unwrap(),
64 Regex::new(r"`.*`").unwrap(),
65 ]
66 });
67
68 LDAP_INJECTION_PATTERNS.get_or_init(|| {
69 vec![
70 Regex::new(r"[*()\\]").unwrap(),
71 Regex::new(r"\x00").unwrap(),
72 ]
73 });
74}
75
76#[derive(Debug, Clone)]
78pub struct SanitizationConfig {
79 pub max_length: usize,
81 pub allow_special_chars: bool,
83 pub allow_html: bool,
85 pub trim_whitespace: bool,
87}
88
89impl Default for SanitizationConfig {
90 fn default() -> Self {
91 Self {
92 max_length: 1000,
93 allow_special_chars: false,
94 allow_html: false,
95 trim_whitespace: true,
96 }
97 }
98}
99
100pub struct InputValidator {
102 config: SanitizationConfig,
103}
104
105impl InputValidator {
106 pub fn new(config: SanitizationConfig) -> Self {
108 init_patterns();
109 Self { config }
110 }
111
112 pub fn default() -> Self {
114 Self::new(SanitizationConfig::default())
115 }
116
117 pub fn validate(&self, input: &str) -> SecurityResult<String> {
119 if input.len() > self.config.max_length {
121 return Err(SecurityError::ValidationError(format!(
122 "Input exceeds maximum length of {} characters",
123 self.config.max_length
124 )));
125 }
126
127 if self.detect_sql_injection(input) {
129 return Err(SecurityError::SqlInjectionAttempt);
130 }
131
132 if self.detect_xss(input) {
134 return Err(SecurityError::XssAttempt);
135 }
136
137 if self.detect_path_traversal(input) {
139 return Err(SecurityError::PathTraversalAttempt);
140 }
141
142 if self.detect_command_injection(input) {
144 return Err(SecurityError::CommandInjectionAttempt);
145 }
146
147 let sanitized = self.sanitize(input);
149
150 Ok(sanitized)
151 }
152
153 fn detect_sql_injection(&self, input: &str) -> bool {
155 SQL_INJECTION_PATTERNS
156 .get()
157 .unwrap()
158 .iter()
159 .any(|pattern| pattern.is_match(input))
160 }
161
162 fn detect_xss(&self, input: &str) -> bool {
164 if !self.config.allow_html {
165 XSS_PATTERNS
166 .get()
167 .unwrap()
168 .iter()
169 .any(|pattern| pattern.is_match(input))
170 } else {
171 false
172 }
173 }
174
175 fn detect_path_traversal(&self, input: &str) -> bool {
177 PATH_TRAVERSAL_PATTERNS
178 .get()
179 .unwrap()
180 .iter()
181 .any(|pattern| pattern.is_match(input))
182 }
183
184 fn detect_command_injection(&self, input: &str) -> bool {
186 COMMAND_INJECTION_PATTERNS
187 .get()
188 .unwrap()
189 .iter()
190 .any(|pattern| pattern.is_match(input))
191 }
192
193 fn detect_ldap_injection(&self, input: &str) -> bool {
195 LDAP_INJECTION_PATTERNS
196 .get()
197 .unwrap()
198 .iter()
199 .any(|pattern| pattern.is_match(input))
200 }
201
202 fn sanitize(&self, input: &str) -> String {
204 let mut sanitized = input.to_string();
205
206 if self.config.trim_whitespace {
208 sanitized = sanitized.trim().to_string();
209 }
210
211 sanitized = sanitized.replace('\0', "");
213
214 if !self.config.allow_html {
216 sanitized = html_escape(&sanitized);
217 }
218
219 sanitized = sanitized
221 .chars()
222 .filter(|c| !c.is_control() || *c == '\n' || *c == '\r' || *c == '\t')
223 .collect();
224
225 sanitized
226 }
227
228 pub fn validate_email(&self, email: &str) -> SecurityResult<String> {
230 let email_regex = Regex::new(
231 r"^[a-zA-Z0-9.!#$%&'*+/=?^_`{|}~-]+@[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?(?:\.[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)*$"
232 ).unwrap();
233
234 if !email_regex.is_match(email) {
235 return Err(SecurityError::ValidationError(
236 "Invalid email format".to_string(),
237 ));
238 }
239
240 Ok(email.to_lowercase())
241 }
242
243 pub fn validate_username(&self, username: &str) -> SecurityResult<String> {
245 let username_regex = Regex::new(r"^[a-zA-Z0-9_-]{3,30}$").unwrap();
247
248 if !username_regex.is_match(username) {
249 return Err(SecurityError::ValidationError(
250 "Username must be 3-30 alphanumeric characters, underscore, or hyphen".to_string(),
251 ));
252 }
253
254 Ok(username.to_string())
255 }
256
257 pub fn validate_config_key(&self, key: &str) -> SecurityResult<String> {
259 let key_regex = Regex::new(r"^[a-zA-Z0-9_\-./]{1,200}$").unwrap();
261
262 if !key_regex.is_match(key) {
263 return Err(SecurityError::ValidationError(
264 "Invalid configuration key format".to_string(),
265 ));
266 }
267
268 if self.detect_path_traversal(key) {
270 return Err(SecurityError::PathTraversalAttempt);
271 }
272
273 Ok(key.to_string())
274 }
275
276 pub fn validate_url(&self, url: &str) -> SecurityResult<String> {
278 let url_regex = Regex::new(r"^https?://[^\s/$.?#].[^\s]*$").unwrap();
280
281 if !url_regex.is_match(url) {
282 return Err(SecurityError::ValidationError(
283 "Invalid URL format".to_string(),
284 ));
285 }
286
287 if self.detect_xss(url) {
289 return Err(SecurityError::XssAttempt);
290 }
291
292 Ok(url.to_string())
293 }
294
295 pub fn validate_json(&self, json: &str) -> SecurityResult<serde_json::Value> {
297 if self.detect_xss(json) {
299 return Err(SecurityError::XssAttempt);
300 }
301
302 serde_json::from_str(json)
304 .map_err(|e| SecurityError::ValidationError(format!("Invalid JSON: {}", e)))
305 }
306}
307
308fn html_escape(input: &str) -> String {
310 input
311 .replace('&', "&")
312 .replace('<', "<")
313 .replace('>', ">")
314 .replace('"', """)
315 .replace('\'', "'")
316 .replace('/', "/")
317}
318
319#[cfg(test)]
320mod tests {
321 use super::*;
322
323 #[test]
324 fn test_sql_injection_detection() {
325 let validator = InputValidator::default();
326
327 assert!(validator
329 .validate("' OR '1'='1")
330 .is_err_and(|e| matches!(e, SecurityError::SqlInjectionAttempt)));
331
332 assert!(validator
333 .validate("'; DROP TABLE users; --")
334 .is_err_and(|e| matches!(e, SecurityError::SqlInjectionAttempt)));
335
336 assert!(validator.validate("normal text").is_ok());
338 }
339
340 #[test]
341 fn test_xss_detection() {
342 let validator = InputValidator::default();
343
344 assert!(validator
346 .validate("<script>alert('XSS')</script>")
347 .is_err_and(|e| matches!(e, SecurityError::XssAttempt)));
348
349 assert!(validator
350 .validate("javascript:alert('XSS')")
351 .is_err_and(|e| matches!(e, SecurityError::XssAttempt)));
352
353 assert!(validator.validate("normal text").is_ok());
355 }
356
357 #[test]
358 fn test_path_traversal_detection() {
359 let validator = InputValidator::default();
360
361 assert!(validator
363 .validate("../../etc/passwd")
364 .is_err_and(|e| matches!(e, SecurityError::PathTraversalAttempt)));
365
366 assert!(validator.validate("normal/path").is_ok());
368 }
369
370 #[test]
371 fn test_command_injection_detection() {
372 let validator = InputValidator::default();
373
374 assert!(validator
376 .validate("test; rm -rf /")
377 .is_err_and(|e| matches!(e, SecurityError::CommandInjectionAttempt)));
378
379 assert!(validator
380 .validate("$(malicious)")
381 .is_err_and(|e| matches!(e, SecurityError::CommandInjectionAttempt)));
382
383 assert!(validator.validate("normal text").is_ok());
385 }
386
387 #[test]
388 fn test_email_validation() {
389 let validator = InputValidator::default();
390
391 assert!(validator.validate_email("user@example.com").is_ok());
392 assert!(validator.validate_email("invalid.email").is_err());
393 assert!(validator.validate_email("@example.com").is_err());
394 }
395
396 #[test]
397 fn test_username_validation() {
398 let validator = InputValidator::default();
399
400 assert!(validator.validate_username("user123").is_ok());
401 assert!(validator.validate_username("user-name_123").is_ok());
402 assert!(validator.validate_username("ab").is_err()); assert!(validator.validate_username("user@name").is_err()); }
405
406 #[test]
407 fn test_config_key_validation() {
408 let validator = InputValidator::default();
409
410 assert!(validator.validate_config_key("app/config/key").is_ok());
411 assert!(validator.validate_config_key("app.config.key").is_ok());
412 assert!(validator.validate_config_key("../etc/passwd").is_err());
413 }
414
415 #[test]
416 fn test_url_validation() {
417 let validator = InputValidator::default();
418
419 assert!(validator.validate_url("https://example.com").is_ok());
420 assert!(validator.validate_url("http://example.com/path").is_ok());
421 assert!(validator.validate_url("invalid-url").is_err());
422 assert!(validator
423 .validate_url("javascript:alert('XSS')")
424 .is_err());
425 }
426
427 #[test]
428 fn test_json_validation() {
429 let validator = InputValidator::default();
430
431 assert!(validator.validate_json(r#"{"key": "value"}"#).is_ok());
432 assert!(validator.validate_json("invalid json").is_err());
433 assert!(validator
434 .validate_json(r#"{"xss": "<script>alert('XSS')</script>"}"#)
435 .is_err());
436 }
437
438 #[test]
439 fn test_sanitization() {
440 let validator = InputValidator::default();
441
442 let result = validator.validate(" test input ").unwrap();
443 assert_eq!(result, "test input");
444
445 let result = validator.validate("test<script>").unwrap();
446 assert!(result.contains("<script>"));
447 }
448
449 #[test]
450 fn test_length_validation() {
451 let config = SanitizationConfig {
452 max_length: 10,
453 ..Default::default()
454 };
455 let validator = InputValidator::new(config);
456
457 assert!(validator.validate("short").is_ok());
458 assert!(validator.validate("this is way too long").is_err());
459 }
460}