Skip to main content

matrixcode_core/tools/workflow/
run.rs

1//! Workflow Execution Tool
2//!
3//! 让 AI 执行工作流
4
5use crate::tools::{Tool, ToolDefinition};
6use crate::workflow::{WorkflowRegistry, WorkflowEngine, WorkflowPersistence, WorkflowStatus};
7use crate::workflow::executors::ExecutorFactory;
8use crate::providers::Provider;
9use crate::config::MatrixConfig;
10use anyhow::Result;
11use async_trait::async_trait;
12use serde_json::Value;
13use std::collections::HashMap;
14use std::sync::Arc;
15
16/// Tool to run a workflow
17pub struct WorkflowRunTool {
18    /// Provider 实例(可选,用于 AI-powered 工具)
19    provider: Option<Arc<dyn Provider>>,
20}
21
22impl WorkflowRunTool {
23    /// 创建新的 WorkflowRunTool(无 Provider)
24    pub fn new() -> Self {
25        Self { provider: None }
26    }
27
28    /// 创建带 Provider 的 WorkflowRunTool
29    pub fn with_provider(provider: Arc<dyn Provider>) -> Self {
30        Self { provider: Some(provider) }
31    }
32
33    /// Get provider - from instance or create from config
34    fn get_provider(&self) -> Result<Arc<dyn Provider>> {
35        if let Some(p) = &self.provider {
36            log::info!("WorkflowRunTool: using injected provider for model {}", p.model_name());
37            Ok(p.clone())
38        } else {
39            log::info!("WorkflowRunTool: no injected provider, creating from config");
40            MatrixConfig::create_provider_from_env()
41        }
42    }
43}
44
45impl Default for WorkflowRunTool {
46    fn default() -> Self {
47        Self::new()
48    }
49}
50
51#[async_trait]
52impl Tool for WorkflowRunTool {
53    fn definition(&self) -> ToolDefinition {
54        ToolDefinition {
55            name: "workflow_run".to_string(),
56            description: "执行指定的 workflow。传入 workflow ID 和可选的输入参数,workflow 会按定义的节点顺序执行。".to_string(),
57            parameters: serde_json::json!({
58                "type": "object",
59                "properties": {
60                    "workflow_id": {
61                        "type": "string",
62                        "description": "要执行的 workflow ID。先用 workflow_discover 查看可用 ID。"
63                    },
64                    "inputs": {
65                        "type": "object",
66                        "description": "workflow 输入参数(JSON 对象)。键名必须匹配 workflow 的 required_inputs。"
67                    }
68                },
69                "required": ["workflow_id"]
70            }),
71            ..Default::default()
72        }
73    }
74
75    async fn execute(&self, params: Value) -> Result<String> {
76        let workflow_id = params.get("workflow_id")
77            .and_then(|v| v.as_str())
78            .ok_or_else(|| anyhow::anyhow!("缺少 workflow_id 参数"))?;
79
80        let inputs: HashMap<String, Value> = params.get("inputs")
81            .and_then(|v| v.as_object())
82            .map(|m| m.clone().into_iter().collect())
83            .unwrap_or_default();
84
85        let project_path = std::env::current_dir().ok();
86        let registry = WorkflowRegistry::new(project_path.as_ref());
87
88        // Load workflow
89        let workflow_def = registry.load_workflow(workflow_id)?
90            .ok_or_else(|| anyhow::anyhow!("Workflow '{}' 不存在。用 workflow_discover 查看可用列表。", workflow_id))?;
91
92        // Get provider
93        let provider = self.get_provider()?;
94
95        // Create engine with executor factory
96        let factory = ExecutorFactory::new().with_provider(provider);
97        let engine = WorkflowEngine::new(workflow_def)?
98            .with_executor_factory(factory);
99
100        let context = engine.run(inputs).await?;
101
102        // Save context
103        let persistence = WorkflowPersistence::new(project_path.as_ref());
104        if let Err(e) = persistence.save(&context) {
105            log::warn!("Failed to save workflow context: {}", e);
106        }
107
108        // Build result
109        let status = if context.status == WorkflowStatus::Completed {
110            "✓ 完成".to_string()
111        } else if context.status == WorkflowStatus::Failed {
112            format!("❌ 失败: {}", context.error.unwrap_or_default())
113        } else {
114            format!("状态: {:?}", context.status)
115        };
116
117        Ok(format!(
118            "Workflow '{}' 执行结果:\n\n实例ID: {}\n节点执行: {} 个\n{}\n\n变量输出: {}",
119            workflow_id,
120            context.instance_id,
121            context.execution_path.len(),
122            status,
123            serde_json::to_string_pretty(&context.variables).unwrap_or_default()
124        ))
125    }
126}