Skip to main content

matrixcode_core/workflow/
engine.rs

1//! Workflow Engine - State Machine Implementation
2//!
3//! 工作流引擎,实现状态机的基础结构和主运行循环。
4
5use anyhow::{Context, Result};
6use std::collections::HashMap;
7use std::sync::Arc;
8use std::time::Duration;
9use tokio::time::timeout;
10use super::context::{WorkflowContext};
11use super::def::{FailureStrategy, NodeDef, NodeType, WorkflowDef};
12use super::rule_engine::evaluate_expression;
13use super::template::TemplateRenderer;
14use super::executors::{NodeExecutor, ExecutorFactory};
15use crate::tools::toolproxy::{ProxyToolExecutor, ProxyToolDef};
16
17/// 任务执行器 trait
18#[async_trait::async_trait]
19pub trait TaskExecutor: Send + Sync {
20    /// 执行任务,返回输出数据
21    async fn execute(
22        &self,
23        task_name: &str,
24        params: &HashMap<String, serde_json::Value>,
25        context: &WorkflowContext,
26    ) -> Result<serde_json::Value>;
27}
28
29/// 工作流事件
30#[derive(Debug, Clone)]
31pub enum WorkflowEvent {
32    /// 工作流开始
33    Started,
34    /// 节点开始执行
35    NodeStarted { node_id: String },
36    /// 节点执行完成
37    NodeCompleted { node_id: String, output: Option<serde_json::Value> },
38    /// 节点执行失败
39    NodeFailed { node_id: String, error: String },
40    /// 节点跳过
41    NodeSkipped { node_id: String, reason: String },
42    /// 工作流完成
43    Completed,
44    /// 工作流失败
45    Failed { error: String },
46    /// 工作流暂停
47    Paused,
48    /// 工作流恢复
49    Resumed,
50}
51
52/// 事件监听器 trait
53pub trait EventListener: Send + Sync {
54    fn on_event(&self, event: WorkflowEvent);
55}
56
57/// 工作流引擎
58pub struct WorkflowEngine {
59    /// 工作流定义
60    definition: WorkflowDef,
61    /// 任务执行器(旧接口)
62    executor: Option<Arc<dyn TaskExecutor>>,
63    /// 节点执行器(新接口)
64    node_executors: HashMap<String, Arc<dyn NodeExecutor>>,
65    /// 执行器工厂
66    executor_factory: Option<ExecutorFactory>,
67    /// 代理工具执行器
68    proxy_executor: Option<Arc<dyn ProxyToolExecutor>>,
69    /// 代理工具定义列表
70    proxy_tool_defs: Vec<ProxyToolDef>,
71    /// 事件监听器
72    listeners: Vec<Box<dyn EventListener>>,
73    /// 模板渲染器
74    template_renderer: TemplateRenderer,
75}
76
77impl WorkflowEngine {
78    /// 创建新的工作流引擎
79    pub fn new(definition: WorkflowDef) -> Result<Self> {
80        definition.validate()
81            .with_context(|| "Invalid workflow definition")?;
82
83        Ok(Self {
84            definition,
85            executor: None,
86            node_executors: HashMap::new(),
87            executor_factory: None,
88            proxy_executor: None,
89            proxy_tool_defs: Vec::new(),
90            listeners: Vec::new(),
91            template_renderer: TemplateRenderer::new(),
92        })
93    }
94
95    /// 设置任务执行器(旧接口)
96    pub fn with_executor(mut self, executor: Arc<dyn TaskExecutor>) -> Self {
97        self.executor = Some(executor);
98        self
99    }
100
101    /// 设置执行器工厂
102    pub fn with_executor_factory(mut self, factory: ExecutorFactory) -> Self {
103        self.executor_factory = Some(factory);
104        self
105    }
106
107    /// 设置代理工具执行器
108    pub fn with_proxy_executor(mut self, executor: Arc<dyn ProxyToolExecutor>, tool_defs: Vec<ProxyToolDef>) -> Self {
109        self.proxy_executor = Some(executor);
110        self.proxy_tool_defs = tool_defs;
111        self
112    }
113
114    /// 注册节点执行器
115    pub fn register_node_executor(mut self, task_type: &str, executor: Arc<dyn NodeExecutor>) -> Self {
116        self.node_executors.insert(task_type.to_string(), executor);
117        self
118    }
119
120    /// 添加事件监听器
121    pub fn add_listener(&mut self, listener: Box<dyn EventListener>) {
122        self.listeners.push(listener);
123    }
124
125    /// 触发事件
126    fn emit_event(&self, event: WorkflowEvent) {
127        for listener in &self.listeners {
128            listener.on_event(event.clone());
129        }
130    }
131
132    /// 获取节点执行器
133    fn get_node_executor(&self, node: &NodeDef) -> Option<Arc<dyn NodeExecutor>> {
134        // 优先从注册的执行器中查找
135        if let Some(task) = &node.task
136            && let Some(executor) = self.node_executors.get(task) {
137                return Some(executor.clone());
138            }
139
140        // 检查是否是代理工具
141        if let Some(task) = &node.task
142            && self.proxy_tool_defs.iter().any(|t| t.definition.name == *task)
143                && let Some(executor) = &self.proxy_executor {
144                    return Some(Arc::new(super::executors::ProxyExecutor::new(
145                        executor.clone(),
146                        self.proxy_tool_defs.clone(),
147                    )));
148                }
149
150        // 根据节点类型选择默认执行器
151        match node.node_type {
152            NodeType::Task => {
153                // 尝试从工厂创建
154                if let Some(factory) = &self.executor_factory
155                    && let Some(task) = &node.task {
156                        // 根据任务名称推断执行器类型
157                        // ai / ai_* / claude* / gpt* 使用 AI 执行器
158                        let task_lower = task.to_lowercase();
159                        if task_lower == "ai" || task_lower.starts_with("ai_") || task_lower.starts_with("claude") || task_lower.starts_with("gpt") {
160                            return factory.create_ai_executor().ok();
161                        }
162                        // 默认使用工具执行器
163                        return Some(factory.create_tool_executor());
164                    }
165            }
166            NodeType::Condition => {
167                if let Some(factory) = &self.executor_factory {
168                    return Some(factory.create_condition_executor());
169                }
170            }
171            NodeType::Approval => {
172                // 审批节点使用特殊的验证执行器
173                if let Some(factory) = &self.executor_factory {
174                    return Some(factory.create_validate_executor());
175                }
176            }
177            _ => {}
178        }
179
180        None
181    }
182
183    /// 运行工作流
184    pub async fn run(&self, inputs: HashMap<String, serde_json::Value>) -> Result<WorkflowContext> {
185        // 创建上下文
186        let mut context = WorkflowContext::new(self.definition.id.clone(), inputs.clone());
187
188        // 验证必填输入
189        self.validate_inputs(&context)?;
190
191        // 初始化变量:先添加 inputs
192        for (key, value) in inputs {
193            context.set_variable(key.clone(), value.clone());
194        }
195
196        // 渲染并添加 workflow 定义的变量
197        let renderer = crate::workflow::template::TemplateRenderer::new();
198        for (key, value) in &self.definition.variables {
199            // 如果是字符串,渲染模板
200            let rendered_value = if let serde_json::Value::String(s) = value {
201                match renderer.render(s, &context.variables) {
202                    Ok(rendered) => serde_json::Value::String(rendered),
203                    Err(_) => value.clone(), // 渲染失败保持原值
204                }
205            } else {
206                value.clone()
207            };
208            context.set_variable(key.clone(), rendered_value);
209        }
210
211        // 开始工作流
212        context.start();
213        self.emit_event(WorkflowEvent::Started);
214
215        // 获取开始节点
216        let start_node = self.definition.get_start_node()
217            .ok_or_else(|| anyhow::anyhow!("No start node found"))?;
218
219        // 执行工作流
220        match self.execute_from_node(start_node, &mut context).await {
221            Ok(()) => {
222                context.complete();
223                self.emit_event(WorkflowEvent::Completed);
224            }
225            Err(e) => {
226                context.fail(e.to_string());
227                self.emit_event(WorkflowEvent::Failed { error: e.to_string() });
228            }
229        }
230
231        Ok(context)
232    }
233
234    /// 从指定节点开始执行
235    async fn execute_from_node(
236        &self,
237        node: &NodeDef,
238        context: &mut WorkflowContext,
239    ) -> Result<()> {
240        let mut current_node = Some(node);
241
242        while let Some(node) = current_node {
243            // 检查工作流状态
244            if !context.can_continue() {
245                break;
246            }
247
248            // 执行节点
249            match self.execute_node(node, context).await {
250                Ok(next_node_id) => {
251                    current_node = next_node_id
252                        .as_ref()
253                        .and_then(|id| self.definition.get_node(id));
254                }
255                Err(e) => {
256                    // 处理失败
257                    match &node.on_failure {
258                        FailureStrategy::Retry { max_attempts, interval_ms } => {
259                            let exec = context.get_or_create_node_execution(&node.id);
260                            if exec.retry_count < *max_attempts {
261                                exec.increment_retry();
262                                if let Some(interval) = interval_ms {
263                                    tokio::time::sleep(Duration::from_millis(*interval)).await;
264                                }
265                                continue; // 重试当前节点
266                            } else {
267                                return Err(e);
268                            }
269                        }
270                        FailureStrategy::Ignore => {
271                            // 忽略错误,标记节点为 Skipped 并继续执行下一个节点
272                            let exec = context.get_or_create_node_execution(&node.id);
273                            exec.skip();
274                            self.emit_event(WorkflowEvent::NodeSkipped {
275                                node_id: node.id.clone(),
276                                reason: e.to_string(),
277                            });
278                            let next = self.get_next_node(node, context)?;
279                            current_node = next
280                                .as_ref()
281                                .and_then(|id| self.definition.get_node(id));
282                        }
283                        FailureStrategy::Abort => {
284                            return Err(e);
285                        }
286                        FailureStrategy::Goto { target } => {
287                            current_node = self.definition.get_node(target);
288                        }
289                    }
290                }
291            }
292        }
293
294        Ok(())
295    }
296
297    /// 执行单个节点
298    async fn execute_node(
299        &self,
300        node: &NodeDef,
301        context: &mut WorkflowContext,
302    ) -> Result<Option<String>> {
303        // 创建执行记录
304        let execution = context.get_or_create_node_execution(&node.id);
305        execution.start();
306        self.emit_event(WorkflowEvent::NodeStarted { node_id: node.id.clone() });
307
308        // 设置当前节点
309        context.set_current_node(node.id.clone());
310
311        // 处理超时
312        let result = if let Some(timeout_ms) = node.timeout_ms {
313            timeout(
314                Duration::from_millis(timeout_ms),
315                self.execute_node_inner(node, context),
316            )
317            .await
318            .with_context(|| format!("Node '{}' timed out after {}ms", node.id, timeout_ms))?
319        } else {
320            self.execute_node_inner(node, context).await
321        };
322
323        match result {
324            Ok(output) => {
325                let exec = context.get_or_create_node_execution(&node.id);
326                exec.complete(output.clone());
327                self.emit_event(WorkflowEvent::NodeCompleted {
328                    node_id: node.id.clone(),
329                    output,
330                });
331
332                // 获取下一个节点
333                self.get_next_node(node, context)
334            }
335            Err(e) => {
336                let exec = context.get_or_create_node_execution(&node.id);
337                exec.fail(e.to_string());
338                self.emit_event(WorkflowEvent::NodeFailed {
339                    node_id: node.id.clone(),
340                    error: e.to_string(),
341                });
342                Err(e)
343            }
344        }
345    }
346
347    /// 节点内部执行逻辑
348    async fn execute_node_inner(
349        &self,
350        node: &NodeDef,
351        context: &mut WorkflowContext,
352    ) -> Result<Option<serde_json::Value>> {
353        match &node.node_type {
354            NodeType::Start => {
355                Ok(None)
356            }
357            NodeType::End => {
358                Ok(None)
359            }
360            NodeType::Task => {
361                self.execute_task(node, context).await
362            }
363            NodeType::Condition => {
364                self.execute_condition(node, context).await
365            }
366            NodeType::Parallel => {
367                self.execute_parallel(node, context).await
368            }
369            NodeType::SubWorkflow => {
370                self.execute_subworkflow(node, context).await
371            }
372            NodeType::Wait => {
373                self.execute_wait(node, context).await
374            }
375            NodeType::Approval => {
376                self.execute_approval(node, context).await
377            }
378        }
379    }
380
381    /// 执行任务节点
382    async fn execute_task(
383        &self,
384        node: &NodeDef,
385        context: &mut WorkflowContext,
386    ) -> Result<Option<serde_json::Value>> {
387        let task_name = node.task.as_ref()
388            .ok_or_else(|| anyhow::anyhow!("Task node '{}' has no task name", node.id))?;
389
390        // 渲染参数
391        let mut rendered_params = HashMap::new();
392        for (key, value) in &node.params {
393            if let serde_json::Value::String(s) = value {
394                let rendered = self.template_renderer.render(s, &context.variables)?;
395                rendered_params.insert(key.clone(), serde_json::Value::String(rendered));
396            } else {
397                rendered_params.insert(key.clone(), value.clone());
398            }
399        }
400
401        // 尝试使用新的 NodeExecutor 接口
402        if let Some(node_executor) = self.get_node_executor(node) {
403            let output = node_executor.execute(node, context).await?;
404            return Ok(Some(output));
405        }
406
407        // 回退到旧的 TaskExecutor 接口
408        if let Some(executor) = &self.executor {
409            let output = executor.execute(task_name, &rendered_params, context).await?;
410            Ok(Some(output))
411        } else {
412            // 无执行器,返回模拟输出
413            Ok(Some(serde_json::json!({ "task": task_name, "status": "completed" })))
414        }
415    }
416
417    /// 执行条件节点
418    async fn execute_condition(
419        &self,
420        node: &NodeDef,
421        context: &mut WorkflowContext,
422    ) -> Result<Option<serde_json::Value>> {
423        let branches = node.branches.as_ref()
424            .ok_or_else(|| anyhow::anyhow!("Condition node '{}' has no branches", node.id))?;
425
426        for branch in branches {
427            if evaluate_expression(&branch.condition, &context.variables)? {
428                // 找到匹配的分支,设置目标节点
429                return Ok(Some(serde_json::Value::String(branch.target.clone())));
430            }
431        }
432
433        // 没有匹配的分支
434        Ok(None)
435    }
436
437    /// 执行并行节点
438    async fn execute_parallel(
439        &self,
440        node: &NodeDef,
441        _context: &mut WorkflowContext,
442    ) -> Result<Option<serde_json::Value>> {
443        let branches = node.parallel_branches.as_ref()
444            .ok_or_else(|| anyhow::anyhow!("Parallel node '{}' has no branches", node.id))?;
445
446        // 并行执行所有分支
447        let mut outputs = Vec::new();
448        for branch in branches {
449            // 这里简化处理,实际应该并行执行
450            outputs.push(serde_json::json!({
451                "branch": branch.name,
452                "status": "completed"
453            }));
454        }
455
456        Ok(Some(serde_json::Value::Array(outputs)))
457    }
458
459    /// 执行子工作流
460    async fn execute_subworkflow(
461        &self,
462        node: &NodeDef,
463        _context: &mut WorkflowContext,
464    ) -> Result<Option<serde_json::Value>> {
465        let workflow_name = node.workflow.as_ref()
466            .ok_or_else(|| anyhow::anyhow!("SubWorkflow node '{}' has no workflow name", node.id))?;
467
468        // 这里简化处理,实际应该加载并执行子工作流
469        Ok(Some(serde_json::json!({
470            "workflow": workflow_name,
471            "status": "completed"
472        })))
473    }
474
475    /// 执行等待节点
476    async fn execute_wait(
477        &self,
478        node: &NodeDef,
479        _context: &mut WorkflowContext,
480    ) -> Result<Option<serde_json::Value>> {
481        let wait_ms = node.wait_ms.unwrap_or(0);
482        if wait_ms > 0 {
483            tokio::time::sleep(Duration::from_millis(wait_ms)).await;
484        }
485        Ok(None)
486    }
487
488    /// 执行审批节点
489    async fn execute_approval(
490        &self,
491        node: &NodeDef,
492        _context: &mut WorkflowContext,
493    ) -> Result<Option<serde_json::Value>> {
494        let approvers = node.approvers.as_ref()
495            .ok_or_else(|| anyhow::anyhow!("Approval node '{}' has no approvers", node.id))?;
496
497        // 这里简化处理,实际应该等待审批
498        Ok(Some(serde_json::json!({
499            "approvers": approvers,
500            "status": "pending_approval"
501        })))
502    }
503
504    /// 获取下一个节点
505    fn get_next_node(
506        &self,
507        node: &NodeDef,
508        context: &WorkflowContext,
509    ) -> Result<Option<String>> {
510        // 结束节点没有下一个节点
511        if node.node_type == NodeType::End {
512            return Ok(None);
513        }
514
515        // 获取输出边
516        let edges = self.definition.get_outgoing_edges(&node.id);
517
518        if edges.is_empty() {
519            return Ok(None);
520        }
521
522        // 条件节点从分支获取下一个节点
523        if node.node_type == NodeType::Condition {
524            let exec = context.get_node_execution(&node.id);
525            if let Some(exec) = exec
526                && let Some(serde_json::Value::String(target)) = &exec.output {
527                    return Ok(Some(target.clone()));
528                }
529        }
530
531        // 根据边条件选择下一个节点
532        for edge in edges {
533            if let Some(condition) = &edge.condition {
534                if evaluate_expression(condition, &context.variables)? {
535                    return Ok(Some(edge.to.clone()));
536                }
537            } else {
538                // 无条件的边,直接返回
539                return Ok(Some(edge.to.clone()));
540            }
541        }
542
543        // 没有匹配的边
544        Ok(None)
545    }
546
547    /// 验证输入参数
548    fn validate_inputs(&self, context: &WorkflowContext) -> Result<()> {
549        for input_def in &self.definition.inputs {
550            if input_def.required
551                && context.get_input(&input_def.name).is_none()
552                    && input_def.default.is_none() {
553                        anyhow::bail!("Required input '{}' is missing", input_def.name);
554                    }
555        }
556        Ok(())
557    }
558
559    /// 获取工作流定义
560    pub fn definition(&self) -> &WorkflowDef {
561        &self.definition
562    }
563}
564
565/// 默认任务执行器(用于测试)
566pub struct DefaultTaskExecutor;
567
568#[async_trait::async_trait]
569impl TaskExecutor for DefaultTaskExecutor {
570    async fn execute(
571        &self,
572        task_name: &str,
573        _params: &HashMap<String, serde_json::Value>,
574        _context: &WorkflowContext,
575    ) -> Result<serde_json::Value> {
576        Ok(serde_json::json!({
577            "task": task_name,
578            "status": "completed",
579            "output": null
580        }))
581    }
582}
583
584#[cfg(test)]
585mod tests {
586    use super::*;
587    use super::super::def::EdgeDef;
588    use super::super::context::WorkflowStatus;
589
590    fn create_simple_workflow() -> WorkflowDef {
591        WorkflowDef {
592            id: "test-workflow".to_string(),
593            name: "Test Workflow".to_string(),
594            version: "1.0.0".to_string(),
595            description: None,
596            inputs: vec![],
597            outputs: vec![],
598            nodes: vec![
599                NodeDef {
600                    id: "start".to_string(),
601                    node_type: NodeType::Start,
602                    name: "Start".to_string(),
603                    description: None,
604                    task: None,
605                    params: HashMap::new(),
606                    on_failure: FailureStrategy::Abort,
607                    timeout_ms: None,
608                    branches: None,
609                    parallel_branches: None,
610                    workflow: None,
611                    wait_ms: None,
612                    approvers: None,
613                },
614                NodeDef {
615                    id: "task1".to_string(),
616                    node_type: NodeType::Task,
617                    name: "Task 1".to_string(),
618                    description: None,
619                    task: Some("do_something".to_string()),
620                    params: HashMap::new(),
621                    on_failure: FailureStrategy::Abort,
622                    timeout_ms: None,
623                    branches: None,
624                    parallel_branches: None,
625                    workflow: None,
626                    wait_ms: None,
627                    approvers: None,
628                },
629                NodeDef {
630                    id: "end".to_string(),
631                    node_type: NodeType::End,
632                    name: "End".to_string(),
633                    description: None,
634                    task: None,
635                    params: HashMap::new(),
636                    on_failure: FailureStrategy::Abort,
637                    timeout_ms: None,
638                    branches: None,
639                    parallel_branches: None,
640                    workflow: None,
641                    wait_ms: None,
642                    approvers: None,
643                },
644            ],
645            edges: vec![
646                EdgeDef {
647                    id: "e1".to_string(),
648                    from: "start".to_string(),
649                    to: "task1".to_string(),
650                    condition: None,
651                    label: None,
652                },
653                EdgeDef {
654                    id: "e2".to_string(),
655                    from: "task1".to_string(),
656                    to: "end".to_string(),
657                    condition: None,
658                    label: None,
659                },
660            ],
661            variables: HashMap::new(),
662            default_failure_strategy: FailureStrategy::Abort,
663            timeout_ms: None,
664        }
665    }
666
667    #[tokio::test]
668    async fn test_engine_run() {
669        let workflow = create_simple_workflow();
670        let engine = WorkflowEngine::new(workflow).unwrap();
671
672        let inputs = HashMap::new();
673        let context = engine.run(inputs).await.unwrap();
674
675        assert_eq!(context.status, WorkflowStatus::Completed);
676        assert_eq!(context.execution_path.len(), 3);
677    }
678
679    #[tokio::test]
680    async fn test_engine_with_executor() {
681        let workflow = create_simple_workflow();
682        let executor = Arc::new(DefaultTaskExecutor);
683        let engine = WorkflowEngine::new(workflow)
684            .unwrap()
685            .with_executor(executor);
686
687        let inputs = HashMap::new();
688        let context = engine.run(inputs).await.unwrap();
689
690        assert_eq!(context.status, WorkflowStatus::Completed);
691    }
692}