codetether_agent/tool/
patch.rs1use anyhow::{Context, Result};
4use async_trait::async_trait;
5use serde::Deserialize;
6use serde_json::{json, Value};
7use std::path::PathBuf;
8use super::{Tool, ToolResult};
9
10pub struct ApplyPatchTool {
11 root: PathBuf,
12}
13
14impl Default for ApplyPatchTool {
15 fn default() -> Self { Self::new() }
16}
17
18impl ApplyPatchTool {
19 pub fn new() -> Self {
20 Self { root: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")) }
21 }
22
23 #[allow(dead_code)]
24 pub fn with_root(root: PathBuf) -> Self {
25 Self { root }
26 }
27
28 fn parse_patch(&self, patch: &str) -> Result<Vec<PatchHunk>> {
29 let mut hunks = Vec::new();
30 let mut current_file: Option<String> = None;
31 let mut current_hunk: Option<HunkBuilder> = None;
32
33 for line in patch.lines() {
34 if line.starts_with("--- ") {
35 } else if line.starts_with("+++ ") {
37 let path = line.strip_prefix("+++ ").unwrap_or("");
39 let path = path.strip_prefix("b/").unwrap_or(path);
40 let path = path.split('\t').next().unwrap_or(path);
41 current_file = Some(path.to_string());
42 } else if line.starts_with("@@ ") {
43 if let Some(hunk) = current_hunk.take() {
45 if let Some(file) = ¤t_file {
46 hunks.push(hunk.build(file.clone()));
47 }
48 }
49
50 let parts: Vec<&str> = line.split_whitespace().collect();
51 if parts.len() >= 3 {
52 let old_range = parts[1].strip_prefix('-').unwrap_or(parts[1]);
53 let old_start: usize = old_range.split(',').next()
54 .and_then(|s| s.parse().ok()).unwrap_or(1);
55
56 current_hunk = Some(HunkBuilder {
57 start_line: old_start,
58 old_lines: Vec::new(),
59 new_lines: Vec::new(),
60 });
61 }
62 } else if let Some(ref mut hunk) = current_hunk {
63 if let Some(stripped) = line.strip_prefix('-') {
64 hunk.old_lines.push(stripped.to_string());
65 } else if let Some(stripped) = line.strip_prefix('+') {
66 hunk.new_lines.push(stripped.to_string());
67 } else if line.starts_with(' ') || line.is_empty() {
68 let content = if line.is_empty() { "" } else { &line[1..] };
69 hunk.old_lines.push(content.to_string());
70 hunk.new_lines.push(content.to_string());
71 }
72 }
73 }
74
75 if let Some(hunk) = current_hunk {
77 if let Some(file) = ¤t_file {
78 hunks.push(hunk.build(file.clone()));
79 }
80 }
81
82 Ok(hunks)
83 }
84
85 fn apply_hunk(&self, content: &str, hunk: &PatchHunk) -> Result<String> {
86 let lines: Vec<&str> = content.lines().collect();
87 let mut result = Vec::new();
88
89 let mut match_start = None;
91 for i in 0..=lines.len().saturating_sub(hunk.old_lines.len()) {
92 let mut matches = true;
93 for (j, old_line) in hunk.old_lines.iter().enumerate() {
94 if i + j >= lines.len() || lines[i + j].trim() != old_line.trim() {
95 matches = false;
96 break;
97 }
98 }
99 if matches {
100 match_start = Some(i);
101 break;
102 }
103 }
104
105 let match_start = match_start.ok_or_else(|| anyhow::anyhow!("Could not find hunk location"))?;
106
107 result.extend(lines[..match_start].iter().map(|s| s.to_string()));
109 result.extend(hunk.new_lines.clone());
110 result.extend(lines[match_start + hunk.old_lines.len()..].iter().map(|s| s.to_string()));
111
112 Ok(result.join("\n"))
113 }
114}
115
116struct HunkBuilder {
117 start_line: usize,
118 old_lines: Vec<String>,
119 new_lines: Vec<String>,
120}
121
122impl HunkBuilder {
123 fn build(self, file: String) -> PatchHunk {
124 PatchHunk {
125 file,
126 start_line: self.start_line,
127 old_lines: self.old_lines,
128 new_lines: self.new_lines,
129 }
130 }
131}
132
133#[derive(Debug)]
134struct PatchHunk {
135 file: String,
136 start_line: usize,
137 old_lines: Vec<String>,
138 new_lines: Vec<String>,
139}
140
141#[derive(Deserialize)]
142struct Params {
143 patch: String,
144 #[serde(default)]
145 dry_run: bool,
146}
147
148#[async_trait]
149impl Tool for ApplyPatchTool {
150 fn id(&self) -> &str { "apply_patch" }
151 fn name(&self) -> &str { "Apply Patch" }
152 fn description(&self) -> &str { "Apply a unified diff patch to files in the workspace." }
153 fn parameters(&self) -> Value {
154 json!({
155 "type": "object",
156 "properties": {
157 "patch": {"type": "string", "description": "Unified diff patch content"},
158 "dry_run": {"type": "boolean", "default": false, "description": "Preview without applying"}
159 },
160 "required": ["patch"]
161 })
162 }
163
164 async fn execute(&self, params: Value) -> Result<ToolResult> {
165 let p: Params = serde_json::from_value(params).context("Invalid params")?;
166
167 let hunks = self.parse_patch(&p.patch)?;
168
169 if hunks.is_empty() {
170 return Ok(ToolResult::error("No valid hunks found in patch"));
171 }
172
173 let mut results = Vec::new();
174 let mut files_modified = Vec::new();
175
176 let mut by_file: std::collections::HashMap<String, Vec<&PatchHunk>> = std::collections::HashMap::new();
178 for hunk in &hunks {
179 by_file.entry(hunk.file.clone()).or_default().push(hunk);
180 }
181
182 for (file, file_hunks) in by_file {
183 let path = self.root.join(&file);
184
185 let mut content = if path.exists() {
186 std::fs::read_to_string(&path).context(format!("Failed to read {}", file))?
187 } else {
188 String::new()
189 };
190
191 for hunk in file_hunks {
192 match self.apply_hunk(&content, hunk) {
193 Ok(new_content) => {
194 content = new_content;
195 results.push(format!("✓ Applied hunk to {} at line {}", file, hunk.start_line));
196 }
197 Err(e) => {
198 results.push(format!("✗ Failed to apply hunk to {}: {}", file, e));
199 }
200 }
201 }
202
203 if !p.dry_run {
204 if let Some(parent) = path.parent() {
205 std::fs::create_dir_all(parent)?;
206 }
207 std::fs::write(&path, &content)?;
208 files_modified.push(file);
209 }
210 }
211
212 let action = if p.dry_run { "Would modify" } else { "Modified" };
213 let summary = format!("{} {} files:\n{}", action, files_modified.len(), results.join("\n"));
214
215 Ok(ToolResult::success(summary).with_metadata("files", json!(files_modified)))
216 }
217}