Skip to main content

graphs_tui/
layout.rs

1use crate::types::{Direction, Graph, NodeId, RenderOptions};
2use std::collections::{HashMap, VecDeque};
3
4const MIN_NODE_WIDTH: usize = 5;
5const NODE_HEIGHT: usize = 3;
6const DEFAULT_HORIZONTAL_GAP: usize = 8;
7const DEFAULT_VERTICAL_GAP: usize = 4;
8const MIN_GAP: usize = 2;
9
10const SUBGRAPH_PADDING: usize = 2;
11
12/// Compute layout for all nodes in the graph
13///
14/// Returns a list of warnings (e.g., cycle detected).
15pub fn compute_layout(graph: &mut Graph) -> Vec<String> {
16    compute_layout_with_options(graph, &RenderOptions::default())
17}
18
19/// Compute layout for all nodes with render options (considers max_width)
20///
21/// Returns a list of warnings (e.g., cycle detected).
22pub fn compute_layout_with_options(graph: &mut Graph, options: &RenderOptions) -> Vec<String> {
23    let mut warnings = Vec::new();
24
25    // 1. Compute node sizes (use chars().count() for proper Unicode handling)
26    for node in graph.nodes.values_mut() {
27        node.width = (node.label.chars().count() + 2).max(MIN_NODE_WIDTH);
28        node.height = NODE_HEIGHT;
29    }
30
31    // 2. Topological layering
32    let layers = assign_layers(graph, &mut warnings);
33
34    // 3. Calculate gaps based on available width
35    let (h_gap, v_gap) = calculate_gaps(graph, &layers, options.max_width);
36
37    // 4. Position assignment based on direction with calculated gaps
38    assign_coordinates_with_gaps(graph, &layers, h_gap, v_gap);
39
40    // 5. Compute subgraph bounding boxes
41    compute_subgraph_bounds(graph);
42
43    warnings
44}
45
46/// Calculate adaptive gaps based on available width
47fn calculate_gaps(
48    graph: &Graph,
49    layers: &HashMap<NodeId, usize>,
50    max_width: Option<usize>,
51) -> (usize, usize) {
52    let max_width = match max_width {
53        Some(w) => w,
54        None => return (DEFAULT_HORIZONTAL_GAP, DEFAULT_VERTICAL_GAP),
55    };
56
57    // Group nodes by layer
58    let mut layers_map: HashMap<usize, Vec<&NodeId>> = HashMap::new();
59    let mut max_layer = 0;
60
61    for (id, &layer) in layers {
62        layers_map.entry(layer).or_default().push(id);
63        max_layer = max_layer.max(layer);
64    }
65
66    // Calculate natural width with default gaps (for horizontal layouts)
67    if graph.direction.is_horizontal() {
68        let mut total_width = 0;
69        for l in 0..=max_layer {
70            let nodes_in_layer = layers_map.get(&l).map(|v| v.as_slice()).unwrap_or(&[]);
71            let layer_max_width = nodes_in_layer
72                .iter()
73                .filter_map(|id| graph.nodes.get(*id))
74                .map(|n| n.width)
75                .max()
76                .unwrap_or(0);
77            total_width += layer_max_width;
78        }
79        total_width += max_layer * DEFAULT_HORIZONTAL_GAP;
80
81        // If natural width exceeds max_width, reduce horizontal gap
82        if total_width > max_width && max_layer > 0 {
83            let node_width = total_width - max_layer * DEFAULT_HORIZONTAL_GAP;
84            let available_for_gaps = max_width.saturating_sub(node_width);
85            let new_gap = (available_for_gaps / max_layer).max(MIN_GAP);
86            return (new_gap, DEFAULT_VERTICAL_GAP);
87        }
88    }
89
90    (DEFAULT_HORIZONTAL_GAP, DEFAULT_VERTICAL_GAP)
91}
92
93/// Compute bounding boxes for all subgraphs
94fn compute_subgraph_bounds(graph: &mut Graph) {
95    for sg in &mut graph.subgraphs {
96        if sg.nodes.is_empty() {
97            continue;
98        }
99
100        let mut min_x = usize::MAX;
101        let mut min_y = usize::MAX;
102        let mut max_x = 0;
103        let mut max_y = 0;
104
105        for node_id in &sg.nodes {
106            if let Some(node) = graph.nodes.get(node_id) {
107                min_x = min_x.min(node.x);
108                min_y = min_y.min(node.y);
109                max_x = max_x.max(node.x + node.width);
110                max_y = max_y.max(node.y + node.height);
111            }
112        }
113
114        if min_x != usize::MAX {
115            // Add padding around the subgraph
116            sg.x = min_x.saturating_sub(SUBGRAPH_PADDING);
117            sg.y = min_y.saturating_sub(SUBGRAPH_PADDING + 1); // Extra space for label
118            sg.width = (max_x - min_x) + SUBGRAPH_PADDING * 2;
119            sg.height = (max_y - min_y) + SUBGRAPH_PADDING * 2 + 1;
120        }
121    }
122}
123
124/// Assign layer numbers using Kahn's algorithm
125fn assign_layers(graph: &Graph, warnings: &mut Vec<String>) -> HashMap<NodeId, usize> {
126    let mut node_layers: HashMap<NodeId, usize> = HashMap::new();
127    let mut in_degree: HashMap<NodeId, usize> = HashMap::new();
128
129    // Initialize
130    for id in graph.nodes.keys() {
131        in_degree.insert(id.clone(), 0);
132        node_layers.insert(id.clone(), 0);
133    }
134
135    // Count in-degrees
136    for edge in &graph.edges {
137        *in_degree.entry(edge.to.clone()).or_insert(0) += 1;
138    }
139
140    // Start with nodes that have no incoming edges
141    let mut queue: VecDeque<NodeId> = VecDeque::new();
142    for (id, &degree) in &in_degree {
143        if degree == 0 {
144            queue.push_back(id.clone());
145        }
146    }
147
148    let mut processed = 0;
149    while let Some(u) = queue.pop_front() {
150        processed += 1;
151
152        // Find all neighbors (nodes that u points to)
153        let neighbors: Vec<NodeId> = graph
154            .edges
155            .iter()
156            .filter(|e| e.from == u)
157            .map(|e| e.to.clone())
158            .collect();
159
160        for v in neighbors {
161            // Update layer to be at least one more than predecessor
162            let u_layer = *node_layers.get(&u).unwrap_or(&0);
163            let v_layer = node_layers.entry(v.clone()).or_insert(0);
164            *v_layer = (*v_layer).max(u_layer + 1);
165
166            // Decrement in-degree
167            if let Some(deg) = in_degree.get_mut(&v) {
168                *deg -= 1;
169                if *deg == 0 {
170                    queue.push_back(v);
171                }
172            }
173        }
174    }
175
176    // Check for cycles
177    if processed < graph.nodes.len() {
178        warnings.push("Cycle detected in graph. Layout may be imperfect.".to_string());
179    }
180
181    node_layers
182}
183
184/// Assign x,y coordinates based on layers and direction with configurable gaps
185fn assign_coordinates_with_gaps(
186    graph: &mut Graph,
187    node_layers: &HashMap<NodeId, usize>,
188    h_gap: usize,
189    v_gap: usize,
190) {
191    let direction = graph.direction;
192
193    // Group nodes by layer
194    let mut layers_map: HashMap<usize, Vec<NodeId>> = HashMap::new();
195    let mut max_layer = 0;
196
197    for (id, &layer) in node_layers {
198        layers_map.entry(layer).or_default().push(id.clone());
199        max_layer = max_layer.max(layer);
200    }
201
202    // Calculate layer dimensions
203    let mut layer_widths: HashMap<usize, usize> = HashMap::new();
204    let mut layer_heights: HashMap<usize, usize> = HashMap::new();
205
206    for l in 0..=max_layer {
207        let nodes_in_layer = layers_map.get(&l).map(|v| v.as_slice()).unwrap_or(&[]);
208        let mut max_w = 0;
209        let mut max_h = 0;
210        let mut total_w = 0;
211        let mut total_h = 0;
212
213        for id in nodes_in_layer {
214            if let Some(node) = graph.nodes.get(id) {
215                max_w = max_w.max(node.width);
216                max_h = max_h.max(node.height);
217                total_w += node.width + h_gap;
218                total_h += node.height + v_gap;
219            }
220        }
221
222        if direction.is_horizontal() {
223            layer_widths.insert(l, max_w);
224            layer_heights.insert(l, total_h.saturating_sub(v_gap));
225        } else {
226            layer_widths.insert(l, total_w.saturating_sub(h_gap));
227            layer_heights.insert(l, max_h);
228        }
229    }
230
231    let max_total_width = layer_widths.values().copied().max().unwrap_or(0);
232    let max_total_height = layer_heights.values().copied().max().unwrap_or(0);
233
234    if direction.is_horizontal() {
235        let mut current_x = 0;
236        for l in 0..=max_layer {
237            let layer_idx = match direction {
238                Direction::LR => l,
239                Direction::RL => max_layer - l,
240                _ => l,
241            };
242
243            let nodes_in_layer = layers_map.get(&layer_idx).cloned().unwrap_or_default();
244            let layer_h = *layer_heights.get(&layer_idx).unwrap_or(&0);
245            let mut start_y = (max_total_height.saturating_sub(layer_h)) / 2;
246
247            for id in nodes_in_layer {
248                if let Some(node) = graph.nodes.get_mut(&id) {
249                    node.x = current_x;
250                    node.y = start_y;
251                    start_y += node.height + v_gap;
252                }
253            }
254
255            current_x += layer_widths.get(&layer_idx).unwrap_or(&0) + h_gap;
256        }
257    } else {
258        let mut current_y = 0;
259        for l in 0..=max_layer {
260            let layer_idx = match direction {
261                Direction::TB => l,
262                Direction::BT => max_layer - l,
263                _ => l,
264            };
265
266            let nodes_in_layer = layers_map.get(&layer_idx).cloned().unwrap_or_default();
267            let layer_w = *layer_widths.get(&layer_idx).unwrap_or(&0);
268            let mut start_x = (max_total_width.saturating_sub(layer_w)) / 2;
269
270            for id in nodes_in_layer {
271                if let Some(node) = graph.nodes.get_mut(&id) {
272                    node.x = start_x;
273                    node.y = current_y;
274                    start_x += node.width + h_gap;
275                }
276            }
277
278            current_y += layer_heights.get(&layer_idx).unwrap_or(&0) + v_gap;
279        }
280    }
281}
282
283#[cfg(test)]
284mod tests {
285    use super::*;
286    use crate::parser::parse_mermaid;
287
288    #[test]
289    fn test_layout_lr() {
290        let mut graph = parse_mermaid("flowchart LR\nA --> B").unwrap();
291        let warnings = compute_layout(&mut graph);
292
293        let a = graph.nodes.get("A").unwrap();
294        let b = graph.nodes.get("B").unwrap();
295
296        assert!(a.x < b.x);
297        assert!(warnings.is_empty());
298    }
299
300    #[test]
301    fn test_layout_tb() {
302        let mut graph = parse_mermaid("flowchart TB\nA --> B").unwrap();
303        let warnings = compute_layout(&mut graph);
304
305        let a = graph.nodes.get("A").unwrap();
306        let b = graph.nodes.get("B").unwrap();
307
308        assert!(a.y < b.y);
309        assert!(warnings.is_empty());
310    }
311
312    #[test]
313    fn test_node_sizes() {
314        let mut graph = parse_mermaid("flowchart LR\nA[Hello World]").unwrap();
315        compute_layout(&mut graph);
316
317        let a = graph.nodes.get("A").unwrap();
318        assert_eq!(a.width, "Hello World".len() + 2);
319        assert_eq!(a.height, NODE_HEIGHT);
320    }
321
322    #[test]
323    fn test_cycle_produces_warning() {
324        let mut graph = parse_mermaid("flowchart LR\nA --> B\nB --> C\nC --> A").unwrap();
325        let warnings = compute_layout(&mut graph);
326        assert_eq!(warnings.len(), 1);
327        assert!(warnings[0].contains("Cycle"));
328    }
329
330    #[test]
331    fn test_acyclic_no_warning() {
332        let mut graph = parse_mermaid("flowchart LR\nA --> B\nB --> C\nA --> C").unwrap();
333        let warnings = compute_layout(&mut graph);
334        assert!(warnings.is_empty());
335    }
336}