Skip to main content

oxi_agent/tools/
write.rs

1/// Write file tool
2/// Supports:
3/// - Creating parent directories if they don't exist
4/// - Append mode (append=true)
5/// - Line count reporting
6/// - Diff-style output preview (first/last few lines for large files)
7/// - File mutation queue for serialized writes (concurrent safety)
8/// - Output truncation for very large content
9use super::file_mutation_queue::global_mutation_queue;
10use super::path_security::PathGuard;
11use super::truncate::{self, TruncationOptions};
12use super::{AgentTool, AgentToolResult, ToolContext, ToolError};
13use async_trait::async_trait;
14use serde_json::{json, Value};
15use std::path::{Path, PathBuf};
16use tokio::fs;
17use tokio::sync::oneshot;
18
19/// Maximum number of lines to show in the diff-style output preview
20const PREVIEW_HEAD_LINES: usize = 5;
21const PREVIEW_TAIL_LINES: usize = 5;
22/// Threshold above which we switch from full-content to head/tail preview display
23const PREVIEW_THRESHOLD_LINES: usize = 20;
24
25/// WriteTool.
26pub struct WriteTool {
27    root_dir: Option<PathBuf>,
28}
29
30impl WriteTool {
31    /// Create with no explicit root (uses ToolContext.workspace_dir at runtime).
32    pub fn new() -> Self {
33        Self { root_dir: None }
34    }
35
36    /// Create with a specific working directory (overrides ToolContext).
37    pub fn with_cwd(cwd: PathBuf) -> Self {
38        Self {
39            root_dir: Some(cwd),
40        }
41    }
42
43    /// Build a human-readable preview of the content that was written.
44    /// For small files, shows everything. For large files, shows first/last few lines.
45    fn build_content_preview(content: &str, total_lines: usize) -> String {
46        if total_lines <= PREVIEW_THRESHOLD_LINES {
47            return content.to_string();
48        }
49
50        let lines: Vec<&str> = content.lines().collect();
51        let head: Vec<&str> = lines.iter().copied().take(PREVIEW_HEAD_LINES).collect();
52        let tail: Vec<&str> = lines
53            .iter()
54            .copied()
55            .rev()
56            .take(PREVIEW_TAIL_LINES)
57            .rev()
58            .collect();
59
60        let omitted = total_lines - PREVIEW_HEAD_LINES - PREVIEW_TAIL_LINES;
61
62        format!(
63            "{}\n\n... [{} lines omitted] ...\n\n{}",
64            head.join("\n"),
65            omitted,
66            tail.join("\n")
67        )
68    }
69
70    /// Core write implementation — runs inside the mutation queue lock.
71    async fn write_file_impl(
72        root_dir: &Path,
73        path: &str,
74        content: &str,
75        append: bool,
76    ) -> Result<String, ToolError> {
77        // Security: validate path with PathGuard
78        let guard = PathGuard::new(root_dir);
79        let file_path = guard
80            .validate_traversal(Path::new(path))
81            .map_err(|e| e.to_string())?;
82
83        // Ensure parent directory exists (create if missing)
84        if let Some(parent) = file_path.parent() {
85            // Only try to create if the parent is non-empty (e.g. not just "")
86            if !parent.as_os_str().is_empty() {
87                fs::create_dir_all(parent)
88                    .await
89                    .map_err(|e| format!("Cannot create parent directory: {}", e))?;
90            }
91        }
92
93        // Check if file already existed before write (for reporting)
94        let existed = file_path.exists();
95
96        // Perform the write through the mutation queue for serialized access
97        let content_owned = content.to_string();
98        let result = global_mutation_queue()
99            .with_queue(&file_path, || async {
100                if append {
101                    let mut file = tokio::fs::OpenOptions::new()
102                        .create(true)
103                        .append(true)
104                        .open(&file_path)
105                        .await
106                        .map_err(|e| format!("Cannot open file for append: {}", e))?;
107                    use tokio::io::AsyncWriteExt;
108                    file.write_all(content_owned.as_bytes())
109                        .await
110                        .map_err(|e| format!("Cannot write file: {}", e))?;
111                    file.flush()
112                        .await
113                        .map_err(|e| format!("Cannot flush file: {}", e))?;
114                } else {
115                    fs::write(&file_path, &content_owned)
116                        .await
117                        .map_err(|e| format!("Cannot write file: {}", e))?;
118                }
119                Ok::<(), ToolError>(())
120            })
121            .await;
122
123        result?;
124
125        let total_lines = content.lines().count();
126        let total_bytes = content.len();
127        let action = if append { "Appended" } else { "Wrote" };
128        let status = if existed && !append {
129            " (overwritten)"
130        } else if append && existed {
131            " (appended)"
132        } else if !existed {
133            " (new file)"
134        } else {
135            ""
136        };
137
138        // Build result with preview
139        let preview = Self::build_content_preview(content, total_lines);
140
141        // Truncate the preview if very large
142        let truncation_opts = TruncationOptions {
143            max_lines: Some(50),
144            max_bytes: Some(4 * 1024),
145        };
146        let truncated = truncate::truncate_head(&preview, &truncation_opts);
147
148        let mut msg = format!(
149            "{} {} lines ({} bytes) to {}{}\n",
150            action, total_lines, total_bytes, path, status
151        );
152
153        msg.push_str(&format!("--- Content Preview ---\n{}", truncated.content));
154
155        if truncated.truncated {
156            msg.push_str(&format!(
157                "\n[Output truncated: {} total lines, {} total bytes]",
158                truncated.total_lines, truncated.total_bytes
159            ));
160        }
161
162        Ok(msg)
163    }
164}
165
166impl Default for WriteTool {
167    fn default() -> Self {
168        Self::new()
169    }
170}
171
172#[async_trait]
173impl AgentTool for WriteTool {
174    fn name(&self) -> &str {
175        "write"
176    }
177
178    fn label(&self) -> &str {
179        "Write File"
180    }
181
182    fn essential(&self) -> bool {
183        true
184    }
185    fn description(&self) -> &str {
186        "Write content to a file, creating parent directories as needed. Existing files will be overwritten. Use append=true to append to existing files."
187    }
188
189    fn parameters_schema(&self) -> Value {
190        json!({
191            "type": "object",
192            "properties": {
193                "path": {
194                    "type": "string",
195                    "description": "The path to the file to write"
196                },
197                "content": {
198                    "type": "string",
199                    "description": "The content to write to the file"
200                },
201                "append": {
202                    "type": "boolean",
203                    "description": "If true, append to existing file instead of overwriting",
204                    "default": false
205                }
206            },
207            "required": ["path", "content"]
208        })
209    }
210
211    async fn execute(
212        &self,
213        _tool_call_id: &str,
214        params: Value,
215        _signal: Option<oneshot::Receiver<()>>,
216        ctx: &ToolContext,
217    ) -> Result<AgentToolResult, ToolError> {
218        let path = params
219            .get("path")
220            .and_then(|v| v.as_str())
221            .ok_or_else(|| "Missing required parameter: path".to_string())?;
222
223        let content = params
224            .get("content")
225            .and_then(|v| v.as_str())
226            .ok_or_else(|| "Missing required parameter: content".to_string())?;
227
228        let append = params
229            .get("append")
230            .and_then(|v| v.as_bool())
231            .unwrap_or(false);
232
233        // Use root_dir if set, else ctx.root()
234        let root = self.root_dir.as_deref().unwrap_or(ctx.root());
235
236        match Self::write_file_impl(root, path, content, append).await {
237            Ok(msg) => Ok(AgentToolResult::success(msg)),
238            Err(e) => Ok(AgentToolResult::error(e)),
239        }
240    }
241}
242
243#[cfg(test)]
244mod tests {
245    use super::*;
246    use tempfile::TempDir;
247
248    #[test]
249    fn test_build_content_preview_small() {
250        let content = "line1\nline2\nline3";
251        let preview = WriteTool::build_content_preview(content, 3);
252        assert_eq!(preview, content);
253    }
254
255    #[test]
256    fn test_build_content_preview_large() {
257        let lines: Vec<String> = (1..=30).map(|i| format!("line {}", i)).collect();
258        let content = lines.join("\n");
259        let preview = WriteTool::build_content_preview(&content, 30);
260
261        assert!(preview.contains("line 1"));
262        assert!(preview.contains("line 5"));
263        assert!(preview.contains("line 26"));
264        assert!(preview.contains("line 30"));
265        assert!(preview.contains("lines omitted"));
266        assert!(!preview.contains("line 10")); // middle should be omitted
267    }
268
269    #[test]
270    fn test_build_content_preview_exact_threshold() {
271        let lines: Vec<String> = (1..=20).map(|i| format!("line {}", i)).collect();
272        let content = lines.join("\n");
273        let preview = WriteTool::build_content_preview(&content, 20);
274        // At threshold, should show full content
275        assert_eq!(preview, content);
276    }
277
278    #[test]
279    fn test_build_content_preview_one_over_threshold() {
280        let lines: Vec<String> = (1..=21).map(|i| format!("line {}", i)).collect();
281        let content = lines.join("\n");
282        let preview = WriteTool::build_content_preview(&content, 21);
283        // Over threshold, should show head/tail
284        assert!(preview.contains("lines omitted"));
285    }
286
287    #[tokio::test]
288    async fn test_write_new_file() {
289        let tmp = TempDir::new().unwrap();
290        let path = tmp.path().join("test.txt");
291        let path_str = path.to_str().unwrap();
292
293        let result =
294            WriteTool::write_file_impl(Path::new("."), path_str, "hello world\nline 2", false)
295                .await;
296        assert!(result.is_ok());
297
298        let written = std::fs::read_to_string(&path).unwrap();
299        assert_eq!(written, "hello world\nline 2");
300
301        let msg = result.unwrap();
302        assert!(msg.contains("2 lines"));
303        assert!(msg.contains("new file"));
304    }
305
306    #[tokio::test]
307    async fn test_write_creates_parent_dirs() {
308        let tmp = TempDir::new().unwrap();
309        let path = tmp.path().join("a/b/c/test.txt");
310        let path_str = path.to_str().unwrap();
311
312        let result =
313            WriteTool::write_file_impl(Path::new("."), path_str, "deep nested", false).await;
314        assert!(result.is_ok());
315
316        let written = std::fs::read_to_string(&path).unwrap();
317        assert_eq!(written, "deep nested");
318    }
319
320    #[tokio::test]
321    async fn test_write_overwrites_existing() {
322        let tmp = TempDir::new().unwrap();
323        let path = tmp.path().join("test.txt");
324        let path_str = path.to_str().unwrap();
325
326        // Create initial file
327        std::fs::write(&path, "old content").unwrap();
328
329        let result =
330            WriteTool::write_file_impl(Path::new("."), path_str, "new content", false).await;
331        assert!(result.is_ok());
332
333        let written = std::fs::read_to_string(&path).unwrap();
334        assert_eq!(written, "new content");
335
336        let msg = result.unwrap();
337        assert!(msg.contains("overwritten"));
338    }
339
340    #[tokio::test]
341    async fn test_write_append_mode() {
342        let tmp = TempDir::new().unwrap();
343        let path = tmp.path().join("test.txt");
344        let path_str = path.to_str().unwrap();
345
346        // Write initial content
347        WriteTool::write_file_impl(Path::new("."), path_str, "line 1\n", false)
348            .await
349            .unwrap();
350
351        // Append to it
352        let result = WriteTool::write_file_impl(Path::new("."), path_str, "line 2\n", true).await;
353        assert!(result.is_ok());
354
355        let written = std::fs::read_to_string(&path).unwrap();
356        assert_eq!(written, "line 1\nline 2\n");
357
358        let msg = result.unwrap();
359        assert!(msg.contains("Appended"));
360    }
361
362    #[tokio::test]
363    async fn test_write_append_to_nonexistent() {
364        let tmp = TempDir::new().unwrap();
365        let path = tmp.path().join("new.txt");
366        let path_str = path.to_str().unwrap();
367
368        let result =
369            WriteTool::write_file_impl(Path::new("."), path_str, "appended content", true).await;
370        assert!(result.is_ok());
371
372        let written = std::fs::read_to_string(&path).unwrap();
373        assert_eq!(written, "appended content");
374    }
375
376    #[tokio::test]
377    async fn test_write_path_traversal_blocked() {
378        let result =
379            WriteTool::write_file_impl(Path::new("."), "../../etc/passwd", "hack", false).await;
380        assert!(result.is_err());
381        assert!(result.unwrap_err().contains("Path traversal"));
382    }
383
384    #[tokio::test]
385    async fn test_write_empty_content() {
386        let tmp = TempDir::new().unwrap();
387        let path = tmp.path().join("empty.txt");
388        let path_str = path.to_str().unwrap();
389
390        let result = WriteTool::write_file_impl(Path::new("."), path_str, "", false).await;
391        assert!(result.is_ok());
392
393        let written = std::fs::read_to_string(&path).unwrap();
394        assert_eq!(written, "");
395
396        let msg = result.unwrap();
397        assert!(msg.contains("0 lines"));
398    }
399
400    #[tokio::test]
401    async fn test_write_large_file_has_preview() {
402        let tmp = TempDir::new().unwrap();
403        let path = tmp.path().join("large.txt");
404        let path_str = path.to_str().unwrap();
405
406        let lines: Vec<String> = (1..=100).map(|i| format!("line {}", i)).collect();
407        let content = lines.join("\n");
408
409        let result = WriteTool::write_file_impl(Path::new("."), path_str, &content, false).await;
410        assert!(result.is_ok());
411
412        let msg = result.unwrap();
413        assert!(msg.contains("100 lines"));
414        assert!(msg.contains("Content Preview"));
415    }
416
417    #[tokio::test]
418    async fn test_execute_via_tool_trait() {
419        let tmp = TempDir::new().unwrap();
420        let path = tmp.path().join("trait_test.txt");
421        let path_str = path.to_str().unwrap().to_string();
422
423        let tool = WriteTool::new();
424        let params = json!({
425            "path": path_str,
426            "content": "via trait"
427        });
428
429        let result = tool
430            .execute("test-id", params, None, &ToolContext::default())
431            .await;
432        assert!(result.is_ok());
433        let tool_result = result.unwrap();
434        assert!(tool_result.success);
435        assert!(tool_result.output.contains("via trait"));
436
437        let written = std::fs::read_to_string(&path).unwrap();
438        assert_eq!(written, "via trait");
439    }
440
441    #[tokio::test]
442    async fn test_execute_missing_path_param() {
443        let tool = WriteTool::new();
444        let params = json!({
445            "content": "no path"
446        });
447
448        let result = tool
449            .execute("test-id", params, None, &ToolContext::default())
450            .await;
451        assert!(result.is_err());
452        assert!(result.unwrap_err().contains("path"));
453    }
454
455    #[tokio::test]
456    async fn test_execute_missing_content_param() {
457        let tool = WriteTool::new();
458        let params = json!({
459            "path": "/tmp/test.txt"
460        });
461
462        let result = tool
463            .execute("test-id", params, None, &ToolContext::default())
464            .await;
465        assert!(result.is_err());
466        assert!(result.unwrap_err().contains("content"));
467    }
468
469    #[tokio::test]
470    async fn test_execute_append_via_trait() {
471        let tmp = TempDir::new().unwrap();
472        let path = tmp.path().join("append_trait.txt");
473        let path_str = path.to_str().unwrap().to_string();
474
475        let tool = WriteTool::new();
476
477        // First write
478        let params = json!({
479            "path": &path_str,
480            "content": "first "
481        });
482        tool.execute("test-id-1", params, None, &ToolContext::default())
483            .await
484            .unwrap();
485
486        // Append
487        let params = json!({
488            "path": &path_str,
489            "content": "second",
490            "append": true
491        });
492        let result = tool
493            .execute("test-id-2", params, None, &ToolContext::default())
494            .await
495            .unwrap();
496        assert!(result.success);
497        assert!(result.output.contains("Appended"));
498
499        let written = std::fs::read_to_string(&path).unwrap();
500        assert_eq!(written, "first second");
501    }
502
503    #[test]
504    fn test_default_impl() {
505        let tool = WriteTool::default();
506        assert_eq!(tool.name(), "write");
507        assert_eq!(tool.label(), "Write File");
508    }
509
510    #[test]
511    fn test_parameters_schema_required_fields() {
512        let tool = WriteTool::new();
513        let schema = tool.parameters_schema();
514        let required = schema.get("required").unwrap().as_array().unwrap();
515        assert!(required.contains(&json!("path")));
516        assert!(required.contains(&json!("content")));
517        assert!(!required.contains(&json!("append"))); // append is optional
518    }
519}