1use std::io;
6use std::path::PathBuf;
7use std::pin::Pin;
8use std::sync::Arc;
9
10use agent_client_protocol_schema::{
11 Content, ContentBlock, Diff, TextContent, ToolCallContent, ToolCallLocation,
12 ToolCallUpdateFields, ToolKind,
13};
14use defect_agent::error::BoxError;
15use defect_agent::fs::{FsBackend, FsError};
16use defect_agent::tool::{
17 SafetyClass, Tool, ToolCallDescription, ToolContext, ToolError, ToolEvent, ToolSchema,
18 ToolStream,
19};
20use futures::future::BoxFuture;
21use futures::stream;
22use serde::{Deserialize, Serialize};
23use serde_json::json;
24
25pub struct EditFileTool {
26 schema: ToolSchema,
27}
28
29impl EditFileTool {
30 pub fn new() -> Self {
31 Self {
32 schema: ToolSchema {
33 name: "edit_file".to_string(),
34 description: "Replace a string in a UTF-8 text file. \
35 Performs an exact string replacement; \
36 fails if `old_string` is not found, or if it appears multiple times \
37 unless `replace_all` is true. \
38 Path must be inside the workspace root."
39 .to_string(),
40 input_schema: json!({
41 "type": "object",
42 "properties": {
43 "path": {
44 "type": "string",
45 "description": "Absolute path or path relative to the session cwd."
46 },
47 "old_string": {
48 "type": "string",
49 "description": "Exact text to replace. Must match a unique substring \
50 unless `replace_all` is true. Empty string is rejected."
51 },
52 "new_string": {
53 "type": "string",
54 "description": "Replacement text. Must differ from old_string."
55 },
56 "replace_all": {
57 "type": "boolean",
58 "description": "When true, replace every occurrence; when false (default), \
59 require old_string to appear exactly once.",
60 "default": false
61 }
62 },
63 "required": ["path", "old_string", "new_string"]
64 }),
65 },
66 }
67 }
68}
69
70impl Default for EditFileTool {
71 fn default() -> Self {
72 Self::new()
73 }
74}
75
76#[derive(Debug, Deserialize)]
77struct EditArgs {
78 path: String,
79 old_string: String,
80 new_string: String,
81 #[serde(default)]
82 replace_all: bool,
83}
84
85#[derive(Debug, Serialize)]
86struct EditFileOutput {
87 matches_replaced: u32,
88 bytes_before: u64,
89 bytes_after: u64,
90}
91
92impl Tool for EditFileTool {
93 fn schema(&self) -> &ToolSchema {
94 &self.schema
95 }
96
97 fn safety_hint(&self, _args: &serde_json::Value) -> SafetyClass {
98 SafetyClass::Mutating
99 }
100
101 fn describe<'a>(
102 &'a self,
103 args: &'a serde_json::Value,
104 _ctx: ToolContext<'a>,
105 ) -> BoxFuture<'a, ToolCallDescription> {
106 Box::pin(async move {
107 let path = args.get("path").and_then(|v| v.as_str()).unwrap_or("");
108 let title = if path.is_empty() {
109 "Edit".to_string()
110 } else {
111 format!("Edit {path}")
112 };
113 let mut fields = ToolCallUpdateFields::default();
114 fields.title = Some(title);
115 fields.kind = Some(ToolKind::Edit);
116 if !path.is_empty() {
117 fields.locations = Some(vec![ToolCallLocation::new(PathBuf::from(path))]);
118 }
119 ToolCallDescription { fields }
120 })
121 }
122
123 fn execute(&self, args: serde_json::Value, ctx: ToolContext<'_>) -> ToolStream {
124 let cancel = ctx.cancel.clone();
125 let fs = ctx.fs.clone();
126 let fut = async move { run_edit(args, cancel, fs).await };
127 let s: Pin<Box<dyn futures::Stream<Item = ToolEvent> + Send>> = Box::pin(stream::once(fut));
128 s
129 }
130}
131
132async fn run_edit(
133 args: serde_json::Value,
134 cancel: tokio_util::sync::CancellationToken,
135 fs: Arc<dyn FsBackend>,
136) -> ToolEvent {
137 let parsed: EditArgs = match serde_json::from_value(args) {
138 Ok(v) => v,
139 Err(err) => return ToolEvent::Failed(ToolError::InvalidArgs(BoxError::new(err))),
140 };
141
142 if parsed.old_string.is_empty() {
143 return ToolEvent::Failed(ToolError::InvalidArgs(BoxError::new(arg_err(
144 "old_string must not be empty",
145 ))));
146 }
147 if parsed.old_string == parsed.new_string {
148 return ToolEvent::Failed(ToolError::InvalidArgs(BoxError::new(arg_err(
149 "old_string and new_string must differ",
150 ))));
151 }
152
153 let path = PathBuf::from(&parsed.path);
154
155 let read_fut = fs.read_text(path.clone(), None, None);
156 let old_content = tokio::select! {
157 biased;
158 () = cancel.cancelled() => return ToolEvent::Failed(ToolError::Canceled),
159 r = read_fut => match r {
160 Ok(t) => t,
161 Err(e) => return ToolEvent::Failed(map_fs_err(e)),
162 },
163 };
164
165 let baseline_fp = fs.fingerprint(path.clone()).await.ok();
174
175 let (new_content, matches_replaced) = match apply_edit(
176 &old_content,
177 &parsed.old_string,
178 &parsed.new_string,
179 parsed.replace_all,
180 ) {
181 Ok(v) => v,
182 Err(EditOutcome::NotFound) => {
183 return ToolEvent::Failed(ToolError::InvalidArgs(BoxError::new(arg_err(
184 "old_string not found",
185 ))));
186 }
187 Err(EditOutcome::Ambiguous(n)) => {
188 return ToolEvent::Failed(ToolError::InvalidArgs(BoxError::new(arg_err(&format!(
189 "old_string matched {n} times; add unique context or set replace_all"
190 )))));
191 }
192 };
193
194 let bytes_before = old_content.len() as u64;
195 let bytes_after = new_content.len() as u64;
196
197 if let Some(baseline) = baseline_fp {
201 match fs.fingerprint(path.clone()).await {
202 Ok(current) if current != baseline => {
203 return ToolEvent::Failed(map_fs_err(FsError::Conflict(path)));
204 }
205 _ => {}
208 }
209 }
210
211 let write_fut = fs.write_text(path.clone(), new_content.clone());
212 tokio::select! {
213 biased;
214 () = cancel.cancelled() => return ToolEvent::Failed(ToolError::Canceled),
215 r = write_fut => {
216 if let Err(e) = r {
217 return ToolEvent::Failed(map_fs_err(e));
218 }
219 }
220 }
221
222 let raw_output = serde_json::to_value(EditFileOutput {
223 matches_replaced,
224 bytes_before,
225 bytes_after,
226 })
227 .unwrap_or(serde_json::Value::Null);
228
229 let diff = Diff::new(path, new_content).old_text(Some(old_content));
230 let mut fields = ToolCallUpdateFields::default();
231 fields.content = Some(vec![
232 ToolCallContent::Diff(diff),
233 ToolCallContent::Content(Content::new(ContentBlock::Text(TextContent::new(format!(
234 "Replaced {matches_replaced} occurrence(s)"
235 ))))),
236 ]);
237 fields.raw_output = Some(raw_output);
238 ToolEvent::Completed(fields)
239}
240
241enum EditOutcome {
242 NotFound,
243 Ambiguous(u32),
245}
246
247fn apply_edit(
248 text: &str,
249 old: &str,
250 new: &str,
251 replace_all: bool,
252) -> Result<(String, u32), EditOutcome> {
253 if replace_all {
254 let count = text.matches(old).count() as u32;
255 if count == 0 {
256 return Err(EditOutcome::NotFound);
257 }
258 Ok((text.replace(old, new), count))
259 } else {
260 let count = text.matches(old).count();
261 match count {
262 0 => Err(EditOutcome::NotFound),
263 1 => Ok((text.replacen(old, new, 1), 1)),
264 n => Err(EditOutcome::Ambiguous(n as u32)),
265 }
266 }
267}
268
269fn map_fs_err(e: FsError) -> ToolError {
270 ToolError::Execution(BoxError::new(e))
271}
272
273fn arg_err(msg: &str) -> io::Error {
274 io::Error::new(io::ErrorKind::InvalidInput, msg.to_string())
275}