1use crate::types::{DiagramWarning, Direction, Graph, NodeId, NodeShape, RenderOptions};
2use std::collections::{HashMap, HashSet, VecDeque};
3
4const MIN_NODE_WIDTH: usize = 5;
5const NODE_HEIGHT: usize = 3;
6const DEFAULT_HORIZONTAL_GAP: usize = 8;
7const DEFAULT_VERTICAL_GAP: usize = 4;
8const MIN_GAP: usize = 2;
9
10const SUBGRAPH_PADDING: usize = 2;
11
12pub fn compute_layout(graph: &mut Graph) -> Vec<DiagramWarning> {
16 compute_layout_with_options(graph, &RenderOptions::default())
17}
18
19pub fn compute_layout_with_options(
23 graph: &mut Graph,
24 options: &RenderOptions,
25) -> Vec<DiagramWarning> {
26 let mut warnings = Vec::new();
27
28 for node in graph.nodes.values_mut() {
30 node.width = (node.label.chars().count() + 2).max(MIN_NODE_WIDTH);
31 node.height = NODE_HEIGHT;
32 if node.shape == NodeShape::Cylinder {
33 node.height = 5;
34 }
35 }
36
37 let layers = assign_layers(graph, &mut warnings);
39
40 let (h_gap, v_gap) = calculate_gaps(graph, &layers, options.max_width);
42
43 assign_coordinates_with_gaps(graph, &layers, h_gap, v_gap);
45
46 compute_subgraph_bounds(graph);
48
49 warnings
50}
51
52fn calculate_gaps(
54 graph: &Graph,
55 layers: &HashMap<NodeId, usize>,
56 max_width: Option<usize>,
57) -> (usize, usize) {
58 let max_width = match max_width {
59 Some(w) => w,
60 None => return (DEFAULT_HORIZONTAL_GAP, DEFAULT_VERTICAL_GAP),
61 };
62
63 let mut layers_map: HashMap<usize, Vec<&NodeId>> = HashMap::new();
65 let mut max_layer = 0;
66
67 for (id, &layer) in layers {
68 layers_map.entry(layer).or_default().push(id);
69 max_layer = max_layer.max(layer);
70 }
71 for nodes in layers_map.values_mut() {
72 nodes.sort();
73 }
74
75 if graph.direction.is_horizontal() {
77 let mut total_width = 0;
78 for l in 0..=max_layer {
79 let nodes_in_layer = layers_map.get(&l).map(|v| v.as_slice()).unwrap_or(&[]);
80 let layer_max_width = nodes_in_layer
81 .iter()
82 .filter_map(|id| graph.nodes.get(*id))
83 .map(|n| n.width)
84 .max()
85 .unwrap_or(0);
86 total_width += layer_max_width;
87 }
88 total_width += max_layer * DEFAULT_HORIZONTAL_GAP;
89
90 if total_width > max_width && max_layer > 0 {
92 let node_width = total_width - max_layer * DEFAULT_HORIZONTAL_GAP;
93 let available_for_gaps = max_width.saturating_sub(node_width);
94 let new_gap = (available_for_gaps / max_layer).max(MIN_GAP);
95 return (new_gap, DEFAULT_VERTICAL_GAP);
96 }
97 }
98
99 (DEFAULT_HORIZONTAL_GAP, DEFAULT_VERTICAL_GAP)
100}
101
102fn compute_subgraph_bounds(graph: &mut Graph) {
104 for sg in &mut graph.subgraphs {
105 if sg.nodes.is_empty() {
106 continue;
107 }
108
109 let mut min_x = usize::MAX;
110 let mut min_y = usize::MAX;
111 let mut max_x = 0;
112 let mut max_y = 0;
113
114 for node_id in &sg.nodes {
115 if let Some(node) = graph.nodes.get(node_id) {
116 min_x = min_x.min(node.x);
117 min_y = min_y.min(node.y);
118 max_x = max_x.max(node.x + node.width);
119 max_y = max_y.max(node.y + node.height);
120 }
121 }
122
123 if min_x != usize::MAX {
124 sg.x = min_x.saturating_sub(SUBGRAPH_PADDING);
126 sg.y = min_y.saturating_sub(SUBGRAPH_PADDING + 1); sg.width = (max_x - min_x) + SUBGRAPH_PADDING * 2;
128 sg.height = (max_y - min_y) + SUBGRAPH_PADDING * 2 + 1;
129 }
130 }
131}
132
133fn assign_layers(graph: &Graph, warnings: &mut Vec<DiagramWarning>) -> HashMap<NodeId, usize> {
140 let mut node_layers: HashMap<NodeId, usize> = HashMap::new();
141 let mut in_degree: HashMap<NodeId, usize> = HashMap::new();
142 let mut processed: HashSet<NodeId> = HashSet::new();
143
144 for id in graph.nodes.keys() {
146 in_degree.insert(id.clone(), 0);
147 node_layers.insert(id.clone(), 0);
148 }
149
150 for edge in &graph.edges {
152 *in_degree.entry(edge.to.clone()).or_insert(0) += 1;
153 }
154
155 let mut first_from_idx: HashMap<&str, usize> = HashMap::new();
159 for (i, edge) in graph.edges.iter().enumerate() {
160 first_from_idx.entry(edge.from.as_str()).or_insert(i);
161 }
162
163 let mut queue: VecDeque<NodeId> = VecDeque::new();
165 let mut zero_in: Vec<&NodeId> = in_degree
166 .iter()
167 .filter(|(_, °)| deg == 0)
168 .map(|(id, _)| id)
169 .collect();
170 zero_in.sort();
171 for id in zero_in {
172 queue.push_back(id.clone());
173 }
174
175 let total = graph.nodes.len();
176 let mut all_cycle_nodes: HashSet<String> = HashSet::new();
177
178 loop {
179 while let Some(u) = queue.pop_front() {
181 if processed.contains(&u) {
182 continue;
183 }
184 processed.insert(u.clone());
185
186 let mut neighbors: Vec<NodeId> = graph
188 .edges
189 .iter()
190 .filter(|e| e.from == u && !processed.contains(&e.to))
191 .map(|e| e.to.clone())
192 .collect();
193 neighbors.sort();
194 neighbors.dedup();
195
196 for v in &neighbors {
197 let u_layer = *node_layers.get(&u).unwrap_or(&0);
198 let v_layer = node_layers.entry(v.clone()).or_insert(0);
199 *v_layer = (*v_layer).max(u_layer + 1);
200
201 if let Some(deg) = in_degree.get_mut(v) {
202 *deg = deg.saturating_sub(1);
203 if *deg == 0 {
204 queue.push_back(v.clone());
205 }
206 }
207 }
208 }
209
210 if processed.len() >= total {
211 break;
212 }
213
214 let mut stuck: Vec<NodeId> = in_degree
216 .iter()
217 .filter(|(id, _)| !processed.contains(*id))
218 .map(|(id, _)| id.clone())
219 .collect();
220
221 let stuck_set: HashSet<&str> = stuck.iter().map(|s| s.as_str()).collect();
224 for n in &stuck {
225 let has_outgoing_to_stuck = graph
226 .edges
227 .iter()
228 .any(|e| e.from == *n && stuck_set.contains(e.to.as_str()));
229 if has_outgoing_to_stuck {
230 all_cycle_nodes.insert(n.clone());
231 }
232 }
233
234 stuck.sort_by(|a, b| {
236 let fa = first_from_idx.get(a.as_str()).copied().unwrap_or(usize::MAX);
237 let fb = first_from_idx.get(b.as_str()).copied().unwrap_or(usize::MAX);
238 fa.cmp(&fb).then(a.cmp(b))
239 });
240
241 if let Some(forced) = stuck.first() {
242 in_degree.insert(forced.clone(), 0);
243 queue.push_back(forced.clone());
244 }
245 }
246
247 if !all_cycle_nodes.is_empty() {
248 let mut cycle_nodes: Vec<String> = all_cycle_nodes.into_iter().collect();
249 cycle_nodes.sort();
250 warnings.push(DiagramWarning::CycleDetected { nodes: cycle_nodes });
251 }
252
253 node_layers
254}
255
256fn assign_coordinates_with_gaps(
258 graph: &mut Graph,
259 node_layers: &HashMap<NodeId, usize>,
260 h_gap: usize,
261 v_gap: usize,
262) {
263 let direction = graph.direction;
264
265 let mut layers_map: HashMap<usize, Vec<NodeId>> = HashMap::new();
267 let mut max_layer = 0;
268
269 for (id, &layer) in node_layers {
270 layers_map.entry(layer).or_default().push(id.clone());
271 max_layer = max_layer.max(layer);
272 }
273 for nodes in layers_map.values_mut() {
274 nodes.sort();
275 }
276
277 let mut layer_widths: HashMap<usize, usize> = HashMap::new();
279 let mut layer_heights: HashMap<usize, usize> = HashMap::new();
280
281 for l in 0..=max_layer {
282 let nodes_in_layer = layers_map.get(&l).map(|v| v.as_slice()).unwrap_or(&[]);
283 let mut max_w = 0;
284 let mut max_h = 0;
285 let mut total_w = 0;
286 let mut total_h = 0;
287
288 for id in nodes_in_layer {
289 if let Some(node) = graph.nodes.get(id) {
290 max_w = max_w.max(node.width);
291 max_h = max_h.max(node.height);
292 total_w += node.width + h_gap;
293 total_h += node.height + v_gap;
294 }
295 }
296
297 if direction.is_horizontal() {
298 layer_widths.insert(l, max_w);
299 layer_heights.insert(l, total_h.saturating_sub(v_gap));
300 } else {
301 layer_widths.insert(l, total_w.saturating_sub(h_gap));
302 layer_heights.insert(l, max_h);
303 }
304 }
305
306 let max_total_width = layer_widths.values().copied().max().unwrap_or(0);
307 let max_total_height = layer_heights.values().copied().max().unwrap_or(0);
308
309 if direction.is_horizontal() {
310 let mut current_x = 0;
311 for l in 0..=max_layer {
312 let layer_idx = match direction {
313 Direction::LR => l,
314 Direction::RL => max_layer - l,
315 _ => l,
316 };
317
318 let nodes_in_layer = layers_map.get(&layer_idx).cloned().unwrap_or_default();
319 let layer_h = *layer_heights.get(&layer_idx).unwrap_or(&0);
320 let mut start_y = (max_total_height.saturating_sub(layer_h)) / 2;
321
322 for id in nodes_in_layer {
323 if let Some(node) = graph.nodes.get_mut(&id) {
324 node.x = current_x;
325 node.y = start_y;
326 start_y += node.height + v_gap;
327 }
328 }
329
330 current_x += layer_widths.get(&layer_idx).unwrap_or(&0) + h_gap;
331 }
332 } else {
333 let mut current_y = 0;
334 for l in 0..=max_layer {
335 let layer_idx = match direction {
336 Direction::TB => l,
337 Direction::BT => max_layer - l,
338 _ => l,
339 };
340
341 let nodes_in_layer = layers_map.get(&layer_idx).cloned().unwrap_or_default();
342 let layer_w = *layer_widths.get(&layer_idx).unwrap_or(&0);
343 let mut start_x = (max_total_width.saturating_sub(layer_w)) / 2;
344
345 for id in nodes_in_layer {
346 if let Some(node) = graph.nodes.get_mut(&id) {
347 node.x = start_x;
348 node.y = current_y;
349 start_x += node.width + h_gap;
350 }
351 }
352
353 current_y += layer_heights.get(&layer_idx).unwrap_or(&0) + v_gap;
354 }
355 }
356}
357
358#[cfg(test)]
359mod tests {
360 use super::*;
361 use crate::parser::parse_mermaid;
362
363 #[test]
364 fn test_layout_lr() {
365 let mut graph = parse_mermaid("flowchart LR\nA --> B").unwrap();
366 let warnings = compute_layout(&mut graph);
367
368 let a = graph.nodes.get("A").unwrap();
369 let b = graph.nodes.get("B").unwrap();
370
371 assert!(a.x < b.x);
372 assert!(warnings.is_empty());
373 }
374
375 #[test]
376 fn test_layout_tb() {
377 let mut graph = parse_mermaid("flowchart TB\nA --> B").unwrap();
378 let warnings = compute_layout(&mut graph);
379
380 let a = graph.nodes.get("A").unwrap();
381 let b = graph.nodes.get("B").unwrap();
382
383 assert!(a.y < b.y);
384 assert!(warnings.is_empty());
385 }
386
387 #[test]
388 fn test_node_sizes() {
389 let mut graph = parse_mermaid("flowchart LR\nA[Hello World]").unwrap();
390 compute_layout(&mut graph);
391
392 let a = graph.nodes.get("A").unwrap();
393 assert_eq!(a.width, "Hello World".len() + 2);
394 assert_eq!(a.height, NODE_HEIGHT);
395 }
396
397 #[test]
398 fn test_cycle_produces_warning() {
399 let mut graph = parse_mermaid("flowchart LR\nA --> B\nB --> C\nC --> A").unwrap();
400 let warnings = compute_layout(&mut graph);
401 assert_eq!(warnings.len(), 1);
402 assert!(warnings[0].to_string().contains("Cycle"));
403 }
404
405 #[test]
406 fn test_acyclic_no_warning() {
407 let mut graph = parse_mermaid("flowchart LR\nA --> B\nB --> C\nA --> C").unwrap();
408 let warnings = compute_layout(&mut graph);
409 assert!(warnings.is_empty());
410 }
411}