Skip to main content

git_surgeon/
patch.rs

1use anyhow::{Context, Result};
2
3use crate::diff::DiffHunk;
4
5pub enum ApplyMode {
6    Stage,
7    Unstage,
8    Discard,
9}
10
11/// Slice a hunk to only include changes within the given 1-based line range.
12/// Lines outside the range have their changes neutralized:
13/// - excluded '+' lines are dropped
14/// - excluded '-' lines become context (the deletion is kept)
15///
16/// Context lines are always preserved for patch validity.
17pub fn slice_hunk(hunk: &DiffHunk, start: usize, end: usize, reverse: bool) -> Result<DiffHunk> {
18    slice_hunk_multi(hunk, &[(start, end)], reverse)
19}
20
21/// Slice a hunk keeping changes from any of the given 1-based line ranges.
22pub fn slice_hunk_multi(
23    hunk: &DiffHunk,
24    ranges: &[(usize, usize)],
25    reverse: bool,
26) -> Result<DiffHunk> {
27    let in_any_range = |idx: usize| ranges.iter().any(|(s, e)| idx >= *s && idx <= *e);
28
29    let mut new_lines = Vec::new();
30    for (i, line) in hunk.lines.iter().enumerate() {
31        let idx = i + 1;
32        let in_range = in_any_range(idx);
33
34        if let Some(rest) = line.strip_prefix('+') {
35            if in_range {
36                new_lines.push(line.clone());
37            } else if reverse {
38                new_lines.push(format!(" {}", rest));
39            }
40        } else if let Some(rest) = line.strip_prefix('-') {
41            if in_range {
42                new_lines.push(line.clone());
43            } else if !reverse {
44                new_lines.push(format!(" {}", rest));
45            }
46        } else {
47            new_lines.push(line.clone());
48        }
49    }
50
51    let old_count = new_lines
52        .iter()
53        .filter(|l| l.starts_with('-') || l.starts_with(' '))
54        .count();
55    let new_count = new_lines
56        .iter()
57        .filter(|l| l.starts_with('+') || l.starts_with(' '))
58        .count();
59
60    let (old_start, new_start) = parse_hunk_starts(&hunk.header)?;
61
62    let func_ctx = hunk
63        .header
64        .find("@@ ")
65        .and_then(|s| {
66            let rest = &hunk.header[s + 3..];
67            rest.find("@@").map(|e| &rest[e + 2..])
68        })
69        .unwrap_or("");
70
71    let new_header = format!(
72        "@@ -{},{} +{},{} @@{}",
73        old_start, old_count, new_start, new_count, func_ctx
74    );
75
76    Ok(DiffHunk {
77        file: hunk.file.clone(),
78        old_file: hunk.old_file.clone(),
79        new_file: hunk.new_file.clone(),
80        file_header: hunk.file_header.clone(),
81        header: new_header,
82        lines: new_lines,
83        unsupported_metadata: hunk.unsupported_metadata.clone(),
84    })
85}
86
87fn parse_hunk_starts(header: &str) -> Result<(usize, usize)> {
88    let content = header
89        .trim_start_matches("@@ ")
90        .split(" @@")
91        .next()
92        .ok_or_else(|| anyhow::anyhow!("invalid hunk header"))?;
93    let mut parts = content.split_whitespace();
94    let old_start: usize = parts
95        .next()
96        .and_then(|s| s.strip_prefix('-'))
97        .and_then(|s| s.split(',').next())
98        .and_then(|s| s.parse().ok())
99        .ok_or_else(|| anyhow::anyhow!("cannot parse old start from header"))?;
100    let new_start: usize = parts
101        .next()
102        .and_then(|s| s.strip_prefix('+'))
103        .and_then(|s| s.split(',').next())
104        .and_then(|s| s.parse().ok())
105        .ok_or_else(|| anyhow::anyhow!("cannot parse new start from header"))?;
106    Ok((old_start, new_start))
107}
108
109/// Slice a hunk using picked/selected state masks (for split command).
110///
111/// This builds a patch that correctly accounts for previously picked lines:
112/// - '+' lines: selected -> keep, already picked -> context, else drop
113/// - '-' lines: selected -> keep, already picked -> drop, else context
114/// - context: always keep
115///
116/// Both `picked` and `selected` are masks over hunk.lines (same length).
117pub fn slice_hunk_with_state(
118    hunk: &DiffHunk,
119    picked: &[bool],
120    selected: &[bool],
121) -> Result<DiffHunk> {
122    if picked.len() != hunk.lines.len() || selected.len() != hunk.lines.len() {
123        anyhow::bail!(
124            "state mask length mismatch: hunk has {} lines, picked {}, selected {}",
125            hunk.lines.len(),
126            picked.len(),
127            selected.len()
128        );
129    }
130
131    let mut new_lines = Vec::new();
132    for (i, line) in hunk.lines.iter().enumerate() {
133        let already_picked = picked[i];
134        let want = selected[i];
135
136        if let Some(rest) = line.strip_prefix('+') {
137            if want {
138                // Selected: include as addition
139                new_lines.push(line.clone());
140            } else if already_picked {
141                // Previously picked: now exists in index, becomes context
142                new_lines.push(format!(" {}", rest));
143            }
144            // else: not picked yet, not selected -> drop (doesn't exist in index)
145        } else if let Some(rest) = line.strip_prefix('-') {
146            if want {
147                // Selected: include as deletion
148                new_lines.push(line.clone());
149            } else if !already_picked {
150                // Not picked yet: line still exists in index, becomes context
151                new_lines.push(format!(" {}", rest));
152            }
153            // else: already picked (removed) -> drop (line no longer in index)
154        } else {
155            // Context line: always keep
156            new_lines.push(line.clone());
157        }
158    }
159
160    let old_count = new_lines
161        .iter()
162        .filter(|l| l.starts_with('-') || l.starts_with(' '))
163        .count();
164    let new_count = new_lines
165        .iter()
166        .filter(|l| l.starts_with('+') || l.starts_with(' '))
167        .count();
168
169    let (old_start, new_start) = parse_hunk_starts(&hunk.header)?;
170
171    let func_ctx = hunk
172        .header
173        .find("@@ ")
174        .and_then(|s| {
175            let rest = &hunk.header[s + 3..];
176            rest.find("@@").map(|e| &rest[e + 2..])
177        })
178        .unwrap_or("");
179
180    let new_header = format!(
181        "@@ -{},{} +{},{} @@{}",
182        old_start, old_count, new_start, new_count, func_ctx
183    );
184
185    Ok(DiffHunk {
186        file: hunk.file.clone(),
187        old_file: hunk.old_file.clone(),
188        new_file: hunk.new_file.clone(),
189        file_header: hunk.file_header.clone(),
190        header: new_header,
191        lines: new_lines,
192        unsupported_metadata: hunk.unsupported_metadata.clone(),
193    })
194}
195
196/// Reconstruct a minimal unified diff patch for a single hunk.
197pub fn build_patch(hunk: &DiffHunk) -> String {
198    let mut patch = String::new();
199    patch.push_str(&hunk.file_header);
200    patch.push('\n');
201    patch.push_str(&hunk.header);
202    patch.push('\n');
203    for line in &hunk.lines {
204        patch.push_str(line);
205        patch.push('\n');
206    }
207    patch
208}
209
210/// Apply a patch using git apply.
211pub fn apply_patch(patch: &str, mode: &ApplyMode) -> Result<()> {
212    apply_patch_impl(patch, mode, None)
213}
214
215/// Apply a patch using git apply against a specific index file.
216pub fn apply_patch_to_index(
217    patch: &str,
218    mode: &ApplyMode,
219    index_path: &std::path::Path,
220) -> Result<()> {
221    apply_patch_impl(patch, mode, Some(index_path))
222}
223
224fn apply_patch_impl(
225    patch: &str,
226    mode: &ApplyMode,
227    index_path: Option<&std::path::Path>,
228) -> Result<()> {
229    use std::io::Write;
230    use std::process::{Command, Stdio};
231
232    let mut cmd = Command::new("git");
233    cmd.arg("apply");
234
235    if let Some(idx) = index_path {
236        cmd.env("GIT_INDEX_FILE", idx);
237    }
238
239    match mode {
240        ApplyMode::Stage => {
241            cmd.arg("--cached");
242        }
243        ApplyMode::Unstage => {
244            cmd.arg("--cached").arg("--reverse");
245        }
246        ApplyMode::Discard => {
247            cmd.arg("--reverse");
248        }
249    }
250
251    cmd.stdin(Stdio::piped());
252    let mut child = cmd.spawn().context("failed to run git apply")?;
253    child.stdin.as_mut().unwrap().write_all(patch.as_bytes())?;
254    let output = child.wait_with_output()?;
255
256    if !output.status.success() {
257        anyhow::bail!(
258            "git apply failed: {}",
259            String::from_utf8_lossy(&output.stderr)
260        );
261    }
262
263    Ok(())
264}