1use std::io;
6use std::path::PathBuf;
7use std::pin::Pin;
8use std::sync::Arc;
9
10use agent_client_protocol_schema::{
11 Content, ContentBlock, Diff, TextContent, ToolCallContent, ToolCallLocation,
12 ToolCallUpdateFields, ToolKind,
13};
14use defect_agent::error::BoxError;
15use defect_agent::fs::{FsBackend, FsError};
16use defect_agent::tool::{
17 SafetyClass, Tool, ToolCallDescription, ToolContext, ToolError, ToolEvent, ToolSchema,
18 ToolStream,
19};
20use futures::future::BoxFuture;
21use futures::stream;
22use serde::{Deserialize, Serialize};
23use serde_json::json;
24
25pub struct EditFileTool {
26 schema: ToolSchema,
27}
28
29impl EditFileTool {
30 pub fn new() -> Self {
31 Self {
32 schema: ToolSchema {
33 name: "edit_file".to_string(),
34 description: "Replace a string in a UTF-8 text file. \
35 Performs an exact string replacement; \
36 fails if `old_string` is not found, or if it appears multiple times \
37 unless `replace_all` is true. \
38 Path must be inside the workspace root."
39 .to_string(),
40 input_schema: json!({
41 "type": "object",
42 "properties": {
43 "path": {
44 "type": "string",
45 "description": "Absolute path or path relative to the session cwd."
46 },
47 "old_string": {
48 "type": "string",
49 "description": "Exact text to replace. Must match a unique substring \
50 unless `replace_all` is true. Empty string is rejected."
51 },
52 "new_string": {
53 "type": "string",
54 "description": "Replacement text. Must differ from old_string."
55 },
56 "replace_all": {
57 "type": "boolean",
58 "description": "When true, replace every occurrence; when false (default), \
59 require old_string to appear exactly once.",
60 "default": false
61 }
62 },
63 "required": ["path", "old_string", "new_string"]
64 }),
65 },
66 }
67 }
68}
69
70impl Default for EditFileTool {
71 fn default() -> Self {
72 Self::new()
73 }
74}
75
76#[derive(Debug, Deserialize)]
77struct EditArgs {
78 path: String,
79 old_string: String,
80 new_string: String,
81 #[serde(default)]
82 replace_all: bool,
83}
84
85#[derive(Debug, Serialize)]
86struct EditFileOutput {
87 matches_replaced: u32,
88 bytes_before: u64,
89 bytes_after: u64,
90}
91
92impl Tool for EditFileTool {
93 fn schema(&self) -> &ToolSchema {
94 &self.schema
95 }
96
97 fn safety_hint(&self, _args: &serde_json::Value) -> SafetyClass {
98 SafetyClass::Mutating
99 }
100
101 fn describe<'a>(
102 &'a self,
103 args: &'a serde_json::Value,
104 _ctx: ToolContext<'a>,
105 ) -> BoxFuture<'a, ToolCallDescription> {
106 Box::pin(async move {
107 let path = args.get("path").and_then(|v| v.as_str()).unwrap_or("");
108 let title = if path.is_empty() {
109 "Edit".to_string()
110 } else {
111 format!("Edit {path}")
112 };
113 let mut fields = ToolCallUpdateFields::default();
114 fields.title = Some(title);
115 fields.kind = Some(ToolKind::Edit);
116 if !path.is_empty() {
117 fields.locations = Some(vec![ToolCallLocation::new(PathBuf::from(path))]);
118 }
119 ToolCallDescription { fields }
120 })
121 }
122
123 fn execute(&self, args: serde_json::Value, ctx: ToolContext<'_>) -> ToolStream {
124 let cancel = ctx.cancel.clone();
125 let fs = ctx.fs.clone();
126 let fut = async move { run_edit(args, cancel, fs).await };
127 let s: Pin<Box<dyn futures::Stream<Item = ToolEvent> + Send>> = Box::pin(stream::once(fut));
128 s
129 }
130}
131
132async fn run_edit(
133 args: serde_json::Value,
134 cancel: tokio_util::sync::CancellationToken,
135 fs: Arc<dyn FsBackend>,
136) -> ToolEvent {
137 let parsed: EditArgs = match serde_json::from_value(args) {
138 Ok(v) => v,
139 Err(err) => return ToolEvent::Failed(ToolError::InvalidArgs(BoxError::new(err))),
140 };
141
142 if parsed.old_string.is_empty() {
143 return ToolEvent::Failed(ToolError::InvalidArgs(BoxError::new(arg_err(
144 "old_string must not be empty",
145 ))));
146 }
147 if parsed.old_string == parsed.new_string {
148 return ToolEvent::Failed(ToolError::InvalidArgs(BoxError::new(arg_err(
149 "old_string and new_string must differ",
150 ))));
151 }
152
153 let path = PathBuf::from(&parsed.path);
154
155 let read_fut = fs.read_text(path.clone(), None, None);
156 let old_content = tokio::select! {
157 biased;
158 () = cancel.cancelled() => return ToolEvent::Failed(ToolError::Canceled),
159 r = read_fut => match r {
160 Ok(t) => t,
161 Err(e) => return ToolEvent::Failed(map_fs_err(e)),
162 },
163 };
164
165 let baseline_fp = fs.fingerprint(path.clone()).await.ok();
174
175 let (new_content, matches_replaced) = match apply_edit(
176 &old_content,
177 &parsed.old_string,
178 &parsed.new_string,
179 parsed.replace_all,
180 ) {
181 Ok(v) => v,
182 Err(EditOutcome::NotFound) => {
183 let msg = if whitespace_insensitive_block_count(&old_content, &parsed.old_string) > 0 {
190 "old_string not found. A block matching it except for leading/trailing \
191 whitespace exists — the indentation differs. Re-read the file and copy the \
192 exact whitespace, or use replace_all if it is intentionally repeated."
193 } else {
194 "old_string not found"
195 };
196 return ToolEvent::Failed(ToolError::InvalidArgs(BoxError::new(arg_err(msg))));
197 }
198 Err(EditOutcome::Ambiguous(n)) => {
199 return ToolEvent::Failed(ToolError::InvalidArgs(BoxError::new(arg_err(&format!(
200 "old_string matched {n} times; add unique context or set replace_all"
201 )))));
202 }
203 };
204
205 let bytes_before = old_content.len() as u64;
206 let bytes_after = new_content.len() as u64;
207
208 if let Some(baseline) = baseline_fp {
212 match fs.fingerprint(path.clone()).await {
213 Ok(current) if current != baseline => {
214 return ToolEvent::Failed(map_fs_err(FsError::Conflict(path)));
215 }
216 _ => {}
219 }
220 }
221
222 let write_fut = fs.write_text(path.clone(), new_content.clone());
223 tokio::select! {
224 biased;
225 () = cancel.cancelled() => return ToolEvent::Failed(ToolError::Canceled),
226 r = write_fut => {
227 if let Err(e) = r {
228 return ToolEvent::Failed(map_fs_err(e));
229 }
230 }
231 }
232
233 let raw_output = serde_json::to_value(EditFileOutput {
234 matches_replaced,
235 bytes_before,
236 bytes_after,
237 })
238 .unwrap_or(serde_json::Value::Null);
239
240 let diff = Diff::new(path, new_content).old_text(Some(old_content));
241 let mut fields = ToolCallUpdateFields::default();
242 fields.content = Some(vec![
243 ToolCallContent::Diff(diff),
244 ToolCallContent::Content(Content::new(ContentBlock::Text(TextContent::new(format!(
245 "Replaced {matches_replaced} occurrence(s)"
246 ))))),
247 ]);
248 fields.raw_output = Some(raw_output);
249 ToolEvent::Completed(fields)
250}
251
252enum EditOutcome {
253 NotFound,
254 Ambiguous(u32),
256}
257
258fn apply_edit(
259 text: &str,
260 old: &str,
261 new: &str,
262 replace_all: bool,
263) -> Result<(String, u32), EditOutcome> {
264 if replace_all {
265 let count = text.matches(old).count() as u32;
266 if count == 0 {
267 return Err(EditOutcome::NotFound);
268 }
269 Ok((text.replace(old, new), count))
270 } else {
271 let count = text.matches(old).count();
272 match count {
273 0 => Err(EditOutcome::NotFound),
274 1 => Ok((text.replacen(old, new, 1), 1)),
275 n => Err(EditOutcome::Ambiguous(n as u32)),
276 }
277 }
278}
279
280fn whitespace_insensitive_block_count(text: &str, needle: &str) -> usize {
284 let needle_lines: Vec<&str> = needle.lines().map(str::trim).collect();
285 if needle_lines.is_empty() {
286 return 0;
287 }
288 let text_lines: Vec<&str> = text.lines().map(str::trim).collect();
289 if text_lines.len() < needle_lines.len() {
290 return 0;
291 }
292 text_lines
293 .windows(needle_lines.len())
294 .filter(|w| *w == needle_lines.as_slice())
295 .count()
296}
297
298fn map_fs_err(e: FsError) -> ToolError {
299 ToolError::Execution(BoxError::new(e))
300}
301
302fn arg_err(msg: &str) -> io::Error {
303 io::Error::new(io::ErrorKind::InvalidInput, msg.to_string())
304}