matrixcode_core/tools/workflow/
run.rs1use crate::tools::{Tool, ToolDefinition};
6use crate::workflow::{WorkflowRegistry, WorkflowEngine, WorkflowPersistence, WorkflowStatus};
7use crate::workflow::executors::ExecutorFactory;
8use crate::providers::Provider;
9use crate::config::MatrixConfig;
10use anyhow::Result;
11use async_trait::async_trait;
12use serde_json::Value;
13use std::collections::HashMap;
14use std::sync::Arc;
15
16pub struct WorkflowRunTool {
18 provider: Option<Arc<dyn Provider>>,
20}
21
22impl WorkflowRunTool {
23 pub fn new() -> Self {
25 Self { provider: None }
26 }
27
28 pub fn with_provider(provider: Arc<dyn Provider>) -> Self {
30 Self { provider: Some(provider) }
31 }
32
33 fn get_provider(&self) -> Result<Arc<dyn Provider>> {
35 if let Some(p) = &self.provider {
36 log::info!("WorkflowRunTool: using injected provider for model {}", p.model_name());
37 Ok(p.clone())
38 } else {
39 log::info!("WorkflowRunTool: no injected provider, creating from config");
40 MatrixConfig::create_provider_from_env()
41 }
42 }
43}
44
45impl Default for WorkflowRunTool {
46 fn default() -> Self {
47 Self::new()
48 }
49}
50
51#[async_trait]
52impl Tool for WorkflowRunTool {
53 fn definition(&self) -> ToolDefinition {
54 ToolDefinition {
55 name: "workflow_run".to_string(),
56 description: "执行指定的 workflow。传入 workflow ID 和可选的输入参数,workflow 会按定义的节点顺序执行。".to_string(),
57 parameters: serde_json::json!({
58 "type": "object",
59 "properties": {
60 "workflow_id": {
61 "type": "string",
62 "description": "要执行的 workflow ID。先用 workflow_discover 查看可用 ID。"
63 },
64 "inputs": {
65 "type": "object",
66 "description": "workflow 输入参数(JSON 对象)。键名必须匹配 workflow 的 required_inputs。"
67 }
68 },
69 "required": ["workflow_id"]
70 }),
71 ..Default::default()
72 }
73 }
74
75 async fn execute(&self, params: Value) -> Result<String> {
76 let workflow_id = params.get("workflow_id")
77 .and_then(|v| v.as_str())
78 .ok_or_else(|| anyhow::anyhow!("缺少 workflow_id 参数"))?;
79
80 let inputs: HashMap<String, Value> = params.get("inputs")
81 .and_then(|v| v.as_object())
82 .map(|m| m.clone().into_iter().collect())
83 .unwrap_or_default();
84
85 let project_path = std::env::current_dir().ok();
86 let registry = WorkflowRegistry::new(project_path.as_ref());
87
88 let workflow_def = registry.load_workflow(workflow_id)?
90 .ok_or_else(|| anyhow::anyhow!("Workflow '{}' 不存在。用 workflow_discover 查看可用列表。", workflow_id))?;
91
92 let provider = self.get_provider()?;
94
95 let factory = ExecutorFactory::new().with_provider(provider);
97 let engine = WorkflowEngine::new(workflow_def)?
98 .with_executor_factory(factory);
99
100 let context = engine.run(inputs).await?;
101
102 let persistence = WorkflowPersistence::new(project_path.as_ref());
104 if let Err(e) = persistence.save(&context) {
105 log::warn!("Failed to save workflow context: {}", e);
106 }
107
108 let status = if context.status == WorkflowStatus::Completed {
110 "✓ 完成".to_string()
111 } else if context.status == WorkflowStatus::Failed {
112 format!("❌ 失败: {}", context.error.unwrap_or_default())
113 } else {
114 format!("状态: {:?}", context.status)
115 };
116
117 Ok(format!(
118 "Workflow '{}' 执行结果:\n\n实例ID: {}\n节点执行: {} 个\n{}\n\n变量输出: {}",
119 workflow_id,
120 context.instance_id,
121 context.execution_path.len(),
122 status,
123 serde_json::to_string_pretty(&context.variables).unwrap_or_default()
124 ))
125 }
126}