Skip to main content

codetether_agent/tool/
patch.rs

1//! Apply Patch Tool - Apply unified diff patches to files.
2
3use super::{Tool, ToolResult};
4use anyhow::{Context, Result};
5use async_trait::async_trait;
6use serde_json::{Value, json};
7use std::path::PathBuf;
8
9pub struct ApplyPatchTool {
10    root: PathBuf,
11}
12
13impl Default for ApplyPatchTool {
14    fn default() -> Self {
15        Self::new()
16    }
17}
18
19impl ApplyPatchTool {
20    pub fn new() -> Self {
21        Self {
22            root: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
23        }
24    }
25
26    #[allow(dead_code)]
27    pub fn with_root(root: PathBuf) -> Self {
28        Self { root }
29    }
30
31    fn parse_patch(&self, patch: &str) -> Result<Vec<PatchHunk>> {
32        let mut hunks = Vec::new();
33        let mut current_file: Option<String> = None;
34        let mut current_hunk: Option<HunkBuilder> = None;
35
36        for line in patch.lines() {
37            if line.starts_with("--- ") {
38                // Old file header, ignore for now
39            } else if line.starts_with("+++ ") {
40                // New file header
41                let path = line.strip_prefix("+++ ").unwrap_or("");
42                let path = path.strip_prefix("b/").unwrap_or(path);
43                let path = path.split('\t').next().unwrap_or(path);
44                current_file = Some(path.to_string());
45            } else if line.starts_with("@@ ") {
46                // Hunk header: @@ -start,count +start,count @@
47                if let Some(hunk) = current_hunk.take() {
48                    if let Some(file) = &current_file {
49                        hunks.push(hunk.build(file.clone()));
50                    }
51                }
52
53                let parts: Vec<&str> = line.split_whitespace().collect();
54                if parts.len() >= 3 {
55                    let old_range = parts[1].strip_prefix('-').unwrap_or(parts[1]);
56                    let old_start: usize = old_range
57                        .split(',')
58                        .next()
59                        .and_then(|s| s.parse().ok())
60                        .unwrap_or(1);
61
62                    current_hunk = Some(HunkBuilder {
63                        start_line: old_start,
64                        old_lines: Vec::new(),
65                        new_lines: Vec::new(),
66                    });
67                }
68            } else if let Some(ref mut hunk) = current_hunk {
69                if let Some(stripped) = line.strip_prefix('-') {
70                    hunk.old_lines.push(stripped.to_string());
71                } else if let Some(stripped) = line.strip_prefix('+') {
72                    hunk.new_lines.push(stripped.to_string());
73                } else if line.starts_with(' ') || line.is_empty() {
74                    let content = if line.is_empty() { "" } else { &line[1..] };
75                    hunk.old_lines.push(content.to_string());
76                    hunk.new_lines.push(content.to_string());
77                }
78            }
79        }
80
81        // Finalize last hunk
82        if let Some(hunk) = current_hunk {
83            if let Some(file) = &current_file {
84                hunks.push(hunk.build(file.clone()));
85            }
86        }
87
88        Ok(hunks)
89    }
90
91    fn apply_hunk(&self, content: &str, hunk: &PatchHunk) -> Result<String> {
92        let lines: Vec<&str> = content.lines().collect();
93        let mut result = Vec::new();
94
95        // Find matching location (fuzzy match)
96        let mut match_start = None;
97        for i in 0..=lines.len().saturating_sub(hunk.old_lines.len()) {
98            let mut matches = true;
99            for (j, old_line) in hunk.old_lines.iter().enumerate() {
100                if i + j >= lines.len() || lines[i + j].trim() != old_line.trim() {
101                    matches = false;
102                    break;
103                }
104            }
105            if matches {
106                match_start = Some(i);
107                break;
108            }
109        }
110
111        let match_start =
112            match_start.ok_or_else(|| anyhow::anyhow!("Could not find hunk location"))?;
113
114        // Build result
115        result.extend(lines[..match_start].iter().map(|s| s.to_string()));
116        result.extend(hunk.new_lines.clone());
117        result.extend(
118            lines[match_start + hunk.old_lines.len()..]
119                .iter()
120                .map(|s| s.to_string()),
121        );
122
123        Ok(result.join("\n"))
124    }
125}
126
127struct HunkBuilder {
128    start_line: usize,
129    old_lines: Vec<String>,
130    new_lines: Vec<String>,
131}
132
133impl HunkBuilder {
134    fn build(self, file: String) -> PatchHunk {
135        PatchHunk {
136            file,
137            start_line: self.start_line,
138            old_lines: self.old_lines,
139            new_lines: self.new_lines,
140        }
141    }
142}
143
144#[derive(Debug)]
145struct PatchHunk {
146    file: String,
147    start_line: usize,
148    old_lines: Vec<String>,
149    new_lines: Vec<String>,
150}
151
152#[async_trait]
153impl Tool for ApplyPatchTool {
154    fn id(&self) -> &str {
155        "apply_patch"
156    }
157    fn name(&self) -> &str {
158        "Apply Patch"
159    }
160    fn description(&self) -> &str {
161        "Apply a unified diff patch to files in the workspace."
162    }
163    fn parameters(&self) -> Value {
164        json!({
165            "type": "object",
166            "properties": {
167                "patch": {"type": "string", "description": "Unified diff patch content"},
168                "dry_run": {"type": "boolean", "default": false, "description": "Preview without applying"}
169            },
170            "required": ["patch"]
171        })
172    }
173
174    async fn execute(&self, params: Value) -> Result<ToolResult> {
175        let patch = match params.get("patch").and_then(|v| v.as_str()) {
176            Some(s) if !s.is_empty() => s.to_string(),
177            _ => {
178                return Ok(ToolResult::structured_error(
179                    "MISSING_FIELD",
180                    "apply_patch",
181                    "patch is required and must be a non-empty string containing a unified diff",
182                    Some(vec!["patch"]),
183                    Some(json!({
184                        "patch": "--- a/file.rs\n+++ b/file.rs\n@@ -1,3 +1,3 @@\n line1\n-old line\n+new line\n line3",
185                        "dry_run": false
186                    })),
187                ));
188            }
189        };
190        let dry_run = params
191            .get("dry_run")
192            .and_then(|v| v.as_bool())
193            .unwrap_or(false);
194
195        let hunks = self.parse_patch(&patch)?;
196
197        if hunks.is_empty() {
198            return Ok(ToolResult::structured_error(
199                "PARSE_ERROR",
200                "apply_patch",
201                "No valid hunks found in patch. Make sure the patch is in unified diff format with proper --- a/, +++ b/, and @@ headers.",
202                None,
203                Some(json!({
204                    "expected_format": "--- a/path/to/file\n+++ b/path/to/file\n@@ -start,count +start,count @@\n context line\n-removed line\n+added line\n context line",
205                    "hint": "Lines starting with - are removed, + are added, space are context"
206                })),
207            ));
208        }
209
210        let mut results = Vec::new();
211        let mut files_modified = Vec::new();
212
213        // Group hunks by file
214        let mut by_file: std::collections::HashMap<String, Vec<&PatchHunk>> =
215            std::collections::HashMap::new();
216        for hunk in &hunks {
217            by_file.entry(hunk.file.clone()).or_default().push(hunk);
218        }
219
220        for (file, file_hunks) in by_file {
221            let path = self.root.join(&file);
222
223            let mut content = if path.exists() {
224                std::fs::read_to_string(&path).context(format!("Failed to read {}", file))?
225            } else {
226                String::new()
227            };
228
229            for hunk in file_hunks {
230                match self.apply_hunk(&content, hunk) {
231                    Ok(new_content) => {
232                        content = new_content;
233                        results.push(format!(
234                            "✓ Applied hunk to {} at line {}",
235                            file, hunk.start_line
236                        ));
237                    }
238                    Err(e) => {
239                        results.push(format!("✗ Failed to apply hunk to {}: {}", file, e));
240                    }
241                }
242            }
243
244            if !dry_run {
245                if let Some(parent) = path.parent() {
246                    std::fs::create_dir_all(parent)?;
247                }
248                std::fs::write(&path, &content)?;
249                files_modified.push(file);
250            }
251        }
252
253        let action = if dry_run { "Would modify" } else { "Modified" };
254        let summary = format!(
255            "{} {} files:\n{}",
256            action,
257            files_modified.len(),
258            results.join("\n")
259        );
260
261        Ok(ToolResult::success(summary).with_metadata("files", json!(files_modified)))
262    }
263}