Skip to main content

graphs_tui/
layout.rs

1use crate::types::{Direction, Graph, NodeId, NodeShape, RenderOptions};
2use std::collections::{HashMap, VecDeque};
3
4const MIN_NODE_WIDTH: usize = 5;
5const NODE_HEIGHT: usize = 3;
6const DEFAULT_HORIZONTAL_GAP: usize = 8;
7const DEFAULT_VERTICAL_GAP: usize = 4;
8const MIN_GAP: usize = 2;
9
10const SUBGRAPH_PADDING: usize = 2;
11
12/// Compute layout for all nodes in the graph
13///
14/// Returns a list of warnings (e.g., cycle detected).
15pub fn compute_layout(graph: &mut Graph) -> Vec<String> {
16    compute_layout_with_options(graph, &RenderOptions::default())
17}
18
19/// Compute layout for all nodes with render options (considers max_width)
20///
21/// Returns a list of warnings (e.g., cycle detected).
22pub fn compute_layout_with_options(graph: &mut Graph, options: &RenderOptions) -> Vec<String> {
23    let mut warnings = Vec::new();
24
25    // 1. Compute node sizes (use chars().count() for proper Unicode handling)
26    for node in graph.nodes.values_mut() {
27        node.width = (node.label.chars().count() + 2).max(MIN_NODE_WIDTH);
28        node.height = NODE_HEIGHT;
29        if node.shape == NodeShape::Cylinder {
30            node.height = 5;
31        }
32    }
33
34    // 2. Topological layering
35    let layers = assign_layers(graph, &mut warnings);
36
37    // 3. Calculate gaps based on available width
38    let (h_gap, v_gap) = calculate_gaps(graph, &layers, options.max_width);
39
40    // 4. Position assignment based on direction with calculated gaps
41    assign_coordinates_with_gaps(graph, &layers, h_gap, v_gap);
42
43    // 5. Compute subgraph bounding boxes
44    compute_subgraph_bounds(graph);
45
46    warnings
47}
48
49/// Calculate adaptive gaps based on available width
50fn calculate_gaps(
51    graph: &Graph,
52    layers: &HashMap<NodeId, usize>,
53    max_width: Option<usize>,
54) -> (usize, usize) {
55    let max_width = match max_width {
56        Some(w) => w,
57        None => return (DEFAULT_HORIZONTAL_GAP, DEFAULT_VERTICAL_GAP),
58    };
59
60    // Group nodes by layer
61    let mut layers_map: HashMap<usize, Vec<&NodeId>> = HashMap::new();
62    let mut max_layer = 0;
63
64    for (id, &layer) in layers {
65        layers_map.entry(layer).or_default().push(id);
66        max_layer = max_layer.max(layer);
67    }
68
69    // Calculate natural width with default gaps (for horizontal layouts)
70    if graph.direction.is_horizontal() {
71        let mut total_width = 0;
72        for l in 0..=max_layer {
73            let nodes_in_layer = layers_map.get(&l).map(|v| v.as_slice()).unwrap_or(&[]);
74            let layer_max_width = nodes_in_layer
75                .iter()
76                .filter_map(|id| graph.nodes.get(*id))
77                .map(|n| n.width)
78                .max()
79                .unwrap_or(0);
80            total_width += layer_max_width;
81        }
82        total_width += max_layer * DEFAULT_HORIZONTAL_GAP;
83
84        // If natural width exceeds max_width, reduce horizontal gap
85        if total_width > max_width && max_layer > 0 {
86            let node_width = total_width - max_layer * DEFAULT_HORIZONTAL_GAP;
87            let available_for_gaps = max_width.saturating_sub(node_width);
88            let new_gap = (available_for_gaps / max_layer).max(MIN_GAP);
89            return (new_gap, DEFAULT_VERTICAL_GAP);
90        }
91    }
92
93    (DEFAULT_HORIZONTAL_GAP, DEFAULT_VERTICAL_GAP)
94}
95
96/// Compute bounding boxes for all subgraphs
97fn compute_subgraph_bounds(graph: &mut Graph) {
98    for sg in &mut graph.subgraphs {
99        if sg.nodes.is_empty() {
100            continue;
101        }
102
103        let mut min_x = usize::MAX;
104        let mut min_y = usize::MAX;
105        let mut max_x = 0;
106        let mut max_y = 0;
107
108        for node_id in &sg.nodes {
109            if let Some(node) = graph.nodes.get(node_id) {
110                min_x = min_x.min(node.x);
111                min_y = min_y.min(node.y);
112                max_x = max_x.max(node.x + node.width);
113                max_y = max_y.max(node.y + node.height);
114            }
115        }
116
117        if min_x != usize::MAX {
118            // Add padding around the subgraph
119            sg.x = min_x.saturating_sub(SUBGRAPH_PADDING);
120            sg.y = min_y.saturating_sub(SUBGRAPH_PADDING + 1); // Extra space for label
121            sg.width = (max_x - min_x) + SUBGRAPH_PADDING * 2;
122            sg.height = (max_y - min_y) + SUBGRAPH_PADDING * 2 + 1;
123        }
124    }
125}
126
127/// Assign layer numbers using Kahn's algorithm
128fn assign_layers(graph: &Graph, warnings: &mut Vec<String>) -> HashMap<NodeId, usize> {
129    let mut node_layers: HashMap<NodeId, usize> = HashMap::new();
130    let mut in_degree: HashMap<NodeId, usize> = HashMap::new();
131
132    // Initialize
133    for id in graph.nodes.keys() {
134        in_degree.insert(id.clone(), 0);
135        node_layers.insert(id.clone(), 0);
136    }
137
138    // Count in-degrees
139    for edge in &graph.edges {
140        *in_degree.entry(edge.to.clone()).or_insert(0) += 1;
141    }
142
143    // Start with nodes that have no incoming edges
144    let mut queue: VecDeque<NodeId> = VecDeque::new();
145    for (id, &degree) in &in_degree {
146        if degree == 0 {
147            queue.push_back(id.clone());
148        }
149    }
150
151    let mut processed = 0;
152    while let Some(u) = queue.pop_front() {
153        processed += 1;
154
155        // Find all neighbors (nodes that u points to)
156        let neighbors: Vec<NodeId> = graph
157            .edges
158            .iter()
159            .filter(|e| e.from == u)
160            .map(|e| e.to.clone())
161            .collect();
162
163        for v in neighbors {
164            // Update layer to be at least one more than predecessor
165            let u_layer = *node_layers.get(&u).unwrap_or(&0);
166            let v_layer = node_layers.entry(v.clone()).or_insert(0);
167            *v_layer = (*v_layer).max(u_layer + 1);
168
169            // Decrement in-degree
170            if let Some(deg) = in_degree.get_mut(&v) {
171                *deg -= 1;
172                if *deg == 0 {
173                    queue.push_back(v);
174                }
175            }
176        }
177    }
178
179    // Check for cycles
180    if processed < graph.nodes.len() {
181        warnings.push("Cycle detected in graph. Layout may be imperfect.".to_string());
182    }
183
184    node_layers
185}
186
187/// Assign x,y coordinates based on layers and direction with configurable gaps
188fn assign_coordinates_with_gaps(
189    graph: &mut Graph,
190    node_layers: &HashMap<NodeId, usize>,
191    h_gap: usize,
192    v_gap: usize,
193) {
194    let direction = graph.direction;
195
196    // Group nodes by layer
197    let mut layers_map: HashMap<usize, Vec<NodeId>> = HashMap::new();
198    let mut max_layer = 0;
199
200    for (id, &layer) in node_layers {
201        layers_map.entry(layer).or_default().push(id.clone());
202        max_layer = max_layer.max(layer);
203    }
204
205    // Calculate layer dimensions
206    let mut layer_widths: HashMap<usize, usize> = HashMap::new();
207    let mut layer_heights: HashMap<usize, usize> = HashMap::new();
208
209    for l in 0..=max_layer {
210        let nodes_in_layer = layers_map.get(&l).map(|v| v.as_slice()).unwrap_or(&[]);
211        let mut max_w = 0;
212        let mut max_h = 0;
213        let mut total_w = 0;
214        let mut total_h = 0;
215
216        for id in nodes_in_layer {
217            if let Some(node) = graph.nodes.get(id) {
218                max_w = max_w.max(node.width);
219                max_h = max_h.max(node.height);
220                total_w += node.width + h_gap;
221                total_h += node.height + v_gap;
222            }
223        }
224
225        if direction.is_horizontal() {
226            layer_widths.insert(l, max_w);
227            layer_heights.insert(l, total_h.saturating_sub(v_gap));
228        } else {
229            layer_widths.insert(l, total_w.saturating_sub(h_gap));
230            layer_heights.insert(l, max_h);
231        }
232    }
233
234    let max_total_width = layer_widths.values().copied().max().unwrap_or(0);
235    let max_total_height = layer_heights.values().copied().max().unwrap_or(0);
236
237    if direction.is_horizontal() {
238        let mut current_x = 0;
239        for l in 0..=max_layer {
240            let layer_idx = match direction {
241                Direction::LR => l,
242                Direction::RL => max_layer - l,
243                _ => l,
244            };
245
246            let nodes_in_layer = layers_map.get(&layer_idx).cloned().unwrap_or_default();
247            let layer_h = *layer_heights.get(&layer_idx).unwrap_or(&0);
248            let mut start_y = (max_total_height.saturating_sub(layer_h)) / 2;
249
250            for id in nodes_in_layer {
251                if let Some(node) = graph.nodes.get_mut(&id) {
252                    node.x = current_x;
253                    node.y = start_y;
254                    start_y += node.height + v_gap;
255                }
256            }
257
258            current_x += layer_widths.get(&layer_idx).unwrap_or(&0) + h_gap;
259        }
260    } else {
261        let mut current_y = 0;
262        for l in 0..=max_layer {
263            let layer_idx = match direction {
264                Direction::TB => l,
265                Direction::BT => max_layer - l,
266                _ => l,
267            };
268
269            let nodes_in_layer = layers_map.get(&layer_idx).cloned().unwrap_or_default();
270            let layer_w = *layer_widths.get(&layer_idx).unwrap_or(&0);
271            let mut start_x = (max_total_width.saturating_sub(layer_w)) / 2;
272
273            for id in nodes_in_layer {
274                if let Some(node) = graph.nodes.get_mut(&id) {
275                    node.x = start_x;
276                    node.y = current_y;
277                    start_x += node.width + h_gap;
278                }
279            }
280
281            current_y += layer_heights.get(&layer_idx).unwrap_or(&0) + v_gap;
282        }
283    }
284}
285
286#[cfg(test)]
287mod tests {
288    use super::*;
289    use crate::parser::parse_mermaid;
290
291    #[test]
292    fn test_layout_lr() {
293        let mut graph = parse_mermaid("flowchart LR\nA --> B").unwrap();
294        let warnings = compute_layout(&mut graph);
295
296        let a = graph.nodes.get("A").unwrap();
297        let b = graph.nodes.get("B").unwrap();
298
299        assert!(a.x < b.x);
300        assert!(warnings.is_empty());
301    }
302
303    #[test]
304    fn test_layout_tb() {
305        let mut graph = parse_mermaid("flowchart TB\nA --> B").unwrap();
306        let warnings = compute_layout(&mut graph);
307
308        let a = graph.nodes.get("A").unwrap();
309        let b = graph.nodes.get("B").unwrap();
310
311        assert!(a.y < b.y);
312        assert!(warnings.is_empty());
313    }
314
315    #[test]
316    fn test_node_sizes() {
317        let mut graph = parse_mermaid("flowchart LR\nA[Hello World]").unwrap();
318        compute_layout(&mut graph);
319
320        let a = graph.nodes.get("A").unwrap();
321        assert_eq!(a.width, "Hello World".len() + 2);
322        assert_eq!(a.height, NODE_HEIGHT);
323    }
324
325    #[test]
326    fn test_cycle_produces_warning() {
327        let mut graph = parse_mermaid("flowchart LR\nA --> B\nB --> C\nC --> A").unwrap();
328        let warnings = compute_layout(&mut graph);
329        assert_eq!(warnings.len(), 1);
330        assert!(warnings[0].contains("Cycle"));
331    }
332
333    #[test]
334    fn test_acyclic_no_warning() {
335        let mut graph = parse_mermaid("flowchart LR\nA --> B\nB --> C\nA --> C").unwrap();
336        let warnings = compute_layout(&mut graph);
337        assert!(warnings.is_empty());
338    }
339}