helios_engine/
tools.rs

1use crate::error::{HeliosError, Result};
2use async_trait::async_trait;
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5use std::collections::HashMap;
6use std::io::{BufReader, BufWriter, Read, Write};
7use std::path::Path;
8use std::time::{SystemTime, UNIX_EPOCH};
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct ToolParameter {
12    #[serde(rename = "type")]
13    pub param_type: String,
14    pub description: String,
15    #[serde(skip)]
16    pub required: Option<bool>,
17}
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct ToolDefinition {
21    #[serde(rename = "type")]
22    pub tool_type: String,
23    pub function: FunctionDefinition,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct FunctionDefinition {
28    pub name: String,
29    pub description: String,
30    pub parameters: ParametersSchema,
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct ParametersSchema {
35    #[serde(rename = "type")]
36    pub schema_type: String,
37    pub properties: HashMap<String, ToolParameter>,
38    #[serde(skip_serializing_if = "Option::is_none")]
39    pub required: Option<Vec<String>>,
40}
41
42#[derive(Debug, Clone)]
43pub struct ToolResult {
44    pub success: bool,
45    pub output: String,
46}
47
48impl ToolResult {
49    pub fn success(output: impl Into<String>) -> Self {
50        Self {
51            success: true,
52            output: output.into(),
53        }
54    }
55
56    pub fn error(message: impl Into<String>) -> Self {
57        Self {
58            success: false,
59            output: message.into(),
60        }
61    }
62}
63
64#[async_trait]
65pub trait Tool: Send + Sync {
66    fn name(&self) -> &str;
67    fn description(&self) -> &str;
68    fn parameters(&self) -> HashMap<String, ToolParameter>;
69    async fn execute(&self, args: Value) -> Result<ToolResult>;
70
71    fn to_definition(&self) -> ToolDefinition {
72        let required: Vec<String> = self
73            .parameters()
74            .iter()
75            .filter(|(_, param)| param.required.unwrap_or(false))
76            .map(|(name, _)| name.clone())
77            .collect();
78
79        ToolDefinition {
80            tool_type: "function".to_string(),
81            function: FunctionDefinition {
82                name: self.name().to_string(),
83                description: self.description().to_string(),
84                parameters: ParametersSchema {
85                    schema_type: "object".to_string(),
86                    properties: self.parameters(),
87                    required: if required.is_empty() {
88                        None
89                    } else {
90                        Some(required)
91                    },
92                },
93            },
94        }
95    }
96}
97
98pub struct ToolRegistry {
99    tools: HashMap<String, Box<dyn Tool>>,
100}
101
102impl ToolRegistry {
103    pub fn new() -> Self {
104        Self {
105            tools: HashMap::new(),
106        }
107    }
108
109    pub fn register(&mut self, tool: Box<dyn Tool>) {
110        let name = tool.name().to_string();
111        self.tools.insert(name, tool);
112    }
113
114    pub fn get(&self, name: &str) -> Option<&dyn Tool> {
115        self.tools.get(name).map(|b| &**b)
116    }
117
118    pub async fn execute(&self, name: &str, args: Value) -> Result<ToolResult> {
119        let tool = self
120            .tools
121            .get(name)
122            .ok_or_else(|| HeliosError::ToolError(format!("Tool '{}' not found", name)))?;
123
124        tool.execute(args).await
125    }
126
127    pub fn get_definitions(&self) -> Vec<ToolDefinition> {
128        self.tools
129            .values()
130            .map(|tool| tool.to_definition())
131            .collect()
132    }
133
134    pub fn list_tools(&self) -> Vec<String> {
135        self.tools.keys().cloned().collect()
136    }
137}
138
139impl Default for ToolRegistry {
140    fn default() -> Self {
141        Self::new()
142    }
143}
144
145// Example built-in tools
146
147pub struct CalculatorTool;
148
149#[async_trait]
150impl Tool for CalculatorTool {
151    fn name(&self) -> &str {
152        "calculator"
153    }
154
155    fn description(&self) -> &str {
156        "Perform basic arithmetic operations. Supports +, -, *, / operations."
157    }
158
159    fn parameters(&self) -> HashMap<String, ToolParameter> {
160        let mut params = HashMap::new();
161        params.insert(
162            "expression".to_string(),
163            ToolParameter {
164                param_type: "string".to_string(),
165                description: "Mathematical expression to evaluate (e.g., '2 + 2')".to_string(),
166                required: Some(true),
167            },
168        );
169        params
170    }
171
172    async fn execute(&self, args: Value) -> Result<ToolResult> {
173        let expression = args
174            .get("expression")
175            .and_then(|v| v.as_str())
176            .ok_or_else(|| HeliosError::ToolError("Missing 'expression' parameter".to_string()))?;
177
178        // Simple expression evaluator
179        let result = evaluate_expression(expression)?;
180        Ok(ToolResult::success(result.to_string()))
181    }
182}
183
184fn evaluate_expression(expr: &str) -> Result<f64> {
185    let expr = expr.replace(" ", "");
186
187    // Simple parsing for basic operations
188    for op in &['*', '/', '+', '-'] {
189        if let Some(pos) = expr.rfind(*op) {
190            if pos == 0 {
191                continue; // Skip if operator is at the beginning (negative number)
192            }
193            let left = &expr[..pos];
194            let right = &expr[pos + 1..];
195
196            let left_val = evaluate_expression(left)?;
197            let right_val = evaluate_expression(right)?;
198
199            return Ok(match op {
200                '+' => left_val + right_val,
201                '-' => left_val - right_val,
202                '*' => left_val * right_val,
203                '/' => {
204                    if right_val == 0.0 {
205                        return Err(HeliosError::ToolError("Division by zero".to_string()));
206                    }
207                    left_val / right_val
208                }
209                _ => unreachable!(),
210            });
211        }
212    }
213
214    expr.parse::<f64>()
215        .map_err(|_| HeliosError::ToolError(format!("Invalid expression: {}", expr)))
216}
217
218pub struct EchoTool;
219
220#[async_trait]
221impl Tool for EchoTool {
222    fn name(&self) -> &str {
223        "echo"
224    }
225
226    fn description(&self) -> &str {
227        "Echo back the provided message."
228    }
229
230    fn parameters(&self) -> HashMap<String, ToolParameter> {
231        let mut params = HashMap::new();
232        params.insert(
233            "message".to_string(),
234            ToolParameter {
235                param_type: "string".to_string(),
236                description: "The message to echo back".to_string(),
237                required: Some(true),
238            },
239        );
240        params
241    }
242
243    async fn execute(&self, args: Value) -> Result<ToolResult> {
244        let message = args
245            .get("message")
246            .and_then(|v| v.as_str())
247            .ok_or_else(|| HeliosError::ToolError("Missing 'message' parameter".to_string()))?;
248
249        Ok(ToolResult::success(format!("Echo: {}", message)))
250    }
251}
252
253pub struct FileSearchTool;
254
255#[async_trait]
256impl Tool for FileSearchTool {
257    fn name(&self) -> &str {
258        "file_search"
259    }
260
261    fn description(&self) -> &str {
262        "Search for files by name pattern or search for content within files. Can search recursively in directories."
263    }
264
265    fn parameters(&self) -> HashMap<String, ToolParameter> {
266        let mut params = HashMap::new();
267        params.insert(
268            "path".to_string(),
269            ToolParameter {
270                param_type: "string".to_string(),
271                description: "The directory path to search in (default: current directory)".to_string(),
272                required: Some(false),
273            },
274        );
275        params.insert(
276            "pattern".to_string(),
277            ToolParameter {
278                param_type: "string".to_string(),
279                description: "File name pattern to search for (supports wildcards like *.rs)".to_string(),
280                required: Some(false),
281            },
282        );
283        params.insert(
284            "content".to_string(),
285            ToolParameter {
286                param_type: "string".to_string(),
287                description: "Text content to search for within files".to_string(),
288                required: Some(false),
289            },
290        );
291        params.insert(
292            "max_results".to_string(),
293            ToolParameter {
294                param_type: "number".to_string(),
295                description: "Maximum number of results to return (default: 50)".to_string(),
296                required: Some(false),
297            },
298        );
299        params
300    }
301
302    async fn execute(&self, args: Value) -> Result<ToolResult> {
303        use walkdir::WalkDir;
304
305        let base_path = args
306            .get("path")
307            .and_then(|v| v.as_str())
308            .unwrap_or(".");
309        
310        let pattern = args.get("pattern").and_then(|v| v.as_str());
311        let content_search = args.get("content").and_then(|v| v.as_str());
312        let max_results = args
313            .get("max_results")
314            .and_then(|v| v.as_u64())
315            .unwrap_or(50) as usize;
316
317        if pattern.is_none() && content_search.is_none() {
318            return Err(HeliosError::ToolError(
319                "Either 'pattern' or 'content' parameter is required".to_string(),
320            ));
321        }
322
323        let mut results = Vec::new();
324        
325        for entry in WalkDir::new(base_path)
326            .max_depth(10)
327            .follow_links(false)
328            .into_iter()
329            .filter_map(|e| e.ok())
330        {
331            if results.len() >= max_results {
332                break;
333            }
334
335            let path = entry.path();
336            
337            // Skip hidden files and common ignore directories
338            if let Some(file_name) = path.file_name().and_then(|n| n.to_str()) {
339                if file_name.starts_with('.') || 
340                   file_name == "target" || 
341                   file_name == "node_modules" ||
342                   file_name == "__pycache__" {
343                    continue;
344                }
345            }
346
347            // Pattern matching for file names
348            if let Some(pat) = pattern {
349                if path.is_file() {
350                    if let Some(file_name) = path.file_name().and_then(|n| n.to_str()) {
351                        if glob_match(file_name, pat) {
352                            results.push(format!("📄 {}", path.display()));
353                        }
354                    }
355                }
356            }
357
358            // Content search within files
359            if let Some(search_term) = content_search {
360                if path.is_file() {
361                    if let Ok(content) = std::fs::read_to_string(path) {
362                        if content.contains(search_term) {
363                            // Find line numbers where content appears
364                            let matching_lines: Vec<(usize, &str)> = content
365                                .lines()
366                                .enumerate()
367                                .filter(|(_, line)| line.contains(search_term))
368                                .take(3) // Show up to 3 matching lines per file
369                                .collect();
370                            
371                            if !matching_lines.is_empty() {
372                                results.push(format!("📄 {} (found in {} lines)", 
373                                    path.display(), matching_lines.len()));
374                                for (line_num, line) in matching_lines {
375                                    results.push(format!("  Line {}: {}", line_num + 1, line.trim()));
376                                }
377                            }
378                        }
379                    }
380                }
381            }
382        }
383
384        if results.is_empty() {
385            Ok(ToolResult::success("No files found matching the criteria.".to_string()))
386        } else {
387            let output = format!(
388                "Found {} result(s):\n\n{}",
389                results.len(),
390                results.join("\n")
391            );
392            Ok(ToolResult::success(output))
393        }
394    }
395}
396
397// Simple glob matching helper
398fn glob_match(text: &str, pattern: &str) -> bool {
399    let re_pattern = pattern
400        .replace(".", r"\.")
401        .replace("*", ".*")
402        .replace("?", ".");
403    
404    if let Ok(re) = regex::Regex::new(&format!("^{}$", re_pattern)) {
405        re.is_match(text)
406    } else {
407        text.contains(pattern)
408    }
409}
410
411pub struct FileReadTool;
412
413#[async_trait]
414impl Tool for FileReadTool {
415    fn name(&self) -> &str {
416        "file_read"
417    }
418
419    fn description(&self) -> &str {
420        "Read the contents of a file. Returns the full file content or specific lines."
421    }
422
423    fn parameters(&self) -> HashMap<String, ToolParameter> {
424        let mut params = HashMap::new();
425        params.insert(
426            "path".to_string(),
427            ToolParameter {
428                param_type: "string".to_string(),
429                description: "The file path to read".to_string(),
430                required: Some(true),
431            },
432        );
433        params.insert(
434            "start_line".to_string(),
435            ToolParameter {
436                param_type: "number".to_string(),
437                description: "Starting line number (1-indexed, optional)".to_string(),
438                required: Some(false),
439            },
440        );
441        params.insert(
442            "end_line".to_string(),
443            ToolParameter {
444                param_type: "number".to_string(),
445                description: "Ending line number (1-indexed, optional)".to_string(),
446                required: Some(false),
447            },
448        );
449        params
450    }
451
452    async fn execute(&self, args: Value) -> Result<ToolResult> {
453        let file_path = args
454            .get("path")
455            .and_then(|v| v.as_str())
456            .ok_or_else(|| HeliosError::ToolError("Missing 'path' parameter".to_string()))?;
457
458        let content = std::fs::read_to_string(file_path)
459            .map_err(|e| HeliosError::ToolError(format!("Failed to read file: {}", e)))?;
460
461        let start_line = args.get("start_line").and_then(|v| v.as_u64()).map(|n| n as usize);
462        let end_line = args.get("end_line").and_then(|v| v.as_u64()).map(|n| n as usize);
463
464        let output = if let (Some(start), Some(end)) = (start_line, end_line) {
465            let lines: Vec<&str> = content.lines().collect();
466            let start_idx = start.saturating_sub(1);
467            let end_idx = end.min(lines.len());
468            
469            if start_idx >= lines.len() {
470                return Err(HeliosError::ToolError(format!(
471                    "Start line {} is beyond file length ({})",
472                    start, lines.len()
473                )));
474            }
475            
476            let selected_lines = &lines[start_idx..end_idx];
477            format!(
478                "File: {} (lines {}-{}):\n\n{}",
479                file_path,
480                start,
481                end_idx,
482                selected_lines.join("\n")
483            )
484        } else {
485            format!("File: {}:\n\n{}", file_path, content)
486        };
487
488        Ok(ToolResult::success(output))
489    }
490}
491
492pub struct FileWriteTool;
493
494#[async_trait]
495impl Tool for FileWriteTool {
496    fn name(&self) -> &str {
497        "file_write"
498    }
499
500    fn description(&self) -> &str {
501        "Write content to a file. Creates new file or overwrites existing file."
502    }
503
504    fn parameters(&self) -> HashMap<String, ToolParameter> {
505        let mut params = HashMap::new();
506        params.insert(
507            "path".to_string(),
508            ToolParameter {
509                param_type: "string".to_string(),
510                description: "The file path to write to".to_string(),
511                required: Some(true),
512            },
513        );
514        params.insert(
515            "content".to_string(),
516            ToolParameter {
517                param_type: "string".to_string(),
518                description: "The content to write to the file".to_string(),
519                required: Some(true),
520            },
521        );
522        params
523    }
524
525    async fn execute(&self, args: Value) -> Result<ToolResult> {
526        let file_path = args
527            .get("path")
528            .and_then(|v| v.as_str())
529            .ok_or_else(|| HeliosError::ToolError("Missing 'path' parameter".to_string()))?;
530
531        let content = args
532            .get("content")
533            .and_then(|v| v.as_str())
534            .ok_or_else(|| HeliosError::ToolError("Missing 'content' parameter".to_string()))?;
535
536        // Create parent directories if they don't exist
537        if let Some(parent) = std::path::Path::new(file_path).parent() {
538            std::fs::create_dir_all(parent)
539                .map_err(|e| HeliosError::ToolError(format!("Failed to create directories: {}", e)))?;
540        }
541
542        std::fs::write(file_path, content)
543            .map_err(|e| HeliosError::ToolError(format!("Failed to write file: {}", e)))?;
544
545        Ok(ToolResult::success(format!(
546            "Successfully wrote {} bytes to {}",
547            content.len(),
548            file_path
549        )))
550    }
551}
552
553pub struct FileEditTool;
554
555#[async_trait]
556impl Tool for FileEditTool {
557    fn name(&self) -> &str {
558        "file_edit"
559    }
560
561    fn description(&self) -> &str {
562        "Edit a file by replacing specific text or lines. Use this to make targeted changes to existing files."
563    }
564
565    fn parameters(&self) -> HashMap<String, ToolParameter> {
566        let mut params = HashMap::new();
567        params.insert(
568            "path".to_string(),
569            ToolParameter {
570                param_type: "string".to_string(),
571                description: "The file path to edit".to_string(),
572                required: Some(true),
573            },
574        );
575        params.insert(
576            "find".to_string(),
577            ToolParameter {
578                param_type: "string".to_string(),
579                description: "The text to find and replace".to_string(),
580                required: Some(true),
581            },
582        );
583        params.insert(
584            "replace".to_string(),
585            ToolParameter {
586                param_type: "string".to_string(),
587                description: "The replacement text".to_string(),
588                required: Some(true),
589            },
590        );
591        params
592    }
593
594    async fn execute(&self, args: Value) -> Result<ToolResult> {
595        let file_path = args
596            .get("path")
597            .and_then(|v| v.as_str())
598            .ok_or_else(|| HeliosError::ToolError("Missing 'path' parameter".to_string()))?;
599
600        let find_text = args
601            .get("find")
602            .and_then(|v| v.as_str())
603            .ok_or_else(|| HeliosError::ToolError("Missing 'find' parameter".to_string()))?;
604
605        let replace_text = args
606            .get("replace")
607            .and_then(|v| v.as_str())
608            .ok_or_else(|| HeliosError::ToolError("Missing 'replace' parameter".to_string()))?;
609
610        if find_text.is_empty() {
611            return Err(HeliosError::ToolError("'find' parameter cannot be empty".to_string()));
612        }
613
614        let path = Path::new(file_path);
615        let parent = path.parent().ok_or_else(|| {
616            HeliosError::ToolError(format!("Invalid target path: {}", file_path))
617        })?;
618        let file_name = path.file_name().ok_or_else(|| {
619            HeliosError::ToolError(format!("Invalid target path: {}", file_path))
620        })?;
621
622        // Build a temp file path in the same directory for atomic rename
623        let pid = std::process::id();
624        let nanos = SystemTime::now()
625            .duration_since(UNIX_EPOCH)
626            .map_err(|e| HeliosError::ToolError(format!("Clock error: {}", e)))?
627            .as_nanos();
628        let tmp_name = format!("{}.tmp.{}.{}", file_name.to_string_lossy(), pid, nanos);
629        let tmp_path = parent.join(tmp_name);
630
631        // Open files
632        let input_file = std::fs::File::open(&path)
633            .map_err(|e| HeliosError::ToolError(format!("Failed to open file for read: {}", e)))?;
634        let mut reader = BufReader::new(input_file);
635
636        let tmp_file = std::fs::File::create(&tmp_path).map_err(|e| {
637            HeliosError::ToolError(format!("Failed to create temp file {}: {}", tmp_path.display(), e))
638        })?;
639        let mut writer = BufWriter::new(&tmp_file);
640
641        // Streamed find/replace to avoid loading entire file into memory
642        let replaced_count = replace_streaming(
643            &mut reader,
644            &mut writer,
645            find_text.as_bytes(),
646            replace_text.as_bytes(),
647        )
648        .map_err(|e| HeliosError::ToolError(format!("I/O error while replacing: {}", e)))?;
649
650        // Ensure all data is flushed and synced before rename
651        writer.flush().map_err(|e| HeliosError::ToolError(format!("Failed to flush temp file: {}", e)))?;
652        tmp_file.sync_all().map_err(|e| HeliosError::ToolError(format!("Failed to sync temp file: {}", e)))?;
653
654        // Preserve permissions
655        if let Ok(meta) = std::fs::metadata(&path) {
656            if let Err(e) = std::fs::set_permissions(&tmp_path, meta.permissions()) {
657                let _ = std::fs::remove_file(&tmp_path);
658                return Err(HeliosError::ToolError(format!("Failed to set permissions: {}", e)));
659            }
660        }
661
662        // Atomic replace
663        std::fs::rename(&tmp_path, &path).map_err(|e| {
664            let _ = std::fs::remove_file(&tmp_path);
665            HeliosError::ToolError(format!("Failed to replace original file: {}", e))
666        })?;
667
668        if replaced_count == 0 {
669            return Ok(ToolResult::error(format!(
670                "Text '{}' not found in file {}",
671                find_text, file_path
672            )));
673        }
674
675        Ok(ToolResult::success(format!(
676            "Successfully replaced {} occurrence(s) in {}",
677            replaced_count, file_path
678        )))
679    }
680}
681
682// Streamed replacement helpers
683fn replace_streaming<R: Read, W: Write>(reader: &mut R, writer: &mut W, needle: &[u8], replacement: &[u8]) -> std::io::Result<usize> {
684    let mut replaced = 0usize;
685    let mut carry: Vec<u8> = Vec::new();
686    let mut buf = [0u8; 8192];
687
688    let tail = if needle.len() > 1 { needle.len() - 1 } else { 0 };
689
690    loop {
691        let n = reader.read(&mut buf)?;
692        if n == 0 {
693            break;
694        }
695
696        let mut combined = Vec::with_capacity(carry.len() + n);
697        combined.extend_from_slice(&carry);
698        combined.extend_from_slice(&buf[..n]);
699
700        let process_len = combined.len().saturating_sub(tail);
701        let (to_process, new_carry) = combined.split_at(process_len);
702        replaced += write_with_replacements(writer, to_process, needle, replacement)?;
703        carry.clear();
704        carry.extend_from_slice(new_carry);
705    }
706
707    // Process remaining carry fully
708    replaced += write_with_replacements(writer, &carry, needle, replacement)?;
709    Ok(replaced)
710}
711
712fn write_with_replacements<W: Write>(writer: &mut W, haystack: &[u8], needle: &[u8], replacement: &[u8]) -> std::io::Result<usize> {
713    if needle.is_empty() {
714        writer.write_all(haystack)?;
715        return Ok(0);
716    }
717
718    let mut count = 0usize;
719    let mut i = 0usize;
720    while let Some(pos) = find_subslice(&haystack[i..], needle) {
721        let idx = i + pos;
722        writer.write_all(&haystack[i..idx])?;
723        writer.write_all(replacement)?;
724        count += 1;
725        i = idx + needle.len();
726    }
727    writer.write_all(&haystack[i..])?;
728    Ok(count)
729}
730
731fn find_subslice(h: &[u8], n: &[u8]) -> Option<usize> {
732    if n.is_empty() {
733        return Some(0);
734    }
735    h.windows(n.len()).position(|w| w == n)
736}
737
738#[cfg(test)]
739mod tests {
740    use super::*;
741    use serde_json::json;
742
743    #[test]
744    fn test_tool_result_success() {
745        let result = ToolResult::success("test output");
746        assert!(result.success);
747        assert_eq!(result.output, "test output");
748    }
749
750    #[test]
751    fn test_tool_result_error() {
752        let result = ToolResult::error("test error");
753        assert!(!result.success);
754        assert_eq!(result.output, "test error");
755    }
756
757    #[tokio::test]
758    async fn test_calculator_tool() {
759        let tool = CalculatorTool;
760        assert_eq!(tool.name(), "calculator");
761        assert_eq!(
762            tool.description(),
763            "Perform basic arithmetic operations. Supports +, -, *, / operations."
764);
765
766        let args = json!({"expression": "2 + 2"});
767        let result = tool.execute(args).await.unwrap();
768        assert!(result.success);
769        assert_eq!(result.output, "4");
770    }
771
772    #[tokio::test]
773    async fn test_calculator_tool_multiplication() {
774        let tool = CalculatorTool;
775        let args = json!({"expression": "3 * 4"});
776        let result = tool.execute(args).await.unwrap();
777        assert!(result.success);
778        assert_eq!(result.output, "12");
779    }
780
781    #[tokio::test]
782    async fn test_calculator_tool_division() {
783        let tool = CalculatorTool;
784        let args = json!({"expression": "8 / 2"});
785        let result = tool.execute(args).await.unwrap();
786        assert!(result.success);
787        assert_eq!(result.output, "4");
788    }
789
790    #[tokio::test]
791    async fn test_calculator_tool_division_by_zero() {
792        let tool = CalculatorTool;
793        let args = json!({"expression": "8 / 0"});
794        let result = tool.execute(args).await;
795        assert!(result.is_err());
796    }
797
798    #[tokio::test]
799    async fn test_calculator_tool_invalid_expression() {
800        let tool = CalculatorTool;
801        let args = json!({"expression": "invalid"});
802        let result = tool.execute(args).await;
803        assert!(result.is_err());
804    }
805
806    #[tokio::test]
807    async fn test_echo_tool() {
808        let tool = EchoTool;
809        assert_eq!(tool.name(), "echo");
810        assert_eq!(tool.description(), "Echo back the provided message.");
811
812        let args = json!({"message": "Hello, world!"});
813        let result = tool.execute(args).await.unwrap();
814        assert!(result.success);
815        assert_eq!(result.output, "Echo: Hello, world!");
816    }
817
818    #[tokio::test]
819    async fn test_echo_tool_missing_parameter() {
820        let tool = EchoTool;
821        let args = json!({});
822        let result = tool.execute(args).await;
823        assert!(result.is_err());
824    }
825
826    #[test]
827    fn test_tool_registry_new() {
828        let registry = ToolRegistry::new();
829        assert!(registry.tools.is_empty());
830    }
831
832    #[tokio::test]
833    async fn test_tool_registry_register_and_get() {
834        let mut registry = ToolRegistry::new();
835        registry.register(Box::new(CalculatorTool));
836
837        let tool = registry.get("calculator");
838        assert!(tool.is_some());
839        assert_eq!(tool.unwrap().name(), "calculator");
840    }
841
842    #[tokio::test]
843    async fn test_tool_registry_execute() {
844        let mut registry = ToolRegistry::new();
845        registry.register(Box::new(CalculatorTool));
846
847        let args = json!({"expression": "5 * 6"});
848        let result = registry.execute("calculator", args).await.unwrap();
849        assert!(result.success);
850        assert_eq!(result.output, "30");
851    }
852
853    #[tokio::test]
854    async fn test_tool_registry_execute_nonexistent_tool() {
855        let registry = ToolRegistry::new();
856        let args = json!({"expression": "5 * 6"});
857        let result = registry.execute("nonexistent", args).await;
858        assert!(result.is_err());
859    }
860
861    #[test]
862    fn test_tool_registry_get_definitions() {
863        let mut registry = ToolRegistry::new();
864        registry.register(Box::new(CalculatorTool));
865        registry.register(Box::new(EchoTool));
866
867        let definitions = registry.get_definitions();
868        assert_eq!(definitions.len(), 2);
869
870        // Check that we have both tools
871        let names: Vec<String> = definitions
872            .iter()
873            .map(|d| d.function.name.clone())
874            .collect();
875        assert!(names.contains(&"calculator".to_string()));
876        assert!(names.contains(&"echo".to_string()));
877    }
878
879    #[test]
880    fn test_tool_registry_list_tools() {
881        let mut registry = ToolRegistry::new();
882        registry.register(Box::new(CalculatorTool));
883        registry.register(Box::new(EchoTool));
884
885        let tools = registry.list_tools();
886        assert_eq!(tools.len(), 2);
887        assert!(tools.contains(&"calculator".to_string()));
888        assert!(tools.contains(&"echo".to_string()));
889    }
890}