use std::collections::HashMap;
use unicode_width::UnicodeWidthStr;
use crate::layout::subgraph::SG_BORDER_PAD;
use crate::types::{Direction, Graph, NodeShape, Subgraph};
const SG_GAP_PER_BOUNDARY: usize = SG_BORDER_PAD + 1;
#[derive(Debug, Clone, Copy)]
pub struct LayoutConfig {
pub layer_gap: usize,
pub node_gap: usize,
}
impl Default for LayoutConfig {
fn default() -> Self {
Self {
layer_gap: 6,
node_gap: 2,
}
}
}
pub type GridPos = (usize, usize);
pub fn layout(graph: &Graph, config: &LayoutConfig) -> HashMap<String, GridPos> {
if graph.nodes.is_empty() {
return HashMap::new();
}
let layers = assign_layers(graph);
let ordered = order_within_layers(graph, &layers);
compute_positions(graph, &ordered, config)
}
fn is_orthogonal(parent: Direction, child: Direction) -> bool {
parent.is_horizontal() != child.is_horizontal()
}
fn collect_orthogonal_sets<'a>(
subs: &'a [Subgraph],
all_subs: &'a [Subgraph],
parent_direction: Direction,
out: &mut Vec<Vec<String>>,
) {
for sg in subs {
if sg
.direction
.is_some_and(|sg_dir| is_orthogonal(parent_direction, sg_dir))
{
out.push(sg.node_ids.clone());
}
let children: Vec<Subgraph> = sg
.subgraph_ids
.iter()
.filter_map(|id| all_subs.iter().find(|s| &s.id == id).cloned())
.collect();
collect_orthogonal_sets(&children, all_subs, parent_direction, out);
}
}
fn orthogonal_node_sets(graph: &Graph) -> Vec<Vec<String>> {
let mut result = Vec::new();
collect_orthogonal_sets(
&graph.subgraphs,
&graph.subgraphs,
graph.direction,
&mut result,
);
result
}
fn assign_layers(graph: &Graph) -> HashMap<String, usize> {
let mut layer: HashMap<String, usize> = HashMap::new();
let mut predecessors: HashMap<&str, Vec<&str>> = HashMap::new();
for node in &graph.nodes {
predecessors.entry(node.id.as_str()).or_default();
}
for edge in &graph.edges {
predecessors
.entry(edge.to.as_str())
.or_default()
.push(edge.from.as_str());
}
let max_iter = graph.nodes.len() + 1;
let mut changed = true;
let mut iter = 0;
for node in &graph.nodes {
layer.insert(node.id.clone(), 0);
}
while changed && iter < max_iter {
changed = false;
iter += 1;
for edge in &graph.edges {
let from_layer = layer.get(edge.from.as_str()).copied().unwrap_or(0);
let to_layer = layer.entry(edge.to.clone()).or_insert(0);
if from_layer + 1 > *to_layer {
*to_layer = from_layer + 1;
changed = true;
}
}
}
for node in &graph.nodes {
layer.entry(node.id.clone()).or_insert(0);
}
let ortho_sets = orthogonal_node_sets(graph);
if !ortho_sets.is_empty() {
let all_ortho: std::collections::HashSet<&str> = ortho_sets
.iter()
.flat_map(|s| s.iter().map(String::as_str))
.collect();
for set in &ortho_sets {
let present: Vec<&str> = set
.iter()
.map(String::as_str)
.filter(|id| layer.contains_key(*id))
.collect();
if present.is_empty() {
continue;
}
let min_layer = present.iter().map(|id| layer[*id]).min().unwrap_or(0);
for id in &present {
layer.insert((*id).to_owned(), min_layer);
}
}
let max_iter2 = graph.nodes.len() + 1;
let mut changed2 = true;
let mut iter2 = 0;
while changed2 && iter2 < max_iter2 {
changed2 = false;
iter2 += 1;
for edge in &graph.edges {
if all_ortho.contains(edge.to.as_str()) {
continue;
}
let from_layer = layer.get(edge.from.as_str()).copied().unwrap_or(0);
let to_layer = layer.entry(edge.to.clone()).or_insert(0);
if from_layer + 1 > *to_layer {
*to_layer = from_layer + 1;
changed2 = true;
}
}
}
}
layer
}
fn order_within_layers(graph: &Graph, layers: &HashMap<String, usize>) -> Vec<Vec<String>> {
let max_layer = layers.values().copied().max().unwrap_or(0);
let num_layers = max_layer + 1;
let mut buckets: Vec<Vec<String>> = vec![Vec::new(); num_layers];
for node in &graph.nodes {
let l = layers[&node.id];
buckets[l].push(node.id.clone());
}
let mut successors: HashMap<&str, Vec<&str>> = HashMap::new();
let mut predecessors: HashMap<&str, Vec<&str>> = HashMap::new();
for edge in &graph.edges {
successors
.entry(edge.from.as_str())
.or_default()
.push(edge.to.as_str());
predecessors
.entry(edge.to.as_str())
.or_default()
.push(edge.from.as_str());
}
let node_layer: HashMap<&str, usize> = layers.iter().map(|(id, &l)| (id.as_str(), l)).collect();
const MAX_PASSES: usize = 8;
const NO_IMPROVEMENT_CAP: usize = 4;
let mut best = buckets.clone();
let mut best_crossings = count_crossings(graph, &node_layer, &best);
let mut no_improvement = 0usize;
for _ in 0..MAX_PASSES {
sort_by_barycenter(&mut buckets, &predecessors, SweepDirection::Forward);
sort_by_barycenter(&mut buckets, &successors, SweepDirection::Backward);
let c = count_crossings(graph, &node_layer, &buckets);
if c < best_crossings {
best = buckets.clone();
best_crossings = c;
no_improvement = 0;
} else {
no_improvement += 1;
if no_improvement >= NO_IMPROVEMENT_CAP {
break;
}
}
if best_crossings == 0 {
break;
}
}
let ortho_sets = orthogonal_node_sets(graph);
if !ortho_sets.is_empty() {
for layer_nodes in &mut best {
for set in &ortho_sets {
let in_layer: Vec<usize> = layer_nodes
.iter()
.enumerate()
.filter(|(_, id)| set.contains(id))
.map(|(i, _)| i)
.collect();
if in_layer.len() <= 1 {
continue;
}
let internal_ids: Vec<String> =
in_layer.iter().map(|&i| layer_nodes[i].clone()).collect();
let internal_set: std::collections::HashSet<&str> =
internal_ids.iter().map(String::as_str).collect();
let mut successors: HashMap<&str, Vec<&str>> =
internal_set.iter().map(|&n| (n, Vec::new())).collect();
let mut in_degree: HashMap<&str, usize> =
internal_set.iter().map(|&n| (n, 0usize)).collect();
for edge in &graph.edges {
if internal_set.contains(edge.from.as_str())
&& internal_set.contains(edge.to.as_str())
{
successors
.entry(edge.from.as_str())
.or_default()
.push(edge.to.as_str());
*in_degree.entry(edge.to.as_str()).or_default() += 1;
}
}
let mut queue: std::collections::VecDeque<&str> = in_degree
.iter()
.filter(|(_, d)| **d == 0)
.map(|(&n, _)| n)
.collect();
let mut topo: Vec<String> = Vec::new();
while let Some(node) = queue.pop_front() {
topo.push(node.to_owned());
let succs: Vec<&str> = successors.get(node).cloned().unwrap_or_default();
for succ in succs {
let d = in_degree.entry(succ).or_default();
*d = d.saturating_sub(1);
if *d == 0 {
queue.push_back(succ);
}
}
}
if topo.len() == in_layer.len() {
for (slot, &pos) in in_layer.iter().enumerate() {
layer_nodes[pos] = topo[slot].clone();
}
}
}
}
}
best
}
#[derive(Copy, Clone)]
enum SweepDirection {
Forward,
Backward,
}
fn sort_by_barycenter(
buckets: &mut [Vec<String>],
neighbors: &HashMap<&str, Vec<&str>>,
dir: SweepDirection,
) {
let num_layers = buckets.len();
if num_layers < 2 {
return;
}
let layer_iter: Box<dyn Iterator<Item = usize>> = match dir {
SweepDirection::Forward => Box::new(1..num_layers),
SweepDirection::Backward => Box::new((0..num_layers - 1).rev()),
};
for l in layer_iter {
let ref_layer = match dir {
SweepDirection::Forward => l - 1,
SweepDirection::Backward => l + 1,
};
let ref_positions: HashMap<&str, f64> = buckets[ref_layer]
.iter()
.enumerate()
.map(|(i, id)| (id.as_str(), i as f64))
.collect();
let mut keyed: Vec<(String, f64)> = buckets[l]
.iter()
.enumerate()
.map(|(i, id)| {
let neigh = neighbors.get(id.as_str()).cloned().unwrap_or_default();
let bc = if neigh.is_empty() {
i as f64
} else {
let sum: f64 = neigh
.iter()
.map(|n| ref_positions.get(n).copied().unwrap_or(i as f64))
.sum();
sum / neigh.len() as f64
};
(id.clone(), bc)
})
.collect();
keyed.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
buckets[l] = keyed.into_iter().map(|(id, _)| id).collect();
}
}
fn count_crossings(
graph: &Graph,
node_layer: &HashMap<&str, usize>,
buckets: &[Vec<String>],
) -> usize {
let mut rank: HashMap<&str, usize> = HashMap::new();
for layer_nodes in buckets {
for (i, id) in layer_nodes.iter().enumerate() {
rank.insert(id.as_str(), i);
}
}
let edges_with_gaps: Vec<(usize, usize, usize, usize)> = graph
.edges
.iter()
.filter_map(|e| {
let fl = *node_layer.get(e.from.as_str())?;
let tl = *node_layer.get(e.to.as_str())?;
let fr = *rank.get(e.from.as_str())?;
let tr = *rank.get(e.to.as_str())?;
Some((fl, tl, fr, tr))
})
.collect();
let mut total = 0usize;
for i in 0..edges_with_gaps.len() {
let (fl1, tl1, fr1, tr1) = edges_with_gaps[i];
for &(fl2, tl2, fr2, tr2) in &edges_with_gaps[i + 1..] {
if (fl1, tl1) != (fl2, tl2) {
continue;
}
let from_order = fr1.cmp(&fr2);
let to_order = tr1.cmp(&tr2);
if from_order != std::cmp::Ordering::Equal
&& to_order != std::cmp::Ordering::Equal
&& from_order != to_order
{
total += 1;
}
}
}
total
}
fn node_box_width(graph: &Graph, id: &str) -> usize {
if let Some(node) = graph.node(id) {
let label_width = node.label_width();
let inner = label_width + 4; match node.shape {
NodeShape::Diamond => inner,
NodeShape::Circle | NodeShape::Stadium | NodeShape::Hexagon | NodeShape::Asymmetric => {
inner + 2
}
NodeShape::Subroutine => inner + 2,
NodeShape::Cylinder => inner,
NodeShape::Parallelogram | NodeShape::Trapezoid => inner + 2,
NodeShape::DoubleCircle => inner + 4,
NodeShape::Rectangle | NodeShape::Rounded => inner,
}
} else {
6 }
}
fn node_box_height(graph: &Graph, id: &str) -> usize {
if let Some(node) = graph.node(id) {
let extra = node.label_line_count().saturating_sub(1);
match node.shape {
NodeShape::Diamond
| NodeShape::Rectangle
| NodeShape::Rounded
| NodeShape::Circle
| NodeShape::Stadium
| NodeShape::Hexagon
| NodeShape::Asymmetric
| NodeShape::Parallelogram
| NodeShape::Trapezoid
| NodeShape::Subroutine => 3 + extra,
NodeShape::Cylinder => 4 + extra,
NodeShape::DoubleCircle => 5 + extra,
}
} else {
3
}
}
fn build_node_layer_map(ordered: &[Vec<String>]) -> HashMap<&str, usize> {
let mut map = HashMap::new();
for (layer_idx, layer_nodes) in ordered.iter().enumerate() {
for id in layer_nodes {
map.insert(id.as_str(), layer_idx);
}
}
map
}
fn label_gap(
graph: &Graph,
node_layer: &HashMap<&str, usize>,
layer_a: usize,
layer_b: usize,
default_gap: usize,
) -> usize {
let mut label_widths: Vec<usize> = graph
.edges
.iter()
.filter(|e| {
let fl = node_layer.get(e.from.as_str()).copied().unwrap_or(0);
let tl = node_layer.get(e.to.as_str()).copied().unwrap_or(0);
(fl == layer_a && tl == layer_b) || (fl == layer_b && tl == layer_a)
})
.filter_map(|e| e.label.as_deref())
.map(UnicodeWidthStr::width)
.collect();
if label_widths.is_empty() {
return default_gap;
}
let max_lbl = label_widths.iter().copied().max().unwrap_or(0);
let needed_for_width = max_lbl + 2;
label_widths.sort_unstable();
let count = label_widths.len();
let needed_for_stacking = count * 2 + 1;
default_gap.max(needed_for_width).max(needed_for_stacking)
}
fn build_subgraph_parent_map(graph: &Graph) -> HashMap<&str, &str> {
let mut m = HashMap::new();
for parent in &graph.subgraphs {
for child_id in &parent.subgraph_ids {
m.insert(child_id.as_str(), parent.id.as_str());
}
}
m
}
fn node_subgraph_chain<'a>(
node_id: &str,
node_to_sg: &'a HashMap<String, String>,
parent_map: &'a HashMap<&'a str, &'a str>,
) -> Vec<&'a str> {
let mut chain = Vec::new();
let Some(sg_id) = node_to_sg.get(node_id) else {
return chain;
};
let mut cur: &str = sg_id.as_str();
chain.push(cur);
while let Some(&parent) = parent_map.get(cur) {
chain.push(parent);
cur = parent;
}
chain
}
fn subgraph_boundary_count(chain_a: &[&str], chain_b: &[&str]) -> usize {
let a_len = chain_a.len();
let b_len = chain_b.len();
let mut shared = 0usize;
for i in 1..=a_len.min(b_len) {
if chain_a[a_len - i] == chain_b[b_len - i] {
shared += 1;
} else {
break;
}
}
(a_len - shared) + (b_len - shared)
}
fn sibling_gap(
node_a: &str,
node_b: &str,
node_to_sg: &HashMap<String, String>,
parent_map: &HashMap<&str, &str>,
base_gap: usize,
) -> usize {
let chain_a = node_subgraph_chain(node_a, node_to_sg, parent_map);
let chain_b = node_subgraph_chain(node_b, node_to_sg, parent_map);
let boundaries = subgraph_boundary_count(&chain_a, &chain_b);
base_gap + boundaries * SG_GAP_PER_BOUNDARY
}
fn compute_positions(
graph: &Graph,
ordered: &[Vec<String>],
config: &LayoutConfig,
) -> HashMap<String, GridPos> {
let mut positions: HashMap<String, GridPos> = HashMap::new();
let node_layer = build_node_layer_map(ordered);
let node_to_sg = graph.node_to_subgraph();
let sg_parent = build_subgraph_parent_map(graph);
match graph.direction {
Direction::LeftToRight | Direction::RightToLeft => {
let mut col = 0usize;
for (layer_idx, layer_nodes) in ordered.iter().enumerate() {
if layer_nodes.is_empty() {
continue;
}
let layer_width = layer_nodes
.iter()
.map(|id| node_box_width(graph, id))
.max()
.unwrap_or(6);
let mut row = 0usize;
let mut prev: Option<&str> = None;
for id in layer_nodes {
let h = node_box_height(graph, id);
if let Some(prev_id) = prev {
let gap =
sibling_gap(prev_id, id, &node_to_sg, &sg_parent, config.node_gap);
row += gap.saturating_sub(config.node_gap);
}
positions.insert(id.clone(), (col, row));
row += h + config.node_gap;
prev = Some(id.as_str());
}
let gap = if layer_idx + 1 < ordered.len() {
label_gap(
graph,
&node_layer,
layer_idx,
layer_idx + 1,
config.layer_gap,
)
} else {
config.layer_gap
};
col += layer_width + gap;
}
if graph.direction == Direction::RightToLeft {
let max_col = positions.values().map(|(c, _)| *c).max().unwrap_or(0);
for (col, _) in positions.values_mut() {
*col = max_col - *col;
}
}
}
Direction::TopToBottom | Direction::BottomToTop => {
let mut row = 0usize;
for (layer_idx, layer_nodes) in ordered.iter().enumerate() {
if layer_nodes.is_empty() {
continue;
}
let layer_height = layer_nodes
.iter()
.map(|id| node_box_height(graph, id))
.max()
.unwrap_or(3);
let mut col = 0usize;
let mut prev: Option<&str> = None;
for id in layer_nodes {
let w = node_box_width(graph, id);
if let Some(prev_id) = prev {
let gap =
sibling_gap(prev_id, id, &node_to_sg, &sg_parent, config.node_gap);
col += gap.saturating_sub(config.node_gap);
}
positions.insert(id.clone(), (col, row));
col += w + config.node_gap;
prev = Some(id.as_str());
}
let gap = if layer_idx + 1 < ordered.len() {
label_gap(
graph,
&node_layer,
layer_idx,
layer_idx + 1,
config.layer_gap,
)
} else {
config.layer_gap
};
row += layer_height + gap;
}
if graph.direction == Direction::BottomToTop {
let max_row = positions.values().map(|(_, r)| *r).max().unwrap_or(0);
for (_, row) in positions.values_mut() {
*row = max_row - *row;
}
}
}
}
positions
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::{Direction, Edge, Graph, Node, NodeShape};
fn simple_lr_graph() -> Graph {
let mut g = Graph::new(Direction::LeftToRight);
g.nodes.push(Node::new("A", "A", NodeShape::Rectangle));
g.nodes.push(Node::new("B", "B", NodeShape::Rectangle));
g.nodes.push(Node::new("C", "C", NodeShape::Rectangle));
g.edges.push(Edge::new("A", "B", None));
g.edges.push(Edge::new("B", "C", None));
g
}
#[test]
fn lr_nodes_have_increasing_columns() {
let g = simple_lr_graph();
let cfg = LayoutConfig::default();
let pos = layout(&g, &cfg);
assert!(pos["A"].0 < pos["B"].0);
assert!(pos["B"].0 < pos["C"].0);
}
#[test]
fn td_nodes_have_increasing_rows() {
let mut g = Graph::new(Direction::TopToBottom);
g.nodes.push(Node::new("A", "A", NodeShape::Rectangle));
g.nodes.push(Node::new("B", "B", NodeShape::Rectangle));
g.edges.push(Edge::new("A", "B", None));
let cfg = LayoutConfig::default();
let pos = layout(&g, &cfg);
assert!(pos["A"].1 < pos["B"].1);
}
#[test]
fn cyclic_graph_terminates() {
let mut g = Graph::new(Direction::LeftToRight);
g.nodes.push(Node::new("A", "A", NodeShape::Rectangle));
g.nodes.push(Node::new("B", "B", NodeShape::Rectangle));
g.edges.push(Edge::new("A", "B", None));
g.edges.push(Edge::new("B", "A", None));
let cfg = LayoutConfig::default();
let pos = layout(&g, &cfg);
assert_eq!(pos.len(), 2);
}
#[test]
fn single_node_layout() {
let mut g = Graph::new(Direction::LeftToRight);
g.nodes.push(Node::new("A", "Alone", NodeShape::Rectangle));
let cfg = LayoutConfig::default();
let pos = layout(&g, &cfg);
assert_eq!(pos["A"], (0, 0));
}
}