use anyhow::{Context, Result};
use rma_common::Finding;
use std::collections::{HashMap, HashSet};
use std::io::{self, BufRead};
use std::path::PathBuf;
use std::process::Command;
pub type ChangedLines = HashMap<PathBuf, HashSet<usize>>;
pub fn get_changed_lines_from_git(project_root: &PathBuf, base_ref: &str) -> Result<ChangedLines> {
let output = Command::new("git")
.args(["diff", "--unified=0", base_ref])
.current_dir(project_root)
.output()
.context("Failed to run git diff")?;
if !output.status.success() {
let _ = Command::new("git")
.args(["fetch", "origin"])
.current_dir(project_root)
.output();
let output = Command::new("git")
.args(["diff", "--unified=0", base_ref])
.current_dir(project_root)
.output()
.context("Failed to run git diff after fetch")?;
if !output.status.success() {
anyhow::bail!(
"git diff failed: {}",
String::from_utf8_lossy(&output.stderr)
);
}
let diff_text = String::from_utf8_lossy(&output.stdout);
return parse_unified_diff(&diff_text, Some(project_root));
}
let diff_text = String::from_utf8_lossy(&output.stdout);
parse_unified_diff(&diff_text, Some(project_root))
}
pub fn get_changed_lines_from_stdin() -> Result<ChangedLines> {
let stdin = io::stdin();
let mut diff_text = String::new();
for line in stdin.lock().lines() {
let line = line.context("Failed to read line from stdin")?;
diff_text.push_str(&line);
diff_text.push('\n');
}
parse_unified_diff(&diff_text, None)
}
pub fn parse_unified_diff(diff_text: &str, project_root: Option<&PathBuf>) -> Result<ChangedLines> {
let mut changed_lines: ChangedLines = HashMap::new();
let mut current_file: Option<PathBuf> = None;
for line in diff_text.lines() {
if line.starts_with("+++ ") {
let path_str = line
.strip_prefix("+++ ")
.unwrap()
.strip_prefix("b/")
.unwrap_or(line.strip_prefix("+++ ").unwrap());
if path_str == "/dev/null" {
current_file = None;
continue;
}
let file_path = if let Some(root) = project_root {
root.join(path_str)
} else {
PathBuf::from(path_str)
};
current_file = Some(file_path);
}
else if line.starts_with("@@ ")
&& let Some(ref file) = current_file
&& let Some((new_start, new_count)) = parse_hunk_header(line)
{
let lines = changed_lines.entry(file.clone()).or_default();
for line_num in new_start..new_start + new_count {
lines.insert(line_num);
}
}
}
Ok(changed_lines)
}
fn parse_hunk_header(line: &str) -> Option<(usize, usize)> {
let parts: Vec<&str> = line.split("@@").collect();
if parts.len() < 2 {
return None;
}
let range_part = parts[1].trim();
for part in range_part.split_whitespace() {
if part.starts_with('+') {
let range_str = part.strip_prefix('+')?;
if let Some((start_str, count_str)) = range_str.split_once(',') {
let start = start_str.parse().ok()?;
let count = count_str.parse().ok()?;
return Some((start, count));
} else {
let start = range_str.parse().ok()?;
return Some((start, 1));
}
}
}
None
}
pub fn filter_findings_by_diff(
findings: Vec<Finding>,
changed_lines: &ChangedLines,
) -> Vec<Finding> {
findings
.into_iter()
.filter(|finding| {
let file_path = &finding.location.file;
if let Some(lines) = changed_lines.get(file_path) {
for line in finding.location.start_line..=finding.location.end_line {
if lines.contains(&line) {
return true;
}
}
false
} else {
let file_name = file_path.file_name();
for (changed_path, lines) in changed_lines.iter() {
let paths_match = changed_path == file_path
|| changed_path.ends_with(file_path)
|| file_path.ends_with(changed_path)
|| (file_name.is_some() && changed_path.file_name() == file_name);
if paths_match {
for line in finding.location.start_line..=finding.location.end_line {
if lines.contains(&line) {
return true;
}
}
}
}
false
}
})
.collect()
}
pub fn is_file_changed(file_path: &PathBuf, changed_lines: &ChangedLines) -> bool {
if changed_lines.contains_key(file_path) {
return true;
}
let file_name = file_path.file_name();
for changed_path in changed_lines.keys() {
if changed_path.ends_with(file_path)
|| file_path.ends_with(changed_path)
|| (file_name.is_some() && changed_path.file_name() == file_name)
{
return true;
}
}
false
}
#[cfg(test)]
mod tests {
use super::*;
use rma_common::{Language, Severity, SourceLocation};
#[test]
fn test_parse_hunk_header_basic() {
let result = parse_hunk_header("@@ -10,3 +10,5 @@");
assert_eq!(result, Some((10, 5)));
}
#[test]
fn test_parse_hunk_header_no_count() {
let result = parse_hunk_header("@@ -10 +10 @@");
assert_eq!(result, Some((10, 1)));
}
#[test]
fn test_parse_hunk_header_with_context() {
let result = parse_hunk_header("@@ -10,3 +10,5 @@ fn example() {");
assert_eq!(result, Some((10, 5)));
}
#[test]
fn test_parse_hunk_header_single_line_new_count() {
let result = parse_hunk_header("@@ -5,0 +6,2 @@");
assert_eq!(result, Some((6, 2)));
}
#[test]
fn test_parse_hunk_header_deletion_only() {
let result = parse_hunk_header("@@ -10,3 +10,0 @@");
assert_eq!(result, Some((10, 0)));
}
#[test]
fn test_parse_unified_diff_simple() {
let diff = r#"diff --git a/src/main.rs b/src/main.rs
index abc123..def456 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -10,3 +10,5 @@ fn main() {
unchanged
+added line 1
+added line 2
unchanged
"#;
let result = parse_unified_diff(diff, None).unwrap();
assert!(result.contains_key(&PathBuf::from("src/main.rs")));
let lines = result.get(&PathBuf::from("src/main.rs")).unwrap();
assert!(lines.contains(&10));
assert!(lines.contains(&11));
assert!(lines.contains(&12));
assert!(lines.contains(&13));
assert!(lines.contains(&14));
}
#[test]
fn test_parse_unified_diff_multiple_hunks() {
let diff = r#"diff --git a/src/lib.rs b/src/lib.rs
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -5,2 +5,3 @@
line
+new line at 6
line
@@ -20,1 +21,2 @@
old
+new line at 22
"#;
let result = parse_unified_diff(diff, None).unwrap();
let lines = result.get(&PathBuf::from("src/lib.rs")).unwrap();
assert!(lines.contains(&5));
assert!(lines.contains(&6));
assert!(lines.contains(&7));
assert!(lines.contains(&21));
assert!(lines.contains(&22));
}
#[test]
fn test_parse_unified_diff_new_file() {
let diff = r#"diff --git a/src/new_file.rs b/src/new_file.rs
new file mode 100644
index 0000000..abc123
--- /dev/null
+++ b/src/new_file.rs
@@ -0,0 +1,10 @@
+fn new_function() {
+ println!("hello");
+}
"#;
let result = parse_unified_diff(diff, None).unwrap();
assert!(result.contains_key(&PathBuf::from("src/new_file.rs")));
let lines = result.get(&PathBuf::from("src/new_file.rs")).unwrap();
for i in 1..=10 {
assert!(lines.contains(&i), "Line {} should be marked as changed", i);
}
}
#[test]
fn test_parse_unified_diff_deleted_file() {
let diff = r#"diff --git a/src/old_file.rs b/src/old_file.rs
deleted file mode 100644
index abc123..0000000
--- a/src/old_file.rs
+++ /dev/null
@@ -1,10 +0,0 @@
-fn old_function() {
- println!("goodbye");
-}
"#;
let result = parse_unified_diff(diff, None).unwrap();
assert!(!result.contains_key(&PathBuf::from("src/old_file.rs")));
}
#[test]
fn test_parse_unified_diff_renamed_file() {
let diff = r#"diff --git a/src/old_name.rs b/src/new_name.rs
similarity index 95%
rename from src/old_name.rs
rename to src/new_name.rs
index abc123..def456 100644
--- a/src/old_name.rs
+++ b/src/new_name.rs
@@ -5,1 +5,2 @@
unchanged
+added in renamed file
"#;
let result = parse_unified_diff(diff, None).unwrap();
assert!(result.contains_key(&PathBuf::from("src/new_name.rs")));
let lines = result.get(&PathBuf::from("src/new_name.rs")).unwrap();
assert!(lines.contains(&5));
assert!(lines.contains(&6));
}
#[test]
fn test_filter_findings_by_diff_keeps_changed() {
let mut changed_lines = ChangedLines::new();
changed_lines.insert(
PathBuf::from("src/main.rs"),
vec![10, 11, 12].into_iter().collect(),
);
let findings = vec![
create_test_finding("src/main.rs", 10, 10), create_test_finding("src/main.rs", 5, 5), create_test_finding("src/main.rs", 11, 12), ];
let filtered = filter_findings_by_diff(findings, &changed_lines);
assert_eq!(filtered.len(), 2);
assert_eq!(filtered[0].location.start_line, 10);
assert_eq!(filtered[1].location.start_line, 11);
}
#[test]
fn test_filter_findings_by_diff_removes_unchanged() {
let mut changed_lines = ChangedLines::new();
changed_lines.insert(
PathBuf::from("src/main.rs"),
vec![10, 11].into_iter().collect(),
);
let findings = vec![
create_test_finding("src/main.rs", 5, 5), create_test_finding("src/main.rs", 20, 25), create_test_finding("src/other.rs", 10, 10), ];
let filtered = filter_findings_by_diff(findings, &changed_lines);
assert!(filtered.is_empty());
}
#[test]
fn test_filter_findings_partial_overlap() {
let mut changed_lines = ChangedLines::new();
changed_lines.insert(
PathBuf::from("src/main.rs"),
vec![10, 11, 12].into_iter().collect(),
);
let findings = vec![
create_test_finding("src/main.rs", 8, 10), create_test_finding("src/main.rs", 12, 15), ];
let filtered = filter_findings_by_diff(findings, &changed_lines);
assert_eq!(filtered.len(), 2); }
#[test]
fn test_is_file_changed() {
let mut changed_lines = ChangedLines::new();
changed_lines.insert(PathBuf::from("src/main.rs"), vec![10].into_iter().collect());
assert!(is_file_changed(
&PathBuf::from("src/main.rs"),
&changed_lines
));
assert!(!is_file_changed(
&PathBuf::from("src/other.rs"),
&changed_lines
));
}
#[test]
fn test_path_matching_relative_absolute() {
let mut changed_lines = ChangedLines::new();
changed_lines.insert(
PathBuf::from("/project/src/main.rs"),
vec![10].into_iter().collect(),
);
let finding = create_test_finding("src/main.rs", 10, 10);
let filtered = filter_findings_by_diff(vec![finding], &changed_lines);
assert_eq!(filtered.len(), 1);
}
fn create_test_finding(file: &str, start_line: usize, end_line: usize) -> Finding {
Finding {
id: format!("test-{}-{}", file, start_line),
rule_id: "test-rule".to_string(),
message: "Test finding".to_string(),
severity: Severity::Warning,
location: SourceLocation {
file: PathBuf::from(file),
start_line,
start_column: 1,
end_line,
end_column: 1,
},
language: Language::Rust,
snippet: None,
suggestion: None,
fix: None,
confidence: rma_common::Confidence::Medium,
category: rma_common::FindingCategory::Security,
fingerprint: None,
properties: None,
occurrence_count: None,
additional_locations: None,
}
}
}