use std::collections::HashMap;
use unicode_width::UnicodeWidthStr;
use crate::layout::subgraph::{SG_BORDER_PAD, parallel_label_extra};
use crate::types::{BarOrientation, Direction, Graph, NodeShape, Subgraph};
const SG_GAP_PER_BOUNDARY: usize = SG_BORDER_PAD + 1;
#[derive(Debug, Clone, Copy)]
pub struct LayoutConfig {
pub layer_gap: usize,
pub node_gap: usize,
pub backend: LayoutBackend,
}
impl Default for LayoutConfig {
fn default() -> Self {
Self {
layer_gap: 6,
node_gap: 2,
backend: LayoutBackend::default(),
}
}
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub enum LayoutBackend {
#[default]
Native,
Sugiyama,
}
impl LayoutConfig {
pub const fn with_gaps(layer_gap: usize, node_gap: usize) -> Self {
Self {
layer_gap,
node_gap,
backend: LayoutBackend::Native,
}
}
}
pub type GridPos = (usize, usize);
#[derive(Debug, Clone)]
pub struct EdgeWaypoints {
pub edge_idx: usize,
pub waypoints: Vec<GridPos>,
}
#[derive(Debug, Clone, Default)]
pub struct LayoutResult {
pub positions: HashMap<String, GridPos>,
pub edge_waypoints: Vec<EdgeWaypoints>,
}
pub fn layout(graph: &Graph, config: &LayoutConfig) -> LayoutResult {
if graph.nodes.is_empty() {
return LayoutResult::default();
}
let layers = assign_layers(graph);
let edges: Vec<(String, String)> = graph
.edges
.iter()
.map(|e| (e.from.clone(), e.to.clone()))
.collect();
let ordered = order_within_layers(graph, &layers, &edges);
let positions = compute_positions(graph, &ordered, config);
let edge_waypoints = compute_edge_waypoints(graph, &layers, &positions);
LayoutResult {
positions,
edge_waypoints,
}
}
fn compute_edge_waypoints(
graph: &Graph,
layers: &HashMap<String, usize>,
positions: &HashMap<String, GridPos>,
) -> Vec<EdgeWaypoints> {
let layer_anchor = layer_axis_anchors(graph, layers, positions);
let layer_occupied = layer_perpendicular_ranges(graph, layers, positions);
let mut out: Vec<EdgeWaypoints> = Vec::new();
for (edge_idx, edge) in graph.edges.iter().enumerate() {
let (Some(&from_layer), Some(&to_layer)) = (layers.get(&edge.from), layers.get(&edge.to))
else {
continue;
};
if to_layer <= from_layer + 1 {
continue; }
let (Some(&src), Some(&tgt)) = (positions.get(&edge.from), positions.get(&edge.to)) else {
continue;
};
let span = to_layer - from_layer;
let mut waypoints = Vec::with_capacity(span - 1);
for slot in 0..(span - 1) {
let intermediate_layer = from_layer + 1 + slot;
let Some(&anchor) = layer_anchor.get(&intermediate_layer) else {
continue;
};
let occupied = layer_occupied
.get(&intermediate_layer)
.map(Vec::as_slice)
.unwrap_or(&[]);
let frac = (slot + 1) as f64 / span as f64;
let waypoint = match graph.direction {
Direction::LeftToRight | Direction::RightToLeft => {
let ideal_row = interpolate(src.1, tgt.1, frac);
let row = nearest_clear(ideal_row, occupied);
(anchor, row)
}
Direction::TopToBottom | Direction::BottomToTop => {
let ideal_col = interpolate(src.0, tgt.0, frac);
let col = nearest_clear(ideal_col, occupied);
(col, anchor)
}
};
waypoints.push(waypoint);
}
if !waypoints.is_empty() {
out.push(EdgeWaypoints {
edge_idx,
waypoints,
});
}
}
out
}
fn layer_axis_anchors(
graph: &Graph,
layers: &HashMap<String, usize>,
positions: &HashMap<String, GridPos>,
) -> HashMap<usize, usize> {
let mut out: HashMap<usize, usize> = HashMap::new();
for (id, &layer) in layers {
let Some(&pos) = positions.get(id) else {
continue;
};
let (anchor, half_size) = match graph.direction {
Direction::LeftToRight | Direction::RightToLeft => {
(pos.0, node_box_width(graph, id) / 2)
}
Direction::TopToBottom | Direction::BottomToTop => {
(pos.1, node_box_height(graph, id) / 2)
}
};
out.entry(layer).or_insert(anchor + half_size);
}
out
}
fn layer_perpendicular_ranges(
graph: &Graph,
layers: &HashMap<String, usize>,
positions: &HashMap<String, GridPos>,
) -> HashMap<usize, Vec<(usize, usize)>> {
let mut out: HashMap<usize, Vec<(usize, usize)>> = HashMap::new();
for (id, &layer) in layers {
let Some(&pos) = positions.get(id) else {
continue;
};
let (start, size) = match graph.direction {
Direction::LeftToRight | Direction::RightToLeft => (pos.1, node_box_height(graph, id)),
Direction::TopToBottom | Direction::BottomToTop => (pos.0, node_box_width(graph, id)),
};
if size == 0 {
continue;
}
out.entry(layer)
.or_default()
.push((start, start + size - 1));
}
out
}
fn nearest_clear(target: usize, occupied: &[(usize, usize)]) -> usize {
let mut current = target;
let max_passes = occupied.len() + 2;
for _ in 0..max_passes {
let mut moved = false;
for &(start, end) in occupied {
if current < start || current > end {
continue;
}
let up_target = start.checked_sub(1);
let down_target = end + 1;
current = match up_target {
Some(up) if (current - up) < (down_target - current) => up,
Some(_) | None => down_target,
};
moved = true;
}
if !moved {
return current;
}
}
current
}
fn interpolate(a: usize, b: usize, frac: f64) -> usize {
let frac = frac.clamp(0.0, 1.0);
let af = a as f64;
let bf = b as f64;
(af + (bf - af) * frac).round() as usize
}
fn is_orthogonal(parent: Direction, child: Direction) -> bool {
parent.is_horizontal() != child.is_horizontal()
}
fn collect_orthogonal_sets<'a>(
subs: &'a [Subgraph],
all_subs: &'a [Subgraph],
parent_direction: Direction,
out: &mut Vec<Vec<String>>,
) {
for sg in subs {
if sg
.direction
.is_some_and(|sg_dir| is_orthogonal(parent_direction, sg_dir))
{
out.push(sg.node_ids.clone());
}
let children: Vec<Subgraph> = sg
.subgraph_ids
.iter()
.filter_map(|id| all_subs.iter().find(|s| &s.id == id).cloned())
.collect();
collect_orthogonal_sets(&children, all_subs, parent_direction, out);
}
}
fn orthogonal_node_sets(graph: &Graph) -> Vec<Vec<String>> {
let mut result = Vec::new();
collect_orthogonal_sets(
&graph.subgraphs,
&graph.subgraphs,
graph.direction,
&mut result,
);
result
}
fn assign_layers(graph: &Graph) -> HashMap<String, usize> {
let mut layer: HashMap<String, usize> = HashMap::new();
let mut predecessors: HashMap<&str, Vec<&str>> = HashMap::new();
for node in &graph.nodes {
predecessors.entry(node.id.as_str()).or_default();
}
for edge in &graph.edges {
predecessors
.entry(edge.to.as_str())
.or_default()
.push(edge.from.as_str());
}
let max_iter = graph.nodes.len() + 1;
let mut changed = true;
let mut iter = 0;
for node in &graph.nodes {
layer.insert(node.id.clone(), 0);
}
while changed && iter < max_iter {
changed = false;
iter += 1;
for edge in &graph.edges {
let from_layer = layer.get(edge.from.as_str()).copied().unwrap_or(0);
let to_layer = layer.entry(edge.to.clone()).or_insert(0);
if from_layer + 1 > *to_layer {
*to_layer = from_layer + 1;
changed = true;
}
}
}
for node in &graph.nodes {
layer.entry(node.id.clone()).or_insert(0);
}
let ortho_sets = orthogonal_node_sets(graph);
if !ortho_sets.is_empty() {
let all_ortho: std::collections::HashSet<&str> = ortho_sets
.iter()
.flat_map(|s| s.iter().map(String::as_str))
.collect();
for set in &ortho_sets {
let present: Vec<&str> = set
.iter()
.map(String::as_str)
.filter(|id| layer.contains_key(*id))
.collect();
if present.is_empty() {
continue;
}
let min_layer = present.iter().map(|id| layer[*id]).min().unwrap_or(0);
for id in &present {
layer.insert((*id).to_owned(), min_layer);
}
}
let max_iter2 = graph.nodes.len() + 1;
let mut changed2 = true;
let mut iter2 = 0;
while changed2 && iter2 < max_iter2 {
changed2 = false;
iter2 += 1;
for edge in &graph.edges {
if all_ortho.contains(edge.to.as_str()) {
continue;
}
let from_layer = layer.get(edge.from.as_str()).copied().unwrap_or(0);
let to_layer = layer.entry(edge.to.clone()).or_insert(0);
if from_layer + 1 > *to_layer {
*to_layer = from_layer + 1;
changed2 = true;
}
}
}
}
layer
}
fn order_within_layers(
graph: &Graph,
layers: &HashMap<String, usize>,
edges: &[(String, String)],
) -> Vec<Vec<String>> {
let max_layer = layers.values().copied().max().unwrap_or(0);
let num_layers = max_layer + 1;
let mut buckets: Vec<Vec<String>> = vec![Vec::new(); num_layers];
for node in &graph.nodes {
if let Some(&l) = layers.get(&node.id) {
buckets[l].push(node.id.clone());
}
}
let mut successors: HashMap<&str, Vec<&str>> = HashMap::new();
let mut predecessors: HashMap<&str, Vec<&str>> = HashMap::new();
for (from, to) in edges {
successors
.entry(from.as_str())
.or_default()
.push(to.as_str());
predecessors
.entry(to.as_str())
.or_default()
.push(from.as_str());
}
let node_layer: HashMap<&str, usize> = layers.iter().map(|(id, &l)| (id.as_str(), l)).collect();
const MAX_PASSES: usize = 8;
const NO_IMPROVEMENT_CAP: usize = 4;
let mut best = buckets.clone();
let mut best_crossings = count_crossings(edges, &node_layer, &best);
let mut no_improvement = 0usize;
let metrics = [SortMetric::Barycenter, SortMetric::Median];
for pass in 0..MAX_PASSES {
let metric = metrics[pass % metrics.len()];
sort_by_metric(&mut buckets, &predecessors, SweepDirection::Forward, metric);
sort_by_metric(&mut buckets, &successors, SweepDirection::Backward, metric);
transpose_pass(&mut buckets, edges, &node_layer);
let c = count_crossings(edges, &node_layer, &buckets);
if c < best_crossings {
best = buckets.clone();
best_crossings = c;
no_improvement = 0;
} else {
no_improvement += 1;
if no_improvement >= NO_IMPROVEMENT_CAP {
break;
}
}
if best_crossings == 0 {
break;
}
}
let ortho_sets = orthogonal_node_sets(graph);
if !ortho_sets.is_empty() {
for layer_nodes in &mut best {
for set in &ortho_sets {
let in_layer: Vec<usize> = layer_nodes
.iter()
.enumerate()
.filter(|(_, id)| set.contains(id))
.map(|(i, _)| i)
.collect();
if in_layer.len() <= 1 {
continue;
}
let internal_ids: Vec<String> =
in_layer.iter().map(|&i| layer_nodes[i].clone()).collect();
let internal_set: std::collections::HashSet<&str> =
internal_ids.iter().map(String::as_str).collect();
let mut successors: HashMap<&str, Vec<&str>> =
internal_set.iter().map(|&n| (n, Vec::new())).collect();
let mut in_degree: HashMap<&str, usize> =
internal_set.iter().map(|&n| (n, 0usize)).collect();
for edge in &graph.edges {
if internal_set.contains(edge.from.as_str())
&& internal_set.contains(edge.to.as_str())
{
successors
.entry(edge.from.as_str())
.or_default()
.push(edge.to.as_str());
*in_degree.entry(edge.to.as_str()).or_default() += 1;
}
}
let mut queue: std::collections::VecDeque<&str> = in_degree
.iter()
.filter(|(_, d)| **d == 0)
.map(|(&n, _)| n)
.collect();
let mut topo: Vec<String> = Vec::new();
while let Some(node) = queue.pop_front() {
topo.push(node.to_owned());
let succs: Vec<&str> = successors.get(node).cloned().unwrap_or_default();
for succ in succs {
let d = in_degree.entry(succ).or_default();
*d = d.saturating_sub(1);
if *d == 0 {
queue.push_back(succ);
}
}
}
if topo.len() == in_layer.len() {
for (slot, &pos) in in_layer.iter().enumerate() {
layer_nodes[pos] = topo[slot].clone();
}
}
}
}
}
best
}
#[derive(Copy, Clone)]
enum SweepDirection {
Forward,
Backward,
}
#[derive(Copy, Clone)]
enum SortMetric {
Barycenter,
Median,
}
fn sort_by_metric(
buckets: &mut [Vec<String>],
neighbors: &HashMap<&str, Vec<&str>>,
dir: SweepDirection,
metric: SortMetric,
) {
let num_layers = buckets.len();
if num_layers < 2 {
return;
}
let layer_iter: Box<dyn Iterator<Item = usize>> = match dir {
SweepDirection::Forward => Box::new(1..num_layers),
SweepDirection::Backward => Box::new((0..num_layers - 1).rev()),
};
for l in layer_iter {
let ref_layer = match dir {
SweepDirection::Forward => l - 1,
SweepDirection::Backward => l + 1,
};
let ref_positions: HashMap<&str, f64> = buckets[ref_layer]
.iter()
.enumerate()
.map(|(i, id)| (id.as_str(), i as f64))
.collect();
let mut keyed: Vec<(String, f64)> = buckets[l]
.iter()
.enumerate()
.map(|(i, id)| {
let mut positions: Vec<f64> = neighbors
.get(id.as_str())
.map(|ns| {
ns.iter()
.map(|n| ref_positions.get(n).copied().unwrap_or(i as f64))
.collect()
})
.unwrap_or_default();
let key = if positions.is_empty() {
i as f64
} else {
match metric {
SortMetric::Barycenter => {
let sum: f64 = positions.iter().sum();
sum / positions.len() as f64
}
SortMetric::Median => median_of_sorted({
positions.sort_by(|a, b| {
a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
});
&positions
}),
}
};
(id.clone(), key)
})
.collect();
keyed.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
buckets[l] = keyed.into_iter().map(|(id, _)| id).collect();
}
}
fn median_of_sorted(sorted: &[f64]) -> f64 {
debug_assert!(!sorted.is_empty(), "median of empty slice is undefined");
let n = sorted.len();
if n.is_multiple_of(2) {
(sorted[n / 2 - 1] + sorted[n / 2]) / 2.0
} else {
sorted[n / 2]
}
}
fn transpose_pass(
buckets: &mut [Vec<String>],
edges: &[(String, String)],
node_layer: &HashMap<&str, usize>,
) -> bool {
let mut any_improved = false;
let mut current_crossings = count_crossings(edges, node_layer, buckets);
let mut improved_this_pass = true;
let mut passes_remaining = 4usize; while improved_this_pass && passes_remaining > 0 {
improved_this_pass = false;
passes_remaining -= 1;
for layer_idx in 0..buckets.len() {
let layer_len = buckets[layer_idx].len();
if layer_len < 2 {
continue;
}
for i in 0..(layer_len - 1) {
buckets[layer_idx].swap(i, i + 1);
let after = count_crossings(edges, node_layer, buckets);
if after < current_crossings {
current_crossings = after;
any_improved = true;
improved_this_pass = true;
} else {
buckets[layer_idx].swap(i, i + 1);
}
}
}
}
any_improved
}
fn count_crossings(
edges: &[(String, String)],
node_layer: &HashMap<&str, usize>,
buckets: &[Vec<String>],
) -> usize {
let mut rank: HashMap<&str, usize> = HashMap::new();
for layer_nodes in buckets {
for (i, id) in layer_nodes.iter().enumerate() {
rank.insert(id.as_str(), i);
}
}
let edges_with_gaps: Vec<(usize, usize, usize, usize)> = edges
.iter()
.filter_map(|(from, to)| {
let fl = *node_layer.get(from.as_str())?;
let tl = *node_layer.get(to.as_str())?;
let fr = *rank.get(from.as_str())?;
let tr = *rank.get(to.as_str())?;
Some((fl, tl, fr, tr))
})
.collect();
let mut total = 0usize;
for i in 0..edges_with_gaps.len() {
let (fl1, tl1, fr1, tr1) = edges_with_gaps[i];
for &(fl2, tl2, fr2, tr2) in &edges_with_gaps[i + 1..] {
if (fl1, tl1) != (fl2, tl2) {
continue;
}
let from_order = fr1.cmp(&fr2);
let to_order = tr1.cmp(&tr2);
if from_order != std::cmp::Ordering::Equal
&& to_order != std::cmp::Ordering::Equal
&& from_order != to_order
{
total += 1;
}
}
}
total
}
pub(crate) fn node_box_width(graph: &Graph, id: &str) -> usize {
if let Some(node) = graph.node(id) {
let label_width = node.label_width();
let inner = label_width + 4; match node.shape {
NodeShape::Diamond => inner,
NodeShape::Circle | NodeShape::Stadium | NodeShape::Hexagon | NodeShape::Asymmetric => {
inner + 2
}
NodeShape::Subroutine => inner + 2,
NodeShape::Cylinder => inner,
NodeShape::Parallelogram | NodeShape::Trapezoid => inner + 2,
NodeShape::DoubleCircle => inner + 4,
NodeShape::Rectangle | NodeShape::Rounded | NodeShape::Note => inner,
NodeShape::Bar(BarOrientation::Horizontal) => 5,
NodeShape::Bar(BarOrientation::Vertical) => 1,
}
} else {
6 }
}
pub(crate) fn node_box_height(graph: &Graph, id: &str) -> usize {
if let Some(node) = graph.node(id) {
let extra = node.label_line_count().saturating_sub(1);
match node.shape {
NodeShape::Diamond
| NodeShape::Rectangle
| NodeShape::Rounded
| NodeShape::Circle
| NodeShape::Stadium
| NodeShape::Hexagon
| NodeShape::Asymmetric
| NodeShape::Parallelogram
| NodeShape::Trapezoid
| NodeShape::Subroutine
| NodeShape::Note => 3 + extra,
NodeShape::Cylinder => 4 + extra,
NodeShape::DoubleCircle => 5 + extra,
NodeShape::Bar(BarOrientation::Vertical) => 5,
NodeShape::Bar(BarOrientation::Horizontal) => 1,
}
} else {
3
}
}
fn build_node_layer_map(ordered: &[Vec<String>]) -> HashMap<&str, usize> {
let mut map = HashMap::new();
for (layer_idx, layer_nodes) in ordered.iter().enumerate() {
for id in layer_nodes {
map.insert(id.as_str(), layer_idx);
}
}
map
}
fn label_gap(
graph: &Graph,
node_layer: &HashMap<&str, usize>,
layer_a: usize,
layer_b: usize,
default_gap: usize,
parallel_groups: &[Vec<usize>],
) -> usize {
let crossings: Vec<(usize, usize)> = graph .edges
.iter()
.enumerate()
.filter(|(_, e)| {
let fl = node_layer.get(e.from.as_str()).copied().unwrap_or(0);
let tl = node_layer.get(e.to.as_str()).copied().unwrap_or(0);
(fl == layer_a && tl == layer_b) || (fl == layer_b && tl == layer_a)
})
.filter_map(|(i, e)| e.label.as_deref().map(|l| (i, UnicodeWidthStr::width(l))))
.collect();
if crossings.is_empty() {
return default_gap;
}
let max_lbl = crossings.iter().map(|(_, w)| *w).max().unwrap_or(0);
let needed_for_width = max_lbl + 2;
let mut widths: Vec<usize> = crossings.iter().map(|(_, w)| *w).collect();
widths.sort_unstable();
let count = widths.len();
let needed_for_stacking = count * 2 + 1;
let parallel_extra = parallel_groups
.iter()
.filter_map(|group| {
let count_in_gap: usize = group
.iter()
.filter(|&&edge_idx| crossings.iter().any(|(i, _)| *i == edge_idx))
.count();
if count_in_gap < 2 {
return None;
}
Some((count_in_gap - 1) * (max_lbl + 2))
})
.max()
.unwrap_or(0);
default_gap
.max(needed_for_width + parallel_extra)
.max(needed_for_stacking)
}
fn build_subgraph_parent_map(graph: &Graph) -> HashMap<&str, &str> {
let mut m = HashMap::new();
for parent in &graph.subgraphs {
for child_id in &parent.subgraph_ids {
m.insert(child_id.as_str(), parent.id.as_str());
}
}
m
}
fn node_subgraph_chain<'a>(
node_id: &str,
node_to_sg: &'a HashMap<String, String>,
parent_map: &'a HashMap<&'a str, &'a str>,
) -> Vec<&'a str> {
let mut chain = Vec::new();
let Some(sg_id) = node_to_sg.get(node_id) else {
return chain;
};
let mut cur: &str = sg_id.as_str();
chain.push(cur);
while let Some(&parent) = parent_map.get(cur) {
chain.push(parent);
cur = parent;
}
chain
}
fn subgraph_boundary_count(chain_a: &[&str], chain_b: &[&str]) -> usize {
let a_len = chain_a.len();
let b_len = chain_b.len();
let mut shared = 0usize;
for i in 1..=a_len.min(b_len) {
if chain_a[a_len - i] == chain_b[b_len - i] {
shared += 1;
} else {
break;
}
}
(a_len - shared) + (b_len - shared)
}
fn sibling_gap(
node_a: &str,
node_b: &str,
node_to_sg: &HashMap<String, String>,
parent_map: &HashMap<&str, &str>,
base_gap: usize,
) -> usize {
let chain_a = node_subgraph_chain(node_a, node_to_sg, parent_map);
let chain_b = node_subgraph_chain(node_b, node_to_sg, parent_map);
let boundaries = subgraph_boundary_count(&chain_a, &chain_b);
base_gap + boundaries * SG_GAP_PER_BOUNDARY
}
fn layer_parallel_label_extra_width(
graph: &Graph,
layer_nodes: &[String],
node_to_sg: &HashMap<String, String>,
) -> usize {
layer_parallel_label_extra(graph, layer_nodes, node_to_sg, true)
}
fn layer_parallel_label_extra_height(
graph: &Graph,
layer_nodes: &[String],
node_to_sg: &HashMap<String, String>,
) -> usize {
layer_parallel_label_extra(graph, layer_nodes, node_to_sg, false)
}
fn layer_parallel_label_extra(
graph: &Graph,
layer_nodes: &[String],
node_to_sg: &HashMap<String, String>,
axis_w: bool,
) -> usize {
let mut seen: std::collections::HashSet<&str> = std::collections::HashSet::new();
let mut max_extra: usize = 0;
for nid in layer_nodes {
let Some(sg_id) = node_to_sg.get(nid) else {
continue;
};
if !seen.insert(sg_id.as_str()) {
continue;
}
let Some(sg) = graph.find_subgraph(sg_id) else {
continue;
};
let (w, h) = parallel_label_extra(graph, sg);
let extra = if axis_w { w } else { h };
max_extra = max_extra.max(extra);
}
max_extra
}
fn compute_positions(
graph: &Graph,
ordered: &[Vec<String>],
config: &LayoutConfig,
) -> HashMap<String, GridPos> {
let mut positions: HashMap<String, GridPos> = HashMap::new();
let node_layer = build_node_layer_map(ordered);
let parallel_groups = graph.parallel_edge_groups();
let node_to_sg = graph.node_to_subgraph();
let sg_parent = build_subgraph_parent_map(graph);
match graph.direction {
Direction::LeftToRight | Direction::RightToLeft => {
let mut col = 0usize;
for (layer_idx, layer_nodes) in ordered.iter().enumerate() {
if layer_nodes.is_empty() {
continue;
}
let base_layer_width = layer_nodes
.iter()
.map(|id| node_box_width(graph, id))
.max()
.unwrap_or(6);
let extra_w = layer_parallel_label_extra_width(graph, layer_nodes, &node_to_sg);
let layer_width = base_layer_width + extra_w;
let mut row = 0usize;
let mut prev: Option<&str> = None;
for id in layer_nodes {
let h = node_box_height(graph, id);
if let Some(prev_id) = prev {
let gap =
sibling_gap(prev_id, id, &node_to_sg, &sg_parent, config.node_gap);
row += gap.saturating_sub(config.node_gap);
}
positions.insert(id.clone(), (col, row));
row += h + config.node_gap;
prev = Some(id.as_str());
}
let gap = if layer_idx + 1 < ordered.len() {
label_gap(
graph,
&node_layer,
layer_idx,
layer_idx + 1,
config.layer_gap,
¶llel_groups,
)
} else {
config.layer_gap
};
col += layer_width + gap;
}
if graph.direction == Direction::RightToLeft {
let max_col = positions.values().map(|(c, _)| *c).max().unwrap_or(0);
for (col, _) in positions.values_mut() {
*col = max_col - *col;
}
}
}
Direction::TopToBottom | Direction::BottomToTop => {
let mut row = 0usize;
for (layer_idx, layer_nodes) in ordered.iter().enumerate() {
if layer_nodes.is_empty() {
continue;
}
let base_layer_height = layer_nodes
.iter()
.map(|id| node_box_height(graph, id))
.max()
.unwrap_or(3);
let extra_h = layer_parallel_label_extra_height(graph, layer_nodes, &node_to_sg);
let layer_height = base_layer_height + extra_h;
let mut col = 0usize;
let mut prev: Option<&str> = None;
for id in layer_nodes {
let w = node_box_width(graph, id);
if let Some(prev_id) = prev {
let gap =
sibling_gap(prev_id, id, &node_to_sg, &sg_parent, config.node_gap);
col += gap.saturating_sub(config.node_gap);
}
positions.insert(id.clone(), (col, row));
col += w + config.node_gap;
prev = Some(id.as_str());
}
let gap = if layer_idx + 1 < ordered.len() {
label_gap(
graph,
&node_layer,
layer_idx,
layer_idx + 1,
config.layer_gap,
¶llel_groups,
)
} else {
config.layer_gap
};
row += layer_height + gap;
}
if graph.direction == Direction::BottomToTop {
let max_row = positions.values().map(|(_, r)| *r).max().unwrap_or(0);
for (_, row) in positions.values_mut() {
*row = max_row - *row;
}
}
}
}
positions
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::{Direction, Edge, Graph, Node, NodeShape};
fn simple_lr_graph() -> Graph {
let mut g = Graph::new(Direction::LeftToRight);
g.nodes.push(Node::new("A", "A", NodeShape::Rectangle));
g.nodes.push(Node::new("B", "B", NodeShape::Rectangle));
g.nodes.push(Node::new("C", "C", NodeShape::Rectangle));
g.edges.push(Edge::new("A", "B", None));
g.edges.push(Edge::new("B", "C", None));
g
}
#[test]
fn lr_nodes_have_increasing_columns() {
let g = simple_lr_graph();
let cfg = LayoutConfig::default();
let pos = layout(&g, &cfg).positions;
assert!(pos["A"].0 < pos["B"].0);
assert!(pos["B"].0 < pos["C"].0);
}
#[test]
fn td_nodes_have_increasing_rows() {
let mut g = Graph::new(Direction::TopToBottom);
g.nodes.push(Node::new("A", "A", NodeShape::Rectangle));
g.nodes.push(Node::new("B", "B", NodeShape::Rectangle));
g.edges.push(Edge::new("A", "B", None));
let cfg = LayoutConfig::default();
let pos = layout(&g, &cfg).positions;
assert!(pos["A"].1 < pos["B"].1);
}
#[test]
fn cyclic_graph_terminates() {
let mut g = Graph::new(Direction::LeftToRight);
g.nodes.push(Node::new("A", "A", NodeShape::Rectangle));
g.nodes.push(Node::new("B", "B", NodeShape::Rectangle));
g.edges.push(Edge::new("A", "B", None));
g.edges.push(Edge::new("B", "A", None));
let cfg = LayoutConfig::default();
let pos = layout(&g, &cfg).positions;
assert_eq!(pos.len(), 2);
}
#[test]
fn single_node_layout() {
let mut g = Graph::new(Direction::LeftToRight);
g.nodes.push(Node::new("A", "Alone", NodeShape::Rectangle));
let cfg = LayoutConfig::default();
let pos = layout(&g, &cfg).positions;
assert_eq!(pos["A"], (0, 0));
}
#[test]
fn short_edges_get_no_waypoints() {
let g = simple_lr_graph();
let result = layout(&g, &LayoutConfig::default());
assert!(result.edge_waypoints.is_empty());
}
#[test]
fn long_edge_gets_waypoint_per_intermediate_layer() {
let mut g = Graph::new(Direction::LeftToRight);
g.nodes.push(Node::new("A", "A", NodeShape::Rectangle));
g.nodes.push(Node::new("B", "B", NodeShape::Rectangle));
g.nodes.push(Node::new("C", "C", NodeShape::Rectangle));
g.nodes.push(Node::new("D", "D", NodeShape::Rectangle));
g.edges.push(Edge::new("A", "B", None));
g.edges.push(Edge::new("B", "C", None));
g.edges.push(Edge::new("C", "D", None));
g.edges.push(Edge::new("A", "D", None));
let result = layout(&g, &LayoutConfig::default());
assert_eq!(result.edge_waypoints.len(), 1);
let w = &result.edge_waypoints[0];
assert_eq!(w.edge_idx, 3);
assert_eq!(w.waypoints.len(), 2, "two intermediate layers (B and C)");
let pos_b = result.positions["B"];
let h_b = node_box_height(&g, "B");
let row_inside_b = (pos_b.1)..(pos_b.1 + h_b);
assert!(
!row_inside_b.contains(&w.waypoints[0].1),
"waypoint row {} should not be inside B's row range {row_inside_b:?}",
w.waypoints[0].1
);
}
#[test]
fn back_edges_get_no_waypoints() {
let mut g = Graph::new(Direction::LeftToRight);
g.nodes.push(Node::new("A", "A", NodeShape::Rectangle));
g.nodes.push(Node::new("B", "B", NodeShape::Rectangle));
g.nodes.push(Node::new("C", "C", NodeShape::Rectangle));
g.edges.push(Edge::new("A", "B", None));
g.edges.push(Edge::new("B", "C", None));
g.edges.push(Edge::new("C", "A", None));
let result = layout(&g, &LayoutConfig::default());
assert!(result.edge_waypoints.is_empty());
}
#[test]
fn nearest_clear_no_overlap_returns_target() {
assert_eq!(nearest_clear(5, &[(10, 12)]), 5);
assert_eq!(nearest_clear(20, &[(10, 12)]), 20);
assert_eq!(nearest_clear(5, &[]), 5);
}
#[test]
fn nearest_clear_snaps_off_overlap_to_closer_boundary() {
assert_eq!(nearest_clear(11, &[(10, 12)]), 13);
assert_eq!(nearest_clear(12, &[(10, 12)]), 13);
assert_eq!(nearest_clear(10, &[(10, 12)]), 9);
}
#[test]
fn nearest_clear_top_edge_pushes_down() {
assert_eq!(nearest_clear(1, &[(0, 4)]), 5);
}
#[test]
fn nearest_clear_handles_consecutive_ranges() {
assert_eq!(nearest_clear(1, &[(0, 2), (4, 6)]), 3);
}
#[test]
fn median_of_sorted_picks_middle() {
assert_eq!(median_of_sorted(&[1.0, 2.0, 3.0]), 2.0);
assert_eq!(median_of_sorted(&[5.0]), 5.0);
}
#[test]
fn median_of_sorted_averages_two_middle_for_even_length() {
assert_eq!(median_of_sorted(&[1.0, 2.0, 3.0, 4.0]), 2.5);
assert_eq!(median_of_sorted(&[1.0, 1.0, 5.0, 5.0]), 3.0);
}
#[test]
fn median_resists_outliers_better_than_barycenter() {
let xs = [0.0, 1.0, 2.0, 100.0]; let median = median_of_sorted(&xs);
let barycenter: f64 = xs.iter().sum::<f64>() / xs.len() as f64;
assert!((median - 1.5).abs() < 0.01); assert!(barycenter > 25.0); }
#[test]
fn transpose_swaps_when_it_reduces_crossings() {
let mut buckets = vec![
vec!["A".to_string(), "B".to_string()],
vec!["D".to_string(), "C".to_string()],
];
let edges = vec![
("A".to_string(), "C".to_string()),
("B".to_string(), "D".to_string()),
];
let mut node_layer: HashMap<&str, usize> = HashMap::new();
node_layer.insert("A", 0);
node_layer.insert("B", 0);
node_layer.insert("C", 1);
node_layer.insert("D", 1);
let before = count_crossings(&edges, &node_layer, &buckets);
assert_eq!(before, 1, "scenario should start with 1 crossing");
let improved = transpose_pass(&mut buckets, &edges, &node_layer);
let after = count_crossings(&edges, &node_layer, &buckets);
assert!(improved, "transpose should report improvement");
assert_eq!(after, 0, "crossing should be eliminated by the swap");
}
#[test]
fn transpose_leaves_already_optimal_orderings_alone() {
let mut buckets = vec![
vec!["A".to_string(), "B".to_string()],
vec!["C".to_string(), "D".to_string()],
];
let edges = vec![
("A".to_string(), "C".to_string()),
("B".to_string(), "D".to_string()),
];
let mut node_layer: HashMap<&str, usize> = HashMap::new();
node_layer.insert("A", 0);
node_layer.insert("B", 0);
node_layer.insert("C", 1);
node_layer.insert("D", 1);
let improved = transpose_pass(&mut buckets, &edges, &node_layer);
assert!(!improved, "no swap should be reported when already optimal");
assert_eq!(buckets[1], vec!["C".to_string(), "D".to_string()]);
}
}