plotlars/plots/
sankeydiagram.rs

1use bon::bon;
2use std::collections::{
3    HashMap,
4    hash_map::Entry,
5};
6
7use plotly::{
8    sankey::{Link, Node}, Layout as LayoutPlotly, Sankey, Trace
9};
10
11use polars::frame::DataFrame;
12use serde::Serialize;
13
14use crate::{
15    common::{Layout, PlotHelper, Polar},
16    components::{Arrangement, Orientation, Rgb, Text},
17};
18
19/// A structure representing a Sankey diagram.
20///
21/// The `SankeyDiagram` struct enables the creation of Sankey diagrams, which visualize flows
22/// between discrete nodes with link widths proportional to the magnitude of the flow. It
23/// offers extensive configuration options for flow orientation, node arrangement, spacing,
24/// thickness, and coloring, as well as axis and title customization. Users can specify a
25/// single uniform color or per-item colors for both nodes and links, adjust padding between
26/// nodes, set node thickness, and supply custom titles and axis labels to produce clear,
27/// publication-quality flow visualizations.
28///
29/// # Arguments
30///
31/// * `data` – A reference to the `DataFrame` containing the data to be plotted.
32/// * `sources` – A string slice naming the column in `data` that contains the source node for each flow.
33/// * `targets` – A string slice naming the column in `data` that contains the target node for each flow.
34/// * `values` – A string slice naming the column in `data` that contains the numeric value of each flow.
35/// * `orientation` – An optional `Orientation` enum to set the overall direction of the diagram
36///   (e.g. `Orientation::Horizontal` or `Orientation::Vertical`).
37/// * `arrangement` – An optional `Arrangement` enum to choose the node-layout algorithm
38///   (e.g. `Arrangement::Snap`, `Arrangement::Perpendicular`, etc.).
39/// * `pad` – An optional `usize` specifying the padding (in pixels) between adjacent nodes.
40/// * `thickness` – An optional `usize` defining the uniform thickness (in pixels) of all nodes.
41/// * `node_color` – An optional `Rgb` value to apply a single uniform color to every node.
42/// * `node_colors` – An optional `Vec<Rgb>` supplying individual colors for each node in order.
43/// * `link_color` – An optional `Rgb` value to apply a single uniform color to every link.
44/// * `link_colors` – An optional `Vec<Rgb>` supplying individual colors for each link in order.
45/// * `plot_title` – An optional `Text` struct for setting the overall title of the plot.
46///
47/// # Example
48///
49/// ```rust
50/// use plotlars::{Arrangement, SankeyDiagram, Orientation, Plot, Rgb, Text};
51///
52/// let dataset = df![
53///         "source" => ["A1", "A2", "A1", "B1", "B2", "B2"],
54///         "target" => &["B1", "B2", "B2", "C1", "C1", "C2"],
55///         "value" => &[8, 4, 2, 8, 4, 2],
56///     ]
57///     .unwrap();
58///
59/// SankeyDiagram::builder()
60///     .data(&dataset)
61///     .sources("source")
62///     .targets("target")
63///     .values("value")
64///     .orientation(Orientation::Horizontal)
65///     .arrangement(Arrangement::Freeform)
66///     .node_colors(vec![
67///         Rgb(222, 235, 247),
68///         Rgb(198, 219, 239),
69///         Rgb(158, 202, 225),
70///         Rgb(107, 174, 214),
71///         Rgb( 66, 146, 198),
72///         Rgb( 33, 113, 181),
73///     ])
74///     .link_colors(vec![
75///         Rgb(222, 235, 247),
76///         Rgb(198, 219, 239),
77///         Rgb(158, 202, 225),
78///         Rgb(107, 174, 214),
79///         Rgb( 66, 146, 198),
80///         Rgb( 33, 113, 181),
81///     ])
82///     .pad(20)
83///     .thickness(30)
84///     .plot_title(
85///         Text::from("Sankey Diagram")
86///             .font("Arial")
87///             .size(18)
88///     )
89///     .build()
90///     .plot();
91/// ```
92///
93/// ![Example](https://imgur.com/jvAew8u.png)
94#[derive(Clone, Serialize)]
95pub struct SankeyDiagram {
96    traces: Vec<Box<dyn Trace + 'static>>,
97    layout: LayoutPlotly,
98}
99
100#[bon]
101impl SankeyDiagram {
102    #[builder(on(String, into), on(Text, into))]
103    pub fn new(
104        data: &DataFrame,
105        sources: &str,
106        targets: &str,
107        values: &str,
108        orientation: Option<Orientation>,
109        arrangement: Option<Arrangement>,
110        pad: Option<usize>,
111        thickness: Option<usize>,
112        node_color: Option<Rgb>,
113        node_colors: Option<Vec<Rgb>>,
114        link_color: Option<Rgb>,
115        link_colors: Option<Vec<Rgb>>,
116        plot_title: Option<Text>,
117    ) -> Self {
118        let legend = None;
119        let legend_title = None;
120        let x_title = None;
121        let y_title = None;
122        let z_title = None;
123        let x_axis = None;
124        let y_axis = None;
125        let z_axis = None;
126
127        let layout = Self::create_layout(
128            plot_title,
129            x_title,
130            y_title,
131            None, // y2_title,
132            z_title,
133            legend_title,
134            x_axis,
135            y_axis,
136            None, // y2_axis,
137            z_axis,
138            legend,
139        );
140
141        let traces = Self::create_traces(
142            data,
143            sources,
144            targets,
145            values,
146            orientation,
147            arrangement,
148            pad,
149            thickness,
150            node_color,
151            node_colors,
152            link_color,
153            link_colors,
154        );
155
156        Self { traces, layout }
157    }
158
159    #[allow(clippy::too_many_arguments)]
160    fn create_traces(
161        data: &DataFrame,
162        sources: &str,
163        targets: &str,
164        values: &str,
165        orientation: Option<Orientation>,
166        arrangement: Option<Arrangement>,
167        pad: Option<usize>,
168        thickness: Option<usize>,
169        node_color: Option<Rgb>,
170        node_colors: Option<Vec<Rgb>>,
171        link_color: Option<Rgb>,
172        link_colors: Option<Vec<Rgb>>,
173    ) -> Vec<Box<dyn Trace + 'static>> {
174        let mut traces: Vec<Box<dyn Trace + 'static>> = Vec::new();
175
176        let trace = Self::create_trace(
177            data,
178            sources,
179            targets,
180            values,
181            orientation,
182            arrangement,
183            pad,
184            thickness,
185            node_color,
186            node_colors,
187            link_color,
188            link_colors,
189        );
190
191        traces.push(trace);
192        traces
193    }
194
195    #[allow(clippy::too_many_arguments)]
196    fn create_trace(
197        data: &DataFrame,
198        sources: &str,
199        targets: &str,
200        values: &str,
201        orientation: Option<Orientation>,
202        arrangement: Option<Arrangement>,
203        pad: Option<usize>,
204        thickness: Option<usize>,
205        node_color: Option<Rgb>,
206        node_colors: Option<Vec<Rgb>>,
207        link_color: Option<Rgb>,
208        link_colors: Option<Vec<Rgb>>,
209    ) -> Box<dyn Trace + 'static> {
210        let sources = Self::get_string_column(data, sources);
211        let targets = Self::get_string_column(data, targets);
212        let values = Self::get_numeric_column(data, values);
213
214        let (labels_unique, label_to_idx) = Self::build_label_index(&sources, &targets);
215
216        let sources_idx = Self::column_to_indices(&sources, &label_to_idx);
217        let targets_idx = Self::column_to_indices(&targets, &label_to_idx);
218
219        let mut node = Node::new()
220            .label(labels_unique);
221
222        node = Self::set_pad(node, pad);
223        node = Self::set_thickness(node, thickness);
224        node = Self::set_node_color(node, node_color);
225        node = Self::set_node_colors(node, node_colors);
226
227        let mut link = Link::new()
228            .source(sources_idx)
229            .target(targets_idx)
230            .value(values);
231
232        link = Self::set_link_color(link, link_color);
233        link = Self::set_link_colors(link, link_colors);
234
235        let mut trace = Sankey::new()
236            .node(node)
237            .link(link);
238
239        trace = Self::set_orientation(trace, orientation);
240        trace = Self::set_arrangement(trace, arrangement);
241        trace
242    }
243
244    fn set_thickness(
245        mut node: Node,
246        thickness: Option<usize>,
247    ) -> Node     {
248        if let Some(thickness) = thickness {
249            node = node.thickness(thickness);
250        }
251
252        node
253    }
254
255    fn set_pad(
256        mut node: Node,
257        pad: Option<usize>,
258    ) -> Node     {
259        if let Some(pad) = pad {
260            node = node.pad(pad);
261        }
262
263        node
264    }
265
266    fn set_link_colors<V>(
267        mut link: Link<V>,
268        colors: Option<Vec<Rgb>>,
269    ) -> Link<V>
270    where
271        V: Serialize + Clone,
272    {
273        if let Some(colors) = colors {
274            link = link.color_array(
275                colors
276                    .iter()
277                    .map(|color| color.to_plotly())
278                    .collect()
279            );
280        }
281
282        link
283    }
284
285    fn set_link_color<V>(
286        mut link: Link<V>,
287        color: Option<Rgb>,
288    ) -> Link<V>
289    where
290        V: Serialize + Clone,
291    {
292        if let Some(color) = color {
293            link = link.color(color);
294        }
295
296        link
297    }
298
299    fn set_node_colors(
300        mut node: Node,
301        colors: Option<Vec<Rgb>>,
302    ) -> Node {
303        if let Some(colors) = colors {
304            node = node.color_array(
305                colors
306                    .iter()
307                    .map(|color| color.to_plotly())
308                    .collect()
309            );
310        }
311
312        node
313    }
314
315    fn set_node_color(
316        mut node: Node,
317        color: Option<Rgb>,
318    ) -> Node {
319        if let Some(color) = color {
320            node = node.color(color);
321        }
322
323        node
324    }
325
326    fn set_arrangement(
327        mut trace: Box<Sankey<Option<f32>>>,
328        arrangement: Option<Arrangement>,
329    ) -> Box<Sankey<Option<f32>>> {
330        if let Some(arrangement) = arrangement {
331            trace = trace.arrangement(arrangement.to_plotly())
332        }
333
334        trace
335    }
336
337    fn set_orientation(
338        mut trace: Box<Sankey<Option<f32>>>,
339        orientation: Option<Orientation>,
340    ) -> Box<Sankey<Option<f32>>> {
341        if let Some(orientation) = orientation {
342            trace = trace.orientation(orientation.to_plotly())
343        }
344
345        trace
346    }
347
348    fn build_label_index<'a>(
349        sources: &'a [Option<String>],
350        targets: &'a [Option<String>],
351    ) -> (Vec<&'a str>, HashMap<&'a str, usize>) {
352        let mut label_to_idx: HashMap<&'a str, usize> = HashMap::new();
353        let mut labels_unique: Vec<&'a str> = Vec::new();
354
355        let iter = sources
356            .iter()
357            .chain(targets.iter())
358            .filter_map(|opt| opt.as_deref());
359
360        for lbl in iter {
361            if let Entry::Vacant(v) = label_to_idx.entry(lbl) {
362                let next_id = labels_unique.len();
363                v.insert(next_id);
364                labels_unique.push(lbl);
365            }
366        }
367
368        (labels_unique, label_to_idx)
369    }
370
371    fn column_to_indices(
372        column: &[Option<String>],
373        label_to_idx: &HashMap<&str, usize>,
374    ) -> Vec<usize> {
375        column
376            .iter()
377            .filter_map(|opt| opt.as_deref())
378            .map(|lbl| *label_to_idx.get(lbl).expect("label must exist in map"))
379            .collect()
380    }
381}
382
383impl Layout for SankeyDiagram {}
384impl Polar for SankeyDiagram {}
385
386impl PlotHelper for SankeyDiagram {
387    fn get_layout(&self) -> &LayoutPlotly {
388        &self.layout
389    }
390
391    fn get_traces(&self) -> &Vec<Box<dyn Trace + 'static>> {
392        &self.traces
393    }
394}