Skip to main content

arc_lang/
layout.rs

1/// Arc layout engine — simplified Sugiyama-style layered layout,
2/// optimized for architecture diagrams (typically 5-40 nodes, hierarchical).
3
4use crate::ast::*;
5use std::collections::{HashMap, HashSet, VecDeque};
6
7// ── Public types ─────────────────────────────────────────────────
8
9#[derive(Debug, Clone)]
10pub struct LayoutResult {
11    pub nodes: Vec<LayoutNode>,
12    pub edges: Vec<LayoutEdge>,
13    pub groups: Vec<LayoutGroup>,
14    pub width: f64,
15    pub height: f64,
16}
17
18#[derive(Debug, Clone)]
19pub struct LayoutNode {
20    pub id: String,
21    pub x: f64,
22    pub y: f64,
23    pub width: f64,
24    pub height: f64,
25    pub node_type: NodeType,
26    pub label: String,
27    pub display_label: String,
28    pub tags: Vec<String>,
29}
30
31#[derive(Debug, Clone)]
32pub struct LayoutEdge {
33    pub from: String,
34    pub to: String,
35    pub points: Vec<(f64, f64)>,
36    pub label: Option<String>,
37    pub tags: Vec<String>,
38    pub arrow_kind: ArrowKind,
39}
40
41#[derive(Debug, Clone)]
42pub struct LayoutGroup {
43    pub label: String,
44    pub x: f64,
45    pub y: f64,
46    pub width: f64,
47    pub height: f64,
48    pub tags: Vec<String>,
49    pub depth: usize,
50    pub children: Vec<LayoutGroup>,
51}
52
53// ── Constants ────────────────────────────────────────────────────
54
55const NODE_WIDTH: f64 = 170.0;
56const NODE_HEIGHT: f64 = 72.0;
57const NODE_HEIGHT_WITH_TAGS: f64 = 90.0;
58const GROUP_PADDING: f64 = 28.0;
59const GROUP_HEADER: f64 = 28.0;
60
61// ── Layout computation ──────────────────────────────────────────
62
63pub fn compute_layout(doc: &Document) -> LayoutResult {
64    let direction = doc.direction();
65    let spacing = doc.spacing();
66    let layer_gap = spacing.layer_gap();
67    let node_gap = spacing.node_gap();
68
69    // Build adjacency info
70    let node_ids: Vec<String> = doc.nodes.iter().map(|n| n.id.clone()).collect();
71    let mut outgoing: HashMap<String, Vec<String>> = HashMap::new();
72    let mut incoming: HashMap<String, Vec<String>> = HashMap::new();
73    for id in &node_ids {
74        outgoing.entry(id.clone()).or_default();
75        incoming.entry(id.clone()).or_default();
76    }
77    for conn in &doc.connections {
78        if conn.arrow == ArrowKind::Blocked { continue; }
79        outgoing.entry(conn.from.clone()).or_default().push(conn.to.clone());
80        if conn.arrow == ArrowKind::Bidirectional {
81            outgoing.entry(conn.to.clone()).or_default().push(conn.from.clone());
82            incoming.entry(conn.from.clone()).or_default().push(conn.to.clone());
83        }
84        incoming.entry(conn.to.clone()).or_default().push(conn.from.clone());
85    }
86
87    // Step 1: Assign layers via BFS (longest path from sources)
88    let mut layers: HashMap<String, usize> = HashMap::new();
89    let sources: Vec<String> = node_ids.iter()
90        .filter(|id| incoming.get(id.as_str()).map(|v| v.is_empty()).unwrap_or(true))
91        .cloned()
92        .collect();
93
94    // If no sources (cycle), pick the first node
95    let seeds = if sources.is_empty() {
96        node_ids.iter().take(1).cloned().collect::<Vec<_>>()
97    } else {
98        sources
99    };
100
101    // BFS to assign layers
102    let mut queue: VecDeque<String> = VecDeque::new();
103    for seed in &seeds {
104        layers.insert(seed.clone(), 0);
105        queue.push_back(seed.clone());
106    }
107
108    while let Some(node) = queue.pop_front() {
109        let current_layer = *layers.get(&node).unwrap_or(&0);
110        if let Some(neighbors) = outgoing.get(&node) {
111            for next in neighbors {
112                let new_layer = current_layer + 1;
113                let existing = layers.get(next).copied().unwrap_or(0);
114                if new_layer > existing || !layers.contains_key(next) {
115                    layers.insert(next.clone(), new_layer);
116                    queue.push_back(next.clone());
117                }
118            }
119        }
120    }
121
122    // Ensure all nodes have a layer (disconnected nodes go to layer 0)
123    for id in &node_ids {
124        layers.entry(id.clone()).or_insert(0);
125    }
126
127    // Step 2: Group nodes by layer
128    let max_layer = layers.values().copied().max().unwrap_or(0);
129    let mut layer_nodes: Vec<Vec<String>> = vec![Vec::new(); max_layer + 1];
130    for (id, layer) in &layers {
131        layer_nodes[*layer].push(id.clone());
132    }
133
134    // Step 3: Order nodes within layers (barycenter heuristic)
135    // Initialize with document order
136    for layer in &mut layer_nodes {
137        layer.sort_by_key(|id| node_ids.iter().position(|n| n == id).unwrap_or(0));
138    }
139
140    // Barycenter iterations
141    for _iteration in 0..4 {
142        for l in 1..=max_layer {
143            let prev_layer = &layer_nodes[l - 1];
144            let prev_positions: HashMap<String, f64> = prev_layer.iter().enumerate()
145                .map(|(i, id)| (id.clone(), i as f64))
146                .collect();
147
148            let mut barycenters: Vec<(String, f64)> = layer_nodes[l].iter().map(|id| {
149                let neighbors = incoming.get(id).cloned().unwrap_or_default();
150                let positions: Vec<f64> = neighbors.iter()
151                    .filter_map(|n| prev_positions.get(n).copied())
152                    .collect();
153                let bc = if positions.is_empty() { f64::MAX } else {
154                    positions.iter().sum::<f64>() / positions.len() as f64
155                };
156                (id.clone(), bc)
157            }).collect();
158
159            barycenters.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
160            layer_nodes[l] = barycenters.into_iter().map(|(id, _)| id).collect();
161        }
162    }
163
164    // Step 4: Assign coordinates
165    let node_map: HashMap<&str, &Node> = doc.nodes.iter().map(|n| (n.id.as_str(), n)).collect();
166    let mut layout_nodes: Vec<LayoutNode> = Vec::new();
167    let mut node_positions: HashMap<String, (f64, f64, f64, f64)> = HashMap::new();
168
169    // Find max nodes in any layer for centering
170    let max_nodes_in_layer = layer_nodes.iter().map(|l| l.len()).max().unwrap_or(1);
171
172    for (layer_idx, nodes_in_layer) in layer_nodes.iter().enumerate() {
173        let n = nodes_in_layer.len();
174        for (pos_idx, node_id) in nodes_in_layer.iter().enumerate() {
175            let node = node_map.get(node_id.as_str());
176            let has_tags = node.map(|n| !n.tags.is_empty()).unwrap_or(false);
177            let h = if has_tags { NODE_HEIGHT_WITH_TAGS } else { NODE_HEIGHT };
178
179            // Center smaller layers
180            let total_extent = n as f64 * h + (n as f64 - 1.0) * node_gap;
181            let max_extent = max_nodes_in_layer as f64 * NODE_HEIGHT_WITH_TAGS + (max_nodes_in_layer as f64 - 1.0) * node_gap;
182            let offset = (max_extent - total_extent) / 2.0;
183
184            let (x, y) = match direction {
185                Direction::Down => {
186                    let x = offset + pos_idx as f64 * (NODE_WIDTH + node_gap);
187                    let y = layer_idx as f64 * (NODE_HEIGHT_WITH_TAGS + layer_gap);
188                    (x, y)
189                }
190                Direction::Right => {
191                    let x = layer_idx as f64 * (NODE_WIDTH + layer_gap);
192                    let y = offset + pos_idx as f64 * (NODE_HEIGHT_WITH_TAGS + node_gap);
193                    (x, y)
194                }
195            };
196
197            let display_label = node
198                .map(|n| n.display_label().to_string())
199                .unwrap_or_else(|| node_id.clone());
200
201            let node_type = node.map(|n| n.node_type).unwrap_or(NodeType::Service);
202            let tags = node.map(|n| n.tags.clone()).unwrap_or_default();
203
204            layout_nodes.push(LayoutNode {
205                id: node_id.clone(),
206                x, y,
207                width: NODE_WIDTH,
208                height: h,
209                node_type,
210                label: node_id.clone(),
211                display_label,
212                tags,
213            });
214
215            node_positions.insert(node_id.clone(), (x, y, NODE_WIDTH, h));
216        }
217    }
218
219    // Step 5: Route edges
220    let mut layout_edges: Vec<LayoutEdge> = Vec::new();
221    for conn in &doc.connections {
222        if let (Some(&(fx, fy, fw, fh)), Some(&(tx, ty, tw, th))) =
223            (node_positions.get(&conn.from), node_positions.get(&conn.to))
224        {
225            let from_center = (fx + fw / 2.0, fy + fh / 2.0);
226            let to_center = (tx + tw / 2.0, ty + th / 2.0);
227
228            // Find connection points on node boundaries
229            let from_point = edge_point(fx, fy, fw, fh, to_center.0, to_center.1);
230            let to_point = edge_point(tx, ty, tw, th, from_center.0, from_center.1);
231
232            layout_edges.push(LayoutEdge {
233                from: conn.from.clone(),
234                to: conn.to.clone(),
235                points: vec![from_point, to_point],
236                label: conn.label.clone(),
237                tags: conn.tags.clone(),
238                arrow_kind: conn.arrow,
239            });
240        }
241    }
242
243    // Step 6: Compute group bounds
244    let layout_groups = compute_group_bounds(&doc.groups, &node_positions, 0);
245
246    // Compute total canvas size
247    let mut min_x = f64::MAX;
248    let mut min_y = f64::MAX;
249    let mut max_x = f64::MIN;
250    let mut max_y = f64::MIN;
251
252    for node in &layout_nodes {
253        min_x = min_x.min(node.x);
254        min_y = min_y.min(node.y);
255        max_x = max_x.max(node.x + node.width);
256        max_y = max_y.max(node.y + node.height);
257    }
258    for group in &layout_groups {
259        min_x = min_x.min(group.x);
260        min_y = min_y.min(group.y);
261        max_x = max_x.max(group.x + group.width);
262        max_y = max_y.max(group.y + group.height);
263    }
264
265    // Normalize to positive coordinates with padding
266    let pad = 40.0;
267    let offset_x = -min_x + pad;
268    let offset_y = -min_y + pad;
269
270    for node in &mut layout_nodes {
271        node.x += offset_x;
272        node.y += offset_y;
273    }
274    for edge in &mut layout_edges {
275        for point in &mut edge.points {
276            point.0 += offset_x;
277            point.1 += offset_y;
278        }
279    }
280    fn offset_groups(groups: &mut Vec<LayoutGroup>, ox: f64, oy: f64) {
281        for g in groups {
282            g.x += ox;
283            g.y += oy;
284            offset_groups(&mut g.children, ox, oy);
285        }
286    }
287    offset_groups(&mut Vec::new(), offset_x, offset_y);
288
289    // Also update the layout_groups
290    let mut layout_groups = layout_groups;
291    fn offset_groups_in_place(groups: &mut [LayoutGroup], ox: f64, oy: f64) {
292        for g in groups.iter_mut() {
293            g.x += ox;
294            g.y += oy;
295            offset_groups_in_place(&mut g.children, ox, oy);
296        }
297    }
298    offset_groups_in_place(&mut layout_groups, offset_x, offset_y);
299
300    let width = (max_x - min_x) + pad * 2.0;
301    let height = (max_y - min_y) + pad * 2.0;
302
303    LayoutResult {
304        nodes: layout_nodes,
305        edges: layout_edges,
306        groups: layout_groups,
307        width: width.max(200.0),
308        height: height.max(200.0),
309    }
310}
311
312// ── Group bounds computation ────────────────────────────────────
313
314fn compute_group_bounds(
315    groups: &[Group],
316    positions: &HashMap<String, (f64, f64, f64, f64)>,
317    depth: usize,
318) -> Vec<LayoutGroup> {
319    let mut result = Vec::new();
320
321    for group in groups {
322        let mut member_ids: HashSet<String> = HashSet::new();
323        let mut child_groups = Vec::new();
324
325        collect_all_member_ids(group, &mut member_ids);
326
327        // Recurse into sub-groups
328        for member in &group.members {
329            if let GroupMember::Group(sub) = member {
330                let sub_bounds = compute_group_bounds(&[sub.clone()], positions, depth + 1);
331                child_groups.extend(sub_bounds);
332            }
333        }
334
335        // Compute bounding box of all members
336        let mut min_x = f64::MAX;
337        let mut min_y = f64::MAX;
338        let mut max_x = f64::MIN;
339        let mut max_y = f64::MIN;
340        let mut has_members = false;
341
342        for id in &member_ids {
343            if let Some(&(x, y, w, h)) = positions.get(id) {
344                min_x = min_x.min(x);
345                min_y = min_y.min(y);
346                max_x = max_x.max(x + w);
347                max_y = max_y.max(y + h);
348                has_members = true;
349            }
350        }
351
352        // Also include child group bounds
353        for cg in &child_groups {
354            min_x = min_x.min(cg.x);
355            min_y = min_y.min(cg.y);
356            max_x = max_x.max(cg.x + cg.width);
357            max_y = max_y.max(cg.y + cg.height);
358            has_members = true;
359        }
360
361        if has_members {
362            result.push(LayoutGroup {
363                label: group.label.clone(),
364                x: min_x - GROUP_PADDING,
365                y: min_y - GROUP_PADDING - GROUP_HEADER,
366                width: (max_x - min_x) + GROUP_PADDING * 2.0,
367                height: (max_y - min_y) + GROUP_PADDING * 2.0 + GROUP_HEADER,
368                tags: group.tags.clone(),
369                depth,
370                children: child_groups,
371            });
372        }
373    }
374
375    result
376}
377
378fn collect_all_member_ids(group: &Group, ids: &mut HashSet<String>) {
379    for member in &group.members {
380        match member {
381            GroupMember::NodeRef(id) => { ids.insert(id.clone()); }
382            GroupMember::NodeRefList(list) => { ids.extend(list.iter().cloned()); }
383            GroupMember::Node(n) => { ids.insert(n.id.clone()); }
384            GroupMember::Connection(c) => { ids.insert(c.from.clone()); ids.insert(c.to.clone()); }
385            GroupMember::Group(g) => { collect_all_member_ids(g, ids); }
386        }
387    }
388}
389
390// ── Edge point calculation ──────────────────────────────────────
391
392/// Find the point on a rectangle boundary closest to the line from center to (tx, ty).
393fn edge_point(rx: f64, ry: f64, rw: f64, rh: f64, tx: f64, ty: f64) -> (f64, f64) {
394    let cx = rx + rw / 2.0;
395    let cy = ry + rh / 2.0;
396    let dx = tx - cx;
397    let dy = ty - cy;
398
399    if dx.abs() < 0.001 && dy.abs() < 0.001 {
400        return (cx, cy);
401    }
402
403    let half_w = rw / 2.0;
404    let half_h = rh / 2.0;
405
406    // Find intersection with rectangle edges
407    let scale_x = if dx.abs() > 0.001 { half_w / dx.abs() } else { f64::MAX };
408    let scale_y = if dy.abs() > 0.001 { half_h / dy.abs() } else { f64::MAX };
409    let scale = scale_x.min(scale_y);
410
411    (cx + dx * scale, cy + dy * scale)
412}