1use crate::types::{Direction, Graph, NodeId, RenderOptions};
2use std::collections::{HashMap, VecDeque};
3
4const MIN_NODE_WIDTH: usize = 5;
5const NODE_HEIGHT: usize = 3;
6const DEFAULT_HORIZONTAL_GAP: usize = 8;
7const DEFAULT_VERTICAL_GAP: usize = 4;
8const MIN_GAP: usize = 2;
9
10const SUBGRAPH_PADDING: usize = 2;
11
12pub fn compute_layout(graph: &mut Graph) {
14 compute_layout_with_options(graph, &RenderOptions::default());
15}
16
17pub fn compute_layout_with_options(graph: &mut Graph, options: &RenderOptions) {
19 for node in graph.nodes.values_mut() {
21 node.width = (node.label.chars().count() + 2).max(MIN_NODE_WIDTH);
22 node.height = NODE_HEIGHT;
23 }
24
25 let layers = assign_layers(graph);
27
28 let (h_gap, v_gap) = calculate_gaps(graph, &layers, options.max_width);
30
31 assign_coordinates_with_gaps(graph, &layers, h_gap, v_gap);
33
34 compute_subgraph_bounds(graph);
36}
37
38fn calculate_gaps(
40 graph: &Graph,
41 layers: &HashMap<NodeId, usize>,
42 max_width: Option<usize>,
43) -> (usize, usize) {
44 let max_width = match max_width {
45 Some(w) => w,
46 None => return (DEFAULT_HORIZONTAL_GAP, DEFAULT_VERTICAL_GAP),
47 };
48
49 let mut layers_map: HashMap<usize, Vec<&NodeId>> = HashMap::new();
51 let mut max_layer = 0;
52
53 for (id, &layer) in layers {
54 layers_map.entry(layer).or_default().push(id);
55 max_layer = max_layer.max(layer);
56 }
57
58 if graph.direction.is_horizontal() {
60 let mut total_width = 0;
61 for l in 0..=max_layer {
62 let nodes_in_layer = layers_map.get(&l).map(|v| v.as_slice()).unwrap_or(&[]);
63 let layer_max_width = nodes_in_layer
64 .iter()
65 .filter_map(|id| graph.nodes.get(*id))
66 .map(|n| n.width)
67 .max()
68 .unwrap_or(0);
69 total_width += layer_max_width;
70 }
71 total_width += max_layer * DEFAULT_HORIZONTAL_GAP;
72
73 if total_width > max_width && max_layer > 0 {
75 let node_width = total_width - max_layer * DEFAULT_HORIZONTAL_GAP;
76 let available_for_gaps = max_width.saturating_sub(node_width);
77 let new_gap = (available_for_gaps / max_layer).max(MIN_GAP);
78 return (new_gap, DEFAULT_VERTICAL_GAP);
79 }
80 }
81
82 (DEFAULT_HORIZONTAL_GAP, DEFAULT_VERTICAL_GAP)
83}
84
85fn compute_subgraph_bounds(graph: &mut Graph) {
87 for sg in &mut graph.subgraphs {
88 if sg.nodes.is_empty() {
89 continue;
90 }
91
92 let mut min_x = usize::MAX;
93 let mut min_y = usize::MAX;
94 let mut max_x = 0;
95 let mut max_y = 0;
96
97 for node_id in &sg.nodes {
98 if let Some(node) = graph.nodes.get(node_id) {
99 min_x = min_x.min(node.x);
100 min_y = min_y.min(node.y);
101 max_x = max_x.max(node.x + node.width);
102 max_y = max_y.max(node.y + node.height);
103 }
104 }
105
106 if min_x != usize::MAX {
107 sg.x = min_x.saturating_sub(SUBGRAPH_PADDING);
109 sg.y = min_y.saturating_sub(SUBGRAPH_PADDING + 1); sg.width = (max_x - min_x) + SUBGRAPH_PADDING * 2;
111 sg.height = (max_y - min_y) + SUBGRAPH_PADDING * 2 + 1;
112 }
113 }
114}
115
116fn assign_layers(graph: &Graph) -> HashMap<NodeId, usize> {
118 let mut node_layers: HashMap<NodeId, usize> = HashMap::new();
119 let mut in_degree: HashMap<NodeId, usize> = HashMap::new();
120
121 for id in graph.nodes.keys() {
123 in_degree.insert(id.clone(), 0);
124 node_layers.insert(id.clone(), 0);
125 }
126
127 for edge in &graph.edges {
129 *in_degree.entry(edge.to.clone()).or_insert(0) += 1;
130 }
131
132 let mut queue: VecDeque<NodeId> = VecDeque::new();
134 for (id, °ree) in &in_degree {
135 if degree == 0 {
136 queue.push_back(id.clone());
137 }
138 }
139
140 let mut processed = 0;
141 while let Some(u) = queue.pop_front() {
142 processed += 1;
143
144 let neighbors: Vec<NodeId> = graph
146 .edges
147 .iter()
148 .filter(|e| e.from == u)
149 .map(|e| e.to.clone())
150 .collect();
151
152 for v in neighbors {
153 let u_layer = *node_layers.get(&u).unwrap_or(&0);
155 let v_layer = node_layers.entry(v.clone()).or_insert(0);
156 *v_layer = (*v_layer).max(u_layer + 1);
157
158 if let Some(deg) = in_degree.get_mut(&v) {
160 *deg -= 1;
161 if *deg == 0 {
162 queue.push_back(v);
163 }
164 }
165 }
166 }
167
168 if processed < graph.nodes.len() {
170 eprintln!("Warning: Cycle detected in graph. Layout may be imperfect.");
171 }
172
173 node_layers
174}
175
176fn assign_coordinates_with_gaps(
178 graph: &mut Graph,
179 node_layers: &HashMap<NodeId, usize>,
180 h_gap: usize,
181 v_gap: usize,
182) {
183 let direction = graph.direction;
184
185 let mut layers_map: HashMap<usize, Vec<NodeId>> = HashMap::new();
187 let mut max_layer = 0;
188
189 for (id, &layer) in node_layers {
190 layers_map.entry(layer).or_default().push(id.clone());
191 max_layer = max_layer.max(layer);
192 }
193
194 let mut layer_widths: HashMap<usize, usize> = HashMap::new();
196 let mut layer_heights: HashMap<usize, usize> = HashMap::new();
197
198 for l in 0..=max_layer {
199 let nodes_in_layer = layers_map.get(&l).map(|v| v.as_slice()).unwrap_or(&[]);
200 let mut max_w = 0;
201 let mut max_h = 0;
202 let mut total_w = 0;
203 let mut total_h = 0;
204
205 for id in nodes_in_layer {
206 if let Some(node) = graph.nodes.get(id) {
207 max_w = max_w.max(node.width);
208 max_h = max_h.max(node.height);
209 total_w += node.width + h_gap;
210 total_h += node.height + v_gap;
211 }
212 }
213
214 if direction.is_horizontal() {
215 layer_widths.insert(l, max_w);
216 layer_heights.insert(l, total_h.saturating_sub(v_gap));
217 } else {
218 layer_widths.insert(l, total_w.saturating_sub(h_gap));
219 layer_heights.insert(l, max_h);
220 }
221 }
222
223 let max_total_width = layer_widths.values().copied().max().unwrap_or(0);
224 let max_total_height = layer_heights.values().copied().max().unwrap_or(0);
225
226 if direction.is_horizontal() {
227 let mut current_x = 0;
228 for l in 0..=max_layer {
229 let layer_idx = match direction {
230 Direction::LR => l,
231 Direction::RL => max_layer - l,
232 _ => l,
233 };
234
235 let nodes_in_layer = layers_map.get(&layer_idx).cloned().unwrap_or_default();
236 let layer_h = *layer_heights.get(&layer_idx).unwrap_or(&0);
237 let mut start_y = (max_total_height.saturating_sub(layer_h)) / 2;
238
239 for id in nodes_in_layer {
240 if let Some(node) = graph.nodes.get_mut(&id) {
241 node.x = current_x;
242 node.y = start_y;
243 start_y += node.height + v_gap;
244 }
245 }
246
247 current_x += layer_widths.get(&layer_idx).unwrap_or(&0) + h_gap;
248 }
249 } else {
250 let mut current_y = 0;
251 for l in 0..=max_layer {
252 let layer_idx = match direction {
253 Direction::TB => l,
254 Direction::BT => max_layer - l,
255 _ => l,
256 };
257
258 let nodes_in_layer = layers_map.get(&layer_idx).cloned().unwrap_or_default();
259 let layer_w = *layer_widths.get(&layer_idx).unwrap_or(&0);
260 let mut start_x = (max_total_width.saturating_sub(layer_w)) / 2;
261
262 for id in nodes_in_layer {
263 if let Some(node) = graph.nodes.get_mut(&id) {
264 node.x = start_x;
265 node.y = current_y;
266 start_x += node.width + h_gap;
267 }
268 }
269
270 current_y += layer_heights.get(&layer_idx).unwrap_or(&0) + v_gap;
271 }
272 }
273}
274
275#[cfg(test)]
276mod tests {
277 use super::*;
278 use crate::parser::parse_mermaid;
279
280 #[test]
281 fn test_layout_lr() {
282 let mut graph = parse_mermaid("flowchart LR\nA --> B").unwrap();
283 compute_layout(&mut graph);
284
285 let a = graph.nodes.get("A").unwrap();
286 let b = graph.nodes.get("B").unwrap();
287
288 assert!(a.x < b.x);
290 }
291
292 #[test]
293 fn test_layout_tb() {
294 let mut graph = parse_mermaid("flowchart TB\nA --> B").unwrap();
295 compute_layout(&mut graph);
296
297 let a = graph.nodes.get("A").unwrap();
298 let b = graph.nodes.get("B").unwrap();
299
300 assert!(a.y < b.y);
302 }
303
304 #[test]
305 fn test_node_sizes() {
306 let mut graph = parse_mermaid("flowchart LR\nA[Hello World]").unwrap();
307 compute_layout(&mut graph);
308
309 let a = graph.nodes.get("A").unwrap();
310 assert_eq!(a.width, "Hello World".len() + 2);
311 assert_eq!(a.height, NODE_HEIGHT);
312 }
313}