1use std::collections::{HashMap, HashSet};
31use std::path::Path;
32
33use serde::{Deserialize, Serialize};
34use streaming_iterator::StreamingIterator;
35use tree_sitter::{Node, Query, QueryCursor, Tree};
36
37use crate::error::{Result, BrrrError};
38use crate::lang::LanguageRegistry;
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
46pub enum Severity {
47 Critical,
49 High,
51 Medium,
53 Low,
55}
56
57impl std::fmt::Display for Severity {
58 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59 match self {
60 Severity::Critical => write!(f, "CRITICAL"),
61 Severity::High => write!(f, "HIGH"),
62 Severity::Medium => write!(f, "MEDIUM"),
63 Severity::Low => write!(f, "LOW"),
64 }
65 }
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct Location {
71 pub file: String,
73 pub line: usize,
75 pub column: usize,
77 pub end_line: usize,
79 pub end_column: usize,
81}
82
83#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
85pub enum SqlSinkType {
86 Execute,
88 Query,
90 Raw,
92 PrismaRaw,
94 SessionExecute,
96 TextConstruct,
98 Other,
100}
101
102impl std::fmt::Display for SqlSinkType {
103 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
104 match self {
105 SqlSinkType::Execute => write!(f, "execute"),
106 SqlSinkType::Query => write!(f, "query"),
107 SqlSinkType::Raw => write!(f, "raw"),
108 SqlSinkType::PrismaRaw => write!(f, "prisma_raw"),
109 SqlSinkType::SessionExecute => write!(f, "session_execute"),
110 SqlSinkType::TextConstruct => write!(f, "text_construct"),
111 SqlSinkType::Other => write!(f, "other"),
112 }
113 }
114}
115
116#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
118pub enum UnsafePattern {
119 StringConcatenation,
121 FStringInterpolation,
123 PercentFormat,
125 DotFormat,
127 TemplateLiteral,
129 NonLiteralArgument,
131}
132
133impl std::fmt::Display for UnsafePattern {
134 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
135 match self {
136 UnsafePattern::StringConcatenation => write!(f, "string_concatenation"),
137 UnsafePattern::FStringInterpolation => write!(f, "f_string_interpolation"),
138 UnsafePattern::PercentFormat => write!(f, "percent_format"),
139 UnsafePattern::DotFormat => write!(f, "dot_format"),
140 UnsafePattern::TemplateLiteral => write!(f, "template_literal"),
141 UnsafePattern::NonLiteralArgument => write!(f, "non_literal_argument"),
142 }
143 }
144}
145
146#[derive(Debug, Clone, Serialize, Deserialize)]
148pub struct SQLInjectionFinding {
149 pub location: Location,
151 pub severity: Severity,
153 pub sink_function: SqlSinkType,
155 pub sink_expression: String,
157 pub tainted_param: usize,
159 pub pattern: UnsafePattern,
161 pub confidence: f64,
163 pub code_snippet: String,
165 pub tainted_variables: Vec<String>,
167 pub description: String,
169 pub remediation: String,
171}
172
173#[derive(Debug, Clone, Serialize, Deserialize)]
175pub struct ScanResult {
176 pub findings: Vec<SQLInjectionFinding>,
178 pub files_scanned: usize,
180 pub sinks_found: usize,
182 pub severity_counts: HashMap<String, usize>,
184 pub language: String,
186}
187
188const PYTHON_SQL_SINKS: &[(&str, SqlSinkType)] = &[
194 ("execute", SqlSinkType::Execute),
196 ("executemany", SqlSinkType::Execute),
197 ("executescript", SqlSinkType::Execute),
198 ("connection.execute", SqlSinkType::Execute),
200 ("engine.execute", SqlSinkType::Execute),
202 ("session.execute", SqlSinkType::SessionExecute),
203 ("text", SqlSinkType::TextConstruct),
204 ("raw", SqlSinkType::Raw),
206 ("extra", SqlSinkType::Raw),
207 ("cursor.execute", SqlSinkType::Execute),
209 ("conn.execute", SqlSinkType::Execute),
210 ("pool.execute", SqlSinkType::Execute),
211 ("db.execute", SqlSinkType::Execute),
213];
214
215const TYPESCRIPT_SQL_SINKS: &[(&str, SqlSinkType)] = &[
217 ("query", SqlSinkType::Query),
219 ("execute", SqlSinkType::Execute),
220 ("raw", SqlSinkType::Raw),
222 ("knex.raw", SqlSinkType::Raw),
223 ("$queryRaw", SqlSinkType::PrismaRaw),
225 ("$executeRaw", SqlSinkType::PrismaRaw),
226 ("$queryRawUnsafe", SqlSinkType::PrismaRaw),
227 ("$executeRawUnsafe", SqlSinkType::PrismaRaw),
228 ("createQueryRunner", SqlSinkType::Query),
230 ("manager.query", SqlSinkType::Query),
231 ("sequelize.query", SqlSinkType::Query),
233 ("pool.query", SqlSinkType::Query),
235 ("client.query", SqlSinkType::Query),
236 ("connection.query", SqlSinkType::Query),
238 ("connection.execute", SqlSinkType::Execute),
239 ("db.exec", SqlSinkType::Execute),
241 ("db.prepare", SqlSinkType::Execute),
242];
243
244pub struct SqlInjectionDetector {
250 python_sinks: HashMap<String, SqlSinkType>,
252 typescript_sinks: HashMap<String, SqlSinkType>,
254}
255
256impl Default for SqlInjectionDetector {
257 fn default() -> Self {
258 Self::new()
259 }
260}
261
262impl SqlInjectionDetector {
263 pub fn new() -> Self {
265 let mut python_sinks = HashMap::new();
266 for (name, sink_type) in PYTHON_SQL_SINKS {
267 python_sinks.insert((*name).to_string(), *sink_type);
268 }
269
270 let mut typescript_sinks = HashMap::new();
271 for (name, sink_type) in TYPESCRIPT_SQL_SINKS {
272 typescript_sinks.insert((*name).to_string(), *sink_type);
273 }
274
275 Self {
276 python_sinks,
277 typescript_sinks,
278 }
279 }
280
281 pub fn scan_file(&self, file_path: &str) -> Result<Vec<SQLInjectionFinding>> {
291 let path = Path::new(file_path);
292 let registry = LanguageRegistry::global();
293
294 let lang = registry.detect_language(path).ok_or_else(|| {
295 BrrrError::UnsupportedLanguage(
296 path.extension()
297 .and_then(|e| e.to_str())
298 .unwrap_or("unknown")
299 .to_string(),
300 )
301 })?;
302
303 let source = std::fs::read(path).map_err(|e| BrrrError::io_with_path(e, path))?;
304 let mut parser = lang.parser_for_path(path)?;
305 let tree = parser.parse(&source, None).ok_or_else(|| BrrrError::Parse {
306 file: file_path.to_string(),
307 message: "Failed to parse file".to_string(),
308 })?;
309
310 let lang_name = lang.name();
311 match lang_name {
312 "python" => self.scan_python(&tree, &source, file_path),
313 "typescript" | "javascript" => self.scan_typescript(&tree, &source, file_path),
314 _ => Ok(vec![]), }
316 }
317
318 pub fn scan_directory(&self, dir_path: &str, language: Option<&str>) -> Result<ScanResult> {
329 let path = Path::new(dir_path);
330 if !path.is_dir() {
331 return Err(BrrrError::InvalidArgument(format!(
332 "Not a directory: {}",
333 dir_path
334 )));
335 }
336
337 let mut findings = Vec::new();
338 let mut files_scanned = 0;
339 let mut sinks_found = 0;
340
341 let mut builder = ignore::WalkBuilder::new(path);
343 builder.add_custom_ignore_filename(".brrrignore");
344 builder.hidden(true);
345
346 let extensions: HashSet<&str> = match language {
347 Some("python") => ["py"].iter().copied().collect(),
348 Some("typescript") => ["ts", "tsx", "js", "jsx", "mjs", "cjs"]
349 .iter()
350 .copied()
351 .collect(),
352 Some("javascript") => ["js", "jsx", "mjs", "cjs"].iter().copied().collect(),
353 _ => ["py", "ts", "tsx", "js", "jsx", "mjs", "cjs"]
354 .iter()
355 .copied()
356 .collect(),
357 };
358
359 for entry in builder.build().flatten() {
360 let entry_path = entry.path();
361 if !entry_path.is_file() {
362 continue;
363 }
364
365 let ext = entry_path
366 .extension()
367 .and_then(|e| e.to_str())
368 .unwrap_or("");
369 if !extensions.contains(ext) {
370 continue;
371 }
372
373 files_scanned += 1;
374
375 if let Ok(file_findings) = self.scan_file(entry_path.to_str().unwrap_or("")) {
376 sinks_found += file_findings.len();
377 findings.extend(file_findings);
378 }
379 }
380
381 let mut severity_counts: HashMap<String, usize> = HashMap::new();
383 for finding in &findings {
384 *severity_counts
385 .entry(finding.severity.to_string())
386 .or_insert(0) += 1;
387 }
388
389 let detected_lang = language.unwrap_or("mixed").to_string();
390
391 Ok(ScanResult {
392 findings,
393 files_scanned,
394 sinks_found,
395 severity_counts,
396 language: detected_lang,
397 })
398 }
399
400 fn scan_python(
406 &self,
407 tree: &Tree,
408 source: &[u8],
409 file_path: &str,
410 ) -> Result<Vec<SQLInjectionFinding>> {
411 let mut findings = Vec::new();
412
413 let query_str = r#"
415 (call
416 function: [
417 (identifier) @func_name
418 (attribute
419 object: (_) @obj
420 attribute: (identifier) @method_name)
421 ]
422 arguments: (argument_list) @args
423 ) @call
424 "#;
425
426 let ts_lang = tree.language();
427 let query = Query::new(&ts_lang, query_str).map_err(|e| {
428 BrrrError::TreeSitter(format!("Failed to create Python query: {}", e))
429 })?;
430
431 let mut cursor = QueryCursor::new();
432 let mut matches = cursor.matches(&query, tree.root_node(), source);
433
434 let func_name_idx = query.capture_index_for_name("func_name");
435 let method_name_idx = query.capture_index_for_name("method_name");
436 let obj_idx = query.capture_index_for_name("obj");
437 let args_idx = query.capture_index_for_name("args");
438 let call_idx = query.capture_index_for_name("call");
439
440 while let Some(match_) = matches.next() {
441 let call_node: Option<Node> = match call_idx {
443 Some(idx) => match_.captures.iter().find(|c| c.index == idx).map(|c| c.node),
444 None => None,
445 };
446
447 let func_name: Option<&str> = func_name_idx.and_then(|idx| {
449 match_
450 .captures
451 .iter()
452 .find(|c| c.index == idx)
453 .map(|c| self.node_text(c.node, source))
454 });
455
456 let method_name: Option<&str> = method_name_idx.and_then(|idx| {
457 match_
458 .captures
459 .iter()
460 .find(|c| c.index == idx)
461 .map(|c| self.node_text(c.node, source))
462 });
463
464 let obj_text: Option<&str> = obj_idx.and_then(|idx| {
465 match_
466 .captures
467 .iter()
468 .find(|c| c.index == idx)
469 .map(|c| self.node_text(c.node, source))
470 });
471
472 let args_node: Option<Node> = args_idx.and_then(|idx| {
473 match_
474 .captures
475 .iter()
476 .find(|c| c.index == idx)
477 .map(|c| c.node)
478 });
479
480 let (sink_name, sink_type) = if let Some(method) = method_name {
482 let full_name = if let Some(obj) = obj_text {
483 format!("{}.{}", obj, method)
484 } else {
485 method.to_string()
486 };
487
488 if let Some(sink_type) = self.python_sinks.get(method) {
490 (full_name, *sink_type)
491 } else if let Some(sink_type) = self.python_sinks.get(&full_name) {
492 (full_name, *sink_type)
493 } else {
494 continue;
495 }
496 } else if let Some(func) = func_name {
497 if let Some(sink_type) = self.python_sinks.get(func) {
498 (func.to_string(), *sink_type)
499 } else {
500 continue;
501 }
502 } else {
503 continue;
504 };
505
506 if let (Some(call_node), Some(args_node)) = (call_node, args_node) {
508 if let Some(finding) = self.analyze_python_call_args(
509 call_node, args_node, source, file_path, &sink_name, sink_type,
510 ) {
511 findings.push(finding);
512 }
513 }
514 }
515
516 Ok(findings)
517 }
518
519 fn analyze_python_call_args(
521 &self,
522 call_node: Node,
523 args_node: Node,
524 source: &[u8],
525 file_path: &str,
526 sink_name: &str,
527 sink_type: SqlSinkType,
528 ) -> Option<SQLInjectionFinding> {
529 let first_arg = self.get_first_python_arg(args_node)?;
531
532 if self.has_python_params(args_node, source) {
535 let query_text = self.node_text(first_arg, source);
537 if query_text.contains('?')
538 || query_text.contains("%s")
539 || query_text.contains("%(")
540 || query_text.contains(':')
541 {
542 return None; }
544 }
545
546 let (pattern, severity, confidence, tainted_vars) =
548 self.analyze_python_expression(first_arg, source)?;
549
550 let code_snippet = self.node_text(call_node, source).to_string();
551 let location = Location {
552 file: file_path.to_string(),
553 line: call_node.start_position().row + 1,
554 column: call_node.start_position().column + 1,
555 end_line: call_node.end_position().row + 1,
556 end_column: call_node.end_position().column + 1,
557 };
558
559 let description = self.generate_description(&pattern, sink_name, &tainted_vars);
560 let remediation = self.generate_remediation(&pattern, "python");
561
562 Some(SQLInjectionFinding {
563 location,
564 severity,
565 sink_function: sink_type,
566 sink_expression: sink_name.to_string(),
567 tainted_param: 0,
568 pattern,
569 confidence,
570 code_snippet,
571 tainted_variables: tainted_vars,
572 description,
573 remediation,
574 })
575 }
576
577 fn get_first_python_arg<'a>(&self, args_node: Node<'a>) -> Option<Node<'a>> {
579 let mut cursor = args_node.walk();
580 for child in args_node.children(&mut cursor) {
581 match child.kind() {
582 "(" | ")" | "," => continue,
583 "keyword_argument" => continue, _ => return Some(child),
585 }
586 }
587 None
588 }
589
590 fn has_python_params(&self, args_node: Node, _source: &[u8]) -> bool {
592 let mut positional_args = Vec::new();
593 let mut cursor = args_node.walk();
594
595 for child in args_node.children(&mut cursor) {
596 match child.kind() {
597 "(" | ")" | "," => continue,
598 "keyword_argument" => continue,
599 _ => positional_args.push(child),
600 }
601 }
602
603 if positional_args.len() >= 2 {
605 let second_arg = positional_args[1];
606 matches!(
607 second_arg.kind(),
608 "tuple" | "list" | "dictionary" | "identifier"
609 )
610 } else {
611 false
612 }
613 }
614
615 fn analyze_python_expression(
619 &self,
620 node: Node,
621 source: &[u8],
622 ) -> Option<(UnsafePattern, Severity, f64, Vec<String>)> {
623 match node.kind() {
624 "string" => {
626 let text = self.node_text(node, source);
627 if text.starts_with("f\"") || text.starts_with("f'") {
628 let vars = self.extract_fstring_variables(text);
630 if !vars.is_empty() {
631 return Some((
632 UnsafePattern::FStringInterpolation,
633 Severity::Critical,
634 0.95,
635 vars,
636 ));
637 }
638 }
639 None
640 }
641
642 "binary_operator" => {
644 let op_node = node.child_by_field_name("operator")?;
645 let op = self.node_text(op_node, source);
646
647 if op == "+" {
648 let left = node.child_by_field_name("left")?;
650 let right = node.child_by_field_name("right")?;
651
652 let left_is_string = self.is_string_literal(left, source);
653 let right_is_string = self.is_string_literal(right, source);
654
655 if left_is_string || right_is_string {
656 let vars = self.collect_variables(node, source);
657 return Some((
658 UnsafePattern::StringConcatenation,
659 Severity::Critical,
660 0.9,
661 vars,
662 ));
663 }
664 } else if op == "%" {
665 let vars = self.collect_variables(node, source);
667 return Some((UnsafePattern::PercentFormat, Severity::Critical, 0.9, vars));
668 }
669 None
670 }
671
672 "call" => {
674 if let Some(func) = node.child_by_field_name("function") {
676 if func.kind() == "attribute" {
677 if let Some(attr) = func.child_by_field_name("attribute") {
678 if self.node_text(attr, source) == "format" {
679 let vars = self.collect_call_args(node, source);
680 return Some((
681 UnsafePattern::DotFormat,
682 Severity::Critical,
683 0.9,
684 vars,
685 ));
686 }
687 }
688 }
689 }
690 None
691 }
692
693 "identifier" => {
695 let var_name = self.node_text(node, source).to_string();
696 Some((
698 UnsafePattern::NonLiteralArgument,
699 Severity::Medium,
700 0.6,
701 vec![var_name],
702 ))
703 }
704
705 "concatenated_string" => {
707 let text = self.node_text(node, source);
708 if text.contains("f\"") || text.contains("f'") {
709 let vars = self.extract_fstring_variables(text);
710 if !vars.is_empty() {
711 return Some((
712 UnsafePattern::FStringInterpolation,
713 Severity::Critical,
714 0.95,
715 vars,
716 ));
717 }
718 }
719 None
720 }
721
722 _ => None,
723 }
724 }
725
726 fn scan_typescript(
732 &self,
733 tree: &Tree,
734 source: &[u8],
735 file_path: &str,
736 ) -> Result<Vec<SQLInjectionFinding>> {
737 let mut findings = Vec::new();
738
739 let query_str = r#"
741 (call_expression
742 function: [
743 (identifier) @func_name
744 (member_expression
745 object: (_) @obj
746 property: (property_identifier) @method_name)
747 ]
748 arguments: (arguments) @args
749 ) @call
750 "#;
751
752 let ts_lang = tree.language();
753 let query = Query::new(&ts_lang, query_str).map_err(|e| {
754 BrrrError::TreeSitter(format!("Failed to create TypeScript query: {}", e))
755 })?;
756
757 let mut cursor = QueryCursor::new();
758 let mut matches = cursor.matches(&query, tree.root_node(), source);
759
760 let func_name_idx = query.capture_index_for_name("func_name");
761 let method_name_idx = query.capture_index_for_name("method_name");
762 let obj_idx = query.capture_index_for_name("obj");
763 let args_idx = query.capture_index_for_name("args");
764 let call_idx = query.capture_index_for_name("call");
765
766 while let Some(match_) = matches.next() {
767 let call_node: Option<Node> = match call_idx {
768 Some(idx) => match_.captures.iter().find(|c| c.index == idx).map(|c| c.node),
769 None => None,
770 };
771
772 let func_name: Option<&str> = func_name_idx.and_then(|idx| {
773 match_
774 .captures
775 .iter()
776 .find(|c| c.index == idx)
777 .map(|c| self.node_text(c.node, source))
778 });
779
780 let method_name: Option<&str> = method_name_idx.and_then(|idx| {
781 match_
782 .captures
783 .iter()
784 .find(|c| c.index == idx)
785 .map(|c| self.node_text(c.node, source))
786 });
787
788 let obj_text: Option<&str> = obj_idx.and_then(|idx| {
789 match_
790 .captures
791 .iter()
792 .find(|c| c.index == idx)
793 .map(|c| self.node_text(c.node, source))
794 });
795
796 let args_node: Option<Node> = args_idx.and_then(|idx| {
797 match_
798 .captures
799 .iter()
800 .find(|c| c.index == idx)
801 .map(|c| c.node)
802 });
803
804 let (sink_name, sink_type) = if let Some(method) = method_name {
806 let full_name = if let Some(obj) = obj_text {
807 format!("{}.{}", obj, method)
808 } else {
809 method.to_string()
810 };
811
812 if let Some(sink_type) = self.typescript_sinks.get(method) {
813 (full_name, *sink_type)
814 } else if let Some(sink_type) = self.typescript_sinks.get(&full_name) {
815 (full_name, *sink_type)
816 } else {
817 continue;
818 }
819 } else if let Some(func) = func_name {
820 if let Some(sink_type) = self.typescript_sinks.get(func) {
821 (func.to_string(), *sink_type)
822 } else {
823 continue;
824 }
825 } else {
826 continue;
827 };
828
829 if let (Some(call_node), Some(args_node)) = (call_node, args_node) {
831 if let Some(finding) = self.analyze_typescript_call_args(
832 call_node, args_node, source, file_path, &sink_name, sink_type,
833 ) {
834 findings.push(finding);
835 }
836 }
837 }
838
839 Ok(findings)
840 }
841
842 fn analyze_typescript_call_args(
844 &self,
845 call_node: Node,
846 args_node: Node,
847 source: &[u8],
848 file_path: &str,
849 sink_name: &str,
850 sink_type: SqlSinkType,
851 ) -> Option<SQLInjectionFinding> {
852 let first_arg = self.get_first_typescript_arg(args_node)?;
854
855 if self.has_typescript_params(args_node, source) {
857 let query_text = self.node_text(first_arg, source);
858 if query_text.contains('$')
860 || query_text.contains('?')
861 || query_text.contains(':')
862 {
863 return None; }
865 }
866
867 let (pattern, severity, confidence, tainted_vars) =
869 self.analyze_typescript_expression(first_arg, source)?;
870
871 let code_snippet = self.node_text(call_node, source).to_string();
872 let location = Location {
873 file: file_path.to_string(),
874 line: call_node.start_position().row + 1,
875 column: call_node.start_position().column + 1,
876 end_line: call_node.end_position().row + 1,
877 end_column: call_node.end_position().column + 1,
878 };
879
880 let description = self.generate_description(&pattern, sink_name, &tainted_vars);
881 let remediation = self.generate_remediation(&pattern, "typescript");
882
883 Some(SQLInjectionFinding {
884 location,
885 severity,
886 sink_function: sink_type,
887 sink_expression: sink_name.to_string(),
888 tainted_param: 0,
889 pattern,
890 confidence,
891 code_snippet,
892 tainted_variables: tainted_vars,
893 description,
894 remediation,
895 })
896 }
897
898 fn get_first_typescript_arg<'a>(&self, args_node: Node<'a>) -> Option<Node<'a>> {
900 let mut cursor = args_node.walk();
901 for child in args_node.children(&mut cursor) {
902 match child.kind() {
903 "(" | ")" | "," => continue,
904 _ => return Some(child),
905 }
906 }
907 None
908 }
909
910 fn has_typescript_params(&self, args_node: Node, _source: &[u8]) -> bool {
912 let mut positional_args = Vec::new();
913 let mut cursor = args_node.walk();
914
915 for child in args_node.children(&mut cursor) {
916 match child.kind() {
917 "(" | ")" | "," => continue,
918 _ => positional_args.push(child),
919 }
920 }
921
922 if positional_args.len() >= 2 {
924 let second_arg = positional_args[1];
925 matches!(second_arg.kind(), "array" | "identifier")
926 } else {
927 false
928 }
929 }
930
931 fn analyze_typescript_expression(
933 &self,
934 node: Node,
935 source: &[u8],
936 ) -> Option<(UnsafePattern, Severity, f64, Vec<String>)> {
937 match node.kind() {
938 "template_string" => {
940 let mut cursor = node.walk();
942 let mut has_substitution = false;
943 let mut vars = Vec::new();
944
945 for child in node.children(&mut cursor) {
946 if child.kind() == "template_substitution" {
947 has_substitution = true;
948 vars.extend(self.collect_variables(child, source));
949 }
950 }
951
952 if has_substitution {
953 return Some((
954 UnsafePattern::TemplateLiteral,
955 Severity::Critical,
956 0.95,
957 vars,
958 ));
959 }
960 None
961 }
962
963 "binary_expression" => {
965 let op_node = node
966 .children(&mut node.walk())
967 .find(|c| c.kind() == "+" || c.kind() == "binary_operator")?;
968 let op = self.node_text(op_node, source);
969
970 if op == "+" {
971 let left = node.child(0)?;
972 let right = node.child(2)?;
973
974 let left_is_string = self.is_string_literal(left, source);
975 let right_is_string = self.is_string_literal(right, source);
976
977 if left_is_string || right_is_string {
978 let vars = self.collect_variables(node, source);
979 return Some((
980 UnsafePattern::StringConcatenation,
981 Severity::Critical,
982 0.9,
983 vars,
984 ));
985 }
986 }
987 None
988 }
989
990 "identifier" => {
992 let var_name = self.node_text(node, source).to_string();
993 Some((
994 UnsafePattern::NonLiteralArgument,
995 Severity::Medium,
996 0.6,
997 vec![var_name],
998 ))
999 }
1000
1001 _ => None,
1002 }
1003 }
1004
1005 fn node_text<'a>(&self, node: Node, source: &'a [u8]) -> &'a str {
1011 std::str::from_utf8(&source[node.start_byte()..node.end_byte()]).unwrap_or("")
1012 }
1013
1014 fn is_string_literal(&self, node: Node, source: &[u8]) -> bool {
1016 let text = self.node_text(node, source);
1017 matches!(node.kind(), "string" | "string_literal" | "template_string")
1018 || text.starts_with('"')
1019 || text.starts_with('\'')
1020 || text.starts_with('`')
1021 }
1022
1023 fn extract_fstring_variables(&self, text: &str) -> Vec<String> {
1028 let bytes = text.as_bytes();
1029
1030 if bytes.len() < 64 {
1032 return self.extract_fstring_variables_scalar(text);
1033 }
1034
1035 let open_positions = Self::find_byte_positions_simd(bytes, b'{');
1037 let close_positions = Self::find_byte_positions_simd(bytes, b'}');
1038
1039 if open_positions.is_empty() || close_positions.is_empty() {
1041 return Vec::new();
1042 }
1043
1044 self.extract_vars_from_positions(bytes, &open_positions, &close_positions)
1046 }
1047
1048 #[cfg(target_arch = "x86_64")]
1052 fn find_byte_positions_simd(haystack: &[u8], needle: u8) -> Vec<usize> {
1053 use std::arch::x86_64::{
1054 __m256i, _mm256_cmpeq_epi8, _mm256_loadu_si256, _mm256_movemask_epi8, _mm256_set1_epi8,
1055 };
1056
1057 let len = haystack.len();
1058 let mut positions = Vec::with_capacity(len / 32 + 1);
1060
1061 if !std::arch::is_x86_feature_detected!("avx2") {
1063 for (i, &b) in haystack.iter().enumerate() {
1065 if b == needle {
1066 positions.push(i);
1067 }
1068 }
1069 return positions;
1070 }
1071
1072 unsafe {
1074 let needle_vec: __m256i = _mm256_set1_epi8(needle as i8);
1075 let mut offset = 0;
1076
1077 while offset + 32 <= len {
1079 let chunk_ptr = haystack.as_ptr().add(offset) as *const __m256i;
1080 let chunk: __m256i = _mm256_loadu_si256(chunk_ptr);
1081 let cmp: __m256i = _mm256_cmpeq_epi8(chunk, needle_vec);
1082 let mask = _mm256_movemask_epi8(cmp) as u32;
1083
1084 if mask != 0 {
1086 let mut m = mask;
1087 while m != 0 {
1088 let bit_pos = m.trailing_zeros() as usize;
1089 positions.push(offset + bit_pos);
1090 m &= m - 1; }
1092 }
1093 offset += 32;
1094 }
1095
1096 for i in offset..len {
1098 if *haystack.get_unchecked(i) == needle {
1099 positions.push(i);
1100 }
1101 }
1102 }
1103
1104 positions
1105 }
1106
1107 #[cfg(not(target_arch = "x86_64"))]
1109 fn find_byte_positions_simd(haystack: &[u8], needle: u8) -> Vec<usize> {
1110 haystack
1111 .iter()
1112 .enumerate()
1113 .filter_map(|(i, &b)| if b == needle { Some(i) } else { None })
1114 .collect()
1115 }
1116
1117 fn extract_vars_from_positions(
1121 &self,
1122 bytes: &[u8],
1123 opens: &[usize],
1124 closes: &[usize],
1125 ) -> Vec<String> {
1126 let mut vars = Vec::with_capacity(opens.len().min(closes.len()));
1127 let mut open_idx = 0;
1128 let mut close_idx = 0;
1129
1130 while open_idx < opens.len() && close_idx < closes.len() {
1131 let open_pos = opens[open_idx];
1132 let close_pos = closes[close_idx];
1133
1134 if close_pos <= open_pos {
1136 close_idx += 1;
1137 continue;
1138 }
1139
1140 if open_pos + 1 < bytes.len() && bytes[open_pos + 1] == b'{' {
1142 open_idx += 2; continue;
1144 }
1145
1146 let content = &bytes[open_pos + 1..close_pos];
1148
1149 if !content.is_empty() {
1151 if let Ok(var_str) = std::str::from_utf8(content) {
1152 let var_name = var_str
1154 .split([':', '!', '.'])
1155 .next()
1156 .unwrap_or(var_str)
1157 .trim();
1158
1159 if !var_name.is_empty() {
1160 vars.push(var_name.to_string());
1161 }
1162 }
1163 }
1164
1165 open_idx += 1;
1166 close_idx += 1;
1167 }
1168
1169 vars
1170 }
1171
1172 fn extract_fstring_variables_scalar(&self, text: &str) -> Vec<String> {
1174 let mut vars = Vec::new();
1175 let mut in_brace = false;
1176 let mut current_var = String::new();
1177
1178 for ch in text.chars() {
1179 if ch == '{' && !in_brace {
1180 in_brace = true;
1181 current_var.clear();
1182 } else if ch == '}' && in_brace {
1183 in_brace = false;
1184 if !current_var.is_empty() && !current_var.starts_with('{') {
1185 let var_name = current_var
1187 .split([':', '!', '.'])
1188 .next()
1189 .unwrap_or(¤t_var)
1190 .trim();
1191 if !var_name.is_empty() {
1192 vars.push(var_name.to_string());
1193 }
1194 }
1195 } else if in_brace {
1196 current_var.push(ch);
1197 }
1198 }
1199
1200 vars
1201 }
1202
1203 fn collect_variables(&self, node: Node, source: &[u8]) -> Vec<String> {
1205 let mut vars = Vec::new();
1206 self.collect_variables_recursive(node, source, &mut vars);
1207 vars.sort();
1208 vars.dedup();
1209 vars
1210 }
1211
1212 fn collect_variables_recursive(&self, node: Node, source: &[u8], vars: &mut Vec<String>) {
1213 if node.kind() == "identifier" {
1214 let name = self.node_text(node, source).to_string();
1215 if !["True", "False", "None", "self", "cls"].contains(&name.as_str()) {
1217 vars.push(name);
1218 }
1219 }
1220
1221 let mut cursor = node.walk();
1222 for child in node.children(&mut cursor) {
1223 self.collect_variables_recursive(child, source, vars);
1224 }
1225 }
1226
1227 fn collect_call_args(&self, node: Node, source: &[u8]) -> Vec<String> {
1229 let mut vars = Vec::new();
1230 if let Some(args) = node.child_by_field_name("arguments") {
1231 self.collect_variables_recursive(args, source, &mut vars);
1232 }
1233 vars.sort();
1234 vars.dedup();
1235 vars
1236 }
1237
1238 fn generate_description(
1240 &self,
1241 pattern: &UnsafePattern,
1242 sink_name: &str,
1243 vars: &[String],
1244 ) -> String {
1245 let var_list = if vars.is_empty() {
1246 "unknown variable".to_string()
1247 } else {
1248 vars.join(", ")
1249 };
1250
1251 match pattern {
1252 UnsafePattern::StringConcatenation => {
1253 format!(
1254 "SQL injection via string concatenation in {}(). Variables {} are concatenated into the query string.",
1255 sink_name, var_list
1256 )
1257 }
1258 UnsafePattern::FStringInterpolation => {
1259 format!(
1260 "SQL injection via f-string interpolation in {}(). Variables {} are interpolated into the query.",
1261 sink_name, var_list
1262 )
1263 }
1264 UnsafePattern::PercentFormat => {
1265 format!(
1266 "SQL injection via percent formatting in {}(). Variables {} are formatted into the query.",
1267 sink_name, var_list
1268 )
1269 }
1270 UnsafePattern::DotFormat => {
1271 format!(
1272 "SQL injection via .format() in {}(). Variables {} are formatted into the query.",
1273 sink_name, var_list
1274 )
1275 }
1276 UnsafePattern::TemplateLiteral => {
1277 format!(
1278 "SQL injection via template literal in {}(). Variables {} are interpolated into the query.",
1279 sink_name, var_list
1280 )
1281 }
1282 UnsafePattern::NonLiteralArgument => {
1283 format!(
1284 "Potential SQL injection in {}(). Variable {} is passed directly to the query.",
1285 sink_name, var_list
1286 )
1287 }
1288 }
1289 }
1290
1291 fn generate_remediation(&self, pattern: &UnsafePattern, language: &str) -> String {
1293 match (pattern, language) {
1294 (_, "python") => {
1295 "Use parameterized queries with placeholders:\n\
1296 cursor.execute(\"SELECT * FROM users WHERE id = ?\", (user_id,))\n\
1297 Or use SQLAlchemy ORM methods with proper escaping."
1298 .to_string()
1299 }
1300 (_, "typescript" | "javascript") => {
1301 "Use parameterized queries with placeholders:\n\
1302 db.query(\"SELECT * FROM users WHERE id = $1\", [userId])\n\
1303 Or use an ORM like Prisma, TypeORM, or Knex with proper parameter binding."
1304 .to_string()
1305 }
1306 _ => "Use parameterized queries instead of string interpolation.".to_string(),
1307 }
1308 }
1309}
1310
1311#[cfg(test)]
1316mod tests {
1317 use super::*;
1318
1319 fn create_temp_file(content: &str, extension: &str) -> tempfile::NamedTempFile {
1320 use std::io::Write;
1321 let mut file = tempfile::Builder::new()
1322 .suffix(extension)
1323 .tempfile()
1324 .expect("Failed to create temp file");
1325 file.write_all(content.as_bytes())
1326 .expect("Failed to write temp file");
1327 file
1328 }
1329
1330 #[test]
1335 fn test_python_fstring_injection() {
1336 let source = r#"
1337import sqlite3
1338conn = sqlite3.connect('test.db')
1339cursor = conn.cursor()
1340
1341def get_user(user_id):
1342 cursor.execute(f"SELECT * FROM users WHERE id = {user_id}")
1343 return cursor.fetchone()
1344"#;
1345 let file = create_temp_file(source, ".py");
1346 let detector = SqlInjectionDetector::new();
1347 let findings = detector
1348 .scan_file(file.path().to_str().unwrap())
1349 .expect("Scan should succeed");
1350
1351 assert!(!findings.is_empty(), "Should detect f-string injection");
1352 let finding = &findings[0];
1353 assert_eq!(finding.pattern, UnsafePattern::FStringInterpolation);
1354 assert_eq!(finding.severity, Severity::Critical);
1355 assert!(finding.tainted_variables.contains(&"user_id".to_string()));
1356 }
1357
1358 #[test]
1359 fn test_python_string_concat_injection() {
1360 let source = r#"
1361import sqlite3
1362conn = sqlite3.connect('test.db')
1363cursor = conn.cursor()
1364
1365def get_user(user_id):
1366 query = "SELECT * FROM users WHERE id = " + user_id
1367 cursor.execute(query)
1368 return cursor.fetchone()
1369"#;
1370 let file = create_temp_file(source, ".py");
1371 let detector = SqlInjectionDetector::new();
1372 let findings = detector
1373 .scan_file(file.path().to_str().unwrap())
1374 .expect("Scan should succeed");
1375
1376 assert!(!findings.is_empty(), "Should detect variable injection");
1378 }
1379
1380 #[test]
1381 fn test_python_percent_format_injection() {
1382 let source = r#"
1383import sqlite3
1384conn = sqlite3.connect('test.db')
1385cursor = conn.cursor()
1386
1387def get_user(user_id):
1388 cursor.execute("SELECT * FROM users WHERE id = %s" % user_id)
1389 return cursor.fetchone()
1390"#;
1391 let file = create_temp_file(source, ".py");
1392 let detector = SqlInjectionDetector::new();
1393 let findings = detector
1394 .scan_file(file.path().to_str().unwrap())
1395 .expect("Scan should succeed");
1396
1397 assert!(!findings.is_empty(), "Should detect percent format injection");
1398 let finding = &findings[0];
1399 assert_eq!(finding.pattern, UnsafePattern::PercentFormat);
1400 }
1401
1402 #[test]
1403 fn test_python_safe_parameterized_query() {
1404 let source = r#"
1405import sqlite3
1406conn = sqlite3.connect('test.db')
1407cursor = conn.cursor()
1408
1409def get_user(user_id):
1410 cursor.execute("SELECT * FROM users WHERE id = ?", (user_id,))
1411 return cursor.fetchone()
1412"#;
1413 let file = create_temp_file(source, ".py");
1414 let detector = SqlInjectionDetector::new();
1415 let findings = detector
1416 .scan_file(file.path().to_str().unwrap())
1417 .expect("Scan should succeed");
1418
1419 assert!(
1420 findings.is_empty(),
1421 "Should NOT detect safe parameterized query"
1422 );
1423 }
1424
1425 #[test]
1426 fn test_python_safe_literal_query() {
1427 let source = r#"
1428import sqlite3
1429conn = sqlite3.connect('test.db')
1430cursor = conn.cursor()
1431
1432def get_all_users():
1433 cursor.execute("SELECT * FROM users")
1434 return cursor.fetchall()
1435"#;
1436 let file = create_temp_file(source, ".py");
1437 let detector = SqlInjectionDetector::new();
1438 let findings = detector
1439 .scan_file(file.path().to_str().unwrap())
1440 .expect("Scan should succeed");
1441
1442 assert!(findings.is_empty(), "Should NOT detect safe literal query");
1443 }
1444
1445 #[test]
1450 fn test_typescript_template_literal_injection() {
1451 let source = r#"
1452import { Pool } from 'pg';
1453const pool = new Pool();
1454
1455async function getUser(userId: string) {
1456 const result = await pool.query(`SELECT * FROM users WHERE id = ${userId}`);
1457 return result.rows[0];
1458}
1459"#;
1460 let file = create_temp_file(source, ".ts");
1461 let detector = SqlInjectionDetector::new();
1462 let findings = detector
1463 .scan_file(file.path().to_str().unwrap())
1464 .expect("Scan should succeed");
1465
1466 assert!(
1467 !findings.is_empty(),
1468 "Should detect template literal injection"
1469 );
1470 let finding = &findings[0];
1471 assert_eq!(finding.pattern, UnsafePattern::TemplateLiteral);
1472 assert_eq!(finding.severity, Severity::Critical);
1473 }
1474
1475 #[test]
1476 fn test_typescript_string_concat_injection() {
1477 let source = r#"
1478import { Pool } from 'pg';
1479const pool = new Pool();
1480
1481async function getUser(userId: string) {
1482 const query = "SELECT * FROM users WHERE id = " + userId;
1483 const result = await pool.query(query);
1484 return result.rows[0];
1485}
1486"#;
1487 let file = create_temp_file(source, ".ts");
1488 let detector = SqlInjectionDetector::new();
1489 let findings = detector
1490 .scan_file(file.path().to_str().unwrap())
1491 .expect("Scan should succeed");
1492
1493 assert!(!findings.is_empty(), "Should detect variable injection");
1495 }
1496
1497 #[test]
1498 fn test_typescript_safe_parameterized_query() {
1499 let source = r#"
1500import { Pool } from 'pg';
1501const pool = new Pool();
1502
1503async function getUser(userId: string) {
1504 const result = await pool.query("SELECT * FROM users WHERE id = $1", [userId]);
1505 return result.rows[0];
1506}
1507"#;
1508 let file = create_temp_file(source, ".ts");
1509 let detector = SqlInjectionDetector::new();
1510 let findings = detector
1511 .scan_file(file.path().to_str().unwrap())
1512 .expect("Scan should succeed");
1513
1514 assert!(
1515 findings.is_empty(),
1516 "Should NOT detect safe parameterized query"
1517 );
1518 }
1519
1520 #[test]
1521 fn test_typescript_safe_literal_query() {
1522 let source = r#"
1523import { Pool } from 'pg';
1524const pool = new Pool();
1525
1526async function getAllUsers() {
1527 const result = await pool.query("SELECT * FROM users");
1528 return result.rows;
1529}
1530"#;
1531 let file = create_temp_file(source, ".ts");
1532 let detector = SqlInjectionDetector::new();
1533 let findings = detector
1534 .scan_file(file.path().to_str().unwrap())
1535 .expect("Scan should succeed");
1536
1537 assert!(findings.is_empty(), "Should NOT detect safe literal query");
1538 }
1539
1540 #[test]
1541 fn test_typescript_prisma_raw_injection() {
1542 let source = r#"
1543import { PrismaClient } from '@prisma/client';
1544const prisma = new PrismaClient();
1545
1546async function getUser(userId: string) {
1547 return prisma.$queryRaw(`SELECT * FROM users WHERE id = ${userId}`);
1548}
1549"#;
1550 let file = create_temp_file(source, ".ts");
1551 let detector = SqlInjectionDetector::new();
1552 let findings = detector
1553 .scan_file(file.path().to_str().unwrap())
1554 .expect("Scan should succeed");
1555
1556 assert!(!findings.is_empty(), "Should detect Prisma raw query injection");
1557 }
1558
1559 #[test]
1564 fn test_extract_fstring_variables() {
1565 let detector = SqlInjectionDetector::new();
1566
1567 let vars = detector.extract_fstring_variables(r#"f"SELECT * FROM users WHERE id = {user_id}""#);
1568 assert_eq!(vars, vec!["user_id"]);
1569
1570 let vars = detector.extract_fstring_variables(r#"f"SELECT * FROM {table} WHERE {col} = {val}""#);
1571 assert_eq!(vars, vec!["table", "col", "val"]);
1572
1573 let vars = detector.extract_fstring_variables(r#"f"value: {x:.2f}""#);
1574 assert_eq!(vars, vec!["x"]);
1575 }
1576
1577 #[test]
1578 fn test_severity_display() {
1579 assert_eq!(Severity::Critical.to_string(), "CRITICAL");
1580 assert_eq!(Severity::High.to_string(), "HIGH");
1581 assert_eq!(Severity::Medium.to_string(), "MEDIUM");
1582 assert_eq!(Severity::Low.to_string(), "LOW");
1583 }
1584
1585 #[test]
1586 fn test_pattern_display() {
1587 assert_eq!(
1588 UnsafePattern::StringConcatenation.to_string(),
1589 "string_concatenation"
1590 );
1591 assert_eq!(
1592 UnsafePattern::FStringInterpolation.to_string(),
1593 "f_string_interpolation"
1594 );
1595 assert_eq!(
1596 UnsafePattern::TemplateLiteral.to_string(),
1597 "template_literal"
1598 );
1599 }
1600
1601 #[test]
1602 fn test_scan_result_counts() {
1603 let result = ScanResult {
1604 findings: vec![],
1605 files_scanned: 10,
1606 sinks_found: 5,
1607 severity_counts: [("CRITICAL".to_string(), 2), ("HIGH".to_string(), 3)]
1608 .into_iter()
1609 .collect(),
1610 language: "python".to_string(),
1611 };
1612
1613 assert_eq!(result.files_scanned, 10);
1614 assert_eq!(result.sinks_found, 5);
1615 assert_eq!(result.severity_counts.get("CRITICAL"), Some(&2));
1616 }
1617}