1use std::collections::HashMap;
2use abu_base::chat::ToolDefinition;
3use super::{Tool, ToolCallResult, ToolError, ToolResult};
4
5pub struct ToolRegister {
6 tools: HashMap<&'static str , Box<dyn Tool>>,
7}
8
9impl ToolRegister {
10 pub fn new() -> Self {
11 Self { tools: HashMap::new() }
12 }
13
14 pub fn init<I: Iterator<Item = Box<dyn Tool>>>(tools: I) -> Self {
15 Self {
16 tools: tools.map(|tool| (tool.name(), tool)).collect()
17 }
18 }
19
20 #[inline]
21 pub fn len(&self) -> usize {
22 self.tools.len()
23 }
24
25 pub fn get_tool(&self, name: &str) -> Option<&Box<dyn Tool>> {
26 self.tools.get(name)
27 }
28
29 pub fn add_tool<T: Tool + 'static>(&mut self, tool: T) {
30 let tool = Box::new(tool);
31 self.add_tool_box(tool);
32 }
33
34 pub fn add_tool_box(&mut self, tool: Box<dyn Tool>) {
35 let name = tool.name();
36 self.tools.insert(name, tool);
37 }
38
39 pub fn has_tool(&self, tool_name: &str) -> bool {
40 self.tools.contains_key(tool_name)
41 }
42
43 pub async fn execute(&self, name: String, arguments: serde_json::Value) -> ToolResult<ToolCallResult> {
45 let tool = self.get_tool(&name).ok_or_else(|| ToolError::ToolNotFound(name))?;
46 let value = tool.execute(arguments).await?;
47 Ok(value)
48 }
49
50 pub fn to_function_defines(&self) -> Vec<ToolDefinition> {
51 self.tools.iter().map(|(_, tool)| tool.to_function_define()).collect()
52 }
53
54 pub fn add_tools<I: Iterator<Item = Box<dyn Tool>>>(&mut self, tools: I) {
55 for tool in tools {
56 self.add_tool_box(tool);
57 }
58 }
59}