1use crate::error::{Result, ToolError};
2use crate::traits::{Tool, ToolOutput};
3use async_trait::async_trait;
4use hehe_core::{Context, ToolDefinition, ToolParameter};
5use serde::{Deserialize, Serialize};
6use serde_json::Value;
7use std::path::Path;
8use tokio::fs;
9
10pub struct ReadFileTool {
11 def: ToolDefinition,
12}
13
14impl ReadFileTool {
15 pub fn new() -> Self {
16 let def = ToolDefinition::new("read_file", "Read the contents of a file")
17 .with_required_param(
18 "path",
19 ToolParameter::string().with_description("Path to the file to read"),
20 )
21 .with_param(
22 "encoding",
23 ToolParameter::string()
24 .with_description("File encoding (default: utf-8)")
25 .with_default(Value::String("utf-8".into())),
26 );
27 Self { def }
28 }
29}
30
31impl Default for ReadFileTool {
32 fn default() -> Self {
33 Self::new()
34 }
35}
36
37#[derive(Deserialize)]
38struct ReadFileInput {
39 path: String,
40 #[serde(default = "default_encoding")]
41 encoding: String,
42}
43
44fn default_encoding() -> String {
45 "utf-8".to_string()
46}
47
48#[async_trait]
49impl Tool for ReadFileTool {
50 fn definition(&self) -> &ToolDefinition {
51 &self.def
52 }
53
54 async fn execute(&self, _ctx: &Context, input: Value) -> Result<ToolOutput> {
55 let input: ReadFileInput = serde_json::from_value(input)?;
56
57 let path = Path::new(&input.path);
58 if !path.exists() {
59 return Ok(ToolOutput::error(format!("File not found: {}", input.path)));
60 }
61
62 match fs::read_to_string(path).await {
63 Ok(content) => {
64 let size = content.len();
65 Ok(ToolOutput::text(content)
66 .with_metadata("path", &input.path)
67 .with_metadata("size", size))
68 }
69 Err(e) => Ok(ToolOutput::error(format!("Failed to read file: {}", e))),
70 }
71 }
72}
73
74pub struct WriteFileTool {
75 def: ToolDefinition,
76}
77
78impl WriteFileTool {
79 pub fn new() -> Self {
80 let def = ToolDefinition::new("write_file", "Write content to a file")
81 .with_required_param(
82 "path",
83 ToolParameter::string().with_description("Path to the file to write"),
84 )
85 .with_required_param(
86 "content",
87 ToolParameter::string().with_description("Content to write"),
88 )
89 .with_param(
90 "append",
91 ToolParameter::boolean()
92 .with_description("Append to file instead of overwriting")
93 .with_default(Value::Bool(false)),
94 )
95 .dangerous();
96 Self { def }
97 }
98}
99
100impl Default for WriteFileTool {
101 fn default() -> Self {
102 Self::new()
103 }
104}
105
106#[derive(Deserialize)]
107struct WriteFileInput {
108 path: String,
109 content: String,
110 #[serde(default)]
111 append: bool,
112}
113
114#[async_trait]
115impl Tool for WriteFileTool {
116 fn definition(&self) -> &ToolDefinition {
117 &self.def
118 }
119
120 async fn execute(&self, _ctx: &Context, input: Value) -> Result<ToolOutput> {
121 let input: WriteFileInput = serde_json::from_value(input)?;
122
123 let path = Path::new(&input.path);
124
125 if let Some(parent) = path.parent() {
126 if !parent.exists() {
127 if let Err(e) = fs::create_dir_all(parent).await {
128 return Ok(ToolOutput::error(format!("Failed to create directory: {}", e)));
129 }
130 }
131 }
132
133 let result = if input.append {
134 let existing = fs::read_to_string(path).await.unwrap_or_default();
135 fs::write(path, format!("{}{}", existing, input.content)).await
136 } else {
137 fs::write(path, &input.content).await
138 };
139
140 match result {
141 Ok(_) => Ok(ToolOutput::text(format!("Successfully wrote to {}", input.path))
142 .with_metadata("path", &input.path)
143 .with_metadata("bytes_written", input.content.len())),
144 Err(e) => Ok(ToolOutput::error(format!("Failed to write file: {}", e))),
145 }
146 }
147}
148
149pub struct ListDirectoryTool {
150 def: ToolDefinition,
151}
152
153impl ListDirectoryTool {
154 pub fn new() -> Self {
155 let def = ToolDefinition::new("list_directory", "List contents of a directory")
156 .with_required_param(
157 "path",
158 ToolParameter::string().with_description("Path to the directory"),
159 )
160 .with_param(
161 "recursive",
162 ToolParameter::boolean()
163 .with_description("List recursively")
164 .with_default(Value::Bool(false)),
165 );
166 Self { def }
167 }
168}
169
170impl Default for ListDirectoryTool {
171 fn default() -> Self {
172 Self::new()
173 }
174}
175
176#[derive(Deserialize)]
177struct ListDirectoryInput {
178 path: String,
179 #[serde(default)]
180 recursive: bool,
181}
182
183#[derive(Serialize, Deserialize)]
184struct DirectoryEntry {
185 name: String,
186 path: String,
187 is_dir: bool,
188 size: Option<u64>,
189}
190
191#[async_trait]
192impl Tool for ListDirectoryTool {
193 fn definition(&self) -> &ToolDefinition {
194 &self.def
195 }
196
197 async fn execute(&self, _ctx: &Context, input: Value) -> Result<ToolOutput> {
198 let input: ListDirectoryInput = serde_json::from_value(input)?;
199
200 let path = Path::new(&input.path);
201 if !path.exists() {
202 return Ok(ToolOutput::error(format!("Directory not found: {}", input.path)));
203 }
204 if !path.is_dir() {
205 return Ok(ToolOutput::error(format!("Not a directory: {}", input.path)));
206 }
207
208 let mut entries = Vec::new();
209
210 if input.recursive {
211 collect_entries_recursive(path, &mut entries).await?;
212 } else {
213 let mut read_dir = fs::read_dir(path).await?;
214 while let Some(entry) = read_dir.next_entry().await? {
215 let metadata = entry.metadata().await?;
216 entries.push(DirectoryEntry {
217 name: entry.file_name().to_string_lossy().to_string(),
218 path: entry.path().to_string_lossy().to_string(),
219 is_dir: metadata.is_dir(),
220 size: if metadata.is_file() { Some(metadata.len()) } else { None },
221 });
222 }
223 }
224
225 entries.sort_by(|a, b| a.name.cmp(&b.name));
226 ToolOutput::json(&entries)
227 }
228}
229
230async fn collect_entries_recursive(path: &Path, entries: &mut Vec<DirectoryEntry>) -> Result<()> {
231 let mut read_dir = fs::read_dir(path).await?;
232 while let Some(entry) = read_dir.next_entry().await? {
233 let metadata = entry.metadata().await?;
234 let entry_data = DirectoryEntry {
235 name: entry.file_name().to_string_lossy().to_string(),
236 path: entry.path().to_string_lossy().to_string(),
237 is_dir: metadata.is_dir(),
238 size: if metadata.is_file() { Some(metadata.len()) } else { None },
239 };
240 entries.push(entry_data);
241
242 if metadata.is_dir() {
243 Box::pin(collect_entries_recursive(&entry.path(), entries)).await?;
244 }
245 }
246 Ok(())
247}
248
249pub struct SearchFilesTool {
250 def: ToolDefinition,
251}
252
253impl SearchFilesTool {
254 pub fn new() -> Self {
255 let def = ToolDefinition::new("search_files", "Search for files matching a pattern")
256 .with_required_param(
257 "pattern",
258 ToolParameter::string().with_description("Glob pattern to search for"),
259 )
260 .with_param(
261 "path",
262 ToolParameter::string()
263 .with_description("Base path to search from")
264 .with_default(Value::String(".".into())),
265 );
266 Self { def }
267 }
268}
269
270impl Default for SearchFilesTool {
271 fn default() -> Self {
272 Self::new()
273 }
274}
275
276#[derive(Deserialize)]
277struct SearchFilesInput {
278 pattern: String,
279 #[serde(default = "default_path")]
280 path: String,
281}
282
283fn default_path() -> String {
284 ".".to_string()
285}
286
287#[async_trait]
288impl Tool for SearchFilesTool {
289 fn definition(&self) -> &ToolDefinition {
290 &self.def
291 }
292
293 async fn execute(&self, _ctx: &Context, input: Value) -> Result<ToolOutput> {
294 let input: SearchFilesInput = serde_json::from_value(input)?;
295
296 let full_pattern = format!("{}/{}", input.path, input.pattern);
297
298 let matches: Vec<String> = glob::glob(&full_pattern)
299 .map_err(|e| ToolError::invalid_input(format!("Invalid pattern: {}", e)))?
300 .filter_map(|r| r.ok())
301 .map(|p| p.to_string_lossy().to_string())
302 .collect();
303
304 ToolOutput::json(&matches)
305 }
306}
307
308#[cfg(test)]
309mod tests {
310 use super::*;
311 use tempfile::TempDir;
312
313 #[tokio::test]
314 async fn test_read_file() {
315 let dir = TempDir::new().unwrap();
316 let file_path = dir.path().join("test.txt");
317 std::fs::write(&file_path, "Hello, World!").unwrap();
318
319 let tool = ReadFileTool::new();
320 let ctx = Context::new();
321 let input = serde_json::json!({
322 "path": file_path.to_string_lossy()
323 });
324
325 let output = tool.execute(&ctx, input).await.unwrap();
326 assert!(!output.is_error);
327 assert_eq!(output.content, "Hello, World!");
328 }
329
330 #[tokio::test]
331 async fn test_read_file_not_found() {
332 let tool = ReadFileTool::new();
333 let ctx = Context::new();
334 let input = serde_json::json!({
335 "path": "/nonexistent/file.txt"
336 });
337
338 let output = tool.execute(&ctx, input).await.unwrap();
339 assert!(output.is_error);
340 assert!(output.content.contains("not found"));
341 }
342
343 #[tokio::test]
344 async fn test_write_file() {
345 let dir = TempDir::new().unwrap();
346 let file_path = dir.path().join("output.txt");
347
348 let tool = WriteFileTool::new();
349 let ctx = Context::new();
350 let input = serde_json::json!({
351 "path": file_path.to_string_lossy(),
352 "content": "Test content"
353 });
354
355 let output = tool.execute(&ctx, input).await.unwrap();
356 assert!(!output.is_error);
357
358 let content = std::fs::read_to_string(&file_path).unwrap();
359 assert_eq!(content, "Test content");
360 }
361
362 #[tokio::test]
363 async fn test_write_file_append() {
364 let dir = TempDir::new().unwrap();
365 let file_path = dir.path().join("append.txt");
366 std::fs::write(&file_path, "First").unwrap();
367
368 let tool = WriteFileTool::new();
369 let ctx = Context::new();
370 let input = serde_json::json!({
371 "path": file_path.to_string_lossy(),
372 "content": "Second",
373 "append": true
374 });
375
376 let output = tool.execute(&ctx, input).await.unwrap();
377 assert!(!output.is_error);
378
379 let content = std::fs::read_to_string(&file_path).unwrap();
380 assert_eq!(content, "FirstSecond");
381 }
382
383 #[tokio::test]
384 async fn test_list_directory() {
385 let dir = TempDir::new().unwrap();
386 std::fs::write(dir.path().join("a.txt"), "a").unwrap();
387 std::fs::write(dir.path().join("b.txt"), "b").unwrap();
388 std::fs::create_dir(dir.path().join("subdir")).unwrap();
389
390 let tool = ListDirectoryTool::new();
391 let ctx = Context::new();
392 let input = serde_json::json!({
393 "path": dir.path().to_string_lossy()
394 });
395
396 let output = tool.execute(&ctx, input).await.unwrap();
397 assert!(!output.is_error);
398
399 let entries: Vec<DirectoryEntry> = serde_json::from_str(&output.content).unwrap();
400 assert_eq!(entries.len(), 3);
401 }
402
403 #[tokio::test]
404 async fn test_search_files() {
405 let dir = TempDir::new().unwrap();
406 std::fs::write(dir.path().join("test1.txt"), "a").unwrap();
407 std::fs::write(dir.path().join("test2.txt"), "b").unwrap();
408 std::fs::write(dir.path().join("other.md"), "c").unwrap();
409
410 let tool = SearchFilesTool::new();
411 let ctx = Context::new();
412 let input = serde_json::json!({
413 "pattern": "*.txt",
414 "path": dir.path().to_string_lossy()
415 });
416
417 let output = tool.execute(&ctx, input).await.unwrap();
418 assert!(!output.is_error);
419
420 let matches: Vec<String> = serde_json::from_str(&output.content).unwrap();
421 assert_eq!(matches.len(), 2);
422 }
423}