Skip to main content

claude_agent/tools/
write.rs

1//! Write tool - creates or overwrites files with atomic operations.
2
3use async_trait::async_trait;
4use schemars::JsonSchema;
5use serde::Deserialize;
6
7use super::SchemaTool;
8use super::context::ExecutionContext;
9use crate::security::fs::SecureFileHandle;
10use crate::types::ToolResult;
11
12#[derive(Debug, Deserialize, JsonSchema)]
13#[schemars(deny_unknown_fields)]
14pub struct WriteInput {
15    /// The absolute path to the file to write (must be absolute, not relative)
16    pub file_path: String,
17    /// The content to write to the file
18    pub content: String,
19}
20
21#[derive(Debug, Clone, Copy, Default)]
22pub struct WriteTool;
23
24#[async_trait]
25impl SchemaTool for WriteTool {
26    type Input = WriteInput;
27
28    const NAME: &'static str = "Write";
29    const DESCRIPTION: &'static str = r#"Writes a file to the local filesystem.
30
31Usage:
32- This tool will overwrite the existing file if there is one at the provided path.
33- If this is an existing file, you MUST use the Read tool first to read the file's contents. This tool will fail if you did not read the file first.
34- ALWAYS prefer editing existing files in the codebase. NEVER write new files unless explicitly required.
35- NEVER proactively create documentation files (*.md) or README files. Only create documentation files if explicitly requested by the User.
36- Only use emojis if the user explicitly requests it. Avoid writing emojis to files unless asked."#;
37
38    async fn handle(&self, input: WriteInput, context: &ExecutionContext) -> ToolResult {
39        let path = match context.try_resolve_for(Self::NAME, &input.file_path) {
40            Ok(p) => p,
41            Err(e) => return e,
42        };
43
44        let content = input.content;
45        let content_len = content.len();
46        let display_path = path.as_path().display().to_string();
47
48        let result = tokio::task::spawn_blocking(move || {
49            let handle = SecureFileHandle::for_atomic_write(path)?;
50            handle.atomic_write(content.as_bytes())?;
51            Ok::<_, crate::security::SecurityError>(())
52        })
53        .await;
54
55        match result {
56            Ok(Ok(())) => ToolResult::success(format!(
57                "Successfully wrote {} bytes to {}",
58                content_len, display_path
59            )),
60            Ok(Err(e)) => ToolResult::error(format!("Failed to write file: {}", e)),
61            Err(e) => ToolResult::error(format!("Task failed: {}", e)),
62        }
63    }
64}
65
66#[cfg(test)]
67mod tests {
68    use super::*;
69    use crate::tools::Tool;
70    use tempfile::tempdir;
71    use tokio::fs;
72
73    #[tokio::test]
74    async fn test_write_file() {
75        let dir = tempdir().unwrap();
76        let root = std::fs::canonicalize(dir.path()).unwrap();
77        let file_path = root.join("test.txt");
78
79        let test_context = ExecutionContext::from_path(&root).unwrap();
80        let tool = WriteTool;
81
82        let result = tool
83            .execute(
84                serde_json::json!({
85                    "file_path": file_path.to_str().unwrap(),
86                    "content": "Hello, World!"
87                }),
88                &test_context,
89            )
90            .await;
91
92        assert!(!result.is_error());
93        let content = fs::read_to_string(&file_path).await.unwrap();
94        assert_eq!(content, "Hello, World!");
95    }
96
97    #[tokio::test]
98    async fn test_write_creates_directories() {
99        let dir = tempdir().unwrap();
100        let root = std::fs::canonicalize(dir.path()).unwrap();
101        let file_path = root.join("subdir/nested/test.txt");
102
103        let test_context = ExecutionContext::from_path(&root).unwrap();
104        let tool = WriteTool;
105
106        let result = tool
107            .execute(
108                serde_json::json!({
109                    "file_path": file_path.to_str().unwrap(),
110                    "content": "Nested content"
111                }),
112                &test_context,
113            )
114            .await;
115
116        assert!(!result.is_error());
117        assert!(file_path.exists());
118    }
119
120    #[tokio::test]
121    async fn test_write_path_escape_blocked() {
122        let dir = tempdir().unwrap();
123        let test_context = ExecutionContext::from_path(dir.path()).unwrap();
124        let tool = WriteTool;
125
126        let result = tool
127            .execute(
128                serde_json::json!({
129                    "file_path": "/etc/passwd",
130                    "content": "bad"
131                }),
132                &test_context,
133            )
134            .await;
135
136        assert!(result.is_error());
137    }
138
139    #[tokio::test]
140    async fn test_write_overwrites_existing() {
141        let dir = tempdir().unwrap();
142        let root = std::fs::canonicalize(dir.path()).unwrap();
143        let file_path = root.join("test.txt");
144        fs::write(&file_path, "original content").await.unwrap();
145
146        let test_context = ExecutionContext::from_path(&root).unwrap();
147        let tool = WriteTool;
148
149        let result = tool
150            .execute(
151                serde_json::json!({
152                    "file_path": file_path.to_str().unwrap(),
153                    "content": "new content"
154                }),
155                &test_context,
156            )
157            .await;
158
159        assert!(!result.is_error());
160        let content = fs::read_to_string(&file_path).await.unwrap();
161        assert_eq!(content, "new content");
162    }
163
164    #[tokio::test]
165    async fn test_write_empty_content() {
166        let dir = tempdir().unwrap();
167        let root = std::fs::canonicalize(dir.path()).unwrap();
168        let file_path = root.join("empty.txt");
169
170        let test_context = ExecutionContext::from_path(&root).unwrap();
171        let tool = WriteTool;
172
173        let result = tool
174            .execute(
175                serde_json::json!({
176                    "file_path": file_path.to_str().unwrap(),
177                    "content": ""
178                }),
179                &test_context,
180            )
181            .await;
182
183        assert!(!result.is_error());
184        let content = fs::read_to_string(&file_path).await.unwrap();
185        assert_eq!(content, "");
186    }
187
188    #[tokio::test]
189    async fn test_write_multiline_content() {
190        let dir = tempdir().unwrap();
191        let root = std::fs::canonicalize(dir.path()).unwrap();
192        let file_path = root.join("multi.txt");
193        let content = "line 1\nline 2\nline 3\n";
194
195        let test_context = ExecutionContext::from_path(&root).unwrap();
196        let tool = WriteTool;
197
198        let result = tool
199            .execute(
200                serde_json::json!({
201                    "file_path": file_path.to_str().unwrap(),
202                    "content": content
203                }),
204                &test_context,
205            )
206            .await;
207
208        assert!(!result.is_error());
209        let read_content = fs::read_to_string(&file_path).await.unwrap();
210        assert_eq!(read_content, content);
211    }
212
213    #[tokio::test]
214    async fn test_write_atomic_no_temp_files_remain() {
215        let dir = tempdir().unwrap();
216        let root = std::fs::canonicalize(dir.path()).unwrap();
217        let file_path = root.join("atomic_test.txt");
218
219        let test_context = ExecutionContext::from_path(&root).unwrap();
220        let tool = WriteTool;
221
222        let result = tool
223            .execute(
224                serde_json::json!({
225                    "file_path": file_path.to_str().unwrap(),
226                    "content": "atomic content"
227                }),
228                &test_context,
229            )
230            .await;
231
232        assert!(!result.is_error());
233
234        let entries: Vec<_> = std::fs::read_dir(&root).unwrap().collect();
235        let has_temp = entries.iter().any(|e| {
236            e.as_ref()
237                .unwrap()
238                .file_name()
239                .to_string_lossy()
240                .contains(".tmp")
241        });
242        assert!(!has_temp, "Temporary files should be cleaned up");
243    }
244
245    #[tokio::test]
246    async fn test_write_atomic_preserves_original_until_complete() {
247        let dir = tempdir().unwrap();
248        let root = std::fs::canonicalize(dir.path()).unwrap();
249        let file_path = root.join("preserve_test.txt");
250        fs::write(&file_path, "original content").await.unwrap();
251
252        let test_context = ExecutionContext::from_path(&root).unwrap();
253        let tool = WriteTool;
254
255        let result = tool
256            .execute(
257                serde_json::json!({
258                    "file_path": file_path.to_str().unwrap(),
259                    "content": "new content"
260                }),
261                &test_context,
262            )
263            .await;
264
265        assert!(!result.is_error());
266        let content = fs::read_to_string(&file_path).await.unwrap();
267        assert_eq!(content, "new content");
268    }
269}