codetether_agent/tool/
patch.rs1use super::{Tool, ToolResult};
4use anyhow::{Context, Result};
5use async_trait::async_trait;
6use serde_json::{Value, json};
7use std::path::PathBuf;
8
9pub struct ApplyPatchTool {
10 root: PathBuf,
11}
12
13impl Default for ApplyPatchTool {
14 fn default() -> Self {
15 Self::new()
16 }
17}
18
19impl ApplyPatchTool {
20 pub fn new() -> Self {
21 Self {
22 root: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
23 }
24 }
25
26 #[allow(dead_code)]
27 pub fn with_root(root: PathBuf) -> Self {
28 Self { root }
29 }
30
31 fn parse_patch(&self, patch: &str) -> Result<Vec<PatchHunk>> {
32 let mut hunks = Vec::new();
33 let mut current_file: Option<String> = None;
34 let mut current_hunk: Option<HunkBuilder> = None;
35
36 for line in patch.lines() {
37 if line.starts_with("--- ") {
38 } else if line.starts_with("+++ ") {
40 let path = line.strip_prefix("+++ ").unwrap_or("");
42 let path = path.strip_prefix("b/").unwrap_or(path);
43 let path = path.split('\t').next().unwrap_or(path);
44 current_file = Some(path.to_string());
45 } else if line.starts_with("@@ ") {
46 if let Some(hunk) = current_hunk.take() {
48 if let Some(file) = ¤t_file {
49 hunks.push(hunk.build(file.clone()));
50 }
51 }
52
53 let parts: Vec<&str> = line.split_whitespace().collect();
54 if parts.len() >= 3 {
55 let old_range = parts[1].strip_prefix('-').unwrap_or(parts[1]);
56 let old_start: usize = old_range
57 .split(',')
58 .next()
59 .and_then(|s| s.parse().ok())
60 .unwrap_or(1);
61
62 current_hunk = Some(HunkBuilder {
63 start_line: old_start,
64 old_lines: Vec::new(),
65 new_lines: Vec::new(),
66 });
67 }
68 } else if let Some(ref mut hunk) = current_hunk {
69 if let Some(stripped) = line.strip_prefix('-') {
70 hunk.old_lines.push(stripped.to_string());
71 } else if let Some(stripped) = line.strip_prefix('+') {
72 hunk.new_lines.push(stripped.to_string());
73 } else if line.starts_with(' ') || line.is_empty() {
74 let content = if line.is_empty() { "" } else { &line[1..] };
75 hunk.old_lines.push(content.to_string());
76 hunk.new_lines.push(content.to_string());
77 }
78 }
79 }
80
81 if let Some(hunk) = current_hunk {
83 if let Some(file) = ¤t_file {
84 hunks.push(hunk.build(file.clone()));
85 }
86 }
87
88 Ok(hunks)
89 }
90
91 fn apply_hunk(&self, content: &str, hunk: &PatchHunk) -> Result<String> {
92 let lines: Vec<&str> = content.lines().collect();
93 let mut result = Vec::new();
94
95 let mut match_start = None;
97 for i in 0..=lines.len().saturating_sub(hunk.old_lines.len()) {
98 let mut matches = true;
99 for (j, old_line) in hunk.old_lines.iter().enumerate() {
100 if i + j >= lines.len() || lines[i + j].trim() != old_line.trim() {
101 matches = false;
102 break;
103 }
104 }
105 if matches {
106 match_start = Some(i);
107 break;
108 }
109 }
110
111 let match_start =
112 match_start.ok_or_else(|| anyhow::anyhow!("Could not find hunk location"))?;
113
114 result.extend(lines[..match_start].iter().map(|s| s.to_string()));
116 result.extend(hunk.new_lines.clone());
117 result.extend(
118 lines[match_start + hunk.old_lines.len()..]
119 .iter()
120 .map(|s| s.to_string()),
121 );
122
123 Ok(result.join("\n"))
124 }
125}
126
127struct HunkBuilder {
128 start_line: usize,
129 old_lines: Vec<String>,
130 new_lines: Vec<String>,
131}
132
133impl HunkBuilder {
134 fn build(self, file: String) -> PatchHunk {
135 PatchHunk {
136 file,
137 start_line: self.start_line,
138 old_lines: self.old_lines,
139 new_lines: self.new_lines,
140 }
141 }
142}
143
144#[derive(Debug)]
145struct PatchHunk {
146 file: String,
147 start_line: usize,
148 old_lines: Vec<String>,
149 new_lines: Vec<String>,
150}
151
152#[async_trait]
153impl Tool for ApplyPatchTool {
154 fn id(&self) -> &str {
155 "apply_patch"
156 }
157 fn name(&self) -> &str {
158 "Apply Patch"
159 }
160 fn description(&self) -> &str {
161 "Apply a unified diff patch to files in the workspace."
162 }
163 fn parameters(&self) -> Value {
164 json!({
165 "type": "object",
166 "properties": {
167 "patch": {"type": "string", "description": "Unified diff patch content"},
168 "dry_run": {"type": "boolean", "default": false, "description": "Preview without applying"}
169 },
170 "required": ["patch"]
171 })
172 }
173
174 async fn execute(&self, params: Value) -> Result<ToolResult> {
175 let patch = match params.get("patch").and_then(|v| v.as_str()) {
176 Some(s) if !s.is_empty() => s.to_string(),
177 _ => {
178 return Ok(ToolResult::structured_error(
179 "MISSING_FIELD",
180 "apply_patch",
181 "patch is required and must be a non-empty string containing a unified diff",
182 Some(vec!["patch"]),
183 Some(json!({
184 "patch": "--- a/file.rs\n+++ b/file.rs\n@@ -1,3 +1,3 @@\n line1\n-old line\n+new line\n line3",
185 "dry_run": false
186 })),
187 ));
188 }
189 };
190 let dry_run = params
191 .get("dry_run")
192 .and_then(|v| v.as_bool())
193 .unwrap_or(false);
194
195 let hunks = self.parse_patch(&patch)?;
196
197 if hunks.is_empty() {
198 return Ok(ToolResult::structured_error(
199 "PARSE_ERROR",
200 "apply_patch",
201 "No valid hunks found in patch. Make sure the patch is in unified diff format with proper --- a/, +++ b/, and @@ headers.",
202 None,
203 Some(json!({
204 "expected_format": "--- a/path/to/file\n+++ b/path/to/file\n@@ -start,count +start,count @@\n context line\n-removed line\n+added line\n context line",
205 "hint": "Lines starting with - are removed, + are added, space are context"
206 })),
207 ));
208 }
209
210 let mut results = Vec::new();
211 let mut files_modified = Vec::new();
212
213 let mut by_file: std::collections::HashMap<String, Vec<&PatchHunk>> =
215 std::collections::HashMap::new();
216 for hunk in &hunks {
217 by_file.entry(hunk.file.clone()).or_default().push(hunk);
218 }
219
220 for (file, file_hunks) in by_file {
221 let path = self.root.join(&file);
222
223 let mut content = if path.exists() {
224 std::fs::read_to_string(&path).context(format!("Failed to read {}", file))?
225 } else {
226 String::new()
227 };
228
229 for hunk in file_hunks {
230 match self.apply_hunk(&content, hunk) {
231 Ok(new_content) => {
232 content = new_content;
233 results.push(format!(
234 "✓ Applied hunk to {} at line {}",
235 file, hunk.start_line
236 ));
237 }
238 Err(e) => {
239 results.push(format!("✗ Failed to apply hunk to {}: {}", file, e));
240 }
241 }
242 }
243
244 if !dry_run {
245 if let Some(parent) = path.parent() {
246 std::fs::create_dir_all(parent)?;
247 }
248 std::fs::write(&path, &content)?;
249 files_modified.push(file);
250 }
251 }
252
253 let action = if dry_run { "Would modify" } else { "Modified" };
254 let summary = format!(
255 "{} {} files:\n{}",
256 action,
257 files_modified.len(),
258 results.join("\n")
259 );
260
261 Ok(ToolResult::success(summary).with_metadata("files", json!(files_modified)))
262 }
263}