Skip to main content

ai_agent/tools/
write.rs

1use crate::types::*;
2use std::fs;
3
4pub struct FileWriteTool;
5
6impl FileWriteTool {
7    pub fn new() -> Self {
8        Self
9    }
10
11    pub fn name(&self) -> &str {
12        "FileWrite"
13    }
14
15    pub fn description(&self) -> &str {
16        "Write content to files"
17    }
18
19    pub fn input_schema(&self) -> ToolInputSchema {
20        ToolInputSchema {
21            schema_type: "object".to_string(),
22            properties: serde_json::json!({
23                "path": {
24                    "type": "string",
25                    "description": "The file path to write to"
26                },
27                "content": {
28                    "type": "string",
29                    "description": "The content to write"
30                }
31            }),
32            required: Some(vec!["path".to_string(), "content".to_string()]),
33        }
34    }
35
36    pub async fn execute(
37        &self,
38        input: serde_json::Value,
39        context: &ToolContext,
40    ) -> Result<ToolResult, crate::error::AgentError> {
41        let path = input["path"]
42            .as_str()
43            .ok_or_else(|| crate::error::AgentError::Tool("path is required".to_string()))?;
44
45        let content = input["content"]
46            .as_str()
47            .ok_or_else(|| crate::error::AgentError::Tool("content is required".to_string()))?;
48
49        // Resolve relative paths using cwd from context
50        let path = if std::path::Path::new(path).is_relative() {
51            std::path::Path::new(&context.cwd).join(path)
52        } else {
53            std::path::Path::new(path).to_path_buf()
54        };
55
56        fs::write(&path, content).map_err(|e| crate::error::AgentError::Io(e))?;
57
58        Ok(ToolResult {
59            result_type: "text".to_string(),
60            tool_use_id: "".to_string(),
61            content: format!("Successfully wrote to {}", path.display()),
62            is_error: None,
63        })
64    }
65}
66
67#[cfg(test)]
68mod tests {
69    use super::*;
70
71    #[test]
72    fn test_file_write_tool_name() {
73        let tool = FileWriteTool::new();
74        assert_eq!(tool.name(), "FileWrite");
75    }
76
77    #[test]
78    fn test_file_write_tool_description_contains_write() {
79        let tool = FileWriteTool::new();
80        assert!(tool.description().to_lowercase().contains("write"));
81    }
82
83    #[test]
84    fn test_file_write_tool_has_path_in_schema() {
85        let tool = FileWriteTool::new();
86        let schema = tool.input_schema();
87        assert!(schema.properties.get("path").is_some());
88    }
89
90    #[test]
91    fn test_file_write_tool_has_content_in_schema() {
92        let tool = FileWriteTool::new();
93        let schema = tool.input_schema();
94        assert!(schema.properties.get("content").is_some());
95    }
96
97    #[tokio::test]
98    async fn test_file_write_tool_creates_file() {
99        let temp_dir = std::env::temp_dir();
100        let temp_file = temp_dir.join("test_write_file.txt");
101
102        let tool = FileWriteTool::new();
103        let input = serde_json::json!({
104            "path": temp_file.to_str().unwrap(),
105            "content": "Test content"
106        });
107        let context = ToolContext::default();
108
109        let result = tool.execute(input, &context).await;
110        assert!(result.is_ok());
111
112        // Verify file was created with correct content
113        let read_content = std::fs::read_to_string(&temp_file).unwrap();
114        assert_eq!(read_content, "Test content");
115
116        // Cleanup
117        std::fs::remove_file(temp_file).ok();
118    }
119
120    #[tokio::test]
121    async fn test_file_write_tool_overwrites_existing_file() {
122        let temp_dir = std::env::temp_dir();
123        let temp_file = temp_dir.join("test_overwrite_file.txt");
124        std::fs::write(&temp_file, "Original content").unwrap();
125
126        let tool = FileWriteTool::new();
127        let input = serde_json::json!({
128            "path": temp_file.to_str().unwrap(),
129            "content": "New content"
130        });
131        let context = ToolContext::default();
132
133        let result = tool.execute(input, &context).await;
134        assert!(result.is_ok());
135
136        // Verify file was overwritten
137        let read_content = std::fs::read_to_string(&temp_file).unwrap();
138        assert_eq!(read_content, "New content");
139
140        // Cleanup
141        std::fs::remove_file(temp_file).ok();
142    }
143}