1use std::path::PathBuf;
6use std::pin::Pin;
7use std::sync::Arc;
8
9use agent_client_protocol_schema::{
10 Content, ContentBlock, Diff, TextContent, ToolCallContent, ToolCallLocation,
11 ToolCallUpdateFields, ToolKind,
12};
13
14use defect_agent::error::BoxError;
15use defect_agent::fs::{FsBackend, FsError};
16use defect_agent::tool::{
17 SafetyClass, Tool, ToolCallDescription, ToolContext, ToolError, ToolEvent, ToolSchema,
18 ToolStream,
19};
20use futures::future::BoxFuture;
21use futures::stream;
22use serde::{Deserialize, Serialize};
23use serde_json::json;
24
25const MAX_WRITE_BYTES: usize = 10 * 1024 * 1024;
26
27pub struct WriteFileTool {
28 schema: ToolSchema,
29}
30
31impl WriteFileTool {
32 pub fn new() -> Self {
33 Self {
34 schema: ToolSchema {
35 name: "write_file".to_string(),
36 description: "Write a UTF-8 text file. \
37 Overwrites the file if it exists; creates it if it does not. \
38 Creates intermediate directories as needed. \
39 Path must be inside the workspace root."
40 .to_string(),
41 input_schema: json!({
42 "type": "object",
43 "properties": {
44 "path": {
45 "type": "string",
46 "description": "Absolute path or path relative to the session cwd."
47 },
48 "content": {
49 "type": "string",
50 "description": "Full UTF-8 text content. Replaces the file entirely."
51 }
52 },
53 "required": ["path", "content"]
54 }),
55 },
56 }
57 }
58}
59
60impl Default for WriteFileTool {
61 fn default() -> Self {
62 Self::new()
63 }
64}
65
66#[derive(Debug, Deserialize)]
67struct WriteArgs {
68 path: String,
69 content: String,
70}
71
72#[derive(Debug, Serialize)]
73struct WriteFileOutput {
74 bytes_written: u64,
75 created: bool,
76 parent_existed: bool,
77}
78
79impl Tool for WriteFileTool {
80 fn schema(&self) -> &ToolSchema {
81 &self.schema
82 }
83
84 fn safety_hint(&self, _args: &serde_json::Value) -> SafetyClass {
85 SafetyClass::Mutating
86 }
87
88 fn describe<'a>(
89 &'a self,
90 args: &'a serde_json::Value,
91 ctx: ToolContext<'a>,
92 ) -> BoxFuture<'a, ToolCallDescription> {
93 Box::pin(async move {
94 let path = args.get("path").and_then(|v| v.as_str()).unwrap_or("");
95 let content = args.get("content").and_then(|v| v.as_str()).unwrap_or("");
96
97 let title = if path.is_empty() {
98 "Write".to_string()
99 } else {
100 format!("Write {path}")
101 };
102 let mut fields = ToolCallUpdateFields::default();
103 fields.title = Some(title);
104 fields.kind = Some(ToolKind::Edit);
105 if !path.is_empty() {
106 fields.locations = Some(vec![ToolCallLocation::new(PathBuf::from(path))]);
107
108 let old = ctx.fs.read_text(PathBuf::from(path), None, None).await.ok();
114
115 fields.content = Some(vec![ToolCallContent::Diff(
116 Diff::new(PathBuf::from(path), content).old_text(old),
117 )]);
118 }
119 ToolCallDescription { fields }
120 })
121 }
122
123 fn execute(&self, args: serde_json::Value, ctx: ToolContext<'_>) -> ToolStream {
124 let cancel = ctx.cancel.clone();
125 let fs = ctx.fs.clone();
126 let cwd = ctx.cwd.to_path_buf();
127 let fut = async move { run_write(args, cancel, fs, &cwd).await };
128 let s: Pin<Box<dyn futures::Stream<Item = ToolEvent> + Send>> = Box::pin(stream::once(fut));
129 s
130 }
131}
132
133async fn run_write(
134 args: serde_json::Value,
135 cancel: tokio_util::sync::CancellationToken,
136 fs: Arc<dyn FsBackend>,
137 cwd: &std::path::Path,
138) -> ToolEvent {
139 let parsed: WriteArgs = match serde_json::from_value(args) {
140 Ok(v) => v,
141 Err(err) => return ToolEvent::Failed(ToolError::InvalidArgs(BoxError::new(err))),
142 };
143
144 if parsed.content.len() > MAX_WRITE_BYTES {
145 return ToolEvent::Failed(ToolError::Execution(BoxError::new(FsError::TooLarge {
146 bytes: parsed.content.len() as u64,
147 limit: MAX_WRITE_BYTES as u64,
148 })));
149 }
150
151 let path = PathBuf::from(&parsed.path);
152
153 let abs_path = if path.is_absolute() {
156 path.clone()
157 } else {
158 cwd.join(&path)
159 };
160 let parent_existed = abs_path.parent().is_none_or(|p| p.is_dir());
161
162 let old = match fs.read_text(path.clone(), None, None).await {
164 Ok(t) => Some(t),
165 Err(FsError::NotFound(_)) => None,
166 Err(_) => None, };
168
169 let bytes_written = parsed.content.len() as u64;
170
171 let write_fut = fs.write_text(path.clone(), parsed.content.clone());
172 tokio::select! {
173 biased;
174 () = cancel.cancelled() => return ToolEvent::Failed(ToolError::Canceled),
175 r = write_fut => {
176 if let Err(e) = r {
177 return ToolEvent::Failed(map_fs_err(e));
178 }
179 }
180 }
181
182 let raw_output = serde_json::to_value(WriteFileOutput {
183 bytes_written,
184 created: old.is_none(),
185 parent_existed,
186 })
187 .unwrap_or(serde_json::Value::Null);
188
189 let diff = Diff::new(path, parsed.content).old_text(old);
190 let mut fields = ToolCallUpdateFields::default();
191 fields.content = Some(vec![
192 ToolCallContent::Diff(diff),
193 ToolCallContent::Content(Content::new(ContentBlock::Text(TextContent::new(format!(
196 "Wrote {bytes_written} bytes"
197 ))))),
198 ]);
199 fields.raw_output = Some(raw_output);
200 ToolEvent::Completed(fields)
201}
202
203fn map_fs_err(e: FsError) -> ToolError {
204 ToolError::Execution(BoxError::new(e))
205}