Skip to main content

imp_core/tools/
write.rs

1use async_trait::async_trait;
2use serde_json::json;
3
4use super::{truncate_head, Tool, ToolContext, ToolOutput};
5use crate::error::Result;
6
7pub struct WriteTool;
8
9#[async_trait]
10impl Tool for WriteTool {
11    fn name(&self) -> &str {
12        "write"
13    }
14    fn label(&self) -> &str {
15        "Write File"
16    }
17    fn description(&self) -> &str {
18        "Create or overwrite a file. Creates parent dirs automatically."
19    }
20    fn parameters(&self) -> serde_json::Value {
21        json!({
22            "type": "object",
23            "properties": {
24                "path": { "type": "string" },
25                "content": { "type": "string" }
26            },
27            "required": ["path", "content"]
28        })
29    }
30    fn is_readonly(&self) -> bool {
31        false
32    }
33
34    async fn execute(
35        &self,
36        _call_id: &str,
37        params: serde_json::Value,
38        ctx: ToolContext,
39    ) -> Result<ToolOutput> {
40        let raw_path = params["path"].as_str().unwrap_or("");
41        let content = params["content"].as_str().unwrap_or("");
42
43        if raw_path.is_empty() {
44            return Ok(ToolOutput::error("Missing required parameter: path"));
45        }
46
47        let path = super::resolve_path(&ctx.cwd, raw_path);
48
49        let existed = path.exists();
50
51        // Check for unread or stale file — warn but don't block (only relevant for overwrites).
52        let tracker_warning = if existed {
53            let tracker = ctx.file_tracker.lock().ok();
54            match tracker {
55                Some(t) if !t.was_read(&path) => Some(format!(
56                    "Warning: editing {} without reading it first. Consider reading to verify current content.",
57                    path.display()
58                )),
59                Some(t) if t.is_stale(&path) => Some(format!(
60                    "Warning: {} was modified externally since last read. Re-read to verify current content.",
61                    path.display()
62                )),
63                _ => None,
64            }
65        } else {
66            None
67        };
68
69        // Create parent directories
70        if let Some(parent) = path.parent() {
71            tokio::fs::create_dir_all(parent).await?;
72        }
73
74        if existed {
75            ctx.checkpoint_state.snapshot_paths(
76                std::slice::from_ref(&path),
77                Some(format!("write {}", path.display())),
78            )?;
79        }
80
81        // Detect existing line endings to preserve them, default to LF for new files
82        let normalized = if existed {
83            if let Ok(existing) = tokio::fs::read(&path).await {
84                let has_crlf = existing.windows(2).any(|w| w == b"\r\n");
85                if has_crlf {
86                    // Preserve CRLF: ensure content uses CRLF
87                    let lf_content = content.replace("\r\n", "\n");
88                    lf_content.replace('\n', "\r\n")
89                } else {
90                    // LF or no newlines — ensure LF
91                    content.replace("\r\n", "\n")
92                }
93            } else {
94                content.replace("\r\n", "\n")
95            }
96        } else {
97            content.replace("\r\n", "\n")
98        };
99
100        let bytes_written = normalized.len();
101        tokio::fs::write(&path, &normalized).await?;
102
103        let action = if existed { "overwritten" } else { "created" };
104        let display = path.display().to_string();
105        let summary = format!("{display}: {bytes_written} bytes {action}");
106
107        const DISPLAY_MAX_LINES: usize = 40;
108        const DISPLAY_MAX_BYTES: usize = 8_000;
109        let display_source = normalized.replace("\r\n", "\n");
110        let display_result = truncate_head(&display_source, DISPLAY_MAX_LINES, DISPLAY_MAX_BYTES);
111        let display_content = display_result.content.trim_end_matches('\n').to_string();
112        let display_note = if display_result.truncated {
113            let note = format!(
114                "[output truncated: showing {}/{} lines, {}/{} bytes]",
115                display_result.output_lines,
116                display_result.total_lines,
117                display_result.output_bytes,
118                display_result.total_bytes,
119            );
120            if let Some(ref tf) = display_result.temp_file {
121                format!("{note} full output: {}", tf.display())
122            } else {
123                note
124            }
125        } else {
126            String::new()
127        };
128
129        let mut warnings = Vec::new();
130        if let Some(warning) = tracker_warning {
131            warnings.push(warning);
132        }
133
134        let mut text = summary.clone();
135        for warning in &warnings {
136            text.push('\n');
137            text.push_str(warning);
138        }
139
140        Ok(ToolOutput {
141            content: vec![imp_llm::ContentBlock::Text { text }],
142            details: json!({
143                "path": display,
144                "bytes": bytes_written,
145                "created": !existed,
146                "summary": summary,
147                "warnings": warnings,
148                "display_content": display_content,
149                "display_note": display_note,
150            }),
151            is_error: false,
152        })
153    }
154}
155
156#[cfg(test)]
157mod tests {
158    use super::*;
159    use crate::tools::ToolContext;
160    use std::path::Path;
161    use std::sync::Arc;
162
163    fn test_ctx(dir: &Path) -> ToolContext {
164        let (tx, _rx) = tokio::sync::mpsc::channel(16);
165        let (cmd_tx, _cmd_rx) = tokio::sync::mpsc::channel(16);
166        ToolContext {
167            cwd: dir.to_path_buf(),
168            cancelled: Arc::new(std::sync::atomic::AtomicBool::new(false)),
169            update_tx: tx,
170            command_tx: cmd_tx,
171            ui: Arc::new(crate::ui::NullInterface),
172            file_cache: Arc::new(crate::tools::FileCache::new()),
173            checkpoint_state: Arc::new(crate::tools::CheckpointState::new()),
174            file_tracker: Arc::new(std::sync::Mutex::new(crate::tools::FileTracker::new())),
175            anchor_store: Arc::new(crate::tools::AnchorStore::new()),
176            lua_tool_loader: None,
177            mode: crate::config::AgentMode::Full,
178            read_max_lines: 500,
179            turn_mana_review: Arc::new(std::sync::Mutex::new(
180                crate::mana_review::TurnManaReviewAccumulator::default(),
181            )),
182            config: Arc::new(crate::config::Config::default()),
183        }
184    }
185
186    #[tokio::test]
187    async fn write_new_file() {
188        let dir = tempfile::tempdir().unwrap();
189        let tool = WriteTool;
190
191        let result = tool
192            .execute(
193                "c1",
194                serde_json::json!({"path": "new.txt", "content": "hello world"}),
195                test_ctx(dir.path()),
196            )
197            .await
198            .unwrap();
199
200        assert!(!result.is_error);
201        let details = &result.details;
202        assert_eq!(details["display_content"], "hello world");
203        assert!(details["summary"]
204            .as_str()
205            .unwrap()
206            .ends_with("new.txt: 11 bytes created"));
207        let written = std::fs::read_to_string(dir.path().join("new.txt")).unwrap();
208        assert_eq!(written, "hello world");
209    }
210
211    #[tokio::test]
212    async fn write_creates_parent_dirs() {
213        let dir = tempfile::tempdir().unwrap();
214        let tool = WriteTool;
215
216        let result = tool
217            .execute(
218                "c2",
219                serde_json::json!({"path": "a/b/c/deep.txt", "content": "deep"}),
220                test_ctx(dir.path()),
221            )
222            .await
223            .unwrap();
224
225        assert!(!result.is_error);
226        let written = std::fs::read_to_string(dir.path().join("a/b/c/deep.txt")).unwrap();
227        assert_eq!(written, "deep");
228    }
229
230    #[tokio::test]
231    async fn write_overwrite_creates_checkpoint_snapshot() {
232        let dir = tempfile::tempdir().unwrap();
233        let file = dir.path().join("existing.txt");
234        std::fs::write(&file, "original").unwrap();
235
236        let tool = WriteTool;
237        let ctx = test_ctx(dir.path());
238        let checkpoint_state = ctx.checkpoint_state.clone();
239
240        let result = tool
241            .execute(
242                "c-overwrite",
243                serde_json::json!({"path": "existing.txt", "content": "updated"}),
244                ctx,
245            )
246            .await
247            .unwrap();
248
249        assert!(!result.is_error);
250        assert_eq!(
251            checkpoint_state.original(&file).as_deref(),
252            Some("original")
253        );
254        let checkpoints = checkpoint_state.checkpoints();
255        assert_eq!(checkpoints.len(), 1);
256        assert!(checkpoints[0].files.contains(&file));
257    }
258
259    #[tokio::test]
260    async fn write_empty_content() {
261        let dir = tempfile::tempdir().unwrap();
262        let tool = WriteTool;
263
264        let result = tool
265            .execute(
266                "c4",
267                serde_json::json!({"path": "empty.txt", "content": ""}),
268                test_ctx(dir.path()),
269            )
270            .await
271            .unwrap();
272
273        assert!(!result.is_error);
274        let written = std::fs::read_to_string(dir.path().join("empty.txt")).unwrap();
275        assert_eq!(written, "");
276        assert_eq!(result.details["display_content"], "");
277    }
278
279    #[tokio::test]
280    async fn write_missing_path_error() {
281        let dir = tempfile::tempdir().unwrap();
282        let tool = WriteTool;
283
284        let result = tool
285            .execute(
286                "c5",
287                serde_json::json!({"content": "hello"}),
288                test_ctx(dir.path()),
289            )
290            .await
291            .unwrap();
292
293        assert!(result.is_error);
294    }
295
296    #[tokio::test]
297    async fn write_preserves_crlf_on_overwrite() {
298        let dir = tempfile::tempdir().unwrap();
299        let file = dir.path().join("crlf.txt");
300        // Write a CRLF file first
301        std::fs::write(&file, "line1\r\nline2\r\n").unwrap();
302
303        let tool = WriteTool;
304        let result = tool
305            .execute(
306                "c6",
307                serde_json::json!({"path": "crlf.txt", "content": "new1\nnew2\n"}),
308                test_ctx(dir.path()),
309            )
310            .await
311            .unwrap();
312
313        assert!(!result.is_error);
314        let raw = std::fs::read(dir.path().join("crlf.txt")).unwrap();
315        // Should convert LF to CRLF since original had CRLF
316        assert!(raw.windows(2).any(|w| w == b"\r\n"));
317    }
318
319    #[tokio::test]
320    async fn write_deep_nested_dirs() {
321        let dir = tempfile::tempdir().unwrap();
322        let tool = WriteTool;
323
324        let result = tool
325            .execute(
326                "c7",
327                serde_json::json!({"path": "x/y/z/w/v/deep.txt", "content": "deep content"}),
328                test_ctx(dir.path()),
329            )
330            .await
331            .unwrap();
332
333        assert!(!result.is_error);
334        let written = std::fs::read_to_string(dir.path().join("x/y/z/w/v/deep.txt")).unwrap();
335        assert_eq!(written, "deep content");
336    }
337
338    #[tokio::test]
339    async fn write_overwrites_existing() {
340        let dir = tempfile::tempdir().unwrap();
341        let file = dir.path().join("exist.txt");
342        std::fs::write(&file, "old content").unwrap();
343
344        let tool = WriteTool;
345        let result = tool
346            .execute(
347                "c3",
348                serde_json::json!({"path": "exist.txt", "content": "new content"}),
349                test_ctx(dir.path()),
350            )
351            .await
352            .unwrap();
353
354        assert!(!result.is_error);
355        let text = result
356            .content
357            .iter()
358            .find_map(|b| match b {
359                imp_llm::ContentBlock::Text { text } => Some(text.as_str()),
360                _ => None,
361            })
362            .unwrap();
363        assert!(text.contains("overwritten"));
364        let written = std::fs::read_to_string(&file).unwrap();
365        assert_eq!(written, "new content");
366    }
367
368    #[tokio::test]
369    async fn write_includes_display_content_metadata() {
370        let dir = tempfile::tempdir().unwrap();
371        let tool = WriteTool;
372
373        let result = tool
374            .execute(
375                "c8",
376                serde_json::json!({"path": "preview.rs", "content": "fn main() {\n    println!(\"hi\");\n}\n"}),
377                test_ctx(dir.path()),
378            )
379            .await
380            .unwrap();
381
382        assert!(!result.is_error);
383        assert!(result.details["path"]
384            .as_str()
385            .unwrap()
386            .ends_with("preview.rs"));
387        assert!(result.details["summary"]
388            .as_str()
389            .unwrap()
390            .ends_with("preview.rs: 34 bytes created"));
391        assert_eq!(
392            result.details["display_content"],
393            "fn main() {\n    println!(\"hi\");\n}"
394        );
395        assert_eq!(result.details["display_note"], "");
396    }
397
398    #[tokio::test]
399    async fn write_display_content_truncates_large_content() {
400        let dir = tempfile::tempdir().unwrap();
401        let tool = WriteTool;
402        let content = (0..100)
403            .map(|i| format!("line {i}"))
404            .collect::<Vec<_>>()
405            .join("\n");
406
407        let result = tool
408            .execute(
409                "c9",
410                serde_json::json!({"path": "large.txt", "content": content}),
411                test_ctx(dir.path()),
412            )
413            .await
414            .unwrap();
415
416        assert!(!result.is_error);
417        let display_content = result.details["display_content"].as_str().unwrap();
418        assert!(display_content.lines().count() <= 40);
419        assert!(result.details["display_note"]
420            .as_str()
421            .unwrap()
422            .contains("output truncated"));
423    }
424}