Skip to main content

codetether_agent/tool/
patch.rs

1//! Apply Patch Tool - Apply unified diff patches to files.
2
3use super::{Tool, ToolResult};
4use anyhow::{Context, Result};
5use async_trait::async_trait;
6use serde::Deserialize;
7use serde_json::{Value, json};
8use std::path::PathBuf;
9
10pub struct ApplyPatchTool {
11    root: PathBuf,
12}
13
14impl Default for ApplyPatchTool {
15    fn default() -> Self {
16        Self::new()
17    }
18}
19
20impl ApplyPatchTool {
21    pub fn new() -> Self {
22        Self {
23            root: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
24        }
25    }
26
27    #[allow(dead_code)]
28    pub fn with_root(root: PathBuf) -> Self {
29        Self { root }
30    }
31
32    fn parse_patch(&self, patch: &str) -> Result<Vec<PatchHunk>> {
33        let mut hunks = Vec::new();
34        let mut current_file: Option<String> = None;
35        let mut current_hunk: Option<HunkBuilder> = None;
36
37        for line in patch.lines() {
38            if line.starts_with("--- ") {
39                // Old file header, ignore for now
40            } else if line.starts_with("+++ ") {
41                // New file header
42                let path = line.strip_prefix("+++ ").unwrap_or("");
43                let path = path.strip_prefix("b/").unwrap_or(path);
44                let path = path.split('\t').next().unwrap_or(path);
45                current_file = Some(path.to_string());
46            } else if line.starts_with("@@ ") {
47                // Hunk header: @@ -start,count +start,count @@
48                if let Some(hunk) = current_hunk.take() {
49                    if let Some(file) = &current_file {
50                        hunks.push(hunk.build(file.clone()));
51                    }
52                }
53
54                let parts: Vec<&str> = line.split_whitespace().collect();
55                if parts.len() >= 3 {
56                    let old_range = parts[1].strip_prefix('-').unwrap_or(parts[1]);
57                    let old_start: usize = old_range
58                        .split(',')
59                        .next()
60                        .and_then(|s| s.parse().ok())
61                        .unwrap_or(1);
62
63                    current_hunk = Some(HunkBuilder {
64                        start_line: old_start,
65                        old_lines: Vec::new(),
66                        new_lines: Vec::new(),
67                    });
68                }
69            } else if let Some(ref mut hunk) = current_hunk {
70                if let Some(stripped) = line.strip_prefix('-') {
71                    hunk.old_lines.push(stripped.to_string());
72                } else if let Some(stripped) = line.strip_prefix('+') {
73                    hunk.new_lines.push(stripped.to_string());
74                } else if line.starts_with(' ') || line.is_empty() {
75                    let content = if line.is_empty() { "" } else { &line[1..] };
76                    hunk.old_lines.push(content.to_string());
77                    hunk.new_lines.push(content.to_string());
78                }
79            }
80        }
81
82        // Finalize last hunk
83        if let Some(hunk) = current_hunk {
84            if let Some(file) = &current_file {
85                hunks.push(hunk.build(file.clone()));
86            }
87        }
88
89        Ok(hunks)
90    }
91
92    fn apply_hunk(&self, content: &str, hunk: &PatchHunk) -> Result<String> {
93        let lines: Vec<&str> = content.lines().collect();
94        let mut result = Vec::new();
95
96        // Find matching location (fuzzy match)
97        let mut match_start = None;
98        for i in 0..=lines.len().saturating_sub(hunk.old_lines.len()) {
99            let mut matches = true;
100            for (j, old_line) in hunk.old_lines.iter().enumerate() {
101                if i + j >= lines.len() || lines[i + j].trim() != old_line.trim() {
102                    matches = false;
103                    break;
104                }
105            }
106            if matches {
107                match_start = Some(i);
108                break;
109            }
110        }
111
112        let match_start =
113            match_start.ok_or_else(|| anyhow::anyhow!("Could not find hunk location"))?;
114
115        // Build result
116        result.extend(lines[..match_start].iter().map(|s| s.to_string()));
117        result.extend(hunk.new_lines.clone());
118        result.extend(
119            lines[match_start + hunk.old_lines.len()..]
120                .iter()
121                .map(|s| s.to_string()),
122        );
123
124        Ok(result.join("\n"))
125    }
126}
127
128struct HunkBuilder {
129    start_line: usize,
130    old_lines: Vec<String>,
131    new_lines: Vec<String>,
132}
133
134impl HunkBuilder {
135    fn build(self, file: String) -> PatchHunk {
136        PatchHunk {
137            file,
138            start_line: self.start_line,
139            old_lines: self.old_lines,
140            new_lines: self.new_lines,
141        }
142    }
143}
144
145#[derive(Debug)]
146struct PatchHunk {
147    file: String,
148    start_line: usize,
149    old_lines: Vec<String>,
150    new_lines: Vec<String>,
151}
152
153#[derive(Deserialize)]
154struct Params {
155    patch: String,
156    #[serde(default)]
157    dry_run: bool,
158}
159
160#[async_trait]
161impl Tool for ApplyPatchTool {
162    fn id(&self) -> &str {
163        "apply_patch"
164    }
165    fn name(&self) -> &str {
166        "Apply Patch"
167    }
168    fn description(&self) -> &str {
169        "Apply a unified diff patch to files in the workspace."
170    }
171    fn parameters(&self) -> Value {
172        json!({
173            "type": "object",
174            "properties": {
175                "patch": {"type": "string", "description": "Unified diff patch content"},
176                "dry_run": {"type": "boolean", "default": false, "description": "Preview without applying"}
177            },
178            "required": ["patch"]
179        })
180    }
181
182    async fn execute(&self, params: Value) -> Result<ToolResult> {
183        let p: Params = serde_json::from_value(params).context("Invalid params")?;
184
185        let hunks = self.parse_patch(&p.patch)?;
186
187        if hunks.is_empty() {
188            return Ok(ToolResult::error("No valid hunks found in patch"));
189        }
190
191        let mut results = Vec::new();
192        let mut files_modified = Vec::new();
193
194        // Group hunks by file
195        let mut by_file: std::collections::HashMap<String, Vec<&PatchHunk>> =
196            std::collections::HashMap::new();
197        for hunk in &hunks {
198            by_file.entry(hunk.file.clone()).or_default().push(hunk);
199        }
200
201        for (file, file_hunks) in by_file {
202            let path = self.root.join(&file);
203
204            let mut content = if path.exists() {
205                std::fs::read_to_string(&path).context(format!("Failed to read {}", file))?
206            } else {
207                String::new()
208            };
209
210            for hunk in file_hunks {
211                match self.apply_hunk(&content, hunk) {
212                    Ok(new_content) => {
213                        content = new_content;
214                        results.push(format!(
215                            "✓ Applied hunk to {} at line {}",
216                            file, hunk.start_line
217                        ));
218                    }
219                    Err(e) => {
220                        results.push(format!("✗ Failed to apply hunk to {}: {}", file, e));
221                    }
222                }
223            }
224
225            if !p.dry_run {
226                if let Some(parent) = path.parent() {
227                    std::fs::create_dir_all(parent)?;
228                }
229                std::fs::write(&path, &content)?;
230                files_modified.push(file);
231            }
232        }
233
234        let action = if p.dry_run {
235            "Would modify"
236        } else {
237            "Modified"
238        };
239        let summary = format!(
240            "{} {} files:\n{}",
241            action,
242            files_modified.len(),
243            results.join("\n")
244        );
245
246        Ok(ToolResult::success(summary).with_metadata("files", json!(files_modified)))
247    }
248}