1use crate::error::{HeliosError, Result};
2use async_trait::async_trait;
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5use std::collections::HashMap;
6use std::io::{BufReader, BufWriter, Read, Write};
7use std::path::Path;
8use std::time::{SystemTime, UNIX_EPOCH};
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct ToolParameter {
12 #[serde(rename = "type")]
13 pub param_type: String,
14 pub description: String,
15 #[serde(skip)]
16 pub required: Option<bool>,
17}
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct ToolDefinition {
21 #[serde(rename = "type")]
22 pub tool_type: String,
23 pub function: FunctionDefinition,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct FunctionDefinition {
28 pub name: String,
29 pub description: String,
30 pub parameters: ParametersSchema,
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct ParametersSchema {
35 #[serde(rename = "type")]
36 pub schema_type: String,
37 pub properties: HashMap<String, ToolParameter>,
38 #[serde(skip_serializing_if = "Option::is_none")]
39 pub required: Option<Vec<String>>,
40}
41
42#[derive(Debug, Clone)]
43pub struct ToolResult {
44 pub success: bool,
45 pub output: String,
46}
47
48impl ToolResult {
49 pub fn success(output: impl Into<String>) -> Self {
50 Self {
51 success: true,
52 output: output.into(),
53 }
54 }
55
56 pub fn error(message: impl Into<String>) -> Self {
57 Self {
58 success: false,
59 output: message.into(),
60 }
61 }
62}
63
64#[async_trait]
65pub trait Tool: Send + Sync {
66 fn name(&self) -> &str;
67 fn description(&self) -> &str;
68 fn parameters(&self) -> HashMap<String, ToolParameter>;
69 async fn execute(&self, args: Value) -> Result<ToolResult>;
70
71 fn to_definition(&self) -> ToolDefinition {
72 let required: Vec<String> = self
73 .parameters()
74 .iter()
75 .filter(|(_, param)| param.required.unwrap_or(false))
76 .map(|(name, _)| name.clone())
77 .collect();
78
79 ToolDefinition {
80 tool_type: "function".to_string(),
81 function: FunctionDefinition {
82 name: self.name().to_string(),
83 description: self.description().to_string(),
84 parameters: ParametersSchema {
85 schema_type: "object".to_string(),
86 properties: self.parameters(),
87 required: if required.is_empty() {
88 None
89 } else {
90 Some(required)
91 },
92 },
93 },
94 }
95 }
96}
97
98pub struct ToolRegistry {
99 tools: HashMap<String, Box<dyn Tool>>,
100}
101
102impl ToolRegistry {
103 pub fn new() -> Self {
104 Self {
105 tools: HashMap::new(),
106 }
107 }
108
109 pub fn register(&mut self, tool: Box<dyn Tool>) {
110 let name = tool.name().to_string();
111 self.tools.insert(name, tool);
112 }
113
114 pub fn get(&self, name: &str) -> Option<&dyn Tool> {
115 self.tools.get(name).map(|b| &**b)
116 }
117
118 pub async fn execute(&self, name: &str, args: Value) -> Result<ToolResult> {
119 let tool = self
120 .tools
121 .get(name)
122 .ok_or_else(|| HeliosError::ToolError(format!("Tool '{}' not found", name)))?;
123
124 tool.execute(args).await
125 }
126
127 pub fn get_definitions(&self) -> Vec<ToolDefinition> {
128 self.tools
129 .values()
130 .map(|tool| tool.to_definition())
131 .collect()
132 }
133
134 pub fn list_tools(&self) -> Vec<String> {
135 self.tools.keys().cloned().collect()
136 }
137}
138
139impl Default for ToolRegistry {
140 fn default() -> Self {
141 Self::new()
142 }
143}
144
145pub struct CalculatorTool;
148
149#[async_trait]
150impl Tool for CalculatorTool {
151 fn name(&self) -> &str {
152 "calculator"
153 }
154
155 fn description(&self) -> &str {
156 "Perform basic arithmetic operations. Supports +, -, *, / operations."
157 }
158
159 fn parameters(&self) -> HashMap<String, ToolParameter> {
160 let mut params = HashMap::new();
161 params.insert(
162 "expression".to_string(),
163 ToolParameter {
164 param_type: "string".to_string(),
165 description: "Mathematical expression to evaluate (e.g., '2 + 2')".to_string(),
166 required: Some(true),
167 },
168 );
169 params
170 }
171
172 async fn execute(&self, args: Value) -> Result<ToolResult> {
173 let expression = args
174 .get("expression")
175 .and_then(|v| v.as_str())
176 .ok_or_else(|| HeliosError::ToolError("Missing 'expression' parameter".to_string()))?;
177
178 let result = evaluate_expression(expression)?;
180 Ok(ToolResult::success(result.to_string()))
181 }
182}
183
184fn evaluate_expression(expr: &str) -> Result<f64> {
185 let expr = expr.replace(" ", "");
186
187 for op in &['*', '/', '+', '-'] {
189 if let Some(pos) = expr.rfind(*op) {
190 if pos == 0 {
191 continue; }
193 let left = &expr[..pos];
194 let right = &expr[pos + 1..];
195
196 let left_val = evaluate_expression(left)?;
197 let right_val = evaluate_expression(right)?;
198
199 return Ok(match op {
200 '+' => left_val + right_val,
201 '-' => left_val - right_val,
202 '*' => left_val * right_val,
203 '/' => {
204 if right_val == 0.0 {
205 return Err(HeliosError::ToolError("Division by zero".to_string()));
206 }
207 left_val / right_val
208 }
209 _ => unreachable!(),
210 });
211 }
212 }
213
214 expr.parse::<f64>()
215 .map_err(|_| HeliosError::ToolError(format!("Invalid expression: {}", expr)))
216}
217
218pub struct EchoTool;
219
220#[async_trait]
221impl Tool for EchoTool {
222 fn name(&self) -> &str {
223 "echo"
224 }
225
226 fn description(&self) -> &str {
227 "Echo back the provided message."
228 }
229
230 fn parameters(&self) -> HashMap<String, ToolParameter> {
231 let mut params = HashMap::new();
232 params.insert(
233 "message".to_string(),
234 ToolParameter {
235 param_type: "string".to_string(),
236 description: "The message to echo back".to_string(),
237 required: Some(true),
238 },
239 );
240 params
241 }
242
243 async fn execute(&self, args: Value) -> Result<ToolResult> {
244 let message = args
245 .get("message")
246 .and_then(|v| v.as_str())
247 .ok_or_else(|| HeliosError::ToolError("Missing 'message' parameter".to_string()))?;
248
249 Ok(ToolResult::success(format!("Echo: {}", message)))
250 }
251}
252
253pub struct FileSearchTool;
254
255#[async_trait]
256impl Tool for FileSearchTool {
257 fn name(&self) -> &str {
258 "file_search"
259 }
260
261 fn description(&self) -> &str {
262 "Search for files by name pattern or search for content within files. Can search recursively in directories."
263 }
264
265 fn parameters(&self) -> HashMap<String, ToolParameter> {
266 let mut params = HashMap::new();
267 params.insert(
268 "path".to_string(),
269 ToolParameter {
270 param_type: "string".to_string(),
271 description: "The directory path to search in (default: current directory)".to_string(),
272 required: Some(false),
273 },
274 );
275 params.insert(
276 "pattern".to_string(),
277 ToolParameter {
278 param_type: "string".to_string(),
279 description: "File name pattern to search for (supports wildcards like *.rs)".to_string(),
280 required: Some(false),
281 },
282 );
283 params.insert(
284 "content".to_string(),
285 ToolParameter {
286 param_type: "string".to_string(),
287 description: "Text content to search for within files".to_string(),
288 required: Some(false),
289 },
290 );
291 params.insert(
292 "max_results".to_string(),
293 ToolParameter {
294 param_type: "number".to_string(),
295 description: "Maximum number of results to return (default: 50)".to_string(),
296 required: Some(false),
297 },
298 );
299 params
300 }
301
302 async fn execute(&self, args: Value) -> Result<ToolResult> {
303 use walkdir::WalkDir;
304
305 let base_path = args
306 .get("path")
307 .and_then(|v| v.as_str())
308 .unwrap_or(".");
309
310 let pattern = args.get("pattern").and_then(|v| v.as_str());
311 let content_search = args.get("content").and_then(|v| v.as_str());
312 let max_results = args
313 .get("max_results")
314 .and_then(|v| v.as_u64())
315 .unwrap_or(50) as usize;
316
317 if pattern.is_none() && content_search.is_none() {
318 return Err(HeliosError::ToolError(
319 "Either 'pattern' or 'content' parameter is required".to_string(),
320 ));
321 }
322
323 let mut results = Vec::new();
324
325 for entry in WalkDir::new(base_path)
326 .max_depth(10)
327 .follow_links(false)
328 .into_iter()
329 .filter_map(|e| e.ok())
330 {
331 if results.len() >= max_results {
332 break;
333 }
334
335 let path = entry.path();
336
337 if let Some(file_name) = path.file_name().and_then(|n| n.to_str()) {
339 if file_name.starts_with('.') ||
340 file_name == "target" ||
341 file_name == "node_modules" ||
342 file_name == "__pycache__" {
343 continue;
344 }
345 }
346
347 if let Some(pat) = pattern {
349 if path.is_file() {
350 if let Some(file_name) = path.file_name().and_then(|n| n.to_str()) {
351 if glob_match(file_name, pat) {
352 results.push(format!("📄 {}", path.display()));
353 }
354 }
355 }
356 }
357
358 if let Some(search_term) = content_search {
360 if path.is_file() {
361 if let Ok(content) = std::fs::read_to_string(path) {
362 if content.contains(search_term) {
363 let matching_lines: Vec<(usize, &str)> = content
365 .lines()
366 .enumerate()
367 .filter(|(_, line)| line.contains(search_term))
368 .take(3) .collect();
370
371 if !matching_lines.is_empty() {
372 results.push(format!("📄 {} (found in {} lines)",
373 path.display(), matching_lines.len()));
374 for (line_num, line) in matching_lines {
375 results.push(format!(" Line {}: {}", line_num + 1, line.trim()));
376 }
377 }
378 }
379 }
380 }
381 }
382 }
383
384 if results.is_empty() {
385 Ok(ToolResult::success("No files found matching the criteria.".to_string()))
386 } else {
387 let output = format!(
388 "Found {} result(s):\n\n{}",
389 results.len(),
390 results.join("\n")
391 );
392 Ok(ToolResult::success(output))
393 }
394 }
395}
396
397fn glob_match(text: &str, pattern: &str) -> bool {
399 let re_pattern = pattern
400 .replace(".", r"\.")
401 .replace("*", ".*")
402 .replace("?", ".");
403
404 if let Ok(re) = regex::Regex::new(&format!("^{}$", re_pattern)) {
405 re.is_match(text)
406 } else {
407 text.contains(pattern)
408 }
409}
410
411pub struct FileReadTool;
412
413#[async_trait]
414impl Tool for FileReadTool {
415 fn name(&self) -> &str {
416 "file_read"
417 }
418
419 fn description(&self) -> &str {
420 "Read the contents of a file. Returns the full file content or specific lines."
421 }
422
423 fn parameters(&self) -> HashMap<String, ToolParameter> {
424 let mut params = HashMap::new();
425 params.insert(
426 "path".to_string(),
427 ToolParameter {
428 param_type: "string".to_string(),
429 description: "The file path to read".to_string(),
430 required: Some(true),
431 },
432 );
433 params.insert(
434 "start_line".to_string(),
435 ToolParameter {
436 param_type: "number".to_string(),
437 description: "Starting line number (1-indexed, optional)".to_string(),
438 required: Some(false),
439 },
440 );
441 params.insert(
442 "end_line".to_string(),
443 ToolParameter {
444 param_type: "number".to_string(),
445 description: "Ending line number (1-indexed, optional)".to_string(),
446 required: Some(false),
447 },
448 );
449 params
450 }
451
452 async fn execute(&self, args: Value) -> Result<ToolResult> {
453 let file_path = args
454 .get("path")
455 .and_then(|v| v.as_str())
456 .ok_or_else(|| HeliosError::ToolError("Missing 'path' parameter".to_string()))?;
457
458 let content = std::fs::read_to_string(file_path)
459 .map_err(|e| HeliosError::ToolError(format!("Failed to read file: {}", e)))?;
460
461 let start_line = args.get("start_line").and_then(|v| v.as_u64()).map(|n| n as usize);
462 let end_line = args.get("end_line").and_then(|v| v.as_u64()).map(|n| n as usize);
463
464 let output = if let (Some(start), Some(end)) = (start_line, end_line) {
465 let lines: Vec<&str> = content.lines().collect();
466 let start_idx = start.saturating_sub(1);
467 let end_idx = end.min(lines.len());
468
469 if start_idx >= lines.len() {
470 return Err(HeliosError::ToolError(format!(
471 "Start line {} is beyond file length ({})",
472 start, lines.len()
473 )));
474 }
475
476 let selected_lines = &lines[start_idx..end_idx];
477 format!(
478 "File: {} (lines {}-{}):\n\n{}",
479 file_path,
480 start,
481 end_idx,
482 selected_lines.join("\n")
483 )
484 } else {
485 format!("File: {}:\n\n{}", file_path, content)
486 };
487
488 Ok(ToolResult::success(output))
489 }
490}
491
492pub struct FileWriteTool;
493
494#[async_trait]
495impl Tool for FileWriteTool {
496 fn name(&self) -> &str {
497 "file_write"
498 }
499
500 fn description(&self) -> &str {
501 "Write content to a file. Creates new file or overwrites existing file."
502 }
503
504 fn parameters(&self) -> HashMap<String, ToolParameter> {
505 let mut params = HashMap::new();
506 params.insert(
507 "path".to_string(),
508 ToolParameter {
509 param_type: "string".to_string(),
510 description: "The file path to write to".to_string(),
511 required: Some(true),
512 },
513 );
514 params.insert(
515 "content".to_string(),
516 ToolParameter {
517 param_type: "string".to_string(),
518 description: "The content to write to the file".to_string(),
519 required: Some(true),
520 },
521 );
522 params
523 }
524
525 async fn execute(&self, args: Value) -> Result<ToolResult> {
526 let file_path = args
527 .get("path")
528 .and_then(|v| v.as_str())
529 .ok_or_else(|| HeliosError::ToolError("Missing 'path' parameter".to_string()))?;
530
531 let content = args
532 .get("content")
533 .and_then(|v| v.as_str())
534 .ok_or_else(|| HeliosError::ToolError("Missing 'content' parameter".to_string()))?;
535
536 if let Some(parent) = std::path::Path::new(file_path).parent() {
538 std::fs::create_dir_all(parent)
539 .map_err(|e| HeliosError::ToolError(format!("Failed to create directories: {}", e)))?;
540 }
541
542 std::fs::write(file_path, content)
543 .map_err(|e| HeliosError::ToolError(format!("Failed to write file: {}", e)))?;
544
545 Ok(ToolResult::success(format!(
546 "Successfully wrote {} bytes to {}",
547 content.len(),
548 file_path
549 )))
550 }
551}
552
553pub struct FileEditTool;
554
555#[async_trait]
556impl Tool for FileEditTool {
557 fn name(&self) -> &str {
558 "file_edit"
559 }
560
561 fn description(&self) -> &str {
562 "Edit a file by replacing specific text or lines. Use this to make targeted changes to existing files."
563 }
564
565 fn parameters(&self) -> HashMap<String, ToolParameter> {
566 let mut params = HashMap::new();
567 params.insert(
568 "path".to_string(),
569 ToolParameter {
570 param_type: "string".to_string(),
571 description: "The file path to edit".to_string(),
572 required: Some(true),
573 },
574 );
575 params.insert(
576 "find".to_string(),
577 ToolParameter {
578 param_type: "string".to_string(),
579 description: "The text to find and replace".to_string(),
580 required: Some(true),
581 },
582 );
583 params.insert(
584 "replace".to_string(),
585 ToolParameter {
586 param_type: "string".to_string(),
587 description: "The replacement text".to_string(),
588 required: Some(true),
589 },
590 );
591 params
592 }
593
594 async fn execute(&self, args: Value) -> Result<ToolResult> {
595 let file_path = args
596 .get("path")
597 .and_then(|v| v.as_str())
598 .ok_or_else(|| HeliosError::ToolError("Missing 'path' parameter".to_string()))?;
599
600 let find_text = args
601 .get("find")
602 .and_then(|v| v.as_str())
603 .ok_or_else(|| HeliosError::ToolError("Missing 'find' parameter".to_string()))?;
604
605 let replace_text = args
606 .get("replace")
607 .and_then(|v| v.as_str())
608 .ok_or_else(|| HeliosError::ToolError("Missing 'replace' parameter".to_string()))?;
609
610 if find_text.is_empty() {
611 return Err(HeliosError::ToolError("'find' parameter cannot be empty".to_string()));
612 }
613
614 let path = Path::new(file_path);
615 let parent = path.parent().ok_or_else(|| {
616 HeliosError::ToolError(format!("Invalid target path: {}", file_path))
617 })?;
618 let file_name = path.file_name().ok_or_else(|| {
619 HeliosError::ToolError(format!("Invalid target path: {}", file_path))
620 })?;
621
622 let pid = std::process::id();
624 let nanos = SystemTime::now()
625 .duration_since(UNIX_EPOCH)
626 .map_err(|e| HeliosError::ToolError(format!("Clock error: {}", e)))?
627 .as_nanos();
628 let tmp_name = format!("{}.tmp.{}.{}", file_name.to_string_lossy(), pid, nanos);
629 let tmp_path = parent.join(tmp_name);
630
631 let input_file = std::fs::File::open(&path)
633 .map_err(|e| HeliosError::ToolError(format!("Failed to open file for read: {}", e)))?;
634 let mut reader = BufReader::new(input_file);
635
636 let tmp_file = std::fs::File::create(&tmp_path).map_err(|e| {
637 HeliosError::ToolError(format!("Failed to create temp file {}: {}", tmp_path.display(), e))
638 })?;
639 let mut writer = BufWriter::new(&tmp_file);
640
641 let replaced_count = replace_streaming(
643 &mut reader,
644 &mut writer,
645 find_text.as_bytes(),
646 replace_text.as_bytes(),
647 )
648 .map_err(|e| HeliosError::ToolError(format!("I/O error while replacing: {}", e)))?;
649
650 writer.flush().map_err(|e| HeliosError::ToolError(format!("Failed to flush temp file: {}", e)))?;
652 tmp_file.sync_all().map_err(|e| HeliosError::ToolError(format!("Failed to sync temp file: {}", e)))?;
653
654 if let Ok(meta) = std::fs::metadata(&path) {
656 if let Err(e) = std::fs::set_permissions(&tmp_path, meta.permissions()) {
657 let _ = std::fs::remove_file(&tmp_path);
658 return Err(HeliosError::ToolError(format!("Failed to set permissions: {}", e)));
659 }
660 }
661
662 std::fs::rename(&tmp_path, &path).map_err(|e| {
664 let _ = std::fs::remove_file(&tmp_path);
665 HeliosError::ToolError(format!("Failed to replace original file: {}", e))
666 })?;
667
668 if replaced_count == 0 {
669 return Ok(ToolResult::error(format!(
670 "Text '{}' not found in file {}",
671 find_text, file_path
672 )));
673 }
674
675 Ok(ToolResult::success(format!(
676 "Successfully replaced {} occurrence(s) in {}",
677 replaced_count, file_path
678 )))
679 }
680}
681
682fn replace_streaming<R: Read, W: Write>(reader: &mut R, writer: &mut W, needle: &[u8], replacement: &[u8]) -> std::io::Result<usize> {
684 let mut replaced = 0usize;
685 let mut carry: Vec<u8> = Vec::new();
686 let mut buf = [0u8; 8192];
687
688 let tail = if needle.len() > 1 { needle.len() - 1 } else { 0 };
689
690 loop {
691 let n = reader.read(&mut buf)?;
692 if n == 0 {
693 break;
694 }
695
696 let mut combined = Vec::with_capacity(carry.len() + n);
697 combined.extend_from_slice(&carry);
698 combined.extend_from_slice(&buf[..n]);
699
700 let process_len = combined.len().saturating_sub(tail);
701 let (to_process, new_carry) = combined.split_at(process_len);
702 replaced += write_with_replacements(writer, to_process, needle, replacement)?;
703 carry.clear();
704 carry.extend_from_slice(new_carry);
705 }
706
707 replaced += write_with_replacements(writer, &carry, needle, replacement)?;
709 Ok(replaced)
710}
711
712fn write_with_replacements<W: Write>(writer: &mut W, haystack: &[u8], needle: &[u8], replacement: &[u8]) -> std::io::Result<usize> {
713 if needle.is_empty() {
714 writer.write_all(haystack)?;
715 return Ok(0);
716 }
717
718 let mut count = 0usize;
719 let mut i = 0usize;
720 while let Some(pos) = find_subslice(&haystack[i..], needle) {
721 let idx = i + pos;
722 writer.write_all(&haystack[i..idx])?;
723 writer.write_all(replacement)?;
724 count += 1;
725 i = idx + needle.len();
726 }
727 writer.write_all(&haystack[i..])?;
728 Ok(count)
729}
730
731fn find_subslice(h: &[u8], n: &[u8]) -> Option<usize> {
732 if n.is_empty() {
733 return Some(0);
734 }
735 h.windows(n.len()).position(|w| w == n)
736}
737
738#[cfg(test)]
739mod tests {
740 use super::*;
741 use serde_json::json;
742
743 #[test]
744 fn test_tool_result_success() {
745 let result = ToolResult::success("test output");
746 assert!(result.success);
747 assert_eq!(result.output, "test output");
748 }
749
750 #[test]
751 fn test_tool_result_error() {
752 let result = ToolResult::error("test error");
753 assert!(!result.success);
754 assert_eq!(result.output, "test error");
755 }
756
757 #[tokio::test]
758 async fn test_calculator_tool() {
759 let tool = CalculatorTool;
760 assert_eq!(tool.name(), "calculator");
761 assert_eq!(
762 tool.description(),
763 "Perform basic arithmetic operations. Supports +, -, *, / operations."
764);
765
766 let args = json!({"expression": "2 + 2"});
767 let result = tool.execute(args).await.unwrap();
768 assert!(result.success);
769 assert_eq!(result.output, "4");
770 }
771
772 #[tokio::test]
773 async fn test_calculator_tool_multiplication() {
774 let tool = CalculatorTool;
775 let args = json!({"expression": "3 * 4"});
776 let result = tool.execute(args).await.unwrap();
777 assert!(result.success);
778 assert_eq!(result.output, "12");
779 }
780
781 #[tokio::test]
782 async fn test_calculator_tool_division() {
783 let tool = CalculatorTool;
784 let args = json!({"expression": "8 / 2"});
785 let result = tool.execute(args).await.unwrap();
786 assert!(result.success);
787 assert_eq!(result.output, "4");
788 }
789
790 #[tokio::test]
791 async fn test_calculator_tool_division_by_zero() {
792 let tool = CalculatorTool;
793 let args = json!({"expression": "8 / 0"});
794 let result = tool.execute(args).await;
795 assert!(result.is_err());
796 }
797
798 #[tokio::test]
799 async fn test_calculator_tool_invalid_expression() {
800 let tool = CalculatorTool;
801 let args = json!({"expression": "invalid"});
802 let result = tool.execute(args).await;
803 assert!(result.is_err());
804 }
805
806 #[tokio::test]
807 async fn test_echo_tool() {
808 let tool = EchoTool;
809 assert_eq!(tool.name(), "echo");
810 assert_eq!(tool.description(), "Echo back the provided message.");
811
812 let args = json!({"message": "Hello, world!"});
813 let result = tool.execute(args).await.unwrap();
814 assert!(result.success);
815 assert_eq!(result.output, "Echo: Hello, world!");
816 }
817
818 #[tokio::test]
819 async fn test_echo_tool_missing_parameter() {
820 let tool = EchoTool;
821 let args = json!({});
822 let result = tool.execute(args).await;
823 assert!(result.is_err());
824 }
825
826 #[test]
827 fn test_tool_registry_new() {
828 let registry = ToolRegistry::new();
829 assert!(registry.tools.is_empty());
830 }
831
832 #[tokio::test]
833 async fn test_tool_registry_register_and_get() {
834 let mut registry = ToolRegistry::new();
835 registry.register(Box::new(CalculatorTool));
836
837 let tool = registry.get("calculator");
838 assert!(tool.is_some());
839 assert_eq!(tool.unwrap().name(), "calculator");
840 }
841
842 #[tokio::test]
843 async fn test_tool_registry_execute() {
844 let mut registry = ToolRegistry::new();
845 registry.register(Box::new(CalculatorTool));
846
847 let args = json!({"expression": "5 * 6"});
848 let result = registry.execute("calculator", args).await.unwrap();
849 assert!(result.success);
850 assert_eq!(result.output, "30");
851 }
852
853 #[tokio::test]
854 async fn test_tool_registry_execute_nonexistent_tool() {
855 let registry = ToolRegistry::new();
856 let args = json!({"expression": "5 * 6"});
857 let result = registry.execute("nonexistent", args).await;
858 assert!(result.is_err());
859 }
860
861 #[test]
862 fn test_tool_registry_get_definitions() {
863 let mut registry = ToolRegistry::new();
864 registry.register(Box::new(CalculatorTool));
865 registry.register(Box::new(EchoTool));
866
867 let definitions = registry.get_definitions();
868 assert_eq!(definitions.len(), 2);
869
870 let names: Vec<String> = definitions
872 .iter()
873 .map(|d| d.function.name.clone())
874 .collect();
875 assert!(names.contains(&"calculator".to_string()));
876 assert!(names.contains(&"echo".to_string()));
877 }
878
879 #[test]
880 fn test_tool_registry_list_tools() {
881 let mut registry = ToolRegistry::new();
882 registry.register(Box::new(CalculatorTool));
883 registry.register(Box::new(EchoTool));
884
885 let tools = registry.list_tools();
886 assert_eq!(tools.len(), 2);
887 assert!(tools.contains(&"calculator".to_string()));
888 assert!(tools.contains(&"echo".to_string()));
889 }
890}