Skip to main content

matrixcode_core/workflow/executors/
tool.rs

1//! Tool Executor
2//!
3//! 工具调用执行器,调用现有 Tools 执行任务节点。
4
5use anyhow::{Context, Result};
6use async_trait::async_trait;
7use std::collections::HashMap;
8use std::sync::Arc;
9
10use super::node_executor::NodeExecutor;
11use crate::tools::{Tool, ToolDefinition};
12use crate::workflow::context::WorkflowContext;
13use crate::workflow::def::NodeDef;
14use crate::workflow::template::TemplateRenderer;
15
16/// 工具执行器配置
17#[derive(Debug, Clone)]
18pub struct ToolExecutorConfig {
19    /// 是否记录工具调用结果
20    pub log_results: bool,
21    /// 是否允许工具调用失败
22    pub allow_failure: bool,
23}
24
25impl Default for ToolExecutorConfig {
26    fn default() -> Self {
27        Self {
28            log_results: true,
29            allow_failure: false,
30        }
31    }
32}
33
34/// 工具执行器
35///
36/// 调用现有的 Tools 系统执行任务节点。
37pub struct ToolExecutor {
38    /// 工具集合
39    tools: HashMap<String, Arc<dyn Tool>>,
40    /// 配置
41    config: ToolExecutorConfig,
42    /// 模板渲染器
43    template_renderer: TemplateRenderer,
44}
45
46impl ToolExecutor {
47    /// 创建新的工具执行器
48    pub fn new(tools: Vec<Box<dyn Tool>>) -> Self {
49        let mut tool_map = HashMap::new();
50        for tool in tools {
51            let def = tool.definition();
52            tool_map.insert(def.name, Arc::from(tool));
53        }
54        Self {
55            tools: tool_map,
56            config: ToolExecutorConfig::default(),
57            template_renderer: TemplateRenderer::new(),
58        }
59    }
60
61    /// 使用配置创建工具执行器
62    pub fn with_config(tools: Vec<Box<dyn Tool>>, config: ToolExecutorConfig) -> Self {
63        let mut tool_map = HashMap::new();
64        for tool in tools {
65            let def = tool.definition();
66            tool_map.insert(def.name, Arc::from(tool));
67        }
68        Self {
69            tools: tool_map,
70            config,
71            template_renderer: TemplateRenderer::new(),
72        }
73    }
74
75    /// 注册单个工具
76    pub fn register_tool(&mut self, tool: Box<dyn Tool>) {
77        let def = tool.definition();
78        self.tools.insert(def.name, Arc::from(tool));
79    }
80
81    /// 渲染参数
82    pub fn render_params(
83        &self,
84        params: &HashMap<String, serde_json::Value>,
85        context: &WorkflowContext,
86    ) -> Result<serde_json::Value> {
87        let mut rendered = HashMap::new();
88        for (key, value) in params {
89            let rendered_value = if let serde_json::Value::String(s) = value {
90                let rendered_str = self.template_renderer.render(s, &context.variables)?;
91                serde_json::Value::String(rendered_str)
92            } else {
93                value.clone()
94            };
95            rendered.insert(key.clone(), rendered_value);
96        }
97        Ok(serde_json::Value::Object(rendered.into_iter().collect()))
98    }
99
100    /// 检查工具是否存在
101    pub fn has_tool(&self, name: &str) -> bool {
102        self.tools.contains_key(name)
103    }
104
105    /// 获取工具定义
106    pub fn get_tool_definition(&self, name: &str) -> Option<ToolDefinition> {
107        self.tools.get(name).map(|t| t.definition())
108    }
109
110    /// 获取所有工具定义
111    pub fn get_all_tool_definitions(&self) -> Vec<ToolDefinition> {
112        self.tools.values().map(|t| t.definition()).collect()
113    }
114}
115
116#[async_trait]
117impl NodeExecutor for ToolExecutor {
118    async fn execute(
119        &self,
120        node: &NodeDef,
121        context: &mut WorkflowContext,
122    ) -> Result<serde_json::Value> {
123        // 获取工具名称
124        let tool_name = node
125            .task
126            .as_ref()
127            .ok_or_else(|| anyhow::anyhow!("Tool executor requires a task name"))?;
128
129        // 查找工具
130        let tool = self
131            .tools
132            .get(tool_name)
133            .ok_or_else(|| anyhow::anyhow!("Tool '{}' not found", tool_name))?;
134
135        // 渲染参数
136        let params = self.render_params(&node.params, context)?;
137
138        // 执行工具
139        let result = tool
140            .execute(params.clone())
141            .await
142            .with_context(|| format!("Tool '{}' execution failed", tool_name));
143
144        // 处理结果
145        match result {
146            Ok(output_str) => {
147                // 尝试解析为 JSON
148                let output =
149                    if let Ok(json) = serde_json::from_str::<serde_json::Value>(&output_str) {
150                        json
151                    } else {
152                        serde_json::json!({
153                            "result": output_str,
154                            "tool": tool_name,
155                        })
156                    };
157
158                // 更新上下文
159                if let serde_json::Value::Object(map) = &output {
160                    for (key, value) in map {
161                        context.set_variable(key.clone(), value.clone());
162                    }
163                }
164
165                Ok(output)
166            }
167            Err(e) => {
168                if self.config.allow_failure {
169                    Ok(serde_json::json!({
170                        "error": e.to_string(),
171                        "tool": tool_name,
172                        "success": false,
173                    }))
174                } else {
175                    Err(e)
176                }
177            }
178        }
179    }
180
181    fn name(&self) -> &str {
182        "tool_executor"
183    }
184}