Skip to main content

graphs_tui/
layout.rs

1use crate::types::{DiagramWarning, Direction, Graph, NodeId, NodeShape, RenderOptions};
2use std::collections::{HashMap, HashSet, VecDeque};
3
4const MIN_NODE_WIDTH: usize = 5;
5const NODE_HEIGHT: usize = 3;
6const DEFAULT_HORIZONTAL_GAP: usize = 8;
7const DEFAULT_VERTICAL_GAP: usize = 4;
8const MIN_GAP: usize = 2;
9
10const SUBGRAPH_PADDING: usize = 2;
11
12/// Compute layout for all nodes in the graph
13///
14/// Returns a list of warnings (e.g., cycle detected).
15pub fn compute_layout(graph: &mut Graph) -> Vec<DiagramWarning> {
16    compute_layout_with_options(graph, &RenderOptions::default())
17}
18
19/// Compute layout for all nodes with render options (considers max_width)
20///
21/// Returns a list of warnings (e.g., cycle detected).
22pub fn compute_layout_with_options(
23    graph: &mut Graph,
24    options: &RenderOptions,
25) -> Vec<DiagramWarning> {
26    let mut warnings = Vec::new();
27
28    // 1. Compute node sizes (use chars().count() for proper Unicode handling)
29    for node in graph.nodes.values_mut() {
30        node.width = (node.label.chars().count() + 2).max(MIN_NODE_WIDTH);
31        node.height = NODE_HEIGHT;
32        if node.shape == NodeShape::Cylinder {
33            node.height = 5;
34        }
35    }
36
37    // 2. Topological layering
38    let layers = assign_layers(graph, &mut warnings);
39
40    // 3. Calculate gaps based on available width
41    let (h_gap, v_gap) = calculate_gaps(graph, &layers, options.max_width);
42
43    // 4. Position assignment based on direction with calculated gaps
44    assign_coordinates_with_gaps(graph, &layers, h_gap, v_gap);
45
46    // 5. Compute subgraph bounding boxes
47    compute_subgraph_bounds(graph);
48
49    warnings
50}
51
52/// Calculate adaptive gaps based on available width
53fn calculate_gaps(
54    graph: &Graph,
55    layers: &HashMap<NodeId, usize>,
56    max_width: Option<usize>,
57) -> (usize, usize) {
58    let max_width = match max_width {
59        Some(w) => w,
60        None => return (DEFAULT_HORIZONTAL_GAP, DEFAULT_VERTICAL_GAP),
61    };
62
63    // Group nodes by layer (sorted for determinism)
64    let mut layers_map: HashMap<usize, Vec<&NodeId>> = HashMap::new();
65    let mut max_layer = 0;
66
67    for (id, &layer) in layers {
68        layers_map.entry(layer).or_default().push(id);
69        max_layer = max_layer.max(layer);
70    }
71    for nodes in layers_map.values_mut() {
72        nodes.sort();
73    }
74
75    // Calculate natural width with default gaps (for horizontal layouts)
76    if graph.direction.is_horizontal() {
77        let mut total_width = 0;
78        for l in 0..=max_layer {
79            let nodes_in_layer = layers_map.get(&l).map(|v| v.as_slice()).unwrap_or(&[]);
80            let layer_max_width = nodes_in_layer
81                .iter()
82                .filter_map(|id| graph.nodes.get(*id))
83                .map(|n| n.width)
84                .max()
85                .unwrap_or(0);
86            total_width += layer_max_width;
87        }
88        total_width += max_layer * DEFAULT_HORIZONTAL_GAP;
89
90        // If natural width exceeds max_width, reduce horizontal gap
91        if total_width > max_width && max_layer > 0 {
92            let node_width = total_width - max_layer * DEFAULT_HORIZONTAL_GAP;
93            let available_for_gaps = max_width.saturating_sub(node_width);
94            let new_gap = (available_for_gaps / max_layer).max(MIN_GAP);
95            return (new_gap, DEFAULT_VERTICAL_GAP);
96        }
97    }
98
99    (DEFAULT_HORIZONTAL_GAP, DEFAULT_VERTICAL_GAP)
100}
101
102/// Compute bounding boxes for all subgraphs
103fn compute_subgraph_bounds(graph: &mut Graph) {
104    for sg in &mut graph.subgraphs {
105        if sg.nodes.is_empty() {
106            continue;
107        }
108
109        let mut min_x = usize::MAX;
110        let mut min_y = usize::MAX;
111        let mut max_x = 0;
112        let mut max_y = 0;
113
114        for node_id in &sg.nodes {
115            if let Some(node) = graph.nodes.get(node_id) {
116                min_x = min_x.min(node.x);
117                min_y = min_y.min(node.y);
118                max_x = max_x.max(node.x + node.width);
119                max_y = max_y.max(node.y + node.height);
120            }
121        }
122
123        if min_x != usize::MAX {
124            // Add padding around the subgraph
125            sg.x = min_x.saturating_sub(SUBGRAPH_PADDING);
126            sg.y = min_y.saturating_sub(SUBGRAPH_PADDING + 1); // Extra space for label
127            sg.width = (max_x - min_x) + SUBGRAPH_PADDING * 2;
128            sg.height = (max_y - min_y) + SUBGRAPH_PADDING * 2 + 1;
129        }
130    }
131}
132
133/// Assign layer numbers using Kahn's algorithm with cycle-breaking.
134///
135/// Standard Kahn's processes nodes with in_degree=0. When the queue empties
136/// but unprocessed nodes remain, a cycle exists. We force-process the stuck
137/// node that appears earliest as a "from" in the edge list (preserving the
138/// user's intended flow direction), then continue Kahn's.
139fn assign_layers(graph: &Graph, warnings: &mut Vec<DiagramWarning>) -> HashMap<NodeId, usize> {
140    let mut node_layers: HashMap<NodeId, usize> = HashMap::new();
141    let mut in_degree: HashMap<NodeId, usize> = HashMap::new();
142    let mut processed: HashSet<NodeId> = HashSet::new();
143
144    // Initialize
145    for id in graph.nodes.keys() {
146        in_degree.insert(id.clone(), 0);
147        node_layers.insert(id.clone(), 0);
148    }
149
150    // Count in-degrees
151    for edge in &graph.edges {
152        *in_degree.entry(edge.to.clone()).or_insert(0) += 1;
153    }
154
155    // Build first-appearance-as-from index for deterministic cycle breaking.
156    // Nodes that appear earlier as edge sources are treated as more "source-like"
157    // when breaking cycles.
158    let mut first_from_idx: HashMap<&str, usize> = HashMap::new();
159    for (i, edge) in graph.edges.iter().enumerate() {
160        first_from_idx.entry(edge.from.as_str()).or_insert(i);
161    }
162
163    // Start with nodes that have no incoming edges (sorted for determinism)
164    let mut queue: VecDeque<NodeId> = VecDeque::new();
165    let mut zero_in: Vec<&NodeId> = in_degree
166        .iter()
167        .filter(|(_, &deg)| deg == 0)
168        .map(|(id, _)| id)
169        .collect();
170    zero_in.sort();
171    for id in zero_in {
172        queue.push_back(id.clone());
173    }
174
175    let total = graph.nodes.len();
176    let mut all_cycle_nodes: HashSet<String> = HashSet::new();
177
178    loop {
179        // Standard Kahn's processing
180        while let Some(u) = queue.pop_front() {
181            if processed.contains(&u) {
182                continue;
183            }
184            processed.insert(u.clone());
185
186            // Find neighbors, skipping already-processed nodes
187            let mut neighbors: Vec<NodeId> = graph
188                .edges
189                .iter()
190                .filter(|e| e.from == u && !processed.contains(&e.to))
191                .map(|e| e.to.clone())
192                .collect();
193            neighbors.sort();
194            neighbors.dedup();
195
196            for v in &neighbors {
197                let u_layer = *node_layers.get(&u).unwrap_or(&0);
198                let v_layer = node_layers.entry(v.clone()).or_insert(0);
199                *v_layer = (*v_layer).max(u_layer + 1);
200
201                if let Some(deg) = in_degree.get_mut(v) {
202                    *deg = deg.saturating_sub(1);
203                    if *deg == 0 {
204                        queue.push_back(v.clone());
205                    }
206                }
207            }
208        }
209
210        if processed.len() >= total {
211            break;
212        }
213
214        // Cycle detected — collect stuck nodes
215        let mut stuck: Vec<NodeId> = in_degree
216            .iter()
217            .filter(|(id, _)| !processed.contains(*id))
218            .map(|(id, _)| id.clone())
219            .collect();
220
221        // Record only nodes that have outgoing edges to other stuck nodes
222        // (actual cycle participants, not just downstream nodes)
223        let stuck_set: HashSet<&str> = stuck.iter().map(|s| s.as_str()).collect();
224        for n in &stuck {
225            let has_outgoing_to_stuck = graph
226                .edges
227                .iter()
228                .any(|e| e.from == *n && stuck_set.contains(e.to.as_str()));
229            if has_outgoing_to_stuck {
230                all_cycle_nodes.insert(n.clone());
231            }
232        }
233
234        // Force-process the stuck node that appears earliest as an edge source
235        stuck.sort_by(|a, b| {
236            let fa = first_from_idx.get(a.as_str()).copied().unwrap_or(usize::MAX);
237            let fb = first_from_idx.get(b.as_str()).copied().unwrap_or(usize::MAX);
238            fa.cmp(&fb).then(a.cmp(b))
239        });
240
241        if let Some(forced) = stuck.first() {
242            in_degree.insert(forced.clone(), 0);
243            queue.push_back(forced.clone());
244        }
245    }
246
247    if !all_cycle_nodes.is_empty() {
248        let mut cycle_nodes: Vec<String> = all_cycle_nodes.into_iter().collect();
249        cycle_nodes.sort();
250        warnings.push(DiagramWarning::CycleDetected { nodes: cycle_nodes });
251    }
252
253    node_layers
254}
255
256/// Assign x,y coordinates based on layers and direction with configurable gaps
257fn assign_coordinates_with_gaps(
258    graph: &mut Graph,
259    node_layers: &HashMap<NodeId, usize>,
260    h_gap: usize,
261    v_gap: usize,
262) {
263    let direction = graph.direction;
264
265    // Group nodes by layer, sort within each layer for determinism
266    let mut layers_map: HashMap<usize, Vec<NodeId>> = HashMap::new();
267    let mut max_layer = 0;
268
269    for (id, &layer) in node_layers {
270        layers_map.entry(layer).or_default().push(id.clone());
271        max_layer = max_layer.max(layer);
272    }
273    for nodes in layers_map.values_mut() {
274        nodes.sort();
275    }
276
277    // Calculate layer dimensions
278    let mut layer_widths: HashMap<usize, usize> = HashMap::new();
279    let mut layer_heights: HashMap<usize, usize> = HashMap::new();
280
281    for l in 0..=max_layer {
282        let nodes_in_layer = layers_map.get(&l).map(|v| v.as_slice()).unwrap_or(&[]);
283        let mut max_w = 0;
284        let mut max_h = 0;
285        let mut total_w = 0;
286        let mut total_h = 0;
287
288        for id in nodes_in_layer {
289            if let Some(node) = graph.nodes.get(id) {
290                max_w = max_w.max(node.width);
291                max_h = max_h.max(node.height);
292                total_w += node.width + h_gap;
293                total_h += node.height + v_gap;
294            }
295        }
296
297        if direction.is_horizontal() {
298            layer_widths.insert(l, max_w);
299            layer_heights.insert(l, total_h.saturating_sub(v_gap));
300        } else {
301            layer_widths.insert(l, total_w.saturating_sub(h_gap));
302            layer_heights.insert(l, max_h);
303        }
304    }
305
306    let max_total_width = layer_widths.values().copied().max().unwrap_or(0);
307    let max_total_height = layer_heights.values().copied().max().unwrap_or(0);
308
309    if direction.is_horizontal() {
310        let mut current_x = 0;
311        for l in 0..=max_layer {
312            let layer_idx = match direction {
313                Direction::LR => l,
314                Direction::RL => max_layer - l,
315                _ => l,
316            };
317
318            let nodes_in_layer = layers_map.get(&layer_idx).cloned().unwrap_or_default();
319            let layer_h = *layer_heights.get(&layer_idx).unwrap_or(&0);
320            let mut start_y = (max_total_height.saturating_sub(layer_h)) / 2;
321
322            for id in nodes_in_layer {
323                if let Some(node) = graph.nodes.get_mut(&id) {
324                    node.x = current_x;
325                    node.y = start_y;
326                    start_y += node.height + v_gap;
327                }
328            }
329
330            current_x += layer_widths.get(&layer_idx).unwrap_or(&0) + h_gap;
331        }
332    } else {
333        let mut current_y = 0;
334        for l in 0..=max_layer {
335            let layer_idx = match direction {
336                Direction::TB => l,
337                Direction::BT => max_layer - l,
338                _ => l,
339            };
340
341            let nodes_in_layer = layers_map.get(&layer_idx).cloned().unwrap_or_default();
342            let layer_w = *layer_widths.get(&layer_idx).unwrap_or(&0);
343            let mut start_x = (max_total_width.saturating_sub(layer_w)) / 2;
344
345            for id in nodes_in_layer {
346                if let Some(node) = graph.nodes.get_mut(&id) {
347                    node.x = start_x;
348                    node.y = current_y;
349                    start_x += node.width + h_gap;
350                }
351            }
352
353            current_y += layer_heights.get(&layer_idx).unwrap_or(&0) + v_gap;
354        }
355    }
356}
357
358#[cfg(test)]
359mod tests {
360    use super::*;
361    use crate::parser::parse_mermaid;
362
363    #[test]
364    fn test_layout_lr() {
365        let mut graph = parse_mermaid("flowchart LR\nA --> B").unwrap();
366        let warnings = compute_layout(&mut graph);
367
368        let a = graph.nodes.get("A").unwrap();
369        let b = graph.nodes.get("B").unwrap();
370
371        assert!(a.x < b.x);
372        assert!(warnings.is_empty());
373    }
374
375    #[test]
376    fn test_layout_tb() {
377        let mut graph = parse_mermaid("flowchart TB\nA --> B").unwrap();
378        let warnings = compute_layout(&mut graph);
379
380        let a = graph.nodes.get("A").unwrap();
381        let b = graph.nodes.get("B").unwrap();
382
383        assert!(a.y < b.y);
384        assert!(warnings.is_empty());
385    }
386
387    #[test]
388    fn test_node_sizes() {
389        let mut graph = parse_mermaid("flowchart LR\nA[Hello World]").unwrap();
390        compute_layout(&mut graph);
391
392        let a = graph.nodes.get("A").unwrap();
393        assert_eq!(a.width, "Hello World".len() + 2);
394        assert_eq!(a.height, NODE_HEIGHT);
395    }
396
397    #[test]
398    fn test_cycle_produces_warning() {
399        let mut graph = parse_mermaid("flowchart LR\nA --> B\nB --> C\nC --> A").unwrap();
400        let warnings = compute_layout(&mut graph);
401        assert_eq!(warnings.len(), 1);
402        assert!(warnings[0].to_string().contains("Cycle"));
403    }
404
405    #[test]
406    fn test_acyclic_no_warning() {
407        let mut graph = parse_mermaid("flowchart LR\nA --> B\nB --> C\nA --> C").unwrap();
408        let warnings = compute_layout(&mut graph);
409        assert!(warnings.is_empty());
410    }
411}