Skip to main content

matrixcode_tui/workflow/
layout.rs

1//! Workflow Layout Algorithm
2//!
3//! Computes DAG layout positions for rendering
4
5use matrixcode_core::workflow::WorkflowDef;
6use std::collections::HashMap;
7
8/// Layout computation result
9pub struct LayoutResult {
10    /// Node positions (row, col)
11    pub positions: HashMap<String, (usize, usize)>,
12    /// Layers (nodes grouped by depth)
13    pub layers: Vec<Vec<String>>,
14    /// Grid dimensions
15    pub width: usize,
16    pub height: usize,
17}
18
19/// Compute layout from workflow definition
20pub fn compute_layout(def: &WorkflowDef) -> LayoutResult {
21    let mut positions = HashMap::new();
22    let mut layers: Vec<Vec<String>> = Vec::new();
23    let mut visited: HashMap<String, bool> = HashMap::new();
24
25    // Find start node
26    let start_node = def
27        .nodes
28        .iter()
29        .find(|n| n.node_type == matrixcode_core::workflow::NodeType::Start)
30        .map(|n| n.id.clone());
31
32    if let Some(start) = start_node {
33        // BFS layer assignment
34        let mut current_layer = vec![start.clone()];
35        visited.insert(start.clone(), true);
36
37        while !current_layer.is_empty() {
38            // Sort layer nodes for consistent ordering
39            let mut sorted_layer = current_layer.clone();
40            sorted_layer.sort();
41            layers.push(sorted_layer.clone());
42
43            // Record positions
44            for (col, node_id) in sorted_layer.iter().enumerate() {
45                positions.insert(node_id.clone(), (layers.len() - 1, col));
46            }
47
48            // Find next layer
49            let mut next_layer = Vec::new();
50            for node_id in &current_layer {
51                for edge in &def.edges {
52                    if &edge.from == node_id && !visited.contains_key(&edge.to) {
53                        // Check if target node exists
54                        if def.nodes.iter().any(|n| n.id == edge.to) {
55                            visited.insert(edge.to.clone(), true);
56                            next_layer.push(edge.to.clone());
57                        }
58                    }
59                }
60            }
61
62            current_layer = next_layer;
63        }
64
65        // Add any unvisited nodes to last layer (fallback)
66        for node in &def.nodes {
67            if !visited.contains_key(&node.id) {
68                if let Some(last_layer) = layers.last_mut() {
69                    last_layer.push(node.id.clone());
70                    let col = last_layer.len() - 1;
71                    positions.insert(node.id.clone(), (layers.len() - 1, col));
72                } else {
73                    layers.push(vec![node.id.clone()]);
74                    positions.insert(node.id.clone(), (0, 0));
75                }
76            }
77        }
78    }
79
80    let height = layers.len();
81    let width = layers.iter().map(|l| l.len()).max().unwrap_or(0);
82
83    LayoutResult {
84        positions,
85        layers,
86        width,
87        height,
88    }
89}
90
91/// Calculate edge path points
92pub fn calculate_edge_path(
93    from_pos: (usize, usize),
94    to_pos: (usize, usize),
95    node_width: usize,
96    node_height: usize,
97    spacing_x: usize,
98    spacing_y: usize,
99) -> Vec<(usize, usize)> {
100    let from_x = from_pos.1 * (node_width + spacing_x) + node_width / 2;
101    let from_y = from_pos.0 * (node_height + spacing_y) + node_height;
102    let to_x = to_pos.1 * (node_width + spacing_x) + node_width / 2;
103    let to_y = to_pos.0 * (node_height + spacing_y);
104
105    // Simple vertical path
106    let mut points = Vec::new();
107    points.push((from_x, from_y));
108
109    // If same column, direct vertical line
110    if from_x == to_x {
111        for y in from_y + 1..to_y {
112            points.push((from_x, y));
113        }
114    } else {
115        // Different columns: need horizontal segment
116        let mid_y = (from_y + to_y) / 2;
117        for y in from_y + 1..mid_y {
118            points.push((from_x, y));
119        }
120        for x in from_x..to_x {
121            points.push((x, mid_y));
122        }
123        for y in mid_y..to_y {
124            points.push((to_x, y));
125        }
126    }
127
128    points.push((to_x, to_y));
129    points
130}