Skip to main content

schema_risk/
impact.rs

1//! Query impact detection.
2//!
3//! Scans source files in a given directory for SQL string literals and ORM
4//! query patterns that reference tables or columns being modified by the
5//! migration.  Reports which files contain queries likely affected by the
6//! pending schema change.
7//!
8//! Uses `rayon` for parallel directory traversal.
9
10use rayon::prelude::*;
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::ffi::OsStr;
14use std::path::{Path, PathBuf};
15
16// ─────────────────────────────────────────────
17// Public types
18// ─────────────────────────────────────────────
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct ImpactedFile {
22    /// Relative path from the scan root
23    pub path: String,
24    /// Tables mentioned in this file that overlap with the migration
25    pub tables_referenced: Vec<String>,
26    /// Columns mentioned in this file that overlap with the migration's
27    /// dropped / renamed / type-changed columns
28    pub columns_referenced: Vec<String>,
29    /// Relevant lines of code (file:line → snippet)
30    pub hits: Vec<QueryHit>,
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct QueryHit {
35    pub line: usize,
36    pub snippet: String,
37    pub match_type: MatchType,
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
41pub enum MatchType {
42    /// Plain SQL string literal containing the table/column name
43    SqlLiteral,
44    /// ORM query builder reference (Sequelize, Prisma, SQLAlchemy, Diesel…)
45    OrmReference,
46    /// An `include:` / `select:` key that contains the column name
47    FieldReference,
48}
49
50#[derive(Debug, Clone, Default, Serialize, Deserialize)]
51pub struct ImpactReport {
52    /// Number of source files scanned
53    pub files_scanned: usize,
54    /// Files that reference affected schema objects
55    pub impacted_files: Vec<ImpactedFile>,
56    /// Table → list of files that reference it
57    pub table_file_map: HashMap<String, Vec<String>>,
58    /// Column → list of files that reference it
59    pub column_file_map: HashMap<String, Vec<String>>,
60}
61
62// ─────────────────────────────────────────────
63// Scanner
64// ─────────────────────────────────────────────
65
66/// Source file extensions we want to look inside.
67const SOURCE_EXTENSIONS: &[&str] = &[
68    "rs", "go", "py", "js", "ts", "jsx", "tsx", "rb", "java", "cs", "php", "sql", "graphql",
69];
70
71pub struct ImpactScanner {
72    /// Tables to look for (lowercased)
73    tables: Vec<String>,
74    /// Columns to look for (lowercased)
75    columns: Vec<String>,
76}
77
78impl ImpactScanner {
79    /// Create a scanner that **skips identifiers shorter than 4 characters**
80    /// to avoid false positives (B-03 fix).
81    ///
82    /// Use [`new_scan_short`] if you need to include short identifiers.
83    pub fn new(tables: Vec<String>, columns: Vec<String>) -> Self {
84        Self::new_with_options(tables, columns, true)
85    }
86
87    /// Create a scanner that includes all short identifiers (opt-in via `--scan-short-names`).
88    pub fn new_scan_short(tables: Vec<String>, columns: Vec<String>) -> Self {
89        Self::new_with_options(tables, columns, false)
90    }
91
92    fn new_with_options(tables: Vec<String>, columns: Vec<String>, skip_short: bool) -> Self {
93        let filter = |idents: Vec<String>| -> Vec<String> {
94            idents
95                .into_iter()
96                .filter(|s| !skip_short || s.chars().count() >= 4)
97                .map(|s| s.to_lowercase())
98                .collect()
99        };
100        Self {
101            tables: filter(tables),
102            columns: filter(columns),
103        }
104    }
105
106    /// Walk `root_dir` recursively, scan all source files in parallel, return
107    /// an `ImpactReport`.
108    pub fn scan(&self, root_dir: &Path) -> ImpactReport {
109        // Collect all source file paths first
110        let paths = collect_source_files(root_dir);
111        let total = paths.len();
112
113        let impacted_files: Vec<ImpactedFile> = paths
114            .par_iter()
115            .filter_map(|path| self.scan_file(path))
116            .collect();
117
118        // Build lookup maps
119        let mut table_file_map: HashMap<String, Vec<String>> = HashMap::new();
120        let mut column_file_map: HashMap<String, Vec<String>> = HashMap::new();
121        for f in &impacted_files {
122            for t in &f.tables_referenced {
123                table_file_map
124                    .entry(t.clone())
125                    .or_default()
126                    .push(f.path.clone());
127            }
128            for c in &f.columns_referenced {
129                column_file_map
130                    .entry(c.clone())
131                    .or_default()
132                    .push(f.path.clone());
133            }
134        }
135
136        ImpactReport {
137            files_scanned: total,
138            impacted_files,
139            table_file_map,
140            column_file_map,
141        }
142    }
143
144    // ── Per-file scan ─────────────────────────────────────────────────────
145
146    fn scan_file(&self, path: &Path) -> Option<ImpactedFile> {
147        let content = std::fs::read_to_string(path).ok()?;
148        let content_lower = content.to_lowercase();
149
150        let mut tables_found: Vec<String> = Vec::new();
151        let mut columns_found: Vec<String> = Vec::new();
152        let mut hits: Vec<QueryHit> = Vec::new();
153
154        for (line_idx, line) in content.lines().enumerate() {
155            let line_lower = line.to_lowercase();
156
157            for table in &self.tables {
158                if line_lower.contains(table.as_str()) {
159                    if !tables_found.contains(table) {
160                        tables_found.push(table.clone());
161                    }
162                    let match_type = classify_match(&line_lower, table);
163                    hits.push(QueryHit {
164                        line: line_idx + 1,
165                        snippet: line.trim().chars().take(200).collect(),
166                        match_type,
167                    });
168                }
169            }
170
171            for col in &self.columns {
172                if line_lower.contains(col.as_str())
173                    && !content_lower.contains(&format!("-- {}", col))
174                {
175                    if !columns_found.contains(col) {
176                        columns_found.push(col.clone());
177                    }
178                    // Avoid duplicate hits on the same line
179                    if !hits.iter().any(|h| h.line == line_idx + 1) {
180                        let match_type = classify_match(&line_lower, col);
181                        hits.push(QueryHit {
182                            line: line_idx + 1,
183                            snippet: line.trim().chars().take(200).collect(),
184                            match_type,
185                        });
186                    }
187                }
188            }
189        }
190
191        if tables_found.is_empty() && columns_found.is_empty() {
192            return None;
193        }
194
195        let rel_path = path.to_string_lossy().to_string();
196
197        Some(ImpactedFile {
198            path: rel_path,
199            tables_referenced: tables_found,
200            columns_referenced: columns_found,
201            hits,
202        })
203    }
204}
205
206// ── Classify what kind of reference this line contains ───────────────────
207
208fn classify_match(line: &str, token: &str) -> MatchType {
209    // ORM patterns
210    let orm_patterns = [
211        "select(",
212        "where(",
213        "findone",
214        "findall",
215        "findmany",
216        "create(",
217        "update(",
218        "delete(",
219        "include:",
220        "prisma.",
221        "model.",
222        ".query(",
223        "execute(",
224        "from(",
225        "join(",
226        "diesel::",
227        "querybuilder",
228        "activerecord",
229        "sqlalchemy",
230    ];
231
232    let field_patterns = ["include:", "select:", "fields:", "columns:", "attributes:"];
233
234    if field_patterns.iter().any(|p| line.contains(p)) {
235        return MatchType::FieldReference;
236    }
237
238    if orm_patterns.iter().any(|p| line.contains(p)) {
239        return MatchType::OrmReference;
240    }
241
242    // Raw SQL string heuristic: the token appears between quotes or after FROM/JOIN/INTO
243    let sql_keywords = ["from ", "join ", "into ", "update ", "\"", "'", "`"];
244    if sql_keywords.iter().any(|k| {
245        if let Some(pos) = line.find(k) {
246            line[pos..].contains(token)
247        } else {
248            false
249        }
250    }) {
251        return MatchType::SqlLiteral;
252    }
253
254    MatchType::OrmReference
255}
256
257// ── Collect all source files under a directory ───────────────────────────
258
259fn collect_source_files(root: &Path) -> Vec<PathBuf> {
260    let mut files = Vec::new();
261    collect_recursive(root, &mut files);
262    files
263}
264
265fn collect_recursive(dir: &Path, out: &mut Vec<PathBuf>) {
266    let Ok(entries) = std::fs::read_dir(dir) else {
267        return;
268    };
269
270    for entry in entries.flatten() {
271        let path = entry.path();
272
273        // Skip hidden dirs and common build/vendor dirs
274        let name = path.file_name().and_then(OsStr::to_str).unwrap_or("");
275        if name.starts_with('.')
276            || matches!(
277                name,
278                "node_modules" | "target" | "dist" | "build" | "vendor" | "__pycache__" | ".git"
279            )
280        {
281            continue;
282        }
283
284        if path.is_dir() {
285            collect_recursive(&path, out);
286        } else if let Some(ext) = path.extension().and_then(OsStr::to_str) {
287            if SOURCE_EXTENSIONS.contains(&ext) {
288                out.push(path);
289            }
290        }
291    }
292}