1use super::node::{NodeType, WorkflowNode};
6use serde::{Deserialize, Serialize};
7use std::collections::{HashMap, HashSet, VecDeque};
8use tracing::{debug, warn};
9
10#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
12pub enum EdgeType {
13 Normal,
15 Conditional(String),
17 Error,
19 Default,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct EdgeConfig {
26 pub from: String,
28 pub to: String,
30 pub edge_type: EdgeType,
32 pub label: Option<String>,
34}
35
36impl EdgeConfig {
37 pub fn new(from: &str, to: &str) -> Self {
38 Self {
39 from: from.to_string(),
40 to: to.to_string(),
41 edge_type: EdgeType::Normal,
42 label: None,
43 }
44 }
45
46 pub fn conditional(from: &str, to: &str, condition: &str) -> Self {
47 Self {
48 from: from.to_string(),
49 to: to.to_string(),
50 edge_type: EdgeType::Conditional(condition.to_string()),
51 label: Some(condition.to_string()),
52 }
53 }
54
55 pub fn error(from: &str, to: &str) -> Self {
56 Self {
57 from: from.to_string(),
58 to: to.to_string(),
59 edge_type: EdgeType::Error,
60 label: Some("error".to_string()),
61 }
62 }
63
64 pub fn default_edge(from: &str, to: &str) -> Self {
65 Self {
66 from: from.to_string(),
67 to: to.to_string(),
68 edge_type: EdgeType::Default,
69 label: Some("default".to_string()),
70 }
71 }
72
73 pub fn with_label(mut self, label: &str) -> Self {
74 self.label = Some(label.to_string());
75 self
76 }
77}
78
79pub struct WorkflowGraph {
81 pub id: String,
83 pub name: String,
85 pub description: String,
87 nodes: HashMap<String, WorkflowNode>,
89 edges: HashMap<String, Vec<EdgeConfig>>,
91 reverse_edges: HashMap<String, Vec<EdgeConfig>>,
93 start_node: Option<String>,
95 end_nodes: Vec<String>,
97}
98
99impl WorkflowGraph {
100 pub fn new(id: &str, name: &str) -> Self {
101 Self {
102 id: id.to_string(),
103 name: name.to_string(),
104 description: String::new(),
105 nodes: HashMap::new(),
106 edges: HashMap::new(),
107 reverse_edges: HashMap::new(),
108 start_node: None,
109 end_nodes: Vec::new(),
110 }
111 }
112
113 pub fn with_description(mut self, desc: &str) -> Self {
114 self.description = desc.to_string();
115 self
116 }
117
118 pub fn add_node(&mut self, node: WorkflowNode) -> &mut Self {
120 let node_id = node.id().to_string();
121
122 match node.node_type() {
124 NodeType::Start => {
125 self.start_node = Some(node_id.clone());
126 }
127 NodeType::End => {
128 self.end_nodes.push(node_id.clone());
129 }
130 _ => {}
131 }
132
133 self.nodes.insert(node_id.clone(), node);
134 self.edges.entry(node_id.clone()).or_default();
135 self.reverse_edges.entry(node_id).or_default();
136 self
137 }
138
139 pub fn add_edge(&mut self, edge: EdgeConfig) -> &mut Self {
141 let from = edge.from.clone();
142 let to = edge.to.clone();
143
144 self.edges.entry(from).or_default().push(edge.clone());
146
147 self.reverse_edges.entry(to).or_default().push(edge);
149
150 self
151 }
152
153 pub fn connect(&mut self, from: &str, to: &str) -> &mut Self {
155 self.add_edge(EdgeConfig::new(from, to))
156 }
157
158 pub fn connect_conditional(&mut self, from: &str, to: &str, condition: &str) -> &mut Self {
160 self.add_edge(EdgeConfig::conditional(from, to, condition))
161 }
162
163 pub fn get_node(&self, node_id: &str) -> Option<&WorkflowNode> {
165 self.nodes.get(node_id)
166 }
167
168 pub fn get_node_mut(&mut self, node_id: &str) -> Option<&mut WorkflowNode> {
170 self.nodes.get_mut(node_id)
171 }
172
173 pub fn node_ids(&self) -> Vec<&str> {
175 self.nodes.keys().map(|s| s.as_str()).collect()
176 }
177
178 pub fn node_count(&self) -> usize {
180 self.nodes.len()
181 }
182
183 pub fn edge_count(&self) -> usize {
185 self.edges.values().map(|e| e.len()).sum()
186 }
187
188 pub fn start_node(&self) -> Option<&str> {
190 self.start_node.as_deref()
191 }
192
193 pub fn end_nodes(&self) -> &[String] {
195 &self.end_nodes
196 }
197
198 pub fn get_outgoing_edges(&self, node_id: &str) -> &[EdgeConfig] {
200 self.edges.get(node_id).map(|v| v.as_slice()).unwrap_or(&[])
201 }
202
203 pub fn get_incoming_edges(&self, node_id: &str) -> &[EdgeConfig] {
205 self.reverse_edges
206 .get(node_id)
207 .map(|v| v.as_slice())
208 .unwrap_or(&[])
209 }
210
211 pub fn get_successors(&self, node_id: &str) -> Vec<&str> {
213 self.get_outgoing_edges(node_id)
214 .iter()
215 .map(|e| e.to.as_str())
216 .collect()
217 }
218
219 pub fn get_predecessors(&self, node_id: &str) -> Vec<&str> {
221 self.get_incoming_edges(node_id)
222 .iter()
223 .map(|e| e.from.as_str())
224 .collect()
225 }
226
227 pub fn get_next_node(&self, node_id: &str, condition: Option<&str>) -> Option<&str> {
229 let edges = self.get_outgoing_edges(node_id);
230
231 if let Some(cond) = condition {
233 for edge in edges {
234 if let EdgeType::Conditional(c) = &edge.edge_type
235 && c == cond
236 {
237 return Some(&edge.to);
238 }
239 }
240 }
241
242 for edge in edges {
244 if matches!(edge.edge_type, EdgeType::Default) {
245 return Some(&edge.to);
246 }
247 }
248
249 for edge in edges {
251 if matches!(edge.edge_type, EdgeType::Normal) {
252 return Some(&edge.to);
253 }
254 }
255
256 None
257 }
258
259 pub fn get_error_handler(&self, node_id: &str) -> Option<&str> {
261 let edges = self.get_outgoing_edges(node_id);
262 for edge in edges {
263 if matches!(edge.edge_type, EdgeType::Error) {
264 return Some(&edge.to);
265 }
266 }
267 None
268 }
269
270 pub fn topological_sort(&self) -> Result<Vec<String>, String> {
272 let mut in_degree: HashMap<&str, usize> = HashMap::new();
273 let mut queue: VecDeque<&str> = VecDeque::new();
274 let mut result: Vec<String> = Vec::new();
275
276 for node_id in self.nodes.keys() {
278 in_degree.insert(node_id, 0);
279 }
280 for edges in self.edges.values() {
281 for edge in edges {
282 *in_degree.entry(&edge.to).or_insert(0) += 1;
283 }
284 }
285
286 for (node_id, °ree) in &in_degree {
288 if degree == 0 {
289 queue.push_back(node_id);
290 }
291 }
292
293 while let Some(node_id) = queue.pop_front() {
295 result.push(node_id.to_string());
296
297 for edge in self.get_outgoing_edges(node_id) {
298 if let Some(degree) = in_degree.get_mut(edge.to.as_str()) {
299 *degree -= 1;
300 if *degree == 0 {
301 queue.push_back(&edge.to);
302 }
303 }
304 }
305 }
306
307 if result.len() != self.nodes.len() {
309 return Err("Graph contains a cycle".to_string());
310 }
311
312 Ok(result)
313 }
314
315 pub fn has_cycle(&self) -> bool {
317 self.topological_sort().is_err()
318 }
319
320 pub fn get_parallel_groups(&self) -> Vec<Vec<String>> {
322 let mut groups: Vec<Vec<String>> = Vec::new();
323 let mut in_degree: HashMap<&str, usize> = HashMap::new();
324 let mut remaining: HashSet<&str> = self.nodes.keys().map(|s| s.as_str()).collect();
325
326 for node_id in self.nodes.keys() {
328 in_degree.insert(node_id, 0);
329 }
330 for edges in self.edges.values() {
331 for edge in edges {
332 *in_degree.entry(&edge.to).or_insert(0) += 1;
333 }
334 }
335
336 while !remaining.is_empty() {
337 let ready: Vec<String> = remaining
339 .iter()
340 .filter(|&&node_id| in_degree.get(node_id).copied().unwrap_or(0) == 0)
341 .map(|&s| s.to_string())
342 .collect();
343
344 if ready.is_empty() {
345 warn!("Cycle detected in workflow graph");
346 break;
347 }
348
349 for node_id in &ready {
351 remaining.remove(node_id.as_str());
352 for edge in self.get_outgoing_edges(node_id) {
353 if let Some(degree) = in_degree.get_mut(edge.to.as_str()) {
354 *degree = degree.saturating_sub(1);
355 }
356 }
357 }
358
359 groups.push(ready);
360 }
361
362 groups
363 }
364
365 pub fn validate(&self) -> Result<(), Vec<String>> {
367 let mut errors: Vec<String> = Vec::new();
368
369 if self.start_node.is_none() {
371 errors.push("No start node found".to_string());
372 }
373
374 if self.end_nodes.is_empty() {
376 errors.push("No end node found".to_string());
377 }
378
379 for (from, edges) in &self.edges {
381 if !self.nodes.contains_key(from) {
382 errors.push(format!("Edge source node '{}' not found", from));
383 }
384 for edge in edges {
385 if !self.nodes.contains_key(&edge.to) {
386 errors.push(format!("Edge target node '{}' not found", edge.to));
387 }
388 }
389 }
390
391 for node_id in self.nodes.keys() {
393 if node_id != self.start_node.as_ref().unwrap_or(&String::new())
394 && self.get_incoming_edges(node_id).is_empty()
395 {
396 errors.push(format!("Node '{}' is unreachable", node_id));
397 }
398 }
399
400 if self.has_cycle() {
402 errors.push("Graph contains a cycle".to_string());
403 }
404
405 for (node_id, node) in &self.nodes {
407 if matches!(node.node_type(), NodeType::Parallel) {
408 debug!("Checking parallel node: {}", node_id);
410 }
411 }
412
413 if errors.is_empty() {
414 Ok(())
415 } else {
416 Err(errors)
417 }
418 }
419
420 pub fn find_all_paths(&self, from: &str, to: &str) -> Vec<Vec<String>> {
422 let mut paths: Vec<Vec<String>> = Vec::new();
423 let mut current_path: Vec<String> = Vec::new();
424 let mut visited: HashSet<String> = HashSet::new();
425
426 self.dfs_paths(from, to, &mut current_path, &mut visited, &mut paths);
427 paths
428 }
429
430 fn dfs_paths(
431 &self,
432 current: &str,
433 target: &str,
434 path: &mut Vec<String>,
435 visited: &mut HashSet<String>,
436 paths: &mut Vec<Vec<String>>,
437 ) {
438 path.push(current.to_string());
439 visited.insert(current.to_string());
440
441 if current == target {
442 paths.push(path.clone());
443 } else {
444 for edge in self.get_outgoing_edges(current) {
445 if !visited.contains(&edge.to) {
446 self.dfs_paths(&edge.to, target, path, visited, paths);
447 }
448 }
449 }
450
451 path.pop();
452 visited.remove(current);
453 }
454
455 pub fn to_dot(&self) -> String {
457 let mut dot = String::new();
458 dot.push_str(&format!("digraph \"{}\" {{\n", self.name));
459 dot.push_str(" rankdir=TB;\n");
460 dot.push_str(" node [shape=box];\n\n");
461
462 for (node_id, node) in &self.nodes {
464 let shape = match node.node_type() {
465 NodeType::Start => "ellipse",
466 NodeType::End => "ellipse",
467 NodeType::Condition => "diamond",
468 NodeType::Parallel => "parallelogram",
469 NodeType::Join => "parallelogram",
470 NodeType::Loop => "hexagon",
471 _ => "box",
472 };
473 let color = match node.node_type() {
474 NodeType::Start => "green",
475 NodeType::End => "red",
476 NodeType::Condition => "yellow",
477 NodeType::Parallel | NodeType::Join => "cyan",
478 _ => "white",
479 };
480 dot.push_str(&format!(
481 " \"{}\" [label=\"{}\\n({})\", shape={}, style=filled, fillcolor={}];\n",
482 node_id, node.config.name, node_id, shape, color
483 ));
484 }
485
486 dot.push('\n');
487
488 for (from, edges) in &self.edges {
490 for edge in edges {
491 let label = edge.label.as_deref().unwrap_or("");
492 let style = match edge.edge_type {
493 EdgeType::Normal => "solid",
494 EdgeType::Conditional(_) => "dashed",
495 EdgeType::Error => "dotted",
496 EdgeType::Default => "bold",
497 };
498 dot.push_str(&format!(
499 " \"{}\" -> \"{}\" [label=\"{}\", style={}];\n",
500 from, edge.to, label, style
501 ));
502 }
503 }
504
505 dot.push_str("}\n");
506 dot
507 }
508}
509
510#[cfg(test)]
511mod tests {
512 use super::*;
513
514 fn create_test_graph() -> WorkflowGraph {
515 let mut graph = WorkflowGraph::new("test", "Test Workflow");
516
517 graph.add_node(WorkflowNode::start("start"));
518 graph.add_node(WorkflowNode::task(
519 "task1",
520 "Task 1",
521 |_ctx, input| async move { Ok(input) },
522 ));
523 graph.add_node(WorkflowNode::task(
524 "task2",
525 "Task 2",
526 |_ctx, input| async move { Ok(input) },
527 ));
528 graph.add_node(WorkflowNode::end("end"));
529
530 graph.connect("start", "task1");
531 graph.connect("task1", "task2");
532 graph.connect("task2", "end");
533
534 graph
535 }
536
537 #[test]
538 fn test_topological_sort() {
539 let graph = create_test_graph();
540 let sorted = graph.topological_sort().unwrap();
541
542 let start_pos = sorted.iter().position(|x| x == "start").unwrap();
544 let task1_pos = sorted.iter().position(|x| x == "task1").unwrap();
545 let task2_pos = sorted.iter().position(|x| x == "task2").unwrap();
546 let end_pos = sorted.iter().position(|x| x == "end").unwrap();
547
548 assert!(start_pos < task1_pos);
549 assert!(task1_pos < task2_pos);
550 assert!(task2_pos < end_pos);
551 }
552
553 #[test]
554 fn test_parallel_groups() {
555 let mut graph = WorkflowGraph::new("test", "Test");
556
557 graph.add_node(WorkflowNode::start("start"));
558 graph.add_node(WorkflowNode::task("a", "A", |_ctx, input| async move {
559 Ok(input)
560 }));
561 graph.add_node(WorkflowNode::task("b", "B", |_ctx, input| async move {
562 Ok(input)
563 }));
564 graph.add_node(WorkflowNode::task("c", "C", |_ctx, input| async move {
565 Ok(input)
566 }));
567 graph.add_node(WorkflowNode::end("end"));
568
569 graph.connect("start", "a");
570 graph.connect("start", "b");
571 graph.connect("a", "c");
572 graph.connect("b", "c");
573 graph.connect("c", "end");
574
575 let groups = graph.get_parallel_groups();
576
577 assert_eq!(groups.len(), 4);
582 assert!(groups[1].contains(&"a".to_string()) && groups[1].contains(&"b".to_string()));
583 }
584
585 #[test]
586 fn test_cycle_detection() {
587 let mut graph = WorkflowGraph::new("test", "Test");
588
589 graph.add_node(WorkflowNode::task("a", "A", |_ctx, input| async move {
590 Ok(input)
591 }));
592 graph.add_node(WorkflowNode::task("b", "B", |_ctx, input| async move {
593 Ok(input)
594 }));
595 graph.add_node(WorkflowNode::task("c", "C", |_ctx, input| async move {
596 Ok(input)
597 }));
598
599 graph.connect("a", "b");
600 graph.connect("b", "c");
601 graph.connect("c", "a"); assert!(graph.has_cycle());
604 }
605
606 #[test]
607 fn test_find_paths() {
608 let graph = create_test_graph();
609 let paths = graph.find_all_paths("start", "end");
610
611 assert_eq!(paths.len(), 1);
612 assert_eq!(paths[0], vec!["start", "task1", "task2", "end"]);
613 }
614
615 #[test]
616 fn test_to_dot() {
617 let graph = create_test_graph();
618 let dot = graph.to_dot();
619
620 assert!(dot.contains("digraph"));
621 assert!(dot.contains("start"));
622 assert!(dot.contains("end"));
623 assert!(dot.contains("->"));
624 }
625}