go_brrr/security/injection/
sql.rs

1//! SQL Injection detection for Python and TypeScript.
2//!
3//! Detects potential SQL injection vulnerabilities by analyzing:
4//! - String concatenation in SQL queries
5//! - Format string interpolation
6//! - f-string interpolation (Python)
7//! - Template literal interpolation (TypeScript)
8//! - Non-literal arguments to SQL execution functions
9//!
10//! # Architecture
11//!
12//! The detector works in three phases:
13//! 1. **Sink Detection**: Find SQL execution calls (execute, query, raw, etc.)
14//! 2. **Pattern Analysis**: Check if query argument uses unsafe string construction
15//! 3. **Taint Tracking**: Trace data flow from user inputs to SQL sinks (via DFG)
16//!
17//! # Safe Patterns (Not Flagged)
18//!
19//! - Parameterized queries: `cursor.execute("SELECT * FROM users WHERE id = ?", (user_id,))`
20//! - ORM methods with proper escaping: `User.objects.filter(id=user_id)`
21//! - Literal string queries: `cursor.execute("SELECT * FROM users")`
22//!
23//! # Unsafe Patterns (Flagged)
24//!
25//! - String concatenation: `cursor.execute("SELECT * FROM users WHERE id = " + user_id)`
26//! - f-strings: `cursor.execute(f"SELECT * FROM users WHERE id = {user_id}")`
27//! - Format strings: `cursor.execute("SELECT * FROM users WHERE id = %s" % user_id)`
28//! - Template literals: `db.query(`SELECT * FROM users WHERE id = ${userId}`)`
29
30use std::collections::{HashMap, HashSet};
31use std::path::Path;
32
33use serde::{Deserialize, Serialize};
34use streaming_iterator::StreamingIterator;
35use tree_sitter::{Node, Query, QueryCursor, Tree};
36
37use crate::error::{Result, BrrrError};
38use crate::lang::LanguageRegistry;
39
40// =============================================================================
41// Types
42// =============================================================================
43
44/// Severity level for SQL injection findings.
45#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
46pub enum Severity {
47    /// Direct SQL injection with high confidence (e.g., string concat in execute)
48    Critical,
49    /// Likely SQL injection (e.g., f-string in query method)
50    High,
51    /// Possible SQL injection (e.g., variable passed to query, needs review)
52    Medium,
53    /// Informational finding (e.g., dynamic query construction detected)
54    Low,
55}
56
57impl std::fmt::Display for Severity {
58    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59        match self {
60            Severity::Critical => write!(f, "CRITICAL"),
61            Severity::High => write!(f, "HIGH"),
62            Severity::Medium => write!(f, "MEDIUM"),
63            Severity::Low => write!(f, "LOW"),
64        }
65    }
66}
67
68/// Location of a finding in source code.
69#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct Location {
71    /// File path
72    pub file: String,
73    /// Line number (1-indexed)
74    pub line: usize,
75    /// Column number (1-indexed)
76    pub column: usize,
77    /// End line number (1-indexed)
78    pub end_line: usize,
79    /// End column number (1-indexed)
80    pub end_column: usize,
81}
82
83/// Type of SQL sink function.
84#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
85pub enum SqlSinkType {
86    /// cursor.execute(), connection.execute()
87    Execute,
88    /// db.query(), pool.query()
89    Query,
90    /// engine.raw(), knex.raw()
91    Raw,
92    /// $queryRaw, $executeRaw (Prisma)
93    PrismaRaw,
94    /// session.execute() (SQLAlchemy)
95    SessionExecute,
96    /// Text() for raw SQL in SQLAlchemy
97    TextConstruct,
98    /// Other sink type
99    Other,
100}
101
102impl std::fmt::Display for SqlSinkType {
103    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
104        match self {
105            SqlSinkType::Execute => write!(f, "execute"),
106            SqlSinkType::Query => write!(f, "query"),
107            SqlSinkType::Raw => write!(f, "raw"),
108            SqlSinkType::PrismaRaw => write!(f, "prisma_raw"),
109            SqlSinkType::SessionExecute => write!(f, "session_execute"),
110            SqlSinkType::TextConstruct => write!(f, "text_construct"),
111            SqlSinkType::Other => write!(f, "other"),
112        }
113    }
114}
115
116/// Type of unsafe pattern detected.
117#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
118pub enum UnsafePattern {
119    /// String concatenation: "SELECT ... " + var
120    StringConcatenation,
121    /// Python f-string: f"SELECT ... {var}"
122    FStringInterpolation,
123    /// Python format: "SELECT ... %s" % var
124    PercentFormat,
125    /// Python .format(): "SELECT ... {}".format(var)
126    DotFormat,
127    /// JavaScript template literal: `SELECT ... ${var}`
128    TemplateLiteral,
129    /// Variable passed directly (not a literal)
130    NonLiteralArgument,
131}
132
133impl std::fmt::Display for UnsafePattern {
134    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
135        match self {
136            UnsafePattern::StringConcatenation => write!(f, "string_concatenation"),
137            UnsafePattern::FStringInterpolation => write!(f, "f_string_interpolation"),
138            UnsafePattern::PercentFormat => write!(f, "percent_format"),
139            UnsafePattern::DotFormat => write!(f, "dot_format"),
140            UnsafePattern::TemplateLiteral => write!(f, "template_literal"),
141            UnsafePattern::NonLiteralArgument => write!(f, "non_literal_argument"),
142        }
143    }
144}
145
146/// A SQL injection finding.
147#[derive(Debug, Clone, Serialize, Deserialize)]
148pub struct SQLInjectionFinding {
149    /// Location in source code
150    pub location: Location,
151    /// Severity level
152    pub severity: Severity,
153    /// Type of SQL sink function
154    pub sink_function: SqlSinkType,
155    /// Full sink call expression (e.g., "cursor.execute")
156    pub sink_expression: String,
157    /// Which parameter is tainted (0-indexed)
158    pub tainted_param: usize,
159    /// The unsafe pattern detected
160    pub pattern: UnsafePattern,
161    /// Confidence score (0.0 to 1.0)
162    pub confidence: f64,
163    /// Code snippet showing the vulnerable code
164    pub code_snippet: String,
165    /// Variables involved in the taint chain
166    pub tainted_variables: Vec<String>,
167    /// Human-readable description
168    pub description: String,
169    /// Suggested fix
170    pub remediation: String,
171}
172
173/// Summary of SQL injection scan results.
174#[derive(Debug, Clone, Serialize, Deserialize)]
175pub struct ScanResult {
176    /// All findings
177    pub findings: Vec<SQLInjectionFinding>,
178    /// Number of files scanned
179    pub files_scanned: usize,
180    /// Number of SQL sinks found
181    pub sinks_found: usize,
182    /// Counts by severity
183    pub severity_counts: HashMap<String, usize>,
184    /// Language detected
185    pub language: String,
186}
187
188// =============================================================================
189// SQL Sink Definitions
190// =============================================================================
191
192/// Known SQL sinks for Python.
193const PYTHON_SQL_SINKS: &[(&str, SqlSinkType)] = &[
194    // Database cursor methods
195    ("execute", SqlSinkType::Execute),
196    ("executemany", SqlSinkType::Execute),
197    ("executescript", SqlSinkType::Execute),
198    // Connection methods
199    ("connection.execute", SqlSinkType::Execute),
200    // SQLAlchemy
201    ("engine.execute", SqlSinkType::Execute),
202    ("session.execute", SqlSinkType::SessionExecute),
203    ("text", SqlSinkType::TextConstruct),
204    // Django
205    ("raw", SqlSinkType::Raw),
206    ("extra", SqlSinkType::Raw),
207    // psycopg2/asyncpg
208    ("cursor.execute", SqlSinkType::Execute),
209    ("conn.execute", SqlSinkType::Execute),
210    ("pool.execute", SqlSinkType::Execute),
211    // aiosqlite
212    ("db.execute", SqlSinkType::Execute),
213];
214
215/// Known SQL sinks for TypeScript/JavaScript.
216const TYPESCRIPT_SQL_SINKS: &[(&str, SqlSinkType)] = &[
217    // Generic query methods
218    ("query", SqlSinkType::Query),
219    ("execute", SqlSinkType::Execute),
220    // Knex
221    ("raw", SqlSinkType::Raw),
222    ("knex.raw", SqlSinkType::Raw),
223    // Prisma
224    ("$queryRaw", SqlSinkType::PrismaRaw),
225    ("$executeRaw", SqlSinkType::PrismaRaw),
226    ("$queryRawUnsafe", SqlSinkType::PrismaRaw),
227    ("$executeRawUnsafe", SqlSinkType::PrismaRaw),
228    // TypeORM
229    ("createQueryRunner", SqlSinkType::Query),
230    ("manager.query", SqlSinkType::Query),
231    // Sequelize
232    ("sequelize.query", SqlSinkType::Query),
233    // node-postgres
234    ("pool.query", SqlSinkType::Query),
235    ("client.query", SqlSinkType::Query),
236    // mysql2
237    ("connection.query", SqlSinkType::Query),
238    ("connection.execute", SqlSinkType::Execute),
239    // better-sqlite3
240    ("db.exec", SqlSinkType::Execute),
241    ("db.prepare", SqlSinkType::Execute),
242];
243
244// =============================================================================
245// Detector Implementation
246// =============================================================================
247
248/// SQL Injection detector for multiple languages.
249pub struct SqlInjectionDetector {
250    /// Python SQL sink patterns (method_name -> sink_type)
251    python_sinks: HashMap<String, SqlSinkType>,
252    /// TypeScript SQL sink patterns
253    typescript_sinks: HashMap<String, SqlSinkType>,
254}
255
256impl Default for SqlInjectionDetector {
257    fn default() -> Self {
258        Self::new()
259    }
260}
261
262impl SqlInjectionDetector {
263    /// Create a new SQL injection detector.
264    pub fn new() -> Self {
265        let mut python_sinks = HashMap::new();
266        for (name, sink_type) in PYTHON_SQL_SINKS {
267            python_sinks.insert((*name).to_string(), *sink_type);
268        }
269
270        let mut typescript_sinks = HashMap::new();
271        for (name, sink_type) in TYPESCRIPT_SQL_SINKS {
272            typescript_sinks.insert((*name).to_string(), *sink_type);
273        }
274
275        Self {
276            python_sinks,
277            typescript_sinks,
278        }
279    }
280
281    /// Scan a file for SQL injection vulnerabilities.
282    ///
283    /// # Arguments
284    ///
285    /// * `file_path` - Path to the source file
286    ///
287    /// # Returns
288    ///
289    /// Vector of SQL injection findings
290    pub fn scan_file(&self, file_path: &str) -> Result<Vec<SQLInjectionFinding>> {
291        let path = Path::new(file_path);
292        let registry = LanguageRegistry::global();
293
294        let lang = registry.detect_language(path).ok_or_else(|| {
295            BrrrError::UnsupportedLanguage(
296                path.extension()
297                    .and_then(|e| e.to_str())
298                    .unwrap_or("unknown")
299                    .to_string(),
300            )
301        })?;
302
303        let source = std::fs::read(path).map_err(|e| BrrrError::io_with_path(e, path))?;
304        let mut parser = lang.parser_for_path(path)?;
305        let tree = parser.parse(&source, None).ok_or_else(|| BrrrError::Parse {
306            file: file_path.to_string(),
307            message: "Failed to parse file".to_string(),
308        })?;
309
310        let lang_name = lang.name();
311        match lang_name {
312            "python" => self.scan_python(&tree, &source, file_path),
313            "typescript" | "javascript" => self.scan_typescript(&tree, &source, file_path),
314            _ => Ok(vec![]), // Other languages not yet supported
315        }
316    }
317
318    /// Scan a directory for SQL injection vulnerabilities.
319    ///
320    /// # Arguments
321    ///
322    /// * `dir_path` - Path to the directory
323    /// * `language` - Optional language filter
324    ///
325    /// # Returns
326    ///
327    /// Scan result with all findings
328    pub fn scan_directory(&self, dir_path: &str, language: Option<&str>) -> Result<ScanResult> {
329        let path = Path::new(dir_path);
330        if !path.is_dir() {
331            return Err(BrrrError::InvalidArgument(format!(
332                "Not a directory: {}",
333                dir_path
334            )));
335        }
336
337        let mut findings = Vec::new();
338        let mut files_scanned = 0;
339        let mut sinks_found = 0;
340
341        // Walk directory respecting .gitignore and .brrrignore
342        let mut builder = ignore::WalkBuilder::new(path);
343        builder.add_custom_ignore_filename(".brrrignore");
344        builder.hidden(true);
345
346        let extensions: HashSet<&str> = match language {
347            Some("python") => ["py"].iter().copied().collect(),
348            Some("typescript") => ["ts", "tsx", "js", "jsx", "mjs", "cjs"]
349                .iter()
350                .copied()
351                .collect(),
352            Some("javascript") => ["js", "jsx", "mjs", "cjs"].iter().copied().collect(),
353            _ => ["py", "ts", "tsx", "js", "jsx", "mjs", "cjs"]
354                .iter()
355                .copied()
356                .collect(),
357        };
358
359        for entry in builder.build().flatten() {
360            let entry_path = entry.path();
361            if !entry_path.is_file() {
362                continue;
363            }
364
365            let ext = entry_path
366                .extension()
367                .and_then(|e| e.to_str())
368                .unwrap_or("");
369            if !extensions.contains(ext) {
370                continue;
371            }
372
373            files_scanned += 1;
374
375            if let Ok(file_findings) = self.scan_file(entry_path.to_str().unwrap_or("")) {
376                sinks_found += file_findings.len();
377                findings.extend(file_findings);
378            }
379        }
380
381        // Count by severity
382        let mut severity_counts: HashMap<String, usize> = HashMap::new();
383        for finding in &findings {
384            *severity_counts
385                .entry(finding.severity.to_string())
386                .or_insert(0) += 1;
387        }
388
389        let detected_lang = language.unwrap_or("mixed").to_string();
390
391        Ok(ScanResult {
392            findings,
393            files_scanned,
394            sinks_found,
395            severity_counts,
396            language: detected_lang,
397        })
398    }
399
400    // =========================================================================
401    // Python Analysis
402    // =========================================================================
403
404    /// Scan Python source for SQL injection.
405    fn scan_python(
406        &self,
407        tree: &Tree,
408        source: &[u8],
409        file_path: &str,
410    ) -> Result<Vec<SQLInjectionFinding>> {
411        let mut findings = Vec::new();
412
413        // Query for call expressions
414        let query_str = r#"
415            (call
416                function: [
417                    (identifier) @func_name
418                    (attribute
419                        object: (_) @obj
420                        attribute: (identifier) @method_name)
421                ]
422                arguments: (argument_list) @args
423            ) @call
424        "#;
425
426        let ts_lang = tree.language();
427        let query = Query::new(&ts_lang, query_str).map_err(|e| {
428            BrrrError::TreeSitter(format!("Failed to create Python query: {}", e))
429        })?;
430
431        let mut cursor = QueryCursor::new();
432        let mut matches = cursor.matches(&query, tree.root_node(), source);
433
434        let func_name_idx = query.capture_index_for_name("func_name");
435        let method_name_idx = query.capture_index_for_name("method_name");
436        let obj_idx = query.capture_index_for_name("obj");
437        let args_idx = query.capture_index_for_name("args");
438        let call_idx = query.capture_index_for_name("call");
439
440        while let Some(match_) = matches.next() {
441            // Get call node
442            let call_node: Option<Node> = match call_idx {
443                Some(idx) => match_.captures.iter().find(|c| c.index == idx).map(|c| c.node),
444                None => None,
445            };
446
447            // Get function/method name
448            let func_name: Option<&str> = func_name_idx.and_then(|idx| {
449                match_
450                    .captures
451                    .iter()
452                    .find(|c| c.index == idx)
453                    .map(|c| self.node_text(c.node, source))
454            });
455
456            let method_name: Option<&str> = method_name_idx.and_then(|idx| {
457                match_
458                    .captures
459                    .iter()
460                    .find(|c| c.index == idx)
461                    .map(|c| self.node_text(c.node, source))
462            });
463
464            let obj_text: Option<&str> = obj_idx.and_then(|idx| {
465                match_
466                    .captures
467                    .iter()
468                    .find(|c| c.index == idx)
469                    .map(|c| self.node_text(c.node, source))
470            });
471
472            let args_node: Option<Node> = args_idx.and_then(|idx| {
473                match_
474                    .captures
475                    .iter()
476                    .find(|c| c.index == idx)
477                    .map(|c| c.node)
478            });
479
480            // Determine if this is a SQL sink
481            let (sink_name, sink_type) = if let Some(method) = method_name {
482                let full_name = if let Some(obj) = obj_text {
483                    format!("{}.{}", obj, method)
484                } else {
485                    method.to_string()
486                };
487
488                // Check if method is a known sink
489                if let Some(sink_type) = self.python_sinks.get(method) {
490                    (full_name, *sink_type)
491                } else if let Some(sink_type) = self.python_sinks.get(&full_name) {
492                    (full_name, *sink_type)
493                } else {
494                    continue;
495                }
496            } else if let Some(func) = func_name {
497                if let Some(sink_type) = self.python_sinks.get(func) {
498                    (func.to_string(), *sink_type)
499                } else {
500                    continue;
501                }
502            } else {
503                continue;
504            };
505
506            // Analyze arguments for unsafe patterns
507            if let (Some(call_node), Some(args_node)) = (call_node, args_node) {
508                if let Some(finding) = self.analyze_python_call_args(
509                    call_node, args_node, source, file_path, &sink_name, sink_type,
510                ) {
511                    findings.push(finding);
512                }
513            }
514        }
515
516        Ok(findings)
517    }
518
519    /// Analyze Python call arguments for SQL injection patterns.
520    fn analyze_python_call_args(
521        &self,
522        call_node: Node,
523        args_node: Node,
524        source: &[u8],
525        file_path: &str,
526        sink_name: &str,
527        sink_type: SqlSinkType,
528    ) -> Option<SQLInjectionFinding> {
529        // Get first argument (the SQL query)
530        let first_arg = self.get_first_python_arg(args_node)?;
531
532        // Check for parameterized query (safe pattern)
533        // Look for tuple as second argument: execute("...", (param,))
534        if self.has_python_params(args_node, source) {
535            // Check if query uses parameterized placeholders (?, %s, :name)
536            let query_text = self.node_text(first_arg, source);
537            if query_text.contains('?')
538                || query_text.contains("%s")
539                || query_text.contains("%(")
540                || query_text.contains(':')
541            {
542                return None; // Safe parameterized query
543            }
544        }
545
546        // Analyze the first argument for unsafe patterns
547        let (pattern, severity, confidence, tainted_vars) =
548            self.analyze_python_expression(first_arg, source)?;
549
550        let code_snippet = self.node_text(call_node, source).to_string();
551        let location = Location {
552            file: file_path.to_string(),
553            line: call_node.start_position().row + 1,
554            column: call_node.start_position().column + 1,
555            end_line: call_node.end_position().row + 1,
556            end_column: call_node.end_position().column + 1,
557        };
558
559        let description = self.generate_description(&pattern, sink_name, &tainted_vars);
560        let remediation = self.generate_remediation(&pattern, "python");
561
562        Some(SQLInjectionFinding {
563            location,
564            severity,
565            sink_function: sink_type,
566            sink_expression: sink_name.to_string(),
567            tainted_param: 0,
568            pattern,
569            confidence,
570            code_snippet,
571            tainted_variables: tainted_vars,
572            description,
573            remediation,
574        })
575    }
576
577    /// Get the first positional argument from Python argument list.
578    fn get_first_python_arg<'a>(&self, args_node: Node<'a>) -> Option<Node<'a>> {
579        let mut cursor = args_node.walk();
580        for child in args_node.children(&mut cursor) {
581            match child.kind() {
582                "(" | ")" | "," => continue,
583                "keyword_argument" => continue, // Skip keyword args
584                _ => return Some(child),
585            }
586        }
587        None
588    }
589
590    /// Check if Python call has parameter arguments (tuple/list as second arg).
591    fn has_python_params(&self, args_node: Node, _source: &[u8]) -> bool {
592        let mut positional_args = Vec::new();
593        let mut cursor = args_node.walk();
594
595        for child in args_node.children(&mut cursor) {
596            match child.kind() {
597                "(" | ")" | "," => continue,
598                "keyword_argument" => continue,
599                _ => positional_args.push(child),
600            }
601        }
602
603        // Check if there's a second argument that looks like params
604        if positional_args.len() >= 2 {
605            let second_arg = positional_args[1];
606            matches!(
607                second_arg.kind(),
608                "tuple" | "list" | "dictionary" | "identifier"
609            )
610        } else {
611            false
612        }
613    }
614
615    /// Analyze a Python expression for unsafe patterns.
616    ///
617    /// Returns (pattern, severity, confidence, tainted_variables)
618    fn analyze_python_expression(
619        &self,
620        node: Node,
621        source: &[u8],
622    ) -> Option<(UnsafePattern, Severity, f64, Vec<String>)> {
623        match node.kind() {
624            // f-string: f"SELECT ... {var}"
625            "string" => {
626                let text = self.node_text(node, source);
627                if text.starts_with("f\"") || text.starts_with("f'") {
628                    // Extract interpolated variables
629                    let vars = self.extract_fstring_variables(text);
630                    if !vars.is_empty() {
631                        return Some((
632                            UnsafePattern::FStringInterpolation,
633                            Severity::Critical,
634                            0.95,
635                            vars,
636                        ));
637                    }
638                }
639                None
640            }
641
642            // Binary operator: "SELECT ... " + var
643            "binary_operator" => {
644                let op_node = node.child_by_field_name("operator")?;
645                let op = self.node_text(op_node, source);
646
647                if op == "+" {
648                    // String concatenation
649                    let left = node.child_by_field_name("left")?;
650                    let right = node.child_by_field_name("right")?;
651
652                    let left_is_string = self.is_string_literal(left, source);
653                    let right_is_string = self.is_string_literal(right, source);
654
655                    if left_is_string || right_is_string {
656                        let vars = self.collect_variables(node, source);
657                        return Some((
658                            UnsafePattern::StringConcatenation,
659                            Severity::Critical,
660                            0.9,
661                            vars,
662                        ));
663                    }
664                } else if op == "%" {
665                    // Percent format: "SELECT ... %s" % var
666                    let vars = self.collect_variables(node, source);
667                    return Some((UnsafePattern::PercentFormat, Severity::Critical, 0.9, vars));
668                }
669                None
670            }
671
672            // Method call: "SELECT ... {}".format(var)
673            "call" => {
674                // Check for .format() calls
675                if let Some(func) = node.child_by_field_name("function") {
676                    if func.kind() == "attribute" {
677                        if let Some(attr) = func.child_by_field_name("attribute") {
678                            if self.node_text(attr, source) == "format" {
679                                let vars = self.collect_call_args(node, source);
680                                return Some((
681                                    UnsafePattern::DotFormat,
682                                    Severity::Critical,
683                                    0.9,
684                                    vars,
685                                ));
686                            }
687                        }
688                    }
689                }
690                None
691            }
692
693            // Identifier: variable passed directly
694            "identifier" => {
695                let var_name = self.node_text(node, source).to_string();
696                // Lower severity for variable - might be a constant
697                Some((
698                    UnsafePattern::NonLiteralArgument,
699                    Severity::Medium,
700                    0.6,
701                    vec![var_name],
702                ))
703            }
704
705            // Concatenated string: handle multi-part strings
706            "concatenated_string" => {
707                let text = self.node_text(node, source);
708                if text.contains("f\"") || text.contains("f'") {
709                    let vars = self.extract_fstring_variables(text);
710                    if !vars.is_empty() {
711                        return Some((
712                            UnsafePattern::FStringInterpolation,
713                            Severity::Critical,
714                            0.95,
715                            vars,
716                        ));
717                    }
718                }
719                None
720            }
721
722            _ => None,
723        }
724    }
725
726    // =========================================================================
727    // TypeScript/JavaScript Analysis
728    // =========================================================================
729
730    /// Scan TypeScript/JavaScript source for SQL injection.
731    fn scan_typescript(
732        &self,
733        tree: &Tree,
734        source: &[u8],
735        file_path: &str,
736    ) -> Result<Vec<SQLInjectionFinding>> {
737        let mut findings = Vec::new();
738
739        // Query for call expressions
740        let query_str = r#"
741            (call_expression
742                function: [
743                    (identifier) @func_name
744                    (member_expression
745                        object: (_) @obj
746                        property: (property_identifier) @method_name)
747                ]
748                arguments: (arguments) @args
749            ) @call
750        "#;
751
752        let ts_lang = tree.language();
753        let query = Query::new(&ts_lang, query_str).map_err(|e| {
754            BrrrError::TreeSitter(format!("Failed to create TypeScript query: {}", e))
755        })?;
756
757        let mut cursor = QueryCursor::new();
758        let mut matches = cursor.matches(&query, tree.root_node(), source);
759
760        let func_name_idx = query.capture_index_for_name("func_name");
761        let method_name_idx = query.capture_index_for_name("method_name");
762        let obj_idx = query.capture_index_for_name("obj");
763        let args_idx = query.capture_index_for_name("args");
764        let call_idx = query.capture_index_for_name("call");
765
766        while let Some(match_) = matches.next() {
767            let call_node: Option<Node> = match call_idx {
768                Some(idx) => match_.captures.iter().find(|c| c.index == idx).map(|c| c.node),
769                None => None,
770            };
771
772            let func_name: Option<&str> = func_name_idx.and_then(|idx| {
773                match_
774                    .captures
775                    .iter()
776                    .find(|c| c.index == idx)
777                    .map(|c| self.node_text(c.node, source))
778            });
779
780            let method_name: Option<&str> = method_name_idx.and_then(|idx| {
781                match_
782                    .captures
783                    .iter()
784                    .find(|c| c.index == idx)
785                    .map(|c| self.node_text(c.node, source))
786            });
787
788            let obj_text: Option<&str> = obj_idx.and_then(|idx| {
789                match_
790                    .captures
791                    .iter()
792                    .find(|c| c.index == idx)
793                    .map(|c| self.node_text(c.node, source))
794            });
795
796            let args_node: Option<Node> = args_idx.and_then(|idx| {
797                match_
798                    .captures
799                    .iter()
800                    .find(|c| c.index == idx)
801                    .map(|c| c.node)
802            });
803
804            // Determine if this is a SQL sink
805            let (sink_name, sink_type) = if let Some(method) = method_name {
806                let full_name = if let Some(obj) = obj_text {
807                    format!("{}.{}", obj, method)
808                } else {
809                    method.to_string()
810                };
811
812                if let Some(sink_type) = self.typescript_sinks.get(method) {
813                    (full_name, *sink_type)
814                } else if let Some(sink_type) = self.typescript_sinks.get(&full_name) {
815                    (full_name, *sink_type)
816                } else {
817                    continue;
818                }
819            } else if let Some(func) = func_name {
820                if let Some(sink_type) = self.typescript_sinks.get(func) {
821                    (func.to_string(), *sink_type)
822                } else {
823                    continue;
824                }
825            } else {
826                continue;
827            };
828
829            // Analyze arguments
830            if let (Some(call_node), Some(args_node)) = (call_node, args_node) {
831                if let Some(finding) = self.analyze_typescript_call_args(
832                    call_node, args_node, source, file_path, &sink_name, sink_type,
833                ) {
834                    findings.push(finding);
835                }
836            }
837        }
838
839        Ok(findings)
840    }
841
842    /// Analyze TypeScript call arguments for SQL injection patterns.
843    fn analyze_typescript_call_args(
844        &self,
845        call_node: Node,
846        args_node: Node,
847        source: &[u8],
848        file_path: &str,
849        sink_name: &str,
850        sink_type: SqlSinkType,
851    ) -> Option<SQLInjectionFinding> {
852        // Get first argument
853        let first_arg = self.get_first_typescript_arg(args_node)?;
854
855        // Check for parameterized query with array as second arg
856        if self.has_typescript_params(args_node, source) {
857            let query_text = self.node_text(first_arg, source);
858            // Check for parameterized placeholders ($1, ?, :name)
859            if query_text.contains('$')
860                || query_text.contains('?')
861                || query_text.contains(':')
862            {
863                return None; // Safe parameterized query
864            }
865        }
866
867        // Analyze expression
868        let (pattern, severity, confidence, tainted_vars) =
869            self.analyze_typescript_expression(first_arg, source)?;
870
871        let code_snippet = self.node_text(call_node, source).to_string();
872        let location = Location {
873            file: file_path.to_string(),
874            line: call_node.start_position().row + 1,
875            column: call_node.start_position().column + 1,
876            end_line: call_node.end_position().row + 1,
877            end_column: call_node.end_position().column + 1,
878        };
879
880        let description = self.generate_description(&pattern, sink_name, &tainted_vars);
881        let remediation = self.generate_remediation(&pattern, "typescript");
882
883        Some(SQLInjectionFinding {
884            location,
885            severity,
886            sink_function: sink_type,
887            sink_expression: sink_name.to_string(),
888            tainted_param: 0,
889            pattern,
890            confidence,
891            code_snippet,
892            tainted_variables: tainted_vars,
893            description,
894            remediation,
895        })
896    }
897
898    /// Get first positional argument from TypeScript arguments.
899    fn get_first_typescript_arg<'a>(&self, args_node: Node<'a>) -> Option<Node<'a>> {
900        let mut cursor = args_node.walk();
901        for child in args_node.children(&mut cursor) {
902            match child.kind() {
903                "(" | ")" | "," => continue,
904                _ => return Some(child),
905            }
906        }
907        None
908    }
909
910    /// Check if TypeScript call has parameter arguments.
911    fn has_typescript_params(&self, args_node: Node, _source: &[u8]) -> bool {
912        let mut positional_args = Vec::new();
913        let mut cursor = args_node.walk();
914
915        for child in args_node.children(&mut cursor) {
916            match child.kind() {
917                "(" | ")" | "," => continue,
918                _ => positional_args.push(child),
919            }
920        }
921
922        // Check if there's an array as second argument
923        if positional_args.len() >= 2 {
924            let second_arg = positional_args[1];
925            matches!(second_arg.kind(), "array" | "identifier")
926        } else {
927            false
928        }
929    }
930
931    /// Analyze a TypeScript expression for unsafe patterns.
932    fn analyze_typescript_expression(
933        &self,
934        node: Node,
935        source: &[u8],
936    ) -> Option<(UnsafePattern, Severity, f64, Vec<String>)> {
937        match node.kind() {
938            // Template literal: `SELECT ... ${var}`
939            "template_string" => {
940                // Check for interpolation
941                let mut cursor = node.walk();
942                let mut has_substitution = false;
943                let mut vars = Vec::new();
944
945                for child in node.children(&mut cursor) {
946                    if child.kind() == "template_substitution" {
947                        has_substitution = true;
948                        vars.extend(self.collect_variables(child, source));
949                    }
950                }
951
952                if has_substitution {
953                    return Some((
954                        UnsafePattern::TemplateLiteral,
955                        Severity::Critical,
956                        0.95,
957                        vars,
958                    ));
959                }
960                None
961            }
962
963            // Binary expression: "SELECT ... " + var
964            "binary_expression" => {
965                let op_node = node
966                    .children(&mut node.walk())
967                    .find(|c| c.kind() == "+" || c.kind() == "binary_operator")?;
968                let op = self.node_text(op_node, source);
969
970                if op == "+" {
971                    let left = node.child(0)?;
972                    let right = node.child(2)?;
973
974                    let left_is_string = self.is_string_literal(left, source);
975                    let right_is_string = self.is_string_literal(right, source);
976
977                    if left_is_string || right_is_string {
978                        let vars = self.collect_variables(node, source);
979                        return Some((
980                            UnsafePattern::StringConcatenation,
981                            Severity::Critical,
982                            0.9,
983                            vars,
984                        ));
985                    }
986                }
987                None
988            }
989
990            // Identifier: variable passed directly
991            "identifier" => {
992                let var_name = self.node_text(node, source).to_string();
993                Some((
994                    UnsafePattern::NonLiteralArgument,
995                    Severity::Medium,
996                    0.6,
997                    vec![var_name],
998                ))
999            }
1000
1001            _ => None,
1002        }
1003    }
1004
1005    // =========================================================================
1006    // Helper Methods
1007    // =========================================================================
1008
1009    /// Get text from a node.
1010    fn node_text<'a>(&self, node: Node, source: &'a [u8]) -> &'a str {
1011        std::str::from_utf8(&source[node.start_byte()..node.end_byte()]).unwrap_or("")
1012    }
1013
1014    /// Check if node is a string literal.
1015    fn is_string_literal(&self, node: Node, source: &[u8]) -> bool {
1016        let text = self.node_text(node, source);
1017        matches!(node.kind(), "string" | "string_literal" | "template_string")
1018            || text.starts_with('"')
1019            || text.starts_with('\'')
1020            || text.starts_with('`')
1021    }
1022
1023    /// Extract variables from f-string.
1024    ///
1025    /// Uses SIMD (AVX2) to find `{` and `}` positions in parallel when available,
1026    /// falling back to scalar processing for small strings or non-x86 platforms.
1027    fn extract_fstring_variables(&self, text: &str) -> Vec<String> {
1028        let bytes = text.as_bytes();
1029
1030        // For very short strings, use scalar path directly
1031        if bytes.len() < 64 {
1032            return self.extract_fstring_variables_scalar(text);
1033        }
1034
1035        // Find all { and } positions using SIMD
1036        let open_positions = Self::find_byte_positions_simd(bytes, b'{');
1037        let close_positions = Self::find_byte_positions_simd(bytes, b'}');
1038
1039        // If no braces found, return early
1040        if open_positions.is_empty() || close_positions.is_empty() {
1041            return Vec::new();
1042        }
1043
1044        // Process matched pairs
1045        self.extract_vars_from_positions(bytes, &open_positions, &close_positions)
1046    }
1047
1048    /// SIMD-accelerated byte position finder using AVX2 (u8x32).
1049    ///
1050    /// Scans 32 bytes at a time, returning all positions where `needle` occurs.
1051    #[cfg(target_arch = "x86_64")]
1052    fn find_byte_positions_simd(haystack: &[u8], needle: u8) -> Vec<usize> {
1053        use std::arch::x86_64::{
1054            __m256i, _mm256_cmpeq_epi8, _mm256_loadu_si256, _mm256_movemask_epi8, _mm256_set1_epi8,
1055        };
1056
1057        let len = haystack.len();
1058        // Pre-allocate with reasonable capacity (expect ~1 match per 32 bytes on average)
1059        let mut positions = Vec::with_capacity(len / 32 + 1);
1060
1061        // SAFETY: We check for AVX2 support at runtime
1062        if !std::arch::is_x86_feature_detected!("avx2") {
1063            // Fallback to scalar for non-AVX2 CPUs
1064            for (i, &b) in haystack.iter().enumerate() {
1065                if b == needle {
1066                    positions.push(i);
1067                }
1068            }
1069            return positions;
1070        }
1071
1072        // SAFETY: AVX2 is available (checked above), pointers are valid
1073        unsafe {
1074            let needle_vec: __m256i = _mm256_set1_epi8(needle as i8);
1075            let mut offset = 0;
1076
1077            // Process 32 bytes at a time
1078            while offset + 32 <= len {
1079                let chunk_ptr = haystack.as_ptr().add(offset) as *const __m256i;
1080                let chunk: __m256i = _mm256_loadu_si256(chunk_ptr);
1081                let cmp: __m256i = _mm256_cmpeq_epi8(chunk, needle_vec);
1082                let mask = _mm256_movemask_epi8(cmp) as u32;
1083
1084                // Extract positions from bitmask
1085                if mask != 0 {
1086                    let mut m = mask;
1087                    while m != 0 {
1088                        let bit_pos = m.trailing_zeros() as usize;
1089                        positions.push(offset + bit_pos);
1090                        m &= m - 1; // Clear lowest set bit
1091                    }
1092                }
1093                offset += 32;
1094            }
1095
1096            // Handle remaining bytes (< 32) with scalar
1097            for i in offset..len {
1098                if *haystack.get_unchecked(i) == needle {
1099                    positions.push(i);
1100                }
1101            }
1102        }
1103
1104        positions
1105    }
1106
1107    /// Fallback for non-x86_64 platforms.
1108    #[cfg(not(target_arch = "x86_64"))]
1109    fn find_byte_positions_simd(haystack: &[u8], needle: u8) -> Vec<usize> {
1110        haystack
1111            .iter()
1112            .enumerate()
1113            .filter_map(|(i, &b)| if b == needle { Some(i) } else { None })
1114            .collect()
1115    }
1116
1117    /// Extract variables from pre-computed brace positions.
1118    ///
1119    /// Matches opening `{` with closing `}` respecting nesting (escaped `{{`).
1120    fn extract_vars_from_positions(
1121        &self,
1122        bytes: &[u8],
1123        opens: &[usize],
1124        closes: &[usize],
1125    ) -> Vec<String> {
1126        let mut vars = Vec::with_capacity(opens.len().min(closes.len()));
1127        let mut open_idx = 0;
1128        let mut close_idx = 0;
1129
1130        while open_idx < opens.len() && close_idx < closes.len() {
1131            let open_pos = opens[open_idx];
1132            let close_pos = closes[close_idx];
1133
1134            // Skip if close comes before open
1135            if close_pos <= open_pos {
1136                close_idx += 1;
1137                continue;
1138            }
1139
1140            // Check for escaped brace `{{` - skip if next char is also `{`
1141            if open_pos + 1 < bytes.len() && bytes[open_pos + 1] == b'{' {
1142                open_idx += 2; // Skip both `{` of `{{`
1143                continue;
1144            }
1145
1146            // Extract content between braces
1147            let content = &bytes[open_pos + 1..close_pos];
1148
1149            // Skip empty content
1150            if !content.is_empty() {
1151                if let Ok(var_str) = std::str::from_utf8(content) {
1152                    // Extract just the variable name (strip formatting specs like :, !, .)
1153                    let var_name = var_str
1154                        .split([':', '!', '.'])
1155                        .next()
1156                        .unwrap_or(var_str)
1157                        .trim();
1158
1159                    if !var_name.is_empty() {
1160                        vars.push(var_name.to_string());
1161                    }
1162                }
1163            }
1164
1165            open_idx += 1;
1166            close_idx += 1;
1167        }
1168
1169        vars
1170    }
1171
1172    /// Scalar fallback for small strings or non-SIMD paths.
1173    fn extract_fstring_variables_scalar(&self, text: &str) -> Vec<String> {
1174        let mut vars = Vec::new();
1175        let mut in_brace = false;
1176        let mut current_var = String::new();
1177
1178        for ch in text.chars() {
1179            if ch == '{' && !in_brace {
1180                in_brace = true;
1181                current_var.clear();
1182            } else if ch == '}' && in_brace {
1183                in_brace = false;
1184                if !current_var.is_empty() && !current_var.starts_with('{') {
1185                    // Extract just the variable name (strip formatting specs)
1186                    let var_name = current_var
1187                        .split([':', '!', '.'])
1188                        .next()
1189                        .unwrap_or(&current_var)
1190                        .trim();
1191                    if !var_name.is_empty() {
1192                        vars.push(var_name.to_string());
1193                    }
1194                }
1195            } else if in_brace {
1196                current_var.push(ch);
1197            }
1198        }
1199
1200        vars
1201    }
1202
1203    /// Collect all identifier variables from a node tree.
1204    fn collect_variables(&self, node: Node, source: &[u8]) -> Vec<String> {
1205        let mut vars = Vec::new();
1206        self.collect_variables_recursive(node, source, &mut vars);
1207        vars.sort();
1208        vars.dedup();
1209        vars
1210    }
1211
1212    fn collect_variables_recursive(&self, node: Node, source: &[u8], vars: &mut Vec<String>) {
1213        if node.kind() == "identifier" {
1214            let name = self.node_text(node, source).to_string();
1215            // Filter out common non-user-input identifiers
1216            if !["True", "False", "None", "self", "cls"].contains(&name.as_str()) {
1217                vars.push(name);
1218            }
1219        }
1220
1221        let mut cursor = node.walk();
1222        for child in node.children(&mut cursor) {
1223            self.collect_variables_recursive(child, source, vars);
1224        }
1225    }
1226
1227    /// Collect arguments from a call node.
1228    fn collect_call_args(&self, node: Node, source: &[u8]) -> Vec<String> {
1229        let mut vars = Vec::new();
1230        if let Some(args) = node.child_by_field_name("arguments") {
1231            self.collect_variables_recursive(args, source, &mut vars);
1232        }
1233        vars.sort();
1234        vars.dedup();
1235        vars
1236    }
1237
1238    /// Generate human-readable description.
1239    fn generate_description(
1240        &self,
1241        pattern: &UnsafePattern,
1242        sink_name: &str,
1243        vars: &[String],
1244    ) -> String {
1245        let var_list = if vars.is_empty() {
1246            "unknown variable".to_string()
1247        } else {
1248            vars.join(", ")
1249        };
1250
1251        match pattern {
1252            UnsafePattern::StringConcatenation => {
1253                format!(
1254                    "SQL injection via string concatenation in {}(). Variables {} are concatenated into the query string.",
1255                    sink_name, var_list
1256                )
1257            }
1258            UnsafePattern::FStringInterpolation => {
1259                format!(
1260                    "SQL injection via f-string interpolation in {}(). Variables {} are interpolated into the query.",
1261                    sink_name, var_list
1262                )
1263            }
1264            UnsafePattern::PercentFormat => {
1265                format!(
1266                    "SQL injection via percent formatting in {}(). Variables {} are formatted into the query.",
1267                    sink_name, var_list
1268                )
1269            }
1270            UnsafePattern::DotFormat => {
1271                format!(
1272                    "SQL injection via .format() in {}(). Variables {} are formatted into the query.",
1273                    sink_name, var_list
1274                )
1275            }
1276            UnsafePattern::TemplateLiteral => {
1277                format!(
1278                    "SQL injection via template literal in {}(). Variables {} are interpolated into the query.",
1279                    sink_name, var_list
1280                )
1281            }
1282            UnsafePattern::NonLiteralArgument => {
1283                format!(
1284                    "Potential SQL injection in {}(). Variable {} is passed directly to the query.",
1285                    sink_name, var_list
1286                )
1287            }
1288        }
1289    }
1290
1291    /// Generate remediation advice.
1292    fn generate_remediation(&self, pattern: &UnsafePattern, language: &str) -> String {
1293        match (pattern, language) {
1294            (_, "python") => {
1295                "Use parameterized queries with placeholders:\n\
1296                 cursor.execute(\"SELECT * FROM users WHERE id = ?\", (user_id,))\n\
1297                 Or use SQLAlchemy ORM methods with proper escaping."
1298                    .to_string()
1299            }
1300            (_, "typescript" | "javascript") => {
1301                "Use parameterized queries with placeholders:\n\
1302                 db.query(\"SELECT * FROM users WHERE id = $1\", [userId])\n\
1303                 Or use an ORM like Prisma, TypeORM, or Knex with proper parameter binding."
1304                    .to_string()
1305            }
1306            _ => "Use parameterized queries instead of string interpolation.".to_string(),
1307        }
1308    }
1309}
1310
1311// =============================================================================
1312// Tests
1313// =============================================================================
1314
1315#[cfg(test)]
1316mod tests {
1317    use super::*;
1318
1319    fn create_temp_file(content: &str, extension: &str) -> tempfile::NamedTempFile {
1320        use std::io::Write;
1321        let mut file = tempfile::Builder::new()
1322            .suffix(extension)
1323            .tempfile()
1324            .expect("Failed to create temp file");
1325        file.write_all(content.as_bytes())
1326            .expect("Failed to write temp file");
1327        file
1328    }
1329
1330    // =========================================================================
1331    // Python Tests
1332    // =========================================================================
1333
1334    #[test]
1335    fn test_python_fstring_injection() {
1336        let source = r#"
1337import sqlite3
1338conn = sqlite3.connect('test.db')
1339cursor = conn.cursor()
1340
1341def get_user(user_id):
1342    cursor.execute(f"SELECT * FROM users WHERE id = {user_id}")
1343    return cursor.fetchone()
1344"#;
1345        let file = create_temp_file(source, ".py");
1346        let detector = SqlInjectionDetector::new();
1347        let findings = detector
1348            .scan_file(file.path().to_str().unwrap())
1349            .expect("Scan should succeed");
1350
1351        assert!(!findings.is_empty(), "Should detect f-string injection");
1352        let finding = &findings[0];
1353        assert_eq!(finding.pattern, UnsafePattern::FStringInterpolation);
1354        assert_eq!(finding.severity, Severity::Critical);
1355        assert!(finding.tainted_variables.contains(&"user_id".to_string()));
1356    }
1357
1358    #[test]
1359    fn test_python_string_concat_injection() {
1360        let source = r#"
1361import sqlite3
1362conn = sqlite3.connect('test.db')
1363cursor = conn.cursor()
1364
1365def get_user(user_id):
1366    query = "SELECT * FROM users WHERE id = " + user_id
1367    cursor.execute(query)
1368    return cursor.fetchone()
1369"#;
1370        let file = create_temp_file(source, ".py");
1371        let detector = SqlInjectionDetector::new();
1372        let findings = detector
1373            .scan_file(file.path().to_str().unwrap())
1374            .expect("Scan should succeed");
1375
1376        // The detector finds the variable being passed to execute
1377        assert!(!findings.is_empty(), "Should detect variable injection");
1378    }
1379
1380    #[test]
1381    fn test_python_percent_format_injection() {
1382        let source = r#"
1383import sqlite3
1384conn = sqlite3.connect('test.db')
1385cursor = conn.cursor()
1386
1387def get_user(user_id):
1388    cursor.execute("SELECT * FROM users WHERE id = %s" % user_id)
1389    return cursor.fetchone()
1390"#;
1391        let file = create_temp_file(source, ".py");
1392        let detector = SqlInjectionDetector::new();
1393        let findings = detector
1394            .scan_file(file.path().to_str().unwrap())
1395            .expect("Scan should succeed");
1396
1397        assert!(!findings.is_empty(), "Should detect percent format injection");
1398        let finding = &findings[0];
1399        assert_eq!(finding.pattern, UnsafePattern::PercentFormat);
1400    }
1401
1402    #[test]
1403    fn test_python_safe_parameterized_query() {
1404        let source = r#"
1405import sqlite3
1406conn = sqlite3.connect('test.db')
1407cursor = conn.cursor()
1408
1409def get_user(user_id):
1410    cursor.execute("SELECT * FROM users WHERE id = ?", (user_id,))
1411    return cursor.fetchone()
1412"#;
1413        let file = create_temp_file(source, ".py");
1414        let detector = SqlInjectionDetector::new();
1415        let findings = detector
1416            .scan_file(file.path().to_str().unwrap())
1417            .expect("Scan should succeed");
1418
1419        assert!(
1420            findings.is_empty(),
1421            "Should NOT detect safe parameterized query"
1422        );
1423    }
1424
1425    #[test]
1426    fn test_python_safe_literal_query() {
1427        let source = r#"
1428import sqlite3
1429conn = sqlite3.connect('test.db')
1430cursor = conn.cursor()
1431
1432def get_all_users():
1433    cursor.execute("SELECT * FROM users")
1434    return cursor.fetchall()
1435"#;
1436        let file = create_temp_file(source, ".py");
1437        let detector = SqlInjectionDetector::new();
1438        let findings = detector
1439            .scan_file(file.path().to_str().unwrap())
1440            .expect("Scan should succeed");
1441
1442        assert!(findings.is_empty(), "Should NOT detect safe literal query");
1443    }
1444
1445    // =========================================================================
1446    // TypeScript Tests
1447    // =========================================================================
1448
1449    #[test]
1450    fn test_typescript_template_literal_injection() {
1451        let source = r#"
1452import { Pool } from 'pg';
1453const pool = new Pool();
1454
1455async function getUser(userId: string) {
1456    const result = await pool.query(`SELECT * FROM users WHERE id = ${userId}`);
1457    return result.rows[0];
1458}
1459"#;
1460        let file = create_temp_file(source, ".ts");
1461        let detector = SqlInjectionDetector::new();
1462        let findings = detector
1463            .scan_file(file.path().to_str().unwrap())
1464            .expect("Scan should succeed");
1465
1466        assert!(
1467            !findings.is_empty(),
1468            "Should detect template literal injection"
1469        );
1470        let finding = &findings[0];
1471        assert_eq!(finding.pattern, UnsafePattern::TemplateLiteral);
1472        assert_eq!(finding.severity, Severity::Critical);
1473    }
1474
1475    #[test]
1476    fn test_typescript_string_concat_injection() {
1477        let source = r#"
1478import { Pool } from 'pg';
1479const pool = new Pool();
1480
1481async function getUser(userId: string) {
1482    const query = "SELECT * FROM users WHERE id = " + userId;
1483    const result = await pool.query(query);
1484    return result.rows[0];
1485}
1486"#;
1487        let file = create_temp_file(source, ".ts");
1488        let detector = SqlInjectionDetector::new();
1489        let findings = detector
1490            .scan_file(file.path().to_str().unwrap())
1491            .expect("Scan should succeed");
1492
1493        // Should detect the variable being passed
1494        assert!(!findings.is_empty(), "Should detect variable injection");
1495    }
1496
1497    #[test]
1498    fn test_typescript_safe_parameterized_query() {
1499        let source = r#"
1500import { Pool } from 'pg';
1501const pool = new Pool();
1502
1503async function getUser(userId: string) {
1504    const result = await pool.query("SELECT * FROM users WHERE id = $1", [userId]);
1505    return result.rows[0];
1506}
1507"#;
1508        let file = create_temp_file(source, ".ts");
1509        let detector = SqlInjectionDetector::new();
1510        let findings = detector
1511            .scan_file(file.path().to_str().unwrap())
1512            .expect("Scan should succeed");
1513
1514        assert!(
1515            findings.is_empty(),
1516            "Should NOT detect safe parameterized query"
1517        );
1518    }
1519
1520    #[test]
1521    fn test_typescript_safe_literal_query() {
1522        let source = r#"
1523import { Pool } from 'pg';
1524const pool = new Pool();
1525
1526async function getAllUsers() {
1527    const result = await pool.query("SELECT * FROM users");
1528    return result.rows;
1529}
1530"#;
1531        let file = create_temp_file(source, ".ts");
1532        let detector = SqlInjectionDetector::new();
1533        let findings = detector
1534            .scan_file(file.path().to_str().unwrap())
1535            .expect("Scan should succeed");
1536
1537        assert!(findings.is_empty(), "Should NOT detect safe literal query");
1538    }
1539
1540    #[test]
1541    fn test_typescript_prisma_raw_injection() {
1542        let source = r#"
1543import { PrismaClient } from '@prisma/client';
1544const prisma = new PrismaClient();
1545
1546async function getUser(userId: string) {
1547    return prisma.$queryRaw(`SELECT * FROM users WHERE id = ${userId}`);
1548}
1549"#;
1550        let file = create_temp_file(source, ".ts");
1551        let detector = SqlInjectionDetector::new();
1552        let findings = detector
1553            .scan_file(file.path().to_str().unwrap())
1554            .expect("Scan should succeed");
1555
1556        assert!(!findings.is_empty(), "Should detect Prisma raw query injection");
1557    }
1558
1559    // =========================================================================
1560    // Utility Tests
1561    // =========================================================================
1562
1563    #[test]
1564    fn test_extract_fstring_variables() {
1565        let detector = SqlInjectionDetector::new();
1566
1567        let vars = detector.extract_fstring_variables(r#"f"SELECT * FROM users WHERE id = {user_id}""#);
1568        assert_eq!(vars, vec!["user_id"]);
1569
1570        let vars = detector.extract_fstring_variables(r#"f"SELECT * FROM {table} WHERE {col} = {val}""#);
1571        assert_eq!(vars, vec!["table", "col", "val"]);
1572
1573        let vars = detector.extract_fstring_variables(r#"f"value: {x:.2f}""#);
1574        assert_eq!(vars, vec!["x"]);
1575    }
1576
1577    #[test]
1578    fn test_severity_display() {
1579        assert_eq!(Severity::Critical.to_string(), "CRITICAL");
1580        assert_eq!(Severity::High.to_string(), "HIGH");
1581        assert_eq!(Severity::Medium.to_string(), "MEDIUM");
1582        assert_eq!(Severity::Low.to_string(), "LOW");
1583    }
1584
1585    #[test]
1586    fn test_pattern_display() {
1587        assert_eq!(
1588            UnsafePattern::StringConcatenation.to_string(),
1589            "string_concatenation"
1590        );
1591        assert_eq!(
1592            UnsafePattern::FStringInterpolation.to_string(),
1593            "f_string_interpolation"
1594        );
1595        assert_eq!(
1596            UnsafePattern::TemplateLiteral.to_string(),
1597            "template_literal"
1598        );
1599    }
1600
1601    #[test]
1602    fn test_scan_result_counts() {
1603        let result = ScanResult {
1604            findings: vec![],
1605            files_scanned: 10,
1606            sinks_found: 5,
1607            severity_counts: [("CRITICAL".to_string(), 2), ("HIGH".to_string(), 3)]
1608                .into_iter()
1609                .collect(),
1610            language: "python".to_string(),
1611        };
1612
1613        assert_eq!(result.files_scanned, 10);
1614        assert_eq!(result.sinks_found, 5);
1615        assert_eq!(result.severity_counts.get("CRITICAL"), Some(&2));
1616    }
1617}