Skip to main content

matrixcode_tui/workflow/
layout.rs

1//! Workflow Layout Algorithm
2//!
3//! Computes DAG layout positions for rendering
4
5use matrixcode_core::workflow::WorkflowDef;
6use std::collections::HashMap;
7
8/// Layout computation result
9pub struct LayoutResult {
10    /// Node positions (row, col)
11    pub positions: HashMap<String, (usize, usize)>,
12    /// Layers (nodes grouped by depth)
13    pub layers: Vec<Vec<String>>,
14    /// Grid dimensions
15    pub width: usize,
16    pub height: usize,
17}
18
19/// Compute layout from workflow definition
20pub fn compute_layout(def: &WorkflowDef) -> LayoutResult {
21    let mut positions = HashMap::new();
22    let mut layers: Vec<Vec<String>> = Vec::new();
23    let mut visited: HashMap<String, bool> = HashMap::new();
24
25    // Find start node
26    let start_node = def.nodes.iter()
27        .find(|n| n.node_type == matrixcode_core::workflow::NodeType::Start)
28        .map(|n| n.id.clone());
29
30    if let Some(start) = start_node {
31        // BFS layer assignment
32        let mut current_layer = vec![start.clone()];
33        visited.insert(start.clone(), true);
34
35        while !current_layer.is_empty() {
36            // Sort layer nodes for consistent ordering
37            let mut sorted_layer = current_layer.clone();
38            sorted_layer.sort();
39            layers.push(sorted_layer.clone());
40
41            // Record positions
42            for (col, node_id) in sorted_layer.iter().enumerate() {
43                positions.insert(node_id.clone(), (layers.len() - 1, col));
44            }
45
46            // Find next layer
47            let mut next_layer = Vec::new();
48            for node_id in &current_layer {
49                for edge in &def.edges {
50                    if &edge.from == node_id && !visited.contains_key(&edge.to) {
51                        // Check if target node exists
52                        if def.nodes.iter().any(|n| n.id == edge.to) {
53                            visited.insert(edge.to.clone(), true);
54                            next_layer.push(edge.to.clone());
55                        }
56                    }
57                }
58            }
59
60            current_layer = next_layer;
61        }
62
63        // Add any unvisited nodes to last layer (fallback)
64        for node in &def.nodes {
65            if !visited.contains_key(&node.id) {
66                if let Some(last_layer) = layers.last_mut() {
67                    last_layer.push(node.id.clone());
68                    let col = last_layer.len() - 1;
69                    positions.insert(node.id.clone(), (layers.len() - 1, col));
70                } else {
71                    layers.push(vec![node.id.clone()]);
72                    positions.insert(node.id.clone(), (0, 0));
73                }
74            }
75        }
76    }
77
78    let height = layers.len();
79    let width = layers.iter().map(|l| l.len()).max().unwrap_or(0);
80
81    LayoutResult {
82        positions,
83        layers,
84        width,
85        height,
86    }
87}
88
89/// Calculate edge path points
90pub fn calculate_edge_path(
91    from_pos: (usize, usize),
92    to_pos: (usize, usize),
93    node_width: usize,
94    node_height: usize,
95    spacing_x: usize,
96    spacing_y: usize,
97) -> Vec<(usize, usize)> {
98    let from_x = from_pos.1 * (node_width + spacing_x) + node_width / 2;
99    let from_y = from_pos.0 * (node_height + spacing_y) + node_height;
100    let to_x = to_pos.1 * (node_width + spacing_x) + node_width / 2;
101    let to_y = to_pos.0 * (node_height + spacing_y);
102
103    // Simple vertical path
104    let mut points = Vec::new();
105    points.push((from_x, from_y));
106
107    // If same column, direct vertical line
108    if from_x == to_x {
109        for y in from_y + 1..to_y {
110            points.push((from_x, y));
111        }
112    } else {
113        // Different columns: need horizontal segment
114        let mid_y = (from_y + to_y) / 2;
115        for y in from_y + 1..mid_y {
116            points.push((from_x, y));
117        }
118        for x in from_x..to_x {
119            points.push((x, mid_y));
120        }
121        for y in mid_y..to_y {
122            points.push((to_x, y));
123        }
124    }
125
126    points.push((to_x, to_y));
127    points
128}