1use crate::types::{Direction, Graph, NodeId, RenderOptions};
2use std::collections::{HashMap, VecDeque};
3
4const MIN_NODE_WIDTH: usize = 5;
5const NODE_HEIGHT: usize = 3;
6const DEFAULT_HORIZONTAL_GAP: usize = 8;
7const DEFAULT_VERTICAL_GAP: usize = 4;
8const MIN_GAP: usize = 2;
9
10const SUBGRAPH_PADDING: usize = 2;
11
12pub fn compute_layout(graph: &mut Graph) -> Vec<String> {
16 compute_layout_with_options(graph, &RenderOptions::default())
17}
18
19pub fn compute_layout_with_options(graph: &mut Graph, options: &RenderOptions) -> Vec<String> {
23 let mut warnings = Vec::new();
24
25 for node in graph.nodes.values_mut() {
27 node.width = (node.label.chars().count() + 2).max(MIN_NODE_WIDTH);
28 node.height = NODE_HEIGHT;
29 }
30
31 let layers = assign_layers(graph, &mut warnings);
33
34 let (h_gap, v_gap) = calculate_gaps(graph, &layers, options.max_width);
36
37 assign_coordinates_with_gaps(graph, &layers, h_gap, v_gap);
39
40 compute_subgraph_bounds(graph);
42
43 warnings
44}
45
46fn calculate_gaps(
48 graph: &Graph,
49 layers: &HashMap<NodeId, usize>,
50 max_width: Option<usize>,
51) -> (usize, usize) {
52 let max_width = match max_width {
53 Some(w) => w,
54 None => return (DEFAULT_HORIZONTAL_GAP, DEFAULT_VERTICAL_GAP),
55 };
56
57 let mut layers_map: HashMap<usize, Vec<&NodeId>> = HashMap::new();
59 let mut max_layer = 0;
60
61 for (id, &layer) in layers {
62 layers_map.entry(layer).or_default().push(id);
63 max_layer = max_layer.max(layer);
64 }
65
66 if graph.direction.is_horizontal() {
68 let mut total_width = 0;
69 for l in 0..=max_layer {
70 let nodes_in_layer = layers_map.get(&l).map(|v| v.as_slice()).unwrap_or(&[]);
71 let layer_max_width = nodes_in_layer
72 .iter()
73 .filter_map(|id| graph.nodes.get(*id))
74 .map(|n| n.width)
75 .max()
76 .unwrap_or(0);
77 total_width += layer_max_width;
78 }
79 total_width += max_layer * DEFAULT_HORIZONTAL_GAP;
80
81 if total_width > max_width && max_layer > 0 {
83 let node_width = total_width - max_layer * DEFAULT_HORIZONTAL_GAP;
84 let available_for_gaps = max_width.saturating_sub(node_width);
85 let new_gap = (available_for_gaps / max_layer).max(MIN_GAP);
86 return (new_gap, DEFAULT_VERTICAL_GAP);
87 }
88 }
89
90 (DEFAULT_HORIZONTAL_GAP, DEFAULT_VERTICAL_GAP)
91}
92
93fn compute_subgraph_bounds(graph: &mut Graph) {
95 for sg in &mut graph.subgraphs {
96 if sg.nodes.is_empty() {
97 continue;
98 }
99
100 let mut min_x = usize::MAX;
101 let mut min_y = usize::MAX;
102 let mut max_x = 0;
103 let mut max_y = 0;
104
105 for node_id in &sg.nodes {
106 if let Some(node) = graph.nodes.get(node_id) {
107 min_x = min_x.min(node.x);
108 min_y = min_y.min(node.y);
109 max_x = max_x.max(node.x + node.width);
110 max_y = max_y.max(node.y + node.height);
111 }
112 }
113
114 if min_x != usize::MAX {
115 sg.x = min_x.saturating_sub(SUBGRAPH_PADDING);
117 sg.y = min_y.saturating_sub(SUBGRAPH_PADDING + 1); sg.width = (max_x - min_x) + SUBGRAPH_PADDING * 2;
119 sg.height = (max_y - min_y) + SUBGRAPH_PADDING * 2 + 1;
120 }
121 }
122}
123
124fn assign_layers(graph: &Graph, warnings: &mut Vec<String>) -> HashMap<NodeId, usize> {
126 let mut node_layers: HashMap<NodeId, usize> = HashMap::new();
127 let mut in_degree: HashMap<NodeId, usize> = HashMap::new();
128
129 for id in graph.nodes.keys() {
131 in_degree.insert(id.clone(), 0);
132 node_layers.insert(id.clone(), 0);
133 }
134
135 for edge in &graph.edges {
137 *in_degree.entry(edge.to.clone()).or_insert(0) += 1;
138 }
139
140 let mut queue: VecDeque<NodeId> = VecDeque::new();
142 for (id, °ree) in &in_degree {
143 if degree == 0 {
144 queue.push_back(id.clone());
145 }
146 }
147
148 let mut processed = 0;
149 while let Some(u) = queue.pop_front() {
150 processed += 1;
151
152 let neighbors: Vec<NodeId> = graph
154 .edges
155 .iter()
156 .filter(|e| e.from == u)
157 .map(|e| e.to.clone())
158 .collect();
159
160 for v in neighbors {
161 let u_layer = *node_layers.get(&u).unwrap_or(&0);
163 let v_layer = node_layers.entry(v.clone()).or_insert(0);
164 *v_layer = (*v_layer).max(u_layer + 1);
165
166 if let Some(deg) = in_degree.get_mut(&v) {
168 *deg -= 1;
169 if *deg == 0 {
170 queue.push_back(v);
171 }
172 }
173 }
174 }
175
176 if processed < graph.nodes.len() {
178 warnings.push("Cycle detected in graph. Layout may be imperfect.".to_string());
179 }
180
181 node_layers
182}
183
184fn assign_coordinates_with_gaps(
186 graph: &mut Graph,
187 node_layers: &HashMap<NodeId, usize>,
188 h_gap: usize,
189 v_gap: usize,
190) {
191 let direction = graph.direction;
192
193 let mut layers_map: HashMap<usize, Vec<NodeId>> = HashMap::new();
195 let mut max_layer = 0;
196
197 for (id, &layer) in node_layers {
198 layers_map.entry(layer).or_default().push(id.clone());
199 max_layer = max_layer.max(layer);
200 }
201
202 let mut layer_widths: HashMap<usize, usize> = HashMap::new();
204 let mut layer_heights: HashMap<usize, usize> = HashMap::new();
205
206 for l in 0..=max_layer {
207 let nodes_in_layer = layers_map.get(&l).map(|v| v.as_slice()).unwrap_or(&[]);
208 let mut max_w = 0;
209 let mut max_h = 0;
210 let mut total_w = 0;
211 let mut total_h = 0;
212
213 for id in nodes_in_layer {
214 if let Some(node) = graph.nodes.get(id) {
215 max_w = max_w.max(node.width);
216 max_h = max_h.max(node.height);
217 total_w += node.width + h_gap;
218 total_h += node.height + v_gap;
219 }
220 }
221
222 if direction.is_horizontal() {
223 layer_widths.insert(l, max_w);
224 layer_heights.insert(l, total_h.saturating_sub(v_gap));
225 } else {
226 layer_widths.insert(l, total_w.saturating_sub(h_gap));
227 layer_heights.insert(l, max_h);
228 }
229 }
230
231 let max_total_width = layer_widths.values().copied().max().unwrap_or(0);
232 let max_total_height = layer_heights.values().copied().max().unwrap_or(0);
233
234 if direction.is_horizontal() {
235 let mut current_x = 0;
236 for l in 0..=max_layer {
237 let layer_idx = match direction {
238 Direction::LR => l,
239 Direction::RL => max_layer - l,
240 _ => l,
241 };
242
243 let nodes_in_layer = layers_map.get(&layer_idx).cloned().unwrap_or_default();
244 let layer_h = *layer_heights.get(&layer_idx).unwrap_or(&0);
245 let mut start_y = (max_total_height.saturating_sub(layer_h)) / 2;
246
247 for id in nodes_in_layer {
248 if let Some(node) = graph.nodes.get_mut(&id) {
249 node.x = current_x;
250 node.y = start_y;
251 start_y += node.height + v_gap;
252 }
253 }
254
255 current_x += layer_widths.get(&layer_idx).unwrap_or(&0) + h_gap;
256 }
257 } else {
258 let mut current_y = 0;
259 for l in 0..=max_layer {
260 let layer_idx = match direction {
261 Direction::TB => l,
262 Direction::BT => max_layer - l,
263 _ => l,
264 };
265
266 let nodes_in_layer = layers_map.get(&layer_idx).cloned().unwrap_or_default();
267 let layer_w = *layer_widths.get(&layer_idx).unwrap_or(&0);
268 let mut start_x = (max_total_width.saturating_sub(layer_w)) / 2;
269
270 for id in nodes_in_layer {
271 if let Some(node) = graph.nodes.get_mut(&id) {
272 node.x = start_x;
273 node.y = current_y;
274 start_x += node.width + h_gap;
275 }
276 }
277
278 current_y += layer_heights.get(&layer_idx).unwrap_or(&0) + v_gap;
279 }
280 }
281}
282
283#[cfg(test)]
284mod tests {
285 use super::*;
286 use crate::parser::parse_mermaid;
287
288 #[test]
289 fn test_layout_lr() {
290 let mut graph = parse_mermaid("flowchart LR\nA --> B").unwrap();
291 let warnings = compute_layout(&mut graph);
292
293 let a = graph.nodes.get("A").unwrap();
294 let b = graph.nodes.get("B").unwrap();
295
296 assert!(a.x < b.x);
297 assert!(warnings.is_empty());
298 }
299
300 #[test]
301 fn test_layout_tb() {
302 let mut graph = parse_mermaid("flowchart TB\nA --> B").unwrap();
303 let warnings = compute_layout(&mut graph);
304
305 let a = graph.nodes.get("A").unwrap();
306 let b = graph.nodes.get("B").unwrap();
307
308 assert!(a.y < b.y);
309 assert!(warnings.is_empty());
310 }
311
312 #[test]
313 fn test_node_sizes() {
314 let mut graph = parse_mermaid("flowchart LR\nA[Hello World]").unwrap();
315 compute_layout(&mut graph);
316
317 let a = graph.nodes.get("A").unwrap();
318 assert_eq!(a.width, "Hello World".len() + 2);
319 assert_eq!(a.height, NODE_HEIGHT);
320 }
321
322 #[test]
323 fn test_cycle_produces_warning() {
324 let mut graph = parse_mermaid("flowchart LR\nA --> B\nB --> C\nC --> A").unwrap();
325 let warnings = compute_layout(&mut graph);
326 assert_eq!(warnings.len(), 1);
327 assert!(warnings[0].contains("Cycle"));
328 }
329
330 #[test]
331 fn test_acyclic_no_warning() {
332 let mut graph = parse_mermaid("flowchart LR\nA --> B\nB --> C\nA --> C").unwrap();
333 let warnings = compute_layout(&mut graph);
334 assert!(warnings.is_empty());
335 }
336}