Skip to main content

matrixcode_core/workflow/
def.rs

1//! Workflow Definition Structures
2//!
3//! 定义工作流的核心数据结构,包括节点、边、类型和失败策略。
4
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7
8/// 节点类型
9#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
10#[serde(rename_all = "snake_case")]
11pub enum NodeType {
12    /// 开始节点
13    Start,
14    /// 结束节点
15    End,
16    /// 任务节点
17    Task,
18    /// 条件分支节点
19    Condition,
20    /// 并行节点
21    Parallel,
22    /// 子工作流节点
23    SubWorkflow,
24    /// 等待节点
25    Wait,
26    /// 人工审批节点
27    Approval,
28}
29
30/// 失败策略类型
31#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
32#[serde(rename_all = "snake_case")]
33pub enum FailureStrategyType {
34    Retry,
35    Ignore,
36    Abort,
37    Goto,
38}
39
40/// 失败策略配置(用于 YAML 解析)
41#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct FailureStrategyConfig {
43    /// 策略类型
44    #[serde(rename = "type", default = "default_failure_strategy_type")]
45    pub strategy_type: FailureStrategyType,
46    /// 最大重试次数(仅 retry)
47    #[serde(skip_serializing_if = "Option::is_none")]
48    pub max_attempts: Option<u32>,
49    /// 重试间隔(毫秒,仅 retry)
50    #[serde(skip_serializing_if = "Option::is_none")]
51    pub interval_ms: Option<u64>,
52    /// 目标节点ID(仅 goto)
53    #[serde(skip_serializing_if = "Option::is_none")]
54    pub target: Option<String>,
55}
56
57fn default_failure_strategy_type() -> FailureStrategyType {
58    FailureStrategyType::Abort
59}
60
61impl From<FailureStrategyConfig> for FailureStrategy {
62    fn from(config: FailureStrategyConfig) -> Self {
63        match config.strategy_type {
64            FailureStrategyType::Retry => FailureStrategy::Retry {
65                max_attempts: config.max_attempts.unwrap_or(1),
66                interval_ms: config.interval_ms,
67            },
68            FailureStrategyType::Ignore => FailureStrategy::Ignore,
69            FailureStrategyType::Abort => FailureStrategy::Abort,
70            FailureStrategyType::Goto => FailureStrategy::Goto {
71                target: config.target.unwrap_or_default(),
72            },
73        }
74    }
75}
76
77impl From<FailureStrategy> for FailureStrategyConfig {
78    fn from(strategy: FailureStrategy) -> Self {
79        match strategy {
80            FailureStrategy::Retry { max_attempts, interval_ms } => FailureStrategyConfig {
81                strategy_type: FailureStrategyType::Retry,
82                max_attempts: Some(max_attempts),
83                interval_ms,
84                target: None,
85            },
86            FailureStrategy::Ignore => FailureStrategyConfig {
87                strategy_type: FailureStrategyType::Ignore,
88                max_attempts: None,
89                interval_ms: None,
90                target: None,
91            },
92            FailureStrategy::Abort => FailureStrategyConfig {
93                strategy_type: FailureStrategyType::Abort,
94                max_attempts: None,
95                interval_ms: None,
96                target: None,
97            },
98            FailureStrategy::Goto { target } => FailureStrategyConfig {
99                strategy_type: FailureStrategyType::Goto,
100                max_attempts: None,
101                interval_ms: None,
102                target: Some(target),
103            },
104        }
105    }
106}
107
108/// 失败策略
109#[derive(Debug, Clone, PartialEq, Eq)]
110#[derive(Default)]
111pub enum FailureStrategy {
112    /// 重试
113    Retry {
114        /// 最大重试次数
115        max_attempts: u32,
116        /// 重试间隔(毫秒)
117        interval_ms: Option<u64>,
118    },
119    /// 忽略继续
120    Ignore,
121    /// 终止工作流
122    #[default]
123    Abort,
124    /// 跳转到指定节点
125    Goto {
126        /// 目标节点ID
127        target: String,
128    },
129}
130
131
132impl Serialize for FailureStrategy {
133    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
134    where
135        S: serde::Serializer,
136    {
137        let config: FailureStrategyConfig = self.clone().into();
138        config.serialize(serializer)
139    }
140}
141
142impl<'de> Deserialize<'de> for FailureStrategy {
143    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
144    where
145        D: serde::Deserializer<'de>,
146    {
147        let config: FailureStrategyConfig = FailureStrategyConfig::deserialize(deserializer)?;
148        Ok(config.into())
149    }
150}
151
152/// 边定义
153#[derive(Debug, Clone, Serialize, Deserialize)]
154pub struct EdgeDef {
155    /// 边ID
156    #[serde(default = "generate_edge_id")]
157    pub id: String,
158    /// 源节点ID
159    pub from: String,
160    /// 目标节点ID
161    pub to: String,
162    /// 条件表达式(可选)
163    #[serde(skip_serializing_if = "Option::is_none")]
164    pub condition: Option<String>,
165    /// 边标签
166    #[serde(skip_serializing_if = "Option::is_none")]
167    pub label: Option<String>,
168}
169
170fn generate_edge_id() -> String {
171    format!("edge_{}", uuid::Uuid::new_v4())
172}
173
174/// 节点定义
175#[derive(Debug, Clone, Serialize, Deserialize)]
176pub struct NodeDef {
177    /// 节点ID
178    pub id: String,
179    /// 节点类型
180    #[serde(rename = "type")]
181    pub node_type: NodeType,
182    /// 节点名称
183    pub name: String,
184    /// 节点描述
185    #[serde(skip_serializing_if = "Option::is_none")]
186    pub description: Option<String>,
187    /// 任务名称(仅任务节点)
188    #[serde(skip_serializing_if = "Option::is_none")]
189    pub task: Option<String>,
190    /// 任务参数
191    #[serde(default)]
192    pub params: HashMap<String, serde_json::Value>,
193    /// 失败策略
194    #[serde(default)]
195    pub on_failure: FailureStrategy,
196    /// 超时时间(毫秒)
197    #[serde(skip_serializing_if = "Option::is_none")]
198    pub timeout_ms: Option<u64>,
199    /// 条件分支(仅条件节点)
200    #[serde(skip_serializing_if = "Option::is_none")]
201    pub branches: Option<Vec<BranchDef>>,
202    /// 并行分支(仅并行节点)
203    #[serde(skip_serializing_if = "Option::is_none")]
204    pub parallel_branches: Option<Vec<ParallelBranchDef>>,
205    /// 子工作流名称(仅子工作流节点)
206    #[serde(skip_serializing_if = "Option::is_none")]
207    pub workflow: Option<String>,
208    /// 等待时间(毫秒,仅等待节点)
209    #[serde(skip_serializing_if = "Option::is_none")]
210    pub wait_ms: Option<u64>,
211    /// 审批人列表(仅审批节点)
212    #[serde(skip_serializing_if = "Option::is_none")]
213    pub approvers: Option<Vec<String>>,
214}
215
216/// 条件分支定义
217#[derive(Debug, Clone, Serialize, Deserialize)]
218pub struct BranchDef {
219    /// 分支名称
220    pub name: String,
221    /// 条件表达式
222    pub condition: String,
223    /// 目标节点ID
224    pub target: String,
225}
226
227/// 并行分支定义
228#[derive(Debug, Clone, Serialize, Deserialize)]
229pub struct ParallelBranchDef {
230    /// 分支名称
231    pub name: String,
232    /// 分支节点列表
233    pub nodes: Vec<NodeDef>,
234}
235
236/// 工作流定义
237#[derive(Debug, Clone, Serialize, Deserialize)]
238pub struct WorkflowDef {
239    /// 工作流ID
240    pub id: String,
241    /// 工作流名称
242    pub name: String,
243    /// 版本
244    #[serde(default = "default_version")]
245    pub version: String,
246    /// 描述
247    #[serde(skip_serializing_if = "Option::is_none")]
248    pub description: Option<String>,
249    /// 输入参数定义
250    #[serde(default)]
251    pub inputs: Vec<InputDef>,
252    /// 输出参数定义
253    #[serde(default)]
254    pub outputs: Vec<OutputDef>,
255    /// 节点列表
256    pub nodes: Vec<NodeDef>,
257    /// 边列表
258    #[serde(default)]
259    pub edges: Vec<EdgeDef>,
260    /// 全局变量
261    #[serde(default)]
262    pub variables: HashMap<String, serde_json::Value>,
263    /// 默认失败策略
264    #[serde(default)]
265    pub default_failure_strategy: FailureStrategy,
266    /// 超时时间(毫秒)
267    #[serde(skip_serializing_if = "Option::is_none")]
268    pub timeout_ms: Option<u64>,
269}
270
271fn default_version() -> String {
272    "1.0.0".to_string()
273}
274
275/// 输入参数定义
276#[derive(Debug, Clone, Serialize, Deserialize)]
277pub struct InputDef {
278    /// 参数名
279    pub name: String,
280    /// 参数类型
281    #[serde(rename = "type", default = "default_input_type")]
282    pub input_type: String,
283    /// 是否必填
284    #[serde(default)]
285    pub required: bool,
286    /// 默认值
287    #[serde(skip_serializing_if = "Option::is_none")]
288    pub default: Option<serde_json::Value>,
289    /// 描述
290    #[serde(skip_serializing_if = "Option::is_none")]
291    pub description: Option<String>,
292}
293
294fn default_input_type() -> String {
295    "string".to_string()
296}
297
298/// 输出参数定义
299#[derive(Debug, Clone, Serialize, Deserialize)]
300pub struct OutputDef {
301    /// 参数名
302    pub name: String,
303    /// 值表达式
304    pub value: String,
305    /// 描述
306    #[serde(skip_serializing_if = "Option::is_none")]
307    pub description: Option<String>,
308}
309
310impl WorkflowDef {
311    /// 根据ID查找节点
312    pub fn get_node(&self, id: &str) -> Option<&NodeDef> {
313        self.nodes.iter().find(|n| n.id == id)
314    }
315
316    /// 获取开始节点
317    pub fn get_start_node(&self) -> Option<&NodeDef> {
318        self.nodes.iter().find(|n| n.node_type == NodeType::Start)
319    }
320
321    /// 获取结束节点
322    pub fn get_end_node(&self) -> Option<&NodeDef> {
323        self.nodes.iter().find(|n| n.node_type == NodeType::End)
324    }
325
326    /// 获取从指定节点出发的边
327    pub fn get_outgoing_edges(&self, node_id: &str) -> Vec<&EdgeDef> {
328        self.edges.iter().filter(|e| e.from == node_id).collect()
329    }
330
331    /// 验证工作流定义
332    pub fn validate(&self) -> anyhow::Result<()> {
333        // 检查必须有开始节点
334        if self.get_start_node().is_none() {
335            anyhow::bail!("Workflow must have a start node");
336        }
337
338        // 检查必须有结束节点
339        if self.get_end_node().is_none() {
340            anyhow::bail!("Workflow must have an end node");
341        }
342
343        // 检查节点ID唯一性
344        let mut node_ids = std::collections::HashSet::new();
345        for node in &self.nodes {
346            if !node_ids.insert(&node.id) {
347                anyhow::bail!("Duplicate node id: {}", node.id);
348            }
349        }
350
351        // 检查边引用的节点是否存在
352        for edge in &self.edges {
353            if !node_ids.contains(&edge.from) {
354                anyhow::bail!("Edge references unknown source node: {}", edge.from);
355            }
356            if !node_ids.contains(&edge.to) {
357                anyhow::bail!("Edge references unknown target node: {}", edge.to);
358            }
359        }
360
361        // 检查必填输入参数
362        for input in &self.inputs {
363            if input.required && input.default.is_none() {
364                // 必填参数没有默认值,需要在运行时提供
365            }
366        }
367
368        Ok(())
369    }
370}
371
372#[cfg(test)]
373mod tests {
374    use super::*;
375
376    #[test]
377    fn test_workflow_def_validation() {
378        let workflow = WorkflowDef {
379            id: "test-workflow".to_string(),
380            name: "Test Workflow".to_string(),
381            version: "1.0.0".to_string(),
382            description: None,
383            inputs: vec![],
384            outputs: vec![],
385            nodes: vec![
386                NodeDef {
387                    id: "start".to_string(),
388                    node_type: NodeType::Start,
389                    name: "Start".to_string(),
390                    description: None,
391                    task: None,
392                    params: HashMap::new(),
393                    on_failure: FailureStrategy::default(),
394                    timeout_ms: None,
395                    branches: None,
396                    parallel_branches: None,
397                    workflow: None,
398                    wait_ms: None,
399                    approvers: None,
400                },
401                NodeDef {
402                    id: "end".to_string(),
403                    node_type: NodeType::End,
404                    name: "End".to_string(),
405                    description: None,
406                    task: None,
407                    params: HashMap::new(),
408                    on_failure: FailureStrategy::default(),
409                    timeout_ms: None,
410                    branches: None,
411                    parallel_branches: None,
412                    workflow: None,
413                    wait_ms: None,
414                    approvers: None,
415                },
416            ],
417            edges: vec![EdgeDef {
418                id: "e1".to_string(),
419                from: "start".to_string(),
420                to: "end".to_string(),
421                condition: None,
422                label: None,
423            }],
424            variables: HashMap::new(),
425            default_failure_strategy: FailureStrategy::default(),
426            timeout_ms: None,
427        };
428
429        assert!(workflow.validate().is_ok());
430    }
431}