Skip to main content

schema_risk/
impact.rs

1//! Query impact detection.
2//!
3//! Scans source files in a given directory for SQL string literals and ORM
4//! query patterns that reference tables or columns being modified by the
5//! migration.  Reports which files contain queries likely affected by the
6//! pending schema change.
7//!
8//! Uses `rayon` for parallel directory traversal.
9
10use rayon::prelude::*;
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::ffi::OsStr;
14use std::path::{Path, PathBuf};
15
16// ─────────────────────────────────────────────
17// Public types
18// ─────────────────────────────────────────────
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct ImpactedFile {
22    /// Relative path from the scan root
23    pub path: String,
24    /// Tables mentioned in this file that overlap with the migration
25    pub tables_referenced: Vec<String>,
26    /// Columns mentioned in this file that overlap with the migration's
27    /// dropped / renamed / type-changed columns
28    pub columns_referenced: Vec<String>,
29    /// Relevant lines of code (file:line → snippet)
30    pub hits: Vec<QueryHit>,
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct QueryHit {
35    pub line: usize,
36    pub snippet: String,
37    pub match_type: MatchType,
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
41pub enum MatchType {
42    /// Plain SQL string literal containing the table/column name
43    SqlLiteral,
44    /// ORM query builder reference (Sequelize, Prisma, SQLAlchemy, Diesel…)
45    OrmReference,
46    /// An `include:` / `select:` key that contains the column name
47    FieldReference,
48}
49
50#[derive(Debug, Clone, Default, Serialize, Deserialize)]
51pub struct ImpactReport {
52    /// Number of source files scanned
53    pub files_scanned: usize,
54    /// Files that reference affected schema objects
55    pub impacted_files: Vec<ImpactedFile>,
56    /// Table → list of files that reference it
57    pub table_file_map: HashMap<String, Vec<String>>,
58    /// Column → list of files that reference it
59    pub column_file_map: HashMap<String, Vec<String>>,
60}
61
62// ─────────────────────────────────────────────
63// Scanner
64// ─────────────────────────────────────────────
65
66/// Source file extensions we want to look inside.
67const SOURCE_EXTENSIONS: &[&str] = &[
68    "rs", "go", "py", "js", "ts", "jsx", "tsx", "rb", "java", "cs", "php", "sql", "graphql",
69];
70
71pub struct ImpactScanner {
72    /// Tables to look for (lowercased)
73    tables: Vec<String>,
74    /// Columns to look for (lowercased)
75    columns: Vec<String>,
76}
77
78impl ImpactScanner {
79    /// Create a scanner that **skips identifiers shorter than 4 characters**
80    /// to avoid false positives (B-03 fix).
81    ///
82    /// Use [`new_scan_short`] if you need to include short identifiers.
83    pub fn new(tables: Vec<String>, columns: Vec<String>) -> Self {
84        Self::new_with_options(tables, columns, true)
85    }
86
87    /// Create a scanner that includes all short identifiers (opt-in via `--scan-short-names`).
88    pub fn new_scan_short(tables: Vec<String>, columns: Vec<String>) -> Self {
89        Self::new_with_options(tables, columns, false)
90    }
91
92    fn new_with_options(tables: Vec<String>, columns: Vec<String>, skip_short: bool) -> Self {
93        let filter = |idents: Vec<String>| -> Vec<String> {
94            idents
95                .into_iter()
96                .filter(|s| !skip_short || s.chars().count() >= 4)
97                .map(|s| s.to_lowercase())
98                .collect()
99        };
100        Self {
101            tables: filter(tables),
102            columns: filter(columns),
103        }
104    }
105
106    /// Walk `root_dir` recursively, scan all source files in parallel, return
107    /// an `ImpactReport`.
108    pub fn scan(&self, root_dir: &Path) -> ImpactReport {
109        // Collect all source file paths first
110        let paths = collect_source_files(root_dir);
111        let total = paths.len();
112
113        let impacted_files: Vec<ImpactedFile> = paths
114            .par_iter()
115            .filter_map(|path| self.scan_file(path))
116            .collect();
117
118        // Build lookup maps
119        let mut table_file_map: HashMap<String, Vec<String>> = HashMap::new();
120        let mut column_file_map: HashMap<String, Vec<String>> = HashMap::new();
121        for f in &impacted_files {
122            for t in &f.tables_referenced {
123                table_file_map
124                    .entry(t.clone())
125                    .or_default()
126                    .push(f.path.clone());
127            }
128            for c in &f.columns_referenced {
129                column_file_map
130                    .entry(c.clone())
131                    .or_default()
132                    .push(f.path.clone());
133            }
134        }
135
136        ImpactReport {
137            files_scanned: total,
138            impacted_files,
139            table_file_map,
140            column_file_map,
141        }
142    }
143
144    // ── Per-file scan ─────────────────────────────────────────────────────
145
146    fn scan_file(&self, path: &Path) -> Option<ImpactedFile> {
147        let content = std::fs::read_to_string(path).ok()?;
148        let content_lower = content.to_lowercase();
149
150        let mut tables_found: Vec<String> = Vec::new();
151        let mut columns_found: Vec<String> = Vec::new();
152        let mut hits: Vec<QueryHit> = Vec::new();
153
154        for (line_idx, line) in content.lines().enumerate() {
155            let line_lower = line.to_lowercase();
156
157            for table in &self.tables {
158                if line_lower.contains(table.as_str()) {
159                    if !tables_found.contains(table) {
160                        tables_found.push(table.clone());
161                    }
162                    let match_type = classify_match(&line_lower, table);
163                    hits.push(QueryHit {
164                        line: line_idx + 1,
165                        snippet: line.trim().chars().take(200).collect(),
166                        match_type,
167                    });
168                }
169            }
170
171            for col in &self.columns {
172                if line_lower.contains(col.as_str())
173                    && !content_lower.contains(&format!("-- {}", col))
174                {
175                    if !columns_found.contains(col) {
176                        columns_found.push(col.clone());
177                    }
178                    // Avoid duplicate hits on the same line
179                    if !hits.iter().any(|h| h.line == line_idx + 1) {
180                        let match_type = classify_match(&line_lower, col);
181                        hits.push(QueryHit {
182                            line: line_idx + 1,
183                            snippet: line.trim().chars().take(200).collect(),
184                            match_type,
185                        });
186                    }
187                }
188            }
189        }
190
191        if tables_found.is_empty() && columns_found.is_empty() {
192            return None;
193        }
194
195        let rel_path = path.to_string_lossy().to_string();
196
197        Some(ImpactedFile {
198            path: rel_path,
199            tables_referenced: tables_found,
200            columns_referenced: columns_found,
201            hits,
202        })
203    }
204}
205
206// ── Classify what kind of reference this line contains ───────────────────
207
208fn classify_match(line: &str, token: &str) -> MatchType {
209    // ORM patterns
210    let orm_patterns = [
211        "select(",
212        "where(",
213        "findone",
214        "findall",
215        "findmany",
216        "create(",
217        "update(",
218        "delete(",
219        "include:",
220        "prisma.",
221        "model.",
222        ".query(",
223        "execute(",
224        "from(",
225        "join(",
226        "diesel::",
227        "querybuilder",
228        "activerecord",
229        "sqlalchemy",
230    ];
231
232    let field_patterns = ["include:", "select:", "fields:", "columns:", "attributes:"];
233
234    if field_patterns.iter().any(|p| line.contains(p)) {
235        return MatchType::FieldReference;
236    }
237
238    if orm_patterns.iter().any(|p| line.contains(p)) {
239        return MatchType::OrmReference;
240    }
241
242    // Raw SQL string heuristic: the token appears between quotes or after FROM/JOIN/INTO
243    let sql_keywords = ["from ", "join ", "into ", "update ", "\"", "'", "`"];
244    if sql_keywords.iter().any(|k| {
245        if let Some(pos) = line.find(k) {
246            line[pos..].contains(token)
247        } else {
248            false
249        }
250    }) {
251        return MatchType::SqlLiteral;
252    }
253
254    MatchType::OrmReference
255}
256
257// ── Collect all source files under a directory ───────────────────────────
258
259fn collect_source_files(root: &Path) -> Vec<PathBuf> {
260    let mut files = Vec::new();
261    collect_recursive(root, &mut files);
262    files
263}
264
265fn collect_recursive(dir: &Path, out: &mut Vec<PathBuf>) {
266    let Ok(entries) = std::fs::read_dir(dir) else {
267        return;
268    };
269
270    for entry in entries.flatten() {
271        let path = entry.path();
272
273        // Skip hidden dirs and common build/vendor dirs
274        let name = path.file_name().and_then(OsStr::to_str).unwrap_or("");
275        if name.starts_with('.')
276            || matches!(
277                name,
278                "node_modules" | "target" | "dist" | "build" | "vendor" | "__pycache__" | ".git"
279            )
280        {
281            continue;
282        }
283
284        if path.is_dir() {
285            collect_recursive(&path, out);
286        } else if let Some(ext) = path.extension().and_then(OsStr::to_str) {
287            if SOURCE_EXTENSIONS.contains(&ext) {
288                out.push(path);
289            }
290        }
291    }
292}
293
294// ─────────────────────────────────────────────────────────────────────────────
295// SQL Extraction from Source Code
296// ─────────────────────────────────────────────────────────────────────────────
297
298/// SQL extracted from source code (not from a .sql file).
299#[derive(Debug, Clone, Serialize, Deserialize)]
300pub struct ExtractedSql {
301    /// The source file path.
302    pub source_file: String,
303    /// Line number where the SQL was found.
304    pub line: usize,
305    /// Column number (if known).
306    pub column: Option<usize>,
307    /// The extracted SQL string.
308    pub sql: String,
309    /// The ORM/framework context.
310    pub context: SqlContext,
311    /// Confidence score (0.0 - 1.0).
312    pub confidence: f32,
313}
314
315/// The context in which SQL was found.
316#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
317pub enum SqlContext {
318    /// Raw SQL string literal (generic).
319    RawSql,
320    /// Prisma $queryRaw / $executeRaw.
321    PrismaRaw,
322    /// TypeORM query / createQueryBuilder.
323    TypeOrm,
324    /// Sequelize raw query.
325    Sequelize,
326    /// SQLAlchemy text() / execute().
327    SqlAlchemy,
328    /// GORM Raw / Exec.
329    Gorm,
330    /// Diesel sql_query.
331    Diesel,
332    /// Entity Framework FromSqlRaw.
333    EntityFramework,
334    /// Laravel DB::raw / DB::statement.
335    Eloquent,
336    /// Rails ActiveRecord execute.
337    ActiveRecord,
338    /// Unknown / generic context.
339    Unknown,
340}
341
342impl std::fmt::Display for SqlContext {
343    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
344        match self {
345            SqlContext::RawSql => write!(f, "Raw SQL"),
346            SqlContext::PrismaRaw => write!(f, "Prisma"),
347            SqlContext::TypeOrm => write!(f, "TypeORM"),
348            SqlContext::Sequelize => write!(f, "Sequelize"),
349            SqlContext::SqlAlchemy => write!(f, "SQLAlchemy"),
350            SqlContext::Gorm => write!(f, "GORM"),
351            SqlContext::Diesel => write!(f, "Diesel"),
352            SqlContext::EntityFramework => write!(f, "Entity Framework"),
353            SqlContext::Eloquent => write!(f, "Eloquent"),
354            SqlContext::ActiveRecord => write!(f, "ActiveRecord"),
355            SqlContext::Unknown => write!(f, "Unknown"),
356        }
357    }
358}
359
360/// Result of scanning a codebase for SQL.
361#[derive(Debug, Clone, Default, Serialize, Deserialize)]
362pub struct SqlExtractionReport {
363    /// Number of source files scanned.
364    pub files_scanned: usize,
365    /// All extracted SQL statements.
366    pub extracted: Vec<ExtractedSql>,
367    /// SQL statements that look dangerous (DELETE, DROP, TRUNCATE, ALTER).
368    pub dangerous: Vec<ExtractedSql>,
369    /// Breakdown by context/ORM.
370    pub by_context: HashMap<String, usize>,
371}
372
373/// SQL extraction engine.
374pub struct SqlExtractor {
375    /// Compiled regex patterns for SQL extraction.
376    patterns: Vec<SqlExtractionPattern>,
377}
378
379struct SqlExtractionPattern {
380    regex: regex::Regex,
381    context: SqlContext,
382    /// File extensions this pattern applies to ("*" = all).
383    extensions: Vec<&'static str>,
384    /// Group index that captures the SQL (0 = whole match).
385    capture_group: usize,
386    /// Base confidence score.
387    confidence: f32,
388}
389
390impl Default for SqlExtractor {
391    fn default() -> Self {
392        Self::new()
393    }
394}
395
396impl SqlExtractor {
397    /// Create a new SQL extractor with default patterns.
398    pub fn new() -> Self {
399        let patterns = Self::build_patterns();
400        Self { patterns }
401    }
402
403    fn build_patterns() -> Vec<SqlExtractionPattern> {
404        let mut patterns = Vec::new();
405
406        // Helper to add patterns
407        let mut add =
408            |pattern: &str, ctx: SqlContext, exts: &[&'static str], group: usize, conf: f32| {
409                if let Ok(re) = regex::Regex::new(pattern) {
410                    patterns.push(SqlExtractionPattern {
411                        regex: re,
412                        context: ctx,
413                        extensions: exts.to_vec(),
414                        capture_group: group,
415                        confidence: conf,
416                    });
417                }
418            };
419
420        // ── Prisma (JavaScript/TypeScript) ───────────────────────────────────
421        add(
422            r#"\$queryRaw\s*`([^`]+)`"#,
423            SqlContext::PrismaRaw,
424            &["ts", "js", "tsx", "jsx"],
425            1,
426            0.95,
427        );
428        add(
429            r#"\$executeRaw\s*`([^`]+)`"#,
430            SqlContext::PrismaRaw,
431            &["ts", "js", "tsx", "jsx"],
432            1,
433            0.95,
434        );
435        add(
436            r#"Prisma\.sql\s*`([^`]+)`"#,
437            SqlContext::PrismaRaw,
438            &["ts", "js", "tsx", "jsx"],
439            1,
440            0.9,
441        );
442
443        // ── TypeORM (JavaScript/TypeScript) ──────────────────────────────────
444        add(
445            r#"\.query\s*\(\s*["'`]([^"'`]+)["'`]"#,
446            SqlContext::TypeOrm,
447            &["ts", "js", "tsx", "jsx"],
448            1,
449            0.85,
450        );
451        add(
452            r#"createQueryBuilder\s*\(\s*["']([^"']+)["']"#,
453            SqlContext::TypeOrm,
454            &["ts", "js", "tsx", "jsx"],
455            1,
456            0.8,
457        );
458        add(
459            r#"\.createQueryRunner\(\)\.query\s*\(\s*["'`]([^"'`]+)["'`]"#,
460            SqlContext::TypeOrm,
461            &["ts", "js"],
462            1,
463            0.9,
464        );
465
466        // ── Sequelize (JavaScript/TypeScript) ────────────────────────────────
467        add(
468            r#"sequelize\.query\s*\(\s*["'`]([^"'`]+)["'`]"#,
469            SqlContext::Sequelize,
470            &["ts", "js", "tsx", "jsx"],
471            1,
472            0.9,
473        );
474        add(
475            r#"QueryTypes\.\w+.*["'`]([^"'`]+)["'`]"#,
476            SqlContext::Sequelize,
477            &["ts", "js"],
478            1,
479            0.85,
480        );
481
482        // ── SQLAlchemy (Python) ──────────────────────────────────────────────
483        add(
484            r#"text\s*\(\s*["']([^"']+)["']"#,
485            SqlContext::SqlAlchemy,
486            &["py"],
487            1,
488            0.9,
489        );
490        add(
491            r#"execute\s*\(\s*["']([^"']+)["']"#,
492            SqlContext::SqlAlchemy,
493            &["py"],
494            1,
495            0.85,
496        );
497        add(
498            r#"session\.execute\s*\(\s*["']([^"']+)["']"#,
499            SqlContext::SqlAlchemy,
500            &["py"],
501            1,
502            0.9,
503        );
504        add(
505            r#"connection\.execute\s*\(\s*["']([^"']+)["']"#,
506            SqlContext::SqlAlchemy,
507            &["py"],
508            1,
509            0.9,
510        );
511
512        // ── Django (Python) ──────────────────────────────────────────────────
513        add(
514            r#"cursor\.execute\s*\(\s*["']([^"']+)["']"#,
515            SqlContext::SqlAlchemy,
516            &["py"],
517            1,
518            0.9,
519        );
520        add(
521            r#"\.raw\s*\(\s*["']([^"']+)["']"#,
522            SqlContext::SqlAlchemy,
523            &["py"],
524            1,
525            0.85,
526        );
527
528        // ── GORM (Go) ────────────────────────────────────────────────────────
529        add(
530            r#"\.Raw\s*\(\s*["'`]([^"'`]+)["'`]"#,
531            SqlContext::Gorm,
532            &["go"],
533            1,
534            0.9,
535        );
536        add(
537            r#"\.Exec\s*\(\s*["'`]([^"'`]+)["'`]"#,
538            SqlContext::Gorm,
539            &["go"],
540            1,
541            0.9,
542        );
543        add(
544            r#"db\.Query\s*\(\s*["'`]([^"'`]+)["'`]"#,
545            SqlContext::Gorm,
546            &["go"],
547            1,
548            0.85,
549        );
550
551        // ── Diesel (Rust) ────────────────────────────────────────────────────
552        add(
553            r#"sql_query\s*\(\s*["']([^"']+)["']"#,
554            SqlContext::Diesel,
555            &["rs"],
556            1,
557            0.9,
558        );
559        add(
560            r#"diesel::sql_query\s*\(\s*["']([^"']+)["']"#,
561            SqlContext::Diesel,
562            &["rs"],
563            1,
564            0.95,
565        );
566
567        // ── Entity Framework (C#) ────────────────────────────────────────────
568        add(
569            r#"\.FromSqlRaw\s*\(\s*["']([^"']+)["']"#,
570            SqlContext::EntityFramework,
571            &["cs"],
572            1,
573            0.9,
574        );
575        add(
576            r#"\.ExecuteSqlRaw\s*\(\s*["']([^"']+)["']"#,
577            SqlContext::EntityFramework,
578            &["cs"],
579            1,
580            0.9,
581        );
582        add(
583            r#"SqlQuery<[^>]+>\s*\(\s*["']([^"']+)["']"#,
584            SqlContext::EntityFramework,
585            &["cs"],
586            1,
587            0.85,
588        );
589
590        // ── Laravel Eloquent (PHP) ───────────────────────────────────────────
591        add(
592            r#"DB::raw\s*\(\s*["']([^"']+)["']"#,
593            SqlContext::Eloquent,
594            &["php"],
595            1,
596            0.9,
597        );
598        add(
599            r#"DB::statement\s*\(\s*["']([^"']+)["']"#,
600            SqlContext::Eloquent,
601            &["php"],
602            1,
603            0.9,
604        );
605        add(
606            r#"DB::select\s*\(\s*["']([^"']+)["']"#,
607            SqlContext::Eloquent,
608            &["php"],
609            1,
610            0.85,
611        );
612        add(
613            r#"DB::insert\s*\(\s*["']([^"']+)["']"#,
614            SqlContext::Eloquent,
615            &["php"],
616            1,
617            0.85,
618        );
619        add(
620            r#"DB::update\s*\(\s*["']([^"']+)["']"#,
621            SqlContext::Eloquent,
622            &["php"],
623            1,
624            0.85,
625        );
626        add(
627            r#"DB::delete\s*\(\s*["']([^"']+)["']"#,
628            SqlContext::Eloquent,
629            &["php"],
630            1,
631            0.85,
632        );
633
634        // ── Rails ActiveRecord (Ruby) ────────────────────────────────────────
635        add(
636            r#"execute\s*\(\s*["']([^"']+)["']"#,
637            SqlContext::ActiveRecord,
638            &["rb"],
639            1,
640            0.85,
641        );
642        add(
643            r#"exec_query\s*\(\s*["']([^"']+)["']"#,
644            SqlContext::ActiveRecord,
645            &["rb"],
646            1,
647            0.9,
648        );
649        add(
650            r#"connection\.execute\s*\(\s*["']([^"']+)["']"#,
651            SqlContext::ActiveRecord,
652            &["rb"],
653            1,
654            0.9,
655        );
656        add(
657            r#"find_by_sql\s*\(\s*["']([^"']+)["']"#,
658            SqlContext::ActiveRecord,
659            &["rb"],
660            1,
661            0.9,
662        );
663
664        // ── Generic SQL string patterns (lower confidence) ───────────────────
665        // These match SQL keywords in string literals
666        add(
667            r#"["'`]((?:SELECT|INSERT|UPDATE|DELETE|CREATE|ALTER|DROP|TRUNCATE)\s+[^"'`]{10,})["'`]"#,
668            SqlContext::RawSql,
669            &["*"],
670            1,
671            0.7,
672        );
673
674        patterns
675    }
676
677    /// Extract SQL from a single file.
678    pub fn extract_from_file(&self, path: &Path) -> Vec<ExtractedSql> {
679        let ext = path.extension().and_then(|e| e.to_str()).unwrap_or("");
680
681        let content = match std::fs::read_to_string(path) {
682            Ok(c) => c,
683            Err(_) => return vec![],
684        };
685
686        let path_str = path.to_string_lossy().to_string();
687        let mut results = Vec::new();
688
689        for (line_idx, line) in content.lines().enumerate() {
690            for pattern in &self.patterns {
691                // Check if this pattern applies to this file type
692                if !pattern.extensions.contains(&"*") && !pattern.extensions.contains(&ext) {
693                    continue;
694                }
695
696                for cap in pattern.regex.captures_iter(line) {
697                    let sql = if pattern.capture_group > 0 {
698                        cap.get(pattern.capture_group)
699                            .map(|m| m.as_str())
700                            .unwrap_or("")
701                    } else {
702                        cap.get(0).map(|m| m.as_str()).unwrap_or("")
703                    };
704
705                    let sql = sql.trim().to_string();
706
707                    // Skip empty or very short matches
708                    if sql.len() < 5 {
709                        continue;
710                    }
711
712                    // Verify it looks like SQL
713                    if !Self::looks_like_sql(&sql) {
714                        continue;
715                    }
716
717                    results.push(ExtractedSql {
718                        source_file: path_str.clone(),
719                        line: line_idx + 1,
720                        column: cap.get(1).map(|m| m.start()),
721                        sql,
722                        context: pattern.context.clone(),
723                        confidence: pattern.confidence,
724                    });
725                }
726            }
727        }
728
729        results
730    }
731
732    /// Check if a string looks like SQL.
733    fn looks_like_sql(s: &str) -> bool {
734        let upper = s.to_uppercase();
735        let sql_keywords = [
736            "SELECT", "INSERT", "UPDATE", "DELETE", "CREATE", "ALTER", "DROP", "TRUNCATE", "FROM",
737            "WHERE", "JOIN", "TABLE", "INDEX", "COLUMN",
738        ];
739        sql_keywords.iter().any(|kw| upper.contains(kw))
740    }
741
742    /// Check if SQL is dangerous (DDL or deletes).
743    fn is_dangerous_sql(sql: &str) -> bool {
744        let upper = sql.to_uppercase();
745        upper.contains("DROP ")
746            || upper.contains("TRUNCATE ")
747            || upper.contains("DELETE ")
748            || upper.contains("ALTER ")
749            || upper.contains("CREATE INDEX")
750    }
751
752    /// Scan a directory for SQL in source code.
753    pub fn scan_directory(&self, root: &Path) -> SqlExtractionReport {
754        let files = collect_source_files(root);
755        let total = files.len();
756
757        let extracted: Vec<ExtractedSql> = files
758            .par_iter()
759            .flat_map(|path| self.extract_from_file(path))
760            .collect();
761
762        // Separate dangerous SQL
763        let dangerous: Vec<ExtractedSql> = extracted
764            .iter()
765            .filter(|e| Self::is_dangerous_sql(&e.sql))
766            .cloned()
767            .collect();
768
769        // Count by context
770        let mut by_context: HashMap<String, usize> = HashMap::new();
771        for e in &extracted {
772            *by_context.entry(e.context.to_string()).or_insert(0) += 1;
773        }
774
775        SqlExtractionReport {
776            files_scanned: total,
777            extracted,
778            dangerous,
779            by_context,
780        }
781    }
782}
783
784// ─────────────────────────────────────────────────────────────────────────────
785// Tests
786// ─────────────────────────────────────────────────────────────────────────────
787
788#[cfg(test)]
789mod tests {
790    use super::*;
791
792    #[test]
793    fn test_sql_extractor_prisma() {
794        let extractor = SqlExtractor::new();
795        let code =
796            r#"const result = await prisma.$queryRaw`SELECT * FROM users WHERE id = ${id}`;"#;
797
798        // Create a temp file
799        let temp_dir = std::env::temp_dir().join("schema-risk-test-prisma");
800        let _ = std::fs::create_dir_all(&temp_dir);
801        let file_path = temp_dir.join("test.ts");
802        std::fs::write(&file_path, code).unwrap();
803
804        let results = extractor.extract_from_file(&file_path);
805        assert!(!results.is_empty());
806        assert_eq!(results[0].context, SqlContext::PrismaRaw);
807        assert!(results[0].sql.contains("SELECT"));
808
809        let _ = std::fs::remove_dir_all(&temp_dir);
810    }
811
812    #[test]
813    fn test_sql_extractor_raw_sql() {
814        let extractor = SqlExtractor::new();
815        let code = r#"const query = "SELECT * FROM users WHERE active = true";"#;
816
817        let temp_dir = std::env::temp_dir().join("schema-risk-test-raw");
818        let _ = std::fs::create_dir_all(&temp_dir);
819        let file_path = temp_dir.join("test.js");
820        std::fs::write(&file_path, code).unwrap();
821
822        let results = extractor.extract_from_file(&file_path);
823        assert!(!results.is_empty());
824        assert!(results[0].sql.contains("SELECT"));
825
826        let _ = std::fs::remove_dir_all(&temp_dir);
827    }
828
829    #[test]
830    fn test_dangerous_sql_detection() {
831        assert!(SqlExtractor::is_dangerous_sql("DROP TABLE users"));
832        assert!(SqlExtractor::is_dangerous_sql(
833            "DELETE FROM users WHERE id = 1"
834        ));
835        assert!(SqlExtractor::is_dangerous_sql("TRUNCATE TABLE sessions"));
836        assert!(SqlExtractor::is_dangerous_sql(
837            "ALTER TABLE users ADD COLUMN age INT"
838        ));
839        assert!(!SqlExtractor::is_dangerous_sql("SELECT * FROM users"));
840        assert!(!SqlExtractor::is_dangerous_sql(
841            "INSERT INTO users (name) VALUES ('test')"
842        ));
843    }
844
845    #[test]
846    fn test_looks_like_sql() {
847        assert!(SqlExtractor::looks_like_sql("SELECT * FROM users"));
848        assert!(SqlExtractor::looks_like_sql(
849            "INSERT INTO users (name) VALUES ('test')"
850        ));
851        assert!(SqlExtractor::looks_like_sql("DROP TABLE users"));
852        assert!(!SqlExtractor::looks_like_sql("Hello world"));
853        assert!(!SqlExtractor::looks_like_sql("const x = 5"));
854    }
855}