hehe_tools/builtin/
filesystem.rs

1use crate::error::{Result, ToolError};
2use crate::traits::{Tool, ToolOutput};
3use async_trait::async_trait;
4use hehe_core::{Context, ToolDefinition, ToolParameter};
5use serde::{Deserialize, Serialize};
6use serde_json::Value;
7use std::path::Path;
8use tokio::fs;
9
10pub struct ReadFileTool {
11    def: ToolDefinition,
12}
13
14impl ReadFileTool {
15    pub fn new() -> Self {
16        let def = ToolDefinition::new("read_file", "Read the contents of a file")
17            .with_required_param(
18                "path",
19                ToolParameter::string().with_description("Path to the file to read"),
20            )
21            .with_param(
22                "encoding",
23                ToolParameter::string()
24                    .with_description("File encoding (default: utf-8)")
25                    .with_default(Value::String("utf-8".into())),
26            );
27        Self { def }
28    }
29}
30
31impl Default for ReadFileTool {
32    fn default() -> Self {
33        Self::new()
34    }
35}
36
37#[derive(Deserialize)]
38struct ReadFileInput {
39    path: String,
40    #[serde(default = "default_encoding")]
41    encoding: String,
42}
43
44fn default_encoding() -> String {
45    "utf-8".to_string()
46}
47
48#[async_trait]
49impl Tool for ReadFileTool {
50    fn definition(&self) -> &ToolDefinition {
51        &self.def
52    }
53
54    async fn execute(&self, _ctx: &Context, input: Value) -> Result<ToolOutput> {
55        let input: ReadFileInput = serde_json::from_value(input)?;
56        
57        let path = Path::new(&input.path);
58        if !path.exists() {
59            return Ok(ToolOutput::error(format!("File not found: {}", input.path)));
60        }
61
62        match fs::read_to_string(path).await {
63            Ok(content) => {
64                let size = content.len();
65                Ok(ToolOutput::text(content)
66                    .with_metadata("path", &input.path)
67                    .with_metadata("size", size))
68            }
69            Err(e) => Ok(ToolOutput::error(format!("Failed to read file: {}", e))),
70        }
71    }
72}
73
74pub struct WriteFileTool {
75    def: ToolDefinition,
76}
77
78impl WriteFileTool {
79    pub fn new() -> Self {
80        let def = ToolDefinition::new("write_file", "Write content to a file")
81            .with_required_param(
82                "path",
83                ToolParameter::string().with_description("Path to the file to write"),
84            )
85            .with_required_param(
86                "content",
87                ToolParameter::string().with_description("Content to write"),
88            )
89            .with_param(
90                "append",
91                ToolParameter::boolean()
92                    .with_description("Append to file instead of overwriting")
93                    .with_default(Value::Bool(false)),
94            )
95            .dangerous();
96        Self { def }
97    }
98}
99
100impl Default for WriteFileTool {
101    fn default() -> Self {
102        Self::new()
103    }
104}
105
106#[derive(Deserialize)]
107struct WriteFileInput {
108    path: String,
109    content: String,
110    #[serde(default)]
111    append: bool,
112}
113
114#[async_trait]
115impl Tool for WriteFileTool {
116    fn definition(&self) -> &ToolDefinition {
117        &self.def
118    }
119
120    async fn execute(&self, _ctx: &Context, input: Value) -> Result<ToolOutput> {
121        let input: WriteFileInput = serde_json::from_value(input)?;
122        
123        let path = Path::new(&input.path);
124        
125        if let Some(parent) = path.parent() {
126            if !parent.exists() {
127                if let Err(e) = fs::create_dir_all(parent).await {
128                    return Ok(ToolOutput::error(format!("Failed to create directory: {}", e)));
129                }
130            }
131        }
132
133        let result = if input.append {
134            let existing = fs::read_to_string(path).await.unwrap_or_default();
135            fs::write(path, format!("{}{}", existing, input.content)).await
136        } else {
137            fs::write(path, &input.content).await
138        };
139
140        match result {
141            Ok(_) => Ok(ToolOutput::text(format!("Successfully wrote to {}", input.path))
142                .with_metadata("path", &input.path)
143                .with_metadata("bytes_written", input.content.len())),
144            Err(e) => Ok(ToolOutput::error(format!("Failed to write file: {}", e))),
145        }
146    }
147}
148
149pub struct ListDirectoryTool {
150    def: ToolDefinition,
151}
152
153impl ListDirectoryTool {
154    pub fn new() -> Self {
155        let def = ToolDefinition::new("list_directory", "List contents of a directory")
156            .with_required_param(
157                "path",
158                ToolParameter::string().with_description("Path to the directory"),
159            )
160            .with_param(
161                "recursive",
162                ToolParameter::boolean()
163                    .with_description("List recursively")
164                    .with_default(Value::Bool(false)),
165            );
166        Self { def }
167    }
168}
169
170impl Default for ListDirectoryTool {
171    fn default() -> Self {
172        Self::new()
173    }
174}
175
176#[derive(Deserialize)]
177struct ListDirectoryInput {
178    path: String,
179    #[serde(default)]
180    recursive: bool,
181}
182
183#[derive(Serialize, Deserialize)]
184struct DirectoryEntry {
185    name: String,
186    path: String,
187    is_dir: bool,
188    size: Option<u64>,
189}
190
191#[async_trait]
192impl Tool for ListDirectoryTool {
193    fn definition(&self) -> &ToolDefinition {
194        &self.def
195    }
196
197    async fn execute(&self, _ctx: &Context, input: Value) -> Result<ToolOutput> {
198        let input: ListDirectoryInput = serde_json::from_value(input)?;
199        
200        let path = Path::new(&input.path);
201        if !path.exists() {
202            return Ok(ToolOutput::error(format!("Directory not found: {}", input.path)));
203        }
204        if !path.is_dir() {
205            return Ok(ToolOutput::error(format!("Not a directory: {}", input.path)));
206        }
207
208        let mut entries = Vec::new();
209        
210        if input.recursive {
211            collect_entries_recursive(path, &mut entries).await?;
212        } else {
213            let mut read_dir = fs::read_dir(path).await?;
214            while let Some(entry) = read_dir.next_entry().await? {
215                let metadata = entry.metadata().await?;
216                entries.push(DirectoryEntry {
217                    name: entry.file_name().to_string_lossy().to_string(),
218                    path: entry.path().to_string_lossy().to_string(),
219                    is_dir: metadata.is_dir(),
220                    size: if metadata.is_file() { Some(metadata.len()) } else { None },
221                });
222            }
223        }
224
225        entries.sort_by(|a, b| a.name.cmp(&b.name));
226        ToolOutput::json(&entries)
227    }
228}
229
230async fn collect_entries_recursive(path: &Path, entries: &mut Vec<DirectoryEntry>) -> Result<()> {
231    let mut read_dir = fs::read_dir(path).await?;
232    while let Some(entry) = read_dir.next_entry().await? {
233        let metadata = entry.metadata().await?;
234        let entry_data = DirectoryEntry {
235            name: entry.file_name().to_string_lossy().to_string(),
236            path: entry.path().to_string_lossy().to_string(),
237            is_dir: metadata.is_dir(),
238            size: if metadata.is_file() { Some(metadata.len()) } else { None },
239        };
240        entries.push(entry_data);
241
242        if metadata.is_dir() {
243            Box::pin(collect_entries_recursive(&entry.path(), entries)).await?;
244        }
245    }
246    Ok(())
247}
248
249pub struct SearchFilesTool {
250    def: ToolDefinition,
251}
252
253impl SearchFilesTool {
254    pub fn new() -> Self {
255        let def = ToolDefinition::new("search_files", "Search for files matching a pattern")
256            .with_required_param(
257                "pattern",
258                ToolParameter::string().with_description("Glob pattern to search for"),
259            )
260            .with_param(
261                "path",
262                ToolParameter::string()
263                    .with_description("Base path to search from")
264                    .with_default(Value::String(".".into())),
265            );
266        Self { def }
267    }
268}
269
270impl Default for SearchFilesTool {
271    fn default() -> Self {
272        Self::new()
273    }
274}
275
276#[derive(Deserialize)]
277struct SearchFilesInput {
278    pattern: String,
279    #[serde(default = "default_path")]
280    path: String,
281}
282
283fn default_path() -> String {
284    ".".to_string()
285}
286
287#[async_trait]
288impl Tool for SearchFilesTool {
289    fn definition(&self) -> &ToolDefinition {
290        &self.def
291    }
292
293    async fn execute(&self, _ctx: &Context, input: Value) -> Result<ToolOutput> {
294        let input: SearchFilesInput = serde_json::from_value(input)?;
295        
296        let full_pattern = format!("{}/{}", input.path, input.pattern);
297        
298        let matches: Vec<String> = glob::glob(&full_pattern)
299            .map_err(|e| ToolError::invalid_input(format!("Invalid pattern: {}", e)))?
300            .filter_map(|r| r.ok())
301            .map(|p| p.to_string_lossy().to_string())
302            .collect();
303
304        ToolOutput::json(&matches)
305    }
306}
307
308#[cfg(test)]
309mod tests {
310    use super::*;
311    use tempfile::TempDir;
312
313    #[tokio::test]
314    async fn test_read_file() {
315        let dir = TempDir::new().unwrap();
316        let file_path = dir.path().join("test.txt");
317        std::fs::write(&file_path, "Hello, World!").unwrap();
318
319        let tool = ReadFileTool::new();
320        let ctx = Context::new();
321        let input = serde_json::json!({
322            "path": file_path.to_string_lossy()
323        });
324
325        let output = tool.execute(&ctx, input).await.unwrap();
326        assert!(!output.is_error);
327        assert_eq!(output.content, "Hello, World!");
328    }
329
330    #[tokio::test]
331    async fn test_read_file_not_found() {
332        let tool = ReadFileTool::new();
333        let ctx = Context::new();
334        let input = serde_json::json!({
335            "path": "/nonexistent/file.txt"
336        });
337
338        let output = tool.execute(&ctx, input).await.unwrap();
339        assert!(output.is_error);
340        assert!(output.content.contains("not found"));
341    }
342
343    #[tokio::test]
344    async fn test_write_file() {
345        let dir = TempDir::new().unwrap();
346        let file_path = dir.path().join("output.txt");
347
348        let tool = WriteFileTool::new();
349        let ctx = Context::new();
350        let input = serde_json::json!({
351            "path": file_path.to_string_lossy(),
352            "content": "Test content"
353        });
354
355        let output = tool.execute(&ctx, input).await.unwrap();
356        assert!(!output.is_error);
357
358        let content = std::fs::read_to_string(&file_path).unwrap();
359        assert_eq!(content, "Test content");
360    }
361
362    #[tokio::test]
363    async fn test_write_file_append() {
364        let dir = TempDir::new().unwrap();
365        let file_path = dir.path().join("append.txt");
366        std::fs::write(&file_path, "First").unwrap();
367
368        let tool = WriteFileTool::new();
369        let ctx = Context::new();
370        let input = serde_json::json!({
371            "path": file_path.to_string_lossy(),
372            "content": "Second",
373            "append": true
374        });
375
376        let output = tool.execute(&ctx, input).await.unwrap();
377        assert!(!output.is_error);
378
379        let content = std::fs::read_to_string(&file_path).unwrap();
380        assert_eq!(content, "FirstSecond");
381    }
382
383    #[tokio::test]
384    async fn test_list_directory() {
385        let dir = TempDir::new().unwrap();
386        std::fs::write(dir.path().join("a.txt"), "a").unwrap();
387        std::fs::write(dir.path().join("b.txt"), "b").unwrap();
388        std::fs::create_dir(dir.path().join("subdir")).unwrap();
389
390        let tool = ListDirectoryTool::new();
391        let ctx = Context::new();
392        let input = serde_json::json!({
393            "path": dir.path().to_string_lossy()
394        });
395
396        let output = tool.execute(&ctx, input).await.unwrap();
397        assert!(!output.is_error);
398
399        let entries: Vec<DirectoryEntry> = serde_json::from_str(&output.content).unwrap();
400        assert_eq!(entries.len(), 3);
401    }
402
403    #[tokio::test]
404    async fn test_search_files() {
405        let dir = TempDir::new().unwrap();
406        std::fs::write(dir.path().join("test1.txt"), "a").unwrap();
407        std::fs::write(dir.path().join("test2.txt"), "b").unwrap();
408        std::fs::write(dir.path().join("other.md"), "c").unwrap();
409
410        let tool = SearchFilesTool::new();
411        let ctx = Context::new();
412        let input = serde_json::json!({
413            "pattern": "*.txt",
414            "path": dir.path().to_string_lossy()
415        });
416
417        let output = tool.execute(&ctx, input).await.unwrap();
418        assert!(!output.is_error);
419
420        let matches: Vec<String> = serde_json::from_str(&output.content).unwrap();
421        assert_eq!(matches.len(), 2);
422    }
423}