1use crate::error::{HeliosError, Result};
2use async_trait::async_trait;
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5use std::collections::HashMap;
6use std::io::{BufReader, BufWriter, Read, Write};
7use std::path::Path;
8use std::time::{SystemTime, UNIX_EPOCH};
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct ToolParameter {
12 #[serde(rename = "type")]
13 pub param_type: String,
14 pub description: String,
15 #[serde(skip)]
16 pub required: Option<bool>,
17}
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct ToolDefinition {
21 #[serde(rename = "type")]
22 pub tool_type: String,
23 pub function: FunctionDefinition,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct FunctionDefinition {
28 pub name: String,
29 pub description: String,
30 pub parameters: ParametersSchema,
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct ParametersSchema {
35 #[serde(rename = "type")]
36 pub schema_type: String,
37 pub properties: HashMap<String, ToolParameter>,
38 #[serde(skip_serializing_if = "Option::is_none")]
39 pub required: Option<Vec<String>>,
40}
41
42#[derive(Debug, Clone)]
43pub struct ToolResult {
44 pub success: bool,
45 pub output: String,
46}
47
48impl ToolResult {
49 pub fn success(output: impl Into<String>) -> Self {
50 Self {
51 success: true,
52 output: output.into(),
53 }
54 }
55
56 pub fn error(message: impl Into<String>) -> Self {
57 Self {
58 success: false,
59 output: message.into(),
60 }
61 }
62}
63
64#[async_trait]
65pub trait Tool: Send + Sync {
66 fn name(&self) -> &str;
67 fn description(&self) -> &str;
68 fn parameters(&self) -> HashMap<String, ToolParameter>;
69 async fn execute(&self, args: Value) -> Result<ToolResult>;
70
71 fn to_definition(&self) -> ToolDefinition {
72 let required: Vec<String> = self
73 .parameters()
74 .iter()
75 .filter(|(_, param)| param.required.unwrap_or(false))
76 .map(|(name, _)| name.clone())
77 .collect();
78
79 ToolDefinition {
80 tool_type: "function".to_string(),
81 function: FunctionDefinition {
82 name: self.name().to_string(),
83 description: self.description().to_string(),
84 parameters: ParametersSchema {
85 schema_type: "object".to_string(),
86 properties: self.parameters(),
87 required: if required.is_empty() {
88 None
89 } else {
90 Some(required)
91 },
92 },
93 },
94 }
95 }
96}
97
98pub struct ToolRegistry {
99 tools: HashMap<String, Box<dyn Tool>>,
100}
101
102impl ToolRegistry {
103 pub fn new() -> Self {
104 Self {
105 tools: HashMap::new(),
106 }
107 }
108
109 pub fn register(&mut self, tool: Box<dyn Tool>) {
110 let name = tool.name().to_string();
111 self.tools.insert(name, tool);
112 }
113
114 pub fn get(&self, name: &str) -> Option<&dyn Tool> {
115 self.tools.get(name).map(|b| &**b)
116 }
117
118 pub async fn execute(&self, name: &str, args: Value) -> Result<ToolResult> {
119 let tool = self
120 .tools
121 .get(name)
122 .ok_or_else(|| HeliosError::ToolError(format!("Tool '{}' not found", name)))?;
123
124 tool.execute(args).await
125 }
126
127 pub fn get_definitions(&self) -> Vec<ToolDefinition> {
128 self.tools
129 .values()
130 .map(|tool| tool.to_definition())
131 .collect()
132 }
133
134 pub fn list_tools(&self) -> Vec<String> {
135 self.tools.keys().cloned().collect()
136 }
137}
138
139impl Default for ToolRegistry {
140 fn default() -> Self {
141 Self::new()
142 }
143}
144
145pub struct CalculatorTool;
148
149#[async_trait]
150impl Tool for CalculatorTool {
151 fn name(&self) -> &str {
152 "calculator"
153 }
154
155 fn description(&self) -> &str {
156 "Perform basic arithmetic operations. Supports +, -, *, / operations."
157 }
158
159 fn parameters(&self) -> HashMap<String, ToolParameter> {
160 let mut params = HashMap::new();
161 params.insert(
162 "expression".to_string(),
163 ToolParameter {
164 param_type: "string".to_string(),
165 description: "Mathematical expression to evaluate (e.g., '2 + 2')".to_string(),
166 required: Some(true),
167 },
168 );
169 params
170 }
171
172 async fn execute(&self, args: Value) -> Result<ToolResult> {
173 let expression = args
174 .get("expression")
175 .and_then(|v| v.as_str())
176 .ok_or_else(|| HeliosError::ToolError("Missing 'expression' parameter".to_string()))?;
177
178 let result = evaluate_expression(expression)?;
180 Ok(ToolResult::success(result.to_string()))
181 }
182}
183
184fn evaluate_expression(expr: &str) -> Result<f64> {
185 let expr = expr.replace(" ", "");
186
187 for op in &['*', '/', '+', '-'] {
189 if let Some(pos) = expr.rfind(*op) {
190 if pos == 0 {
191 continue; }
193 let left = &expr[..pos];
194 let right = &expr[pos + 1..];
195
196 let left_val = evaluate_expression(left)?;
197 let right_val = evaluate_expression(right)?;
198
199 return Ok(match op {
200 '+' => left_val + right_val,
201 '-' => left_val - right_val,
202 '*' => left_val * right_val,
203 '/' => {
204 if right_val == 0.0 {
205 return Err(HeliosError::ToolError("Division by zero".to_string()));
206 }
207 left_val / right_val
208 }
209 _ => unreachable!(),
210 });
211 }
212 }
213
214 expr.parse::<f64>()
215 .map_err(|_| HeliosError::ToolError(format!("Invalid expression: {}", expr)))
216}
217
218pub struct EchoTool;
219
220#[async_trait]
221impl Tool for EchoTool {
222 fn name(&self) -> &str {
223 "echo"
224 }
225
226 fn description(&self) -> &str {
227 "Echo back the provided message."
228 }
229
230 fn parameters(&self) -> HashMap<String, ToolParameter> {
231 let mut params = HashMap::new();
232 params.insert(
233 "message".to_string(),
234 ToolParameter {
235 param_type: "string".to_string(),
236 description: "The message to echo back".to_string(),
237 required: Some(true),
238 },
239 );
240 params
241 }
242
243 async fn execute(&self, args: Value) -> Result<ToolResult> {
244 let message = args
245 .get("message")
246 .and_then(|v| v.as_str())
247 .ok_or_else(|| HeliosError::ToolError("Missing 'message' parameter".to_string()))?;
248
249 Ok(ToolResult::success(format!("Echo: {}", message)))
250 }
251}
252
253pub struct FileSearchTool;
254
255#[async_trait]
256impl Tool for FileSearchTool {
257 fn name(&self) -> &str {
258 "file_search"
259 }
260
261 fn description(&self) -> &str {
262 "Search for files by name pattern or search for content within files. Can search recursively in directories."
263 }
264
265 fn parameters(&self) -> HashMap<String, ToolParameter> {
266 let mut params = HashMap::new();
267 params.insert(
268 "path".to_string(),
269 ToolParameter {
270 param_type: "string".to_string(),
271 description: "The directory path to search in (default: current directory)".to_string(),
272 required: Some(false),
273 },
274 );
275 params.insert(
276 "pattern".to_string(),
277 ToolParameter {
278 param_type: "string".to_string(),
279 description: "File name pattern to search for (supports wildcards like *.rs)".to_string(),
280 required: Some(false),
281 },
282 );
283 params.insert(
284 "content".to_string(),
285 ToolParameter {
286 param_type: "string".to_string(),
287 description: "Text content to search for within files".to_string(),
288 required: Some(false),
289 },
290 );
291 params.insert(
292 "max_results".to_string(),
293 ToolParameter {
294 param_type: "number".to_string(),
295 description: "Maximum number of results to return (default: 50)".to_string(),
296 required: Some(false),
297 },
298 );
299 params
300 }
301
302 async fn execute(&self, args: Value) -> Result<ToolResult> {
303 use walkdir::WalkDir;
304
305 let base_path = args
306 .get("path")
307 .and_then(|v| v.as_str())
308 .unwrap_or(".");
309
310 let pattern = args.get("pattern").and_then(|v| v.as_str());
311 let content_search = args.get("content").and_then(|v| v.as_str());
312 let max_results = args
313 .get("max_results")
314 .and_then(|v| v.as_u64())
315 .unwrap_or(50) as usize;
316
317 if pattern.is_none() && content_search.is_none() {
318 return Err(HeliosError::ToolError(
319 "Either 'pattern' or 'content' parameter is required".to_string(),
320 ));
321 }
322
323 let mut results = Vec::new();
324
325 let compiled_re = if let Some(pat) = pattern {
327 let re_pattern = pat
328 .replace(".", r"\.")
329 .replace("*", ".*")
330 .replace("?", ".");
331 match regex::Regex::new(&format!("^{}$", re_pattern)) {
332 Ok(re) => Some(re),
333 Err(e) => {
334 tracing::warn!(
335 "Invalid glob pattern '{}' ({}). Falling back to substring matching.",
336 pat,
337 e
338 );
339 None
340 }
341 }
342 } else {
343 None
344 };
345
346 for entry in WalkDir::new(base_path)
347 .max_depth(10)
348 .follow_links(false)
349 .into_iter()
350 .filter_map(|e| e.ok())
351 {
352 if results.len() >= max_results {
353 break;
354 }
355
356 let path = entry.path();
357
358 if let Some(file_name) = path.file_name().and_then(|n| n.to_str()) {
360 if file_name.starts_with('.') ||
361 file_name == "target" ||
362 file_name == "node_modules" ||
363 file_name == "__pycache__" {
364 continue;
365 }
366 }
367
368 if let Some(pat) = pattern {
370 if path.is_file() {
371 if let Some(file_name) = path.file_name().and_then(|n| n.to_str()) {
372 let is_match = if let Some(re) = &compiled_re {
373 re.is_match(file_name)
374 } else {
375 file_name.contains(pat)
376 };
377 if is_match {
378 results.push(format!("📄 {}", path.display()));
379 }
380 }
381 }
382 }
383
384 if let Some(search_term) = content_search {
386 if path.is_file() {
387 if let Ok(content) = std::fs::read_to_string(path) {
388 if content.contains(search_term) {
389 let matching_lines: Vec<(usize, &str)> = content
391 .lines()
392 .enumerate()
393 .filter(|(_, line)| line.contains(search_term))
394 .take(3) .collect();
396
397 if !matching_lines.is_empty() {
398 results.push(format!("📄 {} (found in {} lines)",
399 path.display(), matching_lines.len()));
400 for (line_num, line) in matching_lines {
401 results.push(format!(" Line {}: {}", line_num + 1, line.trim()));
402 }
403 }
404 }
405 }
406 }
407 }
408 }
409
410 if results.is_empty() {
411 Ok(ToolResult::success("No files found matching the criteria.".to_string()))
412 } else {
413 let output = format!(
414 "Found {} result(s):\n\n{}",
415 results.len(),
416 results.join("\n")
417 );
418 Ok(ToolResult::success(output))
419 }
420 }
421}
422
423pub struct FileReadTool;
426
427#[async_trait]
428impl Tool for FileReadTool {
429 fn name(&self) -> &str {
430 "file_read"
431 }
432
433 fn description(&self) -> &str {
434 "Read the contents of a file. Returns the full file content or specific lines."
435 }
436
437 fn parameters(&self) -> HashMap<String, ToolParameter> {
438 let mut params = HashMap::new();
439 params.insert(
440 "path".to_string(),
441 ToolParameter {
442 param_type: "string".to_string(),
443 description: "The file path to read".to_string(),
444 required: Some(true),
445 },
446 );
447 params.insert(
448 "start_line".to_string(),
449 ToolParameter {
450 param_type: "number".to_string(),
451 description: "Starting line number (1-indexed, optional)".to_string(),
452 required: Some(false),
453 },
454 );
455 params.insert(
456 "end_line".to_string(),
457 ToolParameter {
458 param_type: "number".to_string(),
459 description: "Ending line number (1-indexed, optional)".to_string(),
460 required: Some(false),
461 },
462 );
463 params
464 }
465
466 async fn execute(&self, args: Value) -> Result<ToolResult> {
467 let file_path = args
468 .get("path")
469 .and_then(|v| v.as_str())
470 .ok_or_else(|| HeliosError::ToolError("Missing 'path' parameter".to_string()))?;
471
472 let content = std::fs::read_to_string(file_path)
473 .map_err(|e| HeliosError::ToolError(format!("Failed to read file: {}", e)))?;
474
475 let start_line = args.get("start_line").and_then(|v| v.as_u64()).map(|n| n as usize);
476 let end_line = args.get("end_line").and_then(|v| v.as_u64()).map(|n| n as usize);
477
478 let output = if let (Some(start), Some(end)) = (start_line, end_line) {
479 let lines: Vec<&str> = content.lines().collect();
480 let start_idx = start.saturating_sub(1);
481 let end_idx = end.min(lines.len());
482
483 if start_idx >= lines.len() {
484 return Err(HeliosError::ToolError(format!(
485 "Start line {} is beyond file length ({})",
486 start, lines.len()
487 )));
488 }
489
490 let selected_lines = &lines[start_idx..end_idx];
491 format!(
492 "File: {} (lines {}-{}):\n\n{}",
493 file_path,
494 start,
495 end_idx,
496 selected_lines.join("\n")
497 )
498 } else {
499 format!("File: {}:\n\n{}", file_path, content)
500 };
501
502 Ok(ToolResult::success(output))
503 }
504}
505
506pub struct FileWriteTool;
507
508#[async_trait]
509impl Tool for FileWriteTool {
510 fn name(&self) -> &str {
511 "file_write"
512 }
513
514 fn description(&self) -> &str {
515 "Write content to a file. Creates new file or overwrites existing file."
516 }
517
518 fn parameters(&self) -> HashMap<String, ToolParameter> {
519 let mut params = HashMap::new();
520 params.insert(
521 "path".to_string(),
522 ToolParameter {
523 param_type: "string".to_string(),
524 description: "The file path to write to".to_string(),
525 required: Some(true),
526 },
527 );
528 params.insert(
529 "content".to_string(),
530 ToolParameter {
531 param_type: "string".to_string(),
532 description: "The content to write to the file".to_string(),
533 required: Some(true),
534 },
535 );
536 params
537 }
538
539 async fn execute(&self, args: Value) -> Result<ToolResult> {
540 let file_path = args
541 .get("path")
542 .and_then(|v| v.as_str())
543 .ok_or_else(|| HeliosError::ToolError("Missing 'path' parameter".to_string()))?;
544
545 let content = args
546 .get("content")
547 .and_then(|v| v.as_str())
548 .ok_or_else(|| HeliosError::ToolError("Missing 'content' parameter".to_string()))?;
549
550 if let Some(parent) = std::path::Path::new(file_path).parent() {
552 std::fs::create_dir_all(parent)
553 .map_err(|e| HeliosError::ToolError(format!("Failed to create directories: {}", e)))?;
554 }
555
556 std::fs::write(file_path, content)
557 .map_err(|e| HeliosError::ToolError(format!("Failed to write file: {}", e)))?;
558
559 Ok(ToolResult::success(format!(
560 "Successfully wrote {} bytes to {}",
561 content.len(),
562 file_path
563 )))
564 }
565}
566
567pub struct FileEditTool;
568
569#[async_trait]
570impl Tool for FileEditTool {
571 fn name(&self) -> &str {
572 "file_edit"
573 }
574
575 fn description(&self) -> &str {
576 "Edit a file by replacing specific text or lines. Use this to make targeted changes to existing files."
577 }
578
579 fn parameters(&self) -> HashMap<String, ToolParameter> {
580 let mut params = HashMap::new();
581 params.insert(
582 "path".to_string(),
583 ToolParameter {
584 param_type: "string".to_string(),
585 description: "The file path to edit".to_string(),
586 required: Some(true),
587 },
588 );
589 params.insert(
590 "find".to_string(),
591 ToolParameter {
592 param_type: "string".to_string(),
593 description: "The text to find and replace".to_string(),
594 required: Some(true),
595 },
596 );
597 params.insert(
598 "replace".to_string(),
599 ToolParameter {
600 param_type: "string".to_string(),
601 description: "The replacement text".to_string(),
602 required: Some(true),
603 },
604 );
605 params
606 }
607
608 async fn execute(&self, args: Value) -> Result<ToolResult> {
609 let file_path = args
610 .get("path")
611 .and_then(|v| v.as_str())
612 .ok_or_else(|| HeliosError::ToolError("Missing 'path' parameter".to_string()))?;
613
614 let find_text = args
615 .get("find")
616 .and_then(|v| v.as_str())
617 .ok_or_else(|| HeliosError::ToolError("Missing 'find' parameter".to_string()))?;
618
619 let replace_text = args
620 .get("replace")
621 .and_then(|v| v.as_str())
622 .ok_or_else(|| HeliosError::ToolError("Missing 'replace' parameter".to_string()))?;
623
624 if find_text.is_empty() {
625 return Err(HeliosError::ToolError("'find' parameter cannot be empty".to_string()));
626 }
627
628 let path = Path::new(file_path);
629 let parent = path.parent().ok_or_else(|| {
630 HeliosError::ToolError(format!("Invalid target path: {}", file_path))
631 })?;
632 let file_name = path.file_name().ok_or_else(|| {
633 HeliosError::ToolError(format!("Invalid target path: {}", file_path))
634 })?;
635
636 let pid = std::process::id();
638 let nanos = SystemTime::now()
639 .duration_since(UNIX_EPOCH)
640 .map_err(|e| HeliosError::ToolError(format!("Clock error: {}", e)))?
641 .as_nanos();
642 let tmp_name = format!("{}.tmp.{}.{}", file_name.to_string_lossy(), pid, nanos);
643 let tmp_path = parent.join(tmp_name);
644
645 let input_file = std::fs::File::open(&path)
647 .map_err(|e| HeliosError::ToolError(format!("Failed to open file for read: {}", e)))?;
648 let mut reader = BufReader::new(input_file);
649
650 let tmp_file = std::fs::File::create(&tmp_path).map_err(|e| {
651 HeliosError::ToolError(format!("Failed to create temp file {}: {}", tmp_path.display(), e))
652 })?;
653 let mut writer = BufWriter::new(&tmp_file);
654
655 let replaced_count = replace_streaming(
657 &mut reader,
658 &mut writer,
659 find_text.as_bytes(),
660 replace_text.as_bytes(),
661 )
662 .map_err(|e| HeliosError::ToolError(format!("I/O error while replacing: {}", e)))?;
663
664 writer.flush().map_err(|e| HeliosError::ToolError(format!("Failed to flush temp file: {}", e)))?;
666 tmp_file.sync_all().map_err(|e| HeliosError::ToolError(format!("Failed to sync temp file: {}", e)))?;
667
668 if let Ok(meta) = std::fs::metadata(&path) {
670 if let Err(e) = std::fs::set_permissions(&tmp_path, meta.permissions()) {
671 let _ = std::fs::remove_file(&tmp_path);
672 return Err(HeliosError::ToolError(format!("Failed to set permissions: {}", e)));
673 }
674 }
675
676 std::fs::rename(&tmp_path, &path).map_err(|e| {
678 let _ = std::fs::remove_file(&tmp_path);
679 HeliosError::ToolError(format!("Failed to replace original file: {}", e))
680 })?;
681
682 if replaced_count == 0 {
683 return Ok(ToolResult::error(format!(
684 "Text '{}' not found in file {}",
685 find_text, file_path
686 )));
687 }
688
689 Ok(ToolResult::success(format!(
690 "Successfully replaced {} occurrence(s) in {}",
691 replaced_count, file_path
692 )))
693 }
694}
695
696fn replace_streaming<R: Read, W: Write>(reader: &mut R, writer: &mut W, needle: &[u8], replacement: &[u8]) -> std::io::Result<usize> {
698 let mut replaced = 0usize;
699 let mut carry: Vec<u8> = Vec::new();
700 let mut buf = [0u8; 8192];
701
702 let tail = if needle.len() > 1 { needle.len() - 1 } else { 0 };
703
704 loop {
705 let n = reader.read(&mut buf)?;
706 if n == 0 {
707 break;
708 }
709
710 let mut combined = Vec::with_capacity(carry.len() + n);
711 combined.extend_from_slice(&carry);
712 combined.extend_from_slice(&buf[..n]);
713
714 let process_len = combined.len().saturating_sub(tail);
715 let (to_process, new_carry) = combined.split_at(process_len);
716 replaced += write_with_replacements(writer, to_process, needle, replacement)?;
717 carry.clear();
718 carry.extend_from_slice(new_carry);
719 }
720
721 replaced += write_with_replacements(writer, &carry, needle, replacement)?;
723 Ok(replaced)
724}
725
726fn write_with_replacements<W: Write>(writer: &mut W, haystack: &[u8], needle: &[u8], replacement: &[u8]) -> std::io::Result<usize> {
727 if needle.is_empty() {
728 writer.write_all(haystack)?;
729 return Ok(0);
730 }
731
732 let mut count = 0usize;
733 let mut i = 0usize;
734 while let Some(pos) = find_subslice(&haystack[i..], needle) {
735 let idx = i + pos;
736 writer.write_all(&haystack[i..idx])?;
737 writer.write_all(replacement)?;
738 count += 1;
739 i = idx + needle.len();
740 }
741 writer.write_all(&haystack[i..])?;
742 Ok(count)
743}
744
745fn find_subslice(h: &[u8], n: &[u8]) -> Option<usize> {
746 if n.is_empty() {
747 return Some(0);
748 }
749 h.windows(n.len()).position(|w| w == n)
750}
751
752#[cfg(test)]
753mod tests {
754 use super::*;
755 use serde_json::json;
756
757 #[test]
758 fn test_tool_result_success() {
759 let result = ToolResult::success("test output");
760 assert!(result.success);
761 assert_eq!(result.output, "test output");
762 }
763
764 #[tokio::test]
765 async fn test_file_search_tool_glob_pattern_precompiled_regex() {
766 use std::time::{SystemTime, UNIX_EPOCH};
767 let base_tmp = std::env::temp_dir();
768 let pid = std::process::id();
769 let nanos = SystemTime::now()
770 .duration_since(UNIX_EPOCH)
771 .unwrap()
772 .as_nanos();
773 let test_dir = base_tmp.join(format!("helios_fs_test_{}_{}", pid, nanos));
774 std::fs::create_dir_all(&test_dir).unwrap();
775
776 let file_rs = test_dir.join("a.rs");
778 let file_txt = test_dir.join("b.txt");
779 let subdir = test_dir.join("subdir");
780 std::fs::create_dir_all(&subdir).unwrap();
781 let file_sub_rs = subdir.join("mod.rs");
782 std::fs::write(&file_rs, "fn main() {}\n").unwrap();
783 std::fs::write(&file_txt, "hello\n").unwrap();
784 std::fs::write(&file_sub_rs, "pub fn x() {}\n").unwrap();
785
786 let tool = FileSearchTool;
788 let args = json!({
789 "path": test_dir.to_string_lossy(),
790 "pattern": "*.rs",
791 "max_results": 50
792 });
793 let result = tool.execute(args).await.unwrap();
794 assert!(result.success);
795 let out = result.output;
796 assert!(out.contains(&file_rs.to_string_lossy().to_string()));
798 assert!(out.contains(&file_sub_rs.to_string_lossy().to_string()));
799 assert!(!out.contains(&file_txt.to_string_lossy().to_string()));
801
802 let _ = std::fs::remove_dir_all(&test_dir);
804 }
805
806 #[tokio::test]
807 async fn test_file_search_tool_invalid_pattern_fallback_contains() {
808 use std::time::{SystemTime, UNIX_EPOCH};
809 let base_tmp = std::env::temp_dir();
810 let pid = std::process::id();
811 let nanos = SystemTime::now()
812 .duration_since(UNIX_EPOCH)
813 .unwrap()
814 .as_nanos();
815 let test_dir = base_tmp.join(format!("helios_fs_test_invalid_{}_{}", pid, nanos));
816 std::fs::create_dir_all(&test_dir).unwrap();
817
818 let special = test_dir.join("foo(bar).txt");
820 std::fs::write(&special, "content\n").unwrap();
821
822 let tool = FileSearchTool;
823 let args = json!({
824 "path": test_dir.to_string_lossy(),
825 "pattern": "(",
826 "max_results": 50
827 });
828 let result = tool.execute(args).await.unwrap();
829 assert!(result.success);
830 let out = result.output;
831 assert!(out.contains(&special.to_string_lossy().to_string()));
832
833 let _ = std::fs::remove_dir_all(&test_dir);
835 }
836
837 #[test]
838 fn test_tool_result_error() {
839 let result = ToolResult::error("test error");
840 assert!(!result.success);
841 assert_eq!(result.output, "test error");
842 }
843
844 #[tokio::test]
845 async fn test_calculator_tool() {
846 let tool = CalculatorTool;
847 assert_eq!(tool.name(), "calculator");
848 assert_eq!(
849 tool.description(),
850 "Perform basic arithmetic operations. Supports +, -, *, / operations."
851);
852
853 let args = json!({"expression": "2 + 2"});
854 let result = tool.execute(args).await.unwrap();
855 assert!(result.success);
856 assert_eq!(result.output, "4");
857 }
858
859 #[tokio::test]
860 async fn test_calculator_tool_multiplication() {
861 let tool = CalculatorTool;
862 let args = json!({"expression": "3 * 4"});
863 let result = tool.execute(args).await.unwrap();
864 assert!(result.success);
865 assert_eq!(result.output, "12");
866 }
867
868 #[tokio::test]
869 async fn test_calculator_tool_division() {
870 let tool = CalculatorTool;
871 let args = json!({"expression": "8 / 2"});
872 let result = tool.execute(args).await.unwrap();
873 assert!(result.success);
874 assert_eq!(result.output, "4");
875 }
876
877 #[tokio::test]
878 async fn test_calculator_tool_division_by_zero() {
879 let tool = CalculatorTool;
880 let args = json!({"expression": "8 / 0"});
881 let result = tool.execute(args).await;
882 assert!(result.is_err());
883 }
884
885 #[tokio::test]
886 async fn test_calculator_tool_invalid_expression() {
887 let tool = CalculatorTool;
888 let args = json!({"expression": "invalid"});
889 let result = tool.execute(args).await;
890 assert!(result.is_err());
891 }
892
893 #[tokio::test]
894 async fn test_echo_tool() {
895 let tool = EchoTool;
896 assert_eq!(tool.name(), "echo");
897 assert_eq!(tool.description(), "Echo back the provided message.");
898
899 let args = json!({"message": "Hello, world!"});
900 let result = tool.execute(args).await.unwrap();
901 assert!(result.success);
902 assert_eq!(result.output, "Echo: Hello, world!");
903 }
904
905 #[tokio::test]
906 async fn test_echo_tool_missing_parameter() {
907 let tool = EchoTool;
908 let args = json!({});
909 let result = tool.execute(args).await;
910 assert!(result.is_err());
911 }
912
913 #[test]
914 fn test_tool_registry_new() {
915 let registry = ToolRegistry::new();
916 assert!(registry.tools.is_empty());
917 }
918
919 #[tokio::test]
920 async fn test_tool_registry_register_and_get() {
921 let mut registry = ToolRegistry::new();
922 registry.register(Box::new(CalculatorTool));
923
924 let tool = registry.get("calculator");
925 assert!(tool.is_some());
926 assert_eq!(tool.unwrap().name(), "calculator");
927 }
928
929 #[tokio::test]
930 async fn test_tool_registry_execute() {
931 let mut registry = ToolRegistry::new();
932 registry.register(Box::new(CalculatorTool));
933
934 let args = json!({"expression": "5 * 6"});
935 let result = registry.execute("calculator", args).await.unwrap();
936 assert!(result.success);
937 assert_eq!(result.output, "30");
938 }
939
940 #[tokio::test]
941 async fn test_tool_registry_execute_nonexistent_tool() {
942 let registry = ToolRegistry::new();
943 let args = json!({"expression": "5 * 6"});
944 let result = registry.execute("nonexistent", args).await;
945 assert!(result.is_err());
946 }
947
948 #[test]
949 fn test_tool_registry_get_definitions() {
950 let mut registry = ToolRegistry::new();
951 registry.register(Box::new(CalculatorTool));
952 registry.register(Box::new(EchoTool));
953
954 let definitions = registry.get_definitions();
955 assert_eq!(definitions.len(), 2);
956
957 let names: Vec<String> = definitions
959 .iter()
960 .map(|d| d.function.name.clone())
961 .collect();
962 assert!(names.contains(&"calculator".to_string()));
963 assert!(names.contains(&"echo".to_string()));
964 }
965
966 #[test]
967 fn test_tool_registry_list_tools() {
968 let mut registry = ToolRegistry::new();
969 registry.register(Box::new(CalculatorTool));
970 registry.register(Box::new(EchoTool));
971
972 let tools = registry.list_tools();
973 assert_eq!(tools.len(), 2);
974 assert!(tools.contains(&"calculator".to_string()));
975 assert!(tools.contains(&"echo".to_string()));
976 }
977}