use unicode_width::UnicodeWidthStr;
use crate::er::{AttributeKey, Cardinality, ErDiagram, Relationship};
use crate::render::box_table::{NAME_PAD, grid_to_string, pad_right, put, put_str};
const DEFAULT_MAX_WIDTH: usize = 80;
const MIN_ENTITY_GAP: usize = 4;
const HEADER_ROWS: usize = 3;
const ROW_GAP: usize = 3;
pub fn render(chart: &ErDiagram, max_width: Option<usize>) -> String {
if chart.entities.is_empty() {
return String::new();
}
let budget = max_width.unwrap_or(DEFAULT_MAX_WIDTH);
let entity_widths: Vec<usize> = chart.entities.iter().map(entity_box_width).collect();
let entity_heights: Vec<usize> = chart.entities.iter().map(entity_box_height).collect();
let n = chart.entities.len();
let n_cols = decide_cols(n, &entity_widths, budget);
let entity_grid_pos: Vec<(usize, usize)> = (0..n).map(|i| (i / n_cols, i % n_cols)).collect();
let n_rows = n.div_ceil(n_cols);
let col_widths: Vec<usize> = (0..n_cols)
.map(|gc| {
(0..n)
.filter(|&i| entity_grid_pos[i].1 == gc)
.map(|i| entity_widths[i])
.max()
.unwrap_or(0)
})
.collect();
let row_heights: Vec<usize> = (0..n_rows)
.map(|gr| {
(0..n)
.filter(|&i| entity_grid_pos[i].0 == gr)
.map(|i| entity_heights[i])
.max()
.unwrap_or(HEADER_ROWS)
})
.collect();
let intra_row_pair_gaps = compute_intra_row_pair_gaps(chart, &entity_grid_pos, n_cols);
let entity_left: Vec<usize> = compute_entity_left(
n,
&entity_grid_pos,
&col_widths,
n_cols,
&intra_row_pair_gaps,
);
let has_labels = chart
.relationships
.iter()
.any(|r| r.label.as_deref().is_some_and(|s| !s.is_empty()));
let top_pad: usize = if has_labels { 1 } else { 0 };
let entity_top: Vec<usize> = compute_entity_top(n, &entity_grid_pos, &row_heights, top_pad);
let canvas_width =
compute_canvas_width(n, chart, &entity_grid_pos, &entity_left, &entity_widths);
let canvas_height = {
let total_entity_h: usize = row_heights.iter().sum();
let gaps = if n_rows > 1 {
(n_rows - 1) * ROW_GAP
} else {
0
};
top_pad + total_entity_h + gaps
};
let mut grid: Vec<Vec<char>> = vec![vec![' '; canvas_width.max(1)]; canvas_height.max(1)];
for (i, entity) in chart.entities.iter().enumerate() {
let left = entity_left[i];
let right = left + entity_widths[i] - 1;
draw_entity_box(&mut grid, entity_top[i], left, right, entity);
}
let mut used_label_ranges: std::collections::HashMap<usize, Vec<(usize, usize)>> =
std::collections::HashMap::new();
for rel in &chart.relationships {
let (Some(from_idx), Some(to_idx)) =
(chart.entity_index(&rel.from), chart.entity_index(&rel.to))
else {
continue;
};
if from_idx == to_idx {
continue;
}
let from_grid_row = entity_grid_pos[from_idx].0;
let to_grid_row = entity_grid_pos[to_idx].0;
if from_grid_row == to_grid_row {
draw_relationship_line(
&mut grid,
entity_top[from_idx],
entity_left[from_idx],
entity_widths[from_idx],
entity_left[to_idx],
entity_widths[to_idx],
rel,
top_pad,
);
} else {
let from_is_rightmost = is_rightmost_in_row(from_idx, &entity_grid_pos, n_cols, n);
let to_is_rightmost = is_rightmost_in_row(to_idx, &entity_grid_pos, n_cols, n);
draw_cross_row_relationship(
&mut grid,
entity_top[from_idx],
entity_heights[from_idx],
entity_left[from_idx],
entity_widths[from_idx],
entity_top[to_idx],
entity_heights[to_idx],
entity_left[to_idx],
entity_widths[to_idx],
rel,
canvas_width,
from_is_rightmost,
to_is_rightmost,
&mut used_label_ranges,
);
}
}
grid_to_string(&grid)
}
fn decide_cols(n: usize, entity_widths: &[usize], budget: usize) -> usize {
if n <= 1 {
return 1;
}
let single_row_width: usize = entity_widths.iter().sum::<usize>() + MIN_ENTITY_GAP * (n - 1);
if single_row_width <= budget {
return n; }
let cols = (n as f64).sqrt().ceil() as usize;
cols.max(1)
}
fn compute_entity_left(
n: usize,
entity_grid_pos: &[(usize, usize)],
col_widths: &[usize],
n_cols: usize,
intra_row_pair_gaps: &[Vec<usize>],
) -> Vec<usize> {
let mut out = vec![0usize; n];
let n_rows = entity_grid_pos.iter().map(|p| p.0).max().unwrap_or(0) + 1;
for (gr, gaps) in intra_row_pair_gaps.iter().enumerate().take(n_rows) {
let mut x = 0usize;
for (gc, &col_w) in col_widths.iter().enumerate().take(n_cols) {
for i in 0..n {
if entity_grid_pos[i] == (gr, gc) {
out[i] = x;
}
}
x += col_w;
if gc + 1 < n_cols {
x += gaps.get(gc).copied().unwrap_or(MIN_ENTITY_GAP);
}
}
}
out
}
fn compute_entity_top(
n: usize,
entity_grid_pos: &[(usize, usize)],
row_heights: &[usize],
top_pad: usize,
) -> Vec<usize> {
let mut out = vec![0usize; n];
let mut y = top_pad;
let n_rows = row_heights.len();
let mut row_y = Vec::with_capacity(n_rows);
for (gr, &h) in row_heights.iter().enumerate() {
row_y.push(y);
y += h;
if gr + 1 < n_rows {
y += ROW_GAP;
}
}
for i in 0..n {
out[i] = row_y[entity_grid_pos[i].0];
}
out
}
fn compute_canvas_width(
n: usize,
chart: &ErDiagram,
entity_grid_pos: &[(usize, usize)],
entity_left: &[usize],
entity_widths: &[usize],
) -> usize {
let rightmost_entity = (0..n)
.map(|i| entity_left[i] + entity_widths[i])
.max()
.unwrap_or(0);
let needs_spine = chart.relationships.iter().any(|rel| {
let Some(fi) = chart.entity_index(&rel.from) else {
return false;
};
let Some(ti) = chart.entity_index(&rel.to) else {
return false;
};
fi != ti && entity_grid_pos[fi].0 != entity_grid_pos[ti].0
});
rightmost_entity + if needs_spine { 2 } else { 0 }
}
fn compute_intra_row_pair_gaps(
chart: &ErDiagram,
entity_grid_pos: &[(usize, usize)],
n_cols: usize,
) -> Vec<Vec<usize>> {
let n_rows = entity_grid_pos.iter().map(|p| p.0).max().unwrap_or(0) + 1;
let mut gaps: Vec<Vec<usize>> = (0..n_rows)
.map(|_| vec![MIN_ENTITY_GAP; n_cols.saturating_sub(1)])
.collect();
for rel in &chart.relationships {
let (Some(from_idx), Some(to_idx)) =
(chart.entity_index(&rel.from), chart.entity_index(&rel.to))
else {
continue;
};
if from_idx == to_idx {
continue;
}
let (from_gr, from_gc) = entity_grid_pos[from_idx];
let (to_gr, to_gc) = entity_grid_pos[to_idx];
if from_gr != to_gr {
continue; }
let (lo_gc, hi_gc) = if from_gc <= to_gc {
(from_gc, to_gc)
} else {
(to_gc, from_gc)
};
let label_w = rel.label.as_deref().map(|s| s.width()).unwrap_or(0);
let needed = label_w.max(2) + 4;
for gc in lo_gc..hi_gc {
if let Some(g) = gaps[from_gr].get_mut(gc) {
*g = (*g).max(needed);
}
}
}
for row in &mut gaps {
while row.len() < n_cols.saturating_sub(1) {
row.push(MIN_ENTITY_GAP);
}
}
gaps
}
fn entity_box_width(entity: &crate::er::Entity) -> usize {
let header_w = entity.name.width() + 2 * NAME_PAD + 2;
if entity.attributes.is_empty() {
return header_w;
}
let cols = attr_columns(entity);
let attr_w = 2 * NAME_PAD + cols.type_w + 1 + cols.name_w + 1 + cols.keys_w + 2;
attr_w.max(header_w)
}
fn entity_box_height(entity: &crate::er::Entity) -> usize {
if entity.attributes.is_empty() {
HEADER_ROWS
} else {
HEADER_ROWS + entity.attributes.len() + 1
}
}
struct AttrColumns {
type_w: usize,
name_w: usize,
keys_w: usize,
}
fn attr_columns(entity: &crate::er::Entity) -> AttrColumns {
let mut cols = AttrColumns {
type_w: 0,
name_w: 0,
keys_w: 0,
};
for attr in &entity.attributes {
cols.type_w = cols.type_w.max(attr.type_name.width());
cols.name_w = cols.name_w.max(attr.name.width());
cols.keys_w = cols.keys_w.max(format_keys(&attr.keys).width());
}
cols
}
fn format_keys(keys: &[AttributeKey]) -> String {
keys.iter()
.map(|k| match k {
AttributeKey::PrimaryKey => "PK",
AttributeKey::ForeignKey => "FK",
AttributeKey::UniqueKey => "UK",
})
.collect::<Vec<_>>()
.join(",")
}
fn draw_entity_box(
grid: &mut [Vec<char>],
entity_top: usize,
left: usize,
right: usize,
entity: &crate::er::Entity,
) {
let interior_w = right - left - 1;
let name_w = entity.name.width();
let name_start = left + 1 + (interior_w.saturating_sub(name_w)) / 2;
put(grid, entity_top, left, '┌');
for c in (left + 1)..right {
put(grid, entity_top, c, '─');
}
put(grid, entity_top, right, '┐');
put(grid, entity_top + 1, left, '│');
put_str(grid, entity_top + 1, name_start, &entity.name);
put(grid, entity_top + 1, right, '│');
if entity.attributes.is_empty() {
put(grid, entity_top + 2, left, '└');
for c in (left + 1)..right {
put(grid, entity_top + 2, c, '─');
}
put(grid, entity_top + 2, right, '┘');
return;
}
put(grid, entity_top + 2, left, '├');
for c in (left + 1)..right {
put(grid, entity_top + 2, c, '─');
}
put(grid, entity_top + 2, right, '┤');
let cols = attr_columns(entity);
for (i, attr) in entity.attributes.iter().enumerate() {
let row = entity_top + HEADER_ROWS + i;
put(grid, row, left, '│');
let mut col = left + 1 + NAME_PAD;
put_str(grid, row, col, &pad_right(&attr.type_name, cols.type_w));
col += cols.type_w + 1;
put_str(grid, row, col, &pad_right(&attr.name, cols.name_w));
col += cols.name_w + 1;
let keys_str = format_keys(&attr.keys);
put_str(grid, row, col, &pad_right(&keys_str, cols.keys_w));
put(grid, row, right, '│');
}
let bottom = entity_top + HEADER_ROWS + entity.attributes.len();
put(grid, bottom, left, '└');
for c in (left + 1)..right {
put(grid, bottom, c, '─');
}
put(grid, bottom, right, '┘');
}
#[allow(clippy::too_many_arguments)]
fn draw_relationship_line(
grid: &mut [Vec<char>],
entity_top: usize,
from_left: usize,
from_width: usize,
to_left: usize,
to_width: usize,
rel: &Relationship,
top_pad: usize,
) {
let line_row = entity_top + 1;
let from_right_border = from_left + from_width - 1;
let to_left_border = to_left;
let from_left_border = from_left;
let to_right_border = to_left + to_width - 1;
let going_right = from_right_border < to_left_border;
let (left_border, right_border, source_at_left, line_lo, line_hi) = if going_right {
let lo = from_right_border + 1;
let hi = to_left_border.saturating_sub(1);
(from_right_border, to_left_border, true, lo, hi)
} else {
let lo = to_right_border + 1;
let hi = from_left_border.saturating_sub(1);
(to_right_border, from_left_border, false, lo, hi)
};
if line_hi <= line_lo {
return;
}
let line_glyph = if rel.line_style.is_dashed() {
'┄'
} else {
'─'
};
if !rel.line_style.is_dashed() {
put(grid, line_row, left_border, '┤');
put(grid, line_row, right_border, '├');
}
for c in line_lo..=line_hi {
put(grid, line_row, c, line_glyph);
}
let (lo_card, hi_card) = if source_at_left {
(rel.from_cardinality, rel.to_cardinality)
} else {
(rel.to_cardinality, rel.from_cardinality)
};
put(grid, line_row, line_lo, cardinality_glyph(lo_card));
put(grid, line_row, line_hi, cardinality_glyph(hi_card));
if top_pad == 0 {
return;
}
if let Some(label) = &rel.label
&& !label.is_empty()
{
let label_w = label.width();
let gap_w = line_hi.saturating_sub(line_lo) + 1;
let label_row = if entity_top >= top_pad {
entity_top - 1
} else {
return; };
if gap_w >= label_w {
let offset = (gap_w - label_w) / 2;
put_str(grid, label_row, line_lo + offset, label);
} else {
put_str(grid, label_row, line_lo, label);
}
}
}
#[allow(clippy::too_many_arguments)]
fn draw_cross_row_relationship(
grid: &mut [Vec<char>],
from_top: usize,
from_height: usize,
from_left: usize,
from_width: usize,
to_top: usize,
to_height: usize,
to_left: usize,
to_width: usize,
rel: &Relationship,
canvas_width: usize,
from_is_rightmost: bool,
to_is_rightmost: bool,
used_label_ranges: &mut std::collections::HashMap<usize, Vec<(usize, usize)>>,
) {
let spine_col = if canvas_width > 0 {
canvas_width - 1
} else {
return;
};
let vert_glyph = if rel.line_style.is_dashed() {
'┆'
} else {
'│'
};
let from_row = from_top + 1;
let to_row = to_top + 1;
let from_right_border = from_left + from_width - 1;
let to_right_border = to_left + to_width - 1;
if from_right_border < spine_col {
if !rel.line_style.is_dashed() {
put(grid, from_row, from_right_border, '┤');
}
let card_col = from_right_border + 1;
put(
grid,
from_row,
card_col,
cardinality_glyph(rel.from_cardinality),
);
if from_is_rightmost {
let fill_glyph = if rel.line_style.is_dashed() {
'┄'
} else {
'─'
};
for c in (card_col + 1)..spine_col {
put(grid, from_row, c, fill_glyph);
}
}
let corner = if from_row < to_row { '┐' } else { '┘' };
put(grid, from_row, spine_col, corner);
} else {
put(
grid,
from_row,
from_right_border,
cardinality_glyph(rel.from_cardinality),
);
}
let (vert_lo, vert_hi) = if from_row < to_row {
(from_row + 1, to_row)
} else {
(to_row + 1, from_row)
};
for r in vert_lo..vert_hi {
put(grid, r, spine_col, vert_glyph);
}
if to_right_border < spine_col {
let corner = if from_row < to_row { '┘' } else { '┐' };
put(grid, to_row, spine_col, corner);
let card_col = to_right_border + 1;
put(
grid,
to_row,
card_col,
cardinality_glyph(rel.to_cardinality),
);
if to_is_rightmost {
let fill_glyph = if rel.line_style.is_dashed() {
'┄'
} else {
'─'
};
for c in (card_col + 1)..spine_col {
put(grid, to_row, c, fill_glyph);
}
}
if !rel.line_style.is_dashed() {
put(grid, to_row, to_right_border, '├');
}
} else {
put(
grid,
to_row,
to_right_border,
cardinality_glyph(rel.to_cardinality),
);
}
if let Some(label) = &rel.label
&& !label.is_empty()
&& from_row != to_row
{
let first_gap_row = if from_row < to_row {
from_top + from_height
} else {
to_top + to_height
};
let label_w = label.width();
let label_col = spine_col.saturating_sub(label_w + 1);
let label_end = label_col + label_w;
let chosen_row = (0..ROW_GAP)
.map(|offset| first_gap_row + offset)
.find(|row| {
used_label_ranges.get(row).is_none_or(|ranges| {
!ranges.iter().any(|&(s, e)| s < label_end && label_col < e)
})
})
.unwrap_or(first_gap_row);
used_label_ranges
.entry(chosen_row)
.or_default()
.push((label_col, label_end));
put_str(grid, chosen_row, label_col, label);
}
}
fn is_rightmost_in_row(
idx: usize,
entity_grid_pos: &[(usize, usize)],
n_cols: usize,
n: usize,
) -> bool {
if idx + 1 >= n {
return true;
}
if n_cols == 0 {
return true;
}
let (row, _) = entity_grid_pos[idx];
let (next_row, _) = entity_grid_pos[idx + 1];
next_row != row
}
fn cardinality_glyph(c: Cardinality) -> char {
match c {
Cardinality::ExactlyOne => '1',
Cardinality::ZeroOrOne => '?',
Cardinality::OneOrMany => '+',
Cardinality::ZeroOrMany => '*',
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::parser::er::parse;
#[test]
fn renders_two_entities_with_relationship() {
let chart = parse("erDiagram\nCUSTOMER ||--o{ ORDER : places").unwrap();
let out = render(&chart, None);
assert!(out.contains("CUSTOMER"));
assert!(out.contains("ORDER"));
assert!(out.contains('1'));
assert!(out.contains('*'));
assert!(out.contains("places"));
}
#[test]
fn renders_isolated_entity_with_attributes() {
let chart = parse("erDiagram\nCUSTOMER {\n string name\n string email PK\n}").unwrap();
let out = render(&chart, None);
assert!(out.contains("CUSTOMER"));
assert!(out.contains("string"));
assert!(out.contains("email"));
assert!(out.contains("PK"));
}
#[test]
fn renders_dashed_line_for_non_identifying() {
let chart = parse("erDiagram\nA ||..o{ B").unwrap();
let out = render(&chart, None);
assert!(out.contains('┄'), "expected dashed line in:\n{out}");
}
#[test]
fn cardinality_glyph_table_is_distinct() {
let glyphs = [
cardinality_glyph(Cardinality::ExactlyOne),
cardinality_glyph(Cardinality::ZeroOrOne),
cardinality_glyph(Cardinality::OneOrMany),
cardinality_glyph(Cardinality::ZeroOrMany),
];
let unique: std::collections::HashSet<_> = glyphs.iter().collect();
assert_eq!(unique.len(), 4, "cardinality glyphs must be unique");
}
#[test]
fn format_keys_handles_zero_one_and_multiple() {
assert_eq!(format_keys(&[]), "");
assert_eq!(format_keys(&[AttributeKey::PrimaryKey]), "PK");
assert_eq!(
format_keys(&[AttributeKey::ForeignKey, AttributeKey::UniqueKey]),
"FK,UK"
);
}
fn make_bare_entities_src(n: usize) -> String {
let mut src = "erDiagram\n".to_string();
for i in 0..n {
if i + 1 < n {
src.push_str(&format!("E{i} ||--o{{ E{} : rel\n", i + 1));
}
}
src
}
#[test]
fn small_er_diagram_uses_single_row() {
let src = make_bare_entities_src(4);
let chart = parse(&src).unwrap();
let out = render(&chart, None);
for i in 0..4 {
assert!(out.contains(&format!("E{i}")), "E{i} missing from output");
}
let top_border_rows = out.lines().filter(|l| l.contains('┌')).count();
assert_eq!(
top_border_rows, 1,
"expected 1 top-border row for 4 entities, got {top_border_rows}"
);
}
#[test]
fn wide_er_diagram_wraps_to_grid() {
let src = make_bare_entities_src(8);
let chart = parse(&src).unwrap();
let out = render(&chart, Some(30));
for i in 0..8 {
assert!(out.contains(&format!("E{i}")), "E{i} missing from:\n{out}");
}
let top_border_rows = out.lines().filter(|l| l.contains('┌')).count();
assert!(
top_border_rows > 1,
"expected multiple top-border rows for 8 entities in 30 cols, got {top_border_rows}"
);
}
#[test]
fn cross_row_relationship_routes_correctly() {
let src = "erDiagram
E0 ||--o{ E1 : a
E1 ||--o{ E2 : b
E2 ||--o{ E3 : c
E3 ||--o{ E4 : d
E4 ||--o{ E5 : e
E5 ||--o{ E6 : f
E6 ||--o{ E7 : g";
let chart = parse(src).unwrap();
let out = render(&chart, Some(30));
assert!(out.contains("E0"), "E0 missing");
assert!(out.contains("E4"), "E4 missing");
let has_vertical = out.contains('│') || out.contains('┐') || out.contains('┘');
assert!(has_vertical, "no vertical routing glyphs found in:\n{out}");
}
#[test]
fn small_diagram_has_no_right_spine() {
let src = "erDiagram
A ||--|| B : rel";
let chart = parse(src).unwrap();
let out = render(&chart, Some(20));
let spine_in_gap = out.lines().any(|l| {
let has_vert = l.contains('│') || l.contains('┆');
let has_box = l.contains('┌')
|| l.contains('├')
|| l.contains('└')
|| l.contains('─')
|| l.contains('┤');
has_vert && !has_box
});
assert!(
!spine_in_gap,
"intra-row-only diagram should not have spine-only rows, got:\n{out}"
);
assert!(out.contains('A'), "entity A missing from:\n{out}");
assert!(out.contains('B'), "entity B missing from:\n{out}");
}
#[test]
fn cross_row_target_alone_in_row_has_horizontal_stub_to_spine() {
let src = "erDiagram
CUSTOMER ||--o{ ORDER : places
ORDER ||--|{ ITEM : contains
PRODUCT ||--o{ ITEM : describes
CATEGORY ||--o{ PRODUCT : groups
ACCOUNT ||--|| CUSTOMER : owns
INVOICE ||--|{ ORDER : bills
CUSTOMER { int id PK string name }
ORDER { int id PK int customerId FK }
PRODUCT { int id PK string name int categoryId FK }
CATEGORY { int id PK string label }
ACCOUNT { int id PK }
INVOICE { int id PK }
ITEM { int orderId FK int productId FK }";
let chart = parse(src).unwrap();
let out = render(&chart, None);
let invoice_row = out
.lines()
.find(|l| l.contains("INVOICE") && l.contains('│'))
.unwrap_or_else(|| panic!("INVOICE name row not found in:\n{out}"));
let trimmed = invoice_row.trim_end();
assert!(
trimmed.ends_with('┘') || trimmed.ends_with('┐'),
"INVOICE row should end with a spine corner glyph (┘ or ┐), got: {trimmed:?}"
);
let card_pos = trimmed
.find('1')
.expect("INVOICE cardinality glyph `1` missing");
let corner_pos = trimmed.rfind('┘').or_else(|| trimmed.rfind('┐')).unwrap();
let gap = &trimmed[card_pos + 1..corner_pos];
assert!(
gap.contains('─'),
"expected `─` stub between INVOICE cardinality `1` and spine corner, got gap: {gap:?}\nfull row: {trimmed:?}"
);
}
#[test]
fn cross_row_labels_in_same_gap_row_do_not_overlap() {
let src = "erDiagram
CUSTOMER ||--o{ ORDER : places
ORDER ||--|{ ITEM : contains
PRODUCT ||--o{ ITEM : describes
CATEGORY ||--o{ PRODUCT : groups
ACCOUNT ||--|| CUSTOMER : owns
INVOICE ||--|{ ORDER : bills
CUSTOMER { int id PK string name }
ORDER { int id PK int customerId FK }
PRODUCT { int id PK string name int categoryId FK }
CATEGORY { int id PK string label }
ACCOUNT { int id PK }
INVOICE { int id PK }
ITEM { int orderId FK int productId FK }";
let chart = parse(src).unwrap();
let out = render(&chart, None);
assert!(
out.contains("describes"),
"label 'describes' was clobbered by an overlapping label:\n{out}"
);
assert!(
out.contains("bills"),
"label 'bills' was clobbered by an overlapping label:\n{out}"
);
assert!(
!out.contains("descbills") && !out.contains("billsescribes"),
"two labels collided into a single token:\n{out}"
);
}
#[test]
fn grid_honours_max_width_budget() {
let src = make_bare_entities_src(8);
let chart = parse(&src).unwrap();
let out = render(&chart, Some(50));
for (line_no, line) in out.lines().enumerate() {
let w = line.width();
assert!(
w <= 60,
"line {line_no} is {w} chars wide (budget 50), content: {line:?}"
);
}
}
}