use std::collections::HashMap;
use unicode_width::UnicodeWidthStr;
use crate::{
layout::{Grid, grid::arrow, layered::GridPos},
types::{Direction, Graph, Node, NodeShape},
};
const LABEL_PADDING: usize = 2;
#[derive(Debug, Clone, Copy)]
struct NodeGeom {
pub width: usize,
pub height: usize,
pub text_row: usize,
}
impl NodeGeom {
fn for_node(node: &Node) -> Self {
let label_w = UnicodeWidthStr::width(node.label.as_str());
let inner_w = label_w + LABEL_PADDING * 2;
match node.shape {
NodeShape::Diamond => NodeGeom {
width: inner_w + 4, height: 5,
text_row: 2,
},
NodeShape::Circle => NodeGeom {
width: inner_w + 2,
height: 3,
text_row: 1,
},
_ => NodeGeom {
width: inner_w,
height: 3,
text_row: 1,
},
}
}
fn cx(self) -> usize {
self.width / 2
}
fn cy(self) -> usize {
self.height / 2
}
}
#[derive(Debug, Clone, Copy)]
struct Attach {
pub col: usize,
pub row: usize,
}
fn exit_point(pos: GridPos, geom: NodeGeom, dir: Direction) -> Attach {
let (c, r) = pos;
match dir {
Direction::LeftToRight => Attach {
col: c + geom.width, row: r + geom.cy(),
},
Direction::RightToLeft => Attach {
col: c.saturating_sub(1),
row: r + geom.cy(),
},
Direction::TopToBottom => Attach {
col: c + geom.cx(),
row: r + geom.height, },
Direction::BottomToTop => Attach {
col: c + geom.cx(),
row: r.saturating_sub(1),
},
}
}
fn entry_point(pos: GridPos, geom: NodeGeom, dir: Direction) -> Attach {
let (c, r) = pos;
match dir {
Direction::LeftToRight => Attach {
col: c.saturating_sub(1), row: r + geom.cy(),
},
Direction::RightToLeft => Attach {
col: c + geom.width,
row: r + geom.cy(),
},
Direction::TopToBottom => Attach {
col: c + geom.cx(),
row: r.saturating_sub(1),
},
Direction::BottomToTop => Attach {
col: c + geom.cx(),
row: r + geom.height,
},
}
}
fn tip_char(dir: Direction) -> char {
match dir {
Direction::LeftToRight => arrow::RIGHT,
Direction::RightToLeft => arrow::LEFT,
Direction::TopToBottom => arrow::DOWN,
Direction::BottomToTop => arrow::UP,
}
}
fn grid_size(
graph: &Graph,
positions: &HashMap<String, GridPos>,
geoms: &HashMap<String, NodeGeom>,
) -> (usize, usize) {
let mut max_col = 0usize;
let mut max_row = 0usize;
for node in &graph.nodes {
if let (Some(&(c, r)), Some(&g)) = (positions.get(&node.id), geoms.get(&node.id)) {
max_col = max_col.max(c + g.width + 4);
max_row = max_row.max(r + g.height + 4);
}
}
(max_col.max(1), max_row.max(1))
}
pub fn render(graph: &Graph, positions: &HashMap<String, GridPos>) -> String {
let geoms: HashMap<String, NodeGeom> = graph
.nodes
.iter()
.map(|n| (n.id.clone(), NodeGeom::for_node(n)))
.collect();
let (width, height) = grid_size(graph, positions, &geoms);
let mut grid = Grid::new(width, height);
for edge in &graph.edges {
let (Some(&from_pos), Some(&to_pos)) = (positions.get(&edge.from), positions.get(&edge.to))
else {
continue;
};
let (Some(&from_geom), Some(&to_geom)) = (geoms.get(&edge.from), geoms.get(&edge.to))
else {
continue;
};
let src = exit_point(from_pos, from_geom, graph.direction);
let dst = entry_point(to_pos, to_geom, graph.direction);
let tip = tip_char(graph.direction);
let horizontal_first = graph.direction.is_horizontal();
grid.draw_manhattan(src.col, src.row, dst.col, dst.row, horizontal_first, tip);
if let Some(ref lbl) = edge.label {
place_edge_label(&mut grid, src, dst, lbl, graph.direction);
}
}
for node in &graph.nodes {
let Some(&pos) = positions.get(&node.id) else {
continue;
};
let Some(&geom) = geoms.get(&node.id) else {
continue;
};
draw_node_box(&mut grid, node, pos, geom);
}
for node in &graph.nodes {
let Some(&pos) = positions.get(&node.id) else {
continue;
};
let Some(&geom) = geoms.get(&node.id) else {
continue;
};
draw_label_centred(&mut grid, node, pos, geom);
}
grid.render()
}
fn draw_node_box(grid: &mut Grid, node: &Node, pos: GridPos, geom: NodeGeom) {
let (col, row) = pos;
for y in (row + 1)..(row + geom.height.saturating_sub(1)) {
for x in (col + 1)..(col + geom.width.saturating_sub(1)) {
grid.set(x, y, ' ');
}
}
match node.shape {
NodeShape::Rectangle => {
grid.draw_box(col, row, geom.width, geom.height);
}
NodeShape::Rounded => {
grid.draw_rounded_box(col, row, geom.width, geom.height);
}
NodeShape::Diamond => {
grid.draw_diamond(col, row, geom.width, geom.height);
}
NodeShape::Circle => {
grid.draw_rounded_box(col, row, geom.width, geom.height);
let mid = row + geom.cy();
grid.set(col + 1, mid, '(');
grid.set(col + geom.width - 2, mid, ')');
}
}
}
fn draw_label_centred(grid: &mut Grid, node: &Node, pos: GridPos, geom: NodeGeom) {
let (col, row) = pos;
let label_w = UnicodeWidthStr::width(node.label.as_str());
let interior_w = geom.width.saturating_sub(2);
let text_col = if label_w <= interior_w {
col + 1 + (interior_w - label_w) / 2
} else {
col + 1
};
let text_col = if node.shape == NodeShape::Diamond {
let indent = geom.height / 2; col + indent
+ (geom
.width
.saturating_sub(indent * 2)
.saturating_sub(label_w))
/ 2
} else {
text_col
};
grid.write_text(text_col, row + geom.text_row, &node.label);
}
fn place_edge_label(grid: &mut Grid, src: Attach, dst: Attach, label: &str, dir: Direction) {
let (lbl_col, lbl_row) = match dir {
Direction::LeftToRight | Direction::RightToLeft => {
let mid_col = (src.col + dst.col) / 2;
let row = if dst.row < src.row {
src.row.saturating_sub(1)
} else if dst.row > src.row {
src.row + 1
} else {
src.row.saturating_sub(1)
};
(mid_col, row)
}
Direction::TopToBottom | Direction::BottomToTop => {
let mid_row = (src.row + dst.row) / 2;
let col = if dst.col > src.col {
src.col.saturating_sub(1)
} else {
src.col + 1
};
(col, mid_row)
}
};
grid.write_text(lbl_col, lbl_row, label);
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
layout::layered::{LayoutConfig, layout},
parser,
};
fn render_diagram(src: &str) -> String {
let graph = parser::parse(src).unwrap();
let positions = layout(&graph, &LayoutConfig::default());
render(&graph, &positions)
}
#[test]
fn lr_output_contains_node_labels() {
let out = render_diagram("graph LR\nA[Start] --> B[End]");
assert!(out.contains("Start"), "missing 'Start' in:\n{out}");
assert!(out.contains("End"), "missing 'End' in:\n{out}");
}
#[test]
fn td_output_contains_node_labels() {
let out = render_diagram("graph TD\nA[Top] --> B[Bottom]");
assert!(out.contains("Top"), "missing 'Top' in:\n{out}");
assert!(out.contains("Bottom"), "missing 'Bottom' in:\n{out}");
}
}