llm_git/
patch.rs

1use std::process::Command;
2
3use crate::{
4   error::{CommitGenError, Result},
5   types::{ChangeGroup, FileChange, HunkSelector},
6};
7
8/// Represents a parsed hunk from a diff
9#[derive(Debug, Clone)]
10struct ParsedHunk {
11   header:         String,
12   #[allow(dead_code, reason = "Useful metadata for future enhancements")]
13   old_start:      usize,
14   #[allow(dead_code, reason = "Useful metadata for future enhancements")]
15   old_count:      usize,
16   #[allow(dead_code, reason = "Useful metadata for future enhancements")]
17   new_start:      usize,
18   #[allow(dead_code, reason = "Useful metadata for future enhancements")]
19   new_count:      usize,
20   lines:          Vec<String>,
21   old_line_range: (usize, usize), // (start, end) in original file
22}
23
24/// Create a patch for specific files
25pub fn create_patch_for_files(files: &[String], dir: &str) -> Result<String> {
26   let output = Command::new("git")
27      .arg("diff")
28      .arg("HEAD")
29      .arg("--")
30      .args(files)
31      .current_dir(dir)
32      .output()
33      .map_err(|e| CommitGenError::GitError(format!("Failed to create patch: {e}")))?;
34
35   if !output.status.success() {
36      let stderr = String::from_utf8_lossy(&output.stderr);
37      return Err(CommitGenError::GitError(format!("git diff failed: {stderr}")));
38   }
39
40   Ok(String::from_utf8_lossy(&output.stdout).to_string())
41}
42
43/// Apply patch to staging area
44pub fn apply_patch_to_index(patch: &str, dir: &str) -> Result<()> {
45   let mut child = Command::new("git")
46      .args(["apply", "--cached"])
47      .current_dir(dir)
48      .stdin(std::process::Stdio::piped())
49      .stdout(std::process::Stdio::piped())
50      .stderr(std::process::Stdio::piped())
51      .spawn()
52      .map_err(|e| CommitGenError::GitError(format!("Failed to spawn git apply: {e}")))?;
53
54   if let Some(mut stdin) = child.stdin.take() {
55      use std::io::Write;
56      stdin
57         .write_all(patch.as_bytes())
58         .map_err(|e| CommitGenError::GitError(format!("Failed to write patch: {e}")))?;
59   }
60
61   let output = child
62      .wait_with_output()
63      .map_err(|e| CommitGenError::GitError(format!("Failed to wait for git apply: {e}")))?;
64
65   if !output.status.success() {
66      let stderr = String::from_utf8_lossy(&output.stderr);
67      return Err(CommitGenError::GitError(format!("git apply --cached failed: {stderr}")));
68   }
69
70   Ok(())
71}
72
73/// Stage specific files (simpler alternative to patch application)
74pub fn stage_files(files: &[String], dir: &str) -> Result<()> {
75   if files.is_empty() {
76      return Ok(());
77   }
78
79   let output = Command::new("git")
80      .arg("add")
81      .arg("--")
82      .args(files)
83      .current_dir(dir)
84      .output()
85      .map_err(|e| CommitGenError::GitError(format!("Failed to stage files: {e}")))?;
86
87   if !output.status.success() {
88      let stderr = String::from_utf8_lossy(&output.stderr);
89      return Err(CommitGenError::GitError(format!("git add failed: {stderr}")));
90   }
91
92   Ok(())
93}
94
95/// Reset staging area
96pub fn reset_staging(dir: &str) -> Result<()> {
97   let output = Command::new("git")
98      .args(["reset", "HEAD"])
99      .current_dir(dir)
100      .output()
101      .map_err(|e| CommitGenError::GitError(format!("Failed to reset staging: {e}")))?;
102
103   if !output.status.success() {
104      let stderr = String::from_utf8_lossy(&output.stderr);
105      return Err(CommitGenError::GitError(format!("git reset HEAD failed: {stderr}")));
106   }
107
108   Ok(())
109}
110
111/// Parse hunk header to extract line numbers
112/// Format: @@ -`old_start,old_count` +`new_start,new_count` @@
113fn parse_hunk_header(header: &str) -> Option<(usize, usize, usize, usize)> {
114   let trimmed = header.trim();
115   if !trimmed.starts_with("@@") {
116      return None;
117   }
118
119   // Extract the part between @@ markers
120   let middle = if let Some(start) = trimmed.find("@@") {
121      let after_first = &trimmed[start + 2..];
122      if let Some(end) = after_first.find("@@") {
123         &after_first[..end].trim()
124      } else {
125         return None;
126      }
127   } else {
128      return None;
129   };
130
131   // Parse "-old_start,old_count +new_start,new_count"
132   let parts: Vec<&str> = middle.split_whitespace().collect();
133   if parts.len() < 2 {
134      return None;
135   }
136
137   let old_part = parts[0].strip_prefix('-')?;
138   let new_part = parts[1].strip_prefix('+')?;
139
140   let parse_range = |s: &str| -> Option<(usize, usize)> {
141      if let Some((start, count)) = s.split_once(',') {
142         Some((start.parse().ok()?, count.parse().ok()?))
143      } else {
144         // If no comma, it's just a line number (count is 1)
145         Some((s.parse().ok()?, 1))
146      }
147   };
148
149   let (old_start, old_count) = parse_range(old_part)?;
150   let (new_start, new_count) = parse_range(new_part)?;
151
152   Some((old_start, old_count, new_start, new_count))
153}
154
155/// Parse all hunks from a file's diff
156fn parse_file_hunks(file_diff: &str) -> Vec<ParsedHunk> {
157   let mut hunks = Vec::new();
158   let mut in_header = true;
159   let mut current_hunk: Option<ParsedHunk> = None;
160
161   for line in file_diff.lines() {
162      if in_header {
163         if line.starts_with("+++") {
164            in_header = false;
165         }
166         continue;
167      }
168
169      if line.starts_with("@@ ") {
170         // Save previous hunk
171         if let Some(hunk) = current_hunk.take() {
172            hunks.push(hunk);
173         }
174
175         // Parse new hunk header
176         if let Some((old_start, old_count, new_start, new_count)) = parse_hunk_header(line) {
177            let old_end = if old_count == 0 {
178               old_start
179            } else {
180               old_start + old_count - 1
181            };
182
183            current_hunk = Some(ParsedHunk {
184               header: line.to_string(),
185               old_start,
186               old_count,
187               new_start,
188               new_count,
189               lines: vec![line.to_string()],
190               old_line_range: (old_start, old_end),
191            });
192         }
193      } else if let Some(hunk) = &mut current_hunk {
194         hunk.lines.push(line.to_string());
195      }
196   }
197
198   // Don't forget the last hunk
199   if let Some(hunk) = current_hunk {
200      hunks.push(hunk);
201   }
202
203   hunks
204}
205
206/// Map line range to hunks that overlap with it
207fn find_hunks_for_line_range(hunks: &[ParsedHunk], start: usize, end: usize) -> Vec<String> {
208   hunks
209      .iter()
210      .filter(|hunk| {
211         // Check if line range overlaps with hunk's old line range
212         let (hunk_start, hunk_end) = hunk.old_line_range;
213         !(end < hunk_start || start > hunk_end)
214      })
215      .map(|hunk| hunk.header.clone())
216      .collect()
217}
218
219/// Convert `HunkSelectors` to actual hunk headers deterministically
220fn resolve_selectors_to_headers(
221   full_diff: &str,
222   file_path: &str,
223   selectors: &[HunkSelector],
224) -> Result<Vec<String>> {
225   // Extract file diff
226   let file_diff = extract_file_diff(full_diff, file_path)?;
227
228   // Parse all hunks from the file
229   let hunks = parse_file_hunks(&file_diff);
230
231   let mut headers = Vec::new();
232
233   for selector in selectors {
234      match selector {
235         HunkSelector::All => {
236            // Return all hunk headers
237            return Ok(hunks.iter().map(|h| h.header.clone()).collect());
238         },
239         HunkSelector::Lines { start, end } => {
240            // Find hunks that overlap with this line range
241            let matching = find_hunks_for_line_range(&hunks, *start, *end);
242            if matching.is_empty() {
243               // Check if there are any nearby hunks to suggest
244               let nearby: Vec<_> = hunks
245                  .iter()
246                  .map(|h| {
247                     let (hunk_start, hunk_end) = h.old_line_range;
248                     let distance = if *end < hunk_start {
249                        hunk_start - *end
250                     } else {
251                        (*start).saturating_sub(hunk_end)
252                     };
253                     (distance, hunk_start, hunk_end)
254                  })
255                  .filter(|(dist, ..)| *dist > 0 && *dist < 20)
256                  .collect();
257
258               let hint = if nearby.is_empty() {
259                  String::new()
260               } else {
261                  let (_, nearest_start, nearest_end) =
262                     nearby.iter().min_by_key(|(dist, ..)| dist).unwrap();
263                  format!(" (nearest hunk: lines {nearest_start}-{nearest_end})")
264               };
265
266               return Err(CommitGenError::Other(format!(
267                  "No changes found in lines {start}-{end} of {file_path}. These lines may be \
268                   context (unchanged) rather than modifications{hint}"
269               )));
270            }
271            headers.extend(matching);
272         },
273         HunkSelector::Search { pattern } => {
274            // If it looks like a hunk header, try to match it directly
275            if pattern.starts_with("@@") {
276               let normalized_pattern = normalize_hunk_header(pattern);
277               let matching: Vec<String> = hunks
278                  .iter()
279                  .filter(|h| normalize_hunk_header(&h.header) == normalized_pattern)
280                  .map(|h| h.header.clone())
281                  .collect();
282
283               if matching.is_empty() {
284                  return Err(CommitGenError::Other(format!(
285                     "Hunk header not found: {pattern} in {file_path}"
286                  )));
287               }
288               headers.extend(matching);
289            } else {
290               // Search for pattern in hunk lines
291               let matching: Vec<String> = hunks
292                  .iter()
293                  .filter(|h| h.lines.iter().any(|line| line.contains(pattern)))
294                  .map(|h| h.header.clone())
295                  .collect();
296
297               if matching.is_empty() {
298                  return Err(CommitGenError::Other(format!(
299                     "Pattern '{pattern}' not found in any hunk in {file_path}"
300                  )));
301               }
302               headers.extend(matching);
303            }
304         },
305      }
306   }
307
308   // Deduplicate headers while preserving order
309   let mut seen = std::collections::HashSet::new();
310   Ok(headers
311      .into_iter()
312      .filter(|h| seen.insert(h.clone()))
313      .collect())
314}
315
316/// Extract specific hunks from a full diff for a file
317fn extract_hunks_for_file(
318   full_diff: &str,
319   file_path: &str,
320   hunk_headers: &[String],
321) -> Result<String> {
322   // If "ALL", return entire file diff
323   if hunk_headers.len() == 1 && hunk_headers[0] == "ALL" {
324      return extract_file_diff(full_diff, file_path);
325   }
326
327   let file_diff = extract_file_diff(full_diff, file_path)?;
328   let mut result = String::new();
329   let mut in_header = true;
330   let mut current_hunk = String::new();
331   let mut current_hunk_header = String::new();
332   let mut include_current = false;
333
334   for line in file_diff.lines() {
335      if in_header {
336         result.push_str(line);
337         result.push('\n');
338         if line.starts_with("+++") {
339            in_header = false;
340         }
341      } else if line.starts_with("@@ ") {
342         // Save previous hunk if we were including it
343         if include_current && !current_hunk.is_empty() {
344            result.push_str(&current_hunk);
345         }
346
347         // Start new hunk
348         current_hunk_header = line.to_string();
349         current_hunk = format!("{line}\n");
350
351         // Check if this hunk should be included
352         include_current = hunk_headers.iter().any(|h| {
353            // Normalize comparison - just compare the numeric parts
354            normalize_hunk_header(h) == normalize_hunk_header(&current_hunk_header)
355         });
356      } else {
357         current_hunk.push_str(line);
358         current_hunk.push('\n');
359      }
360   }
361
362   // Don't forget the last hunk
363   if include_current && !current_hunk.is_empty() {
364      result.push_str(&current_hunk);
365   }
366
367   if result
368      .lines()
369      .filter(|l| !l.starts_with("---") && !l.starts_with("+++") && !l.starts_with("diff "))
370      .count()
371      == 0
372   {
373      return Err(CommitGenError::Other(format!(
374         "No hunks found for {file_path} with headers {hunk_headers:?}"
375      )));
376   }
377
378   Ok(result)
379}
380
381/// Normalize hunk header for fuzzy comparison
382/// Extracts line numbers only, ignoring whitespace variations and context
383fn normalize_hunk_header(header: &str) -> String {
384   let trimmed = header.trim();
385
386   // Extract the part between @@ markers
387   let middle = if let Some(start) = trimmed.find("@@") {
388      let after_first = &trimmed[start + 2..];
389      if let Some(end) = after_first.find("@@") {
390         &after_first[..end]
391      } else {
392         after_first
393      }
394   } else {
395      trimmed
396   };
397
398   // Remove all whitespace for fuzzy matching
399   // Keep only: digits, commas, hyphens, plus signs
400   middle
401      .chars()
402      .filter(|c| c.is_ascii_digit() || *c == ',' || *c == '-' || *c == '+')
403      .collect()
404}
405
406/// Extract the diff for a specific file from a full diff
407fn extract_file_diff(full_diff: &str, file_path: &str) -> Result<String> {
408   let mut result = String::new();
409   let mut in_file = false;
410   let mut found = false;
411
412   for line in full_diff.lines() {
413      if line.starts_with("diff --git") {
414         // Check if this is our file
415         if line.contains(&format!("b/{file_path}")) || line.ends_with(&format!(" b/{file_path}")) {
416            in_file = true;
417            found = true;
418            result.push_str(line);
419            result.push('\n');
420         } else {
421            in_file = false;
422         }
423      } else if in_file {
424         result.push_str(line);
425         result.push('\n');
426      }
427   }
428
429   if !found {
430      return Err(CommitGenError::Other(format!("File {file_path} not found in diff")));
431   }
432
433   Ok(result)
434}
435
436/// Create a patch for specific file changes with hunk selection
437pub fn create_patch_for_changes(full_diff: &str, changes: &[FileChange]) -> Result<String> {
438   let mut patch = String::new();
439
440   for change in changes {
441      // Resolve selectors to actual hunk headers
442      let hunk_headers = resolve_selectors_to_headers(full_diff, &change.path, &change.hunks)?;
443      let file_patch = extract_hunks_for_file(full_diff, &change.path, &hunk_headers)?;
444      patch.push_str(&file_patch);
445   }
446
447   Ok(patch)
448}
449
450/// Stage changes for a specific group (hunk-aware).
451/// The `full_diff` argument must be taken before any compose commits run so the
452/// recorded hunk headers remain stable across groups.
453pub fn stage_group_changes(group: &ChangeGroup, dir: &str, full_diff: &str) -> Result<()> {
454   let mut full_files = Vec::new();
455   let mut partial_changes = Vec::new();
456
457   for change in &group.changes {
458      // Check if all selectors are "All" variant
459      let is_all = change.hunks.len() == 1 && matches!(change.hunks[0], HunkSelector::All);
460
461      if is_all {
462         full_files.push(change.path.clone());
463      } else {
464         partial_changes.push(change.clone());
465      }
466   }
467
468   if !full_files.is_empty() {
469      // Deduplicate to avoid redundant git add calls
470      full_files.sort();
471      full_files.dedup();
472      stage_files(&full_files, dir)?;
473   }
474
475   if partial_changes.is_empty() {
476      return Ok(());
477   }
478
479   let patch = create_patch_for_changes(full_diff, &partial_changes)?;
480   apply_patch_to_index(&patch, dir)
481}