plotlars/plots/
sankeydiagram.rs

1use bon::bon;
2use std::collections::{HashMap, hash_map::Entry};
3
4use plotly::{
5    Layout as LayoutPlotly, Sankey, Trace,
6    sankey::{Link, Node},
7};
8
9use polars::frame::DataFrame;
10use serde::Serialize;
11
12use crate::{
13    common::{Layout, PlotHelper, Polar},
14    components::{Arrangement, Orientation, Rgb, Text},
15};
16
17/// A structure representing a Sankey diagram.
18///
19/// The `SankeyDiagram` struct enables the creation of Sankey diagrams, which visualize flows
20/// between discrete nodes with link widths proportional to the magnitude of the flow. It
21/// offers extensive configuration options for flow orientation, node arrangement, spacing,
22/// thickness, and coloring, as well as axis and title customization. Users can specify a
23/// single uniform color or per-item colors for both nodes and links, adjust padding between
24/// nodes, set node thickness, and supply custom titles and axis labels to produce clear,
25/// publication-quality flow visualizations.
26///
27/// # Arguments
28///
29/// * `data` – A reference to the `DataFrame` containing the data to be plotted.
30/// * `sources` – A string slice naming the column in `data` that contains the source node for each flow.
31/// * `targets` – A string slice naming the column in `data` that contains the target node for each flow.
32/// * `values` – A string slice naming the column in `data` that contains the numeric value of each flow.
33/// * `orientation` – An optional `Orientation` enum to set the overall direction of the diagram
34///   (e.g. `Orientation::Horizontal` or `Orientation::Vertical`).
35/// * `arrangement` – An optional `Arrangement` enum to choose the node-layout algorithm
36///   (e.g. `Arrangement::Snap`, `Arrangement::Perpendicular`, etc.).
37/// * `pad` – An optional `usize` specifying the padding (in pixels) between adjacent nodes.
38/// * `thickness` – An optional `usize` defining the uniform thickness (in pixels) of all nodes.
39/// * `node_color` – An optional `Rgb` value to apply a single uniform color to every node.
40/// * `node_colors` – An optional `Vec<Rgb>` supplying individual colors for each node in order.
41/// * `link_color` – An optional `Rgb` value to apply a single uniform color to every link.
42/// * `link_colors` – An optional `Vec<Rgb>` supplying individual colors for each link in order.
43/// * `plot_title` – An optional `Text` struct for setting the overall title of the plot.
44///
45/// # Example
46///
47/// ```rust
48/// use polars::prelude::*;
49/// use plotlars::{Arrangement, SankeyDiagram, Orientation, Plot, Rgb, Text};
50///
51/// let dataset = df![
52///         "source" => ["A1", "A2", "A1", "B1", "B2", "B2"],
53///         "target" => &["B1", "B2", "B2", "C1", "C1", "C2"],
54///         "value" => &[8, 4, 2, 8, 4, 2],
55///     ]
56///     .unwrap();
57///
58/// SankeyDiagram::builder()
59///     .data(&dataset)
60///     .sources("source")
61///     .targets("target")
62///     .values("value")
63///     .orientation(Orientation::Horizontal)
64///     .arrangement(Arrangement::Freeform)
65///     .node_colors(vec![
66///         Rgb(222, 235, 247),
67///         Rgb(198, 219, 239),
68///         Rgb(158, 202, 225),
69///         Rgb(107, 174, 214),
70///         Rgb( 66, 146, 198),
71///         Rgb( 33, 113, 181),
72///     ])
73///     .link_colors(vec![
74///         Rgb(222, 235, 247),
75///         Rgb(198, 219, 239),
76///         Rgb(158, 202, 225),
77///         Rgb(107, 174, 214),
78///         Rgb( 66, 146, 198),
79///         Rgb( 33, 113, 181),
80///     ])
81///     .pad(20)
82///     .thickness(30)
83///     .plot_title(
84///         Text::from("Sankey Diagram")
85///             .font("Arial")
86///             .size(18)
87///     )
88///     .build()
89///     .plot();
90/// ```
91///
92/// ![Example](https://imgur.com/jvAew8u.png)
93#[derive(Clone, Serialize)]
94pub struct SankeyDiagram {
95    traces: Vec<Box<dyn Trace + 'static>>,
96    layout: LayoutPlotly,
97}
98
99#[bon]
100impl SankeyDiagram {
101    #[builder(on(String, into), on(Text, into))]
102    pub fn new(
103        data: &DataFrame,
104        sources: &str,
105        targets: &str,
106        values: &str,
107        orientation: Option<Orientation>,
108        arrangement: Option<Arrangement>,
109        pad: Option<usize>,
110        thickness: Option<usize>,
111        node_color: Option<Rgb>,
112        node_colors: Option<Vec<Rgb>>,
113        link_color: Option<Rgb>,
114        link_colors: Option<Vec<Rgb>>,
115        plot_title: Option<Text>,
116    ) -> Self {
117        let legend = None;
118        let legend_title = None;
119        let x_title = None;
120        let y_title = None;
121        let z_title = None;
122        let x_axis = None;
123        let y_axis = None;
124        let z_axis = None;
125        let y2_title = None;
126        let y2_axis = None;
127
128        let layout = Self::create_layout(
129            plot_title,
130            x_title,
131            y_title,
132            y2_title,
133            z_title,
134            legend_title,
135            x_axis,
136            y_axis,
137            y2_axis,
138            z_axis,
139            legend,
140        );
141
142        let traces = Self::create_traces(
143            data,
144            sources,
145            targets,
146            values,
147            orientation,
148            arrangement,
149            pad,
150            thickness,
151            node_color,
152            node_colors,
153            link_color,
154            link_colors,
155        );
156
157        Self { traces, layout }
158    }
159
160    #[allow(clippy::too_many_arguments)]
161    fn create_traces(
162        data: &DataFrame,
163        sources: &str,
164        targets: &str,
165        values: &str,
166        orientation: Option<Orientation>,
167        arrangement: Option<Arrangement>,
168        pad: Option<usize>,
169        thickness: Option<usize>,
170        node_color: Option<Rgb>,
171        node_colors: Option<Vec<Rgb>>,
172        link_color: Option<Rgb>,
173        link_colors: Option<Vec<Rgb>>,
174    ) -> Vec<Box<dyn Trace + 'static>> {
175        let mut traces: Vec<Box<dyn Trace + 'static>> = Vec::new();
176
177        let trace = Self::create_trace(
178            data,
179            sources,
180            targets,
181            values,
182            orientation,
183            arrangement,
184            pad,
185            thickness,
186            node_color,
187            node_colors,
188            link_color,
189            link_colors,
190        );
191
192        traces.push(trace);
193        traces
194    }
195
196    #[allow(clippy::too_many_arguments)]
197    fn create_trace(
198        data: &DataFrame,
199        sources: &str,
200        targets: &str,
201        values: &str,
202        orientation: Option<Orientation>,
203        arrangement: Option<Arrangement>,
204        pad: Option<usize>,
205        thickness: Option<usize>,
206        node_color: Option<Rgb>,
207        node_colors: Option<Vec<Rgb>>,
208        link_color: Option<Rgb>,
209        link_colors: Option<Vec<Rgb>>,
210    ) -> Box<dyn Trace + 'static> {
211        let sources = Self::get_string_column(data, sources);
212        let targets = Self::get_string_column(data, targets);
213        let values = Self::get_numeric_column(data, values);
214
215        let (labels_unique, label_to_idx) = Self::build_label_index(&sources, &targets);
216
217        let sources_idx = Self::column_to_indices(&sources, &label_to_idx);
218        let targets_idx = Self::column_to_indices(&targets, &label_to_idx);
219
220        let mut node = Node::new().label(labels_unique);
221
222        node = Self::set_pad(node, pad);
223        node = Self::set_thickness(node, thickness);
224        node = Self::set_node_color(node, node_color);
225        node = Self::set_node_colors(node, node_colors);
226
227        let mut link = Link::new()
228            .source(sources_idx)
229            .target(targets_idx)
230            .value(values);
231
232        link = Self::set_link_color(link, link_color);
233        link = Self::set_link_colors(link, link_colors);
234
235        let mut trace = Sankey::new().node(node).link(link);
236
237        trace = Self::set_orientation(trace, orientation);
238        trace = Self::set_arrangement(trace, arrangement);
239        trace
240    }
241
242    fn set_thickness(mut node: Node, thickness: Option<usize>) -> Node {
243        if let Some(thickness) = thickness {
244            node = node.thickness(thickness);
245        }
246
247        node
248    }
249
250    fn set_pad(mut node: Node, pad: Option<usize>) -> Node {
251        if let Some(pad) = pad {
252            node = node.pad(pad);
253        }
254
255        node
256    }
257
258    fn set_link_colors<V>(mut link: Link<V>, colors: Option<Vec<Rgb>>) -> Link<V>
259    where
260        V: Serialize + Clone,
261    {
262        if let Some(colors) = colors {
263            link = link.color_array(colors.iter().map(|color| color.to_plotly()).collect());
264        }
265
266        link
267    }
268
269    fn set_link_color<V>(mut link: Link<V>, color: Option<Rgb>) -> Link<V>
270    where
271        V: Serialize + Clone,
272    {
273        if let Some(color) = color {
274            link = link.color(color);
275        }
276
277        link
278    }
279
280    fn set_node_colors(mut node: Node, colors: Option<Vec<Rgb>>) -> Node {
281        if let Some(colors) = colors {
282            node = node.color_array(colors.iter().map(|color| color.to_plotly()).collect());
283        }
284
285        node
286    }
287
288    fn set_node_color(mut node: Node, color: Option<Rgb>) -> Node {
289        if let Some(color) = color {
290            node = node.color(color);
291        }
292
293        node
294    }
295
296    fn set_arrangement(
297        mut trace: Box<Sankey<Option<f32>>>,
298        arrangement: Option<Arrangement>,
299    ) -> Box<Sankey<Option<f32>>> {
300        if let Some(arrangement) = arrangement {
301            trace = trace.arrangement(arrangement.to_plotly())
302        }
303
304        trace
305    }
306
307    fn set_orientation(
308        mut trace: Box<Sankey<Option<f32>>>,
309        orientation: Option<Orientation>,
310    ) -> Box<Sankey<Option<f32>>> {
311        if let Some(orientation) = orientation {
312            trace = trace.orientation(orientation.to_plotly())
313        }
314
315        trace
316    }
317
318    fn build_label_index<'a>(
319        sources: &'a [Option<String>],
320        targets: &'a [Option<String>],
321    ) -> (Vec<&'a str>, HashMap<&'a str, usize>) {
322        let mut label_to_idx: HashMap<&'a str, usize> = HashMap::new();
323        let mut labels_unique: Vec<&'a str> = Vec::new();
324
325        let iter = sources
326            .iter()
327            .chain(targets.iter())
328            .filter_map(|opt| opt.as_deref());
329
330        for lbl in iter {
331            if let Entry::Vacant(v) = label_to_idx.entry(lbl) {
332                let next_id = labels_unique.len();
333                v.insert(next_id);
334                labels_unique.push(lbl);
335            }
336        }
337
338        (labels_unique, label_to_idx)
339    }
340
341    fn column_to_indices(
342        column: &[Option<String>],
343        label_to_idx: &HashMap<&str, usize>,
344    ) -> Vec<usize> {
345        column
346            .iter()
347            .filter_map(|opt| opt.as_deref())
348            .map(|lbl| *label_to_idx.get(lbl).expect("label must exist in map"))
349            .collect()
350    }
351}
352
353impl Layout for SankeyDiagram {}
354impl Polar for SankeyDiagram {}
355
356impl PlotHelper for SankeyDiagram {
357    fn get_layout(&self) -> &LayoutPlotly {
358        &self.layout
359    }
360
361    fn get_traces(&self) -> &Vec<Box<dyn Trace + 'static>> {
362        &self.traces
363    }
364}