Skip to main content

aperion_shield/hooks/
check_staged.rs

1//! `aperion-shield --check-staged` — run the engine over the lines a
2//! commit is about to add or modify, refuse the commit if any line
3//! trips a Block-severity rule.
4//!
5//! ## What it inspects
6//!
7//! Only file extensions that historically generate destructive ops:
8//!
9//! | Extension                          | Scope mapped to              |
10//! |------------------------------------|------------------------------|
11//! | `.sql`                             | `execute_sql` tool-call      |
12//! | `.sh`, `.bash`, `.zsh`, `Makefile` | `shell` tool-call            |
13//! | `Dockerfile`                       | `shell` tool-call (RUN ...)  |
14//! | other (`.py`, `.js`, `.ts`, ...)   | `llm_response` text scope    |
15//!
16//! Files outside this whitelist are skipped; we deliberately do NOT lint
17//! every README, every JSON config, every test fixture. The cost of a
18//! false positive in a pre-commit hook is *very high* (it stops the
19//! developer cold and trains them to `--no-verify`); the cost of a
20//! false negative is bounded (the call still has to execute somewhere
21//! and Shield's MCP path will catch it). So we err on precision.
22//!
23//! ## What it skips (intentional)
24//!
25//! - Removed files & deleted lines. We're protecting against newly
26//!   *introduced* destructive code, not historical deletion.
27//! - Binary blobs.
28//! - Files larger than 256 KB (a heuristic — agent-generated
29//!   migrations and shell scripts are tiny; oversize is almost always
30//!   data).
31//! - Lines that are pure whitespace or pure comments.
32//!
33//! ## Exit codes
34//!
35//! Matches the documented hook contract — see `docs/hooks.md`:
36//!
37//! | Code | Meaning                                                  |
38//! |------|----------------------------------------------------------|
39//! | 0    | No blocking matches. Commit proceeds.                    |
40//! | 1    | At least one Block-severity match. Commit refused.       |
41//! | 2    | At least one Approval-severity match (pre-commit can't   |
42//! |      | prompt, so we surface this as a refusal with a note).    |
43//! | 3    | Operational error (git not on PATH, not in a repo, ...). |
44//!
45//! `SHIELD_HOOKS_DISABLE=1` short-circuits this entire function to
46//! exit 0 before any work happens — handled by the hook script, not
47//! here, so the env override is visible in `--check-staged` too (e.g.
48//! when invoked manually for debugging).
49
50use anyhow::{anyhow, Context, Result};
51use serde_json::json;
52use std::collections::BTreeMap;
53use std::path::PathBuf;
54use std::process::Command;
55
56use crate::engine::Engine;
57use crate::{decide, Adjustments, BurstDetector, Decision, WorkspaceContext};
58
59const MAX_FILE_SIZE_BYTES: u64 = 256 * 1024;
60
61/// One scanner finding, surfaced to the user in the pre-commit error
62/// banner. Kept granular so we can group by rule_id for readability.
63#[derive(Debug, Clone)]
64pub struct StagedFinding {
65    pub file: String,
66    pub line_no: usize,
67    pub line: String,
68    pub rule_id: String,
69    pub severity: String,
70    pub decision: String,
71    pub reason: String,
72    pub safer_alternative: Option<String>,
73}
74
75/// Aggregate of a `--check-staged` run, returned to the CLI dispatcher.
76#[derive(Debug, Default)]
77pub struct CheckStagedReport {
78    pub files_scanned: usize,
79    pub lines_scanned: usize,
80    pub findings: Vec<StagedFinding>,
81    /// Highest decision class we saw (None if nothing matched at all).
82    pub worst_decision: Option<Decision>,
83}
84
85impl CheckStagedReport {
86    /// Decide the process exit code per the table at the top of the
87    /// file. Caller maps the `u8` to `std::process::exit`.
88    pub fn exit_code(&self) -> u8 {
89        match &self.worst_decision {
90            Some(d) if d.is_blocking() => 1,
91            Some(Decision::Approval { .. }) => 2,
92            _ => 0,
93        }
94    }
95
96    /// Group findings by rule id for the human-facing banner. Keys are
97    /// sorted to give a stable display order.
98    pub fn group_by_rule(&self) -> BTreeMap<String, Vec<&StagedFinding>> {
99        let mut out: BTreeMap<String, Vec<&StagedFinding>> = BTreeMap::new();
100        for f in &self.findings {
101            out.entry(f.rule_id.clone()).or_default().push(f);
102        }
103        out
104    }
105}
106
107/// Top-level entrypoint. Walks the staged diff in `repo_root`, evaluates
108/// every added/modified line through `engine`, returns the aggregated
109/// report. Runs synchronously — git invocations are cheap and the corpus
110/// is small (~hundreds of lines at most for a normal commit).
111pub fn run(repo_root: &std::path::Path, engine: &Engine, workspace_root: Option<&std::path::Path>) -> Result<CheckStagedReport> {
112    if !is_inside_git_repo(repo_root)? {
113        return Err(anyhow!(
114            "--check-staged must be run inside a git repository (got {})",
115            repo_root.display()
116        ));
117    }
118
119    let staged_files = list_staged_files(repo_root)?;
120
121    // Set up the adaptive layer the same way `--check` does so workspace
122    // probes and burst detection behave consistently across modes.
123    // Decision memory is irrelevant for a one-shot pre-commit run -- we
124    // skip allocating it entirely so a developer's stale ~/.aperion-shield
125    // state never colours commit-time verdicts.
126    let workspace = match workspace_root {
127        Some(p) => WorkspaceContext::probe_at(&engine.policy, p),
128        None => WorkspaceContext::probe_at(&engine.policy, repo_root),
129    };
130    let burst = BurstDetector::new(engine.policy.burst_detector.clone());
131
132    let mut report = CheckStagedReport::default();
133
134    for staged in staged_files {
135        if !is_inspectable(&staged.path) {
136            continue;
137        }
138        let added = match list_added_lines(repo_root, &staged.path) {
139            Ok(v) => v,
140            Err(e) => {
141                // Don't fail the whole hook because of one unreadable
142                // file -- log and continue.
143                eprintln!(
144                    "[shield-check-staged] skipping {}: {}",
145                    staged.path, e
146                );
147                continue;
148            }
149        };
150        if added.is_empty() {
151            continue;
152        }
153        report.files_scanned += 1;
154
155        let kind = classify_file(&staged.path);
156
157        for AddedLine { line_no, content } in added {
158            if content.trim().is_empty() {
159                continue;
160            }
161            if is_pure_comment(&content, kind) {
162                continue;
163            }
164            report.lines_scanned += 1;
165
166            let (eval, _scope) = evaluate_line(engine, kind, &content, &workspace, &burst);
167            let decision = decide(&eval);
168            match decision {
169                Decision::Allow => continue,
170                Decision::Warn { .. }
171                | Decision::Approval { .. }
172                | Decision::Block { .. }
173                | Decision::IdentityVerification { .. } => {
174                    // Pick the dominant rule match for surfacing.
175                    let primary = eval
176                        .matches
177                        .iter()
178                        .max_by(|a, b| {
179                            a.severity.cmp(&b.severity).then(a.points.cmp(&b.points))
180                        })
181                        .cloned();
182                    let (rule_id, severity, reason, safer) = match primary {
183                        Some(m) => (
184                            m.rule_id.clone(),
185                            format!("{:?}", m.severity),
186                            m.reason.clone(),
187                            m.safer_alternative.clone(),
188                        ),
189                        None => (
190                            "shield.unknown".into(),
191                            "Medium".into(),
192                            "matched without an attributable rule".into(),
193                            None,
194                        ),
195                    };
196                    let dec_label = decision.label().to_string();
197                    report.findings.push(StagedFinding {
198                        file: staged.path.clone(),
199                        line_no,
200                        line: content,
201                        rule_id,
202                        severity,
203                        decision: dec_label,
204                        reason,
205                        safer_alternative: safer,
206                    });
207                    if report
208                        .worst_decision
209                        .as_ref()
210                        .map(|d| (severity_rank(&decision)) > severity_rank(d))
211                        .unwrap_or(true)
212                    {
213                        report.worst_decision = Some(decision.clone());
214                    }
215                }
216            }
217        }
218    }
219
220    Ok(report)
221}
222
223/// Decide whether `path` is one of the file types we lint. Anything
224/// outside this list is skipped (see module-level comment for why).
225fn is_inspectable(path: &str) -> bool {
226    matches!(classify_file(path), FileKind::Sql | FileKind::Shell | FileKind::Code)
227}
228
229#[derive(Debug, Clone, Copy, PartialEq, Eq)]
230enum FileKind {
231    Sql,
232    Shell,
233    /// General code (Python, JS, TS, Rust, ...). Lines pass through
234    /// the `llm_response` scope so the `text:` rules in the shieldset
235    /// fire on agent-generated comments + obvious destructive snippets.
236    Code,
237    Other,
238}
239
240fn classify_file(path: &str) -> FileKind {
241    let lower = path.to_lowercase();
242    let basename = std::path::Path::new(&lower)
243        .file_name()
244        .and_then(|s| s.to_str())
245        .unwrap_or("");
246
247    if lower.ends_with(".sql") {
248        return FileKind::Sql;
249    }
250    if lower.ends_with(".sh")
251        || lower.ends_with(".bash")
252        || lower.ends_with(".zsh")
253        || basename == "makefile"
254        || basename.starts_with("dockerfile")
255        || basename == "justfile"
256    {
257        return FileKind::Shell;
258    }
259    if lower.ends_with(".py")
260        || lower.ends_with(".js")
261        || lower.ends_with(".ts")
262        || lower.ends_with(".jsx")
263        || lower.ends_with(".tsx")
264        || lower.ends_with(".rs")
265        || lower.ends_with(".go")
266        || lower.ends_with(".rb")
267        || lower.ends_with(".java")
268        || lower.ends_with(".kt")
269        || lower.ends_with(".swift")
270        || lower.ends_with(".cs")
271    {
272        return FileKind::Code;
273    }
274    FileKind::Other
275}
276
277fn evaluate_line(
278    engine: &Engine,
279    kind: FileKind,
280    line: &str,
281    workspace: &WorkspaceContext,
282    burst: &BurstDetector,
283) -> (crate::engine::Evaluation, &'static str) {
284    let adj = Adjustments {
285        workspace_is_prod: workspace.is_prod,
286        burst_in_progress: burst.in_burst(),
287        ..Default::default()
288    };
289    match kind {
290        FileKind::Sql => {
291            let canonical = json!({"name": "execute_sql", "arguments": {"query": line}});
292            (
293                engine.evaluate("execute_sql", &canonical, adj),
294                "tool_call",
295            )
296        }
297        FileKind::Shell => {
298            let canonical = json!({"name": "shell", "arguments": {"command": line}});
299            (engine.evaluate("shell", &canonical, adj), "tool_call")
300        }
301        FileKind::Code | FileKind::Other => (engine.evaluate_text(line, adj), "llm_response"),
302    }
303}
304
305fn is_pure_comment(line: &str, kind: FileKind) -> bool {
306    let trimmed = line.trim_start();
307    match kind {
308        FileKind::Sql => trimmed.starts_with("--"),
309        FileKind::Shell => trimmed.starts_with('#'),
310        FileKind::Code => {
311            trimmed.starts_with("//")
312                || trimmed.starts_with('#')
313                || trimmed.starts_with("/*")
314                || trimmed.starts_with('*')
315        }
316        FileKind::Other => false,
317    }
318}
319
320fn severity_rank(d: &Decision) -> u8 {
321    match d {
322        Decision::Allow => 0,
323        Decision::Warn { .. } => 1,
324        Decision::IdentityVerification { .. } => 2,
325        Decision::Approval { .. } => 3,
326        Decision::Block { .. } => 4,
327    }
328}
329
330// ─────────────────────────────────────────────────────────────────────
331// git plumbing — shell out and parse, deliberately no libgit2 dep
332// ─────────────────────────────────────────────────────────────────────
333
334#[derive(Debug)]
335struct StagedFile {
336    /// Repo-root-relative path, forward-slash-separated.
337    path: String,
338}
339
340#[derive(Debug)]
341struct AddedLine {
342    line_no: usize,
343    content: String,
344}
345
346fn is_inside_git_repo(repo_root: &std::path::Path) -> Result<bool> {
347    let out = Command::new("git")
348        .args(["rev-parse", "--is-inside-work-tree"])
349        .current_dir(repo_root)
350        .output()
351        .with_context(|| "couldn't invoke `git rev-parse`; is git installed?")?;
352    Ok(out.status.success()
353        && String::from_utf8_lossy(&out.stdout).trim() == "true")
354}
355
356fn list_staged_files(repo_root: &std::path::Path) -> Result<Vec<StagedFile>> {
357    // `--cached` = index vs HEAD, `--diff-filter=AM` = Added + Modified
358    // (we skip Deletions, Renames-only, Copies). `--name-only` is the
359    // fastest path.
360    let out = Command::new("git")
361        .args([
362            "diff",
363            "--cached",
364            "--diff-filter=AM",
365            "--name-only",
366            "-z",
367        ])
368        .current_dir(repo_root)
369        .output()
370        .with_context(|| "git diff --cached failed")?;
371    if !out.status.success() {
372        return Err(anyhow!(
373            "git diff --cached exited {}: {}",
374            out.status,
375            String::from_utf8_lossy(&out.stderr).trim()
376        ));
377    }
378    let mut staged = Vec::new();
379    for chunk in out.stdout.split(|b| *b == 0) {
380        if chunk.is_empty() {
381            continue;
382        }
383        let path = String::from_utf8_lossy(chunk).to_string();
384        // Filter on the index's actual blob size; oversize binaries
385        // never make it to the engine.
386        if blob_oversize(repo_root, &path) {
387            continue;
388        }
389        staged.push(StagedFile { path });
390    }
391    Ok(staged)
392}
393
394fn blob_oversize(repo_root: &std::path::Path, rel_path: &str) -> bool {
395    let on_disk = PathBuf::from(rel_path);
396    let full = repo_root.join(&on_disk);
397    full.metadata()
398        .map(|m| m.len() > MAX_FILE_SIZE_BYTES)
399        .unwrap_or(false)
400}
401
402/// Walk `git diff --cached -U0 -- <path>` and yield every line that
403/// starts with `+` (but not `+++` -- that's the header) along with its
404/// post-image line number.
405fn list_added_lines(
406    repo_root: &std::path::Path,
407    rel_path: &str,
408) -> Result<Vec<AddedLine>> {
409    let out = Command::new("git")
410        .args([
411            "diff",
412            "--cached",
413            "-U0",
414            "--no-color",
415            "--",
416            rel_path,
417        ])
418        .current_dir(repo_root)
419        .output()
420        .with_context(|| format!("git diff --cached -U0 -- {} failed", rel_path))?;
421    if !out.status.success() {
422        return Err(anyhow!(
423            "git diff for {} exited {}: {}",
424            rel_path,
425            out.status,
426            String::from_utf8_lossy(&out.stderr).trim()
427        ));
428    }
429    let text = String::from_utf8_lossy(&out.stdout).to_string();
430    Ok(parse_unified_diff_added(&text))
431}
432
433/// Pure-string parser so this is unit-testable without spawning git.
434/// Handles standard unified diff hunks of the form
435/// `@@ -<a>,<b> +<c>,<d> @@` (with the comma+count optional).
436fn parse_unified_diff_added(diff: &str) -> Vec<AddedLine> {
437    let mut out = Vec::new();
438    let mut cur_line_no: usize = 0;
439    let mut in_hunk = false;
440    for raw in diff.lines() {
441        if raw.starts_with("@@ ") {
442            in_hunk = false;
443            if let Some(plus) = extract_plus_start(raw) {
444                cur_line_no = plus;
445                in_hunk = true;
446            }
447            continue;
448        }
449        if !in_hunk {
450            continue;
451        }
452        if raw.starts_with("+++") || raw.starts_with("---") {
453            continue;
454        }
455        if let Some(rest) = raw.strip_prefix('+') {
456            out.push(AddedLine {
457                line_no: cur_line_no,
458                content: rest.to_string(),
459            });
460            cur_line_no += 1;
461        } else if raw.starts_with(' ') {
462            cur_line_no += 1;
463        }
464        // '-' lines: don't advance the post-image counter.
465    }
466    out
467}
468
469/// Pull the post-image start line number out of a `@@ -X,Y +A,B @@` header.
470fn extract_plus_start(header: &str) -> Option<usize> {
471    let plus = header.find('+')?;
472    let after = &header[plus + 1..];
473    let end = after.find(|c: char| !(c.is_ascii_digit() || c == ',')).unwrap_or(after.len());
474    let nums = &after[..end];
475    let first = nums.split(',').next()?;
476    first.parse::<usize>().ok()
477}
478
479#[cfg(test)]
480mod tests {
481    use super::*;
482
483    #[test]
484    fn classifies_file_extensions_correctly() {
485        assert_eq!(classify_file("migrations/2026.sql"), FileKind::Sql);
486        assert_eq!(classify_file("scripts/cleanup.sh"), FileKind::Shell);
487        assert_eq!(classify_file("Makefile"), FileKind::Shell);
488        assert_eq!(classify_file("Dockerfile"), FileKind::Shell);
489        assert_eq!(classify_file("dockerfile.prod"), FileKind::Shell);
490        assert_eq!(classify_file("src/main.py"), FileKind::Code);
491        assert_eq!(classify_file("README.md"), FileKind::Other);
492        assert_eq!(classify_file("data/dump.json"), FileKind::Other);
493    }
494
495    #[test]
496    fn comment_filter_respects_language() {
497        assert!(is_pure_comment("-- drop table users", FileKind::Sql));
498        assert!(!is_pure_comment("# drop table users", FileKind::Sql)); // not a SQL comment
499        assert!(is_pure_comment("# rm -rf /", FileKind::Shell));
500        assert!(is_pure_comment("// rm -rf /", FileKind::Code));
501        assert!(!is_pure_comment("rm -rf /", FileKind::Shell));
502    }
503
504    #[test]
505    fn diff_parser_extracts_added_lines_with_correct_numbers() {
506        let diff = r#"diff --git a/x.sql b/x.sql
507--- a/x.sql
508+++ b/x.sql
509@@ -0,0 +1,3 @@
510+DROP DATABASE prod;
511+TRUNCATE users;
512+SELECT 1;
513@@ -10,1 +10,2 @@
514-old line
515+new line A
516+new line B
517"#;
518        let lines = parse_unified_diff_added(diff);
519        assert_eq!(lines.len(), 5);
520        assert_eq!(lines[0].line_no, 1);
521        assert_eq!(lines[0].content, "DROP DATABASE prod;");
522        assert_eq!(lines[1].line_no, 2);
523        assert_eq!(lines[2].line_no, 3);
524        assert_eq!(lines[3].line_no, 10);
525        assert_eq!(lines[3].content, "new line A");
526        assert_eq!(lines[4].line_no, 11);
527    }
528
529    #[test]
530    fn diff_parser_ignores_headers_and_minus_lines() {
531        let diff = r#"diff --git a/y.sh b/y.sh
532--- /dev/null
533+++ b/y.sh
534@@ -0,0 +1,1 @@
535+rm -rf /
536"#;
537        let lines = parse_unified_diff_added(diff);
538        assert_eq!(lines.len(), 1);
539        assert_eq!(lines[0].content, "rm -rf /");
540        assert_eq!(lines[0].line_no, 1);
541    }
542
543    #[test]
544    fn plus_start_handles_both_short_and_long_headers() {
545        assert_eq!(extract_plus_start("@@ -0,0 +1,3 @@"), Some(1));
546        assert_eq!(extract_plus_start("@@ -10 +10 @@"), Some(10));
547        assert_eq!(extract_plus_start("@@ -0,0 +42 @@ context"), Some(42));
548    }
549}