Skip to main content

matrixcode_core/workflow/executors/
tool.rs

1//! Tool Executor
2//!
3//! 工具调用执行器,调用现有 Tools 执行任务节点。
4
5use anyhow::{Context, Result};
6use async_trait::async_trait;
7use std::collections::HashMap;
8use std::sync::Arc;
9
10use crate::tools::{Tool, ToolDefinition};
11use crate::workflow::context::WorkflowContext;
12use crate::workflow::def::NodeDef;
13use crate::workflow::template::TemplateRenderer;
14use super::node_executor::NodeExecutor;
15
16/// 工具执行器配置
17#[derive(Debug, Clone)]
18pub struct ToolExecutorConfig {
19    /// 是否记录工具调用结果
20    pub log_results: bool,
21    /// 是否允许工具调用失败
22    pub allow_failure: bool,
23}
24
25impl Default for ToolExecutorConfig {
26    fn default() -> Self {
27        Self {
28            log_results: true,
29            allow_failure: false,
30        }
31    }
32}
33
34/// 工具执行器
35///
36/// 调用现有的 Tools 系统执行任务节点。
37pub struct ToolExecutor {
38    /// 工具集合
39    tools: HashMap<String, Arc<dyn Tool>>,
40    /// 配置
41    config: ToolExecutorConfig,
42    /// 模板渲染器
43    template_renderer: TemplateRenderer,
44}
45
46impl ToolExecutor {
47    /// 创建新的工具执行器
48    pub fn new(tools: Vec<Box<dyn Tool>>) -> Self {
49        let mut tool_map = HashMap::new();
50        for tool in tools {
51            let def = tool.definition();
52            tool_map.insert(def.name, Arc::from(tool));
53        }
54        Self {
55            tools: tool_map,
56            config: ToolExecutorConfig::default(),
57            template_renderer: TemplateRenderer::new(),
58        }
59    }
60
61    /// 使用配置创建工具执行器
62    pub fn with_config(tools: Vec<Box<dyn Tool>>, config: ToolExecutorConfig) -> Self {
63        let mut tool_map = HashMap::new();
64        for tool in tools {
65            let def = tool.definition();
66            tool_map.insert(def.name, Arc::from(tool));
67        }
68        Self {
69            tools: tool_map,
70            config,
71            template_renderer: TemplateRenderer::new(),
72        }
73    }
74
75    /// 注册单个工具
76    pub fn register_tool(&mut self, tool: Box<dyn Tool>) {
77        let def = tool.definition();
78        self.tools.insert(def.name, Arc::from(tool));
79    }
80
81    /// 渲染参数
82    pub fn render_params(
83        &self,
84        params: &HashMap<String, serde_json::Value>,
85        context: &WorkflowContext,
86    ) -> Result<serde_json::Value> {
87        let mut rendered = HashMap::new();
88        for (key, value) in params {
89            let rendered_value = if let serde_json::Value::String(s) = value {
90                let rendered_str = self.template_renderer.render(s, &context.variables)?;
91                serde_json::Value::String(rendered_str)
92            } else {
93                value.clone()
94            };
95            rendered.insert(key.clone(), rendered_value);
96        }
97        Ok(serde_json::Value::Object(rendered.into_iter().collect()))
98    }
99
100    /// 检查工具是否存在
101    pub fn has_tool(&self, name: &str) -> bool {
102        self.tools.contains_key(name)
103    }
104
105    /// 获取工具定义
106    pub fn get_tool_definition(&self, name: &str) -> Option<ToolDefinition> {
107        self.tools.get(name).map(|t| t.definition())
108    }
109
110    /// 获取所有工具定义
111    pub fn get_all_tool_definitions(&self) -> Vec<ToolDefinition> {
112        self.tools.values().map(|t| t.definition()).collect()
113    }
114}
115
116#[async_trait]
117impl NodeExecutor for ToolExecutor {
118    async fn execute(
119        &self,
120        node: &NodeDef,
121        context: &mut WorkflowContext,
122    ) -> Result<serde_json::Value> {
123        // 获取工具名称
124        let tool_name = node.task.as_ref()
125            .ok_or_else(|| anyhow::anyhow!("Tool executor requires a task name"))?;
126
127        // 查找工具
128        let tool = self.tools.get(tool_name)
129            .ok_or_else(|| anyhow::anyhow!("Tool '{}' not found", tool_name))?;
130
131        // 渲染参数
132        let params = self.render_params(&node.params, context)?;
133
134        // 执行工具
135        let result = tool.execute(params.clone())
136            .await
137            .with_context(|| format!("Tool '{}' execution failed", tool_name));
138
139        // 处理结果
140        match result {
141            Ok(output_str) => {
142                // 尝试解析为 JSON
143                let output = if let Ok(json) = serde_json::from_str::<serde_json::Value>(&output_str) {
144                    json
145                } else {
146                    serde_json::json!({
147                        "result": output_str,
148                        "tool": tool_name,
149                    })
150                };
151
152                // 更新上下文
153                if let serde_json::Value::Object(map) = &output {
154                    for (key, value) in map {
155                        context.set_variable(key.clone(), value.clone());
156                    }
157                }
158
159                Ok(output)
160            }
161            Err(e) => {
162                if self.config.allow_failure {
163                    Ok(serde_json::json!({
164                        "error": e.to_string(),
165                        "tool": tool_name,
166                        "success": false,
167                    }))
168                } else {
169                    Err(e)
170                }
171            }
172        }
173    }
174
175    fn name(&self) -> &str {
176        "tool_executor"
177    }
178}