Skip to main content

astrid_tools/
write_file.rs

1//! Write file tool — writes content to a file, creating parent directories as needed.
2
3use crate::{BuiltinTool, ToolContext, ToolError, ToolResult};
4use serde_json::Value;
5
6/// Built-in tool for writing files.
7pub struct WriteFileTool;
8
9#[async_trait::async_trait]
10impl BuiltinTool for WriteFileTool {
11    fn name(&self) -> &'static str {
12        "write_file"
13    }
14
15    fn description(&self) -> &'static str {
16        "Writes content to a file. Creates parent directories if they don't exist. \
17         Overwrites the file if it already exists."
18    }
19
20    fn input_schema(&self) -> Value {
21        serde_json::json!({
22            "type": "object",
23            "properties": {
24                "file_path": {
25                    "type": "string",
26                    "description": "Absolute path to the file to write"
27                },
28                "content": {
29                    "type": "string",
30                    "description": "The content to write to the file"
31                }
32            },
33            "required": ["file_path", "content"]
34        })
35    }
36
37    async fn execute(&self, args: Value, _ctx: &ToolContext) -> ToolResult {
38        let file_path = args
39            .get("file_path")
40            .and_then(Value::as_str)
41            .ok_or_else(|| ToolError::InvalidArguments("file_path is required".into()))?;
42
43        let content = args
44            .get("content")
45            .and_then(Value::as_str)
46            .ok_or_else(|| ToolError::InvalidArguments("content is required".into()))?;
47
48        let path = std::path::Path::new(file_path);
49        if !path.is_absolute() {
50            return Err(ToolError::InvalidArguments(
51                "file_path must be an absolute path".into(),
52            ));
53        }
54
55        // Create parent directories
56        if let Some(parent) = path.parent() {
57            tokio::fs::create_dir_all(parent).await?;
58        }
59
60        tokio::fs::write(path, content).await?;
61
62        let bytes = content.len();
63        Ok(format!("Wrote {bytes} bytes to {file_path}"))
64    }
65}
66
67#[cfg(test)]
68mod tests {
69    use super::*;
70    use tempfile::TempDir;
71
72    fn ctx() -> ToolContext {
73        ToolContext::new(std::env::temp_dir(), None)
74    }
75
76    #[tokio::test]
77    async fn test_write_file_basic() {
78        let dir = TempDir::new().unwrap();
79        let path = dir.path().join("test.txt");
80
81        let result = WriteFileTool
82            .execute(
83                serde_json::json!({
84                    "file_path": path.to_str().unwrap(),
85                    "content": "hello world"
86                }),
87                &ctx(),
88            )
89            .await
90            .unwrap();
91
92        assert!(result.contains("11 bytes"));
93        assert_eq!(std::fs::read_to_string(&path).unwrap(), "hello world");
94    }
95
96    #[tokio::test]
97    async fn test_write_file_creates_dirs() {
98        let dir = TempDir::new().unwrap();
99        let path = dir.path().join("a").join("b").join("c").join("test.txt");
100
101        WriteFileTool
102            .execute(
103                serde_json::json!({
104                    "file_path": path.to_str().unwrap(),
105                    "content": "nested"
106                }),
107                &ctx(),
108            )
109            .await
110            .unwrap();
111
112        assert_eq!(std::fs::read_to_string(&path).unwrap(), "nested");
113    }
114
115    #[tokio::test]
116    async fn test_write_file_overwrites() {
117        let dir = TempDir::new().unwrap();
118        let path = dir.path().join("test.txt");
119        std::fs::write(&path, "old content").unwrap();
120
121        WriteFileTool
122            .execute(
123                serde_json::json!({
124                    "file_path": path.to_str().unwrap(),
125                    "content": "new content"
126                }),
127                &ctx(),
128            )
129            .await
130            .unwrap();
131
132        assert_eq!(std::fs::read_to_string(&path).unwrap(), "new content");
133    }
134
135    #[tokio::test]
136    async fn test_write_file_missing_args() {
137        let result = WriteFileTool
138            .execute(serde_json::json!({"file_path": "/tmp/test.txt"}), &ctx())
139            .await;
140        assert!(result.is_err());
141
142        let result = WriteFileTool
143            .execute(serde_json::json!({"content": "hello"}), &ctx())
144            .await;
145        assert!(result.is_err());
146    }
147}