use super::constants::*;
use super::parser::{LinkColor, NodeAlignment, SankeyDiagram};
use super::templates::{self, build_css, esc};
use crate::theme::Theme;
fn tableau_color_by_index(idx: usize) -> &'static str {
TABLEAU10[idx % TABLEAU10.len()]
}
#[derive(Debug, Clone)]
struct LayoutNode {
id: String,
#[allow(dead_code)]
index: usize,
depth: usize,
height: usize,
layer: usize,
value: f64,
x0: f64,
x1: f64,
y0: f64,
y1: f64,
source_links: Vec<usize>,
target_links: Vec<usize>,
}
#[derive(Debug, Clone)]
struct LayoutLink {
source: usize, target: usize, value: f64,
width: f64, y0: f64,
y1: f64,
index: usize,
}
struct SankeyLayout {
nodes: Vec<LayoutNode>,
links: Vec<LayoutLink>,
#[allow(dead_code)]
num_columns: usize,
}
fn compute_layout(
diag: &SankeyDiagram,
width: f64,
height: f64,
node_width: f64,
node_padding: f64,
alignment: &NodeAlignment,
) -> SankeyLayout {
let n = diag.nodes.len();
if n == 0 {
return SankeyLayout {
nodes: vec![],
links: vec![],
num_columns: 0,
};
}
let mut node_index: std::collections::HashMap<&str, usize> = std::collections::HashMap::new();
for (i, node) in diag.nodes.iter().enumerate() {
node_index.insert(&node.id, i);
}
let mut layout_nodes: Vec<LayoutNode> = diag
.nodes
.iter()
.enumerate()
.map(|(i, node)| LayoutNode {
id: node.id.clone(),
index: i,
depth: 0,
height: 0,
layer: 0,
value: 0.0,
x0: 0.0,
x1: 0.0,
y0: 0.0,
y1: 0.0,
source_links: vec![],
target_links: vec![],
})
.collect();
let mut layout_links: Vec<LayoutLink> = diag
.links
.iter()
.enumerate()
.map(|(i, link)| {
let src = *node_index.get(link.source.as_str()).unwrap_or(&0);
let tgt = *node_index.get(link.target.as_str()).unwrap_or(&0);
LayoutLink {
source: src,
target: tgt,
value: link.value,
width: 0.0,
y0: 0.0,
y1: 0.0,
index: i,
}
})
.collect();
for (li, link) in layout_links.iter().enumerate() {
let src = link.source;
let tgt = link.target;
layout_nodes[src].source_links.push(li);
layout_nodes[tgt].target_links.push(li);
}
for node in layout_nodes.iter_mut() {
let src_sum: f64 = node
.source_links
.iter()
.map(|&li| layout_links[li].value)
.sum();
let tgt_sum: f64 = node
.target_links
.iter()
.map(|&li| layout_links[li].value)
.sum();
node.value = src_sum.max(tgt_sum);
}
{
let mut x = 0usize;
let mut current_set: Vec<bool> = vec![true; n]; let mut current_list: Vec<usize> = (0..n).collect();
while !current_list.is_empty() {
for &ni in ¤t_list {
layout_nodes[ni].depth = x;
}
x += 1;
if x > n {
break; }
let mut next_set: Vec<bool> = vec![false; n];
let mut next_list: Vec<usize> = Vec::new();
for &ni in ¤t_list {
for &li in &layout_nodes[ni].source_links {
let tgt = layout_links[li].target;
if !next_set[tgt] {
next_set[tgt] = true;
next_list.push(tgt);
}
}
current_set[ni] = false;
}
current_list = next_list;
current_set = next_set;
}
}
{
let mut x = 0usize;
let mut current_list: Vec<usize> = (0..n).collect();
while !current_list.is_empty() {
for &ni in ¤t_list {
layout_nodes[ni].height = x;
}
x += 1;
if x > n {
break;
}
let mut next_set: Vec<bool> = vec![false; n];
let mut next_list: Vec<usize> = Vec::new();
for &ni in ¤t_list {
for &li in &layout_nodes[ni].target_links {
let src = layout_links[li].source;
if !next_set[src] {
next_set[src] = true;
next_list.push(src);
}
}
}
current_list = next_list;
}
}
let max_depth = layout_nodes.iter().map(|n| n.depth).max().unwrap_or(0);
let num_columns = max_depth + 1;
let kx = if num_columns > 1 {
(width - node_width) / (num_columns as f64 - 1.0)
} else {
0.0
};
for node in layout_nodes.iter_mut() {
let raw_layer = match alignment {
NodeAlignment::Left => node.depth,
NodeAlignment::Right => {
(num_columns - 1).saturating_sub(node.height)
}
NodeAlignment::Center => {
node.depth
}
NodeAlignment::Justify => {
if !node.source_links.is_empty() {
node.depth
} else {
max_depth
}
}
};
let layer = raw_layer.min(num_columns - 1);
node.layer = layer;
node.x0 = kx * layer as f64;
node.x1 = node.x0 + node_width;
}
let mut columns: Vec<Vec<usize>> = vec![vec![]; num_columns];
for i in 0..n {
columns[layout_nodes[i].layer].push(i);
}
let max_col_size = columns.iter().map(|c| c.len()).max().unwrap_or(1);
let py = if max_col_size > 1 {
node_padding.min(height / (max_col_size as f64 - 1.0))
} else {
node_padding
};
let ky = columns
.iter()
.filter(|c| !c.is_empty())
.map(|c| {
let sum_vals: f64 = c.iter().map(|&i| layout_nodes[i].value).sum();
let avail = height - (c.len() as f64 - 1.0) * py;
if sum_vals > 0.0 {
avail / sum_vals
} else {
f64::MAX
}
})
.fold(f64::MAX, f64::min);
let ky = if ky == f64::MAX { 1.0 } else { ky };
for link in layout_links.iter_mut() {
link.width = link.value * ky;
}
for col_nodes in &columns {
if col_nodes.is_empty() {
continue;
}
let mut y = 0.0_f64;
for &i in col_nodes {
layout_nodes[i].y0 = y;
layout_nodes[i].y1 = y + layout_nodes[i].value * ky;
y = layout_nodes[i].y1 + py;
}
let slack = (height - y + py) / (col_nodes.len() as f64 + 1.0);
for (idx, &i) in col_nodes.iter().enumerate() {
let offset = slack * (idx as f64 + 1.0);
layout_nodes[i].y0 += offset;
layout_nodes[i].y1 += offset;
}
reorder_links_for_col(col_nodes, &mut layout_nodes, &layout_links);
}
for iter in 0..6usize {
let alpha = 0.99_f64.powi(iter as i32);
let beta = (1.0 - alpha).max((iter as f64 + 1.0) / 6.0);
for ci in (0..columns.len()).rev().skip(1) {
let col_nodes = columns[ci].clone();
for &ni in &col_nodes {
let src_links = layout_nodes[ni].source_links.clone();
if src_links.is_empty() {
continue;
}
let mut y = 0.0_f64;
let mut w = 0.0_f64;
for &li in &src_links {
let link = &layout_links[li];
let tgt = link.target;
let v = link.value
* (layout_nodes[tgt].layer as f64 - layout_nodes[ni].layer as f64);
y += source_top(&layout_nodes, &layout_links, ni, tgt, py) * v;
w += v;
}
if w <= 0.0 {
continue;
}
let dy = (y / w - layout_nodes[ni].y0) * alpha;
layout_nodes[ni].y0 += dy;
layout_nodes[ni].y1 += dy;
reorder_node_links(ni, &mut layout_nodes, &layout_links);
}
columns[ci].sort_by(|&a, &b| {
layout_nodes[a]
.y0
.partial_cmp(&layout_nodes[b].y0)
.unwrap_or(std::cmp::Ordering::Equal)
});
resolve_collisions(&mut layout_nodes, &columns[ci], beta, height, py);
}
for column in columns.iter_mut().skip(1) {
let col_nodes = column.clone();
for &ni in &col_nodes {
let tgt_links = layout_nodes[ni].target_links.clone();
if tgt_links.is_empty() {
continue;
}
let mut y = 0.0_f64;
let mut w = 0.0_f64;
for &li in &tgt_links {
let link = &layout_links[li];
let src = link.source;
let v = link.value
* (layout_nodes[ni].layer as f64 - layout_nodes[src].layer as f64);
y += target_top(&layout_nodes, &layout_links, src, ni, py) * v;
w += v;
}
if w <= 0.0 {
continue;
}
let dy = (y / w - layout_nodes[ni].y0) * alpha;
layout_nodes[ni].y0 += dy;
layout_nodes[ni].y1 += dy;
reorder_node_links(ni, &mut layout_nodes, &layout_links);
}
column.sort_by(|&a, &b| {
layout_nodes[a]
.y0
.partial_cmp(&layout_nodes[b].y0)
.unwrap_or(std::cmp::Ordering::Equal)
});
resolve_collisions(&mut layout_nodes, column, beta, height, py);
}
}
compute_link_breadths(&mut layout_nodes, &mut layout_links);
SankeyLayout {
nodes: layout_nodes,
links: layout_links,
num_columns,
}
}
fn reorder_links_for_col(col_nodes: &[usize], nodes: &mut [LayoutNode], links: &[LayoutLink]) {
for &ni in col_nodes {
let mut src = nodes[ni].source_links.clone();
src.sort_by(|&a, &b| {
let ya = nodes[links[a].target].y0;
let yb = nodes[links[b].target].y0;
ya.partial_cmp(&yb)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| links[a].index.cmp(&links[b].index))
});
nodes[ni].source_links = src;
let mut tgt = nodes[ni].target_links.clone();
tgt.sort_by(|&a, &b| {
let ya = nodes[links[a].source].y0;
let yb = nodes[links[b].source].y0;
ya.partial_cmp(&yb)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| links[a].index.cmp(&links[b].index))
});
nodes[ni].target_links = tgt;
}
}
fn reorder_node_links(ni: usize, nodes: &mut [LayoutNode], links: &[LayoutLink]) {
let src_links = nodes[ni].source_links.clone();
for &li in &src_links {
let tgt_ni = links[li].target;
let mut tgt_tgt_links = nodes[tgt_ni].target_links.clone();
tgt_tgt_links.sort_by(|&a, &b| {
let ya = nodes[links[a].source].y0;
let yb = nodes[links[b].source].y0;
ya.partial_cmp(&yb)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| links[a].index.cmp(&links[b].index))
});
nodes[tgt_ni].target_links = tgt_tgt_links;
}
let tgt_links = nodes[ni].target_links.clone();
for &li in &tgt_links {
let src_ni = links[li].source;
let mut src_src_links = nodes[src_ni].source_links.clone();
src_src_links.sort_by(|&a, &b| {
let ya = nodes[links[a].target].y0;
let yb = nodes[links[b].target].y0;
ya.partial_cmp(&yb)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| links[a].index.cmp(&links[b].index))
});
nodes[src_ni].source_links = src_src_links;
}
}
fn target_top(
nodes: &[LayoutNode],
links: &[LayoutLink],
src_ni: usize,
tgt_ni: usize,
py: f64,
) -> f64 {
let src = &nodes[src_ni];
let tgt = &nodes[tgt_ni];
let mut y = src.y0 - (src.source_links.len() as f64 - 1.0) * py / 2.0;
for &li in &src.source_links {
let link = &links[li];
if link.target == tgt_ni {
break;
}
y += link.width + py;
}
for &li in &tgt.target_links {
let link = &links[li];
if link.source == src_ni {
break;
}
y -= link.width;
}
y
}
fn source_top(
nodes: &[LayoutNode],
links: &[LayoutLink],
src_ni: usize,
tgt_ni: usize,
py: f64,
) -> f64 {
let src = &nodes[src_ni];
let tgt = &nodes[tgt_ni];
let mut y = tgt.y0 - (tgt.target_links.len() as f64 - 1.0) * py / 2.0;
for &li in &tgt.target_links {
let link = &links[li];
if link.source == src_ni {
break;
}
y += link.width + py;
}
for &li in &src.source_links {
let link = &links[li];
if link.target == tgt_ni {
break;
}
y -= link.width;
}
y
}
fn resolve_collisions(
nodes: &mut [LayoutNode],
col_nodes: &[usize],
alpha: f64,
height: f64,
py: f64,
) {
if col_nodes.is_empty() {
return;
}
let m = col_nodes.len();
let i = m / 2;
let subject = col_nodes[i];
resolve_bottom_to_top(nodes, col_nodes, subject, i, alpha, py);
resolve_top_to_bottom(nodes, col_nodes, subject, i, alpha, py, height);
resolve_bottom_to_top_from_end(nodes, col_nodes, alpha, py, height);
resolve_top_to_bottom_from_start(nodes, col_nodes, alpha, py);
}
fn resolve_bottom_to_top(
nodes: &mut [LayoutNode],
col_nodes: &[usize],
_subject_idx: usize,
start: usize,
alpha: f64,
py: f64,
) {
if start == 0 {
return;
}
let y_start = nodes[col_nodes[start]].y0 - py;
let mut y = y_start;
let mut idx = start as isize - 1;
while idx >= 0 {
let ni = col_nodes[idx as usize];
let dy = (nodes[ni].y1 - y) * alpha;
if dy > 1e-6 {
let h = nodes[ni].y1 - nodes[ni].y0;
nodes[ni].y0 -= dy;
nodes[ni].y1 -= dy;
if nodes[ni].y0 < 0.0 {
nodes[ni].y0 = 0.0;
nodes[ni].y1 = h;
}
}
y = nodes[ni].y0 - py;
idx -= 1;
}
}
fn resolve_top_to_bottom(
nodes: &mut [LayoutNode],
col_nodes: &[usize],
_subject_idx: usize,
start: usize,
alpha: f64,
py: f64,
_height: f64,
) {
let m = col_nodes.len();
if start + 1 >= m {
return;
}
let y_start = nodes[col_nodes[start]].y1 + py;
let mut y = y_start;
for &ni in col_nodes.iter().take(m).skip(start + 1) {
let dy = (y - nodes[ni].y0) * alpha;
if dy > 1e-6 {
let _h = nodes[ni].y1 - nodes[ni].y0;
nodes[ni].y0 += dy;
nodes[ni].y1 += dy;
}
y = nodes[ni].y1 + py;
}
}
fn resolve_bottom_to_top_from_end(
nodes: &mut [LayoutNode],
col_nodes: &[usize],
alpha: f64,
py: f64,
height: f64,
) {
let m = col_nodes.len();
if m == 0 {
return;
}
let mut y = height;
let mut idx = m as isize - 1;
while idx >= 0 {
let ni = col_nodes[idx as usize];
let dy = (nodes[ni].y1 - y) * alpha;
if dy > 1e-6 {
let h = nodes[ni].y1 - nodes[ni].y0;
nodes[ni].y0 -= dy;
nodes[ni].y1 -= dy;
if nodes[ni].y0 < 0.0 {
nodes[ni].y0 = 0.0;
nodes[ni].y1 = h;
}
}
y = nodes[ni].y0 - py;
idx -= 1;
}
}
fn resolve_top_to_bottom_from_start(
nodes: &mut [LayoutNode],
col_nodes: &[usize],
alpha: f64,
py: f64,
) {
let m = col_nodes.len();
let mut y = 0.0_f64;
for &ni in col_nodes.iter().take(m) {
let dy = (y - nodes[ni].y0) * alpha;
if dy > 1e-6 {
let _h = nodes[ni].y1 - nodes[ni].y0;
nodes[ni].y0 += dy;
nodes[ni].y1 += dy;
}
y = nodes[ni].y1 + py;
}
}
fn compute_link_breadths(nodes: &mut [LayoutNode], links: &mut [LayoutLink]) {
for node in nodes.iter() {
let mut y0 = node.y0;
let src_links = node.source_links.clone();
for &li in &src_links {
links[li].y0 = y0 + links[li].width / 2.0;
y0 += links[li].width;
}
let mut y1 = node.y0;
let tgt_links = node.target_links.clone();
for &li in &tgt_links {
links[li].y1 = y1 + links[li].width / 2.0;
y1 += links[li].width;
}
}
}
fn sankey_link_path(
src_x1: f64, tgt_x0: f64, y0: f64, y1: f64, ) -> String {
let mid_x = (src_x1 + tgt_x0) / 2.0;
format!(
"M{x0:.2},{y0:.2} C{mx:.2},{y0:.2} {mx:.2},{y1:.2} {x1:.2},{y1:.2}",
x0 = src_x1,
y0 = y0,
mx = mid_x,
y1 = y1,
x1 = tgt_x0,
)
}
fn find_central_node_layer(nodes: &[LayoutNode]) -> usize {
let mut max_value = 0.0_f64;
let mut central_layer = 0usize;
for node in nodes {
if node.value > max_value {
max_value = node.value;
central_layer = node.layer;
}
}
central_layer
}
fn label_position(node: &LayoutNode, width: f64) -> (f64, &'static str) {
if node.x0 < width / 2.0 {
(node.x1 + LABEL_OFFSET, "start")
} else {
(node.x0 - LABEL_OFFSET, "end")
}
}
#[allow(dead_code)]
fn label_position_outlined(node: &LayoutNode, central_layer: usize) -> (f64, &'static str) {
if node.layer < central_layer {
(node.x0 - 6.0, "end")
} else {
(node.x1 + 6.0, "start")
}
}
pub fn render(diag: &SankeyDiagram, theme: Theme) -> String {
let vars = theme.resolve();
let ff = vars.font_family;
let svg_id = SVG_ID;
let conf = &diag.config;
let width = conf.width;
let height = conf.height;
let node_width = conf.node_width;
let node_padding = conf.node_padding + if conf.show_values { 15.0 } else { 0.0 };
let show_values = conf.show_values;
let prefix = &conf.prefix;
let suffix = &conf.suffix;
if diag.nodes.is_empty() {
return templates::svg_empty(svg_id, width, height);
}
let layout = compute_layout(
diag,
width,
height,
node_width,
node_padding,
&conf.node_alignment,
);
let nodes = &layout.nodes;
let links = &layout.links;
let _central_layer = find_central_node_layer(nodes);
let label_bottom_offset = 14.0 * 0.35 + 14.0 * 0.217; let actual_height = nodes
.iter()
.map(|n| (n.y0 + n.y1) / 2.0 + label_bottom_offset)
.fold(height, f64::max);
let get_node_color = |_id: &str, idx: usize| -> &'static str { tableau_color_by_index(idx) };
let css = build_css(svg_id, ff);
let mut parts: Vec<String> = Vec::new();
parts.push(templates::svg_root(svg_id, width, actual_height));
parts.push(format!("<style>{}</style>", css));
parts.push(r#"<g class="nodes">"#.to_string());
for (i, node) in nodes.iter().enumerate() {
let color = get_node_color(&node.id, i);
let node_h = node.y1 - node.y0;
let node_w = node.x1 - node.x0;
parts.push(templates::node_group(i, node.x0, node.y0));
parts.push(templates::node_rect(node_h, node_w, color));
parts.push("</g>".to_string());
}
parts.push("</g>".to_string());
parts.push(format!(
r#"<g class="node-labels" font-size="{}">"#,
LABEL_FONT_SIZE_ATTR
));
for node in nodes.iter() {
let label = if show_values {
let rounded = (node.value * 100.0).round() / 100.0;
format!("{}\n{}{}{}", node.id, prefix, rounded, suffix)
} else {
node.id.clone()
};
let (lx, anchor) = label_position(node, width);
let ly = (node.y1 + node.y0) / 2.0;
let dy = "0.35em";
let text_content = esc(&label);
parts.push(templates::node_label_text(
lx,
ly,
dy,
anchor,
ff,
&text_content,
));
}
parts.push("</g>".to_string());
let link_color_mode = &conf.link_color;
if *link_color_mode == LinkColor::Gradient {
parts.push("<defs>".to_string());
for (li, link) in links.iter().enumerate() {
let src_color = get_node_color(&nodes[link.source].id, link.source);
let tgt_color = get_node_color(&nodes[link.target].id, link.target);
parts.push(templates::linear_gradient(
li,
nodes[link.source].x1,
nodes[link.target].x0,
src_color,
tgt_color,
));
}
parts.push("</defs>".to_string());
}
parts.push(r#"<g class="links" fill="none" stroke-opacity="0.5">"#.to_string());
for (li, link) in links.iter().enumerate() {
let src = &nodes[link.source];
let tgt = &nodes[link.target];
let path_d = sankey_link_path(src.x1, tgt.x0, link.y0, link.y1);
let stroke_width = link.width.max(1.0);
let stroke = match link_color_mode {
LinkColor::Gradient => format!("url(#lg-{})", li),
LinkColor::Source => get_node_color(&src.id, link.source).to_string(),
LinkColor::Target => get_node_color(&tgt.id, link.target).to_string(),
LinkColor::Custom(c) => c.clone(),
};
parts.push(templates::link_path(&path_d, &stroke, stroke_width));
}
parts.push("</g>".to_string());
parts.push("</svg>".to_string());
parts.join("\n")
}
#[cfg(test)]
mod tests {
use super::*;
use crate::diagrams::sankey::parser;
#[test]
fn basic_render_produces_svg() {
let input = "sankey-beta\nA,B,10\nA,C,20\nB,D,5\n";
let diag = parser::parse(input).diagram;
let svg = render(&diag, Theme::Default);
assert!(svg.contains("<svg"), "missing <svg");
assert!(svg.contains("class=\"nodes\""));
assert!(svg.contains("class=\"links\""));
}
#[test]
fn node_labels_present() {
let input = "sankey-beta\nA,B,10\nA,C,20\n";
let diag = parser::parse(input).diagram;
let svg = render(&diag, Theme::Default);
assert!(
svg.contains(">A<")
|| svg.contains(">A</")
|| svg.contains("A</tspan>")
|| svg.contains(">A\n")
);
}
#[test]
fn empty_sankey_produces_svg() {
let input = "sankey-beta\n";
let diag = parser::parse(input).diagram;
let svg = render(&diag, Theme::Default);
assert!(svg.contains("<svg"));
}
#[test]
fn link_path_cubic_bezier() {
let path = sankey_link_path(100.0, 200.0, 50.0, 80.0);
assert!(path.starts_with("M100.00,50.00"));
assert!(path.contains('C'));
assert!(path.ends_with("200.00,80.00"));
}
#[test]
fn frontmatter_config_used() {
let input = "---\nconfig:\n sankey:\n showValues: false\n width: 800\n height: 500\n---\nsankey-beta\nA,B,10\n";
let diag = parser::parse(input).diagram;
let svg = render(&diag, Theme::Default);
assert!(svg.contains("800"));
assert!(svg.contains("500"));
}
#[test]
fn gradient_defs_present() {
let input = "sankey-beta\nA,B,10\n";
let diag = parser::parse(input).diagram;
let svg = render(&diag, Theme::Default);
assert!(svg.contains("linearGradient"));
}
#[test]
fn column_assignment_justify() {
let input = "sankey-beta\nA,B,10\nA,C,20\nB,D,5\nC,D,15\n";
let diag = parser::parse(input).diagram;
let svg = render(&diag, Theme::Default);
assert!(svg.contains("class=\"nodes\""));
}
#[test]
fn snapshot_default_theme() {
let input = "sankey-beta\nCoal,Power,50\nGas,Power,30\nNuclear,Power,20\nPower,Homes,40\nPower,Industry,60";
let diag = parser::parse(input).diagram;
let svg = render(&diag, crate::theme::Theme::Default);
insta::assert_snapshot!(crate::svg::normalize_floats(&svg));
}
}