Skip to main content

ai_agent/tools/
write.rs

1use crate::types::*;
2use std::fs;
3
4pub struct FileWriteTool;
5
6impl FileWriteTool {
7    pub fn new() -> Self {
8        Self
9    }
10
11    pub fn name(&self) -> &str {
12        "Write"
13    }
14
15    pub fn description(&self) -> &str {
16        "Write content to files"
17    }
18
19    pub fn user_facing_name(&self, _input: Option<&serde_json::Value>) -> String {
20        "Write".to_string()
21    }
22
23    pub fn get_tool_use_summary(&self, input: Option<&serde_json::Value>) -> Option<String> {
24        input.and_then(|inp| inp["file_path"].as_str().map(String::from))
25    }
26
27    pub fn render_tool_result_message(
28        &self,
29        content: &serde_json::Value,
30    ) -> Option<String> {
31        let text = content["content"].as_str()?;
32        Some(text.to_string())
33    }
34
35    pub fn input_schema(&self) -> ToolInputSchema {
36        ToolInputSchema {
37            schema_type: "object".to_string(),
38            properties: serde_json::json!({
39                "file_path": {
40                    "type": "string",
41                    "description": "The absolute path to the file to write (must be absolute, not relative)"
42                },
43                "content": {
44                    "type": "string",
45                    "description": "The content to write to the file"
46                }
47            }),
48            required: Some(vec!["file_path".to_string(), "content".to_string()]),
49        }
50    }
51
52    pub async fn execute(
53        &self,
54        input: serde_json::Value,
55        context: &ToolContext,
56    ) -> Result<ToolResult, crate::error::AgentError> {
57        let path = input["file_path"]
58            .as_str()
59            .ok_or_else(|| crate::error::AgentError::Tool("file_path is required".to_string()))?;
60
61        let content = input["content"]
62            .as_str()
63            .ok_or_else(|| crate::error::AgentError::Tool("content is required".to_string()))?;
64
65        // Resolve relative paths using cwd from context
66        let path = if std::path::Path::new(path).is_relative() {
67            std::path::Path::new(&context.cwd).join(path)
68        } else {
69            std::path::Path::new(path).to_path_buf()
70        };
71
72        fs::write(&path, content).map_err(|e| crate::error::AgentError::Io(e))?;
73
74        Ok(ToolResult {
75            result_type: "text".to_string(),
76            tool_use_id: "".to_string(),
77            content: format!("Successfully wrote to {}", path.display()),
78            is_error: None,
79            was_persisted: None,
80        })
81    }
82}
83
84#[cfg(test)]
85mod tests {
86    use super::*;
87
88    #[test]
89    fn test_file_write_tool_name() {
90        let tool = FileWriteTool::new();
91        assert_eq!(tool.name(), "Write");
92    }
93
94    #[test]
95    fn test_file_write_tool_description_contains_write() {
96        let tool = FileWriteTool::new();
97        assert!(tool.description().to_lowercase().contains("write"));
98    }
99
100    #[test]
101    fn test_file_write_tool_has_path_in_schema() {
102        let tool = FileWriteTool::new();
103        let schema = tool.input_schema();
104        assert!(schema.properties.get("file_path").is_some());
105    }
106
107    #[test]
108    fn test_file_write_tool_has_content_in_schema() {
109        let tool = FileWriteTool::new();
110        let schema = tool.input_schema();
111        assert!(schema.properties.get("content").is_some());
112    }
113
114    #[tokio::test]
115    async fn test_file_write_tool_creates_file() {
116        let temp_dir = std::env::temp_dir();
117        let temp_file = temp_dir.join("test_write_file.txt");
118
119        let tool = FileWriteTool::new();
120        let input = serde_json::json!({
121            "file_path": temp_file.to_str().unwrap(),
122            "content": "Test content"
123        });
124        let context = ToolContext::default();
125
126        let result = tool.execute(input, &context).await;
127        assert!(result.is_ok());
128
129        // Verify file was created with correct content
130        let read_content = std::fs::read_to_string(&temp_file).unwrap();
131        assert_eq!(read_content, "Test content");
132
133        // Cleanup
134        std::fs::remove_file(temp_file).ok();
135    }
136
137    #[tokio::test]
138    async fn test_file_write_tool_overwrites_existing_file() {
139        let temp_dir = std::env::temp_dir();
140        let temp_file = temp_dir.join("test_overwrite_file.txt");
141        std::fs::write(&temp_file, "Original content").unwrap();
142
143        let tool = FileWriteTool::new();
144        let input = serde_json::json!({
145            "file_path": temp_file.to_str().unwrap(),
146            "content": "New content"
147        });
148        let context = ToolContext::default();
149
150        let result = tool.execute(input, &context).await;
151        assert!(result.is_ok());
152
153        // Verify file was overwritten
154        let read_content = std::fs::read_to_string(&temp_file).unwrap();
155        assert_eq!(read_content, "New content");
156
157        // Cleanup
158        std::fs::remove_file(temp_file).ok();
159    }
160}