Skip to main content

matrixcode_core/tools/workflow/
run.rs

1//! Workflow Execution Tool
2//!
3//! 让 AI 执行工作流
4
5use crate::config::MatrixConfig;
6use crate::providers::Provider;
7use crate::tools::{Tool, ToolDefinition};
8use crate::workflow::executors::ExecutorFactory;
9use crate::workflow::{WorkflowEngine, WorkflowPersistence, WorkflowRegistry, WorkflowStatus};
10use anyhow::Result;
11use async_trait::async_trait;
12use serde_json::Value;
13use std::collections::HashMap;
14use std::sync::Arc;
15
16/// Tool to run a workflow
17pub struct WorkflowRunTool {
18    /// Provider 实例(可选,用于 AI-powered 工具)
19    provider: Option<Arc<dyn Provider>>,
20}
21
22impl WorkflowRunTool {
23    /// 创建新的 WorkflowRunTool(无 Provider)
24    pub fn new() -> Self {
25        Self { provider: None }
26    }
27
28    /// 创建带 Provider 的 WorkflowRunTool
29    pub fn with_provider(provider: Arc<dyn Provider>) -> Self {
30        Self {
31            provider: Some(provider),
32        }
33    }
34
35    /// Get provider - from instance or create from config
36    fn get_provider(&self) -> Result<Arc<dyn Provider>> {
37        if let Some(p) = &self.provider {
38            log::info!(
39                "WorkflowRunTool: using injected provider for model {}",
40                p.model_name()
41            );
42            Ok(p.clone())
43        } else {
44            log::info!("WorkflowRunTool: no injected provider, creating from config");
45            MatrixConfig::create_provider_from_env()
46        }
47    }
48}
49
50impl Default for WorkflowRunTool {
51    fn default() -> Self {
52        Self::new()
53    }
54}
55
56#[async_trait]
57impl Tool for WorkflowRunTool {
58    fn definition(&self) -> ToolDefinition {
59        ToolDefinition {
60            name: "workflow_run".to_string(),
61            description: "执行指定的 workflow。传入 workflow ID 和可选的输入参数,workflow 会按定义的节点顺序执行。".to_string(),
62            parameters: serde_json::json!({
63                "type": "object",
64                "properties": {
65                    "workflow_id": {
66                        "type": "string",
67                        "description": "要执行的 workflow ID。先用 workflow_discover 查看可用 ID。"
68                    },
69                    "inputs": {
70                        "type": "object",
71                        "description": "workflow 输入参数(JSON 对象)。键名必须匹配 workflow 的 required_inputs。"
72                    }
73                },
74                "required": ["workflow_id"]
75            }),
76            ..Default::default()
77        }
78    }
79
80    async fn execute(&self, params: Value) -> Result<String> {
81        let workflow_id = params
82            .get("workflow_id")
83            .and_then(|v| v.as_str())
84            .ok_or_else(|| anyhow::anyhow!("缺少 workflow_id 参数"))?;
85
86        let inputs: HashMap<String, Value> = params
87            .get("inputs")
88            .and_then(|v| v.as_object())
89            .map(|m| m.clone().into_iter().collect())
90            .unwrap_or_default();
91
92        let project_path = std::env::current_dir().ok();
93        let registry = WorkflowRegistry::new(project_path.as_ref());
94
95        // Load workflow
96        let workflow_def = registry.load_workflow(workflow_id)?.ok_or_else(|| {
97            anyhow::anyhow!(
98                "Workflow '{}' 不存在。用 workflow_discover 查看可用列表。",
99                workflow_id
100            )
101        })?;
102
103        // Get provider
104        let provider = self.get_provider()?;
105
106        // Create engine with executor factory
107        let factory = ExecutorFactory::new().with_provider(provider);
108        let engine = WorkflowEngine::new(workflow_def)?.with_executor_factory(factory);
109
110        let context = engine.run(inputs).await?;
111
112        // Save context
113        let persistence = WorkflowPersistence::new(project_path.as_ref());
114        if let Err(e) = persistence.save(&context) {
115            log::warn!("Failed to save workflow context: {}", e);
116        }
117
118        // Build result
119        let status = if context.status == WorkflowStatus::Completed {
120            "✓ 完成".to_string()
121        } else if context.status == WorkflowStatus::Failed {
122            format!("❌ 失败: {}", context.error.unwrap_or_default())
123        } else {
124            format!("状态: {:?}", context.status)
125        };
126
127        Ok(format!(
128            "Workflow '{}' 执行结果:\n\n实例ID: {}\n节点执行: {} 个\n{}\n\n变量输出: {}",
129            workflow_id,
130            context.instance_id,
131            context.execution_path.len(),
132            status,
133            serde_json::to_string_pretty(&context.variables).unwrap_or_default()
134        ))
135    }
136}