Skip to main content

enact_core/tool/
filesystem.rs

1//! File system tools for reading and writing files
2
3use crate::tool::Tool;
4use async_trait::async_trait;
5use serde_json::json;
6use std::path::PathBuf;
7
8const MAX_FILE_SIZE_BYTES: u64 = 10 * 1024 * 1024; // 10MB
9
10/// Expands a leading `~` or `~/` in the path to the current user's home directory.
11/// Uses `HOME` on Unix and `USERPROFILE` on Windows when `HOME` is not set.
12fn expand_tilde(path: &str) -> PathBuf {
13    let path = path.trim();
14    if path.is_empty() {
15        return PathBuf::from(path);
16    }
17    let home = std::env::var("HOME")
18        .ok()
19        .or_else(|| std::env::var("USERPROFILE").ok());
20    match home {
21        Some(home) if path == "~" => PathBuf::from(home),
22        Some(home) if path.starts_with("~/") => {
23            PathBuf::from(home).join(path.trim_start_matches("~/"))
24        }
25        Some(home) if path.starts_with("~\\") => {
26            PathBuf::from(home).join(path.trim_start_matches("~\\"))
27        }
28        _ => PathBuf::from(path),
29    }
30}
31
32/// Read file contents from workspace
33pub struct FileReadTool;
34
35impl FileReadTool {
36    pub fn new() -> Self {
37        Self
38    }
39}
40
41impl Default for FileReadTool {
42    fn default() -> Self {
43        Self::new()
44    }
45}
46
47#[async_trait]
48impl Tool for FileReadTool {
49    fn name(&self) -> &str {
50        "file_read"
51    }
52
53    fn description(&self) -> &str {
54        "Read the contents of a file in the workspace"
55    }
56
57    fn parameters_schema(&self) -> serde_json::Value {
58        json!({
59            "type": "object",
60            "properties": {
61                "path": {
62                    "type": "string",
63                    "description": "Path to the file. Supports ~ for home directory (e.g. ~/.enact/config.yaml). Relative or absolute paths allowed."
64                }
65            },
66            "required": ["path"]
67        })
68    }
69
70    fn requires_network(&self) -> bool {
71        false
72    }
73
74    async fn execute(&self, args: serde_json::Value) -> anyhow::Result<serde_json::Value> {
75        let path_str = args
76            .get("path")
77            .and_then(|v| v.as_str())
78            .ok_or_else(|| anyhow::anyhow!("Missing 'path' parameter"))?;
79
80        let path = expand_tilde(path_str);
81
82        // Security: Prevent directory traversal
83        if path
84            .components()
85            .any(|c| matches!(c, std::path::Component::ParentDir))
86        {
87            anyhow::bail!("Path cannot contain '..' (directory traversal not allowed)");
88        }
89
90        // Check file exists
91        if !path.exists() {
92            anyhow::bail!("File not found: {}", path.display());
93        }
94
95        // Check it's a file
96        if !path.is_file() {
97            anyhow::bail!("Path is not a file: {}", path.display());
98        }
99
100        // Check file size
101        let metadata = tokio::fs::metadata(&path).await?;
102        if metadata.len() > MAX_FILE_SIZE_BYTES {
103            anyhow::bail!(
104                "File too large: {} bytes (max: {} bytes)",
105                metadata.len(),
106                MAX_FILE_SIZE_BYTES
107            );
108        }
109
110        let content = tokio::fs::read_to_string(&path).await?;
111
112        Ok(json!({
113            "success": true,
114            "content": content,
115            "path": path.to_string_lossy().to_string(),
116            "size": metadata.len()
117        }))
118    }
119}
120
121/// Write file contents to workspace
122pub struct FileWriteTool;
123
124impl FileWriteTool {
125    pub fn new() -> Self {
126        Self
127    }
128}
129
130impl Default for FileWriteTool {
131    fn default() -> Self {
132        Self::new()
133    }
134}
135
136#[async_trait]
137impl Tool for FileWriteTool {
138    fn name(&self) -> &str {
139        "file_write"
140    }
141
142    fn description(&self) -> &str {
143        "Write content to a file in the workspace (creates or overwrites)"
144    }
145
146    fn parameters_schema(&self) -> serde_json::Value {
147        json!({
148            "type": "object",
149            "properties": {
150                "path": {
151                    "type": "string",
152                    "description": "Relative path to the file within the workspace"
153                },
154                "content": {
155                    "type": "string",
156                    "description": "Content to write to the file"
157                }
158            },
159            "required": ["path", "content"]
160        })
161    }
162
163    fn requires_network(&self) -> bool {
164        false
165    }
166
167    async fn execute(&self, args: serde_json::Value) -> anyhow::Result<serde_json::Value> {
168        let path_str = args
169            .get("path")
170            .and_then(|v| v.as_str())
171            .ok_or_else(|| anyhow::anyhow!("Missing 'path' parameter"))?;
172
173        let content = args
174            .get("content")
175            .and_then(|v| v.as_str())
176            .ok_or_else(|| anyhow::anyhow!("Missing 'content' parameter"))?;
177
178        let path = expand_tilde(path_str);
179
180        // Security: Prevent directory traversal
181        if path
182            .components()
183            .any(|c| matches!(c, std::path::Component::ParentDir))
184        {
185            anyhow::bail!("Path cannot contain '..' (directory traversal not allowed)");
186        }
187
188        // Create parent directories if needed
189        if let Some(parent) = path.parent() {
190            tokio::fs::create_dir_all(parent).await?;
191        }
192
193        tokio::fs::write(&path, content).await?;
194
195        let metadata = tokio::fs::metadata(&path).await?;
196
197        Ok(json!({
198            "success": true,
199            "path": path.to_string_lossy().to_string(),
200            "size": metadata.len(),
201            "message": "File written successfully"
202        }))
203    }
204}
205
206#[cfg(test)]
207mod tests {
208    use super::*;
209
210    #[tokio::test]
211    async fn test_file_read_success() {
212        let tool = FileReadTool::new();
213
214        // Create a test file
215        let test_content = "Hello, World!";
216        tokio::fs::write("/tmp/test_read.txt", test_content)
217            .await
218            .unwrap();
219
220        let result = tool
221            .execute(json!({"path": "/tmp/test_read.txt"}))
222            .await
223            .unwrap();
224        assert_eq!(result["success"], true);
225        assert_eq!(result["content"], test_content);
226
227        // Cleanup
228        tokio::fs::remove_file("/tmp/test_read.txt").await.ok();
229    }
230
231    #[tokio::test]
232    async fn test_file_write_success() {
233        let tool = FileWriteTool::new();
234
235        let result = tool
236            .execute(json!({
237                "path": "/tmp/test_write.txt",
238                "content": "Test content"
239            }))
240            .await
241            .unwrap();
242
243        assert_eq!(result["success"], true);
244
245        // Verify file was written
246        let content = tokio::fs::read_to_string("/tmp/test_write.txt")
247            .await
248            .unwrap();
249        assert_eq!(content, "Test content");
250
251        // Cleanup
252        tokio::fs::remove_file("/tmp/test_write.txt").await.ok();
253    }
254
255    #[tokio::test]
256    async fn test_file_read_not_found() {
257        let tool = FileReadTool::new();
258        let result = tool
259            .execute(json!({"path": "/tmp/nonexistent_file_xyz.txt"}))
260            .await;
261        assert!(result.is_err());
262    }
263
264    #[tokio::test]
265    async fn test_file_read_traversal_prevention() {
266        let tool = FileReadTool::new();
267        let result = tool.execute(json!({"path": "../etc/passwd"})).await;
268        assert!(result.is_err());
269        assert!(result
270            .unwrap_err()
271            .to_string()
272            .contains("directory traversal"));
273    }
274
275    #[tokio::test]
276    async fn test_file_read_expands_tilde() {
277        let home = std::env::var("HOME")
278            .or_else(|_| std::env::var("USERPROFILE"))
279            .expect("HOME or USERPROFILE");
280        let test_file = std::path::PathBuf::from(&home).join(".enact_file_read_tilde_test");
281        let test_content = "tilde expansion works";
282
283        tokio::fs::write(&test_file, test_content).await.unwrap();
284
285        let tool = FileReadTool::new();
286        let result = tool
287            .execute(json!({"path": "~/.enact_file_read_tilde_test"}))
288            .await
289            .unwrap();
290
291        assert_eq!(result["success"], true);
292        assert_eq!(result["content"], test_content);
293
294        tokio::fs::remove_file(&test_file).await.ok();
295    }
296}