1use bon::bon;
2use std::collections::{HashMap, hash_map::Entry};
3
4use plotly::{
5 Layout as LayoutPlotly, Sankey, Trace,
6 sankey::{Link, Node},
7};
8
9use polars::frame::DataFrame;
10use serde::Serialize;
11
12use crate::{
13 common::{Layout, PlotHelper, Polar},
14 components::{Arrangement, Orientation, Rgb, Text},
15};
16
17#[derive(Clone, Serialize)]
94pub struct SankeyDiagram {
95 traces: Vec<Box<dyn Trace + 'static>>,
96 layout: LayoutPlotly,
97}
98
99#[bon]
100impl SankeyDiagram {
101 #[builder(on(String, into), on(Text, into))]
102 pub fn new(
103 data: &DataFrame,
104 sources: &str,
105 targets: &str,
106 values: &str,
107 orientation: Option<Orientation>,
108 arrangement: Option<Arrangement>,
109 pad: Option<usize>,
110 thickness: Option<usize>,
111 node_color: Option<Rgb>,
112 node_colors: Option<Vec<Rgb>>,
113 link_color: Option<Rgb>,
114 link_colors: Option<Vec<Rgb>>,
115 plot_title: Option<Text>,
116 ) -> Self {
117 let legend = None;
118 let legend_title = None;
119 let x_title = None;
120 let y_title = None;
121 let z_title = None;
122 let x_axis = None;
123 let y_axis = None;
124 let z_axis = None;
125 let y2_title = None;
126 let y2_axis = None;
127
128 let layout = Self::create_layout(
129 plot_title,
130 x_title,
131 y_title,
132 y2_title,
133 z_title,
134 legend_title,
135 x_axis,
136 y_axis,
137 y2_axis,
138 z_axis,
139 legend,
140 );
141
142 let traces = Self::create_traces(
143 data,
144 sources,
145 targets,
146 values,
147 orientation,
148 arrangement,
149 pad,
150 thickness,
151 node_color,
152 node_colors,
153 link_color,
154 link_colors,
155 );
156
157 Self { traces, layout }
158 }
159
160 #[allow(clippy::too_many_arguments)]
161 fn create_traces(
162 data: &DataFrame,
163 sources: &str,
164 targets: &str,
165 values: &str,
166 orientation: Option<Orientation>,
167 arrangement: Option<Arrangement>,
168 pad: Option<usize>,
169 thickness: Option<usize>,
170 node_color: Option<Rgb>,
171 node_colors: Option<Vec<Rgb>>,
172 link_color: Option<Rgb>,
173 link_colors: Option<Vec<Rgb>>,
174 ) -> Vec<Box<dyn Trace + 'static>> {
175 let mut traces: Vec<Box<dyn Trace + 'static>> = Vec::new();
176
177 let trace = Self::create_trace(
178 data,
179 sources,
180 targets,
181 values,
182 orientation,
183 arrangement,
184 pad,
185 thickness,
186 node_color,
187 node_colors,
188 link_color,
189 link_colors,
190 );
191
192 traces.push(trace);
193 traces
194 }
195
196 #[allow(clippy::too_many_arguments)]
197 fn create_trace(
198 data: &DataFrame,
199 sources: &str,
200 targets: &str,
201 values: &str,
202 orientation: Option<Orientation>,
203 arrangement: Option<Arrangement>,
204 pad: Option<usize>,
205 thickness: Option<usize>,
206 node_color: Option<Rgb>,
207 node_colors: Option<Vec<Rgb>>,
208 link_color: Option<Rgb>,
209 link_colors: Option<Vec<Rgb>>,
210 ) -> Box<dyn Trace + 'static> {
211 let sources = Self::get_string_column(data, sources);
212 let targets = Self::get_string_column(data, targets);
213 let values = Self::get_numeric_column(data, values);
214
215 let (labels_unique, label_to_idx) = Self::build_label_index(&sources, &targets);
216
217 let sources_idx = Self::column_to_indices(&sources, &label_to_idx);
218 let targets_idx = Self::column_to_indices(&targets, &label_to_idx);
219
220 let mut node = Node::new().label(labels_unique);
221
222 node = Self::set_pad(node, pad);
223 node = Self::set_thickness(node, thickness);
224 node = Self::set_node_color(node, node_color);
225 node = Self::set_node_colors(node, node_colors);
226
227 let mut link = Link::new()
228 .source(sources_idx)
229 .target(targets_idx)
230 .value(values);
231
232 link = Self::set_link_color(link, link_color);
233 link = Self::set_link_colors(link, link_colors);
234
235 let mut trace = Sankey::new().node(node).link(link);
236
237 trace = Self::set_orientation(trace, orientation);
238 trace = Self::set_arrangement(trace, arrangement);
239 trace
240 }
241
242 fn set_thickness(mut node: Node, thickness: Option<usize>) -> Node {
243 if let Some(thickness) = thickness {
244 node = node.thickness(thickness);
245 }
246
247 node
248 }
249
250 fn set_pad(mut node: Node, pad: Option<usize>) -> Node {
251 if let Some(pad) = pad {
252 node = node.pad(pad);
253 }
254
255 node
256 }
257
258 fn set_link_colors<V>(mut link: Link<V>, colors: Option<Vec<Rgb>>) -> Link<V>
259 where
260 V: Serialize + Clone,
261 {
262 if let Some(colors) = colors {
263 link = link.color_array(colors.iter().map(|color| color.to_plotly()).collect());
264 }
265
266 link
267 }
268
269 fn set_link_color<V>(mut link: Link<V>, color: Option<Rgb>) -> Link<V>
270 where
271 V: Serialize + Clone,
272 {
273 if let Some(color) = color {
274 link = link.color(color);
275 }
276
277 link
278 }
279
280 fn set_node_colors(mut node: Node, colors: Option<Vec<Rgb>>) -> Node {
281 if let Some(colors) = colors {
282 node = node.color_array(colors.iter().map(|color| color.to_plotly()).collect());
283 }
284
285 node
286 }
287
288 fn set_node_color(mut node: Node, color: Option<Rgb>) -> Node {
289 if let Some(color) = color {
290 node = node.color(color);
291 }
292
293 node
294 }
295
296 fn set_arrangement(
297 mut trace: Box<Sankey<Option<f32>>>,
298 arrangement: Option<Arrangement>,
299 ) -> Box<Sankey<Option<f32>>> {
300 if let Some(arrangement) = arrangement {
301 trace = trace.arrangement(arrangement.to_plotly())
302 }
303
304 trace
305 }
306
307 fn set_orientation(
308 mut trace: Box<Sankey<Option<f32>>>,
309 orientation: Option<Orientation>,
310 ) -> Box<Sankey<Option<f32>>> {
311 if let Some(orientation) = orientation {
312 trace = trace.orientation(orientation.to_plotly())
313 }
314
315 trace
316 }
317
318 fn build_label_index<'a>(
319 sources: &'a [Option<String>],
320 targets: &'a [Option<String>],
321 ) -> (Vec<&'a str>, HashMap<&'a str, usize>) {
322 let mut label_to_idx: HashMap<&'a str, usize> = HashMap::new();
323 let mut labels_unique: Vec<&'a str> = Vec::new();
324
325 let iter = sources
326 .iter()
327 .chain(targets.iter())
328 .filter_map(|opt| opt.as_deref());
329
330 for lbl in iter {
331 if let Entry::Vacant(v) = label_to_idx.entry(lbl) {
332 let next_id = labels_unique.len();
333 v.insert(next_id);
334 labels_unique.push(lbl);
335 }
336 }
337
338 (labels_unique, label_to_idx)
339 }
340
341 fn column_to_indices(
342 column: &[Option<String>],
343 label_to_idx: &HashMap<&str, usize>,
344 ) -> Vec<usize> {
345 column
346 .iter()
347 .filter_map(|opt| opt.as_deref())
348 .map(|lbl| *label_to_idx.get(lbl).expect("label must exist in map"))
349 .collect()
350 }
351}
352
353impl Layout for SankeyDiagram {}
354impl Polar for SankeyDiagram {}
355
356impl PlotHelper for SankeyDiagram {
357 fn get_layout(&self) -> &LayoutPlotly {
358 &self.layout
359 }
360
361 fn get_traces(&self) -> &Vec<Box<dyn Trace + 'static>> {
362 &self.traces
363 }
364}