use crate::types::{DiagramWarning, Direction, Graph, NodeId, NodeShape, RenderOptions};
use std::collections::{HashMap, VecDeque};
const MIN_NODE_WIDTH: usize = 5;
const NODE_HEIGHT: usize = 3;
const DEFAULT_HORIZONTAL_GAP: usize = 8;
const DEFAULT_VERTICAL_GAP: usize = 4;
const MIN_GAP: usize = 2;
const SUBGRAPH_PADDING: usize = 2;
pub fn compute_layout(graph: &mut Graph) -> Vec<DiagramWarning> {
compute_layout_with_options(graph, &RenderOptions::default())
}
pub fn compute_layout_with_options(
graph: &mut Graph,
options: &RenderOptions,
) -> Vec<DiagramWarning> {
let mut warnings = Vec::new();
for node in graph.nodes.values_mut() {
node.width = (node.label.chars().count() + 2).max(MIN_NODE_WIDTH);
node.height = NODE_HEIGHT;
if node.shape == NodeShape::Cylinder {
node.height = 5;
}
}
let layers = assign_layers(graph, &mut warnings);
let (h_gap, v_gap) = calculate_gaps(graph, &layers, options.max_width);
assign_coordinates_with_gaps(graph, &layers, h_gap, v_gap);
compute_subgraph_bounds(graph);
warnings
}
fn calculate_gaps(
graph: &Graph,
layers: &HashMap<NodeId, usize>,
max_width: Option<usize>,
) -> (usize, usize) {
let max_width = match max_width {
Some(w) => w,
None => return (DEFAULT_HORIZONTAL_GAP, DEFAULT_VERTICAL_GAP),
};
let mut layers_map: HashMap<usize, Vec<&NodeId>> = HashMap::new();
let mut max_layer = 0;
for (id, &layer) in layers {
layers_map.entry(layer).or_default().push(id);
max_layer = max_layer.max(layer);
}
for nodes in layers_map.values_mut() {
nodes.sort();
}
if graph.direction.is_horizontal() {
let mut total_width = 0;
for l in 0..=max_layer {
let nodes_in_layer = layers_map.get(&l).map(|v| v.as_slice()).unwrap_or(&[]);
let layer_max_width = nodes_in_layer
.iter()
.filter_map(|id| graph.nodes.get(*id))
.map(|n| n.width)
.max()
.unwrap_or(0);
total_width += layer_max_width;
}
total_width += max_layer * DEFAULT_HORIZONTAL_GAP;
if total_width > max_width && max_layer > 0 {
let node_width = total_width - max_layer * DEFAULT_HORIZONTAL_GAP;
let available_for_gaps = max_width.saturating_sub(node_width);
let new_gap = (available_for_gaps / max_layer).max(MIN_GAP);
return (new_gap, DEFAULT_VERTICAL_GAP);
}
}
(DEFAULT_HORIZONTAL_GAP, DEFAULT_VERTICAL_GAP)
}
fn compute_subgraph_bounds(graph: &mut Graph) {
for sg in &mut graph.subgraphs {
if sg.nodes.is_empty() {
continue;
}
let mut min_x = usize::MAX;
let mut min_y = usize::MAX;
let mut max_x = 0;
let mut max_y = 0;
for node_id in &sg.nodes {
if let Some(node) = graph.nodes.get(node_id) {
min_x = min_x.min(node.x);
min_y = min_y.min(node.y);
max_x = max_x.max(node.x + node.width);
max_y = max_y.max(node.y + node.height);
}
}
if min_x != usize::MAX {
sg.x = min_x.saturating_sub(SUBGRAPH_PADDING);
sg.y = min_y.saturating_sub(SUBGRAPH_PADDING + 1); sg.width = (max_x - min_x) + SUBGRAPH_PADDING * 2;
sg.height = (max_y - min_y) + SUBGRAPH_PADDING * 2 + 1;
}
}
}
fn assign_layers(graph: &Graph, warnings: &mut Vec<DiagramWarning>) -> HashMap<NodeId, usize> {
let mut node_layers: HashMap<NodeId, usize> = HashMap::new();
let mut in_degree: HashMap<NodeId, usize> = HashMap::new();
for id in graph.nodes.keys() {
in_degree.insert(id.clone(), 0);
node_layers.insert(id.clone(), 0);
}
for edge in &graph.edges {
*in_degree.entry(edge.to.clone()).or_insert(0) += 1;
}
let mut queue: VecDeque<NodeId> = VecDeque::new();
let mut zero_in: Vec<&NodeId> = in_degree
.iter()
.filter(|(_, °)| deg == 0)
.map(|(id, _)| id)
.collect();
zero_in.sort();
for id in zero_in {
queue.push_back(id.clone());
}
let mut processed = 0;
while let Some(u) = queue.pop_front() {
processed += 1;
let mut neighbors: Vec<NodeId> = graph
.edges
.iter()
.filter(|e| e.from == u)
.map(|e| e.to.clone())
.collect();
neighbors.sort();
for v in neighbors {
let u_layer = *node_layers.get(&u).unwrap_or(&0);
let v_layer = node_layers.entry(v.clone()).or_insert(0);
*v_layer = (*v_layer).max(u_layer + 1);
if let Some(deg) = in_degree.get_mut(&v) {
*deg -= 1;
if *deg == 0 {
queue.push_back(v);
}
}
}
}
if processed < graph.nodes.len() {
let mut cycle_nodes: Vec<String> = in_degree
.iter()
.filter(|(_, °)| deg > 0)
.map(|(id, _)| id.clone())
.collect();
cycle_nodes.sort();
warnings.push(DiagramWarning::CycleDetected { nodes: cycle_nodes });
}
node_layers
}
fn assign_coordinates_with_gaps(
graph: &mut Graph,
node_layers: &HashMap<NodeId, usize>,
h_gap: usize,
v_gap: usize,
) {
let direction = graph.direction;
let mut layers_map: HashMap<usize, Vec<NodeId>> = HashMap::new();
let mut max_layer = 0;
for (id, &layer) in node_layers {
layers_map.entry(layer).or_default().push(id.clone());
max_layer = max_layer.max(layer);
}
for nodes in layers_map.values_mut() {
nodes.sort();
}
let mut layer_widths: HashMap<usize, usize> = HashMap::new();
let mut layer_heights: HashMap<usize, usize> = HashMap::new();
for l in 0..=max_layer {
let nodes_in_layer = layers_map.get(&l).map(|v| v.as_slice()).unwrap_or(&[]);
let mut max_w = 0;
let mut max_h = 0;
let mut total_w = 0;
let mut total_h = 0;
for id in nodes_in_layer {
if let Some(node) = graph.nodes.get(id) {
max_w = max_w.max(node.width);
max_h = max_h.max(node.height);
total_w += node.width + h_gap;
total_h += node.height + v_gap;
}
}
if direction.is_horizontal() {
layer_widths.insert(l, max_w);
layer_heights.insert(l, total_h.saturating_sub(v_gap));
} else {
layer_widths.insert(l, total_w.saturating_sub(h_gap));
layer_heights.insert(l, max_h);
}
}
let max_total_width = layer_widths.values().copied().max().unwrap_or(0);
let max_total_height = layer_heights.values().copied().max().unwrap_or(0);
if direction.is_horizontal() {
let mut current_x = 0;
for l in 0..=max_layer {
let layer_idx = match direction {
Direction::LR => l,
Direction::RL => max_layer - l,
_ => l,
};
let nodes_in_layer = layers_map.get(&layer_idx).cloned().unwrap_or_default();
let layer_h = *layer_heights.get(&layer_idx).unwrap_or(&0);
let mut start_y = (max_total_height.saturating_sub(layer_h)) / 2;
for id in nodes_in_layer {
if let Some(node) = graph.nodes.get_mut(&id) {
node.x = current_x;
node.y = start_y;
start_y += node.height + v_gap;
}
}
current_x += layer_widths.get(&layer_idx).unwrap_or(&0) + h_gap;
}
} else {
let mut current_y = 0;
for l in 0..=max_layer {
let layer_idx = match direction {
Direction::TB => l,
Direction::BT => max_layer - l,
_ => l,
};
let nodes_in_layer = layers_map.get(&layer_idx).cloned().unwrap_or_default();
let layer_w = *layer_widths.get(&layer_idx).unwrap_or(&0);
let mut start_x = (max_total_width.saturating_sub(layer_w)) / 2;
for id in nodes_in_layer {
if let Some(node) = graph.nodes.get_mut(&id) {
node.x = start_x;
node.y = current_y;
start_x += node.width + h_gap;
}
}
current_y += layer_heights.get(&layer_idx).unwrap_or(&0) + v_gap;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::parser::parse_mermaid;
#[test]
fn test_layout_lr() {
let mut graph = parse_mermaid("flowchart LR\nA --> B").unwrap();
let warnings = compute_layout(&mut graph);
let a = graph.nodes.get("A").unwrap();
let b = graph.nodes.get("B").unwrap();
assert!(a.x < b.x);
assert!(warnings.is_empty());
}
#[test]
fn test_layout_tb() {
let mut graph = parse_mermaid("flowchart TB\nA --> B").unwrap();
let warnings = compute_layout(&mut graph);
let a = graph.nodes.get("A").unwrap();
let b = graph.nodes.get("B").unwrap();
assert!(a.y < b.y);
assert!(warnings.is_empty());
}
#[test]
fn test_node_sizes() {
let mut graph = parse_mermaid("flowchart LR\nA[Hello World]").unwrap();
compute_layout(&mut graph);
let a = graph.nodes.get("A").unwrap();
assert_eq!(a.width, "Hello World".len() + 2);
assert_eq!(a.height, NODE_HEIGHT);
}
#[test]
fn test_cycle_produces_warning() {
let mut graph = parse_mermaid("flowchart LR\nA --> B\nB --> C\nC --> A").unwrap();
let warnings = compute_layout(&mut graph);
assert_eq!(warnings.len(), 1);
assert!(warnings[0].to_string().contains("Cycle"));
}
#[test]
fn test_acyclic_no_warning() {
let mut graph = parse_mermaid("flowchart LR\nA --> B\nB --> C\nA --> C").unwrap();
let warnings = compute_layout(&mut graph);
assert!(warnings.is_empty());
}
}