1use async_trait::async_trait;
2use serde_json::json;
3
4use super::{truncate_head, Tool, ToolContext, ToolOutput};
5use crate::error::Result;
6
7pub struct WriteTool;
8
9#[async_trait]
10impl Tool for WriteTool {
11 fn name(&self) -> &str {
12 "write"
13 }
14 fn label(&self) -> &str {
15 "Write File"
16 }
17 fn description(&self) -> &str {
18 "Create or overwrite a file. Creates parent dirs automatically."
19 }
20 fn parameters(&self) -> serde_json::Value {
21 json!({
22 "type": "object",
23 "properties": {
24 "path": { "type": "string" },
25 "content": { "type": "string" }
26 },
27 "required": ["path", "content"]
28 })
29 }
30 fn is_readonly(&self) -> bool {
31 false
32 }
33
34 async fn execute(
35 &self,
36 _call_id: &str,
37 params: serde_json::Value,
38 ctx: ToolContext,
39 ) -> Result<ToolOutput> {
40 let raw_path = params["path"].as_str().unwrap_or("");
41 let content = params["content"].as_str().unwrap_or("");
42
43 if raw_path.is_empty() {
44 return Ok(ToolOutput::error("Missing required parameter: path"));
45 }
46
47 let path = super::resolve_path(&ctx.cwd, raw_path);
48
49 let existed = path.exists();
50
51 let tracker_warning = if existed {
53 let tracker = ctx.file_tracker.lock().ok();
54 match tracker {
55 Some(t) if !t.was_read(&path) => Some(format!(
56 "Warning: editing {} without reading it first. Consider reading to verify current content.",
57 path.display()
58 )),
59 Some(t) if t.is_stale(&path) => Some(format!(
60 "Warning: {} was modified externally since last read. Re-read to verify current content.",
61 path.display()
62 )),
63 _ => None,
64 }
65 } else {
66 None
67 };
68
69 if let Some(parent) = path.parent() {
71 tokio::fs::create_dir_all(parent).await?;
72 }
73
74 if existed {
75 ctx.checkpoint_state.snapshot_paths(
76 std::slice::from_ref(&path),
77 Some(format!("write {}", path.display())),
78 )?;
79 }
80
81 let normalized = if existed {
83 if let Ok(existing) = tokio::fs::read(&path).await {
84 let has_crlf = existing.windows(2).any(|w| w == b"\r\n");
85 if has_crlf {
86 let lf_content = content.replace("\r\n", "\n");
88 lf_content.replace('\n', "\r\n")
89 } else {
90 content.replace("\r\n", "\n")
92 }
93 } else {
94 content.replace("\r\n", "\n")
95 }
96 } else {
97 content.replace("\r\n", "\n")
98 };
99
100 let bytes_written = normalized.len();
101 tokio::fs::write(&path, &normalized).await?;
102
103 let action = if existed { "overwritten" } else { "created" };
104 let display = path.display().to_string();
105 let summary = format!("{display}: {bytes_written} bytes {action}");
106
107 const DISPLAY_MAX_LINES: usize = 40;
108 const DISPLAY_MAX_BYTES: usize = 8_000;
109 let display_source = normalized.replace("\r\n", "\n");
110 let display_result = truncate_head(&display_source, DISPLAY_MAX_LINES, DISPLAY_MAX_BYTES);
111 let display_content = display_result.content.trim_end_matches('\n').to_string();
112 let display_note = if display_result.truncated {
113 let note = format!(
114 "[output truncated: showing {}/{} lines, {}/{} bytes]",
115 display_result.output_lines,
116 display_result.total_lines,
117 display_result.output_bytes,
118 display_result.total_bytes,
119 );
120 if let Some(ref tf) = display_result.temp_file {
121 format!("{note} full output: {}", tf.display())
122 } else {
123 note
124 }
125 } else {
126 String::new()
127 };
128
129 let mut warnings = Vec::new();
130 if let Some(warning) = tracker_warning {
131 warnings.push(warning);
132 }
133
134 let mut text = summary.clone();
135 for warning in &warnings {
136 text.push('\n');
137 text.push_str(warning);
138 }
139
140 Ok(ToolOutput {
141 content: vec![imp_llm::ContentBlock::Text { text }],
142 details: json!({
143 "path": display,
144 "bytes": bytes_written,
145 "created": !existed,
146 "summary": summary,
147 "warnings": warnings,
148 "display_content": display_content,
149 "display_note": display_note,
150 }),
151 is_error: false,
152 })
153 }
154}
155
156#[cfg(test)]
157mod tests {
158 use super::*;
159 use crate::tools::ToolContext;
160 use std::path::Path;
161 use std::sync::Arc;
162
163 fn test_ctx(dir: &Path) -> ToolContext {
164 let (tx, _rx) = tokio::sync::mpsc::channel(16);
165 let (cmd_tx, _cmd_rx) = tokio::sync::mpsc::channel(16);
166 ToolContext {
167 cwd: dir.to_path_buf(),
168 cancelled: Arc::new(std::sync::atomic::AtomicBool::new(false)),
169 update_tx: tx,
170 command_tx: cmd_tx,
171 ui: Arc::new(crate::ui::NullInterface),
172 file_cache: Arc::new(crate::tools::FileCache::new()),
173 checkpoint_state: Arc::new(crate::tools::CheckpointState::new()),
174 file_tracker: Arc::new(std::sync::Mutex::new(crate::tools::FileTracker::new())),
175 anchor_store: Arc::new(crate::tools::AnchorStore::new()),
176 lua_tool_loader: None,
177 mode: crate::config::AgentMode::Full,
178 read_max_lines: 500,
179 turn_mana_review: Arc::new(std::sync::Mutex::new(
180 crate::mana_review::TurnManaReviewAccumulator::default(),
181 )),
182 config: Arc::new(crate::config::Config::default()),
183 }
184 }
185
186 #[tokio::test]
187 async fn write_new_file() {
188 let dir = tempfile::tempdir().unwrap();
189 let tool = WriteTool;
190
191 let result = tool
192 .execute(
193 "c1",
194 serde_json::json!({"path": "new.txt", "content": "hello world"}),
195 test_ctx(dir.path()),
196 )
197 .await
198 .unwrap();
199
200 assert!(!result.is_error);
201 let details = &result.details;
202 assert_eq!(details["display_content"], "hello world");
203 assert!(details["summary"]
204 .as_str()
205 .unwrap()
206 .ends_with("new.txt: 11 bytes created"));
207 let written = std::fs::read_to_string(dir.path().join("new.txt")).unwrap();
208 assert_eq!(written, "hello world");
209 }
210
211 #[tokio::test]
212 async fn write_creates_parent_dirs() {
213 let dir = tempfile::tempdir().unwrap();
214 let tool = WriteTool;
215
216 let result = tool
217 .execute(
218 "c2",
219 serde_json::json!({"path": "a/b/c/deep.txt", "content": "deep"}),
220 test_ctx(dir.path()),
221 )
222 .await
223 .unwrap();
224
225 assert!(!result.is_error);
226 let written = std::fs::read_to_string(dir.path().join("a/b/c/deep.txt")).unwrap();
227 assert_eq!(written, "deep");
228 }
229
230 #[tokio::test]
231 async fn write_overwrite_creates_checkpoint_snapshot() {
232 let dir = tempfile::tempdir().unwrap();
233 let file = dir.path().join("existing.txt");
234 std::fs::write(&file, "original").unwrap();
235
236 let tool = WriteTool;
237 let ctx = test_ctx(dir.path());
238 let checkpoint_state = ctx.checkpoint_state.clone();
239
240 let result = tool
241 .execute(
242 "c-overwrite",
243 serde_json::json!({"path": "existing.txt", "content": "updated"}),
244 ctx,
245 )
246 .await
247 .unwrap();
248
249 assert!(!result.is_error);
250 assert_eq!(
251 checkpoint_state.original(&file).as_deref(),
252 Some("original")
253 );
254 let checkpoints = checkpoint_state.checkpoints();
255 assert_eq!(checkpoints.len(), 1);
256 assert!(checkpoints[0].files.contains(&file));
257 }
258
259 #[tokio::test]
260 async fn write_empty_content() {
261 let dir = tempfile::tempdir().unwrap();
262 let tool = WriteTool;
263
264 let result = tool
265 .execute(
266 "c4",
267 serde_json::json!({"path": "empty.txt", "content": ""}),
268 test_ctx(dir.path()),
269 )
270 .await
271 .unwrap();
272
273 assert!(!result.is_error);
274 let written = std::fs::read_to_string(dir.path().join("empty.txt")).unwrap();
275 assert_eq!(written, "");
276 assert_eq!(result.details["display_content"], "");
277 }
278
279 #[tokio::test]
280 async fn write_missing_path_error() {
281 let dir = tempfile::tempdir().unwrap();
282 let tool = WriteTool;
283
284 let result = tool
285 .execute(
286 "c5",
287 serde_json::json!({"content": "hello"}),
288 test_ctx(dir.path()),
289 )
290 .await
291 .unwrap();
292
293 assert!(result.is_error);
294 }
295
296 #[tokio::test]
297 async fn write_preserves_crlf_on_overwrite() {
298 let dir = tempfile::tempdir().unwrap();
299 let file = dir.path().join("crlf.txt");
300 std::fs::write(&file, "line1\r\nline2\r\n").unwrap();
302
303 let tool = WriteTool;
304 let result = tool
305 .execute(
306 "c6",
307 serde_json::json!({"path": "crlf.txt", "content": "new1\nnew2\n"}),
308 test_ctx(dir.path()),
309 )
310 .await
311 .unwrap();
312
313 assert!(!result.is_error);
314 let raw = std::fs::read(dir.path().join("crlf.txt")).unwrap();
315 assert!(raw.windows(2).any(|w| w == b"\r\n"));
317 }
318
319 #[tokio::test]
320 async fn write_deep_nested_dirs() {
321 let dir = tempfile::tempdir().unwrap();
322 let tool = WriteTool;
323
324 let result = tool
325 .execute(
326 "c7",
327 serde_json::json!({"path": "x/y/z/w/v/deep.txt", "content": "deep content"}),
328 test_ctx(dir.path()),
329 )
330 .await
331 .unwrap();
332
333 assert!(!result.is_error);
334 let written = std::fs::read_to_string(dir.path().join("x/y/z/w/v/deep.txt")).unwrap();
335 assert_eq!(written, "deep content");
336 }
337
338 #[tokio::test]
339 async fn write_overwrites_existing() {
340 let dir = tempfile::tempdir().unwrap();
341 let file = dir.path().join("exist.txt");
342 std::fs::write(&file, "old content").unwrap();
343
344 let tool = WriteTool;
345 let result = tool
346 .execute(
347 "c3",
348 serde_json::json!({"path": "exist.txt", "content": "new content"}),
349 test_ctx(dir.path()),
350 )
351 .await
352 .unwrap();
353
354 assert!(!result.is_error);
355 let text = result
356 .content
357 .iter()
358 .find_map(|b| match b {
359 imp_llm::ContentBlock::Text { text } => Some(text.as_str()),
360 _ => None,
361 })
362 .unwrap();
363 assert!(text.contains("overwritten"));
364 let written = std::fs::read_to_string(&file).unwrap();
365 assert_eq!(written, "new content");
366 }
367
368 #[tokio::test]
369 async fn write_includes_display_content_metadata() {
370 let dir = tempfile::tempdir().unwrap();
371 let tool = WriteTool;
372
373 let result = tool
374 .execute(
375 "c8",
376 serde_json::json!({"path": "preview.rs", "content": "fn main() {\n println!(\"hi\");\n}\n"}),
377 test_ctx(dir.path()),
378 )
379 .await
380 .unwrap();
381
382 assert!(!result.is_error);
383 assert!(result.details["path"]
384 .as_str()
385 .unwrap()
386 .ends_with("preview.rs"));
387 assert!(result.details["summary"]
388 .as_str()
389 .unwrap()
390 .ends_with("preview.rs: 34 bytes created"));
391 assert_eq!(
392 result.details["display_content"],
393 "fn main() {\n println!(\"hi\");\n}"
394 );
395 assert_eq!(result.details["display_note"], "");
396 }
397
398 #[tokio::test]
399 async fn write_display_content_truncates_large_content() {
400 let dir = tempfile::tempdir().unwrap();
401 let tool = WriteTool;
402 let content = (0..100)
403 .map(|i| format!("line {i}"))
404 .collect::<Vec<_>>()
405 .join("\n");
406
407 let result = tool
408 .execute(
409 "c9",
410 serde_json::json!({"path": "large.txt", "content": content}),
411 test_ctx(dir.path()),
412 )
413 .await
414 .unwrap();
415
416 assert!(!result.is_error);
417 let display_content = result.details["display_content"].as_str().unwrap();
418 assert!(display_content.lines().count() <= 40);
419 assert!(result.details["display_note"]
420 .as_str()
421 .unwrap()
422 .contains("output truncated"));
423 }
424}