Skip to main content

imp_core/tools/
write.rs

1use async_trait::async_trait;
2use serde_json::json;
3
4use super::{truncate_head, Tool, ToolContext, ToolOutput};
5use crate::config::WriteOverwritePolicy;
6use crate::error::Result;
7
8pub struct WriteTool;
9
10#[async_trait]
11impl Tool for WriteTool {
12    fn name(&self) -> &str {
13        "write"
14    }
15    fn label(&self) -> &str {
16        "Write File"
17    }
18    fn description(&self) -> &str {
19        "Create or overwrite a file. Creates parent dirs automatically."
20    }
21    fn parameters(&self) -> serde_json::Value {
22        json!({
23            "type": "object",
24            "properties": {
25                "path": { "type": "string" },
26                "content": { "type": "string" }
27            },
28            "required": ["path", "content"]
29        })
30    }
31    fn is_readonly(&self) -> bool {
32        false
33    }
34
35    async fn execute(
36        &self,
37        _call_id: &str,
38        params: serde_json::Value,
39        ctx: ToolContext,
40    ) -> Result<ToolOutput> {
41        let raw_path = params["path"].as_str().unwrap_or("");
42        let content = params["content"].as_str().unwrap_or("");
43
44        if raw_path.is_empty() {
45            return Ok(ToolOutput::error("Missing required parameter: path"));
46        }
47
48        let path = super::resolve_path(&ctx.cwd, raw_path);
49
50        if let Err(error) = ctx.check_write_path(&path) {
51            return Ok(ToolOutput::error(error));
52        }
53
54        if path.is_dir() {
55            return Ok(ToolOutput::error(format!(
56                "Path is a directory, not a file: {}",
57                path.display()
58            )));
59        }
60
61        let existed = path.exists();
62
63        let overwrite_check = if existed {
64            evaluate_overwrite_policy(&path, &ctx)
65        } else {
66            OverwriteCheck::default()
67        };
68        if let Some(error) = overwrite_check.error {
69            return Ok(ToolOutput::error(error));
70        }
71
72        // Create parent directories
73        if let Some(parent) = path.parent() {
74            tokio::fs::create_dir_all(parent).await?;
75        }
76
77        let checkpoint = if existed {
78            ctx.checkpoint_state.snapshot_paths(
79                std::slice::from_ref(&path),
80                Some(format!("write {}", path.display())),
81            )?
82        } else {
83            None
84        };
85
86        // Detect existing line endings to preserve them, default to LF for new files
87        let normalized = if existed {
88            if let Ok(existing) = tokio::fs::read(&path).await {
89                let has_crlf = existing.windows(2).any(|w| w == b"\r\n");
90                if has_crlf {
91                    // Preserve CRLF: ensure content uses CRLF
92                    let lf_content = content.replace("\r\n", "\n");
93                    lf_content.replace('\n', "\r\n")
94                } else {
95                    // LF or no newlines — ensure LF
96                    content.replace("\r\n", "\n")
97                }
98            } else {
99                content.replace("\r\n", "\n")
100            }
101        } else {
102            content.replace("\r\n", "\n")
103        };
104
105        let bytes_written = normalized.len();
106        tokio::fs::write(&path, &normalized).await?;
107
108        let action = if existed { "overwritten" } else { "created" };
109        let display = path.display().to_string();
110        let summary = format!("{display}: {bytes_written} bytes {action}");
111
112        const DISPLAY_MAX_LINES: usize = 40;
113        const DISPLAY_MAX_BYTES: usize = 8_000;
114        let display_source = normalized.replace("\r\n", "\n");
115        let display_result = truncate_head(&display_source, DISPLAY_MAX_LINES, DISPLAY_MAX_BYTES);
116        let display_content = display_result.content.trim_end_matches('\n').to_string();
117        let display_note = if display_result.truncated {
118            let note = format!(
119                "[output truncated: showing {}/{} lines, {}/{} bytes]",
120                display_result.output_lines,
121                display_result.total_lines,
122                display_result.output_bytes,
123                display_result.total_bytes,
124            );
125            if let Some(ref tf) = display_result.temp_file {
126                format!("{note} full output: {}", tf.display())
127            } else {
128                note
129            }
130        } else {
131            String::new()
132        };
133
134        let warnings = overwrite_check.warning_messages;
135        let warning_codes = overwrite_check.warning_codes;
136
137        let mut text = summary.clone();
138        for warning in &warnings {
139            text.push('\n');
140            text.push_str(warning);
141        }
142
143        Ok(ToolOutput {
144            content: vec![imp_llm::ContentBlock::Text { text }],
145            details: json!({
146                "action": action,
147                "path": display,
148                "bytes_written": bytes_written,
149                "line_ending": if normalized.contains("\r\n") { "crlf" } else { "lf" },
150                "created": !existed,
151                "overwritten": existed,
152                "checkpoint_id": checkpoint.as_ref().map(|c| c.id.clone()),
153                "checkpoint_label": checkpoint.as_ref().and_then(|c| c.label.clone()),
154                "summary": summary,
155                "warnings": warnings,
156                "warning_codes": warning_codes,
157                "overwrite_policy": ctx.config.write.overwrite_policy,
158                "display_content": display_content,
159                "display_note": display_note,
160            }),
161            is_error: false,
162        })
163    }
164}
165
166#[derive(Default)]
167struct OverwriteCheck {
168    warning_messages: Vec<String>,
169    warning_codes: Vec<&'static str>,
170    error: Option<String>,
171}
172
173fn evaluate_overwrite_policy(path: &std::path::Path, ctx: &ToolContext) -> OverwriteCheck {
174    let Ok(tracker) = ctx.file_tracker.lock() else {
175        return OverwriteCheck::default();
176    };
177
178    let was_read = tracker.was_read(path);
179    let is_stale = tracker.is_stale(path);
180    let policy = ctx.config.write.overwrite_policy;
181
182    if matches!(policy, WriteOverwritePolicy::Deny) {
183        return OverwriteCheck {
184            error: Some(format!(
185                "Overwriting existing files is disabled by write overwrite policy: {}",
186                path.display()
187            )),
188            ..OverwriteCheck::default()
189        };
190    }
191
192    if matches!(policy, WriteOverwritePolicy::RequireRead) && !was_read {
193        return OverwriteCheck {
194            error: Some(format!(
195                "Write overwrite policy requires reading the file before overwriting: {}",
196                path.display()
197            )),
198            ..OverwriteCheck::default()
199        };
200    }
201
202    if matches!(
203        policy,
204        WriteOverwritePolicy::RequireRead | WriteOverwritePolicy::BlockStale
205    ) && is_stale
206    {
207        return OverwriteCheck {
208            error: Some(format!(
209                "Write overwrite policy blocks overwriting stale files. Re-read before overwriting: {}",
210                path.display()
211            )),
212            ..OverwriteCheck::default()
213        };
214    }
215
216    let mut check = OverwriteCheck::default();
217    if !was_read {
218        check.warning_codes.push("unread_overwrite");
219        check.warning_messages.push(format!(
220            "Warning: overwriting {} without reading it first. Consider reading to verify current content.",
221            path.display()
222        ));
223    } else if is_stale {
224        check.warning_codes.push("stale_overwrite");
225        check.warning_messages.push(format!(
226            "Warning: {} was modified externally since last read. Re-read to verify current content.",
227            path.display()
228        ));
229    }
230
231    check
232}
233
234#[cfg(test)]
235mod tests {
236    use super::*;
237    use crate::tools::ToolContext;
238    use std::path::Path;
239    use std::sync::Arc;
240
241    fn test_ctx(dir: &Path) -> ToolContext {
242        let (tx, _rx) = tokio::sync::mpsc::channel(16);
243        let (cmd_tx, _cmd_rx) = tokio::sync::mpsc::channel(16);
244        ToolContext {
245            cwd: dir.to_path_buf(),
246            cancelled: Arc::new(std::sync::atomic::AtomicBool::new(false)),
247            update_tx: tx,
248            command_tx: cmd_tx,
249            ui: Arc::new(crate::ui::NullInterface),
250            file_cache: Arc::new(crate::tools::FileCache::new()),
251            checkpoint_state: Arc::new(crate::tools::CheckpointState::new()),
252            file_tracker: Arc::new(std::sync::Mutex::new(crate::tools::FileTracker::new())),
253            anchor_store: Arc::new(crate::tools::AnchorStore::new()),
254            lua_tool_loader: None,
255            mode: crate::config::AgentMode::Full,
256            read_max_lines: 500,
257            turn_mana_review: Arc::new(std::sync::Mutex::new(
258                crate::mana_review::TurnManaReviewAccumulator::default(),
259            )),
260            config: Arc::new(crate::config::Config::default()),
261            run_policy: Default::default(),
262            supporting_provenance: Vec::new(),
263        }
264    }
265
266    fn test_ctx_with_policy(dir: &Path, overwrite_policy: WriteOverwritePolicy) -> ToolContext {
267        let mut ctx = test_ctx(dir);
268        let mut config = crate::config::Config::default();
269        config.write.overwrite_policy = overwrite_policy;
270        ctx.config = Arc::new(config);
271        ctx
272    }
273
274    fn test_ctx_with_run_policy(dir: &Path, run_policy: crate::policy::RunPolicy) -> ToolContext {
275        let mut ctx = test_ctx(dir);
276        ctx.run_policy = run_policy;
277        ctx
278    }
279
280    #[tokio::test]
281    async fn write_path_policy_allows_matching_file() {
282        let dir = tempfile::tempdir().unwrap();
283        let tool = WriteTool;
284
285        let result = tool
286            .execute(
287                "c-allow-write",
288                serde_json::json!({"path": "CHANGELOG.md", "content": "updated"}),
289                test_ctx_with_run_policy(
290                    dir.path(),
291                    crate::policy::RunPolicy::new().allow_write("CHANGELOG.md"),
292                ),
293            )
294            .await
295            .unwrap();
296
297        assert!(!result.is_error);
298        assert_eq!(
299            std::fs::read_to_string(dir.path().join("CHANGELOG.md")).unwrap(),
300            "updated"
301        );
302    }
303
304    #[tokio::test]
305    async fn write_path_policy_blocks_unlisted_file() {
306        let dir = tempfile::tempdir().unwrap();
307        let tool = WriteTool;
308
309        let result = tool
310            .execute(
311                "c-deny-write",
312                serde_json::json!({"path": "src/lib.rs", "content": "updated"}),
313                test_ctx_with_run_policy(
314                    dir.path(),
315                    crate::policy::RunPolicy::new().allow_write("CHANGELOG.md"),
316                ),
317            )
318            .await
319            .unwrap();
320
321        assert!(result.is_error);
322        assert!(result.text_content().unwrap().contains("write allowlist"));
323        assert!(!dir.path().join("src/lib.rs").exists());
324    }
325
326    #[tokio::test]
327    async fn write_path_policy_blocks_parent_traversal() {
328        let dir = tempfile::tempdir().unwrap();
329        let outside = tempfile::tempdir().unwrap();
330        let relative =
331            pathdiff::diff_paths(outside.path().join("CHANGELOG.md"), dir.path()).unwrap();
332        let tool = WriteTool;
333
334        let result = tool
335            .execute(
336                "c-traversal",
337                serde_json::json!({"path": relative, "content": "updated"}),
338                test_ctx_with_run_policy(
339                    dir.path(),
340                    crate::policy::RunPolicy::new().allow_write("CHANGELOG.md"),
341                ),
342            )
343            .await
344            .unwrap();
345
346        assert!(result.is_error);
347        assert!(result
348            .text_content()
349            .unwrap()
350            .contains("outside the worker root"));
351        assert!(!outside.path().join("CHANGELOG.md").exists());
352    }
353
354    #[tokio::test]
355    async fn write_path_policy_deny_overrides_allow() {
356        let dir = tempfile::tempdir().unwrap();
357        let tool = WriteTool;
358
359        let result = tool
360            .execute(
361                "c-deny-override",
362                serde_json::json!({"path": "CHANGELOG.md", "content": "updated"}),
363                test_ctx_with_run_policy(
364                    dir.path(),
365                    crate::policy::RunPolicy::new()
366                        .allow_write("CHANGELOG.md")
367                        .deny_write("CHANGELOG.md"),
368                ),
369            )
370            .await
371            .unwrap();
372
373        assert!(result.is_error);
374        assert!(result.text_content().unwrap().contains("denylist"));
375        assert!(!dir.path().join("CHANGELOG.md").exists());
376    }
377
378    #[tokio::test]
379    async fn write_path_policy_glob_allows_matching_file() {
380        let dir = tempfile::tempdir().unwrap();
381        std::fs::create_dir_all(dir.path().join("docs")).unwrap();
382        let tool = WriteTool;
383
384        let result = tool
385            .execute(
386                "c-glob-write",
387                serde_json::json!({"path": "docs/CHANGELOG.md", "content": "updated"}),
388                test_ctx_with_run_policy(
389                    dir.path(),
390                    crate::policy::RunPolicy::new().allow_write("docs/*.md"),
391                ),
392            )
393            .await
394            .unwrap();
395
396        assert!(!result.is_error);
397        assert_eq!(
398            std::fs::read_to_string(dir.path().join("docs/CHANGELOG.md")).unwrap(),
399            "updated"
400        );
401    }
402
403    #[tokio::test]
404    async fn write_default_policy_warns_on_unread_overwrite() {
405        let dir = tempfile::tempdir().unwrap();
406        let file = dir.path().join("existing.txt");
407        std::fs::write(&file, "original").unwrap();
408
409        let tool = WriteTool;
410        let result = tool
411            .execute(
412                "c-warn",
413                serde_json::json!({"path": "existing.txt", "content": "updated"}),
414                test_ctx(dir.path()),
415            )
416            .await
417            .unwrap();
418
419        assert!(!result.is_error);
420        assert_eq!(result.details["warning_codes"][0], "unread_overwrite");
421        assert_eq!(result.details["overwritten"], true);
422        assert!(result.details["checkpoint_id"].as_str().is_some());
423    }
424
425    #[tokio::test]
426    async fn write_require_read_policy_blocks_unread_overwrite() {
427        let dir = tempfile::tempdir().unwrap();
428        let file = dir.path().join("existing.txt");
429        std::fs::write(&file, "original").unwrap();
430
431        let tool = WriteTool;
432        let result = tool
433            .execute(
434                "c-block-unread",
435                serde_json::json!({"path": "existing.txt", "content": "updated"}),
436                test_ctx_with_policy(dir.path(), WriteOverwritePolicy::RequireRead),
437            )
438            .await
439            .unwrap();
440
441        assert!(result.is_error);
442        assert_eq!(std::fs::read_to_string(file).unwrap(), "original");
443    }
444
445    #[tokio::test]
446    async fn write_block_stale_policy_blocks_stale_overwrite() {
447        let dir = tempfile::tempdir().unwrap();
448        let file = dir.path().join("existing.txt");
449        std::fs::write(&file, "original").unwrap();
450
451        let ctx = test_ctx_with_policy(dir.path(), WriteOverwritePolicy::BlockStale);
452        ctx.file_tracker.lock().unwrap().record_read(&file);
453        std::thread::sleep(std::time::Duration::from_millis(5));
454        std::fs::write(&file, "external").unwrap();
455
456        let tool = WriteTool;
457        let result = tool
458            .execute(
459                "c-block-stale",
460                serde_json::json!({"path": "existing.txt", "content": "updated"}),
461                ctx,
462            )
463            .await
464            .unwrap();
465
466        assert!(result.is_error);
467        assert_eq!(std::fs::read_to_string(file).unwrap(), "external");
468    }
469
470    #[tokio::test]
471    async fn write_new_file() {
472        let dir = tempfile::tempdir().unwrap();
473        let tool = WriteTool;
474
475        let result = tool
476            .execute(
477                "c1",
478                serde_json::json!({"path": "new.txt", "content": "hello world"}),
479                test_ctx(dir.path()),
480            )
481            .await
482            .unwrap();
483
484        assert!(!result.is_error);
485        let details = &result.details;
486        assert_eq!(details["display_content"], "hello world");
487        assert!(details["summary"]
488            .as_str()
489            .unwrap()
490            .ends_with("new.txt: 11 bytes created"));
491        let written = std::fs::read_to_string(dir.path().join("new.txt")).unwrap();
492        assert_eq!(written, "hello world");
493    }
494
495    #[tokio::test]
496    async fn write_creates_parent_dirs() {
497        let dir = tempfile::tempdir().unwrap();
498        let tool = WriteTool;
499
500        let result = tool
501            .execute(
502                "c2",
503                serde_json::json!({"path": "a/b/c/deep.txt", "content": "deep"}),
504                test_ctx(dir.path()),
505            )
506            .await
507            .unwrap();
508
509        assert!(!result.is_error);
510        let written = std::fs::read_to_string(dir.path().join("a/b/c/deep.txt")).unwrap();
511        assert_eq!(written, "deep");
512    }
513
514    #[tokio::test]
515    async fn write_overwrite_creates_checkpoint_snapshot() {
516        let dir = tempfile::tempdir().unwrap();
517        let file = dir.path().join("existing.txt");
518        std::fs::write(&file, "original").unwrap();
519
520        let tool = WriteTool;
521        let ctx = test_ctx(dir.path());
522        let checkpoint_state = ctx.checkpoint_state.clone();
523
524        let result = tool
525            .execute(
526                "c-overwrite",
527                serde_json::json!({"path": "existing.txt", "content": "updated"}),
528                ctx,
529            )
530            .await
531            .unwrap();
532
533        assert!(!result.is_error);
534        assert_eq!(
535            checkpoint_state.original(&file).as_deref(),
536            Some("original")
537        );
538        let checkpoints = checkpoint_state.checkpoints();
539        assert_eq!(checkpoints.len(), 1);
540        assert!(checkpoints[0].files.contains(&file));
541    }
542
543    #[tokio::test]
544    async fn write_empty_content() {
545        let dir = tempfile::tempdir().unwrap();
546        let tool = WriteTool;
547
548        let result = tool
549            .execute(
550                "c4",
551                serde_json::json!({"path": "empty.txt", "content": ""}),
552                test_ctx(dir.path()),
553            )
554            .await
555            .unwrap();
556
557        assert!(!result.is_error);
558        let written = std::fs::read_to_string(dir.path().join("empty.txt")).unwrap();
559        assert_eq!(written, "");
560        assert_eq!(result.details["display_content"], "");
561    }
562
563    #[tokio::test]
564    async fn write_missing_path_error() {
565        let dir = tempfile::tempdir().unwrap();
566        let tool = WriteTool;
567
568        let result = tool
569            .execute(
570                "c5",
571                serde_json::json!({"content": "hello"}),
572                test_ctx(dir.path()),
573            )
574            .await
575            .unwrap();
576
577        assert!(result.is_error);
578    }
579
580    #[tokio::test]
581    async fn write_preserves_crlf_on_overwrite() {
582        let dir = tempfile::tempdir().unwrap();
583        let file = dir.path().join("crlf.txt");
584        // Write a CRLF file first
585        std::fs::write(&file, "line1\r\nline2\r\n").unwrap();
586
587        let tool = WriteTool;
588        let result = tool
589            .execute(
590                "c6",
591                serde_json::json!({"path": "crlf.txt", "content": "new1\nnew2\n"}),
592                test_ctx(dir.path()),
593            )
594            .await
595            .unwrap();
596
597        assert!(!result.is_error);
598        let raw = std::fs::read(dir.path().join("crlf.txt")).unwrap();
599        // Should convert LF to CRLF since original had CRLF
600        assert!(raw.windows(2).any(|w| w == b"\r\n"));
601    }
602
603    #[tokio::test]
604    async fn write_deep_nested_dirs() {
605        let dir = tempfile::tempdir().unwrap();
606        let tool = WriteTool;
607
608        let result = tool
609            .execute(
610                "c7",
611                serde_json::json!({"path": "x/y/z/w/v/deep.txt", "content": "deep content"}),
612                test_ctx(dir.path()),
613            )
614            .await
615            .unwrap();
616
617        assert!(!result.is_error);
618        let written = std::fs::read_to_string(dir.path().join("x/y/z/w/v/deep.txt")).unwrap();
619        assert_eq!(written, "deep content");
620    }
621
622    #[tokio::test]
623    async fn write_overwrites_existing() {
624        let dir = tempfile::tempdir().unwrap();
625        let file = dir.path().join("exist.txt");
626        std::fs::write(&file, "old content").unwrap();
627
628        let tool = WriteTool;
629        let result = tool
630            .execute(
631                "c3",
632                serde_json::json!({"path": "exist.txt", "content": "new content"}),
633                test_ctx(dir.path()),
634            )
635            .await
636            .unwrap();
637
638        assert!(!result.is_error);
639        let text = result
640            .content
641            .iter()
642            .find_map(|b| match b {
643                imp_llm::ContentBlock::Text { text } => Some(text.as_str()),
644                _ => None,
645            })
646            .unwrap();
647        assert!(text.contains("overwritten"));
648        let written = std::fs::read_to_string(&file).unwrap();
649        assert_eq!(written, "new content");
650    }
651
652    #[tokio::test]
653    async fn write_includes_display_content_metadata() {
654        let dir = tempfile::tempdir().unwrap();
655        let tool = WriteTool;
656
657        let result = tool
658            .execute(
659                "c8",
660                serde_json::json!({"path": "preview.rs", "content": "fn main() {\n    println!(\"hi\");\n}\n"}),
661                test_ctx(dir.path()),
662            )
663            .await
664            .unwrap();
665
666        assert!(!result.is_error);
667        assert!(result.details["path"]
668            .as_str()
669            .unwrap()
670            .ends_with("preview.rs"));
671        assert!(result.details["summary"]
672            .as_str()
673            .unwrap()
674            .ends_with("preview.rs: 34 bytes created"));
675        assert_eq!(
676            result.details["display_content"],
677            "fn main() {\n    println!(\"hi\");\n}"
678        );
679        assert_eq!(result.details["display_note"], "");
680    }
681
682    #[tokio::test]
683    async fn write_display_content_truncates_large_content() {
684        let dir = tempfile::tempdir().unwrap();
685        let tool = WriteTool;
686        let content = (0..100)
687            .map(|i| format!("line {i}"))
688            .collect::<Vec<_>>()
689            .join("\n");
690
691        let result = tool
692            .execute(
693                "c9",
694                serde_json::json!({"path": "large.txt", "content": content}),
695                test_ctx(dir.path()),
696            )
697            .await
698            .unwrap();
699
700        assert!(!result.is_error);
701        let display_content = result.details["display_content"].as_str().unwrap();
702        assert!(display_content.lines().count() <= 40);
703        assert!(result.details["display_note"]
704            .as_str()
705            .unwrap()
706            .contains("output truncated"));
707    }
708}