Skip to main content

mofa_foundation/workflow/
graph.rs

1//! 工作流图结构
2//!
3//! 定义工作流的有向图结构和边
4
5use super::node::{NodeType, WorkflowNode};
6use serde::{Deserialize, Serialize};
7use std::collections::{HashMap, HashSet, VecDeque};
8use tracing::{debug, warn};
9
10/// 边类型
11#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
12pub enum EdgeType {
13    /// 普通边(顺序执行)
14    Normal,
15    /// 条件边(条件为真时执行)
16    Conditional(String),
17    /// 错误边(发生错误时执行)
18    Error,
19    /// 默认边(无其他边匹配时执行)
20    Default,
21}
22
23/// 边配置
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct EdgeConfig {
26    /// 源节点 ID
27    pub from: String,
28    /// 目标节点 ID
29    pub to: String,
30    /// 边类型
31    pub edge_type: EdgeType,
32    /// 边标签(用于显示)
33    pub label: Option<String>,
34}
35
36impl EdgeConfig {
37    pub fn new(from: &str, to: &str) -> Self {
38        Self {
39            from: from.to_string(),
40            to: to.to_string(),
41            edge_type: EdgeType::Normal,
42            label: None,
43        }
44    }
45
46    pub fn conditional(from: &str, to: &str, condition: &str) -> Self {
47        Self {
48            from: from.to_string(),
49            to: to.to_string(),
50            edge_type: EdgeType::Conditional(condition.to_string()),
51            label: Some(condition.to_string()),
52        }
53    }
54
55    pub fn error(from: &str, to: &str) -> Self {
56        Self {
57            from: from.to_string(),
58            to: to.to_string(),
59            edge_type: EdgeType::Error,
60            label: Some("error".to_string()),
61        }
62    }
63
64    pub fn default_edge(from: &str, to: &str) -> Self {
65        Self {
66            from: from.to_string(),
67            to: to.to_string(),
68            edge_type: EdgeType::Default,
69            label: Some("default".to_string()),
70        }
71    }
72
73    pub fn with_label(mut self, label: &str) -> Self {
74        self.label = Some(label.to_string());
75        self
76    }
77}
78
79/// 工作流图
80pub struct WorkflowGraph {
81    /// 图 ID
82    pub id: String,
83    /// 图名称
84    pub name: String,
85    /// 图描述
86    pub description: String,
87    /// 节点映射
88    nodes: HashMap<String, WorkflowNode>,
89    /// 边列表(邻接表:源节点 ID -> 边列表)
90    edges: HashMap<String, Vec<EdgeConfig>>,
91    /// 反向边(用于查找入边)
92    reverse_edges: HashMap<String, Vec<EdgeConfig>>,
93    /// 开始节点 ID
94    start_node: Option<String>,
95    /// 结束节点 ID 列表(可能有多个)
96    end_nodes: Vec<String>,
97}
98
99impl WorkflowGraph {
100    pub fn new(id: &str, name: &str) -> Self {
101        Self {
102            id: id.to_string(),
103            name: name.to_string(),
104            description: String::new(),
105            nodes: HashMap::new(),
106            edges: HashMap::new(),
107            reverse_edges: HashMap::new(),
108            start_node: None,
109            end_nodes: Vec::new(),
110        }
111    }
112
113    pub fn with_description(mut self, desc: &str) -> Self {
114        self.description = desc.to_string();
115        self
116    }
117
118    /// 添加节点
119    pub fn add_node(&mut self, node: WorkflowNode) -> &mut Self {
120        let node_id = node.id().to_string();
121
122        // 自动检测开始和结束节点
123        match node.node_type() {
124            NodeType::Start => {
125                self.start_node = Some(node_id.clone());
126            }
127            NodeType::End => {
128                self.end_nodes.push(node_id.clone());
129            }
130            _ => {}
131        }
132
133        self.nodes.insert(node_id.clone(), node);
134        self.edges.entry(node_id.clone()).or_default();
135        self.reverse_edges.entry(node_id).or_default();
136        self
137    }
138
139    /// 添加边
140    pub fn add_edge(&mut self, edge: EdgeConfig) -> &mut Self {
141        let from = edge.from.clone();
142        let to = edge.to.clone();
143
144        // 添加正向边
145        self.edges.entry(from).or_default().push(edge.clone());
146
147        // 添加反向边
148        self.reverse_edges.entry(to).or_default().push(edge);
149
150        self
151    }
152
153    /// 添加普通边
154    pub fn connect(&mut self, from: &str, to: &str) -> &mut Self {
155        self.add_edge(EdgeConfig::new(from, to))
156    }
157
158    /// 添加条件边
159    pub fn connect_conditional(&mut self, from: &str, to: &str, condition: &str) -> &mut Self {
160        self.add_edge(EdgeConfig::conditional(from, to, condition))
161    }
162
163    /// 获取节点
164    pub fn get_node(&self, node_id: &str) -> Option<&WorkflowNode> {
165        self.nodes.get(node_id)
166    }
167
168    /// 获取可变节点
169    pub fn get_node_mut(&mut self, node_id: &str) -> Option<&mut WorkflowNode> {
170        self.nodes.get_mut(node_id)
171    }
172
173    /// 获取所有节点 ID
174    pub fn node_ids(&self) -> Vec<&str> {
175        self.nodes.keys().map(|s| s.as_str()).collect()
176    }
177
178    /// 获取节点数量
179    pub fn node_count(&self) -> usize {
180        self.nodes.len()
181    }
182
183    /// 获取边数量
184    pub fn edge_count(&self) -> usize {
185        self.edges.values().map(|e| e.len()).sum()
186    }
187
188    /// 获取开始节点
189    pub fn start_node(&self) -> Option<&str> {
190        self.start_node.as_deref()
191    }
192
193    /// 获取结束节点列表
194    pub fn end_nodes(&self) -> &[String] {
195        &self.end_nodes
196    }
197
198    /// 获取节点的出边
199    pub fn get_outgoing_edges(&self, node_id: &str) -> &[EdgeConfig] {
200        self.edges.get(node_id).map(|v| v.as_slice()).unwrap_or(&[])
201    }
202
203    /// 获取节点的入边
204    pub fn get_incoming_edges(&self, node_id: &str) -> &[EdgeConfig] {
205        self.reverse_edges
206            .get(node_id)
207            .map(|v| v.as_slice())
208            .unwrap_or(&[])
209    }
210
211    /// 获取节点的后继节点
212    pub fn get_successors(&self, node_id: &str) -> Vec<&str> {
213        self.get_outgoing_edges(node_id)
214            .iter()
215            .map(|e| e.to.as_str())
216            .collect()
217    }
218
219    /// 获取节点的前驱节点
220    pub fn get_predecessors(&self, node_id: &str) -> Vec<&str> {
221        self.get_incoming_edges(node_id)
222            .iter()
223            .map(|e| e.from.as_str())
224            .collect()
225    }
226
227    /// 获取满足条件的下一个节点
228    pub fn get_next_node(&self, node_id: &str, condition: Option<&str>) -> Option<&str> {
229        let edges = self.get_outgoing_edges(node_id);
230
231        // 优先匹配条件边
232        if let Some(cond) = condition {
233            for edge in edges {
234                if let EdgeType::Conditional(c) = &edge.edge_type
235                    && c == cond
236                {
237                    return Some(&edge.to);
238                }
239            }
240        }
241
242        // 其次匹配默认边
243        for edge in edges {
244            if matches!(edge.edge_type, EdgeType::Default) {
245                return Some(&edge.to);
246            }
247        }
248
249        // 最后匹配普通边
250        for edge in edges {
251            if matches!(edge.edge_type, EdgeType::Normal) {
252                return Some(&edge.to);
253            }
254        }
255
256        None
257    }
258
259    /// 获取错误处理节点
260    pub fn get_error_handler(&self, node_id: &str) -> Option<&str> {
261        let edges = self.get_outgoing_edges(node_id);
262        for edge in edges {
263            if matches!(edge.edge_type, EdgeType::Error) {
264                return Some(&edge.to);
265            }
266        }
267        None
268    }
269
270    /// 拓扑排序
271    pub fn topological_sort(&self) -> Result<Vec<String>, String> {
272        let mut in_degree: HashMap<&str, usize> = HashMap::new();
273        let mut queue: VecDeque<&str> = VecDeque::new();
274        let mut result: Vec<String> = Vec::new();
275
276        // 计算入度
277        for node_id in self.nodes.keys() {
278            in_degree.insert(node_id, 0);
279        }
280        for edges in self.edges.values() {
281            for edge in edges {
282                *in_degree.entry(&edge.to).or_insert(0) += 1;
283            }
284        }
285
286        // 入度为 0 的节点入队
287        for (node_id, &degree) in &in_degree {
288            if degree == 0 {
289                queue.push_back(node_id);
290            }
291        }
292
293        // BFS
294        while let Some(node_id) = queue.pop_front() {
295            result.push(node_id.to_string());
296
297            for edge in self.get_outgoing_edges(node_id) {
298                if let Some(degree) = in_degree.get_mut(edge.to.as_str()) {
299                    *degree -= 1;
300                    if *degree == 0 {
301                        queue.push_back(&edge.to);
302                    }
303                }
304            }
305        }
306
307        // 检查是否有环
308        if result.len() != self.nodes.len() {
309            return Err("Graph contains a cycle".to_string());
310        }
311
312        Ok(result)
313    }
314
315    /// 检测环
316    pub fn has_cycle(&self) -> bool {
317        self.topological_sort().is_err()
318    }
319
320    /// 获取可以并行执行的节点组
321    pub fn get_parallel_groups(&self) -> Vec<Vec<String>> {
322        let mut groups: Vec<Vec<String>> = Vec::new();
323        let mut in_degree: HashMap<&str, usize> = HashMap::new();
324        let mut remaining: HashSet<&str> = self.nodes.keys().map(|s| s.as_str()).collect();
325
326        // 计算入度
327        for node_id in self.nodes.keys() {
328            in_degree.insert(node_id, 0);
329        }
330        for edges in self.edges.values() {
331            for edge in edges {
332                *in_degree.entry(&edge.to).or_insert(0) += 1;
333            }
334        }
335
336        while !remaining.is_empty() {
337            // 找出当前入度为 0 的节点
338            let ready: Vec<String> = remaining
339                .iter()
340                .filter(|&&node_id| in_degree.get(node_id).copied().unwrap_or(0) == 0)
341                .map(|&s| s.to_string())
342                .collect();
343
344            if ready.is_empty() {
345                warn!("Cycle detected in workflow graph");
346                break;
347            }
348
349            // 更新入度
350            for node_id in &ready {
351                remaining.remove(node_id.as_str());
352                for edge in self.get_outgoing_edges(node_id) {
353                    if let Some(degree) = in_degree.get_mut(edge.to.as_str()) {
354                        *degree = degree.saturating_sub(1);
355                    }
356                }
357            }
358
359            groups.push(ready);
360        }
361
362        groups
363    }
364
365    /// 验证图的完整性
366    pub fn validate(&self) -> Result<(), Vec<String>> {
367        let mut errors: Vec<String> = Vec::new();
368
369        // 检查是否有开始节点
370        if self.start_node.is_none() {
371            errors.push("No start node found".to_string());
372        }
373
374        // 检查是否有结束节点
375        if self.end_nodes.is_empty() {
376            errors.push("No end node found".to_string());
377        }
378
379        // 检查边引用的节点是否存在
380        for (from, edges) in &self.edges {
381            if !self.nodes.contains_key(from) {
382                errors.push(format!("Edge source node '{}' not found", from));
383            }
384            for edge in edges {
385                if !self.nodes.contains_key(&edge.to) {
386                    errors.push(format!("Edge target node '{}' not found", edge.to));
387                }
388            }
389        }
390
391        // 检查是否有孤立节点
392        for node_id in self.nodes.keys() {
393            if node_id != self.start_node.as_ref().unwrap_or(&String::new())
394                && self.get_incoming_edges(node_id).is_empty()
395            {
396                errors.push(format!("Node '{}' is unreachable", node_id));
397            }
398        }
399
400        // 检查是否有环
401        if self.has_cycle() {
402            errors.push("Graph contains a cycle".to_string());
403        }
404
405        // 检查并行节点是否有对应的聚合节点
406        for (node_id, node) in &self.nodes {
407            if matches!(node.node_type(), NodeType::Parallel) {
408                // 检查每个分支是否最终汇聚
409                debug!("Checking parallel node: {}", node_id);
410            }
411        }
412
413        if errors.is_empty() {
414            Ok(())
415        } else {
416            Err(errors)
417        }
418    }
419
420    /// 获取从源到目标的所有路径
421    pub fn find_all_paths(&self, from: &str, to: &str) -> Vec<Vec<String>> {
422        let mut paths: Vec<Vec<String>> = Vec::new();
423        let mut current_path: Vec<String> = Vec::new();
424        let mut visited: HashSet<String> = HashSet::new();
425
426        self.dfs_paths(from, to, &mut current_path, &mut visited, &mut paths);
427        paths
428    }
429
430    fn dfs_paths(
431        &self,
432        current: &str,
433        target: &str,
434        path: &mut Vec<String>,
435        visited: &mut HashSet<String>,
436        paths: &mut Vec<Vec<String>>,
437    ) {
438        path.push(current.to_string());
439        visited.insert(current.to_string());
440
441        if current == target {
442            paths.push(path.clone());
443        } else {
444            for edge in self.get_outgoing_edges(current) {
445                if !visited.contains(&edge.to) {
446                    self.dfs_paths(&edge.to, target, path, visited, paths);
447                }
448            }
449        }
450
451        path.pop();
452        visited.remove(current);
453    }
454
455    /// 导出为 DOT 格式(用于可视化)
456    pub fn to_dot(&self) -> String {
457        let mut dot = String::new();
458        dot.push_str(&format!("digraph \"{}\" {{\n", self.name));
459        dot.push_str("  rankdir=TB;\n");
460        dot.push_str("  node [shape=box];\n\n");
461
462        // 节点
463        for (node_id, node) in &self.nodes {
464            let shape = match node.node_type() {
465                NodeType::Start => "ellipse",
466                NodeType::End => "ellipse",
467                NodeType::Condition => "diamond",
468                NodeType::Parallel => "parallelogram",
469                NodeType::Join => "parallelogram",
470                NodeType::Loop => "hexagon",
471                _ => "box",
472            };
473            let color = match node.node_type() {
474                NodeType::Start => "green",
475                NodeType::End => "red",
476                NodeType::Condition => "yellow",
477                NodeType::Parallel | NodeType::Join => "cyan",
478                _ => "white",
479            };
480            dot.push_str(&format!(
481                "  \"{}\" [label=\"{}\\n({})\", shape={}, style=filled, fillcolor={}];\n",
482                node_id, node.config.name, node_id, shape, color
483            ));
484        }
485
486        dot.push('\n');
487
488        // 边
489        for (from, edges) in &self.edges {
490            for edge in edges {
491                let label = edge.label.as_deref().unwrap_or("");
492                let style = match edge.edge_type {
493                    EdgeType::Normal => "solid",
494                    EdgeType::Conditional(_) => "dashed",
495                    EdgeType::Error => "dotted",
496                    EdgeType::Default => "bold",
497                };
498                dot.push_str(&format!(
499                    "  \"{}\" -> \"{}\" [label=\"{}\", style={}];\n",
500                    from, edge.to, label, style
501                ));
502            }
503        }
504
505        dot.push_str("}\n");
506        dot
507    }
508}
509
510#[cfg(test)]
511mod tests {
512    use super::*;
513
514    fn create_test_graph() -> WorkflowGraph {
515        let mut graph = WorkflowGraph::new("test", "Test Workflow");
516
517        graph.add_node(WorkflowNode::start("start"));
518        graph.add_node(WorkflowNode::task(
519            "task1",
520            "Task 1",
521            |_ctx, input| async move { Ok(input) },
522        ));
523        graph.add_node(WorkflowNode::task(
524            "task2",
525            "Task 2",
526            |_ctx, input| async move { Ok(input) },
527        ));
528        graph.add_node(WorkflowNode::end("end"));
529
530        graph.connect("start", "task1");
531        graph.connect("task1", "task2");
532        graph.connect("task2", "end");
533
534        graph
535    }
536
537    #[test]
538    fn test_topological_sort() {
539        let graph = create_test_graph();
540        let sorted = graph.topological_sort().unwrap();
541
542        // start 应该在前面
543        let start_pos = sorted.iter().position(|x| x == "start").unwrap();
544        let task1_pos = sorted.iter().position(|x| x == "task1").unwrap();
545        let task2_pos = sorted.iter().position(|x| x == "task2").unwrap();
546        let end_pos = sorted.iter().position(|x| x == "end").unwrap();
547
548        assert!(start_pos < task1_pos);
549        assert!(task1_pos < task2_pos);
550        assert!(task2_pos < end_pos);
551    }
552
553    #[test]
554    fn test_parallel_groups() {
555        let mut graph = WorkflowGraph::new("test", "Test");
556
557        graph.add_node(WorkflowNode::start("start"));
558        graph.add_node(WorkflowNode::task("a", "A", |_ctx, input| async move {
559            Ok(input)
560        }));
561        graph.add_node(WorkflowNode::task("b", "B", |_ctx, input| async move {
562            Ok(input)
563        }));
564        graph.add_node(WorkflowNode::task("c", "C", |_ctx, input| async move {
565            Ok(input)
566        }));
567        graph.add_node(WorkflowNode::end("end"));
568
569        graph.connect("start", "a");
570        graph.connect("start", "b");
571        graph.connect("a", "c");
572        graph.connect("b", "c");
573        graph.connect("c", "end");
574
575        let groups = graph.get_parallel_groups();
576
577        // 第一组: start
578        // 第二组: a, b (可并行)
579        // 第三组: c
580        // 第四组: end
581        assert_eq!(groups.len(), 4);
582        assert!(groups[1].contains(&"a".to_string()) && groups[1].contains(&"b".to_string()));
583    }
584
585    #[test]
586    fn test_cycle_detection() {
587        let mut graph = WorkflowGraph::new("test", "Test");
588
589        graph.add_node(WorkflowNode::task("a", "A", |_ctx, input| async move {
590            Ok(input)
591        }));
592        graph.add_node(WorkflowNode::task("b", "B", |_ctx, input| async move {
593            Ok(input)
594        }));
595        graph.add_node(WorkflowNode::task("c", "C", |_ctx, input| async move {
596            Ok(input)
597        }));
598
599        graph.connect("a", "b");
600        graph.connect("b", "c");
601        graph.connect("c", "a"); // 形成环
602
603        assert!(graph.has_cycle());
604    }
605
606    #[test]
607    fn test_find_paths() {
608        let graph = create_test_graph();
609        let paths = graph.find_all_paths("start", "end");
610
611        assert_eq!(paths.len(), 1);
612        assert_eq!(paths[0], vec!["start", "task1", "task2", "end"]);
613    }
614
615    #[test]
616    fn test_to_dot() {
617        let graph = create_test_graph();
618        let dot = graph.to_dot();
619
620        assert!(dot.contains("digraph"));
621        assert!(dot.contains("start"));
622        assert!(dot.contains("end"));
623        assert!(dot.contains("->"));
624    }
625}