Skip to main content

matrixcode_core/workflow/
engine.rs

1//! Workflow Engine - State Machine Implementation
2//!
3//! 工作流引擎,实现状态机的基础结构和主运行循环。
4
5use super::context::WorkflowContext;
6use super::def::{ExecutionMode, FailureStrategy, NodeDef, NodeType, WorkflowDef};
7use super::executors::{ExecutorFactory, NodeExecutor};
8use super::rule_engine::evaluate_expression;
9use super::template::TemplateRenderer;
10use crate::tools::toolproxy::{ProxyToolDef, ProxyToolExecutor};
11use anyhow::{Context, Result};
12use std::collections::HashMap;
13use std::sync::Arc;
14use std::time::Duration;
15use tokio::time::timeout;
16
17/// 任务执行器 trait
18#[async_trait::async_trait]
19pub trait TaskExecutor: Send + Sync {
20    /// 执行任务,返回输出数据
21    async fn execute(
22        &self,
23        task_name: &str,
24        params: &HashMap<String, serde_json::Value>,
25        context: &WorkflowContext,
26    ) -> Result<serde_json::Value>;
27}
28
29/// 工作流事件
30#[derive(Debug, Clone)]
31pub enum WorkflowEvent {
32    /// 工作流开始
33    Started,
34    /// 节点开始执行
35    NodeStarted { node_id: String },
36    /// 节点执行完成
37    NodeCompleted {
38        node_id: String,
39        output: Option<serde_json::Value>,
40    },
41    /// 节点执行失败
42    NodeFailed { node_id: String, error: String },
43    /// 节点跳过
44    NodeSkipped { node_id: String, reason: String },
45    /// 工作流完成
46    Completed,
47    /// 工作流失败
48    Failed { error: String },
49    /// 工作流暂停
50    Paused,
51    /// 工作流恢复
52    Resumed,
53}
54
55/// 事件监听器 trait
56pub trait EventListener: Send + Sync {
57    fn on_event(&self, event: WorkflowEvent);
58}
59
60/// 工作流引擎
61pub struct WorkflowEngine {
62    /// 工作流定义
63    definition: WorkflowDef,
64    /// 任务执行器(旧接口)
65    executor: Option<Arc<dyn TaskExecutor>>,
66    /// 节点执行器(新接口)
67    node_executors: HashMap<String, Arc<dyn NodeExecutor>>,
68    /// 执行器工厂
69    executor_factory: Option<ExecutorFactory>,
70    /// 代理工具执行器
71    proxy_executor: Option<Arc<dyn ProxyToolExecutor>>,
72    /// 代理工具定义列表
73    proxy_tool_defs: Vec<ProxyToolDef>,
74    /// 事件监听器
75    listeners: Vec<Box<dyn EventListener>>,
76    /// 模板渲染器
77    template_renderer: TemplateRenderer,
78}
79
80impl WorkflowEngine {
81    /// 创建新的工作流引擎
82    pub fn new(definition: WorkflowDef) -> Result<Self> {
83        definition
84            .validate()
85            .with_context(|| "Invalid workflow definition")?;
86
87        Ok(Self {
88            definition,
89            executor: None,
90            node_executors: HashMap::new(),
91            executor_factory: None,
92            proxy_executor: None,
93            proxy_tool_defs: Vec::new(),
94            listeners: Vec::new(),
95            template_renderer: TemplateRenderer::new(),
96        })
97    }
98
99    /// 设置任务执行器(旧接口)
100    pub fn with_executor(mut self, executor: Arc<dyn TaskExecutor>) -> Self {
101        self.executor = Some(executor);
102        self
103    }
104
105    /// 设置执行器工厂
106    pub fn with_executor_factory(mut self, factory: ExecutorFactory) -> Self {
107        self.executor_factory = Some(factory);
108        self
109    }
110
111    /// 设置代理工具执行器
112    pub fn with_proxy_executor(
113        mut self,
114        executor: Arc<dyn ProxyToolExecutor>,
115        tool_defs: Vec<ProxyToolDef>,
116    ) -> Self {
117        self.proxy_executor = Some(executor);
118        self.proxy_tool_defs = tool_defs;
119        self
120    }
121
122    /// 注册节点执行器
123    pub fn register_node_executor(
124        mut self,
125        task_type: &str,
126        executor: Arc<dyn NodeExecutor>,
127    ) -> Self {
128        self.node_executors.insert(task_type.to_string(), executor);
129        self
130    }
131
132    /// 添加事件监听器
133    pub fn add_listener(&mut self, listener: Box<dyn EventListener>) {
134        self.listeners.push(listener);
135    }
136
137    /// 触发事件
138    fn emit_event(&self, event: WorkflowEvent) {
139        for listener in &self.listeners {
140            listener.on_event(event.clone());
141        }
142    }
143
144    /// 获取节点执行器
145    fn get_node_executor(&self, node: &NodeDef) -> Option<Arc<dyn NodeExecutor>> {
146        // 优先从注册的执行器中查找
147        if let Some(task) = &node.task
148            && let Some(executor) = self.node_executors.get(task)
149        {
150            return Some(executor.clone());
151        }
152
153        // 检查是否是代理工具
154        if let Some(task) = &node.task
155            && self
156                .proxy_tool_defs
157                .iter()
158                .any(|t| t.definition.name == *task)
159            && let Some(executor) = &self.proxy_executor
160        {
161            return Some(Arc::new(super::executors::ProxyExecutor::new(
162                executor.clone(),
163                self.proxy_tool_defs.clone(),
164            )));
165        }
166
167        // 根据节点类型选择默认执行器
168        match node.node_type {
169            NodeType::Task => {
170                // 尝试从工厂创建
171                if let Some(factory) = &self.executor_factory
172                    && let Some(task) = &node.task
173                {
174                    // 根据任务名称推断执行器类型
175                    // ai / ai_* / claude* / gpt* 使用 AI 执行器
176                    let task_lower = task.to_lowercase();
177                    if task_lower == "ai"
178                        || task_lower.starts_with("ai_")
179                        || task_lower.starts_with("claude")
180                        || task_lower.starts_with("gpt")
181                    {
182                        return factory.create_ai_executor().ok();
183                    }
184                    // 默认使用工具执行器
185                    return Some(factory.create_tool_executor());
186                }
187            }
188            NodeType::Condition => {
189                if let Some(factory) = &self.executor_factory {
190                    return Some(factory.create_condition_executor());
191                }
192            }
193            NodeType::Approval => {
194                // 审批节点使用特殊的验证执行器
195                if let Some(factory) = &self.executor_factory {
196                    return Some(factory.create_validate_executor());
197                }
198            }
199            _ => {}
200        }
201
202        None
203    }
204
205    /// 运行工作流
206    pub async fn run(&self, inputs: HashMap<String, serde_json::Value>) -> Result<WorkflowContext> {
207        // 创建上下文
208        let mut context = WorkflowContext::new(self.definition.id.clone(), inputs.clone());
209
210        // 验证必填输入
211        self.validate_inputs(&context)?;
212
213        // 初始化变量:先添加 inputs
214        for (key, value) in inputs {
215            context.set_variable(key.clone(), value.clone());
216        }
217
218        // 渲染并添加 workflow 定义的变量
219        let renderer = crate::workflow::template::TemplateRenderer::new();
220        for (key, value) in &self.definition.variables {
221            // 如果是字符串,渲染模板
222            let rendered_value = if let serde_json::Value::String(s) = value {
223                match renderer.render(s, &context.variables) {
224                    Ok(rendered) => serde_json::Value::String(rendered),
225                    Err(_) => value.clone(), // 渲染失败保持原值
226                }
227            } else {
228                value.clone()
229            };
230            context.set_variable(key.clone(), rendered_value);
231        }
232
233        // 开始工作流
234        context.start();
235        self.emit_event(WorkflowEvent::Started);
236
237        // 获取开始节点
238        let start_node = self
239            .definition
240            .get_start_node()
241            .ok_or_else(|| anyhow::anyhow!("No start node found"))?;
242
243        // 执行工作流
244        match self.execute_from_node(start_node, &mut context).await {
245            Ok(()) => {
246                context.complete();
247                self.emit_event(WorkflowEvent::Completed);
248            }
249            Err(e) => {
250                context.fail(e.to_string());
251                self.emit_event(WorkflowEvent::Failed {
252                    error: e.to_string(),
253                });
254            }
255        }
256
257        Ok(context)
258    }
259
260    /// 从指定节点开始执行
261    async fn execute_from_node(&self, node: &NodeDef, context: &mut WorkflowContext) -> Result<()> {
262        let mut current_node = Some(node);
263
264        while let Some(node) = current_node {
265            // 检查工作流状态
266            if !context.can_continue() {
267                break;
268            }
269
270            // 执行节点
271            match self.execute_node(node, context).await {
272                Ok(next_node_id) => {
273                    current_node = next_node_id
274                        .as_ref()
275                        .and_then(|id| self.definition.get_node(id));
276                }
277                Err(e) => {
278                    // 处理失败
279                    match &node.on_failure {
280                        FailureStrategy::Retry {
281                            max_attempts,
282                            interval_ms,
283                        } => {
284                            let exec = context.get_or_create_node_execution(&node.id);
285                            if exec.retry_count < *max_attempts {
286                                exec.increment_retry();
287                                if let Some(interval) = interval_ms {
288                                    tokio::time::sleep(Duration::from_millis(*interval)).await;
289                                }
290                                continue; // 重试当前节点
291                            } else {
292                                return Err(e);
293                            }
294                        }
295                        FailureStrategy::Ignore => {
296                            // 忽略错误,标记节点为 Skipped 并继续执行下一个节点
297                            let exec = context.get_or_create_node_execution(&node.id);
298                            exec.skip();
299                            self.emit_event(WorkflowEvent::NodeSkipped {
300                                node_id: node.id.clone(),
301                                reason: e.to_string(),
302                            });
303                            let next = self.get_next_node(node, context)?;
304                            current_node =
305                                next.as_ref().and_then(|id| self.definition.get_node(id));
306                        }
307                        FailureStrategy::Abort => {
308                            return Err(e);
309                        }
310                        FailureStrategy::Goto { target } => {
311                            current_node = self.definition.get_node(target);
312                        }
313                    }
314                }
315            }
316        }
317
318        Ok(())
319    }
320
321    /// 执行单个节点
322    async fn execute_node(
323        &self,
324        node: &NodeDef,
325        context: &mut WorkflowContext,
326    ) -> Result<Option<String>> {
327        // 创建执行记录
328        let execution = context.get_or_create_node_execution(&node.id);
329        execution.start();
330        self.emit_event(WorkflowEvent::NodeStarted {
331            node_id: node.id.clone(),
332        });
333
334        // 设置当前节点
335        context.set_current_node(node.id.clone());
336
337        // 处理超时
338        let result = if let Some(timeout_ms) = node.timeout_ms {
339            timeout(
340                Duration::from_millis(timeout_ms),
341                self.execute_node_inner(node, context),
342            )
343            .await
344            .with_context(|| format!("Node '{}' timed out after {}ms", node.id, timeout_ms))?
345        } else {
346            self.execute_node_inner(node, context).await
347        };
348
349        match result {
350            Ok(output) => {
351                let exec = context.get_or_create_node_execution(&node.id);
352                exec.complete(output.clone());
353                self.emit_event(WorkflowEvent::NodeCompleted {
354                    node_id: node.id.clone(),
355                    output,
356                });
357
358                // 获取下一个节点
359                self.get_next_node(node, context)
360            }
361            Err(e) => {
362                let exec = context.get_or_create_node_execution(&node.id);
363                exec.fail(e.to_string());
364                self.emit_event(WorkflowEvent::NodeFailed {
365                    node_id: node.id.clone(),
366                    error: e.to_string(),
367                });
368                Err(e)
369            }
370        }
371    }
372
373    /// 节点内部执行逻辑
374    async fn execute_node_inner(
375        &self,
376        node: &NodeDef,
377        context: &mut WorkflowContext,
378    ) -> Result<Option<serde_json::Value>> {
379        match &node.node_type {
380            NodeType::Start => Ok(None),
381            NodeType::End => Ok(None),
382            NodeType::Task => self.execute_task(node, context).await,
383            NodeType::Condition => self.execute_condition(node, context).await,
384            NodeType::Parallel => self.execute_parallel(node, context).await,
385            NodeType::Pipeline => self.execute_pipeline(node, context).await,
386            NodeType::SubWorkflow => self.execute_subworkflow(node, context).await,
387            NodeType::Wait => self.execute_wait(node, context).await,
388            NodeType::Approval => self.execute_approval(node, context).await,
389        }
390    }
391
392    /// 执行任务节点
393    async fn execute_task(
394        &self,
395        node: &NodeDef,
396        context: &mut WorkflowContext,
397    ) -> Result<Option<serde_json::Value>> {
398        let task_name = node
399            .task
400            .as_ref()
401            .ok_or_else(|| anyhow::anyhow!("Task node '{}' has no task name", node.id))?;
402
403        // 渲染参数
404        let mut rendered_params = HashMap::new();
405        for (key, value) in &node.params {
406            if let serde_json::Value::String(s) = value {
407                let rendered = self.template_renderer.render(s, &context.variables)?;
408                rendered_params.insert(key.clone(), serde_json::Value::String(rendered));
409            } else {
410                rendered_params.insert(key.clone(), value.clone());
411            }
412        }
413
414        // 尝试使用新的 NodeExecutor 接口
415        if let Some(node_executor) = self.get_node_executor(node) {
416            let output = node_executor.execute(node, context).await?;
417            return Ok(Some(output));
418        }
419
420        // 回退到旧的 TaskExecutor 接口
421        if let Some(executor) = &self.executor {
422            let output = executor
423                .execute(task_name, &rendered_params, context)
424                .await?;
425            Ok(Some(output))
426        } else {
427            // 无执行器,返回模拟输出
428            Ok(Some(
429                serde_json::json!({ "task": task_name, "status": "completed" }),
430            ))
431        }
432    }
433
434    /// 执行条件节点
435    async fn execute_condition(
436        &self,
437        node: &NodeDef,
438        context: &mut WorkflowContext,
439    ) -> Result<Option<serde_json::Value>> {
440        let branches = node
441            .branches
442            .as_ref()
443            .ok_or_else(|| anyhow::anyhow!("Condition node '{}' has no branches", node.id))?;
444
445        for branch in branches {
446            if evaluate_expression(&branch.condition, &context.variables)? {
447                // 找到匹配的分支,设置目标节点
448                return Ok(Some(serde_json::Value::String(branch.target.clone())));
449            }
450        }
451
452        // 没有匹配的分支
453        Ok(None)
454    }
455
456    /// 执行并行节点
457    async fn execute_parallel(
458        &self,
459        node: &NodeDef,
460        _context: &mut WorkflowContext,
461    ) -> Result<Option<serde_json::Value>> {
462        let branches = node
463            .parallel_branches
464            .as_ref()
465            .ok_or_else(|| anyhow::anyhow!("Parallel node '{}' has no branches", node.id))?;
466
467        // 并行执行所有分支
468        let mut outputs = Vec::new();
469        for branch in branches {
470            // 这里简化处理,实际应该并行执行
471            outputs.push(serde_json::json!({
472                "branch": branch.name,
473                "status": "completed"
474            }));
475        }
476
477        Ok(Some(serde_json::Value::Array(outputs)))
478    }
479
480    /// 执行流式节点 (Pipeline 模式)
481    ///
482    /// Pipeline 模式特点:
483    /// - 任务完成后立即流转到下一阶段
484    /// - 不等待其他任务完成(无屏障)
485    /// - Wall-clock = 最慢的单任务链
486    async fn execute_pipeline(
487        &self,
488        node: &NodeDef,
489        _context: &mut WorkflowContext,
490    ) -> Result<Option<serde_json::Value>> {
491        let branches = node
492            .parallel_branches
493            .as_ref()
494            .ok_or_else(|| anyhow::anyhow!("Pipeline node '{}' has no branches", node.id))?;
495
496        // Pipeline 模式:流式处理,无屏障等待
497        // 每个分支独立流转,不等待其他分支完成
498        let mut outputs = Vec::new();
499        for branch in branches {
500            // 验证分支的执行模式
501            if branch.mode != ExecutionMode::Pipeline {
502                log::warn!(
503                    "Pipeline node '{}' branch '{}' has mode '{}', expected Pipeline",
504                    node.id, branch.name, branch.mode.display_name()
505                );
506            }
507
508            // 流式处理:立即返回,不等待
509            outputs.push(serde_json::json!({
510                "branch": branch.name,
511                "mode": "pipeline",
512                "status": "streaming",
513                "has_barrier": false
514            }));
515        }
516
517        log::info!(
518            "Pipeline node '{}' executing {} branches in streaming mode (no barrier)",
519            node.id,
520            branches.len()
521        );
522
523        Ok(Some(serde_json::Value::Array(outputs)))
524    }
525
526    /// 执行子工作流
527    async fn execute_subworkflow(
528        &self,
529        node: &NodeDef,
530        _context: &mut WorkflowContext,
531    ) -> Result<Option<serde_json::Value>> {
532        let workflow_name = node.workflow.as_ref().ok_or_else(|| {
533            anyhow::anyhow!("SubWorkflow node '{}' has no workflow name", node.id)
534        })?;
535
536        // 这里简化处理,实际应该加载并执行子工作流
537        Ok(Some(serde_json::json!({
538            "workflow": workflow_name,
539            "status": "completed"
540        })))
541    }
542
543    /// 执行等待节点
544    async fn execute_wait(
545        &self,
546        node: &NodeDef,
547        _context: &mut WorkflowContext,
548    ) -> Result<Option<serde_json::Value>> {
549        let wait_ms = node.wait_ms.unwrap_or(0);
550        if wait_ms > 0 {
551            tokio::time::sleep(Duration::from_millis(wait_ms)).await;
552        }
553        Ok(None)
554    }
555
556    /// 执行审批节点
557    async fn execute_approval(
558        &self,
559        node: &NodeDef,
560        _context: &mut WorkflowContext,
561    ) -> Result<Option<serde_json::Value>> {
562        let approvers = node
563            .approvers
564            .as_ref()
565            .ok_or_else(|| anyhow::anyhow!("Approval node '{}' has no approvers", node.id))?;
566
567        // 这里简化处理,实际应该等待审批
568        Ok(Some(serde_json::json!({
569            "approvers": approvers,
570            "status": "pending_approval"
571        })))
572    }
573
574    /// 获取下一个节点
575    fn get_next_node(&self, node: &NodeDef, context: &WorkflowContext) -> Result<Option<String>> {
576        // 结束节点没有下一个节点
577        if node.node_type == NodeType::End {
578            return Ok(None);
579        }
580
581        // 获取输出边
582        let edges = self.definition.get_outgoing_edges(&node.id);
583
584        if edges.is_empty() {
585            return Ok(None);
586        }
587
588        // 条件节点从分支获取下一个节点
589        if node.node_type == NodeType::Condition {
590            let exec = context.get_node_execution(&node.id);
591            if let Some(exec) = exec
592                && let Some(serde_json::Value::String(target)) = &exec.output
593            {
594                return Ok(Some(target.clone()));
595            }
596        }
597
598        // 根据边条件选择下一个节点
599        for edge in edges {
600            if let Some(condition) = &edge.condition {
601                if evaluate_expression(condition, &context.variables)? {
602                    return Ok(Some(edge.to.clone()));
603                }
604            } else {
605                // 无条件的边,直接返回
606                return Ok(Some(edge.to.clone()));
607            }
608        }
609
610        // 没有匹配的边
611        Ok(None)
612    }
613
614    /// 验证输入参数
615    fn validate_inputs(&self, context: &WorkflowContext) -> Result<()> {
616        for input_def in &self.definition.inputs {
617            if input_def.required
618                && context.get_input(&input_def.name).is_none()
619                && input_def.default.is_none()
620            {
621                anyhow::bail!("Required input '{}' is missing", input_def.name);
622            }
623        }
624        Ok(())
625    }
626
627    /// 获取工作流定义
628    pub fn definition(&self) -> &WorkflowDef {
629        &self.definition
630    }
631}
632
633/// 默认任务执行器(用于测试)
634pub struct DefaultTaskExecutor;
635
636#[async_trait::async_trait]
637impl TaskExecutor for DefaultTaskExecutor {
638    async fn execute(
639        &self,
640        task_name: &str,
641        _params: &HashMap<String, serde_json::Value>,
642        _context: &WorkflowContext,
643    ) -> Result<serde_json::Value> {
644        Ok(serde_json::json!({
645            "task": task_name,
646            "status": "completed",
647            "output": null
648        }))
649    }
650}
651
652#[cfg(test)]
653mod tests {
654    use super::super::context::WorkflowStatus;
655    use super::super::def::EdgeDef;
656    use super::*;
657
658    fn create_simple_workflow() -> WorkflowDef {
659        WorkflowDef {
660            id: "test-workflow".to_string(),
661            name: "Test Workflow".to_string(),
662            version: "1.0.0".to_string(),
663            description: None,
664            inputs: vec![],
665            outputs: vec![],
666            nodes: vec![
667                NodeDef {
668                    id: "start".to_string(),
669                    node_type: NodeType::Start,
670                    name: "Start".to_string(),
671                    description: None,
672                    task: None,
673                    params: HashMap::new(),
674                    on_failure: FailureStrategy::Abort,
675                    timeout_ms: None,
676                    branches: None,
677                    parallel_branches: None,
678                    execution_mode: None,
679                    workflow: None,
680                    wait_ms: None,
681                    approvers: None,
682                },
683                NodeDef {
684                    id: "task1".to_string(),
685                    node_type: NodeType::Task,
686                    name: "Task 1".to_string(),
687                    description: None,
688                    task: Some("do_something".to_string()),
689                    params: HashMap::new(),
690                    on_failure: FailureStrategy::Abort,
691                    timeout_ms: None,
692                    branches: None,
693                    parallel_branches: None,
694                    execution_mode: None,
695                    workflow: None,
696                    wait_ms: None,
697                    approvers: None,
698                },
699                NodeDef {
700                    id: "end".to_string(),
701                    node_type: NodeType::End,
702                    name: "End".to_string(),
703                    description: None,
704                    task: None,
705                    params: HashMap::new(),
706                    on_failure: FailureStrategy::Abort,
707                    timeout_ms: None,
708                    branches: None,
709                    parallel_branches: None,
710                    execution_mode: None,
711                    workflow: None,
712                    wait_ms: None,
713                    approvers: None,
714                },
715            ],
716            edges: vec![
717                EdgeDef {
718                    id: "e1".to_string(),
719                    from: "start".to_string(),
720                    to: "task1".to_string(),
721                    condition: None,
722                    label: None,
723                },
724                EdgeDef {
725                    id: "e2".to_string(),
726                    from: "task1".to_string(),
727                    to: "end".to_string(),
728                    condition: None,
729                    label: None,
730                },
731            ],
732            variables: HashMap::new(),
733            default_failure_strategy: FailureStrategy::Abort,
734            timeout_ms: None,
735        }
736    }
737
738    #[tokio::test]
739    async fn test_engine_run() {
740        let workflow = create_simple_workflow();
741        let engine = WorkflowEngine::new(workflow).unwrap();
742
743        let inputs = HashMap::new();
744        let context = engine.run(inputs).await.unwrap();
745
746        assert_eq!(context.status, WorkflowStatus::Completed);
747        assert_eq!(context.execution_path.len(), 3);
748    }
749
750    #[tokio::test]
751    async fn test_engine_with_executor() {
752        let workflow = create_simple_workflow();
753        let executor = Arc::new(DefaultTaskExecutor);
754        let engine = WorkflowEngine::new(workflow)
755            .unwrap()
756            .with_executor(executor);
757
758        let inputs = HashMap::new();
759        let context = engine.run(inputs).await.unwrap();
760
761        assert_eq!(context.status, WorkflowStatus::Completed);
762    }
763}