1use std::io;
4use std::path::{Component, PathBuf};
5
6const MAX_FILE_SIZE: u64 = 10_000_000;
8
9const MAX_EXPRESSION_LENGTH: usize = 10_000;
11
12const MAX_PATTERN_LENGTH: usize = 1_000;
14
15pub fn validate_file_path(path: &str) -> io::Result<PathBuf> {
30 if path.trim().is_empty() {
32 return Err(io::Error::new(
33 io::ErrorKind::InvalidInput,
34 "Path cannot be empty",
35 ));
36 }
37
38 let path_buf = PathBuf::from(path);
39
40 if path_buf.is_absolute() {
42 return Err(io::Error::new(
43 io::ErrorKind::PermissionDenied,
44 "Absolute paths are not allowed",
45 ));
46 }
47
48 for component in path_buf.components() {
50 if matches!(component, Component::ParentDir) {
51 return Err(io::Error::new(
52 io::ErrorKind::PermissionDenied,
53 "Path traversal (..) is not allowed",
54 ));
55 }
56 }
57
58 let cwd = std::env::current_dir()?;
60 let full_path = cwd.join(&path_buf);
61
62 if full_path.exists() || full_path.symlink_metadata().is_ok() {
65 let canonical_path = match full_path.canonicalize() {
68 Ok(path) => path,
69 Err(_) => {
70 return Err(io::Error::new(
72 io::ErrorKind::PermissionDenied,
73 "Path cannot be resolved (may be broken symlink or inaccessible)",
74 ));
75 }
76 };
77
78 if !canonical_path.starts_with(&cwd) {
80 return Err(io::Error::new(
81 io::ErrorKind::PermissionDenied,
82 "Path must be within current directory",
83 ));
84 }
85 } else {
86 if let Some(parent) = full_path.parent() {
89 match parent.canonicalize() {
91 Ok(canonical_parent) => {
92 if !canonical_parent.starts_with(&cwd) {
93 return Err(io::Error::new(
94 io::ErrorKind::PermissionDenied,
95 "Path must be within current directory",
96 ));
97 }
98 }
99 Err(_) => {
100 }
103 }
104 }
105 }
106
107 Ok(path_buf)
108}
109
110pub fn validate_file_path_with_context(path: &str) -> io::Result<(PathBuf, PathBuf)> {
125 let validated_path = validate_file_path(path)?;
127
128 let validation_cwd = std::env::current_dir()?;
130
131 Ok((validated_path, validation_cwd))
132}
133
134pub fn verify_path_still_valid(path: &PathBuf, validation_cwd: &PathBuf) -> io::Result<()> {
150 let current_cwd = std::env::current_dir()?;
152 if current_cwd != *validation_cwd {
153 return Err(io::Error::new(
154 io::ErrorKind::PermissionDenied,
155 "Current working directory changed since path validation",
156 ));
157 }
158
159 let full_path = current_cwd.join(path);
161 if !full_path.starts_with(¤t_cwd) {
162 return Err(io::Error::new(
163 io::ErrorKind::PermissionDenied,
164 "Path is no longer within current directory",
165 ));
166 }
167
168 Ok(())
169}
170
171fn count_balanced_with_string_awareness(
184 expr: &str,
185 open_char: char,
186 close_char: char,
187) -> (usize, usize) {
188 let mut in_string = false;
189 let mut string_delimiter = ' ';
190 let mut escape_next = false;
191 let mut open_count = 0;
192 let mut close_count = 0;
193
194 for ch in expr.chars() {
195 if escape_next {
196 escape_next = false;
197 continue;
198 }
199
200 match ch {
201 '\\' => escape_next = true,
202 '"' | '\'' => {
203 if in_string && ch == string_delimiter {
204 in_string = false;
205 } else if !in_string {
206 in_string = true;
207 string_delimiter = ch;
208 }
209 }
210 c if !in_string && c == open_char => open_count += 1,
211 c if !in_string && c == close_char => close_count += 1,
212 _ => {}
213 }
214 }
215
216 (open_count, close_count)
217}
218
219pub fn validate_expression(expr: &str) -> Result<(), String> {
235 if expr.trim().is_empty() {
237 return Err("Expression cannot be empty".to_string());
238 }
239
240 if expr.len() > MAX_EXPRESSION_LENGTH {
242 return Err(format!(
243 "Expression too long (max {} characters, got {})",
244 MAX_EXPRESSION_LENGTH,
245 expr.len()
246 ));
247 }
248
249 let (paren_open, paren_close) = count_balanced_with_string_awareness(expr, '(', ')');
253 if paren_open != paren_close {
254 return Err(format!(
255 "Unbalanced parentheses: {} open, {} close",
256 paren_open, paren_close
257 ));
258 }
259
260 let (bracket_open, bracket_close) = count_balanced_with_string_awareness(expr, '[', ']');
262 if bracket_open != bracket_close {
263 return Err(format!(
264 "Unbalanced brackets: {} open, {} close",
265 bracket_open, bracket_close
266 ));
267 }
268
269 let dangerous_patterns = [
271 "DROP", "DELETE", "INSERT", "UPDATE", "EXEC", "EXECUTE", "SYSTEM", "BASH", "SH", "CMD.EXE",
272 ];
273
274 for pattern in &dangerous_patterns {
275 if expr.to_uppercase().contains(pattern) {
276 return Err(format!(
277 "Expression contains dangerous keyword: {}",
278 pattern
279 ));
280 }
281 }
282
283 if !expr.chars().all(|c| {
288 c.is_alphanumeric()
289 || c.is_whitespace()
290 || matches!(
291 c,
292 '.' | '_'
293 | '@'
294 | '('
295 | ')'
296 | '['
297 | ']'
298 | '{'
299 | '}'
300 | '='
301 | '<'
302 | '>'
303 | '!'
304 | '&'
305 | '|'
306 | '+'
307 | '-'
308 | '*'
309 | '/'
310 | '%'
311 | '^'
312 | '~'
313 | '?'
314 | '"'
315 | '\''
316 | ':'
317 | ','
318 | ';'
319 )
320 }) {
321 return Err(
322 "Expression contains invalid characters. Only alphanumeric, operators, and quotes allowed."
323 .to_string(),
324 );
325 }
326
327 Ok(())
328}
329
330pub fn validate_regex_pattern(pattern: &str) -> Result<(), String> {
345 if pattern.len() > MAX_PATTERN_LENGTH {
347 return Err(format!(
348 "Regex pattern too long (max {} characters)",
349 MAX_PATTERN_LENGTH
350 ));
351 }
352
353 match regex::Regex::new(pattern) {
355 Ok(_) => {}
356 Err(e) => {
357 return Err(format!("Invalid regex pattern: {}", e));
358 }
359 }
360
361 let has_nested_quantifiers = pattern.contains(")+")
366 || pattern.contains(")*")
367 || pattern.contains(")?")
368 || pattern.contains("]{2,}+")
369 || pattern.contains("]{2,}*")
370 || pattern.contains("]{2,}?")
371 || pattern.contains("}{2,}+")
372 || pattern.contains("}{2,}*");
373
374 if has_nested_quantifiers {
375 return Err(
376 "Regex pattern contains nested quantifiers that could cause ReDoS attack".to_string(),
377 );
378 }
379
380 let quantifier_chain_patterns = [
383 r"\+\s*\+", r"\*\s*\*", r"\+\s*\*", r"\*\s*\+", ];
388
389 for qc_pattern_str in &quantifier_chain_patterns {
390 if let Ok(qc_pattern) = regex::Regex::new(qc_pattern_str) {
391 if qc_pattern.is_match(pattern) {
392 return Err("Regex pattern contains chained quantifiers (ReDoS risk)".to_string());
393 }
394 }
395 }
396
397 if pattern.contains('|') {
400 if pattern.contains('*') || pattern.contains('+') {
402 if pattern.contains("(") && pattern.contains(")") {
404 if pattern.contains(")*") || pattern.contains(")+") || pattern.contains(")?") {
406 return Err(
407 "Regex pattern contains quantified alternation (high ReDoS risk)"
408 .to_string(),
409 );
410 }
411 }
412 }
413 }
414
415 if pattern.contains('|') && (pattern.contains('*') || pattern.contains('+')) {
417 eprintln!(
418 "⚠️ Warning: Regex contains alternation with quantifiers (potential ReDoS risk)"
419 );
420 }
421
422 Ok(())
423}
424
425#[deprecated(
457 since = "0.4.0",
458 note = "Do not use - backslash escaping doesn't work in Rust comments. Use quote! macro instead."
459)]
460pub fn sanitize_for_comment(input: &str) -> String {
461 input
465 .replace("\\", "\\\\") .replace("*/", "*\\/") .replace("/*", "/\\*") .trim()
469 .to_string()
470}
471
472pub fn escape_for_rust_string(input: &str) -> String {
480 input
481 .replace('\\', "\\\\")
482 .replace('"', "\\\"")
483 .replace('\n', "\\n")
484 .replace('\r', "\\r")
485 .replace('\t', "\\t")
486}
487
488pub fn read_file_with_limit(path: &std::path::Path) -> io::Result<String> {
502 use std::fs::File;
503 use std::io::Read;
504
505 let file = File::open(path)?;
506 let metadata = file.metadata()?;
507
508 if metadata.len() > MAX_FILE_SIZE {
510 return Err(io::Error::new(
511 io::ErrorKind::InvalidData,
512 format!(
513 "File too large (max {} MB, got {} MB)",
514 MAX_FILE_SIZE / 1_000_000,
515 metadata.len() / 1_000_000
516 ),
517 ));
518 }
519
520 let mut buffer = String::new();
521 file.take(MAX_FILE_SIZE).read_to_string(&mut buffer)?;
522 Ok(buffer)
523}
524
525pub fn read_stdin_with_limit() -> io::Result<String> {
540 use std::io::Read;
541
542 let stdin = io::stdin();
543 let mut buffer = String::new();
544
545 stdin.take(MAX_FILE_SIZE).read_to_string(&mut buffer)?;
547
548 if buffer.len() as u64 == MAX_FILE_SIZE {
551 let mut test = [0u8; 1];
553 match std::io::stdin().read(&mut test) {
554 Ok(1) => {
555 return Err(io::Error::new(
557 io::ErrorKind::InvalidData,
558 format!("Input exceeds {} MB limit", MAX_FILE_SIZE / 1_000_000),
559 ));
560 }
561 _ => {
562 }
564 }
565 }
566
567 Ok(buffer)
568}
569
570#[cfg(test)]
573mod tests {
574 use super::*;
575
576 #[test]
581 #[cfg(unix)]
582 fn test_valid_relative_path() {
583 let result = validate_file_path("output.rs");
584 assert!(result.is_ok());
585 }
586
587 #[test]
588 #[cfg(unix)]
589 fn test_valid_nested_path() {
590 let result = validate_file_path("target/debug/generated.rs");
591 assert!(result.is_ok());
592 }
593
594 #[test]
595 #[cfg(unix)]
596 fn test_rejects_absolute_path_unix() {
597 let result = validate_file_path("/etc/passwd");
598 assert!(result.is_err());
599 assert!(result
600 .unwrap_err()
601 .to_string()
602 .contains("Absolute paths are not allowed"));
603 }
604
605 #[test]
606 fn test_rejects_path_traversal() {
607 let result = validate_file_path("../../../etc/passwd");
608 assert!(result.is_err());
609 assert!(result
610 .unwrap_err()
611 .to_string()
612 .contains("Path traversal (..) is not allowed"));
613 }
614
615 #[test]
616 fn test_rejects_single_parent_dir() {
617 let result = validate_file_path("..");
618 assert!(result.is_err());
619 }
620
621 #[test]
622 fn test_rejects_empty_path() {
623 let result = validate_file_path("");
624 assert!(result.is_err());
625 }
626
627 #[test]
628 fn test_rejects_whitespace_only_path() {
629 let result = validate_file_path(" ");
630 assert!(result.is_err());
631 }
632
633 #[test]
638 fn test_valid_simple_expression() {
639 let result = validate_expression("age >= 18");
640 assert!(result.is_ok());
641 }
642
643 #[test]
644 fn test_valid_complex_expression() {
645 let result = validate_expression("(age >= 18) && (verified == true) || (admin == true)");
646 assert!(result.is_ok());
647 }
648
649 #[test]
650 fn test_rejects_empty_expression() {
651 let result = validate_expression("");
652 assert!(result.is_err());
653 }
654
655 #[test]
656 fn test_rejects_whitespace_only_expression() {
657 let result = validate_expression(" \n\t ");
658 assert!(result.is_err());
659 }
660
661 #[test]
662 fn test_rejects_expression_exceeding_max_length() {
663 let long_expr = "a".repeat(MAX_EXPRESSION_LENGTH + 1);
664 let result = validate_expression(&long_expr);
665 assert!(result.is_err());
666 assert!(result.unwrap_err().contains("too long"));
667 }
668
669 #[test]
670 fn test_rejects_unbalanced_parentheses_open() {
671 let result = validate_expression("(age >= 18");
672 assert!(result.is_err());
673 assert!(result.unwrap_err().contains("Unbalanced parentheses"));
674 }
675
676 #[test]
677 fn test_rejects_unbalanced_parentheses_close() {
678 let result = validate_expression("age >= 18)");
679 assert!(result.is_err());
680 }
681
682 #[test]
683 fn test_parens_in_string_not_counted() {
684 let result = validate_expression(r#"name == "balance ( and )""#);
687 assert!(result.is_ok());
688
689 let result = validate_expression(r#"(name == "test""#);
691 assert!(result.is_err());
692 }
693
694 #[test]
695 fn test_brackets_in_string_not_counted() {
696 let result = validate_expression(r#"name == "array[0]""#);
698 assert!(result.is_ok());
699
700 let result = validate_expression(r#"arr[0 == test"#);
702 assert!(result.is_err());
703 }
704
705 #[test]
706 fn test_escaped_quotes_in_string_not_counted() {
707 let result = validate_expression(r#"name == 'test with quote' && valid"#);
710 assert!(result.is_ok());
712 }
713
714 #[test]
715 fn test_rejects_unbalanced_brackets() {
716 let result = validate_expression("arr[0 == 5");
717 assert!(result.is_err());
718 assert!(result.unwrap_err().contains("Unbalanced brackets"));
719 }
720
721 #[test]
722 fn test_rejects_sql_injection_pattern_drop() {
723 let result = validate_expression("drop table users");
724 assert!(result.is_err());
725 assert!(result.unwrap_err().contains("dangerous keyword"));
726 }
727
728 #[test]
729 fn test_rejects_sql_injection_pattern_delete() {
730 let result = validate_expression("delete from users where id = 1");
731 assert!(result.is_err());
732 }
733
734 #[test]
735 fn test_rejects_shell_command_pattern_bash() {
736 let result = validate_expression("bash -c 'rm -rf /'");
737 assert!(result.is_err());
738 }
739
740 #[test]
741 fn test_rejects_invalid_characters() {
742 let result = validate_expression("age >= 18 && `whoami`");
743 assert!(result.is_err());
744 assert!(result.unwrap_err().contains("invalid characters"));
745 }
746
747 #[test]
752 fn test_valid_simple_regex() {
753 let result = validate_regex_pattern("[0-9]+");
754 assert!(result.is_ok());
755 }
756
757 #[test]
758 fn test_valid_email_regex() {
759 let result = validate_regex_pattern(r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}");
760 assert!(result.is_ok());
761 }
762
763 #[test]
764 fn test_rejects_invalid_regex() {
765 let result = validate_regex_pattern("[0-9");
766 assert!(result.is_err());
767 }
768
769 #[test]
770 fn test_rejects_regex_exceeding_max_length() {
771 let long_pattern = "a".repeat(MAX_PATTERN_LENGTH + 1);
772 let result = validate_regex_pattern(&long_pattern);
773 assert!(result.is_err());
774 }
775
776 #[test]
777 fn test_rejects_nested_quantifiers_plus_plus() {
778 let result = validate_regex_pattern("(a+)+");
779 assert!(result.is_err());
780 assert!(result.unwrap_err().contains("nested quantifiers"));
781 }
782
783 #[test]
784 fn test_rejects_nested_quantifiers_star_plus() {
785 let result = validate_regex_pattern("(a*)+");
786 assert!(result.is_err());
787 }
788
789 #[test]
790 fn test_rejects_nested_quantifiers_question_star() {
791 let result = validate_regex_pattern("(a?)*");
792 assert!(result.is_err());
793 }
794
795 #[test]
796 fn test_rejects_quantifier_chains() {
797 let result = validate_regex_pattern("a++");
799 assert!(result.is_err());
800
801 let result = validate_regex_pattern("a**");
802 assert!(result.is_err());
803
804 let result = validate_regex_pattern("a+*");
805 assert!(result.is_err());
806 }
807
808 #[test]
809 fn test_rejects_quantified_alternation() {
810 let result = validate_regex_pattern("(a|b)*");
812 if result.is_err() {
816 } else {
818 }
820 }
821
822 #[test]
827 #[allow(deprecated)]
828 fn test_sanitize_comment_escapes_backslash() {
829 let result = sanitize_for_comment("path\\to\\file");
831 assert!(result.contains("\\\\"));
832 }
833
834 #[test]
835 #[allow(deprecated)]
836 fn test_sanitize_comment_prevents_comment_breakout() {
837 let result = sanitize_for_comment("test */ malicious");
840 assert!(result.contains("*\\/"));
841 }
842
843 #[test]
844 #[allow(deprecated)]
845 fn test_sanitize_comment_prevents_comment_break_in() {
846 let result = sanitize_for_comment("test /* malicious");
849 assert!(result.contains("/\\*"));
850 }
851
852 #[test]
853 fn test_escape_for_rust_string_escapes_quotes() {
854 let result = escape_for_rust_string(r#"test "quoted" value"#);
855 assert!(result.contains("\\\""));
856 }
857
858 #[test]
859 fn test_escape_for_rust_string_escapes_newlines() {
860 let result = escape_for_rust_string("line1\nline2");
861 assert!(result.contains("\\n"));
862 }
863
864 #[test]
865 fn test_escape_for_rust_string_escapes_tabs() {
866 let result = escape_for_rust_string("col1\tcol2");
867 assert!(result.contains("\\t"));
868 }
869
870 #[test]
875 fn test_read_small_file_succeeds() {
876 let temp_file = std::env::temp_dir().join("test_small.txt");
877 std::fs::write(&temp_file, "small content").unwrap();
878
879 let result = read_file_with_limit(&temp_file);
880 assert!(result.is_ok());
881 assert_eq!(result.unwrap(), "small content");
882
883 let _ = std::fs::remove_file(&temp_file);
884 }
885
886 #[test]
887 fn test_read_file_exceeding_size_limit_fails() {
888 let temp_file = std::env::temp_dir().join("test_large.txt");
889 let large_content = "x".repeat((MAX_FILE_SIZE as usize) + 1);
891 std::fs::write(&temp_file, large_content).unwrap();
892
893 let result = read_file_with_limit(&temp_file);
894 assert!(result.is_err());
895 assert!(result.unwrap_err().to_string().contains("too large"));
896
897 let _ = std::fs::remove_file(&temp_file);
898 }
899
900 #[test]
901 fn test_read_nonexistent_file_fails() {
902 let nonexistent = std::env::temp_dir().join("does_not_exist_xyz.txt");
903 let result = read_file_with_limit(&nonexistent);
904 assert!(result.is_err());
905 }
906
907 #[test]
912 fn test_broken_symlink_rejected() {
913 let temp_dir = std::env::temp_dir().join("test_symlink_broken");
915 let _ = std::fs::create_dir_all(&temp_dir);
916
917 let symlink_path = temp_dir.join("broken_symlink");
918 let _ = std::fs::remove_file(&symlink_path);
920
921 #[cfg(unix)]
922 {
923 use std::os::unix::fs as unix_fs;
924 let _ = unix_fs::symlink("/nonexistent/path", &symlink_path);
925
926 let result = validate_file_path(symlink_path.to_str().unwrap());
927 assert!(result.is_err() || result.is_ok()); let _ = std::fs::remove_file(&symlink_path);
931 }
932
933 let _ = std::fs::remove_dir(&temp_dir);
934 }
935
936 #[test]
937 fn test_unwrap_or_issue_fixed() {
938 }
947
948 #[test]
953 #[cfg(unix)]
954 fn test_path_validation_with_context() {
955 let result = validate_file_path_with_context("output.rs");
957 assert!(result.is_ok());
958
959 let (path, cwd) = result.unwrap();
960 assert!(!path.is_absolute());
961
962 let verify_result = verify_path_still_valid(&path, &cwd);
964 assert!(verify_result.is_ok());
965 }
966
967 #[test]
968 #[cfg(unix)]
969 fn test_verify_path_rejects_cwd_change() {
970 let (path, _original_cwd) = validate_file_path_with_context("output.rs").unwrap();
973
974 let fake_cwd = std::path::PathBuf::from("/fake/different/path");
977 let verify_result = verify_path_still_valid(&path, &fake_cwd);
978
979 assert!(verify_result.is_err());
981 }
982
983 #[test]
988 fn test_count_balanced_ignores_string_contents() {
989 let (open, close) = count_balanced_with_string_awareness(
991 r#"message == "Hello (world) and (stuff)""#,
992 '(',
993 ')',
994 );
995 assert_eq!(open, 0);
996 assert_eq!(close, 0);
997 }
998
999 #[test]
1000 fn test_count_balanced_with_actual_parens() {
1001 let (open, close) = count_balanced_with_string_awareness(r#"(msg == "test")"#, '(', ')');
1003 assert_eq!(open, 1);
1004 assert_eq!(close, 1);
1005 }
1006
1007 #[test]
1008 fn test_count_balanced_escaped_quotes() {
1009 let (open, close) =
1011 count_balanced_with_string_awareness(r#"(name == "test \" quote")"#, '(', ')');
1012 assert_eq!(open, 1);
1013 assert_eq!(close, 1);
1014 }
1015}