Skip to main content

sh_layer2/
tool_registry.rs

1//! # Tool Registry
2//!
3//! 工具注册和发现机制。
4
5use async_trait::async_trait;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::sync::Arc;
9
10use crate::types::{Layer2Error, Layer2Result, ToolResult};
11
12/// 工具元数据
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct ToolMeta {
15    pub name: String,
16    pub description: String,
17    pub parameters: serde_json::Value,
18    pub required: Vec<String>,
19}
20
21/// 工具调用请求
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct ToolRequest {
24    pub tool_call_id: String,
25    pub name: String,
26    pub arguments: serde_json::Value,
27}
28
29/// 工具定义(OpenAI 格式)
30#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct ToolDefinition {
32    pub r#type: String,
33    pub function: FunctionDefinition,
34}
35
36/// 函数定义
37#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct FunctionDefinition {
39    pub name: String,
40    pub description: String,
41    pub parameters: serde_json::Value,
42}
43
44/// 工具接口
45///
46/// 所有工具必须实现此接口。
47#[async_trait]
48pub trait Tool: Send + Sync {
49    /// 获取工具名称
50    fn name(&self) -> &str;
51
52    /// 获取工具描述
53    fn description(&self) -> &str;
54
55    /// 获取参数 schema
56    fn parameters(&self) -> serde_json::Value;
57
58    /// 执行工具
59    async fn execute(&self, args: &str) -> Layer2Result<ToolResult>;
60
61    /// 验证参数
62    fn validate_args(&self, _args: &serde_json::Value) -> Layer2Result<bool> {
63        // 默认实现:总是返回 true
64        Ok(true)
65    }
66}
67
68/// 工具注册接口
69#[async_trait]
70pub trait ToolRegistryTrait: Send + Sync {
71    /// 注册工具
72    fn register(&self, tool: Box<dyn Tool>) -> Layer2Result<()>;
73
74    /// 注销工具
75    fn unregister(&self, name: &str) -> Layer2Result<bool>;
76
77    /// 获取工具
78    fn get(&self, name: &str) -> Option<Arc<dyn Tool>>;
79
80    /// 检查工具是否存在
81    fn exists(&self, name: &str) -> bool;
82
83    /// 列出所有工具名称
84    fn list(&self) -> Vec<String>;
85
86    /// 获取所有工具定义(OpenAI 格式)
87    fn definitions(&self) -> Vec<ToolDefinition>;
88
89    /// 执行工具
90    async fn execute(&self, name: &str, args: &str) -> Layer2Result<ToolResult>;
91
92    /// 获取工具数量
93    fn count(&self) -> usize;
94}
95
96/// 工具注册表实现
97pub struct ToolRegistry {
98    tools: parking_lot::RwLock<HashMap<String, Arc<dyn Tool>>>,
99}
100
101impl ToolRegistry {
102    pub fn new() -> Self {
103        Self {
104            tools: parking_lot::RwLock::new(HashMap::new()),
105        }
106    }
107
108    /// 创建带内置工具的注册表
109    pub fn with_builtin_tools() -> Self {
110        Self::new()
111        // 内置工具将在 Layer 3 实现
112    }
113}
114
115impl Default for ToolRegistry {
116    fn default() -> Self {
117        Self::new()
118    }
119}
120
121#[async_trait]
122impl ToolRegistryTrait for ToolRegistry {
123    fn register(&self, tool: Box<dyn Tool>) -> Layer2Result<()> {
124        let mut tools = self.tools.write();
125        let name = tool.name().to_string();
126        tools.insert(name, Arc::from(tool));
127        Ok(())
128    }
129
130    fn unregister(&self, name: &str) -> Layer2Result<bool> {
131        let mut tools = self.tools.write();
132        Ok(tools.remove(name).is_some())
133    }
134
135    fn get(&self, name: &str) -> Option<Arc<dyn Tool>> {
136        let tools = self.tools.read();
137        tools.get(name).cloned()
138    }
139
140    fn exists(&self, name: &str) -> bool {
141        let tools = self.tools.read();
142        tools.contains_key(name)
143    }
144
145    fn list(&self) -> Vec<String> {
146        let tools = self.tools.read();
147        tools.keys().cloned().collect()
148    }
149
150    fn definitions(&self) -> Vec<ToolDefinition> {
151        let tools = self.tools.read();
152        tools
153            .values()
154            .map(|tool| ToolDefinition {
155                r#type: "function".to_string(),
156                function: FunctionDefinition {
157                    name: tool.name().to_string(),
158                    description: tool.description().to_string(),
159                    parameters: tool.parameters(),
160                },
161            })
162            .collect()
163    }
164
165    async fn execute(&self, name: &str, args: &str) -> Layer2Result<ToolResult> {
166        let tool = self
167            .get(name)
168            .ok_or_else(|| Layer2Error::ToolNotFound(name.to_string()))?;
169
170        tool.execute(args).await
171    }
172
173    fn count(&self) -> usize {
174        let tools = self.tools.read();
175        tools.len()
176    }
177}
178
179#[cfg(test)]
180mod tests {
181    use super::*;
182
183    #[test]
184    fn test_tool_registry_creation() {
185        let registry = ToolRegistry::new();
186        assert_eq!(registry.count(), 0);
187    }
188
189    #[test]
190    fn test_tool_registry_list() {
191        let registry = ToolRegistry::new();
192        let list = registry.list();
193        assert!(list.is_empty());
194    }
195}