use crate::core::data::SankeyData;
#[derive(Debug, Clone, PartialEq)]
pub struct SankeyNodeRect {
pub x: f64,
pub y: f64,
pub width: f64,
pub height: f64,
pub index: usize,
pub label: String,
pub color: String,
}
#[derive(Debug, Clone, PartialEq)]
pub struct SankeyRibbon {
pub path: String,
pub source: usize,
pub target: usize,
pub value: f64,
pub color: String,
}
#[derive(Debug, Clone, PartialEq)]
pub struct SankeyLayoutResult {
pub nodes: Vec<SankeyNodeRect>,
pub links: Vec<SankeyRibbon>,
}
const DEFAULT_PALETTE: &[&str] = &[
"#4e79a7", "#f28e2b", "#e15759", "#76b7b2", "#59a14f", "#edc948", "#b07aa1", "#ff9da7",
"#9c755f", "#bab0ac",
];
pub fn layout_sankey(
data: &SankeyData,
width: f64,
height: f64,
node_width: f64,
node_gap: f64,
) -> SankeyLayoutResult {
let n = data.nodes.len();
if n == 0 {
return SankeyLayoutResult {
nodes: vec![],
links: vec![],
};
}
let mut depth = vec![usize::MAX; n];
let has_incoming: Vec<bool> = {
let mut v = vec![false; n];
for link in &data.links {
if link.target < n {
v[link.target] = true;
}
}
v
};
let mut queue: std::collections::VecDeque<usize> = std::collections::VecDeque::new();
for i in 0..n {
if !has_incoming[i] {
depth[i] = 0;
queue.push_back(i);
}
}
while let Some(node) = queue.pop_front() {
let d = depth[node];
for link in &data.links {
if link.source == node && link.target < n && depth[link.target] == usize::MAX {
depth[link.target] = d + 1;
queue.push_back(link.target);
}
}
}
for d in &mut depth {
if *d == usize::MAX {
*d = 0;
}
}
let max_depth = depth.iter().copied().max().unwrap_or(0);
let num_cols = max_depth + 1;
let mut flow_in = vec![0.0_f64; n];
let mut flow_out = vec![0.0_f64; n];
for link in &data.links {
if link.source < n {
flow_out[link.source] += link.value;
}
if link.target < n {
flow_in[link.target] += link.value;
}
}
let node_flow: Vec<f64> = (0..n)
.map(|i| flow_in[i].max(flow_out[i]).max(1.0))
.collect();
let mut col_nodes: Vec<Vec<usize>> = vec![vec![]; num_cols];
for i in 0..n {
col_nodes[depth[i]].push(i);
}
let col_x: Vec<f64> = if num_cols <= 1 {
vec![0.0]
} else {
let step = (width - node_width) / (num_cols - 1) as f64;
(0..num_cols).map(|c| c as f64 * step).collect()
};
let usable_height = height - node_gap * (n as f64);
let mut node_rects: Vec<SankeyNodeRect> = vec![
SankeyNodeRect {
x: 0.0,
y: 0.0,
width: node_width,
height: 0.0,
index: 0,
label: String::new(),
color: String::new(),
};
n
];
for (col, nodes_in_col) in col_nodes.iter().enumerate() {
if nodes_in_col.is_empty() {
continue;
}
let total_flow: f64 = nodes_in_col.iter().map(|&i| node_flow[i]).sum();
let total_gaps = node_gap * (nodes_in_col.len() as f64 + 1.0);
let available = (usable_height - total_gaps).max(10.0 * nodes_in_col.len() as f64);
let mut y_cursor = node_gap;
for &node_idx in nodes_in_col {
let h = (node_flow[node_idx] / total_flow * available).max(4.0);
let color = data.nodes[node_idx]
.color
.clone()
.unwrap_or_else(|| DEFAULT_PALETTE[node_idx % DEFAULT_PALETTE.len()].to_string());
node_rects[node_idx] = SankeyNodeRect {
x: col_x[col],
y: y_cursor,
width: node_width,
height: h,
index: node_idx,
label: data.nodes[node_idx].label.clone(),
color,
};
y_cursor += h + node_gap;
}
}
let mut src_y_used = vec![0.0_f64; n];
let mut dst_y_used = vec![0.0_f64; n];
let mut link_order: Vec<usize> = (0..data.links.len()).collect();
link_order.sort_by_key(|&i| (data.links[i].source, data.links[i].target));
let mut ribbons: Vec<SankeyRibbon> = Vec::with_capacity(data.links.len());
for &li in &link_order {
let link = &data.links[li];
if link.source >= n || link.target >= n {
continue;
}
let src = &node_rects[link.source];
let dst = &node_rects[link.target];
let src_flow = flow_out[link.source].max(1.0);
let dst_flow = flow_in[link.target].max(1.0);
let ribbon_h_src = (link.value / src_flow * src.height).max(1.0);
let ribbon_h_dst = (link.value / dst_flow * dst.height).max(1.0);
let y0_top = src.y + src_y_used[link.source];
let y0_bot = y0_top + ribbon_h_src;
let y1_top = dst.y + dst_y_used[link.target];
let y1_bot = y1_top + ribbon_h_dst;
src_y_used[link.source] += ribbon_h_src;
dst_y_used[link.target] += ribbon_h_dst;
let x0 = src.x + src.width;
let x1 = dst.x;
let cx = (x0 + x1) / 2.0;
let path = format!(
"M {x0:.2} {y0_top:.2} C {cx:.2} {y0_top:.2}, {cx:.2} {y1_top:.2}, {x1:.2} {y1_top:.2} \
L {x1:.2} {y1_bot:.2} C {cx:.2} {y1_bot:.2}, {cx:.2} {y0_bot:.2}, {x0:.2} {y0_bot:.2} Z"
);
let color = link
.color
.clone()
.unwrap_or_else(|| node_rects[link.source].color.clone());
ribbons.push(SankeyRibbon {
path,
source: link.source,
target: link.target,
value: link.value,
color,
});
}
SankeyLayoutResult {
nodes: node_rects,
links: ribbons,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::data::{SankeyLink, SankeyNode};
fn make_simple() -> SankeyData {
SankeyData {
nodes: vec![
SankeyNode {
label: "A".into(),
color: None,
},
SankeyNode {
label: "B".into(),
color: None,
},
SankeyNode {
label: "C".into(),
color: None,
},
],
links: vec![
SankeyLink {
source: 0,
target: 1,
value: 10.0,
color: None,
},
SankeyLink {
source: 0,
target: 2,
value: 20.0,
color: None,
},
],
}
}
#[test]
fn test_layout_produces_correct_node_count() {
let data = make_simple();
let result = layout_sankey(&data, 400.0, 300.0, 20.0, 8.0);
assert_eq!(result.nodes.len(), 3);
assert_eq!(result.links.len(), 2);
}
#[test]
fn test_conservation() {
let data = make_simple();
let result = layout_sankey(&data, 400.0, 300.0, 20.0, 8.0);
let link_to_b: f64 = data
.links
.iter()
.filter(|l| l.target == 1)
.map(|l| l.value)
.sum();
let ribbon_to_b: f64 = result
.links
.iter()
.filter(|r| r.target == 1)
.map(|r| r.value)
.sum();
assert_eq!(
link_to_b, ribbon_to_b,
"ribbon values should match link values"
);
}
#[test]
fn test_empty_sankey() {
let data = SankeyData::default();
let result = layout_sankey(&data, 400.0, 300.0, 20.0, 8.0);
assert!(result.nodes.is_empty());
assert!(result.links.is_empty());
}
}