cersei_tools/
apply_patch.rs1use super::*;
4use serde::Deserialize;
5use std::path::PathBuf;
6
7pub struct ApplyPatchTool;
8
9#[async_trait]
10impl Tool for ApplyPatchTool {
11 fn name(&self) -> &str {
12 "ApplyPatch"
13 }
14
15 fn description(&self) -> &str {
16 "Apply a unified diff patch to one or more files. The patch should be in standard \
17 unified diff format (as produced by `diff -u` or `git diff`). Supports creating \
18 new files and deleting files."
19 }
20
21 fn permission_level(&self) -> PermissionLevel {
22 PermissionLevel::Write
23 }
24 fn category(&self) -> ToolCategory {
25 ToolCategory::FileSystem
26 }
27
28 fn input_schema(&self) -> Value {
29 serde_json::json!({
30 "type": "object",
31 "properties": {
32 "patch": {
33 "type": "string",
34 "description": "Unified diff patch content"
35 }
36 },
37 "required": ["patch"]
38 })
39 }
40
41 async fn execute(&self, input: Value, ctx: &ToolContext) -> ToolResult {
42 #[derive(Deserialize)]
43 struct Input {
44 patch: String,
45 }
46
47 let input: Input = match serde_json::from_value(input) {
48 Ok(i) => i,
49 Err(e) => return ToolResult::error(format!("Invalid input: {e}")),
50 };
51
52 match apply_unified_patch(&input.patch, &ctx.working_dir) {
53 Ok(files) => {
54 if files.is_empty() {
55 ToolResult::success("Patch applied (no files changed).")
56 } else {
57 ToolResult::success(format!(
58 "Patch applied to {} file(s):\n{}",
59 files.len(),
60 files
61 .iter()
62 .map(|f| format!(" {}", f.display()))
63 .collect::<Vec<_>>()
64 .join("\n")
65 ))
66 }
67 }
68 Err(e) => ToolResult::error(format!("Failed to apply patch: {e}")),
69 }
70 }
71}
72
73fn apply_unified_patch(
75 patch: &str,
76 working_dir: &std::path::Path,
77) -> std::result::Result<Vec<PathBuf>, String> {
78 let mut modified = Vec::new();
79 let mut current_file: Option<PathBuf> = None;
80 let mut original_lines: Vec<String> = Vec::new();
81 let mut hunks: Vec<Hunk> = Vec::new();
82
83 let lines: Vec<&str> = patch.lines().collect();
85 let mut i = 0;
86
87 while i < lines.len() {
88 let line = lines[i];
89
90 if line.starts_with("--- ") {
91 if let Some(ref file) = current_file {
93 apply_hunks(file, &original_lines, &hunks)?;
94 modified.push(file.clone());
95 }
96
97 i += 1;
99 if i >= lines.len() || !lines[i].starts_with("+++ ") {
100 return Err("Expected +++ line after ---".into());
101 }
102
103 let target = lines[i].strip_prefix("+++ ").unwrap_or(lines[i]);
104 let target = target.split('\t').next().unwrap_or(target); let target = target.strip_prefix("b/").unwrap_or(target); let file_path = working_dir.join(target);
108 original_lines = if file_path.exists() {
109 std::fs::read_to_string(&file_path)
110 .map_err(|e| format!("Cannot read {}: {e}", file_path.display()))?
111 .lines()
112 .map(String::from)
113 .collect()
114 } else {
115 Vec::new() };
117
118 current_file = Some(file_path);
119 hunks.clear();
120 i += 1;
121 continue;
122 }
123
124 if line.starts_with("@@ ") {
125 if let Some(hunk) = parse_hunk_header(line) {
126 let mut hunk_lines = Vec::new();
127 i += 1;
128 while i < lines.len()
129 && !lines[i].starts_with("@@ ")
130 && !lines[i].starts_with("--- ")
131 && !lines[i].starts_with("diff ")
132 {
133 hunk_lines.push(lines[i].to_string());
134 i += 1;
135 }
136 hunks.push(Hunk {
137 old_start: hunk.0,
138 old_count: hunk.1,
139 new_start: hunk.2,
140 new_count: hunk.3,
141 lines: hunk_lines,
142 });
143 continue;
144 }
145 }
146
147 i += 1;
148 }
149
150 if let Some(ref file) = current_file {
152 apply_hunks(file, &original_lines, &hunks)?;
153 modified.push(file.clone());
154 }
155
156 Ok(modified)
157}
158
159struct Hunk {
160 old_start: usize,
161 old_count: usize,
162 new_start: usize,
163 new_count: usize,
164 lines: Vec<String>,
165}
166
167fn parse_hunk_header(line: &str) -> Option<(usize, usize, usize, usize)> {
168 let line = line.strip_prefix("@@ -")?;
170 let (old, rest) = line.split_once(' ')?;
171 let rest = rest.strip_prefix('+')?;
172 let (new, _) = rest
173 .split_once(' ')
174 .unwrap_or((rest.trim_end_matches(" @@"), ""));
175 let new = new.trim_end_matches(" @@");
176
177 let parse_range = |s: &str| -> (usize, usize) {
178 if let Some((start, count)) = s.split_once(',') {
179 (start.parse().unwrap_or(1), count.parse().unwrap_or(0))
180 } else {
181 (s.parse().unwrap_or(1), 1)
182 }
183 };
184
185 let (os, oc) = parse_range(old);
186 let (ns, nc) = parse_range(new);
187 Some((os, oc, ns, nc))
188}
189
190fn apply_hunks(
191 file: &std::path::Path,
192 original: &[String],
193 hunks: &[Hunk],
194) -> std::result::Result<(), String> {
195 let mut result = original.to_vec();
196 let mut offset: isize = 0;
197
198 for hunk in hunks {
199 let start = ((hunk.old_start as isize - 1) + offset).max(0) as usize;
200 let mut new_lines = Vec::new();
201 let mut old_removed = 0usize;
202
203 for line in &hunk.lines {
204 if let Some(content) = line.strip_prefix('+') {
205 new_lines.push(content.to_string());
206 } else if let Some(_) = line.strip_prefix('-') {
207 old_removed += 1;
208 } else if let Some(content) = line.strip_prefix(' ') {
209 new_lines.push(content.to_string());
210 old_removed += 1; } else {
212 new_lines.push(line.to_string());
214 old_removed += 1;
215 }
216 }
217
218 let end = (start + old_removed).min(result.len());
220 result.splice(start..end, new_lines.iter().cloned());
221 offset += new_lines.len() as isize - old_removed as isize;
222 }
223
224 if let Some(parent) = file.parent() {
226 std::fs::create_dir_all(parent).map_err(|e| format!("Cannot create directory: {e}"))?;
227 }
228 std::fs::write(file, result.join("\n") + "\n")
229 .map_err(|e| format!("Cannot write {}: {e}", file.display()))?;
230
231 Ok(())
232}