Skip to main content

codetether_agent/tool/
patch.rs

1//! Apply Patch Tool - Apply unified diff patches to files.
2
3use anyhow::{Context, Result};
4use async_trait::async_trait;
5use serde::Deserialize;
6use serde_json::{json, Value};
7use std::path::PathBuf;
8use super::{Tool, ToolResult};
9
10pub struct ApplyPatchTool {
11    root: PathBuf,
12}
13
14impl Default for ApplyPatchTool {
15    fn default() -> Self { Self::new() }
16}
17
18impl ApplyPatchTool {
19    pub fn new() -> Self {
20        Self { root: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")) }
21    }
22
23    #[allow(dead_code)]
24    pub fn with_root(root: PathBuf) -> Self {
25        Self { root }
26    }
27
28    fn parse_patch(&self, patch: &str) -> Result<Vec<PatchHunk>> {
29        let mut hunks = Vec::new();
30        let mut current_file: Option<String> = None;
31        let mut current_hunk: Option<HunkBuilder> = None;
32
33        for line in patch.lines() {
34            if line.starts_with("--- ") {
35                // Old file header, ignore for now
36            } else if line.starts_with("+++ ") {
37                // New file header
38                let path = line.strip_prefix("+++ ").unwrap_or("");
39                let path = path.strip_prefix("b/").unwrap_or(path);
40                let path = path.split('\t').next().unwrap_or(path);
41                current_file = Some(path.to_string());
42            } else if line.starts_with("@@ ") {
43                // Hunk header: @@ -start,count +start,count @@
44                if let Some(hunk) = current_hunk.take() {
45                    if let Some(file) = &current_file {
46                        hunks.push(hunk.build(file.clone()));
47                    }
48                }
49                
50                let parts: Vec<&str> = line.split_whitespace().collect();
51                if parts.len() >= 3 {
52                    let old_range = parts[1].strip_prefix('-').unwrap_or(parts[1]);
53                    let old_start: usize = old_range.split(',').next()
54                        .and_then(|s| s.parse().ok()).unwrap_or(1);
55                    
56                    current_hunk = Some(HunkBuilder {
57                        start_line: old_start,
58                        old_lines: Vec::new(),
59                        new_lines: Vec::new(),
60                    });
61                }
62            } else if let Some(ref mut hunk) = current_hunk {
63                if let Some(stripped) = line.strip_prefix('-') {
64                    hunk.old_lines.push(stripped.to_string());
65                } else if let Some(stripped) = line.strip_prefix('+') {
66                    hunk.new_lines.push(stripped.to_string());
67                } else if line.starts_with(' ') || line.is_empty() {
68                    let content = if line.is_empty() { "" } else { &line[1..] };
69                    hunk.old_lines.push(content.to_string());
70                    hunk.new_lines.push(content.to_string());
71                }
72            }
73        }
74        
75        // Finalize last hunk
76        if let Some(hunk) = current_hunk {
77            if let Some(file) = &current_file {
78                hunks.push(hunk.build(file.clone()));
79            }
80        }
81        
82        Ok(hunks)
83    }
84
85    fn apply_hunk(&self, content: &str, hunk: &PatchHunk) -> Result<String> {
86        let lines: Vec<&str> = content.lines().collect();
87        let mut result = Vec::new();
88        
89        // Find matching location (fuzzy match)
90        let mut match_start = None;
91        for i in 0..=lines.len().saturating_sub(hunk.old_lines.len()) {
92            let mut matches = true;
93            for (j, old_line) in hunk.old_lines.iter().enumerate() {
94                if i + j >= lines.len() || lines[i + j].trim() != old_line.trim() {
95                    matches = false;
96                    break;
97                }
98            }
99            if matches {
100                match_start = Some(i);
101                break;
102            }
103        }
104        
105        let match_start = match_start.ok_or_else(|| anyhow::anyhow!("Could not find hunk location"))?;
106        
107        // Build result
108        result.extend(lines[..match_start].iter().map(|s| s.to_string()));
109        result.extend(hunk.new_lines.clone());
110        result.extend(lines[match_start + hunk.old_lines.len()..].iter().map(|s| s.to_string()));
111        
112        Ok(result.join("\n"))
113    }
114}
115
116struct HunkBuilder {
117    start_line: usize,
118    old_lines: Vec<String>,
119    new_lines: Vec<String>,
120}
121
122impl HunkBuilder {
123    fn build(self, file: String) -> PatchHunk {
124        PatchHunk {
125            file,
126            start_line: self.start_line,
127            old_lines: self.old_lines,
128            new_lines: self.new_lines,
129        }
130    }
131}
132
133#[derive(Debug)]
134struct PatchHunk {
135    file: String,
136    start_line: usize,
137    old_lines: Vec<String>,
138    new_lines: Vec<String>,
139}
140
141#[derive(Deserialize)]
142struct Params {
143    patch: String,
144    #[serde(default)]
145    dry_run: bool,
146}
147
148#[async_trait]
149impl Tool for ApplyPatchTool {
150    fn id(&self) -> &str { "apply_patch" }
151    fn name(&self) -> &str { "Apply Patch" }
152    fn description(&self) -> &str { "Apply a unified diff patch to files in the workspace." }
153    fn parameters(&self) -> Value {
154        json!({
155            "type": "object",
156            "properties": {
157                "patch": {"type": "string", "description": "Unified diff patch content"},
158                "dry_run": {"type": "boolean", "default": false, "description": "Preview without applying"}
159            },
160            "required": ["patch"]
161        })
162    }
163
164    async fn execute(&self, params: Value) -> Result<ToolResult> {
165        let p: Params = serde_json::from_value(params).context("Invalid params")?;
166        
167        let hunks = self.parse_patch(&p.patch)?;
168        
169        if hunks.is_empty() {
170            return Ok(ToolResult::error("No valid hunks found in patch"));
171        }
172        
173        let mut results = Vec::new();
174        let mut files_modified = Vec::new();
175        
176        // Group hunks by file
177        let mut by_file: std::collections::HashMap<String, Vec<&PatchHunk>> = std::collections::HashMap::new();
178        for hunk in &hunks {
179            by_file.entry(hunk.file.clone()).or_default().push(hunk);
180        }
181        
182        for (file, file_hunks) in by_file {
183            let path = self.root.join(&file);
184            
185            let mut content = if path.exists() {
186                std::fs::read_to_string(&path).context(format!("Failed to read {}", file))?
187            } else {
188                String::new()
189            };
190            
191            for hunk in file_hunks {
192                match self.apply_hunk(&content, hunk) {
193                    Ok(new_content) => {
194                        content = new_content;
195                        results.push(format!("✓ Applied hunk to {} at line {}", file, hunk.start_line));
196                    }
197                    Err(e) => {
198                        results.push(format!("✗ Failed to apply hunk to {}: {}", file, e));
199                    }
200                }
201            }
202            
203            if !p.dry_run {
204                if let Some(parent) = path.parent() {
205                    std::fs::create_dir_all(parent)?;
206                }
207                std::fs::write(&path, &content)?;
208                files_modified.push(file);
209            }
210        }
211        
212        let action = if p.dry_run { "Would modify" } else { "Modified" };
213        let summary = format!("{} {} files:\n{}", action, files_modified.len(), results.join("\n"));
214        
215        Ok(ToolResult::success(summary).with_metadata("files", json!(files_modified)))
216    }
217}