use serde::Deserialize;
use crate::common::{AgentError, Result};
#[derive(Debug, Deserialize)]
pub struct GitPatchInput {
pub path: String,
pub patch: String,
}
pub fn definition() -> serde_json::Value {
serde_json::json!({
"type": "function",
"function": {
"name": "git_patch",
"description": "Apply a unified diff patch to a file. Use this for multi-hunk edits or when you need to modify several non-adjacent sections of a file at once. The patch must be in unified diff format with proper hunk headers (@@ -line,count +line,count @@).",
"parameters": {
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "Path to the file to patch"
},
"patch": {
"type": "string",
"description": "Unified diff patch content. Must include --- a/path and +++ b/path headers and @@ hunk headers."
}
},
"required": ["path", "patch"]
}
}
})
}
pub async fn execute(input: GitPatchInput, working_dir: &str) -> Result<String> {
let resolved_path = resolve_path(&input.path, working_dir);
if !std::path::Path::new(&resolved_path).exists() {
return Err(AgentError::InvalidArgument(format!(
"File not found: {}",
input.path
)));
}
let patch = normalize_patch(&input.patch, &input.path);
if !patch.contains("@@") {
return Err(AgentError::InvalidArgument(
"Invalid patch format: missing @@ hunk headers.\n\
Expected unified diff format like:\n\
--- a/file.rs\n\
+++ b/file.rs\n\
@@ -1,3 +1,4 @@\n\
context line\n\
-old line\n\
+new line\n\
context line"
.to_string(),
));
}
let result = apply_with_git(&patch, working_dir).await;
match result {
Ok(msg) => Ok(msg),
Err(git_err) => {
tracing::warn!("git apply failed, trying manual patch: {git_err}");
match apply_manual(&patch, &resolved_path).await {
Ok(msg) => Ok(format!("{msg} (manual fallback)")),
Err(manual_err) => Err(AgentError::InvalidArgument(format!(
"Patch failed.\n\
git apply: {git_err}\n\
manual: {manual_err}\n\n\
Tip: Use file_edit for single-location changes, \
or ensure your patch has correct line numbers and context."
))),
}
}
}
}
async fn apply_with_git(patch: &str, working_dir: &str) -> Result<String> {
let mut child = tokio::process::Command::new("git")
.args(["apply", "--verbose", "-"])
.current_dir(working_dir)
.stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped())
.spawn()?;
if let Some(mut stdin) = child.stdin.take() {
use tokio::io::AsyncWriteExt;
stdin.write_all(patch.as_bytes()).await?;
drop(stdin);
}
let output = child.wait_with_output().await?;
if output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
let applied_hunks = stderr.lines().count();
Ok(format!(
"Patch applied successfully ({applied_hunks} hunks)"
))
} else {
let stderr = String::from_utf8_lossy(&output.stderr);
Err(AgentError::Command(stderr.trim().to_string()))
}
}
async fn apply_manual(patch: &str, file_path: &str) -> Result<String> {
let content = tokio::fs::read_to_string(file_path).await?;
let lines: Vec<&str> = content.lines().collect();
let hunks = parse_hunks(patch)?;
if hunks.is_empty() {
return Err(AgentError::InvalidArgument(
"No valid hunks found in patch".to_string(),
));
}
let mut result_lines: Vec<String> = lines.iter().map(|l| l.to_string()).collect();
let mut sorted_hunks = hunks;
sorted_hunks.sort_by(|a, b| b.old_start.cmp(&a.old_start));
for hunk in &sorted_hunks {
let start = (hunk.old_start as usize).saturating_sub(1);
let end = start + hunk.old_count as usize;
if end > result_lines.len() {
return Err(AgentError::InvalidArgument(format!(
"Hunk at line {} extends beyond file (file has {} lines)",
hunk.old_start,
result_lines.len()
)));
}
let mut old_idx = start;
for line in &hunk.lines {
match line {
PatchLine::Context(text) | PatchLine::Remove(text) => {
if old_idx < result_lines.len() {
let actual = &result_lines[old_idx];
if actual.trim() != text.trim() {
return Err(AgentError::InvalidArgument(format!(
"Context mismatch at line {}: expected '{}', got '{}'",
old_idx + 1,
text,
actual,
)));
}
}
old_idx += 1;
}
PatchLine::Add(_) => {}
}
}
let mut new_lines: Vec<String> = Vec::new();
for line in &hunk.lines {
match line {
PatchLine::Context(text) | PatchLine::Add(text) => {
new_lines.push(text.to_string());
}
PatchLine::Remove(_) => {}
}
}
result_lines.splice(start..end, new_lines);
}
let new_content = result_lines.join("\n");
let final_content = if content.ends_with('\n') {
format!("{new_content}\n")
} else {
new_content
};
tokio::fs::write(file_path, &final_content).await?;
Ok(format!(
"Patch applied manually ({} hunks)",
sorted_hunks.len()
))
}
struct Hunk {
old_start: u32,
old_count: u32,
lines: Vec<PatchLine>,
}
enum PatchLine {
Context(String),
Add(String),
Remove(String),
}
fn parse_hunks(patch: &str) -> Result<Vec<Hunk>> {
let mut hunks = Vec::new();
let mut current_hunk: Option<Hunk> = None;
for line in patch.lines() {
if line.starts_with("@@") {
if let Some(hunk) = current_hunk.take() {
hunks.push(hunk);
}
let (old_start, old_count) = parse_hunk_header(line)?;
current_hunk = Some(Hunk {
old_start,
old_count,
lines: Vec::new(),
});
} else if line.starts_with("---") || line.starts_with("+++") {
continue;
} else if let Some(ref mut hunk) = current_hunk {
if let Some(text) = line.strip_prefix('+') {
hunk.lines.push(PatchLine::Add(text.to_string()));
} else if let Some(text) = line.strip_prefix('-') {
hunk.lines.push(PatchLine::Remove(text.to_string()));
} else if let Some(text) = line.strip_prefix(' ') {
hunk.lines.push(PatchLine::Context(text.to_string()));
} else if line.is_empty() {
hunk.lines.push(PatchLine::Context(String::new()));
}
}
}
if let Some(hunk) = current_hunk {
hunks.push(hunk);
}
Ok(hunks)
}
fn parse_hunk_header(line: &str) -> Result<(u32, u32)> {
let parts: Vec<&str> = line.split_whitespace().collect();
if parts.len() < 3 {
return Err(AgentError::InvalidArgument(format!(
"Invalid hunk header: {}",
line
)));
}
let old_range = parts[1]; let old_range = old_range.strip_prefix('-').unwrap_or(old_range);
let (start, count) = if let Some((s, c)) = old_range.split_once(',') {
(s.parse::<u32>()?, c.parse::<u32>()?)
} else {
(old_range.parse::<u32>()?, 1)
};
Ok((start, count))
}
fn normalize_patch(patch: &str, path: &str) -> String {
let body: String = patch
.lines()
.filter(|l| !l.starts_with("--- ") && !l.starts_with("+++ "))
.collect::<Vec<_>>()
.join("\n");
let body = if patch.ends_with('\n') {
format!("{body}\n")
} else {
body
};
format!("--- a/{path}\n+++ b/{path}\n{body}")
}
fn resolve_path(path: &str, working_dir: &str) -> String {
let candidate = if path.starts_with('/') {
std::path::PathBuf::from(path)
} else {
std::path::Path::new(working_dir).join(path)
};
crate::agent::approval::normalize_path_lexical(&candidate)
.to_string_lossy()
.into_owned()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_hunk_header() {
let (start, count) = parse_hunk_header("@@ -10,5 +10,7 @@ fn foo()").unwrap();
assert_eq!(start, 10);
assert_eq!(count, 5);
}
#[test]
fn test_parse_hunk_header_single_line() {
let (start, count) = parse_hunk_header("@@ -5 +5,2 @@").unwrap();
assert_eq!(start, 5);
assert_eq!(count, 1);
}
#[test]
fn test_parse_hunks() {
let patch = "\
--- a/test.rs
+++ b/test.rs
@@ -1,3 +1,4 @@
line one
+inserted line
line two
line three
@@ -10,2 +11,2 @@
-old line ten
+new line ten
line eleven";
let hunks = parse_hunks(patch).unwrap();
assert_eq!(hunks.len(), 2);
assert_eq!(hunks[0].old_start, 1);
assert_eq!(hunks[0].old_count, 3);
assert_eq!(hunks[1].old_start, 10);
assert_eq!(hunks[1].old_count, 2);
}
#[test]
fn test_normalize_patch_adds_headers() {
let patch = "@@ -1,3 +1,4 @@\n line 1\n+added\n line 2\n line 3";
let normalized = normalize_patch(patch, "src/main.rs");
assert!(normalized.starts_with("--- a/src/main.rs"));
assert!(normalized.contains("+++ b/src/main.rs"));
}
#[test]
fn test_normalize_patch_keeps_existing_headers() {
let patch = "--- a/foo.rs\n+++ b/foo.rs\n@@ -1,1 +1,1 @@\n-old\n+new";
let normalized = normalize_patch(patch, "foo.rs");
assert_eq!(normalized, patch);
}
#[tokio::test]
async fn test_manual_apply_single_hunk() {
let dir = tempfile::tempdir().unwrap();
let file_path = dir.path().join("test.txt");
std::fs::write(&file_path, "line one\nline two\nline three\nline four\n").unwrap();
let patch = "\
--- a/test.txt
+++ b/test.txt
@@ -2,2 +2,3 @@
line two
-line three
+line THREE
+inserted line";
let result = apply_manual(patch, file_path.to_str().unwrap()).await;
assert!(result.is_ok(), "apply_manual failed: {:?}", result);
let content = std::fs::read_to_string(&file_path).unwrap();
assert!(content.contains("line THREE"));
assert!(content.contains("inserted line"));
assert!(!content.contains("line three"));
}
}