1use async_trait::async_trait;
2use serde_json::json;
3
4use super::{truncate_head, Tool, ToolContext, ToolOutput};
5use crate::config::WriteOverwritePolicy;
6use crate::error::Result;
7
8pub struct WriteTool;
9
10#[async_trait]
11impl Tool for WriteTool {
12 fn name(&self) -> &str {
13 "write"
14 }
15 fn label(&self) -> &str {
16 "Write File"
17 }
18 fn description(&self) -> &str {
19 "Create or overwrite a file. Creates parent dirs automatically."
20 }
21 fn parameters(&self) -> serde_json::Value {
22 json!({
23 "type": "object",
24 "properties": {
25 "path": { "type": "string" },
26 "content": { "type": "string" }
27 },
28 "required": ["path", "content"]
29 })
30 }
31 fn is_readonly(&self) -> bool {
32 false
33 }
34
35 async fn execute(
36 &self,
37 _call_id: &str,
38 params: serde_json::Value,
39 ctx: ToolContext,
40 ) -> Result<ToolOutput> {
41 let raw_path = params["path"].as_str().unwrap_or("");
42 let content = params["content"].as_str().unwrap_or("");
43
44 if raw_path.is_empty() {
45 return Ok(ToolOutput::error("Missing required parameter: path"));
46 }
47
48 let path = super::resolve_path(&ctx.cwd, raw_path);
49
50 if let Err(error) = ctx.check_write_path(&path) {
51 return Ok(ToolOutput::error(error));
52 }
53
54 if path.is_dir() {
55 return Ok(ToolOutput::error(format!(
56 "Path is a directory, not a file: {}",
57 path.display()
58 )));
59 }
60
61 let existed = path.exists();
62
63 let overwrite_check = if existed {
64 evaluate_overwrite_policy(&path, &ctx)
65 } else {
66 OverwriteCheck::default()
67 };
68 if let Some(error) = overwrite_check.error {
69 return Ok(ToolOutput::error(error));
70 }
71
72 if let Some(parent) = path.parent() {
74 tokio::fs::create_dir_all(parent).await?;
75 }
76
77 let checkpoint = if existed {
78 ctx.checkpoint_state.snapshot_paths(
79 std::slice::from_ref(&path),
80 Some(format!("write {}", path.display())),
81 )?
82 } else {
83 None
84 };
85
86 let normalized = if existed {
88 if let Ok(existing) = tokio::fs::read(&path).await {
89 let has_crlf = existing.windows(2).any(|w| w == b"\r\n");
90 if has_crlf {
91 let lf_content = content.replace("\r\n", "\n");
93 lf_content.replace('\n', "\r\n")
94 } else {
95 content.replace("\r\n", "\n")
97 }
98 } else {
99 content.replace("\r\n", "\n")
100 }
101 } else {
102 content.replace("\r\n", "\n")
103 };
104
105 let bytes_written = normalized.len();
106 tokio::fs::write(&path, &normalized).await?;
107
108 let action = if existed { "overwritten" } else { "created" };
109 let display = path.display().to_string();
110 let summary = format!("{display}: {bytes_written} bytes {action}");
111
112 const DISPLAY_MAX_LINES: usize = 40;
113 const DISPLAY_MAX_BYTES: usize = 8_000;
114 let display_source = normalized.replace("\r\n", "\n");
115 let display_result = truncate_head(&display_source, DISPLAY_MAX_LINES, DISPLAY_MAX_BYTES);
116 let display_content = display_result.content.trim_end_matches('\n').to_string();
117 let display_note = if display_result.truncated {
118 let note = format!(
119 "[output truncated: showing {}/{} lines, {}/{} bytes]",
120 display_result.output_lines,
121 display_result.total_lines,
122 display_result.output_bytes,
123 display_result.total_bytes,
124 );
125 if let Some(ref tf) = display_result.temp_file {
126 format!("{note} full output: {}", tf.display())
127 } else {
128 note
129 }
130 } else {
131 String::new()
132 };
133
134 let warnings = overwrite_check.warning_messages;
135 let warning_codes = overwrite_check.warning_codes;
136
137 let mut text = summary.clone();
138 for warning in &warnings {
139 text.push('\n');
140 text.push_str(warning);
141 }
142
143 Ok(ToolOutput {
144 content: vec![imp_llm::ContentBlock::Text { text }],
145 details: json!({
146 "action": action,
147 "path": display,
148 "bytes_written": bytes_written,
149 "line_ending": if normalized.contains("\r\n") { "crlf" } else { "lf" },
150 "created": !existed,
151 "overwritten": existed,
152 "checkpoint_id": checkpoint.as_ref().map(|c| c.id.clone()),
153 "checkpoint_label": checkpoint.as_ref().and_then(|c| c.label.clone()),
154 "summary": summary,
155 "warnings": warnings,
156 "warning_codes": warning_codes,
157 "overwrite_policy": ctx.config.write.overwrite_policy,
158 "display_content": display_content,
159 "display_note": display_note,
160 }),
161 is_error: false,
162 })
163 }
164}
165
166#[derive(Default)]
167struct OverwriteCheck {
168 warning_messages: Vec<String>,
169 warning_codes: Vec<&'static str>,
170 error: Option<String>,
171}
172
173fn evaluate_overwrite_policy(path: &std::path::Path, ctx: &ToolContext) -> OverwriteCheck {
174 let Ok(tracker) = ctx.file_tracker.lock() else {
175 return OverwriteCheck::default();
176 };
177
178 let was_read = tracker.was_read(path);
179 let is_stale = tracker.is_stale(path);
180 let policy = ctx.config.write.overwrite_policy;
181
182 if matches!(policy, WriteOverwritePolicy::Deny) {
183 return OverwriteCheck {
184 error: Some(format!(
185 "Overwriting existing files is disabled by write overwrite policy: {}",
186 path.display()
187 )),
188 ..OverwriteCheck::default()
189 };
190 }
191
192 if matches!(policy, WriteOverwritePolicy::RequireRead) && !was_read {
193 return OverwriteCheck {
194 error: Some(format!(
195 "Write overwrite policy requires reading the file before overwriting: {}",
196 path.display()
197 )),
198 ..OverwriteCheck::default()
199 };
200 }
201
202 if matches!(
203 policy,
204 WriteOverwritePolicy::RequireRead | WriteOverwritePolicy::BlockStale
205 ) && is_stale
206 {
207 return OverwriteCheck {
208 error: Some(format!(
209 "Write overwrite policy blocks overwriting stale files. Re-read before overwriting: {}",
210 path.display()
211 )),
212 ..OverwriteCheck::default()
213 };
214 }
215
216 let mut check = OverwriteCheck::default();
217 if !was_read {
218 check.warning_codes.push("unread_overwrite");
219 check.warning_messages.push(format!(
220 "Warning: overwriting {} without reading it first. Consider reading to verify current content.",
221 path.display()
222 ));
223 } else if is_stale {
224 check.warning_codes.push("stale_overwrite");
225 check.warning_messages.push(format!(
226 "Warning: {} was modified externally since last read. Re-read to verify current content.",
227 path.display()
228 ));
229 }
230
231 check
232}
233
234#[cfg(test)]
235mod tests {
236 use super::*;
237 use crate::tools::ToolContext;
238 use std::path::Path;
239 use std::sync::Arc;
240
241 fn test_ctx(dir: &Path) -> ToolContext {
242 let (tx, _rx) = tokio::sync::mpsc::channel(16);
243 let (cmd_tx, _cmd_rx) = tokio::sync::mpsc::channel(16);
244 ToolContext {
245 cwd: dir.to_path_buf(),
246 cancelled: Arc::new(std::sync::atomic::AtomicBool::new(false)),
247 update_tx: tx,
248 command_tx: cmd_tx,
249 ui: Arc::new(crate::ui::NullInterface),
250 file_cache: Arc::new(crate::tools::FileCache::new()),
251 checkpoint_state: Arc::new(crate::tools::CheckpointState::new()),
252 file_tracker: Arc::new(std::sync::Mutex::new(crate::tools::FileTracker::new())),
253 anchor_store: Arc::new(crate::tools::AnchorStore::new()),
254 lua_tool_loader: None,
255 mode: crate::config::AgentMode::Full,
256 read_max_lines: 500,
257 turn_mana_review: Arc::new(std::sync::Mutex::new(
258 crate::mana_review::TurnManaReviewAccumulator::default(),
259 )),
260 config: Arc::new(crate::config::Config::default()),
261 run_policy: Default::default(),
262 supporting_provenance: Vec::new(),
263 }
264 }
265
266 fn test_ctx_with_policy(dir: &Path, overwrite_policy: WriteOverwritePolicy) -> ToolContext {
267 let mut ctx = test_ctx(dir);
268 let mut config = crate::config::Config::default();
269 config.write.overwrite_policy = overwrite_policy;
270 ctx.config = Arc::new(config);
271 ctx
272 }
273
274 fn test_ctx_with_run_policy(dir: &Path, run_policy: crate::policy::RunPolicy) -> ToolContext {
275 let mut ctx = test_ctx(dir);
276 ctx.run_policy = run_policy;
277 ctx
278 }
279
280 #[tokio::test]
281 async fn write_path_policy_allows_matching_file() {
282 let dir = tempfile::tempdir().unwrap();
283 let tool = WriteTool;
284
285 let result = tool
286 .execute(
287 "c-allow-write",
288 serde_json::json!({"path": "CHANGELOG.md", "content": "updated"}),
289 test_ctx_with_run_policy(
290 dir.path(),
291 crate::policy::RunPolicy::new().allow_write("CHANGELOG.md"),
292 ),
293 )
294 .await
295 .unwrap();
296
297 assert!(!result.is_error);
298 assert_eq!(
299 std::fs::read_to_string(dir.path().join("CHANGELOG.md")).unwrap(),
300 "updated"
301 );
302 }
303
304 #[tokio::test]
305 async fn write_path_policy_blocks_unlisted_file() {
306 let dir = tempfile::tempdir().unwrap();
307 let tool = WriteTool;
308
309 let result = tool
310 .execute(
311 "c-deny-write",
312 serde_json::json!({"path": "src/lib.rs", "content": "updated"}),
313 test_ctx_with_run_policy(
314 dir.path(),
315 crate::policy::RunPolicy::new().allow_write("CHANGELOG.md"),
316 ),
317 )
318 .await
319 .unwrap();
320
321 assert!(result.is_error);
322 assert!(result.text_content().unwrap().contains("write allowlist"));
323 assert!(!dir.path().join("src/lib.rs").exists());
324 }
325
326 #[tokio::test]
327 async fn write_path_policy_blocks_parent_traversal() {
328 let dir = tempfile::tempdir().unwrap();
329 let outside = tempfile::tempdir().unwrap();
330 let relative =
331 pathdiff::diff_paths(outside.path().join("CHANGELOG.md"), dir.path()).unwrap();
332 let tool = WriteTool;
333
334 let result = tool
335 .execute(
336 "c-traversal",
337 serde_json::json!({"path": relative, "content": "updated"}),
338 test_ctx_with_run_policy(
339 dir.path(),
340 crate::policy::RunPolicy::new().allow_write("CHANGELOG.md"),
341 ),
342 )
343 .await
344 .unwrap();
345
346 assert!(result.is_error);
347 assert!(result
348 .text_content()
349 .unwrap()
350 .contains("outside the worker root"));
351 assert!(!outside.path().join("CHANGELOG.md").exists());
352 }
353
354 #[tokio::test]
355 async fn write_path_policy_deny_overrides_allow() {
356 let dir = tempfile::tempdir().unwrap();
357 let tool = WriteTool;
358
359 let result = tool
360 .execute(
361 "c-deny-override",
362 serde_json::json!({"path": "CHANGELOG.md", "content": "updated"}),
363 test_ctx_with_run_policy(
364 dir.path(),
365 crate::policy::RunPolicy::new()
366 .allow_write("CHANGELOG.md")
367 .deny_write("CHANGELOG.md"),
368 ),
369 )
370 .await
371 .unwrap();
372
373 assert!(result.is_error);
374 assert!(result.text_content().unwrap().contains("denylist"));
375 assert!(!dir.path().join("CHANGELOG.md").exists());
376 }
377
378 #[tokio::test]
379 async fn write_path_policy_glob_allows_matching_file() {
380 let dir = tempfile::tempdir().unwrap();
381 std::fs::create_dir_all(dir.path().join("docs")).unwrap();
382 let tool = WriteTool;
383
384 let result = tool
385 .execute(
386 "c-glob-write",
387 serde_json::json!({"path": "docs/CHANGELOG.md", "content": "updated"}),
388 test_ctx_with_run_policy(
389 dir.path(),
390 crate::policy::RunPolicy::new().allow_write("docs/*.md"),
391 ),
392 )
393 .await
394 .unwrap();
395
396 assert!(!result.is_error);
397 assert_eq!(
398 std::fs::read_to_string(dir.path().join("docs/CHANGELOG.md")).unwrap(),
399 "updated"
400 );
401 }
402
403 #[tokio::test]
404 async fn write_default_policy_warns_on_unread_overwrite() {
405 let dir = tempfile::tempdir().unwrap();
406 let file = dir.path().join("existing.txt");
407 std::fs::write(&file, "original").unwrap();
408
409 let tool = WriteTool;
410 let result = tool
411 .execute(
412 "c-warn",
413 serde_json::json!({"path": "existing.txt", "content": "updated"}),
414 test_ctx(dir.path()),
415 )
416 .await
417 .unwrap();
418
419 assert!(!result.is_error);
420 assert_eq!(result.details["warning_codes"][0], "unread_overwrite");
421 assert_eq!(result.details["overwritten"], true);
422 assert!(result.details["checkpoint_id"].as_str().is_some());
423 }
424
425 #[tokio::test]
426 async fn write_require_read_policy_blocks_unread_overwrite() {
427 let dir = tempfile::tempdir().unwrap();
428 let file = dir.path().join("existing.txt");
429 std::fs::write(&file, "original").unwrap();
430
431 let tool = WriteTool;
432 let result = tool
433 .execute(
434 "c-block-unread",
435 serde_json::json!({"path": "existing.txt", "content": "updated"}),
436 test_ctx_with_policy(dir.path(), WriteOverwritePolicy::RequireRead),
437 )
438 .await
439 .unwrap();
440
441 assert!(result.is_error);
442 assert_eq!(std::fs::read_to_string(file).unwrap(), "original");
443 }
444
445 #[tokio::test]
446 async fn write_block_stale_policy_blocks_stale_overwrite() {
447 let dir = tempfile::tempdir().unwrap();
448 let file = dir.path().join("existing.txt");
449 std::fs::write(&file, "original").unwrap();
450
451 let ctx = test_ctx_with_policy(dir.path(), WriteOverwritePolicy::BlockStale);
452 ctx.file_tracker.lock().unwrap().record_read(&file);
453 std::thread::sleep(std::time::Duration::from_millis(5));
454 std::fs::write(&file, "external").unwrap();
455
456 let tool = WriteTool;
457 let result = tool
458 .execute(
459 "c-block-stale",
460 serde_json::json!({"path": "existing.txt", "content": "updated"}),
461 ctx,
462 )
463 .await
464 .unwrap();
465
466 assert!(result.is_error);
467 assert_eq!(std::fs::read_to_string(file).unwrap(), "external");
468 }
469
470 #[tokio::test]
471 async fn write_new_file() {
472 let dir = tempfile::tempdir().unwrap();
473 let tool = WriteTool;
474
475 let result = tool
476 .execute(
477 "c1",
478 serde_json::json!({"path": "new.txt", "content": "hello world"}),
479 test_ctx(dir.path()),
480 )
481 .await
482 .unwrap();
483
484 assert!(!result.is_error);
485 let details = &result.details;
486 assert_eq!(details["display_content"], "hello world");
487 assert!(details["summary"]
488 .as_str()
489 .unwrap()
490 .ends_with("new.txt: 11 bytes created"));
491 let written = std::fs::read_to_string(dir.path().join("new.txt")).unwrap();
492 assert_eq!(written, "hello world");
493 }
494
495 #[tokio::test]
496 async fn write_creates_parent_dirs() {
497 let dir = tempfile::tempdir().unwrap();
498 let tool = WriteTool;
499
500 let result = tool
501 .execute(
502 "c2",
503 serde_json::json!({"path": "a/b/c/deep.txt", "content": "deep"}),
504 test_ctx(dir.path()),
505 )
506 .await
507 .unwrap();
508
509 assert!(!result.is_error);
510 let written = std::fs::read_to_string(dir.path().join("a/b/c/deep.txt")).unwrap();
511 assert_eq!(written, "deep");
512 }
513
514 #[tokio::test]
515 async fn write_overwrite_creates_checkpoint_snapshot() {
516 let dir = tempfile::tempdir().unwrap();
517 let file = dir.path().join("existing.txt");
518 std::fs::write(&file, "original").unwrap();
519
520 let tool = WriteTool;
521 let ctx = test_ctx(dir.path());
522 let checkpoint_state = ctx.checkpoint_state.clone();
523
524 let result = tool
525 .execute(
526 "c-overwrite",
527 serde_json::json!({"path": "existing.txt", "content": "updated"}),
528 ctx,
529 )
530 .await
531 .unwrap();
532
533 assert!(!result.is_error);
534 assert_eq!(
535 checkpoint_state.original(&file).as_deref(),
536 Some("original")
537 );
538 let checkpoints = checkpoint_state.checkpoints();
539 assert_eq!(checkpoints.len(), 1);
540 assert!(checkpoints[0].files.contains(&file));
541 }
542
543 #[tokio::test]
544 async fn write_empty_content() {
545 let dir = tempfile::tempdir().unwrap();
546 let tool = WriteTool;
547
548 let result = tool
549 .execute(
550 "c4",
551 serde_json::json!({"path": "empty.txt", "content": ""}),
552 test_ctx(dir.path()),
553 )
554 .await
555 .unwrap();
556
557 assert!(!result.is_error);
558 let written = std::fs::read_to_string(dir.path().join("empty.txt")).unwrap();
559 assert_eq!(written, "");
560 assert_eq!(result.details["display_content"], "");
561 }
562
563 #[tokio::test]
564 async fn write_missing_path_error() {
565 let dir = tempfile::tempdir().unwrap();
566 let tool = WriteTool;
567
568 let result = tool
569 .execute(
570 "c5",
571 serde_json::json!({"content": "hello"}),
572 test_ctx(dir.path()),
573 )
574 .await
575 .unwrap();
576
577 assert!(result.is_error);
578 }
579
580 #[tokio::test]
581 async fn write_preserves_crlf_on_overwrite() {
582 let dir = tempfile::tempdir().unwrap();
583 let file = dir.path().join("crlf.txt");
584 std::fs::write(&file, "line1\r\nline2\r\n").unwrap();
586
587 let tool = WriteTool;
588 let result = tool
589 .execute(
590 "c6",
591 serde_json::json!({"path": "crlf.txt", "content": "new1\nnew2\n"}),
592 test_ctx(dir.path()),
593 )
594 .await
595 .unwrap();
596
597 assert!(!result.is_error);
598 let raw = std::fs::read(dir.path().join("crlf.txt")).unwrap();
599 assert!(raw.windows(2).any(|w| w == b"\r\n"));
601 }
602
603 #[tokio::test]
604 async fn write_deep_nested_dirs() {
605 let dir = tempfile::tempdir().unwrap();
606 let tool = WriteTool;
607
608 let result = tool
609 .execute(
610 "c7",
611 serde_json::json!({"path": "x/y/z/w/v/deep.txt", "content": "deep content"}),
612 test_ctx(dir.path()),
613 )
614 .await
615 .unwrap();
616
617 assert!(!result.is_error);
618 let written = std::fs::read_to_string(dir.path().join("x/y/z/w/v/deep.txt")).unwrap();
619 assert_eq!(written, "deep content");
620 }
621
622 #[tokio::test]
623 async fn write_overwrites_existing() {
624 let dir = tempfile::tempdir().unwrap();
625 let file = dir.path().join("exist.txt");
626 std::fs::write(&file, "old content").unwrap();
627
628 let tool = WriteTool;
629 let result = tool
630 .execute(
631 "c3",
632 serde_json::json!({"path": "exist.txt", "content": "new content"}),
633 test_ctx(dir.path()),
634 )
635 .await
636 .unwrap();
637
638 assert!(!result.is_error);
639 let text = result
640 .content
641 .iter()
642 .find_map(|b| match b {
643 imp_llm::ContentBlock::Text { text } => Some(text.as_str()),
644 _ => None,
645 })
646 .unwrap();
647 assert!(text.contains("overwritten"));
648 let written = std::fs::read_to_string(&file).unwrap();
649 assert_eq!(written, "new content");
650 }
651
652 #[tokio::test]
653 async fn write_includes_display_content_metadata() {
654 let dir = tempfile::tempdir().unwrap();
655 let tool = WriteTool;
656
657 let result = tool
658 .execute(
659 "c8",
660 serde_json::json!({"path": "preview.rs", "content": "fn main() {\n println!(\"hi\");\n}\n"}),
661 test_ctx(dir.path()),
662 )
663 .await
664 .unwrap();
665
666 assert!(!result.is_error);
667 assert!(result.details["path"]
668 .as_str()
669 .unwrap()
670 .ends_with("preview.rs"));
671 assert!(result.details["summary"]
672 .as_str()
673 .unwrap()
674 .ends_with("preview.rs: 34 bytes created"));
675 assert_eq!(
676 result.details["display_content"],
677 "fn main() {\n println!(\"hi\");\n}"
678 );
679 assert_eq!(result.details["display_note"], "");
680 }
681
682 #[tokio::test]
683 async fn write_display_content_truncates_large_content() {
684 let dir = tempfile::tempdir().unwrap();
685 let tool = WriteTool;
686 let content = (0..100)
687 .map(|i| format!("line {i}"))
688 .collect::<Vec<_>>()
689 .join("\n");
690
691 let result = tool
692 .execute(
693 "c9",
694 serde_json::json!({"path": "large.txt", "content": content}),
695 test_ctx(dir.path()),
696 )
697 .await
698 .unwrap();
699
700 assert!(!result.is_error);
701 let display_content = result.details["display_content"].as_str().unwrap();
702 assert!(display_content.lines().count() <= 40);
703 assert!(result.details["display_note"]
704 .as_str()
705 .unwrap()
706 .contains("output truncated"));
707 }
708}