Skip to main content

mofa_foundation/agent/tools/
adapters.rs

1//! 工具适配器
2//!
3//! 提供便捷的工具创建方式
4
5use async_trait::async_trait;
6use mofa_kernel::agent::Tool;
7use mofa_kernel::agent::components::tool::{ToolInput, ToolMetadata, ToolResult};
8use mofa_kernel::agent::context::AgentContext;
9use std::future::Future;
10use std::pin::Pin;
11
12/// 函数工具
13///
14/// 从函数创建工具
15///
16/// # 示例
17///
18/// ```rust,ignore
19/// use mofa_foundation::agent::tools::FunctionTool;
20///
21/// async fn my_tool_fn(input: ToolInput, ctx: &AgentContext) -> ToolResult {
22///     let message = input.get_str("message").unwrap_or("default");
23///     ToolResult::success_text(format!("Processed: {}", message))
24/// }
25///
26/// let tool = FunctionTool::new(
27///     "my_tool",
28///     "A custom tool",
29///     serde_json::json!({
30///         "type": "object",
31///         "properties": {
32///             "message": { "type": "string" }
33///         }
34///     }),
35///     my_tool_fn,
36/// );
37/// ```
38pub struct FunctionTool<F>
39where
40    F: Fn(ToolInput, &AgentContext) -> Pin<Box<dyn Future<Output = ToolResult> + Send + '_>>
41        + Send
42        + Sync,
43{
44    name: String,
45    description: String,
46    parameters_schema: serde_json::Value,
47    handler: F,
48    metadata: ToolMetadata,
49}
50
51impl<F> FunctionTool<F>
52where
53    F: Fn(ToolInput, &AgentContext) -> Pin<Box<dyn Future<Output = ToolResult> + Send + '_>>
54        + Send
55        + Sync,
56{
57    /// 创建新的函数工具
58    pub fn new(
59        name: impl Into<String>,
60        description: impl Into<String>,
61        parameters_schema: serde_json::Value,
62        handler: F,
63    ) -> Self {
64        Self {
65            name: name.into(),
66            description: description.into(),
67            parameters_schema,
68            handler,
69            metadata: ToolMetadata::default(),
70        }
71    }
72
73    /// 设置元数据
74    pub fn with_metadata(mut self, metadata: ToolMetadata) -> Self {
75        self.metadata = metadata;
76        self
77    }
78}
79
80#[async_trait]
81impl<F> Tool for FunctionTool<F>
82where
83    F: Fn(ToolInput, &AgentContext) -> Pin<Box<dyn Future<Output = ToolResult> + Send + '_>>
84        + Send
85        + Sync,
86{
87    fn name(&self) -> &str {
88        &self.name
89    }
90
91    fn description(&self) -> &str {
92        &self.description
93    }
94
95    fn parameters_schema(&self) -> serde_json::Value {
96        self.parameters_schema.clone()
97    }
98
99    async fn execute(&self, input: ToolInput, ctx: &AgentContext) -> ToolResult {
100        (self.handler)(input, ctx).await
101    }
102
103    fn metadata(&self) -> ToolMetadata {
104        self.metadata.clone()
105    }
106}
107
108/// 闭包工具
109///
110/// 使用闭包创建简单工具
111///
112/// # 示例
113///
114/// ```rust,ignore
115/// use mofa_foundation::agent::tools::ClosureTool;
116///
117/// let tool = ClosureTool::new(
118///     "add",
119///     "Add two numbers",
120///     |input| {
121///         let a = input.get_number("a").unwrap_or(0.0);
122///         let b = input.get_number("b").unwrap_or(0.0);
123///         ToolResult::success_text(format!("{}", a + b))
124///     },
125/// );
126/// ```
127pub struct ClosureTool<F>
128where
129    F: Fn(ToolInput) -> ToolResult + Send + Sync,
130{
131    name: String,
132    description: String,
133    parameters_schema: serde_json::Value,
134    handler: F,
135    metadata: ToolMetadata,
136}
137
138impl<F> ClosureTool<F>
139where
140    F: Fn(ToolInput) -> ToolResult + Send + Sync,
141{
142    /// 创建新的闭包工具
143    pub fn new(name: impl Into<String>, description: impl Into<String>, handler: F) -> Self {
144        Self {
145            name: name.into(),
146            description: description.into(),
147            parameters_schema: serde_json::json!({
148                "type": "object",
149                "properties": {}
150            }),
151            handler,
152            metadata: ToolMetadata::default(),
153        }
154    }
155
156    /// 设置参数 Schema
157    pub fn with_schema(mut self, schema: serde_json::Value) -> Self {
158        self.parameters_schema = schema;
159        self
160    }
161
162    /// 设置元数据
163    pub fn with_metadata(mut self, metadata: ToolMetadata) -> Self {
164        self.metadata = metadata;
165        self
166    }
167}
168
169#[async_trait]
170impl<F> Tool for ClosureTool<F>
171where
172    F: Fn(ToolInput) -> ToolResult + Send + Sync,
173{
174    fn name(&self) -> &str {
175        &self.name
176    }
177
178    fn description(&self) -> &str {
179        &self.description
180    }
181
182    fn parameters_schema(&self) -> serde_json::Value {
183        self.parameters_schema.clone()
184    }
185
186    async fn execute(&self, input: ToolInput, _ctx: &AgentContext) -> ToolResult {
187        (self.handler)(input)
188    }
189
190    fn metadata(&self) -> ToolMetadata {
191        self.metadata.clone()
192    }
193}
194
195// ============================================================================
196// 便捷工具创建宏
197// ============================================================================
198
199/// 创建简单同步工具
200#[macro_export]
201macro_rules! simple_tool {
202    ($name:expr, $desc:expr, $handler:expr) => {
203        $crate::agent::tools::ClosureTool::new($name, $desc, $handler)
204    };
205    ($name:expr, $desc:expr, $schema:expr, $handler:expr) => {
206        $crate::agent::tools::ClosureTool::new($name, $desc, $handler).with_schema($schema)
207    };
208}
209
210// ============================================================================
211// 内置工具集合
212// ============================================================================
213
214/// 内置工具集合
215pub struct BuiltinTools;
216
217impl BuiltinTools {
218    /// 创建计算器工具
219    pub fn calculator() -> impl Tool {
220        ClosureTool::new(
221            "calculator",
222            "Perform basic arithmetic operations",
223            |input| {
224                let operation = input.get_str("operation").unwrap_or("add");
225                let a = input.get_number("a").unwrap_or(0.0);
226                let b = input.get_number("b").unwrap_or(0.0);
227
228                let result = match operation {
229                    "add" => a + b,
230                    "sub" => a - b,
231                    "mul" => a * b,
232                    "div" => {
233                        if b == 0.0 {
234                            return ToolResult::failure("Division by zero");
235                        }
236                        a / b
237                    }
238                    _ => return ToolResult::failure(format!("Unknown operation: {}", operation)),
239                };
240
241                ToolResult::success_text(format!("{}", result))
242            },
243        )
244        .with_schema(serde_json::json!({
245            "type": "object",
246            "properties": {
247                "operation": {
248                    "type": "string",
249                    "enum": ["add", "sub", "mul", "div"],
250                    "description": "The arithmetic operation to perform"
251                },
252                "a": {
253                    "type": "number",
254                    "description": "First operand"
255                },
256                "b": {
257                    "type": "number",
258                    "description": "Second operand"
259                }
260            },
261            "required": ["operation", "a", "b"]
262        }))
263    }
264
265    /// 创建当前时间工具
266    pub fn current_time() -> impl Tool {
267        ClosureTool::new("current_time", "Get the current date and time", |_input| {
268            let now = std::time::SystemTime::now()
269                .duration_since(std::time::UNIX_EPOCH)
270                .unwrap_or_default()
271                .as_secs();
272
273            ToolResult::success(serde_json::json!({
274                "timestamp": now,
275                "formatted": format!("Unix timestamp: {}", now)
276            }))
277        })
278    }
279
280    /// 创建 JSON 解析工具
281    pub fn json_parser() -> impl Tool {
282        ClosureTool::new(
283            "json_parser",
284            "Parse JSON string into structured data",
285            |input| {
286                let json_str = match input.get_str("json") {
287                    Some(s) => s,
288                    None => return ToolResult::failure("No JSON string provided"),
289                };
290
291                match serde_json::from_str::<serde_json::Value>(json_str) {
292                    Ok(parsed) => ToolResult::success(parsed),
293                    Err(e) => ToolResult::failure(format!("Failed to parse JSON: {}", e)),
294                }
295            },
296        )
297        .with_schema(serde_json::json!({
298            "type": "object",
299            "properties": {
300                "json": {
301                    "type": "string",
302                    "description": "The JSON string to parse"
303                }
304            },
305            "required": ["json"]
306        }))
307    }
308
309    /// 创建字符串处理工具
310    pub fn string_utils() -> impl Tool {
311        ClosureTool::new("string_utils", "String manipulation utilities", |input| {
312            let operation = input.get_str("operation").unwrap_or("length");
313            let text = input.get_str("text").unwrap_or("");
314
315            let result = match operation {
316                "length" => serde_json::json!({ "length": text.len() }),
317                "upper" => serde_json::json!({ "result": text.to_uppercase() }),
318                "lower" => serde_json::json!({ "result": text.to_lowercase() }),
319                "trim" => serde_json::json!({ "result": text.trim() }),
320                "reverse" => {
321                    serde_json::json!({ "result": text.chars().rev().collect::<String>() })
322                }
323                "word_count" => serde_json::json!({ "count": text.split_whitespace().count() }),
324                _ => return ToolResult::failure(format!("Unknown operation: {}", operation)),
325            };
326
327            ToolResult::success(result)
328        })
329        .with_schema(serde_json::json!({
330            "type": "object",
331            "properties": {
332                "operation": {
333                    "type": "string",
334                    "enum": ["length", "upper", "lower", "trim", "reverse", "word_count"],
335                    "description": "The string operation to perform"
336                },
337                "text": {
338                    "type": "string",
339                    "description": "The text to process"
340                }
341            },
342            "required": ["operation", "text"]
343        }))
344    }
345}
346
347#[cfg(test)]
348mod tests {
349    use super::*;
350
351    #[tokio::test]
352    async fn test_closure_tool() {
353        let tool = ClosureTool::new("test", "Test tool", |input| {
354            let msg = input.get_str("message").unwrap_or("default");
355            ToolResult::success_text(format!("Got: {}", msg))
356        });
357
358        let ctx = AgentContext::new("test");
359        let input = ToolInput::from_json(serde_json::json!({"message": "hello"}));
360
361        let result = tool.execute(input, &ctx).await;
362        assert!(result.success);
363        assert_eq!(result.as_text(), Some("Got: hello"));
364    }
365
366    #[tokio::test]
367    async fn test_calculator_tool() {
368        let tool = BuiltinTools::calculator();
369        let ctx = AgentContext::new("test");
370
371        // Test addition
372        let input = ToolInput::from_json(serde_json::json!({
373            "operation": "add",
374            "a": 5,
375            "b": 3
376        }));
377        let result = tool.execute(input, &ctx).await;
378        assert!(result.success);
379        assert_eq!(result.as_text(), Some("8"));
380
381        // Test division by zero
382        let input = ToolInput::from_json(serde_json::json!({
383            "operation": "div",
384            "a": 10,
385            "b": 0
386        }));
387        let result = tool.execute(input, &ctx).await;
388        assert!(!result.success);
389    }
390
391    #[tokio::test]
392    async fn test_string_utils_tool() {
393        let tool = BuiltinTools::string_utils();
394        let ctx = AgentContext::new("test");
395
396        let input = ToolInput::from_json(serde_json::json!({
397            "operation": "upper",
398            "text": "hello world"
399        }));
400        let result = tool.execute(input, &ctx).await;
401        assert!(result.success);
402
403        let output = result.output;
404        assert_eq!(output["result"], "HELLO WORLD");
405    }
406}