Skip to main content

abu_tool/
register.rs

1use std::collections::HashMap;
2use abu_base::chat::ToolDefinition;
3use super::{Tool, ToolCallResult, ToolError, ToolResult};
4
5pub struct ToolRegister {
6    tools: HashMap<&'static str , Box<dyn Tool>>,
7}
8
9impl ToolRegister {
10    pub fn new() -> Self {
11        Self { tools: HashMap::new() }
12    }
13
14    pub fn init<I: Iterator<Item = Box<dyn Tool>>>(tools: I) -> Self {
15        Self {
16            tools: tools.map(|tool| (tool.name(), tool)).collect()
17        }
18    }
19
20    #[inline]
21    pub fn len(&self) -> usize {
22        self.tools.len()
23    }
24
25    pub fn get_tool(&self, name: &str) -> Option<&Box<dyn Tool>> {
26        self.tools.get(name)
27    }
28
29    pub fn add_tool<T: Tool + 'static>(&mut self, tool: T) {
30        let tool = Box::new(tool);
31        self.add_tool_box(tool);
32    }
33
34    pub fn add_tool_box(&mut self, tool: Box<dyn Tool>) {
35        let name = tool.name();
36        self.tools.insert(name, tool);
37    }
38
39    pub fn has_tool(&self, tool_name: &str) -> bool {
40        self.tools.contains_key(tool_name)
41    }
42
43    /// Return tool execute error if tool inner error
44    pub async fn execute(&self, name: &str, arguments: serde_json::Value) -> ToolResult<ToolCallResult> {
45        let tool = self.get_tool(name).ok_or_else(|| ToolError::ToolNotFound(name.to_string()))?;
46        let value = tool.execute(arguments).await?;
47        Ok(value)
48    }
49
50    pub fn to_function_defines(&self) -> Vec<ToolDefinition> {
51        self.tools.iter().map(|(_, tool)| tool.to_function_define()).collect()
52    }
53
54    pub fn add_tools<I: Iterator<Item = Box<dyn Tool>>>(&mut self, tools: I) {
55        for tool in tools {
56            self.add_tool_box(tool);
57        }
58    }
59}