Skip to main content

cersei_tools/
apply_patch.rs

1//! ApplyPatch tool: apply unified diff patches to files.
2
3use super::*;
4use serde::Deserialize;
5use std::path::PathBuf;
6
7pub struct ApplyPatchTool;
8
9#[async_trait]
10impl Tool for ApplyPatchTool {
11    fn name(&self) -> &str {
12        "ApplyPatch"
13    }
14
15    fn description(&self) -> &str {
16        "Apply a unified diff patch to one or more files. The patch should be in standard \
17         unified diff format (as produced by `diff -u` or `git diff`). Supports creating \
18         new files and deleting files."
19    }
20
21    fn permission_level(&self) -> PermissionLevel {
22        PermissionLevel::Write
23    }
24    fn category(&self) -> ToolCategory {
25        ToolCategory::FileSystem
26    }
27
28    fn input_schema(&self) -> Value {
29        serde_json::json!({
30            "type": "object",
31            "properties": {
32                "patch": {
33                    "type": "string",
34                    "description": "Unified diff patch content"
35                }
36            },
37            "required": ["patch"]
38        })
39    }
40
41    async fn execute(&self, input: Value, ctx: &ToolContext) -> ToolResult {
42        #[derive(Deserialize)]
43        struct Input {
44            patch: String,
45        }
46
47        let input: Input = match serde_json::from_value(input) {
48            Ok(i) => i,
49            Err(e) => return ToolResult::error(format!("Invalid input: {e}")),
50        };
51
52        match apply_unified_patch(&input.patch, &ctx.working_dir) {
53            Ok(files) => {
54                if files.is_empty() {
55                    ToolResult::success("Patch applied (no files changed).")
56                } else {
57                    ToolResult::success(format!(
58                        "Patch applied to {} file(s):\n{}",
59                        files.len(),
60                        files
61                            .iter()
62                            .map(|f| format!("  {}", f.display()))
63                            .collect::<Vec<_>>()
64                            .join("\n")
65                    ))
66                }
67            }
68            Err(e) => ToolResult::error(format!("Failed to apply patch: {e}")),
69        }
70    }
71}
72
73/// Apply a unified diff patch. Returns list of modified files.
74fn apply_unified_patch(
75    patch: &str,
76    working_dir: &std::path::Path,
77) -> std::result::Result<Vec<PathBuf>, String> {
78    let mut modified = Vec::new();
79    let mut current_file: Option<PathBuf> = None;
80    let mut original_lines: Vec<String> = Vec::new();
81    let mut hunks: Vec<Hunk> = Vec::new();
82
83    // Parse patch into files and hunks
84    let lines: Vec<&str> = patch.lines().collect();
85    let mut i = 0;
86
87    while i < lines.len() {
88        let line = lines[i];
89
90        if line.starts_with("--- ") {
91            // Flush previous file
92            if let Some(ref file) = current_file {
93                apply_hunks(file, &original_lines, &hunks)?;
94                modified.push(file.clone());
95            }
96
97            // Parse file paths
98            i += 1;
99            if i >= lines.len() || !lines[i].starts_with("+++ ") {
100                return Err("Expected +++ line after ---".into());
101            }
102
103            let target = lines[i].strip_prefix("+++ ").unwrap_or(lines[i]);
104            let target = target.split('\t').next().unwrap_or(target); // Strip timestamp
105            let target = target.strip_prefix("b/").unwrap_or(target); // Strip git prefix
106
107            let file_path = working_dir.join(target);
108            original_lines = if file_path.exists() {
109                std::fs::read_to_string(&file_path)
110                    .map_err(|e| format!("Cannot read {}: {e}", file_path.display()))?
111                    .lines()
112                    .map(String::from)
113                    .collect()
114            } else {
115                Vec::new() // New file
116            };
117
118            current_file = Some(file_path);
119            hunks.clear();
120            i += 1;
121            continue;
122        }
123
124        if line.starts_with("@@ ") {
125            if let Some(hunk) = parse_hunk_header(line) {
126                let mut hunk_lines = Vec::new();
127                i += 1;
128                while i < lines.len()
129                    && !lines[i].starts_with("@@ ")
130                    && !lines[i].starts_with("--- ")
131                    && !lines[i].starts_with("diff ")
132                {
133                    hunk_lines.push(lines[i].to_string());
134                    i += 1;
135                }
136                hunks.push(Hunk {
137                    old_start: hunk.0,
138                    old_count: hunk.1,
139                    new_start: hunk.2,
140                    new_count: hunk.3,
141                    lines: hunk_lines,
142                });
143                continue;
144            }
145        }
146
147        i += 1;
148    }
149
150    // Flush last file
151    if let Some(ref file) = current_file {
152        apply_hunks(file, &original_lines, &hunks)?;
153        modified.push(file.clone());
154    }
155
156    Ok(modified)
157}
158
159struct Hunk {
160    old_start: usize,
161    old_count: usize,
162    new_start: usize,
163    new_count: usize,
164    lines: Vec<String>,
165}
166
167fn parse_hunk_header(line: &str) -> Option<(usize, usize, usize, usize)> {
168    // @@ -old_start,old_count +new_start,new_count @@
169    let line = line.strip_prefix("@@ -")?;
170    let (old, rest) = line.split_once(' ')?;
171    let rest = rest.strip_prefix('+')?;
172    let (new, _) = rest
173        .split_once(' ')
174        .unwrap_or((rest.trim_end_matches(" @@"), ""));
175    let new = new.trim_end_matches(" @@");
176
177    let parse_range = |s: &str| -> (usize, usize) {
178        if let Some((start, count)) = s.split_once(',') {
179            (start.parse().unwrap_or(1), count.parse().unwrap_or(0))
180        } else {
181            (s.parse().unwrap_or(1), 1)
182        }
183    };
184
185    let (os, oc) = parse_range(old);
186    let (ns, nc) = parse_range(new);
187    Some((os, oc, ns, nc))
188}
189
190fn apply_hunks(
191    file: &std::path::Path,
192    original: &[String],
193    hunks: &[Hunk],
194) -> std::result::Result<(), String> {
195    let mut result = original.to_vec();
196    let mut offset: isize = 0;
197
198    for hunk in hunks {
199        let start = ((hunk.old_start as isize - 1) + offset).max(0) as usize;
200        let mut new_lines = Vec::new();
201        let mut old_removed = 0usize;
202
203        for line in &hunk.lines {
204            if let Some(content) = line.strip_prefix('+') {
205                new_lines.push(content.to_string());
206            } else if let Some(_) = line.strip_prefix('-') {
207                old_removed += 1;
208            } else if let Some(content) = line.strip_prefix(' ') {
209                new_lines.push(content.to_string());
210                old_removed += 1; // context line replaces itself
211            } else {
212                // No prefix = context line
213                new_lines.push(line.to_string());
214                old_removed += 1;
215            }
216        }
217
218        // Replace old lines with new lines
219        let end = (start + old_removed).min(result.len());
220        result.splice(start..end, new_lines.iter().cloned());
221        offset += new_lines.len() as isize - old_removed as isize;
222    }
223
224    // Write result
225    if let Some(parent) = file.parent() {
226        std::fs::create_dir_all(parent).map_err(|e| format!("Cannot create directory: {e}"))?;
227    }
228    std::fs::write(file, result.join("\n") + "\n")
229        .map_err(|e| format!("Cannot write {}: {e}", file.display()))?;
230
231    Ok(())
232}