1use crate::ast::*;
5use std::collections::{HashMap, HashSet, VecDeque};
6
7#[derive(Debug, Clone)]
10pub struct LayoutResult {
11 pub nodes: Vec<LayoutNode>,
12 pub edges: Vec<LayoutEdge>,
13 pub groups: Vec<LayoutGroup>,
14 pub width: f64,
15 pub height: f64,
16}
17
18#[derive(Debug, Clone)]
19pub struct LayoutNode {
20 pub id: String,
21 pub x: f64,
22 pub y: f64,
23 pub width: f64,
24 pub height: f64,
25 pub node_type: NodeType,
26 pub label: String,
27 pub display_label: String,
28 pub tags: Vec<String>,
29}
30
31#[derive(Debug, Clone)]
32pub struct LayoutEdge {
33 pub from: String,
34 pub to: String,
35 pub points: Vec<(f64, f64)>,
36 pub label: Option<String>,
37 pub tags: Vec<String>,
38 pub arrow_kind: ArrowKind,
39}
40
41#[derive(Debug, Clone)]
42pub struct LayoutGroup {
43 pub label: String,
44 pub x: f64,
45 pub y: f64,
46 pub width: f64,
47 pub height: f64,
48 pub tags: Vec<String>,
49 pub depth: usize,
50 pub children: Vec<LayoutGroup>,
51}
52
53const NODE_WIDTH: f64 = 170.0;
56const NODE_HEIGHT: f64 = 72.0;
57const NODE_HEIGHT_WITH_TAGS: f64 = 90.0;
58const GROUP_PADDING: f64 = 28.0;
59const GROUP_HEADER: f64 = 28.0;
60
61pub fn compute_layout(doc: &Document) -> LayoutResult {
64 let direction = doc.direction();
65 let spacing = doc.spacing();
66 let layer_gap = spacing.layer_gap();
67 let node_gap = spacing.node_gap();
68
69 let node_ids: Vec<String> = doc.nodes.iter().map(|n| n.id.clone()).collect();
71 let mut outgoing: HashMap<String, Vec<String>> = HashMap::new();
72 let mut incoming: HashMap<String, Vec<String>> = HashMap::new();
73 for id in &node_ids {
74 outgoing.entry(id.clone()).or_default();
75 incoming.entry(id.clone()).or_default();
76 }
77 for conn in &doc.connections {
78 if conn.arrow == ArrowKind::Blocked { continue; }
79 outgoing.entry(conn.from.clone()).or_default().push(conn.to.clone());
80 if conn.arrow == ArrowKind::Bidirectional {
81 outgoing.entry(conn.to.clone()).or_default().push(conn.from.clone());
82 incoming.entry(conn.from.clone()).or_default().push(conn.to.clone());
83 }
84 incoming.entry(conn.to.clone()).or_default().push(conn.from.clone());
85 }
86
87 let mut layers: HashMap<String, usize> = HashMap::new();
89 let sources: Vec<String> = node_ids.iter()
90 .filter(|id| incoming.get(id.as_str()).map(|v| v.is_empty()).unwrap_or(true))
91 .cloned()
92 .collect();
93
94 let seeds = if sources.is_empty() {
96 node_ids.iter().take(1).cloned().collect::<Vec<_>>()
97 } else {
98 sources
99 };
100
101 let mut queue: VecDeque<String> = VecDeque::new();
103 for seed in &seeds {
104 layers.insert(seed.clone(), 0);
105 queue.push_back(seed.clone());
106 }
107
108 while let Some(node) = queue.pop_front() {
109 let current_layer = *layers.get(&node).unwrap_or(&0);
110 if let Some(neighbors) = outgoing.get(&node) {
111 for next in neighbors {
112 let new_layer = current_layer + 1;
113 let existing = layers.get(next).copied().unwrap_or(0);
114 if new_layer > existing || !layers.contains_key(next) {
115 layers.insert(next.clone(), new_layer);
116 queue.push_back(next.clone());
117 }
118 }
119 }
120 }
121
122 for id in &node_ids {
124 layers.entry(id.clone()).or_insert(0);
125 }
126
127 let max_layer = layers.values().copied().max().unwrap_or(0);
129 let mut layer_nodes: Vec<Vec<String>> = vec![Vec::new(); max_layer + 1];
130 for (id, layer) in &layers {
131 layer_nodes[*layer].push(id.clone());
132 }
133
134 for layer in &mut layer_nodes {
137 layer.sort_by_key(|id| node_ids.iter().position(|n| n == id).unwrap_or(0));
138 }
139
140 for _iteration in 0..4 {
142 for l in 1..=max_layer {
143 let prev_layer = &layer_nodes[l - 1];
144 let prev_positions: HashMap<String, f64> = prev_layer.iter().enumerate()
145 .map(|(i, id)| (id.clone(), i as f64))
146 .collect();
147
148 let mut barycenters: Vec<(String, f64)> = layer_nodes[l].iter().map(|id| {
149 let neighbors = incoming.get(id).cloned().unwrap_or_default();
150 let positions: Vec<f64> = neighbors.iter()
151 .filter_map(|n| prev_positions.get(n).copied())
152 .collect();
153 let bc = if positions.is_empty() { f64::MAX } else {
154 positions.iter().sum::<f64>() / positions.len() as f64
155 };
156 (id.clone(), bc)
157 }).collect();
158
159 barycenters.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
160 layer_nodes[l] = barycenters.into_iter().map(|(id, _)| id).collect();
161 }
162 }
163
164 let node_map: HashMap<&str, &Node> = doc.nodes.iter().map(|n| (n.id.as_str(), n)).collect();
166 let mut layout_nodes: Vec<LayoutNode> = Vec::new();
167 let mut node_positions: HashMap<String, (f64, f64, f64, f64)> = HashMap::new();
168
169 let max_nodes_in_layer = layer_nodes.iter().map(|l| l.len()).max().unwrap_or(1);
171
172 for (layer_idx, nodes_in_layer) in layer_nodes.iter().enumerate() {
173 let n = nodes_in_layer.len();
174 for (pos_idx, node_id) in nodes_in_layer.iter().enumerate() {
175 let node = node_map.get(node_id.as_str());
176 let has_tags = node.map(|n| !n.tags.is_empty()).unwrap_or(false);
177 let h = if has_tags { NODE_HEIGHT_WITH_TAGS } else { NODE_HEIGHT };
178
179 let total_extent = n as f64 * h + (n as f64 - 1.0) * node_gap;
181 let max_extent = max_nodes_in_layer as f64 * NODE_HEIGHT_WITH_TAGS + (max_nodes_in_layer as f64 - 1.0) * node_gap;
182 let offset = (max_extent - total_extent) / 2.0;
183
184 let (x, y) = match direction {
185 Direction::Down => {
186 let x = offset + pos_idx as f64 * (NODE_WIDTH + node_gap);
187 let y = layer_idx as f64 * (NODE_HEIGHT_WITH_TAGS + layer_gap);
188 (x, y)
189 }
190 Direction::Right => {
191 let x = layer_idx as f64 * (NODE_WIDTH + layer_gap);
192 let y = offset + pos_idx as f64 * (NODE_HEIGHT_WITH_TAGS + node_gap);
193 (x, y)
194 }
195 };
196
197 let display_label = node
198 .map(|n| n.display_label().to_string())
199 .unwrap_or_else(|| node_id.clone());
200
201 let node_type = node.map(|n| n.node_type).unwrap_or(NodeType::Service);
202 let tags = node.map(|n| n.tags.clone()).unwrap_or_default();
203
204 layout_nodes.push(LayoutNode {
205 id: node_id.clone(),
206 x, y,
207 width: NODE_WIDTH,
208 height: h,
209 node_type,
210 label: node_id.clone(),
211 display_label,
212 tags,
213 });
214
215 node_positions.insert(node_id.clone(), (x, y, NODE_WIDTH, h));
216 }
217 }
218
219 let mut layout_edges: Vec<LayoutEdge> = Vec::new();
221 for conn in &doc.connections {
222 if let (Some(&(fx, fy, fw, fh)), Some(&(tx, ty, tw, th))) =
223 (node_positions.get(&conn.from), node_positions.get(&conn.to))
224 {
225 let from_center = (fx + fw / 2.0, fy + fh / 2.0);
226 let to_center = (tx + tw / 2.0, ty + th / 2.0);
227
228 let from_point = edge_point(fx, fy, fw, fh, to_center.0, to_center.1);
230 let to_point = edge_point(tx, ty, tw, th, from_center.0, from_center.1);
231
232 layout_edges.push(LayoutEdge {
233 from: conn.from.clone(),
234 to: conn.to.clone(),
235 points: vec![from_point, to_point],
236 label: conn.label.clone(),
237 tags: conn.tags.clone(),
238 arrow_kind: conn.arrow,
239 });
240 }
241 }
242
243 let layout_groups = compute_group_bounds(&doc.groups, &node_positions, 0);
245
246 let mut min_x = f64::MAX;
248 let mut min_y = f64::MAX;
249 let mut max_x = f64::MIN;
250 let mut max_y = f64::MIN;
251
252 for node in &layout_nodes {
253 min_x = min_x.min(node.x);
254 min_y = min_y.min(node.y);
255 max_x = max_x.max(node.x + node.width);
256 max_y = max_y.max(node.y + node.height);
257 }
258 for group in &layout_groups {
259 min_x = min_x.min(group.x);
260 min_y = min_y.min(group.y);
261 max_x = max_x.max(group.x + group.width);
262 max_y = max_y.max(group.y + group.height);
263 }
264
265 let pad = 40.0;
267 let offset_x = -min_x + pad;
268 let offset_y = -min_y + pad;
269
270 for node in &mut layout_nodes {
271 node.x += offset_x;
272 node.y += offset_y;
273 }
274 for edge in &mut layout_edges {
275 for point in &mut edge.points {
276 point.0 += offset_x;
277 point.1 += offset_y;
278 }
279 }
280 fn offset_groups(groups: &mut Vec<LayoutGroup>, ox: f64, oy: f64) {
281 for g in groups {
282 g.x += ox;
283 g.y += oy;
284 offset_groups(&mut g.children, ox, oy);
285 }
286 }
287 offset_groups(&mut Vec::new(), offset_x, offset_y);
288
289 let mut layout_groups = layout_groups;
291 fn offset_groups_in_place(groups: &mut [LayoutGroup], ox: f64, oy: f64) {
292 for g in groups.iter_mut() {
293 g.x += ox;
294 g.y += oy;
295 offset_groups_in_place(&mut g.children, ox, oy);
296 }
297 }
298 offset_groups_in_place(&mut layout_groups, offset_x, offset_y);
299
300 let width = (max_x - min_x) + pad * 2.0;
301 let height = (max_y - min_y) + pad * 2.0;
302
303 LayoutResult {
304 nodes: layout_nodes,
305 edges: layout_edges,
306 groups: layout_groups,
307 width: width.max(200.0),
308 height: height.max(200.0),
309 }
310}
311
312fn compute_group_bounds(
315 groups: &[Group],
316 positions: &HashMap<String, (f64, f64, f64, f64)>,
317 depth: usize,
318) -> Vec<LayoutGroup> {
319 let mut result = Vec::new();
320
321 for group in groups {
322 let mut member_ids: HashSet<String> = HashSet::new();
323 let mut child_groups = Vec::new();
324
325 collect_all_member_ids(group, &mut member_ids);
326
327 for member in &group.members {
329 if let GroupMember::Group(sub) = member {
330 let sub_bounds = compute_group_bounds(&[sub.clone()], positions, depth + 1);
331 child_groups.extend(sub_bounds);
332 }
333 }
334
335 let mut min_x = f64::MAX;
337 let mut min_y = f64::MAX;
338 let mut max_x = f64::MIN;
339 let mut max_y = f64::MIN;
340 let mut has_members = false;
341
342 for id in &member_ids {
343 if let Some(&(x, y, w, h)) = positions.get(id) {
344 min_x = min_x.min(x);
345 min_y = min_y.min(y);
346 max_x = max_x.max(x + w);
347 max_y = max_y.max(y + h);
348 has_members = true;
349 }
350 }
351
352 for cg in &child_groups {
354 min_x = min_x.min(cg.x);
355 min_y = min_y.min(cg.y);
356 max_x = max_x.max(cg.x + cg.width);
357 max_y = max_y.max(cg.y + cg.height);
358 has_members = true;
359 }
360
361 if has_members {
362 result.push(LayoutGroup {
363 label: group.label.clone(),
364 x: min_x - GROUP_PADDING,
365 y: min_y - GROUP_PADDING - GROUP_HEADER,
366 width: (max_x - min_x) + GROUP_PADDING * 2.0,
367 height: (max_y - min_y) + GROUP_PADDING * 2.0 + GROUP_HEADER,
368 tags: group.tags.clone(),
369 depth,
370 children: child_groups,
371 });
372 }
373 }
374
375 result
376}
377
378fn collect_all_member_ids(group: &Group, ids: &mut HashSet<String>) {
379 for member in &group.members {
380 match member {
381 GroupMember::NodeRef(id) => { ids.insert(id.clone()); }
382 GroupMember::NodeRefList(list) => { ids.extend(list.iter().cloned()); }
383 GroupMember::Node(n) => { ids.insert(n.id.clone()); }
384 GroupMember::Connection(c) => { ids.insert(c.from.clone()); ids.insert(c.to.clone()); }
385 GroupMember::Group(g) => { collect_all_member_ids(g, ids); }
386 }
387 }
388}
389
390fn edge_point(rx: f64, ry: f64, rw: f64, rh: f64, tx: f64, ty: f64) -> (f64, f64) {
394 let cx = rx + rw / 2.0;
395 let cy = ry + rh / 2.0;
396 let dx = tx - cx;
397 let dy = ty - cy;
398
399 if dx.abs() < 0.001 && dy.abs() < 0.001 {
400 return (cx, cy);
401 }
402
403 let half_w = rw / 2.0;
404 let half_h = rh / 2.0;
405
406 let scale_x = if dx.abs() > 0.001 { half_w / dx.abs() } else { f64::MAX };
408 let scale_y = if dy.abs() > 0.001 { half_h / dy.abs() } else { f64::MAX };
409 let scale = scale_x.min(scale_y);
410
411 (cx + dx * scale, cy + dy * scale)
412}