Skip to main content

code_baseline/
git_diff.rs

1use std::collections::HashMap;
2use std::fmt;
3use std::ops::RangeInclusive;
4use std::path::PathBuf;
5use std::process::Command;
6
7#[derive(Debug)]
8pub enum GitDiffError {
9    GitNotFound,
10    NotARepo,
11    BaseRefNotFound(String),
12    CommandFailed(String),
13}
14
15impl fmt::Display for GitDiffError {
16    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
17        match self {
18            GitDiffError::GitNotFound => write!(f, "git is not installed or not in PATH"),
19            GitDiffError::NotARepo => write!(f, "not inside a git repository"),
20            GitDiffError::BaseRefNotFound(r) => {
21                write!(f, "base ref '{}' not found (try fetching it first)", r)
22            }
23            GitDiffError::CommandFailed(msg) => write!(f, "git command failed: {}", msg),
24        }
25    }
26}
27
28impl std::error::Error for GitDiffError {}
29
30/// Changed files and line ranges from a git diff.
31#[derive(Debug)]
32pub struct DiffInfo {
33    /// Map of relative file path to list of changed line ranges.
34    pub changed_lines: HashMap<PathBuf, Vec<RangeInclusive<usize>>>,
35}
36
37impl DiffInfo {
38    pub fn has_file(&self, path: &PathBuf) -> bool {
39        self.changed_lines.contains_key(path)
40    }
41
42    /// Check if a specific line in a file is within a changed range.
43    pub fn has_line(&self, path: &PathBuf, line: usize) -> bool {
44        match self.changed_lines.get(path) {
45            Some(ranges) => ranges.iter().any(|r| r.contains(&line)),
46            None => false,
47        }
48    }
49}
50
51/// Detect the base ref from CI environment variables, falling back to "main".
52pub fn detect_base_ref() -> String {
53    // GitHub Actions
54    if let Ok(base) = std::env::var("GITHUB_BASE_REF") {
55        if !base.is_empty() {
56            return base;
57        }
58    }
59    // GitLab CI
60    if let Ok(base) = std::env::var("CI_MERGE_REQUEST_TARGET_BRANCH_NAME") {
61        if !base.is_empty() {
62            return base;
63        }
64    }
65    // Bitbucket Pipelines
66    if let Ok(base) = std::env::var("BITBUCKET_PR_DESTINATION_BRANCH") {
67        if !base.is_empty() {
68            return base;
69        }
70    }
71    "main".to_string()
72}
73
74/// Get the repository root directory.
75pub fn repo_root() -> Result<PathBuf, GitDiffError> {
76    let output = Command::new("git")
77        .args(["rev-parse", "--show-toplevel"])
78        .output()
79        .map_err(|_| GitDiffError::GitNotFound)?;
80
81    if !output.status.success() {
82        return Err(GitDiffError::NotARepo);
83    }
84
85    let root = String::from_utf8_lossy(&output.stdout).trim().to_string();
86    Ok(PathBuf::from(root))
87}
88
89/// Parse a git diff to extract changed files and their changed line ranges.
90///
91/// Uses triple-dot diff (`base...HEAD`) for correct merge-base comparison.
92/// Only includes Added, Copied, Modified, Renamed files (`--diff-filter=ACMR`).
93pub fn diff_info(base_ref: &str) -> Result<DiffInfo, GitDiffError> {
94    // Ensure we're in a git repo
95    repo_root()?;
96
97    // Try the base ref directly, then with origin/ prefix
98    let effective_base = resolve_base_ref(base_ref)?;
99
100    let output = Command::new("git")
101        .args([
102            "diff",
103            "-U0",
104            "--diff-filter=ACMR",
105            &format!("{}...HEAD", effective_base),
106        ])
107        .output()
108        .map_err(|_| GitDiffError::GitNotFound)?;
109
110    if !output.status.success() {
111        let stderr = String::from_utf8_lossy(&output.stderr).to_string();
112        return Err(GitDiffError::CommandFailed(stderr));
113    }
114
115    let diff_text = String::from_utf8_lossy(&output.stdout);
116    Ok(parse_diff(&diff_text))
117}
118
119/// Resolve a base ref, trying the ref directly then with origin/ prefix.
120/// For shallow clones, attempts a fetch first.
121fn resolve_base_ref(base_ref: &str) -> Result<String, GitDiffError> {
122    // Try the ref directly
123    if ref_exists(base_ref) {
124        return Ok(base_ref.to_string());
125    }
126
127    // Try with origin/ prefix
128    let with_origin = format!("origin/{}", base_ref);
129    if ref_exists(&with_origin) {
130        return Ok(with_origin);
131    }
132
133    // Attempt shallow fetch and retry
134    let _ = Command::new("git")
135        .args(["fetch", "--depth=1", "origin", base_ref])
136        .output();
137
138    if ref_exists(&with_origin) {
139        return Ok(with_origin);
140    }
141
142    if ref_exists(base_ref) {
143        return Ok(base_ref.to_string());
144    }
145
146    Err(GitDiffError::BaseRefNotFound(base_ref.to_string()))
147}
148
149fn ref_exists(r: &str) -> bool {
150    Command::new("git")
151        .args(["rev-parse", "--verify", r])
152        .output()
153        .map(|o| o.status.success())
154        .unwrap_or(false)
155}
156
157/// Parse unified diff output into a DiffInfo.
158fn parse_diff(diff_text: &str) -> DiffInfo {
159    let mut changed_lines: HashMap<PathBuf, Vec<RangeInclusive<usize>>> = HashMap::new();
160    let mut current_file: Option<PathBuf> = None;
161
162    for line in diff_text.lines() {
163        // Detect file path from +++ line
164        if let Some(path) = line.strip_prefix("+++ b/") {
165            current_file = Some(PathBuf::from(path));
166            changed_lines
167                .entry(PathBuf::from(path))
168                .or_insert_with(Vec::new);
169            continue;
170        }
171
172        // Parse hunk header: @@ -old_start,old_count +new_start,new_count @@
173        if line.starts_with("@@") {
174            if let Some(ref file) = current_file {
175                if let Some(range) = parse_hunk_header(line) {
176                    changed_lines.entry(file.clone()).or_default().push(range);
177                }
178            }
179        }
180    }
181
182    DiffInfo { changed_lines }
183}
184
185/// Parse a hunk header like `@@ -10,3 +15,4 @@` and return the new-side line range.
186///
187/// Format: `+start,count` means lines `start..=start+count-1`.
188/// If count is 0, it's a pure deletion — return None.
189/// If count is omitted, it defaults to 1.
190fn parse_hunk_header(line: &str) -> Option<RangeInclusive<usize>> {
191    // Find the +start,count portion
192    let plus_pos = line.find('+')?;
193    let after_plus = &line[plus_pos + 1..];
194
195    // Find the end of the numbers (next space or @@)
196    let end = after_plus
197        .find(|c: char| c == ' ' || c == '@')
198        .unwrap_or(after_plus.len());
199    let range_str = &after_plus[..end];
200
201    if let Some(comma_pos) = range_str.find(',') {
202        let start: usize = range_str[..comma_pos].parse().ok()?;
203        let count: usize = range_str[comma_pos + 1..].parse().ok()?;
204        if count == 0 {
205            return None; // pure deletion
206        }
207        Some(start..=start + count - 1)
208    } else {
209        // No comma — single line change (count = 1)
210        let start: usize = range_str.parse().ok()?;
211        Some(start..=start)
212    }
213}
214
215#[cfg(test)]
216mod tests {
217    use super::*;
218
219    #[test]
220    fn parse_hunk_single_line() {
221        let range = parse_hunk_header("@@ -10,0 +15 @@").unwrap();
222        assert_eq!(range, 15..=15);
223    }
224
225    #[test]
226    fn parse_hunk_multi_line() {
227        let range = parse_hunk_header("@@ -10,3 +15,4 @@").unwrap();
228        assert_eq!(range, 15..=18);
229    }
230
231    #[test]
232    fn parse_hunk_pure_deletion() {
233        let range = parse_hunk_header("@@ -10,3 +14,0 @@");
234        assert!(range.is_none());
235    }
236
237    #[test]
238    fn parse_hunk_with_context() {
239        let range = parse_hunk_header("@@ -1,5 +1,7 @@ fn main() {").unwrap();
240        assert_eq!(range, 1..=7);
241    }
242
243    #[test]
244    fn parse_diff_full() {
245        let diff = "\
246diff --git a/src/foo.rs b/src/foo.rs
247index abc..def 100644
248--- a/src/foo.rs
249+++ b/src/foo.rs
250@@ -1,3 +1,5 @@
251+new line 1
252+new line 2
253 existing
254diff --git a/src/bar.rs b/src/bar.rs
255new file mode 100644
256--- /dev/null
257+++ b/src/bar.rs
258@@ -0,0 +1,10 @@
259+all new file
260";
261        let info = parse_diff(diff);
262        assert!(info.changed_lines.contains_key(&PathBuf::from("src/foo.rs")));
263        assert!(info.changed_lines.contains_key(&PathBuf::from("src/bar.rs")));
264
265        let foo_ranges = &info.changed_lines[&PathBuf::from("src/foo.rs")];
266        assert_eq!(foo_ranges.len(), 1);
267        assert_eq!(foo_ranges[0], 1..=5);
268
269        let bar_ranges = &info.changed_lines[&PathBuf::from("src/bar.rs")];
270        assert_eq!(bar_ranges.len(), 1);
271        assert_eq!(bar_ranges[0], 1..=10);
272    }
273
274    #[test]
275    fn diff_info_has_file_and_line() {
276        let mut changed_lines = HashMap::new();
277        changed_lines.insert(
278            PathBuf::from("src/main.rs"),
279            vec![5..=10, 20..=25],
280        );
281        let info = DiffInfo { changed_lines };
282
283        assert!(info.has_file(&PathBuf::from("src/main.rs")));
284        assert!(!info.has_file(&PathBuf::from("src/other.rs")));
285
286        assert!(info.has_line(&PathBuf::from("src/main.rs"), 7));
287        assert!(info.has_line(&PathBuf::from("src/main.rs"), 20));
288        assert!(!info.has_line(&PathBuf::from("src/main.rs"), 15));
289    }
290
291    #[test]
292    fn detect_base_ref_defaults_to_main() {
293        // When no CI env vars are set, should default to "main"
294        // (This test may behave differently in CI, but the logic is correct)
295        let base = detect_base_ref();
296        // In local dev, should be "main" unless CI env vars are set
297        assert!(!base.is_empty());
298    }
299}