Skip to main content

matrixcode_core/workflow/
def.rs

1//! Workflow Definition Structures
2//!
3//! 定义工作流的核心数据结构,包括节点、边、类型和失败策略。
4
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7
8/// 节点类型
9#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
10#[serde(rename_all = "snake_case")]
11pub enum NodeType {
12    /// 开始节点
13    Start,
14    /// 结束节点
15    End,
16    /// 任务节点
17    Task,
18    /// 条件分支节点
19    Condition,
20    /// 并行节点
21    Parallel,
22    /// 子工作流节点
23    SubWorkflow,
24    /// 等待节点
25    Wait,
26    /// 人工审批节点
27    Approval,
28}
29
30/// 失败策略类型
31#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
32#[serde(rename_all = "snake_case")]
33pub enum FailureStrategyType {
34    Retry,
35    Ignore,
36    Abort,
37    Goto,
38}
39
40/// 失败策略配置(用于 YAML 解析)
41#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct FailureStrategyConfig {
43    /// 策略类型
44    #[serde(rename = "type", default = "default_failure_strategy_type")]
45    pub strategy_type: FailureStrategyType,
46    /// 最大重试次数(仅 retry)
47    #[serde(skip_serializing_if = "Option::is_none")]
48    pub max_attempts: Option<u32>,
49    /// 重试间隔(毫秒,仅 retry)
50    #[serde(skip_serializing_if = "Option::is_none")]
51    pub interval_ms: Option<u64>,
52    /// 目标节点ID(仅 goto)
53    #[serde(skip_serializing_if = "Option::is_none")]
54    pub target: Option<String>,
55}
56
57fn default_failure_strategy_type() -> FailureStrategyType {
58    FailureStrategyType::Abort
59}
60
61impl From<FailureStrategyConfig> for FailureStrategy {
62    fn from(config: FailureStrategyConfig) -> Self {
63        match config.strategy_type {
64            FailureStrategyType::Retry => FailureStrategy::Retry {
65                max_attempts: config.max_attempts.unwrap_or(1),
66                interval_ms: config.interval_ms,
67            },
68            FailureStrategyType::Ignore => FailureStrategy::Ignore,
69            FailureStrategyType::Abort => FailureStrategy::Abort,
70            FailureStrategyType::Goto => FailureStrategy::Goto {
71                target: config.target.unwrap_or_default(),
72            },
73        }
74    }
75}
76
77impl From<FailureStrategy> for FailureStrategyConfig {
78    fn from(strategy: FailureStrategy) -> Self {
79        match strategy {
80            FailureStrategy::Retry {
81                max_attempts,
82                interval_ms,
83            } => FailureStrategyConfig {
84                strategy_type: FailureStrategyType::Retry,
85                max_attempts: Some(max_attempts),
86                interval_ms,
87                target: None,
88            },
89            FailureStrategy::Ignore => FailureStrategyConfig {
90                strategy_type: FailureStrategyType::Ignore,
91                max_attempts: None,
92                interval_ms: None,
93                target: None,
94            },
95            FailureStrategy::Abort => FailureStrategyConfig {
96                strategy_type: FailureStrategyType::Abort,
97                max_attempts: None,
98                interval_ms: None,
99                target: None,
100            },
101            FailureStrategy::Goto { target } => FailureStrategyConfig {
102                strategy_type: FailureStrategyType::Goto,
103                max_attempts: None,
104                interval_ms: None,
105                target: Some(target),
106            },
107        }
108    }
109}
110
111/// 失败策略
112#[derive(Debug, Clone, PartialEq, Eq, Default)]
113pub enum FailureStrategy {
114    /// 重试
115    Retry {
116        /// 最大重试次数
117        max_attempts: u32,
118        /// 重试间隔(毫秒)
119        interval_ms: Option<u64>,
120    },
121    /// 忽略继续
122    Ignore,
123    /// 终止工作流
124    #[default]
125    Abort,
126    /// 跳转到指定节点
127    Goto {
128        /// 目标节点ID
129        target: String,
130    },
131}
132
133impl Serialize for FailureStrategy {
134    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
135    where
136        S: serde::Serializer,
137    {
138        let config: FailureStrategyConfig = self.clone().into();
139        config.serialize(serializer)
140    }
141}
142
143impl<'de> Deserialize<'de> for FailureStrategy {
144    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
145    where
146        D: serde::Deserializer<'de>,
147    {
148        let config: FailureStrategyConfig = FailureStrategyConfig::deserialize(deserializer)?;
149        Ok(config.into())
150    }
151}
152
153/// 边定义
154#[derive(Debug, Clone, Serialize, Deserialize)]
155pub struct EdgeDef {
156    /// 边ID
157    #[serde(default = "generate_edge_id")]
158    pub id: String,
159    /// 源节点ID
160    pub from: String,
161    /// 目标节点ID
162    pub to: String,
163    /// 条件表达式(可选)
164    #[serde(skip_serializing_if = "Option::is_none")]
165    pub condition: Option<String>,
166    /// 边标签
167    #[serde(skip_serializing_if = "Option::is_none")]
168    pub label: Option<String>,
169}
170
171fn generate_edge_id() -> String {
172    format!("edge_{}", uuid::Uuid::new_v4())
173}
174
175/// 节点定义
176#[derive(Debug, Clone, Serialize, Deserialize)]
177pub struct NodeDef {
178    /// 节点ID
179    pub id: String,
180    /// 节点类型
181    #[serde(rename = "type")]
182    pub node_type: NodeType,
183    /// 节点名称
184    pub name: String,
185    /// 节点描述
186    #[serde(skip_serializing_if = "Option::is_none")]
187    pub description: Option<String>,
188    /// 任务名称(仅任务节点)
189    #[serde(skip_serializing_if = "Option::is_none")]
190    pub task: Option<String>,
191    /// 任务参数
192    #[serde(default)]
193    pub params: HashMap<String, serde_json::Value>,
194    /// 失败策略
195    #[serde(default)]
196    pub on_failure: FailureStrategy,
197    /// 超时时间(毫秒)
198    #[serde(skip_serializing_if = "Option::is_none")]
199    pub timeout_ms: Option<u64>,
200    /// 条件分支(仅条件节点)
201    #[serde(skip_serializing_if = "Option::is_none")]
202    pub branches: Option<Vec<BranchDef>>,
203    /// 并行分支(仅并行节点)
204    #[serde(skip_serializing_if = "Option::is_none")]
205    pub parallel_branches: Option<Vec<ParallelBranchDef>>,
206    /// 子工作流名称(仅子工作流节点)
207    #[serde(skip_serializing_if = "Option::is_none")]
208    pub workflow: Option<String>,
209    /// 等待时间(毫秒,仅等待节点)
210    #[serde(skip_serializing_if = "Option::is_none")]
211    pub wait_ms: Option<u64>,
212    /// 审批人列表(仅审批节点)
213    #[serde(skip_serializing_if = "Option::is_none")]
214    pub approvers: Option<Vec<String>>,
215}
216
217/// 条件分支定义
218#[derive(Debug, Clone, Serialize, Deserialize)]
219pub struct BranchDef {
220    /// 分支名称
221    pub name: String,
222    /// 条件表达式
223    pub condition: String,
224    /// 目标节点ID
225    pub target: String,
226}
227
228/// 并行分支定义
229#[derive(Debug, Clone, Serialize, Deserialize)]
230pub struct ParallelBranchDef {
231    /// 分支名称
232    pub name: String,
233    /// 分支节点列表
234    pub nodes: Vec<NodeDef>,
235}
236
237/// 工作流定义
238#[derive(Debug, Clone, Serialize, Deserialize)]
239pub struct WorkflowDef {
240    /// 工作流ID
241    pub id: String,
242    /// 工作流名称
243    pub name: String,
244    /// 版本
245    #[serde(default = "default_version")]
246    pub version: String,
247    /// 描述
248    #[serde(skip_serializing_if = "Option::is_none")]
249    pub description: Option<String>,
250    /// 输入参数定义
251    #[serde(default)]
252    pub inputs: Vec<InputDef>,
253    /// 输出参数定义
254    #[serde(default)]
255    pub outputs: Vec<OutputDef>,
256    /// 节点列表
257    pub nodes: Vec<NodeDef>,
258    /// 边列表
259    #[serde(default)]
260    pub edges: Vec<EdgeDef>,
261    /// 全局变量
262    #[serde(default)]
263    pub variables: HashMap<String, serde_json::Value>,
264    /// 默认失败策略
265    #[serde(default)]
266    pub default_failure_strategy: FailureStrategy,
267    /// 超时时间(毫秒)
268    #[serde(skip_serializing_if = "Option::is_none")]
269    pub timeout_ms: Option<u64>,
270}
271
272fn default_version() -> String {
273    "1.0.0".to_string()
274}
275
276/// 输入参数定义
277#[derive(Debug, Clone, Serialize, Deserialize)]
278pub struct InputDef {
279    /// 参数名
280    pub name: String,
281    /// 参数类型
282    #[serde(rename = "type", default = "default_input_type")]
283    pub input_type: String,
284    /// 是否必填
285    #[serde(default)]
286    pub required: bool,
287    /// 默认值
288    #[serde(skip_serializing_if = "Option::is_none")]
289    pub default: Option<serde_json::Value>,
290    /// 描述
291    #[serde(skip_serializing_if = "Option::is_none")]
292    pub description: Option<String>,
293}
294
295fn default_input_type() -> String {
296    "string".to_string()
297}
298
299/// 输出参数定义
300#[derive(Debug, Clone, Serialize, Deserialize)]
301pub struct OutputDef {
302    /// 参数名
303    pub name: String,
304    /// 值表达式
305    pub value: String,
306    /// 描述
307    #[serde(skip_serializing_if = "Option::is_none")]
308    pub description: Option<String>,
309}
310
311impl WorkflowDef {
312    /// 根据ID查找节点
313    pub fn get_node(&self, id: &str) -> Option<&NodeDef> {
314        self.nodes.iter().find(|n| n.id == id)
315    }
316
317    /// 获取开始节点
318    pub fn get_start_node(&self) -> Option<&NodeDef> {
319        self.nodes.iter().find(|n| n.node_type == NodeType::Start)
320    }
321
322    /// 获取结束节点
323    pub fn get_end_node(&self) -> Option<&NodeDef> {
324        self.nodes.iter().find(|n| n.node_type == NodeType::End)
325    }
326
327    /// 获取从指定节点出发的边
328    pub fn get_outgoing_edges(&self, node_id: &str) -> Vec<&EdgeDef> {
329        self.edges.iter().filter(|e| e.from == node_id).collect()
330    }
331
332    /// 验证工作流定义
333    pub fn validate(&self) -> anyhow::Result<()> {
334        // 检查必须有开始节点
335        if self.get_start_node().is_none() {
336            anyhow::bail!("Workflow must have a start node");
337        }
338
339        // 检查必须有结束节点
340        if self.get_end_node().is_none() {
341            anyhow::bail!("Workflow must have an end node");
342        }
343
344        // 检查节点ID唯一性
345        let mut node_ids = std::collections::HashSet::new();
346        for node in &self.nodes {
347            if !node_ids.insert(&node.id) {
348                anyhow::bail!("Duplicate node id: {}", node.id);
349            }
350        }
351
352        // 检查边引用的节点是否存在
353        for edge in &self.edges {
354            if !node_ids.contains(&edge.from) {
355                anyhow::bail!("Edge references unknown source node: {}", edge.from);
356            }
357            if !node_ids.contains(&edge.to) {
358                anyhow::bail!("Edge references unknown target node: {}", edge.to);
359            }
360        }
361
362        // 检查必填输入参数
363        for input in &self.inputs {
364            if input.required && input.default.is_none() {
365                // 必填参数没有默认值,需要在运行时提供
366            }
367        }
368
369        Ok(())
370    }
371}
372
373#[cfg(test)]
374mod tests {
375    use super::*;
376
377    #[test]
378    fn test_workflow_def_validation() {
379        let workflow = WorkflowDef {
380            id: "test-workflow".to_string(),
381            name: "Test Workflow".to_string(),
382            version: "1.0.0".to_string(),
383            description: None,
384            inputs: vec![],
385            outputs: vec![],
386            nodes: vec![
387                NodeDef {
388                    id: "start".to_string(),
389                    node_type: NodeType::Start,
390                    name: "Start".to_string(),
391                    description: None,
392                    task: None,
393                    params: HashMap::new(),
394                    on_failure: FailureStrategy::default(),
395                    timeout_ms: None,
396                    branches: None,
397                    parallel_branches: None,
398                    workflow: None,
399                    wait_ms: None,
400                    approvers: None,
401                },
402                NodeDef {
403                    id: "end".to_string(),
404                    node_type: NodeType::End,
405                    name: "End".to_string(),
406                    description: None,
407                    task: None,
408                    params: HashMap::new(),
409                    on_failure: FailureStrategy::default(),
410                    timeout_ms: None,
411                    branches: None,
412                    parallel_branches: None,
413                    workflow: None,
414                    wait_ms: None,
415                    approvers: None,
416                },
417            ],
418            edges: vec![EdgeDef {
419                id: "e1".to_string(),
420                from: "start".to_string(),
421                to: "end".to_string(),
422                condition: None,
423                label: None,
424            }],
425            variables: HashMap::new(),
426            default_failure_strategy: FailureStrategy::default(),
427            timeout_ms: None,
428        };
429
430        assert!(workflow.validate().is_ok());
431    }
432}