Skip to main content

plotlars_core/plots/
sankeydiagram.rs

1use bon::bon;
2use std::collections::{hash_map::Entry, HashMap};
3
4use crate::{
5    components::{Arrangement, FacetConfig, Legend, Orientation, Rgb, Text},
6    ir::data::ColumnData,
7    ir::layout::LayoutIR,
8    ir::trace::{SankeyDiagramIR, TraceIR},
9};
10use polars::frame::DataFrame;
11
12/// A structure representing a Sankey diagram.
13///
14/// The `SankeyDiagram` struct enables the creation of Sankey diagrams, which visualize flows
15/// between discrete nodes with link widths proportional to the magnitude of the flow. It
16/// offers extensive configuration options for flow orientation, node arrangement, spacing,
17/// thickness, and coloring, as well as axis and title customization. Users can specify a
18/// single uniform color or per-item colors for both nodes and links, adjust padding between
19/// nodes, set node thickness, and supply custom titles and axis labels to produce clear,
20/// publication-quality flow visualizations. Faceting support allows creating multiple Sankey
21/// diagrams in a grid layout for comparing flows across categories.
22///
23/// # Backend Support
24///
25/// | Backend | Supported |
26/// |---------|-----------|
27/// | Plotly  | Yes       |
28/// | Plotters| --        |
29///
30/// # Arguments
31///
32/// * `data` – A reference to the `DataFrame` containing the data to be plotted.
33/// * `sources` – A string slice naming the column in `data` that contains the source node for each flow.
34/// * `targets` – A string slice naming the column in `data` that contains the target node for each flow.
35/// * `values` – A string slice naming the column in `data` that contains the numeric value of each flow.
36/// * `facet` – An optional string slice naming the column in `data` to be used for creating facets (small multiples).
37/// * `facet_config` – An optional reference to a `FacetConfig` struct for customizing facet layout and behavior.
38/// * `orientation` – An optional `Orientation` enum to set the overall direction of the diagram
39///   (e.g. `Orientation::Horizontal` or `Orientation::Vertical`).
40/// * `arrangement` – An optional `Arrangement` enum to choose the node-layout algorithm
41///   (e.g. `Arrangement::Snap`, `Arrangement::Perpendicular`, etc.).
42/// * `pad` – An optional `usize` specifying the padding (in pixels) between adjacent nodes.
43/// * `thickness` – An optional `usize` defining the uniform thickness (in pixels) of all nodes.
44/// * `node_color` – An optional `Rgb` value to apply a single uniform color to every node.
45/// * `node_colors` – An optional `Vec<Rgb>` supplying individual colors for each node in order.
46/// * `link_color` – An optional `Rgb` value to apply a single uniform color to every link.
47/// * `link_colors` – An optional `Vec<Rgb>` supplying individual colors for each link in order.
48/// * `plot_title` – An optional `Text` struct for setting the overall title of the plot.
49/// * `legend_title` – An optional `Text` struct specifying the title of the legend.
50/// * `legend` – An optional reference to a `Legend` struct for customizing the legend of the plot.
51///
52/// # Example
53///
54/// ```rust
55/// use plotlars::{Arrangement, SankeyDiagram, Orientation, Plot, Rgb, Text};
56/// use polars::prelude::*;
57///
58/// let dataset = LazyCsvReader::new(PlRefPath::new("data/sankey_flow.csv"))
59///     .finish()
60///     .unwrap()
61///     .collect()
62///     .unwrap();
63///
64/// SankeyDiagram::builder()
65///     .data(&dataset)
66///     .sources("source")
67///     .targets("target")
68///     .values("value")
69///     .orientation(Orientation::Horizontal)
70///     .arrangement(Arrangement::Freeform)
71///     .node_colors(vec![
72///         Rgb(222, 235, 247),
73///         Rgb(198, 219, 239),
74///         Rgb(158, 202, 225),
75///         Rgb(107, 174, 214),
76///         Rgb( 66, 146, 198),
77///         Rgb( 33, 113, 181),
78///     ])
79///     .link_colors(vec![
80///         Rgb(222, 235, 247),
81///         Rgb(198, 219, 239),
82///         Rgb(158, 202, 225),
83///         Rgb(107, 174, 214),
84///         Rgb( 66, 146, 198),
85///         Rgb( 33, 113, 181),
86///     ])
87///     .pad(20)
88///     .thickness(30)
89///     .plot_title(
90///         Text::from("Sankey Diagram")
91///             .font("Arial")
92///             .size(18)
93///     )
94///     .build()
95///     .plot();
96/// ```
97///
98/// ![Example](https://imgur.com/jvAew8u.png)
99#[derive(Clone)]
100#[allow(dead_code)]
101pub struct SankeyDiagram {
102    traces: Vec<TraceIR>,
103    layout: LayoutIR,
104}
105
106struct FacetCell {
107    domain_x_start: f64,
108    domain_x_end: f64,
109    domain_y_start: f64,
110    domain_y_end: f64,
111}
112
113#[bon]
114impl SankeyDiagram {
115    #[builder(on(String, into), on(Text, into))]
116    pub fn new(
117        data: &DataFrame,
118        sources: &str,
119        targets: &str,
120        values: &str,
121        facet: Option<&str>,
122        facet_config: Option<&FacetConfig>,
123        orientation: Option<Orientation>,
124        arrangement: Option<Arrangement>,
125        pad: Option<usize>,
126        thickness: Option<usize>,
127        node_color: Option<Rgb>,
128        node_colors: Option<Vec<Rgb>>,
129        link_color: Option<Rgb>,
130        link_colors: Option<Vec<Rgb>>,
131        plot_title: Option<Text>,
132        legend_title: Option<Text>,
133        legend: Option<&Legend>,
134    ) -> Self {
135        let grid = facet.map(|facet_column| {
136            let config = facet_config.cloned().unwrap_or_default();
137            let facet_categories =
138                crate::data::get_unique_groups(data, facet_column, config.sorter);
139            let n_facets = facet_categories.len();
140            let (ncols, nrows) =
141                crate::faceting::calculate_grid_dimensions(n_facets, config.cols, config.rows);
142            crate::ir::facet::GridSpec {
143                kind: crate::ir::facet::FacetKind::Domain,
144                rows: nrows,
145                cols: ncols,
146                h_gap: config.h_gap,
147                v_gap: config.v_gap,
148                scales: config.scales.clone(),
149                n_facets,
150                facet_categories,
151                title_style: config.title_style.clone(),
152                x_title: None,
153                y_title: None,
154                x_axis: None,
155                y_axis: None,
156                legend_title: legend_title.clone(),
157                legend: legend.cloned(),
158            }
159        });
160
161        let layout = LayoutIR {
162            title: plot_title,
163            x_title: None,
164            y_title: None,
165            y2_title: None,
166            z_title: None,
167            legend_title: if grid.is_some() { None } else { legend_title },
168            legend: if grid.is_some() {
169                None
170            } else {
171                legend.cloned()
172            },
173            dimensions: None,
174            bar_mode: None,
175            box_mode: None,
176            box_gap: None,
177            margin_bottom: None,
178            axes_2d: None,
179            scene_3d: None,
180            polar: None,
181            mapbox: None,
182            grid,
183            annotations: vec![],
184        };
185
186        let traces = match facet {
187            Some(facet_column) => {
188                let config = facet_config.cloned().unwrap_or_default();
189                Self::create_ir_traces_faceted(
190                    data,
191                    sources,
192                    targets,
193                    values,
194                    facet_column,
195                    &config,
196                    orientation,
197                    arrangement,
198                    pad,
199                    thickness,
200                    node_color,
201                    node_colors,
202                    link_color,
203                    link_colors,
204                )
205            }
206            None => Self::create_ir_traces(
207                data,
208                sources,
209                targets,
210                values,
211                orientation,
212                arrangement,
213                pad,
214                thickness,
215                node_color,
216                node_colors,
217                link_color,
218                link_colors,
219            ),
220        };
221        Self { traces, layout }
222    }
223}
224
225#[bon]
226impl SankeyDiagram {
227    #[builder(
228        start_fn = try_builder,
229        finish_fn = try_build,
230        builder_type = SankeyDiagramTryBuilder,
231        on(String, into),
232        on(Text, into),
233    )]
234    pub fn try_new(
235        data: &DataFrame,
236        sources: &str,
237        targets: &str,
238        values: &str,
239        facet: Option<&str>,
240        facet_config: Option<&FacetConfig>,
241        orientation: Option<Orientation>,
242        arrangement: Option<Arrangement>,
243        pad: Option<usize>,
244        thickness: Option<usize>,
245        node_color: Option<Rgb>,
246        node_colors: Option<Vec<Rgb>>,
247        link_color: Option<Rgb>,
248        link_colors: Option<Vec<Rgb>>,
249        plot_title: Option<Text>,
250        legend_title: Option<Text>,
251        legend: Option<&Legend>,
252    ) -> Result<Self, crate::io::PlotlarsError> {
253        std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
254            Self::__orig_new(
255                data,
256                sources,
257                targets,
258                values,
259                facet,
260                facet_config,
261                orientation,
262                arrangement,
263                pad,
264                thickness,
265                node_color,
266                node_colors,
267                link_color,
268                link_colors,
269                plot_title,
270                legend_title,
271                legend,
272            )
273        }))
274        .map_err(|panic| {
275            let msg = panic
276                .downcast_ref::<String>()
277                .cloned()
278                .or_else(|| panic.downcast_ref::<&str>().map(|s| s.to_string()))
279                .unwrap_or_else(|| "unknown error".to_string());
280            crate::io::PlotlarsError::PlotBuild { message: msg }
281        })
282    }
283}
284
285impl SankeyDiagram {
286    #[allow(clippy::too_many_arguments)]
287    fn create_ir_traces(
288        data: &DataFrame,
289        sources: &str,
290        targets: &str,
291        values: &str,
292        orientation: Option<Orientation>,
293        arrangement: Option<Arrangement>,
294        pad: Option<usize>,
295        thickness: Option<usize>,
296        node_color: Option<Rgb>,
297        node_colors: Option<Vec<Rgb>>,
298        link_color: Option<Rgb>,
299        link_colors: Option<Vec<Rgb>>,
300    ) -> Vec<TraceIR> {
301        let sources_col = crate::data::get_string_column(data, sources);
302        let targets_col = crate::data::get_string_column(data, targets);
303        let values_data = crate::data::get_numeric_column(data, values);
304
305        let (labels_unique, label_to_idx) = Self::build_label_index(&sources_col, &targets_col);
306
307        let sources_idx = Self::column_to_indices(&sources_col, &label_to_idx);
308        let targets_idx = Self::column_to_indices(&targets_col, &label_to_idx);
309
310        let resolved_node_colors = Self::resolve_node_colors(node_color, node_colors);
311        let resolved_link_colors = Self::resolve_link_colors(link_color, link_colors);
312
313        vec![TraceIR::SankeyDiagram(SankeyDiagramIR {
314            sources: sources_idx,
315            targets: targets_idx,
316            values: ColumnData::Numeric(values_data),
317            node_labels: labels_unique.iter().map(|s| s.to_string()).collect(),
318            orientation,
319            arrangement,
320            pad,
321            thickness,
322            node_colors: resolved_node_colors,
323            link_colors: resolved_link_colors,
324            domain_x: None,
325            domain_y: None,
326        })]
327    }
328
329    #[allow(clippy::too_many_arguments)]
330    fn create_ir_traces_faceted(
331        data: &DataFrame,
332        sources: &str,
333        targets: &str,
334        values: &str,
335        facet_column: &str,
336        config: &FacetConfig,
337        orientation: Option<Orientation>,
338        arrangement: Option<Arrangement>,
339        pad: Option<usize>,
340        thickness: Option<usize>,
341        node_color: Option<Rgb>,
342        node_colors: Option<Vec<Rgb>>,
343        link_color: Option<Rgb>,
344        link_colors: Option<Vec<Rgb>>,
345    ) -> Vec<TraceIR> {
346        const MAX_FACETS: usize = 8;
347
348        let facet_categories = crate::data::get_unique_groups(data, facet_column, config.sorter);
349
350        if facet_categories.len() > MAX_FACETS {
351            panic!(
352                "Facet column '{}' has {} unique values, but plotly.rs supports maximum {} subplots",
353                facet_column,
354                facet_categories.len(),
355                MAX_FACETS
356            );
357        }
358
359        let n_facets = facet_categories.len();
360        let (ncols, nrows) =
361            crate::faceting::calculate_grid_dimensions(n_facets, config.cols, config.rows);
362
363        let facet_categories_non_empty: Vec<String> = facet_categories
364            .iter()
365            .filter(|facet_value| {
366                let facet_data = crate::data::filter_data_by_group(data, facet_column, facet_value);
367                facet_data.height() > 0
368            })
369            .cloned()
370            .collect();
371
372        let resolved_node_colors = Self::resolve_node_colors(node_color, node_colors);
373        let resolved_link_colors = Self::resolve_link_colors(link_color, link_colors);
374
375        let mut traces = Vec::new();
376
377        for (idx, facet_value) in facet_categories_non_empty.iter().enumerate() {
378            let facet_data = crate::data::filter_data_by_group(data, facet_column, facet_value);
379
380            let cell = Self::calculate_facet_cell(idx, ncols, nrows, config.h_gap, config.v_gap);
381
382            let sources_col = crate::data::get_string_column(&facet_data, sources);
383            let targets_col = crate::data::get_string_column(&facet_data, targets);
384            let values_data = crate::data::get_numeric_column(&facet_data, values);
385
386            let (labels_unique, label_to_idx) = Self::build_label_index(&sources_col, &targets_col);
387
388            let sources_idx = Self::column_to_indices(&sources_col, &label_to_idx);
389            let targets_idx = Self::column_to_indices(&targets_col, &label_to_idx);
390
391            traces.push(TraceIR::SankeyDiagram(SankeyDiagramIR {
392                sources: sources_idx,
393                targets: targets_idx,
394                values: ColumnData::Numeric(values_data),
395                node_labels: labels_unique.iter().map(|s| s.to_string()).collect(),
396                orientation: orientation.clone(),
397                arrangement: arrangement.clone(),
398                pad,
399                thickness,
400                node_colors: resolved_node_colors.clone(),
401                link_colors: resolved_link_colors.clone(),
402                domain_x: Some((cell.domain_x_start, cell.domain_x_end)),
403                domain_y: Some((cell.domain_y_start, cell.domain_y_end)),
404            }));
405        }
406
407        traces
408    }
409
410    fn resolve_node_colors(
411        node_color: Option<Rgb>,
412        node_colors: Option<Vec<Rgb>>,
413    ) -> Option<Vec<Rgb>> {
414        node_colors.or_else(|| node_color.map(|color| vec![color]))
415    }
416
417    fn resolve_link_colors(
418        link_color: Option<Rgb>,
419        link_colors: Option<Vec<Rgb>>,
420    ) -> Option<Vec<Rgb>> {
421        link_colors.or_else(|| link_color.map(|color| vec![color]))
422    }
423    /// Calculates the grid cell positions for a subplot with reserved space for titles.
424    ///
425    /// This function computes both the sankey diagram domain and annotation position,
426    /// ensuring that space is reserved above each diagram for the facet title.
427    fn calculate_facet_cell(
428        subplot_index: usize,
429        ncols: usize,
430        nrows: usize,
431        x_gap: Option<f64>,
432        y_gap: Option<f64>,
433    ) -> FacetCell {
434        let row = subplot_index / ncols;
435        let col = subplot_index % ncols;
436
437        let x_gap_val = x_gap.unwrap_or(0.05);
438        let y_gap_val = y_gap.unwrap_or(0.10);
439
440        // Reserve space for facet title (10% of each cell's height)
441        const TITLE_HEIGHT_RATIO: f64 = 0.10;
442
443        // Calculate total cell dimensions
444        let cell_width = (1.0 - x_gap_val * (ncols - 1) as f64) / ncols as f64;
445        let cell_height = (1.0 - y_gap_val * (nrows - 1) as f64) / nrows as f64;
446
447        // Calculate cell boundaries
448        let cell_x_start = col as f64 * (cell_width + x_gap_val);
449        let cell_y_top = 1.0 - row as f64 * (cell_height + y_gap_val);
450        let cell_y_bottom = cell_y_top - cell_height;
451
452        // Reserve title space at the top of the cell
453        let title_height = cell_height * TITLE_HEIGHT_RATIO;
454        let sankey_y_top = cell_y_top - title_height;
455
456        // Sankey diagram domain (bottom 90% of the cell)
457        let sankey_x_start = cell_x_start;
458        let sankey_x_end = cell_x_start + cell_width;
459        let sankey_y_start = cell_y_bottom;
460        let sankey_y_end = sankey_y_top;
461
462        FacetCell {
463            domain_x_start: sankey_x_start,
464            domain_x_end: sankey_x_end,
465            domain_y_start: sankey_y_start,
466            domain_y_end: sankey_y_end,
467        }
468    }
469
470    fn build_label_index<'a>(
471        sources: &'a [Option<String>],
472        targets: &'a [Option<String>],
473    ) -> (Vec<&'a str>, HashMap<&'a str, usize>) {
474        let mut label_to_idx: HashMap<&'a str, usize> = HashMap::new();
475        let mut labels_unique: Vec<&'a str> = Vec::new();
476
477        let iter = sources
478            .iter()
479            .chain(targets.iter())
480            .filter_map(|opt| opt.as_deref());
481
482        for lbl in iter {
483            if let Entry::Vacant(v) = label_to_idx.entry(lbl) {
484                let next_id = labels_unique.len();
485                v.insert(next_id);
486                labels_unique.push(lbl);
487            }
488        }
489
490        (labels_unique, label_to_idx)
491    }
492
493    fn column_to_indices(
494        column: &[Option<String>],
495        label_to_idx: &HashMap<&str, usize>,
496    ) -> Vec<usize> {
497        column
498            .iter()
499            .filter_map(|opt| opt.as_deref())
500            .map(|lbl| *label_to_idx.get(lbl).expect("label must exist in map"))
501            .collect()
502    }
503}
504
505impl crate::Plot for SankeyDiagram {
506    fn ir_traces(&self) -> &[TraceIR] {
507        &self.traces
508    }
509
510    fn ir_layout(&self) -> &LayoutIR {
511        &self.layout
512    }
513}
514
515#[cfg(test)]
516mod tests {
517    use super::*;
518    use crate::Plot;
519    use polars::prelude::*;
520
521    #[test]
522    fn test_basic_one_trace() {
523        let df = df![
524            "source" => ["A", "A", "B"],
525            "target" => ["B", "C", "C"],
526            "value" => [10.0, 20.0, 30.0]
527        ]
528        .unwrap();
529        let plot = SankeyDiagram::builder()
530            .data(&df)
531            .sources("source")
532            .targets("target")
533            .values("value")
534            .build();
535        assert_eq!(plot.ir_traces().len(), 1);
536        assert!(matches!(plot.ir_traces()[0], TraceIR::SankeyDiagram(_)));
537    }
538
539    #[test]
540    fn test_layout_no_axes() {
541        let df = df![
542            "source" => ["A", "B"],
543            "target" => ["B", "C"],
544            "value" => [10.0, 20.0]
545        ]
546        .unwrap();
547        let plot = SankeyDiagram::builder()
548            .data(&df)
549            .sources("source")
550            .targets("target")
551            .values("value")
552            .build();
553        assert!(plot.ir_layout().axes_2d.is_none());
554    }
555
556    #[test]
557    fn test_layout_title() {
558        let df = df![
559            "source" => ["A"],
560            "target" => ["B"],
561            "value" => [10.0]
562        ]
563        .unwrap();
564        let plot = SankeyDiagram::builder()
565            .data(&df)
566            .sources("source")
567            .targets("target")
568            .values("value")
569            .plot_title("Sankey")
570            .build();
571        assert!(plot.ir_layout().title.is_some());
572    }
573}