use crate::text::display_width;
use crate::types::{
DiagramWarning, Direction, Graph, NodeId, NodeShape, RenderOptions, TableField,
};
use std::collections::{HashMap, HashSet, VecDeque};
const MIN_NODE_WIDTH: usize = 5;
const NODE_HEIGHT: usize = 3;
const MIN_GAP: usize = 2;
const SUBGRAPH_PADDING: usize = 2;
pub fn compute_layout(graph: &mut Graph) -> Vec<DiagramWarning> {
compute_layout_with_options(graph, &RenderOptions::default())
}
pub fn compute_layout_with_options(
graph: &mut Graph,
options: &RenderOptions,
) -> Vec<DiagramWarning> {
let mut warnings = Vec::new();
let text_padding = options.border_padding * 2;
for node in graph.nodes.values_mut() {
let lines: Vec<&str> = node.label.split('\n').collect();
let max_line_width = lines.iter().map(|l| display_width(l)).max().unwrap_or(0);
node.width = (max_line_width + text_padding).max(MIN_NODE_WIDTH);
let line_count = lines.len();
node.height = if line_count > 1 {
line_count + 2 } else {
NODE_HEIGHT
};
if node.shape == NodeShape::Cylinder {
node.height = node.height.max(5);
}
if node.shape == NodeShape::Person {
node.height = node.height.max(5);
node.width = node.width.max(7);
}
if node.shape == NodeShape::Cloud {
node.width += 4;
node.height += 2;
}
if node.shape == NodeShape::Document {
node.height += 1;
}
if node.shape == NodeShape::Table && !node.fields.is_empty() {
for field in &node.fields {
let field_len = format_field_width(field);
node.width = node.width.max(field_len + 2 + text_padding); }
node.height = 3 + node.fields.len(); }
}
let layers = assign_layers(graph, &mut warnings);
let (h_gap, v_gap) = calculate_gaps(graph, &layers, options);
assign_coordinates_with_gaps(graph, &layers, h_gap, v_gap);
compute_subgraph_bounds(graph);
warnings
}
fn calculate_gaps(
graph: &Graph,
layers: &HashMap<NodeId, usize>,
options: &RenderOptions,
) -> (usize, usize) {
let h_gap = options.padding_x;
let v_gap = options.padding_y;
let max_width = match options.max_width {
Some(w) => w,
None => return (h_gap, v_gap),
};
let mut layers_map: HashMap<usize, Vec<&NodeId>> = HashMap::new();
let mut max_layer = 0;
for (id, &layer) in layers {
layers_map.entry(layer).or_default().push(id);
max_layer = max_layer.max(layer);
}
for nodes in layers_map.values_mut() {
nodes.sort();
}
if graph.direction.is_horizontal() {
let mut total_width = 0;
for l in 0..=max_layer {
let nodes_in_layer = layers_map.get(&l).map(|v| v.as_slice()).unwrap_or(&[]);
let layer_max_width = nodes_in_layer
.iter()
.filter_map(|id| graph.nodes.get(*id))
.map(|n| n.width)
.max()
.unwrap_or(0);
total_width += layer_max_width;
}
total_width += max_layer * h_gap;
if total_width > max_width && max_layer > 0 {
let node_width = total_width - max_layer * h_gap;
let available_for_gaps = max_width.saturating_sub(node_width);
let new_gap = (available_for_gaps / max_layer).max(MIN_GAP);
return (new_gap, v_gap);
}
}
(h_gap, v_gap)
}
fn format_field_width(field: &TableField) -> usize {
let mut len = display_width(&field.name);
if let Some(ref ti) = field.type_info {
len += 2 + display_width(ti); }
if let Some(ref c) = field.constraint {
len += 1 + constraint_abbrev(c).len(); }
len
}
fn constraint_abbrev(constraint: &str) -> String {
match constraint {
"primary_key" => "[PK]".to_string(),
"foreign_key" => "[FK]".to_string(),
"unique" => "[UQ]".to_string(),
"not_null" => "[NN]".to_string(),
other => format!("[{}]", other),
}
}
fn compute_subgraph_bounds(graph: &mut Graph) {
let sg_count = graph.subgraphs.len();
let sg_ids: Vec<String> = graph.subgraphs.iter().map(|sg| sg.id.clone()).collect();
let sg_parents: Vec<Option<String>> =
graph.subgraphs.iter().map(|sg| sg.parent.clone()).collect();
let mut has_children: std::collections::HashSet<String> = std::collections::HashSet::new();
for p in sg_parents.iter().flatten() {
has_children.insert(p.clone());
}
let mut processed: std::collections::HashSet<String> = std::collections::HashSet::new();
for _ in 0..sg_count + 1 {
for i in 0..sg_count {
let sg_id = &sg_ids[i];
if processed.contains(sg_id) {
continue;
}
let all_children_done = sg_ids.iter().enumerate().all(|(j, child_id)| {
if sg_parents[j].as_ref() == Some(sg_id) {
processed.contains(child_id)
} else {
true
}
});
if !all_children_done {
continue;
}
let sg = &graph.subgraphs[i];
let mut min_x = usize::MAX;
let mut min_y = usize::MAX;
let mut max_x = 0;
let mut max_y = 0;
for node_id in &sg.nodes {
if let Some(node) = graph.nodes.get(node_id) {
min_x = min_x.min(node.x);
min_y = min_y.min(node.y);
max_x = max_x.max(node.x + node.width);
max_y = max_y.max(node.y + node.height);
}
}
for (j, _child_id) in sg_ids.iter().enumerate() {
if sg_parents[j].as_ref() == Some(sg_id) {
let child = &graph.subgraphs[j];
if child.width > 0 && child.height > 0 {
min_x = min_x.min(child.x);
min_y = min_y.min(child.y);
max_x = max_x.max(child.x + child.width);
max_y = max_y.max(child.y + child.height);
}
}
}
if min_x != usize::MAX {
graph.subgraphs[i].x = min_x.saturating_sub(SUBGRAPH_PADDING);
graph.subgraphs[i].y = min_y.saturating_sub(SUBGRAPH_PADDING + 1);
graph.subgraphs[i].width = (max_x - min_x) + SUBGRAPH_PADDING * 2;
graph.subgraphs[i].height = (max_y - min_y) + SUBGRAPH_PADDING * 2 + 1;
}
processed.insert(sg_id.clone());
}
if processed.len() == sg_count {
break;
}
}
}
fn assign_layers(graph: &Graph, warnings: &mut Vec<DiagramWarning>) -> HashMap<NodeId, usize> {
let mut node_layers: HashMap<NodeId, usize> = HashMap::new();
let mut in_degree: HashMap<NodeId, usize> = HashMap::new();
let mut processed: HashSet<NodeId> = HashSet::new();
for id in graph.nodes.keys() {
in_degree.insert(id.clone(), 0);
node_layers.insert(id.clone(), 0);
}
for edge in &graph.edges {
*in_degree.entry(edge.to.clone()).or_insert(0) += 1;
}
let mut first_from_idx: HashMap<&str, usize> = HashMap::new();
for (i, edge) in graph.edges.iter().enumerate() {
first_from_idx.entry(edge.from.as_str()).or_insert(i);
}
let mut queue: VecDeque<NodeId> = VecDeque::new();
let mut zero_in: Vec<&NodeId> = in_degree
.iter()
.filter(|(_, °)| deg == 0)
.map(|(id, _)| id)
.collect();
zero_in.sort();
for id in zero_in {
queue.push_back(id.clone());
}
let total = graph.nodes.len();
let mut all_cycle_nodes: HashSet<String> = HashSet::new();
loop {
while let Some(u) = queue.pop_front() {
if processed.contains(&u) {
continue;
}
processed.insert(u.clone());
let mut neighbors: Vec<NodeId> = graph
.edges
.iter()
.filter(|e| e.from == u && !processed.contains(&e.to))
.map(|e| e.to.clone())
.collect();
neighbors.sort();
neighbors.dedup();
for v in &neighbors {
let u_layer = *node_layers.get(&u).unwrap_or(&0);
let v_layer = node_layers.entry(v.clone()).or_insert(0);
*v_layer = (*v_layer).max(u_layer + 1);
if let Some(deg) = in_degree.get_mut(v) {
*deg = deg.saturating_sub(1);
if *deg == 0 {
queue.push_back(v.clone());
}
}
}
}
if processed.len() >= total {
break;
}
let mut stuck: Vec<NodeId> = in_degree
.iter()
.filter(|(id, _)| !processed.contains(*id))
.map(|(id, _)| id.clone())
.collect();
let stuck_set: HashSet<&str> = stuck.iter().map(|s| s.as_str()).collect();
for n in &stuck {
let has_outgoing_to_stuck = graph
.edges
.iter()
.any(|e| e.from == *n && stuck_set.contains(e.to.as_str()));
if has_outgoing_to_stuck {
all_cycle_nodes.insert(n.clone());
}
}
stuck.sort_by(|a, b| {
let fa = first_from_idx
.get(a.as_str())
.copied()
.unwrap_or(usize::MAX);
let fb = first_from_idx
.get(b.as_str())
.copied()
.unwrap_or(usize::MAX);
fa.cmp(&fb).then(a.cmp(b))
});
if let Some(forced) = stuck.first() {
in_degree.insert(forced.clone(), 0);
queue.push_back(forced.clone());
}
}
if !all_cycle_nodes.is_empty() {
let mut cycle_nodes: Vec<String> = all_cycle_nodes.into_iter().collect();
cycle_nodes.sort();
warnings.push(DiagramWarning::CycleDetected { nodes: cycle_nodes });
}
node_layers
}
fn assign_coordinates_with_gaps(
graph: &mut Graph,
node_layers: &HashMap<NodeId, usize>,
h_gap: usize,
v_gap: usize,
) {
let direction = graph.direction;
let mut layers_map: HashMap<usize, Vec<NodeId>> = HashMap::new();
let mut max_layer = 0;
for (id, &layer) in node_layers {
layers_map.entry(layer).or_default().push(id.clone());
max_layer = max_layer.max(layer);
}
for nodes in layers_map.values_mut() {
nodes.sort();
}
let mut layer_widths: HashMap<usize, usize> = HashMap::new();
let mut layer_heights: HashMap<usize, usize> = HashMap::new();
for l in 0..=max_layer {
let nodes_in_layer = layers_map.get(&l).map(|v| v.as_slice()).unwrap_or(&[]);
let mut max_w = 0;
let mut max_h = 0;
let mut total_w = 0;
let mut total_h = 0;
for id in nodes_in_layer {
if let Some(node) = graph.nodes.get(id) {
max_w = max_w.max(node.width);
max_h = max_h.max(node.height);
total_w += node.width + h_gap;
total_h += node.height + v_gap;
}
}
if direction.is_horizontal() {
layer_widths.insert(l, max_w);
layer_heights.insert(l, total_h.saturating_sub(v_gap));
} else {
layer_widths.insert(l, total_w.saturating_sub(h_gap));
layer_heights.insert(l, max_h);
}
}
let max_total_width = layer_widths.values().copied().max().unwrap_or(0);
let max_total_height = layer_heights.values().copied().max().unwrap_or(0);
if direction.is_horizontal() {
let mut current_x = 0;
for l in 0..=max_layer {
let layer_idx = match direction {
Direction::LR => l,
Direction::RL => max_layer - l,
_ => l,
};
let nodes_in_layer = layers_map.get(&layer_idx).cloned().unwrap_or_default();
let layer_h = *layer_heights.get(&layer_idx).unwrap_or(&0);
let mut start_y = (max_total_height.saturating_sub(layer_h)) / 2;
for id in nodes_in_layer {
if let Some(node) = graph.nodes.get_mut(&id) {
node.x = current_x;
node.y = start_y;
start_y += node.height + v_gap;
}
}
current_x += layer_widths.get(&layer_idx).unwrap_or(&0) + h_gap;
}
} else {
let mut current_y = 0;
for l in 0..=max_layer {
let layer_idx = match direction {
Direction::TB => l,
Direction::BT => max_layer - l,
_ => l,
};
let nodes_in_layer = layers_map.get(&layer_idx).cloned().unwrap_or_default();
let layer_w = *layer_widths.get(&layer_idx).unwrap_or(&0);
let mut start_x = (max_total_width.saturating_sub(layer_w)) / 2;
for id in nodes_in_layer {
if let Some(node) = graph.nodes.get_mut(&id) {
node.x = start_x;
node.y = current_y;
start_x += node.width + h_gap;
}
}
current_y += layer_heights.get(&layer_idx).unwrap_or(&0) + v_gap;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::parser::parse_mermaid;
#[test]
fn test_layout_lr() {
let mut graph = parse_mermaid("flowchart LR\nA --> B").unwrap();
let warnings = compute_layout(&mut graph);
let a = graph.nodes.get("A").unwrap();
let b = graph.nodes.get("B").unwrap();
assert!(a.x < b.x);
assert!(warnings.is_empty());
}
#[test]
fn test_layout_tb() {
let mut graph = parse_mermaid("flowchart TB\nA --> B").unwrap();
let warnings = compute_layout(&mut graph);
let a = graph.nodes.get("A").unwrap();
let b = graph.nodes.get("B").unwrap();
assert!(a.y < b.y);
assert!(warnings.is_empty());
}
#[test]
fn test_node_sizes() {
let mut graph = parse_mermaid("flowchart LR\nA[Hello World]").unwrap();
compute_layout(&mut graph);
let a = graph.nodes.get("A").unwrap();
assert_eq!(a.width, "Hello World".len() + 2);
assert_eq!(a.height, NODE_HEIGHT);
}
#[test]
fn test_cycle_produces_warning() {
let mut graph = parse_mermaid("flowchart LR\nA --> B\nB --> C\nC --> A").unwrap();
let warnings = compute_layout(&mut graph);
assert_eq!(warnings.len(), 1);
assert!(warnings[0].to_string().contains("Cycle"));
}
#[test]
fn test_acyclic_no_warning() {
let mut graph = parse_mermaid("flowchart LR\nA --> B\nB --> C\nA --> C").unwrap();
let warnings = compute_layout(&mut graph);
assert!(warnings.is_empty());
}
#[test]
fn test_custom_padding() {
let mut graph = parse_mermaid("flowchart LR\nA --> B").unwrap();
let options = RenderOptions {
padding_x: 20,
padding_y: 10,
..Default::default()
};
compute_layout_with_options(&mut graph, &options);
let a = graph.nodes.get("A").unwrap();
let b = graph.nodes.get("B").unwrap();
assert!(b.x - (a.x + a.width) >= 20);
}
#[test]
fn test_border_padding_affects_width() {
let mut graph1 = parse_mermaid("flowchart LR\nA[Test]").unwrap();
let mut graph2 = parse_mermaid("flowchart LR\nA[Test]").unwrap();
let opts1 = RenderOptions {
border_padding: 1,
..Default::default()
};
let opts2 = RenderOptions {
border_padding: 3,
..Default::default()
};
compute_layout_with_options(&mut graph1, &opts1);
compute_layout_with_options(&mut graph2, &opts2);
let w1 = graph1.nodes.get("A").unwrap().width;
let w2 = graph2.nodes.get("A").unwrap().width;
assert!(w2 > w1);
}
}