Skip to main content

defect_tools/fs/
write.rs

1//! `write_file` tool: overwrites a UTF-8 text file entirely.
2//!
3//! Write tool — writes content to a file.
4
5use std::path::PathBuf;
6use std::pin::Pin;
7use std::sync::Arc;
8
9use agent_client_protocol_schema::{
10    Content, ContentBlock, Diff, TextContent, ToolCallContent, ToolCallLocation,
11    ToolCallUpdateFields, ToolKind,
12};
13
14use defect_agent::error::BoxError;
15use defect_agent::fs::{FsBackend, FsError};
16use defect_agent::tool::{
17    SafetyClass, Tool, ToolCallDescription, ToolContext, ToolError, ToolEvent, ToolSchema,
18    ToolStream,
19};
20use futures::future::BoxFuture;
21use futures::stream;
22use serde::{Deserialize, Serialize};
23use serde_json::json;
24
25const MAX_WRITE_BYTES: usize = 10 * 1024 * 1024;
26
27pub struct WriteFileTool {
28    schema: ToolSchema,
29}
30
31impl WriteFileTool {
32    pub fn new() -> Self {
33        Self {
34            schema: ToolSchema {
35                name: "write_file".to_string(),
36                description: "Write a UTF-8 text file. \
37                              Overwrites the file if it exists; creates it if it does not. \
38                              Creates intermediate directories as needed. \
39                              Path must be inside the workspace root."
40                    .to_string(),
41                input_schema: json!({
42                    "type": "object",
43                    "properties": {
44                        "path": {
45                            "type": "string",
46                            "description": "Absolute path or path relative to the session cwd."
47                        },
48                        "content": {
49                            "type": "string",
50                            "description": "Full UTF-8 text content. Replaces the file entirely."
51                        }
52                    },
53                    "required": ["path", "content"]
54                }),
55            },
56        }
57    }
58}
59
60impl Default for WriteFileTool {
61    fn default() -> Self {
62        Self::new()
63    }
64}
65
66#[derive(Debug, Deserialize)]
67struct WriteArgs {
68    path: String,
69    content: String,
70}
71
72#[derive(Debug, Serialize)]
73struct WriteFileOutput {
74    bytes_written: u64,
75    created: bool,
76    parent_existed: bool,
77}
78
79impl Tool for WriteFileTool {
80    fn schema(&self) -> &ToolSchema {
81        &self.schema
82    }
83
84    fn safety_hint(&self, _args: &serde_json::Value) -> SafetyClass {
85        SafetyClass::Mutating
86    }
87
88    fn describe<'a>(
89        &'a self,
90        args: &'a serde_json::Value,
91        ctx: ToolContext<'a>,
92    ) -> BoxFuture<'a, ToolCallDescription> {
93        Box::pin(async move {
94            let path = args.get("path").and_then(|v| v.as_str()).unwrap_or("");
95            let content = args.get("content").and_then(|v| v.as_str()).unwrap_or("");
96
97            let title = if path.is_empty() {
98                "Write".to_string()
99            } else {
100                format!("Write {path}")
101            };
102            let mut fields = ToolCallUpdateFields::default();
103            fields.title = Some(title);
104            fields.kind = Some(ToolKind::Edit);
105            if !path.is_empty() {
106                fields.locations = Some(vec![ToolCallLocation::new(PathBuf::from(path))]);
107
108                // Lightly read the old content during the `describe` phase so the
109                // authorization UI can render an exact old↔new diff. On failure, fall
110                // back to a "fresh" diff (old=None) — `describe` should not block
111                // ToolCall delivery due to IO jitter. NotFound is equivalent to a "create
112                // new file" path, where old content is None.
113                let old = ctx.fs.read_text(PathBuf::from(path), None, None).await.ok();
114
115                fields.content = Some(vec![ToolCallContent::Diff(
116                    Diff::new(PathBuf::from(path), content).old_text(old),
117                )]);
118            }
119            ToolCallDescription { fields }
120        })
121    }
122
123    fn execute(&self, args: serde_json::Value, ctx: ToolContext<'_>) -> ToolStream {
124        let cancel = ctx.cancel.clone();
125        let fs = ctx.fs.clone();
126        let cwd = ctx.cwd.to_path_buf();
127        let fut = async move { run_write(args, cancel, fs, &cwd).await };
128        let s: Pin<Box<dyn futures::Stream<Item = ToolEvent> + Send>> = Box::pin(stream::once(fut));
129        s
130    }
131}
132
133async fn run_write(
134    args: serde_json::Value,
135    cancel: tokio_util::sync::CancellationToken,
136    fs: Arc<dyn FsBackend>,
137    cwd: &std::path::Path,
138) -> ToolEvent {
139    let parsed: WriteArgs = match serde_json::from_value(args) {
140        Ok(v) => v,
141        Err(err) => return ToolEvent::Failed(ToolError::InvalidArgs(BoxError::new(err))),
142    };
143
144    if parsed.content.len() > MAX_WRITE_BYTES {
145        return ToolEvent::Failed(ToolError::Execution(BoxError::new(FsError::TooLarge {
146            bytes: parsed.content.len() as u64,
147            limit: MAX_WRITE_BYTES as u64,
148        })));
149    }
150
151    let path = PathBuf::from(&parsed.path);
152
153    // Record whether the parent directory already existed before writing (best-effort,
154    // used to inform the LLM).
155    let abs_path = if path.is_absolute() {
156        path.clone()
157    } else {
158        cwd.join(&path)
159    };
160    let parent_existed = abs_path.parent().is_none_or(|p| p.is_dir());
161
162    // Best-effort read of old content for accurate diff and `created` detection
163    let old = match fs.read_text(path.clone(), None, None).await {
164        Ok(t) => Some(t),
165        Err(FsError::NotFound(_)) => None,
166        Err(_) => None, // On read failure, `created` stays `None`; the write step will report the specific error.
167    };
168
169    let bytes_written = parsed.content.len() as u64;
170
171    let write_fut = fs.write_text(path.clone(), parsed.content.clone());
172    tokio::select! {
173        biased;
174        () = cancel.cancelled() => return ToolEvent::Failed(ToolError::Canceled),
175        r = write_fut => {
176            if let Err(e) = r {
177                return ToolEvent::Failed(map_fs_err(e));
178            }
179        }
180    }
181
182    let raw_output = serde_json::to_value(WriteFileOutput {
183        bytes_written,
184        created: old.is_none(),
185        parent_existed,
186    })
187    .unwrap_or(serde_json::Value::Null);
188
189    let diff = Diff::new(path, parsed.content).old_text(old);
190    let mut fields = ToolCallUpdateFields::default();
191    fields.content = Some(vec![
192        ToolCallContent::Diff(diff),
193        // `turn.rs::extract_text` takes the first `Text` block as the `tool_result` —
194        // feeds a short summary to the LLM.
195        ToolCallContent::Content(Content::new(ContentBlock::Text(TextContent::new(format!(
196            "Wrote {bytes_written} bytes"
197        ))))),
198    ]);
199    fields.raw_output = Some(raw_output);
200    ToolEvent::Completed(fields)
201}
202
203fn map_fs_err(e: FsError) -> ToolError {
204    ToolError::Execution(BoxError::new(e))
205}