Skip to main content

mofa_kernel/agent/components/
tool.rs

1//! 工具组件
2//!
3//! 定义统一的工具接口,合并 ToolExecutor 和 ReActTool
4
5use crate::agent::context::AgentContext;
6use crate::agent::error::{AgentError, AgentResult};
7use async_trait::async_trait;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::sync::Arc;
11
12/// 统一工具 Trait
13///
14/// 合并了 ToolExecutor 和 ReActTool 的功能
15///
16/// # 示例
17///
18/// ```rust,ignore
19/// use mofa_kernel::agent::components::tool::{Tool, ToolInput, ToolResult, ToolMetadata};
20///
21/// struct Calculator;
22///
23/// #[async_trait]
24/// impl Tool for Calculator {
25///     fn name(&self) -> &str { "calculator" }
26///     fn description(&self) -> &str { "Perform arithmetic operations" }
27///     fn parameters_schema(&self) -> serde_json::Value {
28///         serde_json::json!({
29///             "type": "object",
30///             "properties": {
31///                 "operation": { "type": "string", "enum": ["add", "sub", "mul", "div"] },
32///                 "a": { "type": "number" },
33///                 "b": { "type": "number" }
34///             },
35///             "required": ["operation", "a", "b"]
36///         })
37///     }
38///
39///     async fn execute(&self, input: ToolInput, ctx: &CoreAgentContext) -> ToolResult {
40///         // Implementation
41///     }
42///
43///     fn metadata(&self) -> ToolMetadata {
44///         ToolMetadata::default()
45///     }
46/// }
47/// ```
48#[async_trait]
49pub trait Tool: Send + Sync {
50    /// 工具名称 (唯一标识符)
51    fn name(&self) -> &str;
52
53    /// 工具描述 (用于 LLM 理解)
54    fn description(&self) -> &str;
55
56    /// 参数 JSON Schema
57    fn parameters_schema(&self) -> serde_json::Value;
58
59    /// 执行工具
60    async fn execute(&self, input: ToolInput, ctx: &AgentContext) -> ToolResult;
61
62    /// 工具元数据
63    fn metadata(&self) -> ToolMetadata {
64        ToolMetadata::default()
65    }
66
67    /// 验证输入
68    fn validate_input(&self, input: &ToolInput) -> AgentResult<()> {
69        // 默认不做验证,子类可以覆盖
70        let _ = input;
71        Ok(())
72    }
73
74    /// 是否需要确认
75    fn requires_confirmation(&self) -> bool {
76        false
77    }
78
79    /// 转换为 LLM Tool 格式
80    fn to_llm_tool(&self) -> LLMTool {
81        LLMTool {
82            name: self.name().to_string(),
83            description: self.description().to_string(),
84            parameters: self.parameters_schema(),
85        }
86    }
87}
88
89/// 工具输入
90#[derive(Debug, Clone, Serialize, Deserialize)]
91pub struct ToolInput {
92    /// 结构化参数
93    pub arguments: serde_json::Value,
94    /// 原始输入 (可选)
95    pub raw_input: Option<String>,
96}
97
98impl ToolInput {
99    /// 从 JSON 参数创建
100    pub fn from_json(arguments: serde_json::Value) -> Self {
101        Self {
102            arguments,
103            raw_input: None,
104        }
105    }
106
107    /// 从原始字符串创建
108    pub fn from_raw(raw: impl Into<String>) -> Self {
109        let raw = raw.into();
110        Self {
111            arguments: serde_json::Value::String(raw.clone()),
112            raw_input: Some(raw),
113        }
114    }
115
116    /// 获取参数值
117    pub fn get<T: serde::de::DeserializeOwned>(&self, key: &str) -> Option<T> {
118        self.arguments
119            .get(key)
120            .and_then(|v| serde_json::from_value(v.clone()).ok())
121    }
122
123    /// 获取字符串参数
124    pub fn get_str(&self, key: &str) -> Option<&str> {
125        self.arguments.get(key).and_then(|v| v.as_str())
126    }
127
128    /// 获取数字参数
129    pub fn get_number(&self, key: &str) -> Option<f64> {
130        self.arguments.get(key).and_then(|v| v.as_f64())
131    }
132
133    /// 获取布尔参数
134    pub fn get_bool(&self, key: &str) -> Option<bool> {
135        self.arguments.get(key).and_then(|v| v.as_bool())
136    }
137}
138
139impl From<serde_json::Value> for ToolInput {
140    fn from(v: serde_json::Value) -> Self {
141        Self::from_json(v)
142    }
143}
144
145impl From<String> for ToolInput {
146    fn from(s: String) -> Self {
147        Self::from_raw(s)
148    }
149}
150
151impl From<&str> for ToolInput {
152    fn from(s: &str) -> Self {
153        Self::from_raw(s)
154    }
155}
156
157/// 工具执行结果
158#[derive(Debug, Clone, Serialize, Deserialize)]
159pub struct ToolResult {
160    /// 是否成功
161    pub success: bool,
162    /// 输出内容
163    pub output: serde_json::Value,
164    /// 错误信息 (如果失败)
165    pub error: Option<String>,
166    /// 额外元数据
167    pub metadata: HashMap<String, String>,
168}
169
170impl ToolResult {
171    /// 创建成功结果
172    pub fn success(output: serde_json::Value) -> Self {
173        Self {
174            success: true,
175            output,
176            error: None,
177            metadata: HashMap::new(),
178        }
179    }
180
181    /// 创建文本成功结果
182    pub fn success_text(text: impl Into<String>) -> Self {
183        Self::success(serde_json::Value::String(text.into()))
184    }
185
186    /// 创建失败结果
187    pub fn failure(error: impl Into<String>) -> Self {
188        Self {
189            success: false,
190            output: serde_json::Value::Null,
191            error: Some(error.into()),
192            metadata: HashMap::new(),
193        }
194    }
195
196    /// 添加元数据
197    pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
198        self.metadata.insert(key.into(), value.into());
199        self
200    }
201
202    /// 获取文本输出
203    pub fn as_text(&self) -> Option<&str> {
204        self.output.as_str()
205    }
206
207    /// 转换为字符串
208    pub fn to_string_output(&self) -> String {
209        if self.success {
210            match &self.output {
211                serde_json::Value::String(s) => s.clone(),
212                v => v.to_string(),
213            }
214        } else {
215            format!(
216                "Error: {}",
217                self.error.as_deref().unwrap_or("Unknown error")
218            )
219        }
220    }
221}
222
223/// 工具元数据
224#[derive(Debug, Clone, Default, Serialize, Deserialize)]
225pub struct ToolMetadata {
226    /// 工具分类
227    pub category: Option<String>,
228    /// 工具标签
229    pub tags: Vec<String>,
230    /// 是否为危险操作
231    pub is_dangerous: bool,
232    /// 是否需要网络
233    pub requires_network: bool,
234    /// 是否需要文件系统访问
235    pub requires_filesystem: bool,
236    /// 自定义属性
237    pub custom: HashMap<String, serde_json::Value>,
238}
239
240impl ToolMetadata {
241    /// 创建新的元数据
242    pub fn new() -> Self {
243        Self::default()
244    }
245
246    /// 设置分类
247    pub fn with_category(mut self, category: impl Into<String>) -> Self {
248        self.category = Some(category.into());
249        self
250    }
251
252    /// 添加标签
253    pub fn with_tag(mut self, tag: impl Into<String>) -> Self {
254        self.tags.push(tag.into());
255        self
256    }
257
258    /// 标记为危险操作
259    pub fn dangerous(mut self) -> Self {
260        self.is_dangerous = true;
261        self
262    }
263
264    /// 标记需要网络
265    pub fn needs_network(mut self) -> Self {
266        self.requires_network = true;
267        self
268    }
269
270    /// 标记需要文件系统
271    pub fn needs_filesystem(mut self) -> Self {
272        self.requires_filesystem = true;
273        self
274    }
275}
276
277/// 工具描述符 (用于列表展示)
278#[derive(Debug, Clone, Serialize, Deserialize)]
279pub struct ToolDescriptor {
280    /// 工具名称
281    pub name: String,
282    /// 工具描述
283    pub description: String,
284    /// 参数 Schema
285    pub parameters_schema: serde_json::Value,
286    /// 元数据
287    pub metadata: ToolMetadata,
288}
289
290impl ToolDescriptor {
291    /// 从 Tool 创建描述符
292    pub fn from_tool(tool: &dyn Tool) -> Self {
293        Self {
294            name: tool.name().to_string(),
295            description: tool.description().to_string(),
296            parameters_schema: tool.parameters_schema(),
297            metadata: tool.metadata(),
298        }
299    }
300}
301
302/// LLM Tool 格式 (用于 API 调用)
303#[derive(Debug, Clone, Serialize, Deserialize)]
304pub struct LLMTool {
305    /// 工具名称
306    pub name: String,
307    /// 工具描述
308    pub description: String,
309    /// 参数 Schema
310    pub parameters: serde_json::Value,
311}
312
313// ============================================================================
314// 工具注册中心 Trait (接口仅在此定义)
315// ============================================================================
316
317/// 定义工具注册的接口,具体实现在 foundation 层。
318#[async_trait]
319pub trait ToolRegistry: Send + Sync {
320    /// 注册工具
321    fn register(&mut self, tool: Arc<dyn Tool>) -> AgentResult<()>;
322
323    /// 批量注册工具
324    fn register_all(&mut self, tools: Vec<Arc<dyn Tool>>) -> AgentResult<()> {
325        for tool in tools {
326            self.register(tool)?;
327        }
328        Ok(())
329    }
330
331    /// 获取工具
332    fn get(&self, name: &str) -> Option<Arc<dyn Tool>>;
333
334    /// 移除工具
335    fn unregister(&mut self, name: &str) -> AgentResult<bool>;
336
337    /// 列出所有工具
338    fn list(&self) -> Vec<ToolDescriptor>;
339
340    /// 列出所有工具名称
341    fn list_names(&self) -> Vec<String>;
342
343    /// 检查工具是否存在
344    fn contains(&self, name: &str) -> bool;
345
346    /// 获取工具数量
347    fn count(&self) -> usize;
348
349    /// 执行工具
350    async fn execute(
351        &self,
352        name: &str,
353        input: ToolInput,
354        ctx: &AgentContext,
355    ) -> AgentResult<ToolResult> {
356        let tool = self
357            .get(name)
358            .ok_or_else(|| AgentError::ToolNotFound(name.to_string()))?;
359        tool.validate_input(&input)?;
360        Ok(tool.execute(input, ctx).await)
361    }
362
363    /// 转换为 LLM Tools
364    fn to_llm_tools(&self) -> Vec<LLMTool> {
365        self.list()
366            .iter()
367            .map(|d| LLMTool {
368                name: d.name.clone(),
369                description: d.description.clone(),
370                parameters: d.parameters_schema.clone(),
371            })
372            .collect()
373    }
374}
375
376#[cfg(test)]
377mod tests {
378    use super::*;
379    use crate::agent::context::AgentContext;
380
381    #[test]
382    fn test_tool_input_from_json() {
383        let input = ToolInput::from_json(serde_json::json!({
384            "name": "test",
385            "count": 42
386        }));
387
388        assert_eq!(input.get_str("name"), Some("test"));
389        assert_eq!(input.get_number("count"), Some(42.0));
390    }
391
392    #[test]
393    fn test_tool_result() {
394        let success = ToolResult::success_text("OK");
395        assert!(success.success);
396        assert_eq!(success.as_text(), Some("OK"));
397
398        let failure = ToolResult::failure("Something went wrong");
399        assert!(!failure.success);
400        assert!(failure.error.is_some());
401    }
402}