1use anyhow::{anyhow, Context, Result};
51use serde_json::json;
52use std::collections::BTreeMap;
53use std::path::PathBuf;
54use std::process::Command;
55
56use crate::engine::Engine;
57use crate::{decide, Adjustments, BurstDetector, Decision, WorkspaceContext};
58
59const MAX_FILE_SIZE_BYTES: u64 = 256 * 1024;
60
61#[derive(Debug, Clone)]
64pub struct StagedFinding {
65 pub file: String,
66 pub line_no: usize,
67 pub line: String,
68 pub rule_id: String,
69 pub severity: String,
70 pub decision: String,
71 pub reason: String,
72 pub safer_alternative: Option<String>,
73}
74
75#[derive(Debug, Default)]
77pub struct CheckStagedReport {
78 pub files_scanned: usize,
79 pub lines_scanned: usize,
80 pub findings: Vec<StagedFinding>,
81 pub worst_decision: Option<Decision>,
83}
84
85impl CheckStagedReport {
86 pub fn exit_code(&self) -> u8 {
89 match &self.worst_decision {
90 Some(d) if d.is_blocking() => 1,
91 Some(Decision::Approval { .. }) => 2,
92 _ => 0,
93 }
94 }
95
96 pub fn group_by_rule(&self) -> BTreeMap<String, Vec<&StagedFinding>> {
99 let mut out: BTreeMap<String, Vec<&StagedFinding>> = BTreeMap::new();
100 for f in &self.findings {
101 out.entry(f.rule_id.clone()).or_default().push(f);
102 }
103 out
104 }
105}
106
107pub fn run(repo_root: &std::path::Path, engine: &Engine, workspace_root: Option<&std::path::Path>) -> Result<CheckStagedReport> {
112 if !is_inside_git_repo(repo_root)? {
113 return Err(anyhow!(
114 "--check-staged must be run inside a git repository (got {})",
115 repo_root.display()
116 ));
117 }
118
119 let staged_files = list_staged_files(repo_root)?;
120
121 let workspace = match workspace_root {
127 Some(p) => WorkspaceContext::probe_at(&engine.policy, p),
128 None => WorkspaceContext::probe_at(&engine.policy, repo_root),
129 };
130 let burst = BurstDetector::new(engine.policy.burst_detector.clone());
131
132 let mut report = CheckStagedReport::default();
133
134 for staged in staged_files {
135 if !is_inspectable(&staged.path) {
136 continue;
137 }
138 let added = match list_added_lines(repo_root, &staged.path) {
139 Ok(v) => v,
140 Err(e) => {
141 eprintln!(
144 "[shield-check-staged] skipping {}: {}",
145 staged.path, e
146 );
147 continue;
148 }
149 };
150 if added.is_empty() {
151 continue;
152 }
153 report.files_scanned += 1;
154
155 let kind = classify_file(&staged.path);
156
157 for AddedLine { line_no, content } in added {
158 if content.trim().is_empty() {
159 continue;
160 }
161 if is_pure_comment(&content, kind) {
162 continue;
163 }
164 report.lines_scanned += 1;
165
166 let (eval, _scope) = evaluate_line(engine, kind, &content, &workspace, &burst);
167 let decision = decide(&eval);
168 match decision {
169 Decision::Allow => continue,
170 Decision::Warn { .. }
171 | Decision::Approval { .. }
172 | Decision::Block { .. }
173 | Decision::IdentityVerification { .. } => {
174 let primary = eval
176 .matches
177 .iter()
178 .max_by(|a, b| {
179 a.severity.cmp(&b.severity).then(a.points.cmp(&b.points))
180 })
181 .cloned();
182 let (rule_id, severity, reason, safer) = match primary {
183 Some(m) => (
184 m.rule_id.clone(),
185 format!("{:?}", m.severity),
186 m.reason.clone(),
187 m.safer_alternative.clone(),
188 ),
189 None => (
190 "shield.unknown".into(),
191 "Medium".into(),
192 "matched without an attributable rule".into(),
193 None,
194 ),
195 };
196 let dec_label = decision.label().to_string();
197 report.findings.push(StagedFinding {
198 file: staged.path.clone(),
199 line_no,
200 line: content,
201 rule_id,
202 severity,
203 decision: dec_label,
204 reason,
205 safer_alternative: safer,
206 });
207 if report
208 .worst_decision
209 .as_ref()
210 .map(|d| (severity_rank(&decision)) > severity_rank(d))
211 .unwrap_or(true)
212 {
213 report.worst_decision = Some(decision.clone());
214 }
215 }
216 }
217 }
218 }
219
220 Ok(report)
221}
222
223fn is_inspectable(path: &str) -> bool {
226 matches!(classify_file(path), FileKind::Sql | FileKind::Shell | FileKind::Code)
227}
228
229#[derive(Debug, Clone, Copy, PartialEq, Eq)]
230enum FileKind {
231 Sql,
232 Shell,
233 Code,
237 Other,
238}
239
240fn classify_file(path: &str) -> FileKind {
241 let lower = path.to_lowercase();
242 let basename = std::path::Path::new(&lower)
243 .file_name()
244 .and_then(|s| s.to_str())
245 .unwrap_or("");
246
247 if lower.ends_with(".sql") {
248 return FileKind::Sql;
249 }
250 if lower.ends_with(".sh")
251 || lower.ends_with(".bash")
252 || lower.ends_with(".zsh")
253 || basename == "makefile"
254 || basename.starts_with("dockerfile")
255 || basename == "justfile"
256 {
257 return FileKind::Shell;
258 }
259 if lower.ends_with(".py")
260 || lower.ends_with(".js")
261 || lower.ends_with(".ts")
262 || lower.ends_with(".jsx")
263 || lower.ends_with(".tsx")
264 || lower.ends_with(".rs")
265 || lower.ends_with(".go")
266 || lower.ends_with(".rb")
267 || lower.ends_with(".java")
268 || lower.ends_with(".kt")
269 || lower.ends_with(".swift")
270 || lower.ends_with(".cs")
271 {
272 return FileKind::Code;
273 }
274 FileKind::Other
275}
276
277fn evaluate_line(
278 engine: &Engine,
279 kind: FileKind,
280 line: &str,
281 workspace: &WorkspaceContext,
282 burst: &BurstDetector,
283) -> (crate::engine::Evaluation, &'static str) {
284 let adj = Adjustments {
285 workspace_is_prod: workspace.is_prod,
286 burst_in_progress: burst.in_burst(),
287 ..Default::default()
288 };
289 match kind {
290 FileKind::Sql => {
291 let canonical = json!({"name": "execute_sql", "arguments": {"query": line}});
292 (
293 engine.evaluate("execute_sql", &canonical, adj),
294 "tool_call",
295 )
296 }
297 FileKind::Shell => {
298 let canonical = json!({"name": "shell", "arguments": {"command": line}});
299 (engine.evaluate("shell", &canonical, adj), "tool_call")
300 }
301 FileKind::Code | FileKind::Other => (engine.evaluate_text(line, adj), "llm_response"),
302 }
303}
304
305fn is_pure_comment(line: &str, kind: FileKind) -> bool {
306 let trimmed = line.trim_start();
307 match kind {
308 FileKind::Sql => trimmed.starts_with("--"),
309 FileKind::Shell => trimmed.starts_with('#'),
310 FileKind::Code => {
311 trimmed.starts_with("//")
312 || trimmed.starts_with('#')
313 || trimmed.starts_with("/*")
314 || trimmed.starts_with('*')
315 }
316 FileKind::Other => false,
317 }
318}
319
320fn severity_rank(d: &Decision) -> u8 {
321 match d {
322 Decision::Allow => 0,
323 Decision::Warn { .. } => 1,
324 Decision::IdentityVerification { .. } => 2,
325 Decision::Approval { .. } => 3,
326 Decision::Block { .. } => 4,
327 }
328}
329
330#[derive(Debug)]
335struct StagedFile {
336 path: String,
338}
339
340#[derive(Debug)]
341struct AddedLine {
342 line_no: usize,
343 content: String,
344}
345
346fn is_inside_git_repo(repo_root: &std::path::Path) -> Result<bool> {
347 let out = Command::new("git")
348 .args(["rev-parse", "--is-inside-work-tree"])
349 .current_dir(repo_root)
350 .output()
351 .with_context(|| "couldn't invoke `git rev-parse`; is git installed?")?;
352 Ok(out.status.success()
353 && String::from_utf8_lossy(&out.stdout).trim() == "true")
354}
355
356fn list_staged_files(repo_root: &std::path::Path) -> Result<Vec<StagedFile>> {
357 let out = Command::new("git")
361 .args([
362 "diff",
363 "--cached",
364 "--diff-filter=AM",
365 "--name-only",
366 "-z",
367 ])
368 .current_dir(repo_root)
369 .output()
370 .with_context(|| "git diff --cached failed")?;
371 if !out.status.success() {
372 return Err(anyhow!(
373 "git diff --cached exited {}: {}",
374 out.status,
375 String::from_utf8_lossy(&out.stderr).trim()
376 ));
377 }
378 let mut staged = Vec::new();
379 for chunk in out.stdout.split(|b| *b == 0) {
380 if chunk.is_empty() {
381 continue;
382 }
383 let path = String::from_utf8_lossy(chunk).to_string();
384 if blob_oversize(repo_root, &path) {
387 continue;
388 }
389 staged.push(StagedFile { path });
390 }
391 Ok(staged)
392}
393
394fn blob_oversize(repo_root: &std::path::Path, rel_path: &str) -> bool {
395 let on_disk = PathBuf::from(rel_path);
396 let full = repo_root.join(&on_disk);
397 full.metadata()
398 .map(|m| m.len() > MAX_FILE_SIZE_BYTES)
399 .unwrap_or(false)
400}
401
402fn list_added_lines(
406 repo_root: &std::path::Path,
407 rel_path: &str,
408) -> Result<Vec<AddedLine>> {
409 let out = Command::new("git")
410 .args([
411 "diff",
412 "--cached",
413 "-U0",
414 "--no-color",
415 "--",
416 rel_path,
417 ])
418 .current_dir(repo_root)
419 .output()
420 .with_context(|| format!("git diff --cached -U0 -- {} failed", rel_path))?;
421 if !out.status.success() {
422 return Err(anyhow!(
423 "git diff for {} exited {}: {}",
424 rel_path,
425 out.status,
426 String::from_utf8_lossy(&out.stderr).trim()
427 ));
428 }
429 let text = String::from_utf8_lossy(&out.stdout).to_string();
430 Ok(parse_unified_diff_added(&text))
431}
432
433fn parse_unified_diff_added(diff: &str) -> Vec<AddedLine> {
437 let mut out = Vec::new();
438 let mut cur_line_no: usize = 0;
439 let mut in_hunk = false;
440 for raw in diff.lines() {
441 if raw.starts_with("@@ ") {
442 in_hunk = false;
443 if let Some(plus) = extract_plus_start(raw) {
444 cur_line_no = plus;
445 in_hunk = true;
446 }
447 continue;
448 }
449 if !in_hunk {
450 continue;
451 }
452 if raw.starts_with("+++") || raw.starts_with("---") {
453 continue;
454 }
455 if let Some(rest) = raw.strip_prefix('+') {
456 out.push(AddedLine {
457 line_no: cur_line_no,
458 content: rest.to_string(),
459 });
460 cur_line_no += 1;
461 } else if raw.starts_with(' ') {
462 cur_line_no += 1;
463 }
464 }
466 out
467}
468
469fn extract_plus_start(header: &str) -> Option<usize> {
471 let plus = header.find('+')?;
472 let after = &header[plus + 1..];
473 let end = after.find(|c: char| !(c.is_ascii_digit() || c == ',')).unwrap_or(after.len());
474 let nums = &after[..end];
475 let first = nums.split(',').next()?;
476 first.parse::<usize>().ok()
477}
478
479#[cfg(test)]
480mod tests {
481 use super::*;
482
483 #[test]
484 fn classifies_file_extensions_correctly() {
485 assert_eq!(classify_file("migrations/2026.sql"), FileKind::Sql);
486 assert_eq!(classify_file("scripts/cleanup.sh"), FileKind::Shell);
487 assert_eq!(classify_file("Makefile"), FileKind::Shell);
488 assert_eq!(classify_file("Dockerfile"), FileKind::Shell);
489 assert_eq!(classify_file("dockerfile.prod"), FileKind::Shell);
490 assert_eq!(classify_file("src/main.py"), FileKind::Code);
491 assert_eq!(classify_file("README.md"), FileKind::Other);
492 assert_eq!(classify_file("data/dump.json"), FileKind::Other);
493 }
494
495 #[test]
496 fn comment_filter_respects_language() {
497 assert!(is_pure_comment("-- drop table users", FileKind::Sql));
498 assert!(!is_pure_comment("# drop table users", FileKind::Sql)); assert!(is_pure_comment("# rm -rf /", FileKind::Shell));
500 assert!(is_pure_comment("// rm -rf /", FileKind::Code));
501 assert!(!is_pure_comment("rm -rf /", FileKind::Shell));
502 }
503
504 #[test]
505 fn diff_parser_extracts_added_lines_with_correct_numbers() {
506 let diff = r#"diff --git a/x.sql b/x.sql
507--- a/x.sql
508+++ b/x.sql
509@@ -0,0 +1,3 @@
510+DROP DATABASE prod;
511+TRUNCATE users;
512+SELECT 1;
513@@ -10,1 +10,2 @@
514-old line
515+new line A
516+new line B
517"#;
518 let lines = parse_unified_diff_added(diff);
519 assert_eq!(lines.len(), 5);
520 assert_eq!(lines[0].line_no, 1);
521 assert_eq!(lines[0].content, "DROP DATABASE prod;");
522 assert_eq!(lines[1].line_no, 2);
523 assert_eq!(lines[2].line_no, 3);
524 assert_eq!(lines[3].line_no, 10);
525 assert_eq!(lines[3].content, "new line A");
526 assert_eq!(lines[4].line_no, 11);
527 }
528
529 #[test]
530 fn diff_parser_ignores_headers_and_minus_lines() {
531 let diff = r#"diff --git a/y.sh b/y.sh
532--- /dev/null
533+++ b/y.sh
534@@ -0,0 +1,1 @@
535+rm -rf /
536"#;
537 let lines = parse_unified_diff_added(diff);
538 assert_eq!(lines.len(), 1);
539 assert_eq!(lines[0].content, "rm -rf /");
540 assert_eq!(lines[0].line_no, 1);
541 }
542
543 #[test]
544 fn plus_start_handles_both_short_and_long_headers() {
545 assert_eq!(extract_plus_start("@@ -0,0 +1,3 @@"), Some(1));
546 assert_eq!(extract_plus_start("@@ -10 +10 @@"), Some(10));
547 assert_eq!(extract_plus_start("@@ -0,0 +42 @@ context"), Some(42));
548 }
549}