Skip to main content

matrixcode_core/workflow/
engine.rs

1//! Workflow Engine - State Machine Implementation
2//!
3//! 工作流引擎,实现状态机的基础结构和主运行循环。
4
5use super::context::WorkflowContext;
6use super::def::{FailureStrategy, NodeDef, NodeType, WorkflowDef};
7use super::executors::{ExecutorFactory, NodeExecutor};
8use super::rule_engine::evaluate_expression;
9use super::template::TemplateRenderer;
10use crate::tools::toolproxy::{ProxyToolDef, ProxyToolExecutor};
11use anyhow::{Context, Result};
12use std::collections::HashMap;
13use std::sync::Arc;
14use std::time::Duration;
15use tokio::time::timeout;
16
17/// 任务执行器 trait
18#[async_trait::async_trait]
19pub trait TaskExecutor: Send + Sync {
20    /// 执行任务,返回输出数据
21    async fn execute(
22        &self,
23        task_name: &str,
24        params: &HashMap<String, serde_json::Value>,
25        context: &WorkflowContext,
26    ) -> Result<serde_json::Value>;
27}
28
29/// 工作流事件
30#[derive(Debug, Clone)]
31pub enum WorkflowEvent {
32    /// 工作流开始
33    Started,
34    /// 节点开始执行
35    NodeStarted { node_id: String },
36    /// 节点执行完成
37    NodeCompleted {
38        node_id: String,
39        output: Option<serde_json::Value>,
40    },
41    /// 节点执行失败
42    NodeFailed { node_id: String, error: String },
43    /// 节点跳过
44    NodeSkipped { node_id: String, reason: String },
45    /// 工作流完成
46    Completed,
47    /// 工作流失败
48    Failed { error: String },
49    /// 工作流暂停
50    Paused,
51    /// 工作流恢复
52    Resumed,
53}
54
55/// 事件监听器 trait
56pub trait EventListener: Send + Sync {
57    fn on_event(&self, event: WorkflowEvent);
58}
59
60/// 工作流引擎
61pub struct WorkflowEngine {
62    /// 工作流定义
63    definition: WorkflowDef,
64    /// 任务执行器(旧接口)
65    executor: Option<Arc<dyn TaskExecutor>>,
66    /// 节点执行器(新接口)
67    node_executors: HashMap<String, Arc<dyn NodeExecutor>>,
68    /// 执行器工厂
69    executor_factory: Option<ExecutorFactory>,
70    /// 代理工具执行器
71    proxy_executor: Option<Arc<dyn ProxyToolExecutor>>,
72    /// 代理工具定义列表
73    proxy_tool_defs: Vec<ProxyToolDef>,
74    /// 事件监听器
75    listeners: Vec<Box<dyn EventListener>>,
76    /// 模板渲染器
77    template_renderer: TemplateRenderer,
78}
79
80impl WorkflowEngine {
81    /// 创建新的工作流引擎
82    pub fn new(definition: WorkflowDef) -> Result<Self> {
83        definition
84            .validate()
85            .with_context(|| "Invalid workflow definition")?;
86
87        Ok(Self {
88            definition,
89            executor: None,
90            node_executors: HashMap::new(),
91            executor_factory: None,
92            proxy_executor: None,
93            proxy_tool_defs: Vec::new(),
94            listeners: Vec::new(),
95            template_renderer: TemplateRenderer::new(),
96        })
97    }
98
99    /// 设置任务执行器(旧接口)
100    pub fn with_executor(mut self, executor: Arc<dyn TaskExecutor>) -> Self {
101        self.executor = Some(executor);
102        self
103    }
104
105    /// 设置执行器工厂
106    pub fn with_executor_factory(mut self, factory: ExecutorFactory) -> Self {
107        self.executor_factory = Some(factory);
108        self
109    }
110
111    /// 设置代理工具执行器
112    pub fn with_proxy_executor(
113        mut self,
114        executor: Arc<dyn ProxyToolExecutor>,
115        tool_defs: Vec<ProxyToolDef>,
116    ) -> Self {
117        self.proxy_executor = Some(executor);
118        self.proxy_tool_defs = tool_defs;
119        self
120    }
121
122    /// 注册节点执行器
123    pub fn register_node_executor(
124        mut self,
125        task_type: &str,
126        executor: Arc<dyn NodeExecutor>,
127    ) -> Self {
128        self.node_executors.insert(task_type.to_string(), executor);
129        self
130    }
131
132    /// 添加事件监听器
133    pub fn add_listener(&mut self, listener: Box<dyn EventListener>) {
134        self.listeners.push(listener);
135    }
136
137    /// 触发事件
138    fn emit_event(&self, event: WorkflowEvent) {
139        for listener in &self.listeners {
140            listener.on_event(event.clone());
141        }
142    }
143
144    /// 获取节点执行器
145    fn get_node_executor(&self, node: &NodeDef) -> Option<Arc<dyn NodeExecutor>> {
146        // 优先从注册的执行器中查找
147        if let Some(task) = &node.task
148            && let Some(executor) = self.node_executors.get(task)
149        {
150            return Some(executor.clone());
151        }
152
153        // 检查是否是代理工具
154        if let Some(task) = &node.task
155            && self
156                .proxy_tool_defs
157                .iter()
158                .any(|t| t.definition.name == *task)
159            && let Some(executor) = &self.proxy_executor
160        {
161            return Some(Arc::new(super::executors::ProxyExecutor::new(
162                executor.clone(),
163                self.proxy_tool_defs.clone(),
164            )));
165        }
166
167        // 根据节点类型选择默认执行器
168        match node.node_type {
169            NodeType::Task => {
170                // 尝试从工厂创建
171                if let Some(factory) = &self.executor_factory
172                    && let Some(task) = &node.task
173                {
174                    // 根据任务名称推断执行器类型
175                    // ai / ai_* / claude* / gpt* 使用 AI 执行器
176                    let task_lower = task.to_lowercase();
177                    if task_lower == "ai"
178                        || task_lower.starts_with("ai_")
179                        || task_lower.starts_with("claude")
180                        || task_lower.starts_with("gpt")
181                    {
182                        return factory.create_ai_executor().ok();
183                    }
184                    // 默认使用工具执行器
185                    return Some(factory.create_tool_executor());
186                }
187            }
188            NodeType::Condition => {
189                if let Some(factory) = &self.executor_factory {
190                    return Some(factory.create_condition_executor());
191                }
192            }
193            NodeType::Approval => {
194                // 审批节点使用特殊的验证执行器
195                if let Some(factory) = &self.executor_factory {
196                    return Some(factory.create_validate_executor());
197                }
198            }
199            _ => {}
200        }
201
202        None
203    }
204
205    /// 运行工作流
206    pub async fn run(&self, inputs: HashMap<String, serde_json::Value>) -> Result<WorkflowContext> {
207        // 创建上下文
208        let mut context = WorkflowContext::new(self.definition.id.clone(), inputs.clone());
209
210        // 验证必填输入
211        self.validate_inputs(&context)?;
212
213        // 初始化变量:先添加 inputs
214        for (key, value) in inputs {
215            context.set_variable(key.clone(), value.clone());
216        }
217
218        // 渲染并添加 workflow 定义的变量
219        let renderer = crate::workflow::template::TemplateRenderer::new();
220        for (key, value) in &self.definition.variables {
221            // 如果是字符串,渲染模板
222            let rendered_value = if let serde_json::Value::String(s) = value {
223                match renderer.render(s, &context.variables) {
224                    Ok(rendered) => serde_json::Value::String(rendered),
225                    Err(_) => value.clone(), // 渲染失败保持原值
226                }
227            } else {
228                value.clone()
229            };
230            context.set_variable(key.clone(), rendered_value);
231        }
232
233        // 开始工作流
234        context.start();
235        self.emit_event(WorkflowEvent::Started);
236
237        // 获取开始节点
238        let start_node = self
239            .definition
240            .get_start_node()
241            .ok_or_else(|| anyhow::anyhow!("No start node found"))?;
242
243        // 执行工作流
244        match self.execute_from_node(start_node, &mut context).await {
245            Ok(()) => {
246                context.complete();
247                self.emit_event(WorkflowEvent::Completed);
248            }
249            Err(e) => {
250                context.fail(e.to_string());
251                self.emit_event(WorkflowEvent::Failed {
252                    error: e.to_string(),
253                });
254            }
255        }
256
257        Ok(context)
258    }
259
260    /// 从指定节点开始执行
261    async fn execute_from_node(&self, node: &NodeDef, context: &mut WorkflowContext) -> Result<()> {
262        let mut current_node = Some(node);
263
264        while let Some(node) = current_node {
265            // 检查工作流状态
266            if !context.can_continue() {
267                break;
268            }
269
270            // 执行节点
271            match self.execute_node(node, context).await {
272                Ok(next_node_id) => {
273                    current_node = next_node_id
274                        .as_ref()
275                        .and_then(|id| self.definition.get_node(id));
276                }
277                Err(e) => {
278                    // 处理失败
279                    match &node.on_failure {
280                        FailureStrategy::Retry {
281                            max_attempts,
282                            interval_ms,
283                        } => {
284                            let exec = context.get_or_create_node_execution(&node.id);
285                            if exec.retry_count < *max_attempts {
286                                exec.increment_retry();
287                                if let Some(interval) = interval_ms {
288                                    tokio::time::sleep(Duration::from_millis(*interval)).await;
289                                }
290                                continue; // 重试当前节点
291                            } else {
292                                return Err(e);
293                            }
294                        }
295                        FailureStrategy::Ignore => {
296                            // 忽略错误,标记节点为 Skipped 并继续执行下一个节点
297                            let exec = context.get_or_create_node_execution(&node.id);
298                            exec.skip();
299                            self.emit_event(WorkflowEvent::NodeSkipped {
300                                node_id: node.id.clone(),
301                                reason: e.to_string(),
302                            });
303                            let next = self.get_next_node(node, context)?;
304                            current_node =
305                                next.as_ref().and_then(|id| self.definition.get_node(id));
306                        }
307                        FailureStrategy::Abort => {
308                            return Err(e);
309                        }
310                        FailureStrategy::Goto { target } => {
311                            current_node = self.definition.get_node(target);
312                        }
313                    }
314                }
315            }
316        }
317
318        Ok(())
319    }
320
321    /// 执行单个节点
322    async fn execute_node(
323        &self,
324        node: &NodeDef,
325        context: &mut WorkflowContext,
326    ) -> Result<Option<String>> {
327        // 创建执行记录
328        let execution = context.get_or_create_node_execution(&node.id);
329        execution.start();
330        self.emit_event(WorkflowEvent::NodeStarted {
331            node_id: node.id.clone(),
332        });
333
334        // 设置当前节点
335        context.set_current_node(node.id.clone());
336
337        // 处理超时
338        let result = if let Some(timeout_ms) = node.timeout_ms {
339            timeout(
340                Duration::from_millis(timeout_ms),
341                self.execute_node_inner(node, context),
342            )
343            .await
344            .with_context(|| format!("Node '{}' timed out after {}ms", node.id, timeout_ms))?
345        } else {
346            self.execute_node_inner(node, context).await
347        };
348
349        match result {
350            Ok(output) => {
351                let exec = context.get_or_create_node_execution(&node.id);
352                exec.complete(output.clone());
353                self.emit_event(WorkflowEvent::NodeCompleted {
354                    node_id: node.id.clone(),
355                    output,
356                });
357
358                // 获取下一个节点
359                self.get_next_node(node, context)
360            }
361            Err(e) => {
362                let exec = context.get_or_create_node_execution(&node.id);
363                exec.fail(e.to_string());
364                self.emit_event(WorkflowEvent::NodeFailed {
365                    node_id: node.id.clone(),
366                    error: e.to_string(),
367                });
368                Err(e)
369            }
370        }
371    }
372
373    /// 节点内部执行逻辑
374    async fn execute_node_inner(
375        &self,
376        node: &NodeDef,
377        context: &mut WorkflowContext,
378    ) -> Result<Option<serde_json::Value>> {
379        match &node.node_type {
380            NodeType::Start => Ok(None),
381            NodeType::End => Ok(None),
382            NodeType::Task => self.execute_task(node, context).await,
383            NodeType::Condition => self.execute_condition(node, context).await,
384            NodeType::Parallel => self.execute_parallel(node, context).await,
385            NodeType::SubWorkflow => self.execute_subworkflow(node, context).await,
386            NodeType::Wait => self.execute_wait(node, context).await,
387            NodeType::Approval => self.execute_approval(node, context).await,
388        }
389    }
390
391    /// 执行任务节点
392    async fn execute_task(
393        &self,
394        node: &NodeDef,
395        context: &mut WorkflowContext,
396    ) -> Result<Option<serde_json::Value>> {
397        let task_name = node
398            .task
399            .as_ref()
400            .ok_or_else(|| anyhow::anyhow!("Task node '{}' has no task name", node.id))?;
401
402        // 渲染参数
403        let mut rendered_params = HashMap::new();
404        for (key, value) in &node.params {
405            if let serde_json::Value::String(s) = value {
406                let rendered = self.template_renderer.render(s, &context.variables)?;
407                rendered_params.insert(key.clone(), serde_json::Value::String(rendered));
408            } else {
409                rendered_params.insert(key.clone(), value.clone());
410            }
411        }
412
413        // 尝试使用新的 NodeExecutor 接口
414        if let Some(node_executor) = self.get_node_executor(node) {
415            let output = node_executor.execute(node, context).await?;
416            return Ok(Some(output));
417        }
418
419        // 回退到旧的 TaskExecutor 接口
420        if let Some(executor) = &self.executor {
421            let output = executor
422                .execute(task_name, &rendered_params, context)
423                .await?;
424            Ok(Some(output))
425        } else {
426            // 无执行器,返回模拟输出
427            Ok(Some(
428                serde_json::json!({ "task": task_name, "status": "completed" }),
429            ))
430        }
431    }
432
433    /// 执行条件节点
434    async fn execute_condition(
435        &self,
436        node: &NodeDef,
437        context: &mut WorkflowContext,
438    ) -> Result<Option<serde_json::Value>> {
439        let branches = node
440            .branches
441            .as_ref()
442            .ok_or_else(|| anyhow::anyhow!("Condition node '{}' has no branches", node.id))?;
443
444        for branch in branches {
445            if evaluate_expression(&branch.condition, &context.variables)? {
446                // 找到匹配的分支,设置目标节点
447                return Ok(Some(serde_json::Value::String(branch.target.clone())));
448            }
449        }
450
451        // 没有匹配的分支
452        Ok(None)
453    }
454
455    /// 执行并行节点
456    async fn execute_parallel(
457        &self,
458        node: &NodeDef,
459        _context: &mut WorkflowContext,
460    ) -> Result<Option<serde_json::Value>> {
461        let branches = node
462            .parallel_branches
463            .as_ref()
464            .ok_or_else(|| anyhow::anyhow!("Parallel node '{}' has no branches", node.id))?;
465
466        // 并行执行所有分支
467        let mut outputs = Vec::new();
468        for branch in branches {
469            // 这里简化处理,实际应该并行执行
470            outputs.push(serde_json::json!({
471                "branch": branch.name,
472                "status": "completed"
473            }));
474        }
475
476        Ok(Some(serde_json::Value::Array(outputs)))
477    }
478
479    /// 执行子工作流
480    async fn execute_subworkflow(
481        &self,
482        node: &NodeDef,
483        _context: &mut WorkflowContext,
484    ) -> Result<Option<serde_json::Value>> {
485        let workflow_name = node.workflow.as_ref().ok_or_else(|| {
486            anyhow::anyhow!("SubWorkflow node '{}' has no workflow name", node.id)
487        })?;
488
489        // 这里简化处理,实际应该加载并执行子工作流
490        Ok(Some(serde_json::json!({
491            "workflow": workflow_name,
492            "status": "completed"
493        })))
494    }
495
496    /// 执行等待节点
497    async fn execute_wait(
498        &self,
499        node: &NodeDef,
500        _context: &mut WorkflowContext,
501    ) -> Result<Option<serde_json::Value>> {
502        let wait_ms = node.wait_ms.unwrap_or(0);
503        if wait_ms > 0 {
504            tokio::time::sleep(Duration::from_millis(wait_ms)).await;
505        }
506        Ok(None)
507    }
508
509    /// 执行审批节点
510    async fn execute_approval(
511        &self,
512        node: &NodeDef,
513        _context: &mut WorkflowContext,
514    ) -> Result<Option<serde_json::Value>> {
515        let approvers = node
516            .approvers
517            .as_ref()
518            .ok_or_else(|| anyhow::anyhow!("Approval node '{}' has no approvers", node.id))?;
519
520        // 这里简化处理,实际应该等待审批
521        Ok(Some(serde_json::json!({
522            "approvers": approvers,
523            "status": "pending_approval"
524        })))
525    }
526
527    /// 获取下一个节点
528    fn get_next_node(&self, node: &NodeDef, context: &WorkflowContext) -> Result<Option<String>> {
529        // 结束节点没有下一个节点
530        if node.node_type == NodeType::End {
531            return Ok(None);
532        }
533
534        // 获取输出边
535        let edges = self.definition.get_outgoing_edges(&node.id);
536
537        if edges.is_empty() {
538            return Ok(None);
539        }
540
541        // 条件节点从分支获取下一个节点
542        if node.node_type == NodeType::Condition {
543            let exec = context.get_node_execution(&node.id);
544            if let Some(exec) = exec
545                && let Some(serde_json::Value::String(target)) = &exec.output
546            {
547                return Ok(Some(target.clone()));
548            }
549        }
550
551        // 根据边条件选择下一个节点
552        for edge in edges {
553            if let Some(condition) = &edge.condition {
554                if evaluate_expression(condition, &context.variables)? {
555                    return Ok(Some(edge.to.clone()));
556                }
557            } else {
558                // 无条件的边,直接返回
559                return Ok(Some(edge.to.clone()));
560            }
561        }
562
563        // 没有匹配的边
564        Ok(None)
565    }
566
567    /// 验证输入参数
568    fn validate_inputs(&self, context: &WorkflowContext) -> Result<()> {
569        for input_def in &self.definition.inputs {
570            if input_def.required
571                && context.get_input(&input_def.name).is_none()
572                && input_def.default.is_none()
573            {
574                anyhow::bail!("Required input '{}' is missing", input_def.name);
575            }
576        }
577        Ok(())
578    }
579
580    /// 获取工作流定义
581    pub fn definition(&self) -> &WorkflowDef {
582        &self.definition
583    }
584}
585
586/// 默认任务执行器(用于测试)
587pub struct DefaultTaskExecutor;
588
589#[async_trait::async_trait]
590impl TaskExecutor for DefaultTaskExecutor {
591    async fn execute(
592        &self,
593        task_name: &str,
594        _params: &HashMap<String, serde_json::Value>,
595        _context: &WorkflowContext,
596    ) -> Result<serde_json::Value> {
597        Ok(serde_json::json!({
598            "task": task_name,
599            "status": "completed",
600            "output": null
601        }))
602    }
603}
604
605#[cfg(test)]
606mod tests {
607    use super::super::context::WorkflowStatus;
608    use super::super::def::EdgeDef;
609    use super::*;
610
611    fn create_simple_workflow() -> WorkflowDef {
612        WorkflowDef {
613            id: "test-workflow".to_string(),
614            name: "Test Workflow".to_string(),
615            version: "1.0.0".to_string(),
616            description: None,
617            inputs: vec![],
618            outputs: vec![],
619            nodes: vec![
620                NodeDef {
621                    id: "start".to_string(),
622                    node_type: NodeType::Start,
623                    name: "Start".to_string(),
624                    description: None,
625                    task: None,
626                    params: HashMap::new(),
627                    on_failure: FailureStrategy::Abort,
628                    timeout_ms: None,
629                    branches: None,
630                    parallel_branches: None,
631                    workflow: None,
632                    wait_ms: None,
633                    approvers: None,
634                },
635                NodeDef {
636                    id: "task1".to_string(),
637                    node_type: NodeType::Task,
638                    name: "Task 1".to_string(),
639                    description: None,
640                    task: Some("do_something".to_string()),
641                    params: HashMap::new(),
642                    on_failure: FailureStrategy::Abort,
643                    timeout_ms: None,
644                    branches: None,
645                    parallel_branches: None,
646                    workflow: None,
647                    wait_ms: None,
648                    approvers: None,
649                },
650                NodeDef {
651                    id: "end".to_string(),
652                    node_type: NodeType::End,
653                    name: "End".to_string(),
654                    description: None,
655                    task: None,
656                    params: HashMap::new(),
657                    on_failure: FailureStrategy::Abort,
658                    timeout_ms: None,
659                    branches: None,
660                    parallel_branches: None,
661                    workflow: None,
662                    wait_ms: None,
663                    approvers: None,
664                },
665            ],
666            edges: vec![
667                EdgeDef {
668                    id: "e1".to_string(),
669                    from: "start".to_string(),
670                    to: "task1".to_string(),
671                    condition: None,
672                    label: None,
673                },
674                EdgeDef {
675                    id: "e2".to_string(),
676                    from: "task1".to_string(),
677                    to: "end".to_string(),
678                    condition: None,
679                    label: None,
680                },
681            ],
682            variables: HashMap::new(),
683            default_failure_strategy: FailureStrategy::Abort,
684            timeout_ms: None,
685        }
686    }
687
688    #[tokio::test]
689    async fn test_engine_run() {
690        let workflow = create_simple_workflow();
691        let engine = WorkflowEngine::new(workflow).unwrap();
692
693        let inputs = HashMap::new();
694        let context = engine.run(inputs).await.unwrap();
695
696        assert_eq!(context.status, WorkflowStatus::Completed);
697        assert_eq!(context.execution_path.len(), 3);
698    }
699
700    #[tokio::test]
701    async fn test_engine_with_executor() {
702        let workflow = create_simple_workflow();
703        let executor = Arc::new(DefaultTaskExecutor);
704        let engine = WorkflowEngine::new(workflow)
705            .unwrap()
706            .with_executor(executor);
707
708        let inputs = HashMap::new();
709        let context = engine.run(inputs).await.unwrap();
710
711        assert_eq!(context.status, WorkflowStatus::Completed);
712    }
713}