Skip to main content

plotlars/plots/
sankeydiagram.rs

1use bon::bon;
2use std::collections::{hash_map::Entry, HashMap};
3
4use plotly::{
5    common::{Anchor, Domain},
6    layout::Annotation,
7    sankey::{Link, Node},
8    Layout as LayoutPlotly, Sankey, Trace,
9};
10
11use polars::frame::DataFrame;
12use serde::Serialize;
13
14use crate::{
15    common::{Layout, PlotHelper, Polar},
16    components::{Arrangement, FacetConfig, Legend, Orientation, Rgb, Text},
17};
18
19/// A structure representing a Sankey diagram.
20///
21/// The `SankeyDiagram` struct enables the creation of Sankey diagrams, which visualize flows
22/// between discrete nodes with link widths proportional to the magnitude of the flow. It
23/// offers extensive configuration options for flow orientation, node arrangement, spacing,
24/// thickness, and coloring, as well as axis and title customization. Users can specify a
25/// single uniform color or per-item colors for both nodes and links, adjust padding between
26/// nodes, set node thickness, and supply custom titles and axis labels to produce clear,
27/// publication-quality flow visualizations. Faceting support allows creating multiple Sankey
28/// diagrams in a grid layout for comparing flows across categories.
29///
30/// # Arguments
31///
32/// * `data` – A reference to the `DataFrame` containing the data to be plotted.
33/// * `sources` – A string slice naming the column in `data` that contains the source node for each flow.
34/// * `targets` – A string slice naming the column in `data` that contains the target node for each flow.
35/// * `values` – A string slice naming the column in `data` that contains the numeric value of each flow.
36/// * `facet` – An optional string slice naming the column in `data` to be used for creating facets (small multiples).
37/// * `facet_config` – An optional reference to a `FacetConfig` struct for customizing facet layout and behavior.
38/// * `orientation` – An optional `Orientation` enum to set the overall direction of the diagram
39///   (e.g. `Orientation::Horizontal` or `Orientation::Vertical`).
40/// * `arrangement` – An optional `Arrangement` enum to choose the node-layout algorithm
41///   (e.g. `Arrangement::Snap`, `Arrangement::Perpendicular`, etc.).
42/// * `pad` – An optional `usize` specifying the padding (in pixels) between adjacent nodes.
43/// * `thickness` – An optional `usize` defining the uniform thickness (in pixels) of all nodes.
44/// * `node_color` – An optional `Rgb` value to apply a single uniform color to every node.
45/// * `node_colors` – An optional `Vec<Rgb>` supplying individual colors for each node in order.
46/// * `link_color` – An optional `Rgb` value to apply a single uniform color to every link.
47/// * `link_colors` – An optional `Vec<Rgb>` supplying individual colors for each link in order.
48/// * `plot_title` – An optional `Text` struct for setting the overall title of the plot.
49/// * `legend_title` – An optional `Text` struct specifying the title of the legend.
50/// * `legend` – An optional reference to a `Legend` struct for customizing the legend of the plot.
51///
52/// # Example
53///
54/// ```rust
55/// use plotlars::{Arrangement, SankeyDiagram, Orientation, Plot, Rgb, Text};
56/// use polars::prelude::*;
57///
58/// let dataset = LazyCsvReader::new(PlRefPath::new("data/sankey_flow.csv"))
59///     .finish()
60///     .unwrap()
61///     .collect()
62///     .unwrap();
63///
64/// SankeyDiagram::builder()
65///     .data(&dataset)
66///     .sources("source")
67///     .targets("target")
68///     .values("value")
69///     .orientation(Orientation::Horizontal)
70///     .arrangement(Arrangement::Freeform)
71///     .node_colors(vec![
72///         Rgb(222, 235, 247),
73///         Rgb(198, 219, 239),
74///         Rgb(158, 202, 225),
75///         Rgb(107, 174, 214),
76///         Rgb( 66, 146, 198),
77///         Rgb( 33, 113, 181),
78///     ])
79///     .link_colors(vec![
80///         Rgb(222, 235, 247),
81///         Rgb(198, 219, 239),
82///         Rgb(158, 202, 225),
83///         Rgb(107, 174, 214),
84///         Rgb( 66, 146, 198),
85///         Rgb( 33, 113, 181),
86///     ])
87///     .pad(20)
88///     .thickness(30)
89///     .plot_title(
90///         Text::from("Sankey Diagram")
91///             .font("Arial")
92///             .size(18)
93///     )
94///     .build()
95///     .plot();
96/// ```
97///
98/// ![Example](https://imgur.com/jvAew8u.png)
99#[derive(Clone, Serialize)]
100pub struct SankeyDiagram {
101    traces: Vec<Box<dyn Trace + 'static>>,
102    layout: LayoutPlotly,
103}
104
105#[bon]
106impl SankeyDiagram {
107    #[builder(on(String, into), on(Text, into))]
108    pub fn new(
109        data: &DataFrame,
110        sources: &str,
111        targets: &str,
112        values: &str,
113        facet: Option<&str>,
114        facet_config: Option<&FacetConfig>,
115        orientation: Option<Orientation>,
116        arrangement: Option<Arrangement>,
117        pad: Option<usize>,
118        thickness: Option<usize>,
119        node_color: Option<Rgb>,
120        node_colors: Option<Vec<Rgb>>,
121        link_color: Option<Rgb>,
122        link_colors: Option<Vec<Rgb>>,
123        plot_title: Option<Text>,
124        legend_title: Option<Text>,
125        legend: Option<&Legend>,
126    ) -> Self {
127        let x_title = None;
128        let y_title = None;
129        let z_title = None;
130        let x_axis = None;
131        let y_axis = None;
132        let z_axis = None;
133        let y2_title = None;
134        let y2_axis = None;
135
136        let (layout, traces) = match facet {
137            Some(facet_column) => {
138                let config = facet_config.cloned().unwrap_or_default();
139
140                let layout = Self::create_faceted_layout(
141                    data,
142                    facet_column,
143                    &config,
144                    plot_title,
145                    legend_title,
146                    legend,
147                );
148
149                let traces = Self::create_faceted_traces(
150                    data,
151                    sources,
152                    targets,
153                    values,
154                    facet_column,
155                    &config,
156                    orientation,
157                    arrangement,
158                    pad,
159                    thickness,
160                    node_color,
161                    node_colors,
162                    link_color,
163                    link_colors,
164                );
165
166                (layout, traces)
167            }
168            None => {
169                let layout = Self::create_layout(
170                    plot_title,
171                    x_title,
172                    y_title,
173                    y2_title,
174                    z_title,
175                    legend_title,
176                    x_axis,
177                    y_axis,
178                    y2_axis,
179                    z_axis,
180                    legend,
181                    None,
182                );
183
184                let traces = Self::create_traces(
185                    data,
186                    sources,
187                    targets,
188                    values,
189                    orientation,
190                    arrangement,
191                    pad,
192                    thickness,
193                    node_color,
194                    node_colors,
195                    link_color,
196                    link_colors,
197                );
198
199                (layout, traces)
200            }
201        };
202
203        Self { traces, layout }
204    }
205
206    #[allow(clippy::too_many_arguments)]
207    fn create_traces(
208        data: &DataFrame,
209        sources: &str,
210        targets: &str,
211        values: &str,
212        orientation: Option<Orientation>,
213        arrangement: Option<Arrangement>,
214        pad: Option<usize>,
215        thickness: Option<usize>,
216        node_color: Option<Rgb>,
217        node_colors: Option<Vec<Rgb>>,
218        link_color: Option<Rgb>,
219        link_colors: Option<Vec<Rgb>>,
220    ) -> Vec<Box<dyn Trace + 'static>> {
221        let mut traces: Vec<Box<dyn Trace + 'static>> = Vec::new();
222
223        let trace = Self::create_trace(
224            data,
225            sources,
226            targets,
227            values,
228            orientation,
229            arrangement,
230            pad,
231            thickness,
232            node_color,
233            node_colors,
234            link_color,
235            link_colors,
236            None,
237        );
238
239        traces.push(trace);
240        traces
241    }
242
243    #[allow(clippy::too_many_arguments)]
244    fn create_trace(
245        data: &DataFrame,
246        sources: &str,
247        targets: &str,
248        values: &str,
249        orientation: Option<Orientation>,
250        arrangement: Option<Arrangement>,
251        pad: Option<usize>,
252        thickness: Option<usize>,
253        node_color: Option<Rgb>,
254        node_colors: Option<Vec<Rgb>>,
255        link_color: Option<Rgb>,
256        link_colors: Option<Vec<Rgb>>,
257        domain: Option<Domain>,
258    ) -> Box<dyn Trace + 'static> {
259        let sources = Self::get_string_column(data, sources);
260        let targets = Self::get_string_column(data, targets);
261        let values = Self::get_numeric_column(data, values);
262
263        let (labels_unique, label_to_idx) = Self::build_label_index(&sources, &targets);
264
265        let sources_idx = Self::column_to_indices(&sources, &label_to_idx);
266        let targets_idx = Self::column_to_indices(&targets, &label_to_idx);
267
268        let mut node = Node::new().label(labels_unique);
269
270        node = Self::set_pad(node, pad);
271        node = Self::set_thickness(node, thickness);
272        node = Self::set_node_color(node, node_color);
273        node = Self::set_node_colors(node, node_colors);
274
275        let mut link = Link::new()
276            .source(sources_idx)
277            .target(targets_idx)
278            .value(values);
279
280        link = Self::set_link_color(link, link_color);
281        link = Self::set_link_colors(link, link_colors);
282
283        let mut trace = Sankey::new().node(node).link(link);
284
285        trace = Self::set_orientation(trace, orientation);
286        trace = Self::set_arrangement(trace, arrangement);
287
288        if let Some(domain_val) = domain {
289            trace = trace.domain(domain_val);
290        }
291
292        trace
293    }
294
295    #[allow(clippy::too_many_arguments)]
296    fn create_faceted_traces(
297        data: &DataFrame,
298        sources: &str,
299        targets: &str,
300        values: &str,
301        facet_column: &str,
302        config: &FacetConfig,
303        orientation: Option<Orientation>,
304        arrangement: Option<Arrangement>,
305        pad: Option<usize>,
306        thickness: Option<usize>,
307        node_color: Option<Rgb>,
308        node_colors: Option<Vec<Rgb>>,
309        link_color: Option<Rgb>,
310        link_colors: Option<Vec<Rgb>>,
311    ) -> Vec<Box<dyn Trace + 'static>> {
312        const MAX_FACETS: usize = 8;
313
314        let facet_categories = Self::get_unique_groups(data, facet_column, config.sorter);
315
316        if facet_categories.len() > MAX_FACETS {
317            panic!(
318                "Facet column '{}' has {} unique values, but plotly.rs supports maximum {} subplots",
319                facet_column,
320                facet_categories.len(),
321                MAX_FACETS
322            );
323        }
324
325        let n_facets = facet_categories.len();
326        let (ncols, nrows) = Self::calculate_grid_dimensions(n_facets, config.cols, config.rows);
327
328        // Filter out facets with no data to prevent empty diagrams
329        let facet_categories_non_empty: Vec<String> = facet_categories
330            .iter()
331            .filter(|facet_value| {
332                let facet_data = Self::filter_data_by_group(data, facet_column, facet_value);
333                facet_data.height() > 0
334            })
335            .cloned()
336            .collect();
337
338        let mut all_traces = Vec::new();
339
340        // Need to clone Vec colors for reuse across facets
341        let node_colors_cloned = node_colors.clone();
342        let link_colors_cloned = link_colors.clone();
343
344        for (idx, facet_value) in facet_categories_non_empty.iter().enumerate() {
345            let facet_data = Self::filter_data_by_group(data, facet_column, facet_value);
346
347            let domain =
348                Self::calculate_sankey_domain(idx, ncols, nrows, config.h_gap, config.v_gap);
349
350            let trace = Self::create_trace(
351                &facet_data,
352                sources,
353                targets,
354                values,
355                orientation.clone(),
356                arrangement.clone(),
357                pad,
358                thickness,
359                node_color,
360                node_colors_cloned.clone(),
361                link_color,
362                link_colors_cloned.clone(),
363                Some(domain),
364            );
365
366            all_traces.push(trace);
367        }
368
369        all_traces
370    }
371
372    fn create_faceted_layout(
373        data: &DataFrame,
374        facet_column: &str,
375        config: &FacetConfig,
376        plot_title: Option<Text>,
377        legend_title: Option<Text>,
378        legend: Option<&Legend>,
379    ) -> LayoutPlotly {
380        let facet_categories = Self::get_unique_groups(data, facet_column, config.sorter);
381
382        // Filter out facets with no data
383        let facet_categories_non_empty: Vec<String> = facet_categories
384            .iter()
385            .filter(|facet_value| {
386                let facet_data = Self::filter_data_by_group(data, facet_column, facet_value);
387                facet_data.height() > 0
388            })
389            .cloned()
390            .collect();
391
392        let n_facets = facet_categories_non_empty.len();
393        let (ncols, nrows) = Self::calculate_grid_dimensions(n_facets, config.cols, config.rows);
394
395        let mut layout = LayoutPlotly::new();
396
397        if let Some(title) = plot_title {
398            layout = layout.title(title.to_plotly());
399        }
400
401        let annotations = Self::create_facet_annotations_sankey(
402            &facet_categories_non_empty,
403            ncols,
404            nrows,
405            config.title_style.as_ref(),
406            config.h_gap,
407            config.v_gap,
408        );
409        layout = layout.annotations(annotations);
410
411        layout = layout.legend(Legend::set_legend(legend_title, legend));
412
413        layout
414    }
415
416    /// Calculates the grid cell positions for a subplot with reserved space for titles.
417    ///
418    /// This function computes both the sankey diagram domain and annotation position,
419    /// ensuring that space is reserved above each diagram for the facet title.
420    fn calculate_facet_cell(
421        subplot_index: usize,
422        ncols: usize,
423        nrows: usize,
424        x_gap: Option<f64>,
425        y_gap: Option<f64>,
426    ) -> FacetCell {
427        let row = subplot_index / ncols;
428        let col = subplot_index % ncols;
429
430        let x_gap_val = x_gap.unwrap_or(0.05);
431        let y_gap_val = y_gap.unwrap_or(0.10);
432
433        // Reserve space for facet title (10% of each cell's height)
434        const TITLE_HEIGHT_RATIO: f64 = 0.10;
435        // Padding ratio creates buffer space above annotation (35% of reserved title space)
436        const TITLE_PADDING_RATIO: f64 = 0.35;
437
438        // Calculate total cell dimensions
439        let cell_width = (1.0 - x_gap_val * (ncols - 1) as f64) / ncols as f64;
440        let cell_height = (1.0 - y_gap_val * (nrows - 1) as f64) / nrows as f64;
441
442        // Calculate cell boundaries
443        let cell_x_start = col as f64 * (cell_width + x_gap_val);
444        let cell_y_top = 1.0 - row as f64 * (cell_height + y_gap_val);
445        let cell_y_bottom = cell_y_top - cell_height;
446
447        // Reserve title space at the top of the cell
448        let title_height = cell_height * TITLE_HEIGHT_RATIO;
449        let sankey_y_top = cell_y_top - title_height;
450
451        // Sankey diagram domain (bottom 90% of the cell)
452        let sankey_x_start = cell_x_start;
453        let sankey_x_end = cell_x_start + cell_width;
454        let sankey_y_start = cell_y_bottom;
455        let sankey_y_end = sankey_y_top;
456
457        // Calculate annotation position with padding buffer
458        let padding_height = title_height * TITLE_PADDING_RATIO;
459        let actual_title_height = title_height - padding_height;
460        let annotation_x = cell_x_start + cell_width / 2.0;
461        let annotation_y = sankey_y_top + padding_height + (actual_title_height / 2.0);
462
463        FacetCell {
464            domain_x_start: sankey_x_start,
465            domain_x_end: sankey_x_end,
466            domain_y_start: sankey_y_start,
467            domain_y_end: sankey_y_end,
468            annotation_x,
469            annotation_y,
470        }
471    }
472
473    fn calculate_sankey_domain(
474        subplot_index: usize,
475        ncols: usize,
476        nrows: usize,
477        x_gap: Option<f64>,
478        y_gap: Option<f64>,
479    ) -> Domain {
480        let cell = Self::calculate_facet_cell(subplot_index, ncols, nrows, x_gap, y_gap);
481        Domain::new()
482            .x(&[cell.domain_x_start, cell.domain_x_end])
483            .y(&[cell.domain_y_start, cell.domain_y_end])
484    }
485
486    fn create_facet_annotations_sankey(
487        categories: &[String],
488        ncols: usize,
489        nrows: usize,
490        title_style: Option<&Text>,
491        x_gap: Option<f64>,
492        y_gap: Option<f64>,
493    ) -> Vec<Annotation> {
494        categories
495            .iter()
496            .enumerate()
497            .map(|(i, cat)| {
498                let cell = Self::calculate_facet_cell(i, ncols, nrows, x_gap, y_gap);
499
500                let mut ann = Annotation::new()
501                    .text(cat.as_str())
502                    .x_ref("paper")
503                    .y_ref("paper")
504                    .x_anchor(Anchor::Center)
505                    .y_anchor(Anchor::Middle)
506                    .x(cell.annotation_x)
507                    .y(cell.annotation_y)
508                    .show_arrow(false);
509
510                if let Some(style) = title_style {
511                    ann = ann.font(style.to_font());
512                }
513
514                ann
515            })
516            .collect()
517    }
518
519    fn set_thickness(mut node: Node, thickness: Option<usize>) -> Node {
520        if let Some(thickness) = thickness {
521            node = node.thickness(thickness);
522        }
523
524        node
525    }
526
527    fn set_pad(mut node: Node, pad: Option<usize>) -> Node {
528        if let Some(pad) = pad {
529            node = node.pad(pad);
530        }
531
532        node
533    }
534
535    fn set_link_colors<V>(mut link: Link<V>, colors: Option<Vec<Rgb>>) -> Link<V>
536    where
537        V: Serialize + Clone,
538    {
539        if let Some(colors) = colors {
540            link = link.color_array(colors.iter().map(|color| color.to_plotly()).collect());
541        }
542
543        link
544    }
545
546    fn set_link_color<V>(mut link: Link<V>, color: Option<Rgb>) -> Link<V>
547    where
548        V: Serialize + Clone,
549    {
550        if let Some(color) = color {
551            link = link.color(color);
552        }
553
554        link
555    }
556
557    fn set_node_colors(mut node: Node, colors: Option<Vec<Rgb>>) -> Node {
558        if let Some(colors) = colors {
559            node = node.color_array(colors.iter().map(|color| color.to_plotly()).collect());
560        }
561
562        node
563    }
564
565    fn set_node_color(mut node: Node, color: Option<Rgb>) -> Node {
566        if let Some(color) = color {
567            node = node.color(color);
568        }
569
570        node
571    }
572
573    fn set_arrangement(
574        mut trace: Box<Sankey<Option<f32>>>,
575        arrangement: Option<Arrangement>,
576    ) -> Box<Sankey<Option<f32>>> {
577        if let Some(arrangement) = arrangement {
578            trace = trace.arrangement(arrangement.to_plotly())
579        }
580
581        trace
582    }
583
584    fn set_orientation(
585        mut trace: Box<Sankey<Option<f32>>>,
586        orientation: Option<Orientation>,
587    ) -> Box<Sankey<Option<f32>>> {
588        if let Some(orientation) = orientation {
589            trace = trace.orientation(orientation.to_plotly())
590        }
591
592        trace
593    }
594
595    fn build_label_index<'a>(
596        sources: &'a [Option<String>],
597        targets: &'a [Option<String>],
598    ) -> (Vec<&'a str>, HashMap<&'a str, usize>) {
599        let mut label_to_idx: HashMap<&'a str, usize> = HashMap::new();
600        let mut labels_unique: Vec<&'a str> = Vec::new();
601
602        let iter = sources
603            .iter()
604            .chain(targets.iter())
605            .filter_map(|opt| opt.as_deref());
606
607        for lbl in iter {
608            if let Entry::Vacant(v) = label_to_idx.entry(lbl) {
609                let next_id = labels_unique.len();
610                v.insert(next_id);
611                labels_unique.push(lbl);
612            }
613        }
614
615        (labels_unique, label_to_idx)
616    }
617
618    fn column_to_indices(
619        column: &[Option<String>],
620        label_to_idx: &HashMap<&str, usize>,
621    ) -> Vec<usize> {
622        column
623            .iter()
624            .filter_map(|opt| opt.as_deref())
625            .map(|lbl| *label_to_idx.get(lbl).expect("label must exist in map"))
626            .collect()
627    }
628}
629
630/// Helper struct containing calculated positions for a facet cell
631struct FacetCell {
632    domain_x_start: f64,
633    domain_x_end: f64,
634    domain_y_start: f64,
635    domain_y_end: f64,
636    annotation_x: f64,
637    annotation_y: f64,
638}
639
640impl Layout for SankeyDiagram {}
641impl Polar for SankeyDiagram {}
642
643impl PlotHelper for SankeyDiagram {
644    fn get_layout(&self) -> &LayoutPlotly {
645        &self.layout
646    }
647
648    fn get_traces(&self) -> &Vec<Box<dyn Trace + 'static>> {
649        &self.traces
650    }
651}