use super::Canvas;
use super::canvas;
use super::color::*;
use super::common::*;
use super::component::*;
use super::params::*;
use super::theme::{DEFAULT_Y_AXIS_WIDTH, Theme, get_default_theme_name, get_theme};
use super::util::*;
use crate::charts::measure_text_width_family;
use charts_rs_derive::Chart;
use std::sync::Arc;
#[derive(Clone, Debug, Default)]
pub struct SankeyNode {
pub name: String,
pub color: Option<Color>,
}
impl From<&str> for SankeyNode {
fn from(name: &str) -> Self {
SankeyNode {
name: name.to_string(),
color: None,
}
}
}
#[derive(Clone, Debug, Default)]
pub struct SankeyLink {
pub source: String,
pub target: String,
pub value: f32,
}
impl From<(&str, &str, f32)> for SankeyLink {
fn from(v: (&str, &str, f32)) -> Self {
SankeyLink {
source: v.0.to_string(),
target: v.1.to_string(),
value: v.2,
}
}
}
struct LayoutNode {
name: String,
color: Color,
value: f32,
layer: usize,
x: f32,
y: f32,
dy: f32,
in_links: Vec<usize>,
out_links: Vec<usize>,
}
struct LayoutLink {
source: usize,
target: usize,
value: f32,
width: f32,
sy: f32,
ty: f32,
}
fn sample_link_edge(x0: f32, y0: f32, x1: f32, y1: f32, segments: usize, out: &mut Vec<Point>) {
let xm = (x0 + x1) / 2.0;
for i in 0..=segments {
let t = i as f32 / segments as f32;
let mt = 1.0 - t;
let b0 = mt * mt * mt;
let b1 = 3.0 * mt * mt * t;
let b2 = 3.0 * mt * t * t;
let b3 = t * t * t;
let x = b0 * x0 + b1 * xm + b2 * xm + b3 * x1;
let y = b0 * y0 + b1 * y0 + b2 * y1 + b3 * y1;
out.push((x, y).into());
}
}
#[charts_rs_derive::chart_common_fields]
#[derive(Clone, Debug, Default, Chart)]
pub struct SankeyChart {
pub x_axis_data: Vec<String>,
pub x_axis_height: f32,
pub x_axis_stroke_color: Color,
pub x_axis_font_size: f32,
pub x_axis_font_color: Color,
pub x_axis_font_weight: Option<String>,
pub x_axis_name_gap: f32,
pub x_axis_name_rotate: f32,
pub x_axis_margin: Option<Box>,
pub x_axis_hidden: bool,
pub x_boundary_gap: Option<bool>,
pub y_axis_hidden: bool,
y_axis_configs: Vec<YAxisConfig>,
grid_stroke_color: Color,
grid_stroke_width: f32,
pub series_stroke_width: f32,
pub series_label_font_color: Color,
pub series_label_font_size: f32,
pub series_label_font_weight: Option<String>,
pub series_label_formatter: String,
pub series_colors: Vec<Color>,
pub series_symbol: Option<Symbol>,
pub series_smooth: bool,
pub series_fill: bool,
pub nodes: Vec<SankeyNode>,
pub links: Vec<SankeyLink>,
pub node_width: f32,
pub node_gap: f32,
pub link_opacity: f32,
pub animation: Option<AnimationConfig>,
pub node_align: Option<String>,
pub link_gradient: bool,
}
const RELAX_ITERATIONS: usize = 32;
const LINK_SEGMENTS: usize = 24;
impl SankeyChart {
fn fill_default(&mut self) {
if self.node_width <= 0.0 {
self.node_width = 16.0;
}
if self.node_gap <= 0.0 {
self.node_gap = 8.0;
}
if self.link_opacity <= 0.0 {
self.link_opacity = 0.45;
}
self.link_opacity = self.link_opacity.min(1.0);
}
pub fn new(nodes: Vec<SankeyNode>, links: Vec<SankeyLink>) -> SankeyChart {
SankeyChart::new_with_theme(nodes, links, &get_default_theme_name())
}
pub fn new_with_theme(
nodes: Vec<SankeyNode>,
links: Vec<SankeyLink>,
theme: &str,
) -> SankeyChart {
let mut c = SankeyChart {
nodes,
links,
..Default::default()
};
c.fill_theme(get_theme(theme));
c.fill_default();
c
}
pub fn from_json(json: &str) -> canvas::Result<SankeyChart> {
let mut c = SankeyChart {
..Default::default()
};
let value = c.fill_option(json)?;
if let Some(arr) = value.get("nodes").and_then(|v| v.as_array()) {
c.nodes = arr
.iter()
.filter_map(|item| {
let name = get_string_from_value(item, "name").unwrap_or_default();
if name.is_empty() {
return None;
}
Some(SankeyNode {
name,
color: get_color_from_value(item, "color"),
})
})
.collect();
}
if let Some(arr) = value.get("links").and_then(|v| v.as_array()) {
c.links = arr
.iter()
.filter_map(|item| {
let source = get_string_from_value(item, "source").unwrap_or_default();
let target = get_string_from_value(item, "target").unwrap_or_default();
if source.is_empty() || target.is_empty() {
return None;
}
Some(SankeyLink {
source,
target,
value: get_f32_from_value(item, "value").unwrap_or_default(),
})
})
.collect();
}
if let Some(v) = get_f32_from_value(&value, "node_width") {
c.node_width = v;
}
if let Some(v) = get_f32_from_value(&value, "node_gap") {
c.node_gap = v;
}
if let Some(v) = get_f32_from_value(&value, "link_opacity") {
c.link_opacity = v;
}
if let Some(s) = get_string_from_value(&value, "node_align") {
c.node_align = Some(s);
}
if let Some(b) = get_bool_from_value(&value, "link_gradient") {
c.link_gradient = b;
}
if let Some(anim) = value.get("animation")
&& !anim.is_null()
{
let mut config = AnimationConfig::default();
if let Some(d) = get_usize_from_value(anim, "duration") {
config.duration = d as u32;
}
if let Some(e) = get_string_from_value(anim, "easing") {
config.easing = e;
}
if let Some(d) = get_usize_from_value(anim, "delay") {
config.delay = d as u32;
}
c.animation = Some(config);
}
c.fill_default();
Ok(c)
}
fn layout(&self, cw: f32, ch: f32) -> Option<(Vec<LayoutNode>, Vec<LayoutLink>)> {
let mut names: Vec<String> = vec![];
let mut explicit_color: Vec<Option<Color>> = vec![];
for n in &self.nodes {
if !names.contains(&n.name) {
names.push(n.name.clone());
explicit_color.push(n.color);
}
}
let index_of = |names: &[String], name: &str| names.iter().position(|n| n == name);
let mut links: Vec<LayoutLink> = vec![];
for link in &self.links {
if link.value <= 0.0 {
continue;
}
for name in [&link.source, &link.target] {
if index_of(&names, name).is_none() {
names.push(name.clone());
explicit_color.push(None);
}
}
let source = index_of(&names, &link.source)?;
let target = index_of(&names, &link.target)?;
if source == target {
continue;
}
links.push(LayoutLink {
source,
target,
value: link.value,
width: 0.0,
sy: 0.0,
ty: 0.0,
});
}
if links.is_empty() || names.is_empty() {
return None;
}
let node_count = names.len();
let mut nodes: Vec<LayoutNode> = names
.iter()
.enumerate()
.map(|(i, name)| {
let color = explicit_color[i].unwrap_or_else(|| get_color(&self.series_colors, i));
LayoutNode {
name: name.clone(),
color,
value: 0.0,
layer: 0,
x: 0.0,
y: 0.0,
dy: 0.0,
in_links: vec![],
out_links: vec![],
}
})
.collect();
let mut in_sum = vec![0.0_f32; node_count];
let mut out_sum = vec![0.0_f32; node_count];
for (li, link) in links.iter().enumerate() {
nodes[link.source].out_links.push(li);
nodes[link.target].in_links.push(li);
out_sum[link.source] += link.value;
in_sum[link.target] += link.value;
}
for (i, node) in nodes.iter_mut().enumerate() {
node.value = in_sum[i].max(out_sum[i]);
}
for _ in 0..node_count {
let mut changed = false;
for link in &links {
let want = nodes[link.source].layer + 1;
if nodes[link.target].layer < want {
nodes[link.target].layer = want;
changed = true;
}
}
if !changed {
break;
}
}
let layer_count = nodes.iter().map(|n| n.layer).max().unwrap_or(0) + 1;
match self.node_align.as_deref() {
Some("justify") => {
for node in nodes.iter_mut() {
if node.out_links.is_empty() {
node.layer = layer_count - 1;
}
}
}
Some("right") => {
let mut backward = vec![0usize; node_count];
for _ in 0..node_count {
let mut changed = false;
for link in &links {
let want = backward[link.target] + 1;
if backward[link.source] < want {
backward[link.source] = want;
changed = true;
}
}
if !changed {
break;
}
}
for (i, node) in nodes.iter_mut().enumerate() {
node.layer = (layer_count - 1) - backward[i];
}
}
_ => {}
}
for node in nodes.iter_mut() {
node.x = if layer_count <= 1 {
0.0
} else {
node.layer as f32 / (layer_count - 1) as f32 * (cw - self.node_width)
};
}
let mut layers_nodes: Vec<Vec<usize>> = vec![vec![]; layer_count];
for (i, node) in nodes.iter().enumerate() {
layers_nodes[node.layer].push(i);
}
let mut ky = f32::MAX;
for layer in &layers_nodes {
let sum: f32 = layer.iter().map(|&i| nodes[i].value).sum();
if sum <= 0.0 {
continue;
}
let avail = ch - (layer.len() as f32 - 1.0) * self.node_gap;
if avail <= 0.0 {
return None;
}
ky = ky.min(avail / sum);
}
if !ky.is_finite() || ky <= 0.0 {
return None;
}
for node in nodes.iter_mut() {
node.dy = node.value * ky;
}
for link in links.iter_mut() {
link.width = link.value * ky;
}
for layer in &layers_nodes {
for (i, &ni) in layer.iter().enumerate() {
nodes[ni].y = i as f32;
}
}
resolve_collisions(&mut nodes, &layers_nodes, ch, self.node_gap);
let mut alpha = 1.0_f32;
for _ in 0..RELAX_ITERATIONS {
alpha *= 0.99;
relax(&mut nodes, &links, &layers_nodes, alpha, true);
resolve_collisions(&mut nodes, &layers_nodes, ch, self.node_gap);
relax(&mut nodes, &links, &layers_nodes, alpha, false);
resolve_collisions(&mut nodes, &layers_nodes, ch, self.node_gap);
}
for ni in 0..node_count {
let mut outs = nodes[ni].out_links.clone();
outs.sort_by(|&a, &b| {
center(&nodes[links[a].target])
.partial_cmp(¢er(&nodes[links[b].target]))
.unwrap_or(std::cmp::Ordering::Equal)
});
let mut sy = 0.0;
for li in outs {
links[li].sy = sy;
sy += links[li].width;
}
let mut ins = nodes[ni].in_links.clone();
ins.sort_by(|&a, &b| {
center(&nodes[links[a].source])
.partial_cmp(¢er(&nodes[links[b].source]))
.unwrap_or(std::cmp::Ordering::Equal)
});
let mut ty = 0.0;
for li in ins {
links[li].ty = ty;
ty += links[li].width;
}
}
Some((nodes, links))
}
pub fn svg(&self) -> canvas::Result<String> {
let mut c = Canvas::new_width_xy(self.width, self.height, self.x, self.y);
let axis_top = self.render_header(&mut c);
let mut content = c.child(Box {
top: axis_top,
..Default::default()
});
let cw = content.width();
let ch = content.height();
if cw <= 0.0 || ch <= 0.0 {
return c.svg();
}
let Some((nodes, links)) = self.layout(cw, ch) else {
return c.svg();
};
let grand_total: f32 = {
let inflow: f32 = nodes.iter().filter(|n| n.layer == 0).map(|n| n.value).sum();
if inflow > 0.0 {
inflow
} else {
nodes.iter().map(|n| n.value).sum()
}
};
let alpha = (self.link_opacity.clamp(0.0, 1.0) * 255.0).round() as u8;
for link in &links {
let source = &nodes[link.source];
let target = &nodes[link.target];
let x0 = source.x + self.node_width;
let x1 = target.x;
let top_s = source.y + link.sy;
let top_t = target.y + link.ty;
let mut points: Vec<Point> = vec![];
sample_link_edge(x0, top_s, x1, top_t, LINK_SEGMENTS, &mut points);
sample_link_edge(
x1,
top_t + link.width,
x0,
top_s + link.width,
LINK_SEGMENTS,
&mut points,
);
let (fill, gradient) = if self.link_gradient {
(
None,
Some(Fill::LinearGradient {
start_color: source.color.with_alpha(alpha),
end_color: target.color.with_alpha(alpha),
angle: 90.0,
}),
)
} else {
(Some(source.color.with_alpha(alpha)), None)
};
content.polygon(Polygon {
color: None,
fill,
gradient,
points,
class: self.animation.as_ref().map(|_| "sankey-anim".to_string()),
style: self
.animation
.as_ref()
.map(|a| format!("animation-delay:{}ms", source.layer as u32 * a.delay)),
..Default::default()
});
}
for node in &nodes {
if node.dy <= 0.0 {
continue;
}
content.rect(Rect {
fill: Some(node.color.into()),
left: node.x,
top: node.y,
width: self.node_width,
height: node.dy,
class: self.animation.as_ref().map(|_| "sankey-anim".to_string()),
style: self
.animation
.as_ref()
.map(|a| format!("animation-delay:{}ms", node.layer as u32 * a.delay)),
..Default::default()
});
}
let font_size = self.series_label_font_size.max(10.0);
let font_color = self.series_label_font_color;
for node in &nodes {
if node.dy <= 0.0 {
continue;
}
let text = if self.series_label_formatter.is_empty() {
node.name.clone()
} else {
LabelOption {
series_name: node.name.clone(),
category_name: node.name.clone(),
value: node.value,
percentage: if grand_total > 0.0 {
node.value / grand_total
} else {
0.0
},
formatter: self.series_label_formatter.clone(),
}
.format()
};
if text.is_empty() {
continue;
}
let mid_y = node.y + node.dy / 2.0;
let (x, anchor) = if node.x + self.node_width / 2.0 < cw / 2.0 {
(node.x + self.node_width + 5.0, "start")
} else {
(node.x - 5.0, "end")
};
content.text(Text {
text,
font_family: Some(self.font_family.clone()),
font_color: Some(font_color),
font_size: Some(font_size),
font_weight: self.series_label_font_weight.clone(),
x: Some(x),
y: Some(mid_y),
text_anchor: Some(anchor.to_string()),
dominant_baseline: Some("central".to_string()),
class: self.animation.as_ref().map(|_| "sankey-fade".to_string()),
..Default::default()
});
}
if let Some(ref anim) = self.animation {
let css = format!(
"@keyframes sankey-grow{{from{{transform:scaleX(0)}}to{{transform:scaleX(1)}}}} \
@keyframes sankey-fade{{from{{opacity:0}}to{{opacity:1}}}} \
.sankey-anim{{transform-box:fill-box;transform-origin:left center;\
animation:sankey-grow {}ms {} both}} \
.sankey-fade{{animation:sankey-fade {}ms {} both}}",
anim.duration, anim.easing, anim.duration, anim.easing
);
c.svg_with_style(&css)
} else {
c.svg()
}
}
}
fn center(node: &LayoutNode) -> f32 {
node.y + node.dy / 2.0
}
fn relax(
nodes: &mut [LayoutNode],
links: &[LayoutLink],
layers_nodes: &[Vec<usize>],
alpha: f32,
use_targets: bool,
) {
let order: Vec<usize> = if use_targets {
(0..layers_nodes.len()).rev().collect()
} else {
(0..layers_nodes.len()).collect()
};
for layer_idx in order {
for &ni in &layers_nodes[layer_idx] {
let (cur_center, weighted, has) = {
let node = &nodes[ni];
let link_ids = if use_targets {
&node.out_links
} else {
&node.in_links
};
if link_ids.is_empty() {
(0.0, 0.0, false)
} else {
let mut sum_v = 0.0;
let mut acc = 0.0;
for &li in link_ids {
let other = if use_targets {
&nodes[links[li].target]
} else {
&nodes[links[li].source]
};
acc += center(other) * links[li].value;
sum_v += links[li].value;
}
if sum_v <= 0.0 {
(0.0, 0.0, false)
} else {
(center(node), acc / sum_v, true)
}
}
};
if has {
nodes[ni].y += (weighted - cur_center) * alpha;
}
}
}
}
fn resolve_collisions(
nodes: &mut [LayoutNode],
layers_nodes: &[Vec<usize>],
height: f32,
node_gap: f32,
) {
for layer in layers_nodes {
if layer.is_empty() {
continue;
}
let mut order = layer.clone();
order.sort_by(|&a, &b| {
nodes[a]
.y
.partial_cmp(&nodes[b].y)
.unwrap_or(std::cmp::Ordering::Equal)
});
let mut y0 = 0.0_f32;
for &ni in &order {
let dy = y0 - nodes[ni].y;
if dy > 0.0 {
nodes[ni].y += dy;
}
y0 = nodes[ni].y + nodes[ni].dy + node_gap;
}
let last = *order.last().unwrap();
let overflow = nodes[last].y + nodes[last].dy - height;
if overflow > 0.0 {
nodes[last].y -= overflow;
let mut y_limit = nodes[last].y;
for &ni in order.iter().rev().skip(1) {
let dy = nodes[ni].y + nodes[ni].dy + node_gap - y_limit;
if dy > 0.0 {
nodes[ni].y -= dy;
}
y_limit = nodes[ni].y;
}
}
}
}
#[cfg(test)]
mod tests {
use super::{SankeyChart, SankeyLink, SankeyNode};
use pretty_assertions::assert_eq;
fn make_links() -> Vec<SankeyLink> {
vec![
("Coal", "Electricity", 25.0).into(),
("Coal", "Heat", 10.0).into(),
("Gas", "Electricity", 15.0).into(),
("Gas", "Heat", 20.0).into(),
("Solar", "Electricity", 10.0).into(),
("Electricity", "Residential", 18.0).into(),
("Electricity", "Industrial", 22.0).into(),
("Electricity", "Commercial", 10.0).into(),
("Heat", "Residential", 12.0).into(),
("Heat", "Industrial", 18.0).into(),
]
}
#[test]
fn sankey_chart_basic() {
let chart = SankeyChart::new(vec![], make_links());
assert_eq!(
include_str!("../../asset/sankey_chart/basic.svg"),
chart.svg().unwrap()
);
}
#[test]
fn sankey_chart_basic_json() {
let chart = SankeyChart::from_json(
r##"{
"title_text": "Energy Flow",
"nodes": [
{"name": "Coal"},
{"name": "Gas"},
{"name": "Solar"},
{"name": "Electricity"},
{"name": "Heat"},
{"name": "Residential"},
{"name": "Industrial"},
{"name": "Commercial"}
],
"links": [
{"source": "Coal", "target": "Electricity", "value": 25},
{"source": "Coal", "target": "Heat", "value": 10},
{"source": "Gas", "target": "Electricity", "value": 15},
{"source": "Gas", "target": "Heat", "value": 20},
{"source": "Solar", "target": "Electricity", "value": 10},
{"source": "Electricity", "target": "Residential", "value": 18},
{"source": "Electricity", "target": "Industrial", "value": 22},
{"source": "Electricity", "target": "Commercial", "value": 10},
{"source": "Heat", "target": "Residential", "value": 12},
{"source": "Heat", "target": "Industrial", "value": 18}
]
}"##,
)
.unwrap();
assert_eq!(
include_str!("../../asset/sankey_chart/basic_json.svg"),
chart.svg().unwrap()
);
}
#[test]
fn sankey_chart_label_formatter() {
let mut chart = SankeyChart::new(
vec![SankeyNode::from("a"), SankeyNode::from("b")],
vec![("a", "b", 10.0).into()],
);
chart.series_label_formatter = "{b}: {c}".to_string();
let svg = chart.svg().unwrap();
assert!(svg.contains("a: 10"), "missing formatted source label");
assert!(svg.contains("b: 10"), "missing formatted target label");
}
#[test]
fn sankey_chart_animation() {
let mut chart = SankeyChart::new(vec![], make_links());
chart.animation = Some(super::AnimationConfig {
duration: 700,
easing: "linear".to_string(),
delay: 100,
});
let svg = chart.svg().unwrap();
assert!(
svg.contains("sankey-grow"),
"missing @keyframes sankey-grow"
);
assert!(
svg.contains(r#"class="sankey-anim""#),
"missing class on node/link"
);
assert!(
svg.contains(r#"class="sankey-fade""#),
"missing fade class on labels"
);
assert!(svg.contains("700ms linear"), "missing duration/easing");
assert!(svg.contains("animation-delay:0ms"), "missing layer-0 delay");
assert!(
svg.contains("animation-delay:100ms"),
"missing layer-1 delay"
);
}
#[test]
fn sankey_chart_link_gradient() {
let mut chart = SankeyChart::new(vec![], make_links());
chart.link_gradient = true;
let svg = chart.svg().unwrap();
assert!(svg.contains("<linearGradient"), "missing gradient def");
assert!(
svg.contains("url(#grad_"),
"links should reference a gradient"
);
}
#[test]
fn sankey_chart_node_align() {
let links: Vec<SankeyLink> = vec![
("a", "b", 4.0).into(),
("b", "c", 4.0).into(),
("a", "x", 2.0).into(),
];
let left = SankeyChart::new(vec![], links.clone()).svg().unwrap();
let mut justify = SankeyChart::new(vec![], links);
justify.node_align = Some("justify".to_string());
assert_ne!(left, justify.svg().unwrap());
}
#[test]
fn sankey_chart_empty_links() {
let chart = SankeyChart::new(vec![SankeyNode::from("a")], vec![]);
assert!(chart.svg().unwrap().starts_with("<svg"));
}
}