codetether_agent/tool/
patch.rs1use super::{Tool, ToolResult};
4use anyhow::{Context, Result};
5use async_trait::async_trait;
6use serde::Deserialize;
7use serde_json::{Value, json};
8use std::path::PathBuf;
9
10pub struct ApplyPatchTool {
11 root: PathBuf,
12}
13
14impl Default for ApplyPatchTool {
15 fn default() -> Self {
16 Self::new()
17 }
18}
19
20impl ApplyPatchTool {
21 pub fn new() -> Self {
22 Self {
23 root: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
24 }
25 }
26
27 #[allow(dead_code)]
28 pub fn with_root(root: PathBuf) -> Self {
29 Self { root }
30 }
31
32 fn parse_patch(&self, patch: &str) -> Result<Vec<PatchHunk>> {
33 let mut hunks = Vec::new();
34 let mut current_file: Option<String> = None;
35 let mut current_hunk: Option<HunkBuilder> = None;
36
37 for line in patch.lines() {
38 if line.starts_with("--- ") {
39 } else if line.starts_with("+++ ") {
41 let path = line.strip_prefix("+++ ").unwrap_or("");
43 let path = path.strip_prefix("b/").unwrap_or(path);
44 let path = path.split('\t').next().unwrap_or(path);
45 current_file = Some(path.to_string());
46 } else if line.starts_with("@@ ") {
47 if let Some(hunk) = current_hunk.take() {
49 if let Some(file) = ¤t_file {
50 hunks.push(hunk.build(file.clone()));
51 }
52 }
53
54 let parts: Vec<&str> = line.split_whitespace().collect();
55 if parts.len() >= 3 {
56 let old_range = parts[1].strip_prefix('-').unwrap_or(parts[1]);
57 let old_start: usize = old_range
58 .split(',')
59 .next()
60 .and_then(|s| s.parse().ok())
61 .unwrap_or(1);
62
63 current_hunk = Some(HunkBuilder {
64 start_line: old_start,
65 old_lines: Vec::new(),
66 new_lines: Vec::new(),
67 });
68 }
69 } else if let Some(ref mut hunk) = current_hunk {
70 if let Some(stripped) = line.strip_prefix('-') {
71 hunk.old_lines.push(stripped.to_string());
72 } else if let Some(stripped) = line.strip_prefix('+') {
73 hunk.new_lines.push(stripped.to_string());
74 } else if line.starts_with(' ') || line.is_empty() {
75 let content = if line.is_empty() { "" } else { &line[1..] };
76 hunk.old_lines.push(content.to_string());
77 hunk.new_lines.push(content.to_string());
78 }
79 }
80 }
81
82 if let Some(hunk) = current_hunk {
84 if let Some(file) = ¤t_file {
85 hunks.push(hunk.build(file.clone()));
86 }
87 }
88
89 Ok(hunks)
90 }
91
92 fn apply_hunk(&self, content: &str, hunk: &PatchHunk) -> Result<String> {
93 let lines: Vec<&str> = content.lines().collect();
94 let mut result = Vec::new();
95
96 let mut match_start = None;
98 for i in 0..=lines.len().saturating_sub(hunk.old_lines.len()) {
99 let mut matches = true;
100 for (j, old_line) in hunk.old_lines.iter().enumerate() {
101 if i + j >= lines.len() || lines[i + j].trim() != old_line.trim() {
102 matches = false;
103 break;
104 }
105 }
106 if matches {
107 match_start = Some(i);
108 break;
109 }
110 }
111
112 let match_start =
113 match_start.ok_or_else(|| anyhow::anyhow!("Could not find hunk location"))?;
114
115 result.extend(lines[..match_start].iter().map(|s| s.to_string()));
117 result.extend(hunk.new_lines.clone());
118 result.extend(
119 lines[match_start + hunk.old_lines.len()..]
120 .iter()
121 .map(|s| s.to_string()),
122 );
123
124 Ok(result.join("\n"))
125 }
126}
127
128struct HunkBuilder {
129 start_line: usize,
130 old_lines: Vec<String>,
131 new_lines: Vec<String>,
132}
133
134impl HunkBuilder {
135 fn build(self, file: String) -> PatchHunk {
136 PatchHunk {
137 file,
138 start_line: self.start_line,
139 old_lines: self.old_lines,
140 new_lines: self.new_lines,
141 }
142 }
143}
144
145#[derive(Debug)]
146struct PatchHunk {
147 file: String,
148 start_line: usize,
149 old_lines: Vec<String>,
150 new_lines: Vec<String>,
151}
152
153#[derive(Deserialize)]
154struct Params {
155 patch: String,
156 #[serde(default)]
157 dry_run: bool,
158}
159
160#[async_trait]
161impl Tool for ApplyPatchTool {
162 fn id(&self) -> &str {
163 "apply_patch"
164 }
165 fn name(&self) -> &str {
166 "Apply Patch"
167 }
168 fn description(&self) -> &str {
169 "Apply a unified diff patch to files in the workspace."
170 }
171 fn parameters(&self) -> Value {
172 json!({
173 "type": "object",
174 "properties": {
175 "patch": {"type": "string", "description": "Unified diff patch content"},
176 "dry_run": {"type": "boolean", "default": false, "description": "Preview without applying"}
177 },
178 "required": ["patch"]
179 })
180 }
181
182 async fn execute(&self, params: Value) -> Result<ToolResult> {
183 let p: Params = serde_json::from_value(params).context("Invalid params")?;
184
185 let hunks = self.parse_patch(&p.patch)?;
186
187 if hunks.is_empty() {
188 return Ok(ToolResult::error("No valid hunks found in patch"));
189 }
190
191 let mut results = Vec::new();
192 let mut files_modified = Vec::new();
193
194 let mut by_file: std::collections::HashMap<String, Vec<&PatchHunk>> =
196 std::collections::HashMap::new();
197 for hunk in &hunks {
198 by_file.entry(hunk.file.clone()).or_default().push(hunk);
199 }
200
201 for (file, file_hunks) in by_file {
202 let path = self.root.join(&file);
203
204 let mut content = if path.exists() {
205 std::fs::read_to_string(&path).context(format!("Failed to read {}", file))?
206 } else {
207 String::new()
208 };
209
210 for hunk in file_hunks {
211 match self.apply_hunk(&content, hunk) {
212 Ok(new_content) => {
213 content = new_content;
214 results.push(format!(
215 "✓ Applied hunk to {} at line {}",
216 file, hunk.start_line
217 ));
218 }
219 Err(e) => {
220 results.push(format!("✗ Failed to apply hunk to {}: {}", file, e));
221 }
222 }
223 }
224
225 if !p.dry_run {
226 if let Some(parent) = path.parent() {
227 std::fs::create_dir_all(parent)?;
228 }
229 std::fs::write(&path, &content)?;
230 files_modified.push(file);
231 }
232 }
233
234 let action = if p.dry_run {
235 "Would modify"
236 } else {
237 "Modified"
238 };
239 let summary = format!(
240 "{} {} files:\n{}",
241 action,
242 files_modified.len(),
243 results.join("\n")
244 );
245
246 Ok(ToolResult::success(summary).with_metadata("files", json!(files_modified)))
247 }
248}