Skip to main content

matrixcode_core/workflow/
def.rs

1//! Workflow Definition Structures
2//!
3//! 定义工作流的核心数据结构,包括节点、边、类型和失败策略。
4//!
5//! ## Execution Modes (Pipeline vs Parallel)
6//!
7//! **Pipeline** (流式处理,无屏障):
8//! - 任务完成后立即流转到下一阶段
9//! - 不等待其他任务完成
10//! - Wall-clock = 最慢的单任务链
11//! - 适用场景:批量文件处理、流式数据转换
12//!
13//! **Parallel** (并行执行,有屏障):
14//! - 所有任务并行启动
15//! - 必须等待全部完成才能继续
16//! - 适用场景:多维度审查、跨源数据收集、结果汇总
17
18use serde::{Deserialize, Serialize};
19use std::collections::HashMap;
20
21/// 节点类型
22#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
23#[serde(rename_all = "snake_case")]
24pub enum NodeType {
25    /// 开始节点
26    Start,
27    /// 结束节点
28    End,
29    /// 任务节点
30    Task,
31    /// 条件分支节点
32    Condition,
33    /// 并行节点
34    Parallel,
35    /// 流式节点 (无屏障等待)
36    Pipeline,
37    /// 子工作流节点
38    SubWorkflow,
39    /// 等待节点
40    Wait,
41    /// 人工审批节点
42    Approval,
43}
44
45/// 执行模式 - 控制并行分支的执行策略
46#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
47#[serde(rename_all = "snake_case")]
48pub enum ExecutionMode {
49    /// Pipeline: 流式处理,无屏障等待
50    /// 任务 A 完成后立即流转到下一阶段,不等待其他任务
51    /// Wall-clock = 最慢的单任务链
52    /// 适用场景:批量文件处理、流式数据转换
53    Pipeline,
54
55    /// Parallel: 并行执行,有屏障等待 (默认)
56    /// 所有任务必须全部完成才能继续下一阶段
57    /// 适用场景:多维度审查、跨源数据收集、结果汇总
58    #[default]
59    Parallel,
60}
61
62impl ExecutionMode {
63    /// 是否需要屏障等待
64    pub fn has_barrier(&self) -> bool {
65        match self {
66            Self::Pipeline => false,
67            Self::Parallel => true,
68        }
69    }
70
71    /// 获取显示名称
72    pub fn display_name(&self) -> &'static str {
73        match self {
74            Self::Pipeline => "流式",
75            Self::Parallel => "并行",
76        }
77    }
78
79    /// 获取描述
80    pub fn description(&self) -> &'static str {
81        match self {
82            Self::Pipeline => "任务流转执行,不等待其他任务",
83            Self::Parallel => "所有任务并行执行,等待全部完成",
84        }
85    }
86}
87
88/// 失败策略类型
89#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
90#[serde(rename_all = "snake_case")]
91pub enum FailureStrategyType {
92    Retry,
93    Ignore,
94    Abort,
95    Goto,
96}
97
98/// 失败策略配置(用于 YAML 解析)
99#[derive(Debug, Clone, Serialize, Deserialize)]
100pub struct FailureStrategyConfig {
101    /// 策略类型
102    #[serde(rename = "type", default = "default_failure_strategy_type")]
103    pub strategy_type: FailureStrategyType,
104    /// 最大重试次数(仅 retry)
105    #[serde(skip_serializing_if = "Option::is_none")]
106    pub max_attempts: Option<u32>,
107    /// 重试间隔(毫秒,仅 retry)
108    #[serde(skip_serializing_if = "Option::is_none")]
109    pub interval_ms: Option<u64>,
110    /// 目标节点ID(仅 goto)
111    #[serde(skip_serializing_if = "Option::is_none")]
112    pub target: Option<String>,
113}
114
115fn default_failure_strategy_type() -> FailureStrategyType {
116    FailureStrategyType::Abort
117}
118
119impl From<FailureStrategyConfig> for FailureStrategy {
120    fn from(config: FailureStrategyConfig) -> Self {
121        match config.strategy_type {
122            FailureStrategyType::Retry => FailureStrategy::Retry {
123                max_attempts: config.max_attempts.unwrap_or(1),
124                interval_ms: config.interval_ms,
125            },
126            FailureStrategyType::Ignore => FailureStrategy::Ignore,
127            FailureStrategyType::Abort => FailureStrategy::Abort,
128            FailureStrategyType::Goto => FailureStrategy::Goto {
129                target: config.target.unwrap_or_default(),
130            },
131        }
132    }
133}
134
135impl From<FailureStrategy> for FailureStrategyConfig {
136    fn from(strategy: FailureStrategy) -> Self {
137        match strategy {
138            FailureStrategy::Retry {
139                max_attempts,
140                interval_ms,
141            } => FailureStrategyConfig {
142                strategy_type: FailureStrategyType::Retry,
143                max_attempts: Some(max_attempts),
144                interval_ms,
145                target: None,
146            },
147            FailureStrategy::Ignore => FailureStrategyConfig {
148                strategy_type: FailureStrategyType::Ignore,
149                max_attempts: None,
150                interval_ms: None,
151                target: None,
152            },
153            FailureStrategy::Abort => FailureStrategyConfig {
154                strategy_type: FailureStrategyType::Abort,
155                max_attempts: None,
156                interval_ms: None,
157                target: None,
158            },
159            FailureStrategy::Goto { target } => FailureStrategyConfig {
160                strategy_type: FailureStrategyType::Goto,
161                max_attempts: None,
162                interval_ms: None,
163                target: Some(target),
164            },
165        }
166    }
167}
168
169/// 失败策略
170#[derive(Debug, Clone, PartialEq, Eq, Default)]
171pub enum FailureStrategy {
172    /// 重试
173    Retry {
174        /// 最大重试次数
175        max_attempts: u32,
176        /// 重试间隔(毫秒)
177        interval_ms: Option<u64>,
178    },
179    /// 忽略继续
180    Ignore,
181    /// 终止工作流
182    #[default]
183    Abort,
184    /// 跳转到指定节点
185    Goto {
186        /// 目标节点ID
187        target: String,
188    },
189}
190
191impl Serialize for FailureStrategy {
192    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
193    where
194        S: serde::Serializer,
195    {
196        let config: FailureStrategyConfig = self.clone().into();
197        config.serialize(serializer)
198    }
199}
200
201impl<'de> Deserialize<'de> for FailureStrategy {
202    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
203    where
204        D: serde::Deserializer<'de>,
205    {
206        let config: FailureStrategyConfig = FailureStrategyConfig::deserialize(deserializer)?;
207        Ok(config.into())
208    }
209}
210
211/// 边定义
212#[derive(Debug, Clone, Serialize, Deserialize)]
213pub struct EdgeDef {
214    /// 边ID
215    #[serde(default = "generate_edge_id")]
216    pub id: String,
217    /// 源节点ID
218    pub from: String,
219    /// 目标节点ID
220    pub to: String,
221    /// 条件表达式(可选)
222    #[serde(skip_serializing_if = "Option::is_none")]
223    pub condition: Option<String>,
224    /// 边标签
225    #[serde(skip_serializing_if = "Option::is_none")]
226    pub label: Option<String>,
227}
228
229fn generate_edge_id() -> String {
230    format!("edge_{}", uuid::Uuid::new_v4())
231}
232
233/// 节点定义
234#[derive(Debug, Clone, Serialize, Deserialize)]
235pub struct NodeDef {
236    /// 节点ID
237    pub id: String,
238    /// 节点类型
239    #[serde(rename = "type")]
240    pub node_type: NodeType,
241    /// 节点名称
242    pub name: String,
243    /// 节点描述
244    #[serde(skip_serializing_if = "Option::is_none")]
245    pub description: Option<String>,
246    /// 任务名称(仅任务节点)
247    #[serde(skip_serializing_if = "Option::is_none")]
248    pub task: Option<String>,
249    /// 任务参数
250    #[serde(default)]
251    pub params: HashMap<String, serde_json::Value>,
252    /// 失败策略
253    #[serde(default)]
254    pub on_failure: FailureStrategy,
255    /// 超时时间(毫秒)
256    #[serde(skip_serializing_if = "Option::is_none")]
257    pub timeout_ms: Option<u64>,
258    /// 条件分支(仅条件节点)
259    #[serde(skip_serializing_if = "Option::is_none")]
260    pub branches: Option<Vec<BranchDef>>,
261    /// 并行分支(仅并行节点和流式节点)
262    #[serde(skip_serializing_if = "Option::is_none")]
263    pub parallel_branches: Option<Vec<ParallelBranchDef>>,
264    /// 执行模式(仅并行/流式节点,默认从节点类型推断)
265    #[serde(skip_serializing_if = "Option::is_none")]
266    pub execution_mode: Option<ExecutionMode>,
267    /// 子工作流名称(仅子工作流节点)
268    #[serde(skip_serializing_if = "Option::is_none")]
269    pub workflow: Option<String>,
270    /// 等待时间(毫秒,仅等待节点)
271    #[serde(skip_serializing_if = "Option::is_none")]
272    pub wait_ms: Option<u64>,
273    /// 审批人列表(仅审批节点)
274    #[serde(skip_serializing_if = "Option::is_none")]
275    pub approvers: Option<Vec<String>>,
276}
277
278impl NodeDef {
279    /// 获取节点的执行模式
280    /// - 对于 Pipeline 类型节点,返回 Pipeline 模式
281    /// - 对于 Parallel 类型节点,返回 Parallel 模式(或自定义模式)
282    /// - 其他类型节点返回 None
283    pub fn get_execution_mode(&self) -> Option<ExecutionMode> {
284        match self.node_type {
285            NodeType::Pipeline => Some(ExecutionMode::Pipeline),
286            NodeType::Parallel => Some(
287                self.execution_mode.unwrap_or(ExecutionMode::Parallel)
288            ),
289            _ => None,
290        }
291    }
292
293    /// 是否需要屏障等待
294    pub fn has_barrier(&self) -> bool {
295        self.get_execution_mode()
296            .map(|m| m.has_barrier())
297            .unwrap_or(false)
298    }
299}
300
301/// 条件分支定义
302#[derive(Debug, Clone, Serialize, Deserialize)]
303pub struct BranchDef {
304    /// 分支名称
305    pub name: String,
306    /// 条件表达式
307    pub condition: String,
308    /// 目标节点ID
309    pub target: String,
310}
311
312/// 并行分支定义
313#[derive(Debug, Clone, Serialize, Deserialize)]
314pub struct ParallelBranchDef {
315    /// 分支名称
316    pub name: String,
317    /// 分支节点列表
318    pub nodes: Vec<NodeDef>,
319    /// 执行模式(可选,默认为 Parallel)
320    /// - Pipeline: 流式处理,无屏障等待
321    /// - Parallel: 并行执行,等待全部完成
322    #[serde(default)]
323    pub mode: ExecutionMode,
324}
325
326impl ParallelBranchDef {
327    /// 创建默认并行分支(Parallel 模式)
328    pub fn new(name: String, nodes: Vec<NodeDef>) -> Self {
329        Self {
330            name,
331            nodes,
332            mode: ExecutionMode::default(),
333        }
334    }
335
336    /// 创建 Pipeline 模式分支(流式处理)
337    pub fn pipeline(name: String, nodes: Vec<NodeDef>) -> Self {
338        Self {
339            name,
340            nodes,
341            mode: ExecutionMode::Pipeline,
342        }
343    }
344
345    /// 创建 Parallel 模式分支(有屏障等待)
346    pub fn parallel(name: String, nodes: Vec<NodeDef>) -> Self {
347        Self {
348            name,
349            nodes,
350            mode: ExecutionMode::Parallel,
351        }
352    }
353
354    /// 是否需要屏障等待
355    pub fn has_barrier(&self) -> bool {
356        self.mode.has_barrier()
357    }
358
359    /// 获取执行模式描述
360    pub fn mode_description(&self) -> &'static str {
361        self.mode.description()
362    }
363}
364
365/// 工作流定义
366#[derive(Debug, Clone, Serialize, Deserialize)]
367pub struct WorkflowDef {
368    /// 工作流ID
369    pub id: String,
370    /// 工作流名称
371    pub name: String,
372    /// 版本
373    #[serde(default = "default_version")]
374    pub version: String,
375    /// 描述
376    #[serde(skip_serializing_if = "Option::is_none")]
377    pub description: Option<String>,
378    /// 输入参数定义
379    #[serde(default)]
380    pub inputs: Vec<InputDef>,
381    /// 输出参数定义
382    #[serde(default)]
383    pub outputs: Vec<OutputDef>,
384    /// 节点列表
385    pub nodes: Vec<NodeDef>,
386    /// 边列表
387    #[serde(default)]
388    pub edges: Vec<EdgeDef>,
389    /// 全局变量
390    #[serde(default)]
391    pub variables: HashMap<String, serde_json::Value>,
392    /// 默认失败策略
393    #[serde(default)]
394    pub default_failure_strategy: FailureStrategy,
395    /// 超时时间(毫秒)
396    #[serde(skip_serializing_if = "Option::is_none")]
397    pub timeout_ms: Option<u64>,
398}
399
400fn default_version() -> String {
401    "1.0.0".to_string()
402}
403
404/// 输入参数定义
405#[derive(Debug, Clone, Serialize, Deserialize)]
406pub struct InputDef {
407    /// 参数名
408    pub name: String,
409    /// 参数类型
410    #[serde(rename = "type", default = "default_input_type")]
411    pub input_type: String,
412    /// 是否必填
413    #[serde(default)]
414    pub required: bool,
415    /// 默认值
416    #[serde(skip_serializing_if = "Option::is_none")]
417    pub default: Option<serde_json::Value>,
418    /// 描述
419    #[serde(skip_serializing_if = "Option::is_none")]
420    pub description: Option<String>,
421}
422
423fn default_input_type() -> String {
424    "string".to_string()
425}
426
427/// 输出参数定义
428#[derive(Debug, Clone, Serialize, Deserialize)]
429pub struct OutputDef {
430    /// 参数名
431    pub name: String,
432    /// 值表达式
433    pub value: String,
434    /// 描述
435    #[serde(skip_serializing_if = "Option::is_none")]
436    pub description: Option<String>,
437}
438
439impl WorkflowDef {
440    /// 根据ID查找节点
441    pub fn get_node(&self, id: &str) -> Option<&NodeDef> {
442        self.nodes.iter().find(|n| n.id == id)
443    }
444
445    /// 获取开始节点
446    pub fn get_start_node(&self) -> Option<&NodeDef> {
447        self.nodes.iter().find(|n| n.node_type == NodeType::Start)
448    }
449
450    /// 获取结束节点
451    pub fn get_end_node(&self) -> Option<&NodeDef> {
452        self.nodes.iter().find(|n| n.node_type == NodeType::End)
453    }
454
455    /// 获取从指定节点出发的边
456    pub fn get_outgoing_edges(&self, node_id: &str) -> Vec<&EdgeDef> {
457        self.edges.iter().filter(|e| e.from == node_id).collect()
458    }
459
460    /// 验证工作流定义
461    pub fn validate(&self) -> anyhow::Result<()> {
462        // 检查必须有开始节点
463        if self.get_start_node().is_none() {
464            anyhow::bail!("Workflow must have a start node");
465        }
466
467        // 检查必须有结束节点
468        if self.get_end_node().is_none() {
469            anyhow::bail!("Workflow must have an end node");
470        }
471
472        // 检查节点ID唯一性
473        let mut node_ids = std::collections::HashSet::new();
474        for node in &self.nodes {
475            if !node_ids.insert(&node.id) {
476                anyhow::bail!("Duplicate node id: {}", node.id);
477            }
478        }
479
480        // 检查边引用的节点是否存在
481        for edge in &self.edges {
482            if !node_ids.contains(&edge.from) {
483                anyhow::bail!("Edge references unknown source node: {}", edge.from);
484            }
485            if !node_ids.contains(&edge.to) {
486                anyhow::bail!("Edge references unknown target node: {}", edge.to);
487            }
488        }
489
490        // 检查必填输入参数
491        for input in &self.inputs {
492            if input.required && input.default.is_none() {
493                // 必填参数没有默认值,需要在运行时提供
494            }
495        }
496
497        Ok(())
498    }
499}
500
501#[cfg(test)]
502mod tests {
503    use super::*;
504
505    #[test]
506    fn test_workflow_def_validation() {
507        let workflow = WorkflowDef {
508            id: "test-workflow".to_string(),
509            name: "Test Workflow".to_string(),
510            version: "1.0.0".to_string(),
511            description: None,
512            inputs: vec![],
513            outputs: vec![],
514            nodes: vec![
515                NodeDef {
516                    id: "start".to_string(),
517                    node_type: NodeType::Start,
518                    name: "Start".to_string(),
519                    description: None,
520                    task: None,
521                    params: HashMap::new(),
522                    on_failure: FailureStrategy::default(),
523                    timeout_ms: None,
524                    branches: None,
525                    parallel_branches: None,
526                    execution_mode: None,
527                    workflow: None,
528                    wait_ms: None,
529                    approvers: None,
530                },
531                NodeDef {
532                    id: "end".to_string(),
533                    node_type: NodeType::End,
534                    name: "End".to_string(),
535                    description: None,
536                    task: None,
537                    params: HashMap::new(),
538                    on_failure: FailureStrategy::default(),
539                    timeout_ms: None,
540                    branches: None,
541                    parallel_branches: None,
542                    execution_mode: None,
543                    workflow: None,
544                    wait_ms: None,
545                    approvers: None,
546                },
547            ],
548            edges: vec![EdgeDef {
549                id: "e1".to_string(),
550                from: "start".to_string(),
551                to: "end".to_string(),
552                condition: None,
553                label: None,
554            }],
555            variables: HashMap::new(),
556            default_failure_strategy: FailureStrategy::default(),
557            timeout_ms: None,
558        };
559
560        assert!(workflow.validate().is_ok());
561    }
562
563    #[test]
564    fn test_execution_mode_default() {
565        let mode = ExecutionMode::default();
566        assert_eq!(mode, ExecutionMode::Parallel);
567        assert!(mode.has_barrier());
568    }
569
570    #[test]
571    fn test_execution_mode_pipeline_no_barrier() {
572        let mode = ExecutionMode::Pipeline;
573        assert!(!mode.has_barrier());
574        assert_eq!(mode.display_name(), "流式");
575    }
576
577    #[test]
578    fn test_execution_mode_parallel_has_barrier() {
579        let mode = ExecutionMode::Parallel;
580        assert!(mode.has_barrier());
581        assert_eq!(mode.display_name(), "并行");
582    }
583
584    #[test]
585    fn test_parallel_branch_def_default_mode() {
586        let branch = ParallelBranchDef::new("test".to_string(), vec![]);
587        assert_eq!(branch.mode, ExecutionMode::Parallel);
588        assert!(branch.has_barrier());
589    }
590
591    #[test]
592    fn test_parallel_branch_def_pipeline_mode() {
593        let branch = ParallelBranchDef::pipeline("test".to_string(), vec![]);
594        assert_eq!(branch.mode, ExecutionMode::Pipeline);
595        assert!(!branch.has_barrier());
596    }
597
598    #[test]
599    fn test_parallel_branch_def_parallel_mode() {
600        let branch = ParallelBranchDef::parallel("test".to_string(), vec![]);
601        assert_eq!(branch.mode, ExecutionMode::Parallel);
602        assert!(branch.has_barrier());
603    }
604
605    #[test]
606    fn test_node_def_get_execution_mode_pipeline() {
607        let node = NodeDef {
608            id: "pipeline-node".to_string(),
609            node_type: NodeType::Pipeline,
610            name: "Pipeline Node".to_string(),
611            description: None,
612            task: None,
613            params: HashMap::new(),
614            on_failure: FailureStrategy::default(),
615            timeout_ms: None,
616            branches: None,
617            parallel_branches: None,
618            execution_mode: None,
619            workflow: None,
620            wait_ms: None,
621            approvers: None,
622        };
623        assert_eq!(node.get_execution_mode(), Some(ExecutionMode::Pipeline));
624        assert!(!node.has_barrier());
625    }
626
627    #[test]
628    fn test_node_def_get_execution_mode_parallel() {
629        let node = NodeDef {
630            id: "parallel-node".to_string(),
631            node_type: NodeType::Parallel,
632            name: "Parallel Node".to_string(),
633            description: None,
634            task: None,
635            params: HashMap::new(),
636            on_failure: FailureStrategy::default(),
637            timeout_ms: None,
638            branches: None,
639            parallel_branches: None,
640            execution_mode: None,
641            workflow: None,
642            wait_ms: None,
643            approvers: None,
644        };
645        assert_eq!(node.get_execution_mode(), Some(ExecutionMode::Parallel));
646        assert!(node.has_barrier());
647    }
648
649    #[test]
650    fn test_node_def_get_execution_mode_custom() {
651        // Parallel node with custom Pipeline mode
652        let node = NodeDef {
653            id: "custom-node".to_string(),
654            node_type: NodeType::Parallel,
655            name: "Custom Node".to_string(),
656            description: None,
657            task: None,
658            params: HashMap::new(),
659            on_failure: FailureStrategy::default(),
660            timeout_ms: None,
661            branches: None,
662            parallel_branches: None,
663            execution_mode: Some(ExecutionMode::Pipeline), // Override
664            workflow: None,
665            wait_ms: None,
666            approvers: None,
667        };
668        assert_eq!(node.get_execution_mode(), Some(ExecutionMode::Pipeline));
669        assert!(!node.has_barrier());
670    }
671
672    #[test]
673    fn test_node_def_get_execution_mode_other_types() {
674        let node = NodeDef {
675            id: "task-node".to_string(),
676            node_type: NodeType::Task,
677            name: "Task Node".to_string(),
678            description: None,
679            task: Some("do_something".to_string()),
680            params: HashMap::new(),
681            on_failure: FailureStrategy::default(),
682            timeout_ms: None,
683            branches: None,
684            parallel_branches: None,
685            execution_mode: None,
686            workflow: None,
687            wait_ms: None,
688            approvers: None,
689        };
690        assert_eq!(node.get_execution_mode(), None);
691    }
692}