1use crate::types::{Direction, Graph, NodeId, NodeShape, RenderOptions};
2use std::collections::{HashMap, VecDeque};
3
4const MIN_NODE_WIDTH: usize = 5;
5const NODE_HEIGHT: usize = 3;
6const DEFAULT_HORIZONTAL_GAP: usize = 8;
7const DEFAULT_VERTICAL_GAP: usize = 4;
8const MIN_GAP: usize = 2;
9
10const SUBGRAPH_PADDING: usize = 2;
11
12pub fn compute_layout(graph: &mut Graph) -> Vec<String> {
16 compute_layout_with_options(graph, &RenderOptions::default())
17}
18
19pub fn compute_layout_with_options(graph: &mut Graph, options: &RenderOptions) -> Vec<String> {
23 let mut warnings = Vec::new();
24
25 for node in graph.nodes.values_mut() {
27 node.width = (node.label.chars().count() + 2).max(MIN_NODE_WIDTH);
28 node.height = NODE_HEIGHT;
29 if node.shape == NodeShape::Cylinder {
30 node.height = 5;
31 }
32 }
33
34 let layers = assign_layers(graph, &mut warnings);
36
37 let (h_gap, v_gap) = calculate_gaps(graph, &layers, options.max_width);
39
40 assign_coordinates_with_gaps(graph, &layers, h_gap, v_gap);
42
43 compute_subgraph_bounds(graph);
45
46 warnings
47}
48
49fn calculate_gaps(
51 graph: &Graph,
52 layers: &HashMap<NodeId, usize>,
53 max_width: Option<usize>,
54) -> (usize, usize) {
55 let max_width = match max_width {
56 Some(w) => w,
57 None => return (DEFAULT_HORIZONTAL_GAP, DEFAULT_VERTICAL_GAP),
58 };
59
60 let mut layers_map: HashMap<usize, Vec<&NodeId>> = HashMap::new();
62 let mut max_layer = 0;
63
64 for (id, &layer) in layers {
65 layers_map.entry(layer).or_default().push(id);
66 max_layer = max_layer.max(layer);
67 }
68
69 if graph.direction.is_horizontal() {
71 let mut total_width = 0;
72 for l in 0..=max_layer {
73 let nodes_in_layer = layers_map.get(&l).map(|v| v.as_slice()).unwrap_or(&[]);
74 let layer_max_width = nodes_in_layer
75 .iter()
76 .filter_map(|id| graph.nodes.get(*id))
77 .map(|n| n.width)
78 .max()
79 .unwrap_or(0);
80 total_width += layer_max_width;
81 }
82 total_width += max_layer * DEFAULT_HORIZONTAL_GAP;
83
84 if total_width > max_width && max_layer > 0 {
86 let node_width = total_width - max_layer * DEFAULT_HORIZONTAL_GAP;
87 let available_for_gaps = max_width.saturating_sub(node_width);
88 let new_gap = (available_for_gaps / max_layer).max(MIN_GAP);
89 return (new_gap, DEFAULT_VERTICAL_GAP);
90 }
91 }
92
93 (DEFAULT_HORIZONTAL_GAP, DEFAULT_VERTICAL_GAP)
94}
95
96fn compute_subgraph_bounds(graph: &mut Graph) {
98 for sg in &mut graph.subgraphs {
99 if sg.nodes.is_empty() {
100 continue;
101 }
102
103 let mut min_x = usize::MAX;
104 let mut min_y = usize::MAX;
105 let mut max_x = 0;
106 let mut max_y = 0;
107
108 for node_id in &sg.nodes {
109 if let Some(node) = graph.nodes.get(node_id) {
110 min_x = min_x.min(node.x);
111 min_y = min_y.min(node.y);
112 max_x = max_x.max(node.x + node.width);
113 max_y = max_y.max(node.y + node.height);
114 }
115 }
116
117 if min_x != usize::MAX {
118 sg.x = min_x.saturating_sub(SUBGRAPH_PADDING);
120 sg.y = min_y.saturating_sub(SUBGRAPH_PADDING + 1); sg.width = (max_x - min_x) + SUBGRAPH_PADDING * 2;
122 sg.height = (max_y - min_y) + SUBGRAPH_PADDING * 2 + 1;
123 }
124 }
125}
126
127fn assign_layers(graph: &Graph, warnings: &mut Vec<String>) -> HashMap<NodeId, usize> {
129 let mut node_layers: HashMap<NodeId, usize> = HashMap::new();
130 let mut in_degree: HashMap<NodeId, usize> = HashMap::new();
131
132 for id in graph.nodes.keys() {
134 in_degree.insert(id.clone(), 0);
135 node_layers.insert(id.clone(), 0);
136 }
137
138 for edge in &graph.edges {
140 *in_degree.entry(edge.to.clone()).or_insert(0) += 1;
141 }
142
143 let mut queue: VecDeque<NodeId> = VecDeque::new();
145 for (id, °ree) in &in_degree {
146 if degree == 0 {
147 queue.push_back(id.clone());
148 }
149 }
150
151 let mut processed = 0;
152 while let Some(u) = queue.pop_front() {
153 processed += 1;
154
155 let neighbors: Vec<NodeId> = graph
157 .edges
158 .iter()
159 .filter(|e| e.from == u)
160 .map(|e| e.to.clone())
161 .collect();
162
163 for v in neighbors {
164 let u_layer = *node_layers.get(&u).unwrap_or(&0);
166 let v_layer = node_layers.entry(v.clone()).or_insert(0);
167 *v_layer = (*v_layer).max(u_layer + 1);
168
169 if let Some(deg) = in_degree.get_mut(&v) {
171 *deg -= 1;
172 if *deg == 0 {
173 queue.push_back(v);
174 }
175 }
176 }
177 }
178
179 if processed < graph.nodes.len() {
181 warnings.push("Cycle detected in graph. Layout may be imperfect.".to_string());
182 }
183
184 node_layers
185}
186
187fn assign_coordinates_with_gaps(
189 graph: &mut Graph,
190 node_layers: &HashMap<NodeId, usize>,
191 h_gap: usize,
192 v_gap: usize,
193) {
194 let direction = graph.direction;
195
196 let mut layers_map: HashMap<usize, Vec<NodeId>> = HashMap::new();
198 let mut max_layer = 0;
199
200 for (id, &layer) in node_layers {
201 layers_map.entry(layer).or_default().push(id.clone());
202 max_layer = max_layer.max(layer);
203 }
204
205 let mut layer_widths: HashMap<usize, usize> = HashMap::new();
207 let mut layer_heights: HashMap<usize, usize> = HashMap::new();
208
209 for l in 0..=max_layer {
210 let nodes_in_layer = layers_map.get(&l).map(|v| v.as_slice()).unwrap_or(&[]);
211 let mut max_w = 0;
212 let mut max_h = 0;
213 let mut total_w = 0;
214 let mut total_h = 0;
215
216 for id in nodes_in_layer {
217 if let Some(node) = graph.nodes.get(id) {
218 max_w = max_w.max(node.width);
219 max_h = max_h.max(node.height);
220 total_w += node.width + h_gap;
221 total_h += node.height + v_gap;
222 }
223 }
224
225 if direction.is_horizontal() {
226 layer_widths.insert(l, max_w);
227 layer_heights.insert(l, total_h.saturating_sub(v_gap));
228 } else {
229 layer_widths.insert(l, total_w.saturating_sub(h_gap));
230 layer_heights.insert(l, max_h);
231 }
232 }
233
234 let max_total_width = layer_widths.values().copied().max().unwrap_or(0);
235 let max_total_height = layer_heights.values().copied().max().unwrap_or(0);
236
237 if direction.is_horizontal() {
238 let mut current_x = 0;
239 for l in 0..=max_layer {
240 let layer_idx = match direction {
241 Direction::LR => l,
242 Direction::RL => max_layer - l,
243 _ => l,
244 };
245
246 let nodes_in_layer = layers_map.get(&layer_idx).cloned().unwrap_or_default();
247 let layer_h = *layer_heights.get(&layer_idx).unwrap_or(&0);
248 let mut start_y = (max_total_height.saturating_sub(layer_h)) / 2;
249
250 for id in nodes_in_layer {
251 if let Some(node) = graph.nodes.get_mut(&id) {
252 node.x = current_x;
253 node.y = start_y;
254 start_y += node.height + v_gap;
255 }
256 }
257
258 current_x += layer_widths.get(&layer_idx).unwrap_or(&0) + h_gap;
259 }
260 } else {
261 let mut current_y = 0;
262 for l in 0..=max_layer {
263 let layer_idx = match direction {
264 Direction::TB => l,
265 Direction::BT => max_layer - l,
266 _ => l,
267 };
268
269 let nodes_in_layer = layers_map.get(&layer_idx).cloned().unwrap_or_default();
270 let layer_w = *layer_widths.get(&layer_idx).unwrap_or(&0);
271 let mut start_x = (max_total_width.saturating_sub(layer_w)) / 2;
272
273 for id in nodes_in_layer {
274 if let Some(node) = graph.nodes.get_mut(&id) {
275 node.x = start_x;
276 node.y = current_y;
277 start_x += node.width + h_gap;
278 }
279 }
280
281 current_y += layer_heights.get(&layer_idx).unwrap_or(&0) + v_gap;
282 }
283 }
284}
285
286#[cfg(test)]
287mod tests {
288 use super::*;
289 use crate::parser::parse_mermaid;
290
291 #[test]
292 fn test_layout_lr() {
293 let mut graph = parse_mermaid("flowchart LR\nA --> B").unwrap();
294 let warnings = compute_layout(&mut graph);
295
296 let a = graph.nodes.get("A").unwrap();
297 let b = graph.nodes.get("B").unwrap();
298
299 assert!(a.x < b.x);
300 assert!(warnings.is_empty());
301 }
302
303 #[test]
304 fn test_layout_tb() {
305 let mut graph = parse_mermaid("flowchart TB\nA --> B").unwrap();
306 let warnings = compute_layout(&mut graph);
307
308 let a = graph.nodes.get("A").unwrap();
309 let b = graph.nodes.get("B").unwrap();
310
311 assert!(a.y < b.y);
312 assert!(warnings.is_empty());
313 }
314
315 #[test]
316 fn test_node_sizes() {
317 let mut graph = parse_mermaid("flowchart LR\nA[Hello World]").unwrap();
318 compute_layout(&mut graph);
319
320 let a = graph.nodes.get("A").unwrap();
321 assert_eq!(a.width, "Hello World".len() + 2);
322 assert_eq!(a.height, NODE_HEIGHT);
323 }
324
325 #[test]
326 fn test_cycle_produces_warning() {
327 let mut graph = parse_mermaid("flowchart LR\nA --> B\nB --> C\nC --> A").unwrap();
328 let warnings = compute_layout(&mut graph);
329 assert_eq!(warnings.len(), 1);
330 assert!(warnings[0].contains("Cycle"));
331 }
332
333 #[test]
334 fn test_acyclic_no_warning() {
335 let mut graph = parse_mermaid("flowchart LR\nA --> B\nB --> C\nA --> C").unwrap();
336 let warnings = compute_layout(&mut graph);
337 assert!(warnings.is_empty());
338 }
339}