Skip to main content

graphs_tui/
layout.rs

1use crate::types::{Direction, Graph, NodeId, RenderOptions};
2use std::collections::{HashMap, VecDeque};
3
4const MIN_NODE_WIDTH: usize = 5;
5const NODE_HEIGHT: usize = 3;
6const DEFAULT_HORIZONTAL_GAP: usize = 8;
7const DEFAULT_VERTICAL_GAP: usize = 4;
8const MIN_GAP: usize = 2;
9
10const SUBGRAPH_PADDING: usize = 2;
11
12/// Compute layout for all nodes in the graph
13pub fn compute_layout(graph: &mut Graph) {
14    compute_layout_with_options(graph, &RenderOptions::default());
15}
16
17/// Compute layout for all nodes with render options (considers max_width)
18pub fn compute_layout_with_options(graph: &mut Graph, options: &RenderOptions) {
19    // 1. Compute node sizes (use chars().count() for proper Unicode handling)
20    for node in graph.nodes.values_mut() {
21        node.width = (node.label.chars().count() + 2).max(MIN_NODE_WIDTH);
22        node.height = NODE_HEIGHT;
23    }
24
25    // 2. Topological layering
26    let layers = assign_layers(graph);
27
28    // 3. Calculate gaps based on available width
29    let (h_gap, v_gap) = calculate_gaps(graph, &layers, options.max_width);
30
31    // 4. Position assignment based on direction with calculated gaps
32    assign_coordinates_with_gaps(graph, &layers, h_gap, v_gap);
33
34    // 5. Compute subgraph bounding boxes
35    compute_subgraph_bounds(graph);
36}
37
38/// Calculate adaptive gaps based on available width
39fn calculate_gaps(
40    graph: &Graph,
41    layers: &HashMap<NodeId, usize>,
42    max_width: Option<usize>,
43) -> (usize, usize) {
44    let max_width = match max_width {
45        Some(w) => w,
46        None => return (DEFAULT_HORIZONTAL_GAP, DEFAULT_VERTICAL_GAP),
47    };
48
49    // Group nodes by layer
50    let mut layers_map: HashMap<usize, Vec<&NodeId>> = HashMap::new();
51    let mut max_layer = 0;
52
53    for (id, &layer) in layers {
54        layers_map.entry(layer).or_default().push(id);
55        max_layer = max_layer.max(layer);
56    }
57
58    // Calculate natural width with default gaps (for horizontal layouts)
59    if graph.direction.is_horizontal() {
60        let mut total_width = 0;
61        for l in 0..=max_layer {
62            let nodes_in_layer = layers_map.get(&l).map(|v| v.as_slice()).unwrap_or(&[]);
63            let layer_max_width = nodes_in_layer
64                .iter()
65                .filter_map(|id| graph.nodes.get(*id))
66                .map(|n| n.width)
67                .max()
68                .unwrap_or(0);
69            total_width += layer_max_width;
70        }
71        total_width += max_layer * DEFAULT_HORIZONTAL_GAP;
72
73        // If natural width exceeds max_width, reduce horizontal gap
74        if total_width > max_width && max_layer > 0 {
75            let node_width = total_width - max_layer * DEFAULT_HORIZONTAL_GAP;
76            let available_for_gaps = max_width.saturating_sub(node_width);
77            let new_gap = (available_for_gaps / max_layer).max(MIN_GAP);
78            return (new_gap, DEFAULT_VERTICAL_GAP);
79        }
80    }
81
82    (DEFAULT_HORIZONTAL_GAP, DEFAULT_VERTICAL_GAP)
83}
84
85/// Compute bounding boxes for all subgraphs
86fn compute_subgraph_bounds(graph: &mut Graph) {
87    for sg in &mut graph.subgraphs {
88        if sg.nodes.is_empty() {
89            continue;
90        }
91
92        let mut min_x = usize::MAX;
93        let mut min_y = usize::MAX;
94        let mut max_x = 0;
95        let mut max_y = 0;
96
97        for node_id in &sg.nodes {
98            if let Some(node) = graph.nodes.get(node_id) {
99                min_x = min_x.min(node.x);
100                min_y = min_y.min(node.y);
101                max_x = max_x.max(node.x + node.width);
102                max_y = max_y.max(node.y + node.height);
103            }
104        }
105
106        if min_x != usize::MAX {
107            // Add padding around the subgraph
108            sg.x = min_x.saturating_sub(SUBGRAPH_PADDING);
109            sg.y = min_y.saturating_sub(SUBGRAPH_PADDING + 1); // Extra space for label
110            sg.width = (max_x - min_x) + SUBGRAPH_PADDING * 2;
111            sg.height = (max_y - min_y) + SUBGRAPH_PADDING * 2 + 1;
112        }
113    }
114}
115
116/// Assign layer numbers using Kahn's algorithm
117fn assign_layers(graph: &Graph) -> HashMap<NodeId, usize> {
118    let mut node_layers: HashMap<NodeId, usize> = HashMap::new();
119    let mut in_degree: HashMap<NodeId, usize> = HashMap::new();
120
121    // Initialize
122    for id in graph.nodes.keys() {
123        in_degree.insert(id.clone(), 0);
124        node_layers.insert(id.clone(), 0);
125    }
126
127    // Count in-degrees
128    for edge in &graph.edges {
129        *in_degree.entry(edge.to.clone()).or_insert(0) += 1;
130    }
131
132    // Start with nodes that have no incoming edges
133    let mut queue: VecDeque<NodeId> = VecDeque::new();
134    for (id, &degree) in &in_degree {
135        if degree == 0 {
136            queue.push_back(id.clone());
137        }
138    }
139
140    let mut processed = 0;
141    while let Some(u) = queue.pop_front() {
142        processed += 1;
143
144        // Find all neighbors (nodes that u points to)
145        let neighbors: Vec<NodeId> = graph
146            .edges
147            .iter()
148            .filter(|e| e.from == u)
149            .map(|e| e.to.clone())
150            .collect();
151
152        for v in neighbors {
153            // Update layer to be at least one more than predecessor
154            let u_layer = *node_layers.get(&u).unwrap_or(&0);
155            let v_layer = node_layers.entry(v.clone()).or_insert(0);
156            *v_layer = (*v_layer).max(u_layer + 1);
157
158            // Decrement in-degree
159            if let Some(deg) = in_degree.get_mut(&v) {
160                *deg -= 1;
161                if *deg == 0 {
162                    queue.push_back(v);
163                }
164            }
165        }
166    }
167
168    // Check for cycles
169    if processed < graph.nodes.len() {
170        eprintln!("Warning: Cycle detected in graph. Layout may be imperfect.");
171    }
172
173    node_layers
174}
175
176/// Assign x,y coordinates based on layers and direction with configurable gaps
177fn assign_coordinates_with_gaps(
178    graph: &mut Graph,
179    node_layers: &HashMap<NodeId, usize>,
180    h_gap: usize,
181    v_gap: usize,
182) {
183    let direction = graph.direction;
184
185    // Group nodes by layer
186    let mut layers_map: HashMap<usize, Vec<NodeId>> = HashMap::new();
187    let mut max_layer = 0;
188
189    for (id, &layer) in node_layers {
190        layers_map.entry(layer).or_default().push(id.clone());
191        max_layer = max_layer.max(layer);
192    }
193
194    // Calculate layer dimensions
195    let mut layer_widths: HashMap<usize, usize> = HashMap::new();
196    let mut layer_heights: HashMap<usize, usize> = HashMap::new();
197
198    for l in 0..=max_layer {
199        let nodes_in_layer = layers_map.get(&l).map(|v| v.as_slice()).unwrap_or(&[]);
200        let mut max_w = 0;
201        let mut max_h = 0;
202        let mut total_w = 0;
203        let mut total_h = 0;
204
205        for id in nodes_in_layer {
206            if let Some(node) = graph.nodes.get(id) {
207                max_w = max_w.max(node.width);
208                max_h = max_h.max(node.height);
209                total_w += node.width + h_gap;
210                total_h += node.height + v_gap;
211            }
212        }
213
214        if direction.is_horizontal() {
215            layer_widths.insert(l, max_w);
216            layer_heights.insert(l, total_h.saturating_sub(v_gap));
217        } else {
218            layer_widths.insert(l, total_w.saturating_sub(h_gap));
219            layer_heights.insert(l, max_h);
220        }
221    }
222
223    let max_total_width = layer_widths.values().copied().max().unwrap_or(0);
224    let max_total_height = layer_heights.values().copied().max().unwrap_or(0);
225
226    if direction.is_horizontal() {
227        let mut current_x = 0;
228        for l in 0..=max_layer {
229            let layer_idx = match direction {
230                Direction::LR => l,
231                Direction::RL => max_layer - l,
232                _ => l,
233            };
234
235            let nodes_in_layer = layers_map.get(&layer_idx).cloned().unwrap_or_default();
236            let layer_h = *layer_heights.get(&layer_idx).unwrap_or(&0);
237            let mut start_y = (max_total_height.saturating_sub(layer_h)) / 2;
238
239            for id in nodes_in_layer {
240                if let Some(node) = graph.nodes.get_mut(&id) {
241                    node.x = current_x;
242                    node.y = start_y;
243                    start_y += node.height + v_gap;
244                }
245            }
246
247            current_x += layer_widths.get(&layer_idx).unwrap_or(&0) + h_gap;
248        }
249    } else {
250        let mut current_y = 0;
251        for l in 0..=max_layer {
252            let layer_idx = match direction {
253                Direction::TB => l,
254                Direction::BT => max_layer - l,
255                _ => l,
256            };
257
258            let nodes_in_layer = layers_map.get(&layer_idx).cloned().unwrap_or_default();
259            let layer_w = *layer_widths.get(&layer_idx).unwrap_or(&0);
260            let mut start_x = (max_total_width.saturating_sub(layer_w)) / 2;
261
262            for id in nodes_in_layer {
263                if let Some(node) = graph.nodes.get_mut(&id) {
264                    node.x = start_x;
265                    node.y = current_y;
266                    start_x += node.width + h_gap;
267                }
268            }
269
270            current_y += layer_heights.get(&layer_idx).unwrap_or(&0) + v_gap;
271        }
272    }
273}
274
275#[cfg(test)]
276mod tests {
277    use super::*;
278    use crate::parser::parse_mermaid;
279
280    #[test]
281    fn test_layout_lr() {
282        let mut graph = parse_mermaid("flowchart LR\nA --> B").unwrap();
283        compute_layout(&mut graph);
284
285        let a = graph.nodes.get("A").unwrap();
286        let b = graph.nodes.get("B").unwrap();
287
288        // A should be to the left of B
289        assert!(a.x < b.x);
290    }
291
292    #[test]
293    fn test_layout_tb() {
294        let mut graph = parse_mermaid("flowchart TB\nA --> B").unwrap();
295        compute_layout(&mut graph);
296
297        let a = graph.nodes.get("A").unwrap();
298        let b = graph.nodes.get("B").unwrap();
299
300        // A should be above B
301        assert!(a.y < b.y);
302    }
303
304    #[test]
305    fn test_node_sizes() {
306        let mut graph = parse_mermaid("flowchart LR\nA[Hello World]").unwrap();
307        compute_layout(&mut graph);
308
309        let a = graph.nodes.get("A").unwrap();
310        assert_eq!(a.width, "Hello World".len() + 2);
311        assert_eq!(a.height, NODE_HEIGHT);
312    }
313}