1use bon::bon;
2use std::collections::{hash_map::Entry, HashMap};
3
4use plotly::{
5 common::{Anchor, Domain},
6 layout::Annotation,
7 sankey::{Link, Node},
8 Layout as LayoutPlotly, Sankey, Trace,
9};
10
11use polars::frame::DataFrame;
12use serde::Serialize;
13
14use crate::{
15 common::{Layout, PlotHelper, Polar},
16 components::{Arrangement, FacetConfig, Legend, Orientation, Rgb, Text},
17};
18
19#[derive(Clone, Serialize)]
100pub struct SankeyDiagram {
101 traces: Vec<Box<dyn Trace + 'static>>,
102 layout: LayoutPlotly,
103}
104
105#[bon]
106impl SankeyDiagram {
107 #[builder(on(String, into), on(Text, into))]
108 pub fn new(
109 data: &DataFrame,
110 sources: &str,
111 targets: &str,
112 values: &str,
113 facet: Option<&str>,
114 facet_config: Option<&FacetConfig>,
115 orientation: Option<Orientation>,
116 arrangement: Option<Arrangement>,
117 pad: Option<usize>,
118 thickness: Option<usize>,
119 node_color: Option<Rgb>,
120 node_colors: Option<Vec<Rgb>>,
121 link_color: Option<Rgb>,
122 link_colors: Option<Vec<Rgb>>,
123 plot_title: Option<Text>,
124 legend_title: Option<Text>,
125 legend: Option<&Legend>,
126 ) -> Self {
127 let x_title = None;
128 let y_title = None;
129 let z_title = None;
130 let x_axis = None;
131 let y_axis = None;
132 let z_axis = None;
133 let y2_title = None;
134 let y2_axis = None;
135
136 let (layout, traces) = match facet {
137 Some(facet_column) => {
138 let config = facet_config.cloned().unwrap_or_default();
139
140 let layout = Self::create_faceted_layout(
141 data,
142 facet_column,
143 &config,
144 plot_title,
145 legend_title,
146 legend,
147 );
148
149 let traces = Self::create_faceted_traces(
150 data,
151 sources,
152 targets,
153 values,
154 facet_column,
155 &config,
156 orientation,
157 arrangement,
158 pad,
159 thickness,
160 node_color,
161 node_colors,
162 link_color,
163 link_colors,
164 );
165
166 (layout, traces)
167 }
168 None => {
169 let layout = Self::create_layout(
170 plot_title,
171 x_title,
172 y_title,
173 y2_title,
174 z_title,
175 legend_title,
176 x_axis,
177 y_axis,
178 y2_axis,
179 z_axis,
180 legend,
181 None,
182 );
183
184 let traces = Self::create_traces(
185 data,
186 sources,
187 targets,
188 values,
189 orientation,
190 arrangement,
191 pad,
192 thickness,
193 node_color,
194 node_colors,
195 link_color,
196 link_colors,
197 );
198
199 (layout, traces)
200 }
201 };
202
203 Self { traces, layout }
204 }
205
206 #[allow(clippy::too_many_arguments)]
207 fn create_traces(
208 data: &DataFrame,
209 sources: &str,
210 targets: &str,
211 values: &str,
212 orientation: Option<Orientation>,
213 arrangement: Option<Arrangement>,
214 pad: Option<usize>,
215 thickness: Option<usize>,
216 node_color: Option<Rgb>,
217 node_colors: Option<Vec<Rgb>>,
218 link_color: Option<Rgb>,
219 link_colors: Option<Vec<Rgb>>,
220 ) -> Vec<Box<dyn Trace + 'static>> {
221 let mut traces: Vec<Box<dyn Trace + 'static>> = Vec::new();
222
223 let trace = Self::create_trace(
224 data,
225 sources,
226 targets,
227 values,
228 orientation,
229 arrangement,
230 pad,
231 thickness,
232 node_color,
233 node_colors,
234 link_color,
235 link_colors,
236 None,
237 );
238
239 traces.push(trace);
240 traces
241 }
242
243 #[allow(clippy::too_many_arguments)]
244 fn create_trace(
245 data: &DataFrame,
246 sources: &str,
247 targets: &str,
248 values: &str,
249 orientation: Option<Orientation>,
250 arrangement: Option<Arrangement>,
251 pad: Option<usize>,
252 thickness: Option<usize>,
253 node_color: Option<Rgb>,
254 node_colors: Option<Vec<Rgb>>,
255 link_color: Option<Rgb>,
256 link_colors: Option<Vec<Rgb>>,
257 domain: Option<Domain>,
258 ) -> Box<dyn Trace + 'static> {
259 let sources = Self::get_string_column(data, sources);
260 let targets = Self::get_string_column(data, targets);
261 let values = Self::get_numeric_column(data, values);
262
263 let (labels_unique, label_to_idx) = Self::build_label_index(&sources, &targets);
264
265 let sources_idx = Self::column_to_indices(&sources, &label_to_idx);
266 let targets_idx = Self::column_to_indices(&targets, &label_to_idx);
267
268 let mut node = Node::new().label(labels_unique);
269
270 node = Self::set_pad(node, pad);
271 node = Self::set_thickness(node, thickness);
272 node = Self::set_node_color(node, node_color);
273 node = Self::set_node_colors(node, node_colors);
274
275 let mut link = Link::new()
276 .source(sources_idx)
277 .target(targets_idx)
278 .value(values);
279
280 link = Self::set_link_color(link, link_color);
281 link = Self::set_link_colors(link, link_colors);
282
283 let mut trace = Sankey::new().node(node).link(link);
284
285 trace = Self::set_orientation(trace, orientation);
286 trace = Self::set_arrangement(trace, arrangement);
287
288 if let Some(domain_val) = domain {
289 trace = trace.domain(domain_val);
290 }
291
292 trace
293 }
294
295 #[allow(clippy::too_many_arguments)]
296 fn create_faceted_traces(
297 data: &DataFrame,
298 sources: &str,
299 targets: &str,
300 values: &str,
301 facet_column: &str,
302 config: &FacetConfig,
303 orientation: Option<Orientation>,
304 arrangement: Option<Arrangement>,
305 pad: Option<usize>,
306 thickness: Option<usize>,
307 node_color: Option<Rgb>,
308 node_colors: Option<Vec<Rgb>>,
309 link_color: Option<Rgb>,
310 link_colors: Option<Vec<Rgb>>,
311 ) -> Vec<Box<dyn Trace + 'static>> {
312 const MAX_FACETS: usize = 8;
313
314 let facet_categories = Self::get_unique_groups(data, facet_column, config.sorter);
315
316 if facet_categories.len() > MAX_FACETS {
317 panic!(
318 "Facet column '{}' has {} unique values, but plotly.rs supports maximum {} subplots",
319 facet_column,
320 facet_categories.len(),
321 MAX_FACETS
322 );
323 }
324
325 let n_facets = facet_categories.len();
326 let (ncols, nrows) = Self::calculate_grid_dimensions(n_facets, config.cols, config.rows);
327
328 let facet_categories_non_empty: Vec<String> = facet_categories
330 .iter()
331 .filter(|facet_value| {
332 let facet_data = Self::filter_data_by_group(data, facet_column, facet_value);
333 facet_data.height() > 0
334 })
335 .cloned()
336 .collect();
337
338 let mut all_traces = Vec::new();
339
340 let node_colors_cloned = node_colors.clone();
342 let link_colors_cloned = link_colors.clone();
343
344 for (idx, facet_value) in facet_categories_non_empty.iter().enumerate() {
345 let facet_data = Self::filter_data_by_group(data, facet_column, facet_value);
346
347 let domain =
348 Self::calculate_sankey_domain(idx, ncols, nrows, config.h_gap, config.v_gap);
349
350 let trace = Self::create_trace(
351 &facet_data,
352 sources,
353 targets,
354 values,
355 orientation.clone(),
356 arrangement.clone(),
357 pad,
358 thickness,
359 node_color,
360 node_colors_cloned.clone(),
361 link_color,
362 link_colors_cloned.clone(),
363 Some(domain),
364 );
365
366 all_traces.push(trace);
367 }
368
369 all_traces
370 }
371
372 fn create_faceted_layout(
373 data: &DataFrame,
374 facet_column: &str,
375 config: &FacetConfig,
376 plot_title: Option<Text>,
377 legend_title: Option<Text>,
378 legend: Option<&Legend>,
379 ) -> LayoutPlotly {
380 let facet_categories = Self::get_unique_groups(data, facet_column, config.sorter);
381
382 let facet_categories_non_empty: Vec<String> = facet_categories
384 .iter()
385 .filter(|facet_value| {
386 let facet_data = Self::filter_data_by_group(data, facet_column, facet_value);
387 facet_data.height() > 0
388 })
389 .cloned()
390 .collect();
391
392 let n_facets = facet_categories_non_empty.len();
393 let (ncols, nrows) = Self::calculate_grid_dimensions(n_facets, config.cols, config.rows);
394
395 let mut layout = LayoutPlotly::new();
396
397 if let Some(title) = plot_title {
398 layout = layout.title(title.to_plotly());
399 }
400
401 let annotations = Self::create_facet_annotations_sankey(
402 &facet_categories_non_empty,
403 ncols,
404 nrows,
405 config.title_style.as_ref(),
406 config.h_gap,
407 config.v_gap,
408 );
409 layout = layout.annotations(annotations);
410
411 layout = layout.legend(Legend::set_legend(legend_title, legend));
412
413 layout
414 }
415
416 fn calculate_facet_cell(
421 subplot_index: usize,
422 ncols: usize,
423 nrows: usize,
424 x_gap: Option<f64>,
425 y_gap: Option<f64>,
426 ) -> FacetCell {
427 let row = subplot_index / ncols;
428 let col = subplot_index % ncols;
429
430 let x_gap_val = x_gap.unwrap_or(0.05);
431 let y_gap_val = y_gap.unwrap_or(0.10);
432
433 const TITLE_HEIGHT_RATIO: f64 = 0.10;
435 const TITLE_PADDING_RATIO: f64 = 0.35;
437
438 let cell_width = (1.0 - x_gap_val * (ncols - 1) as f64) / ncols as f64;
440 let cell_height = (1.0 - y_gap_val * (nrows - 1) as f64) / nrows as f64;
441
442 let cell_x_start = col as f64 * (cell_width + x_gap_val);
444 let cell_y_top = 1.0 - row as f64 * (cell_height + y_gap_val);
445 let cell_y_bottom = cell_y_top - cell_height;
446
447 let title_height = cell_height * TITLE_HEIGHT_RATIO;
449 let sankey_y_top = cell_y_top - title_height;
450
451 let sankey_x_start = cell_x_start;
453 let sankey_x_end = cell_x_start + cell_width;
454 let sankey_y_start = cell_y_bottom;
455 let sankey_y_end = sankey_y_top;
456
457 let padding_height = title_height * TITLE_PADDING_RATIO;
459 let actual_title_height = title_height - padding_height;
460 let annotation_x = cell_x_start + cell_width / 2.0;
461 let annotation_y = sankey_y_top + padding_height + (actual_title_height / 2.0);
462
463 FacetCell {
464 domain_x_start: sankey_x_start,
465 domain_x_end: sankey_x_end,
466 domain_y_start: sankey_y_start,
467 domain_y_end: sankey_y_end,
468 annotation_x,
469 annotation_y,
470 }
471 }
472
473 fn calculate_sankey_domain(
474 subplot_index: usize,
475 ncols: usize,
476 nrows: usize,
477 x_gap: Option<f64>,
478 y_gap: Option<f64>,
479 ) -> Domain {
480 let cell = Self::calculate_facet_cell(subplot_index, ncols, nrows, x_gap, y_gap);
481 Domain::new()
482 .x(&[cell.domain_x_start, cell.domain_x_end])
483 .y(&[cell.domain_y_start, cell.domain_y_end])
484 }
485
486 fn create_facet_annotations_sankey(
487 categories: &[String],
488 ncols: usize,
489 nrows: usize,
490 title_style: Option<&Text>,
491 x_gap: Option<f64>,
492 y_gap: Option<f64>,
493 ) -> Vec<Annotation> {
494 categories
495 .iter()
496 .enumerate()
497 .map(|(i, cat)| {
498 let cell = Self::calculate_facet_cell(i, ncols, nrows, x_gap, y_gap);
499
500 let mut ann = Annotation::new()
501 .text(cat.as_str())
502 .x_ref("paper")
503 .y_ref("paper")
504 .x_anchor(Anchor::Center)
505 .y_anchor(Anchor::Middle)
506 .x(cell.annotation_x)
507 .y(cell.annotation_y)
508 .show_arrow(false);
509
510 if let Some(style) = title_style {
511 ann = ann.font(style.to_font());
512 }
513
514 ann
515 })
516 .collect()
517 }
518
519 fn set_thickness(mut node: Node, thickness: Option<usize>) -> Node {
520 if let Some(thickness) = thickness {
521 node = node.thickness(thickness);
522 }
523
524 node
525 }
526
527 fn set_pad(mut node: Node, pad: Option<usize>) -> Node {
528 if let Some(pad) = pad {
529 node = node.pad(pad);
530 }
531
532 node
533 }
534
535 fn set_link_colors<V>(mut link: Link<V>, colors: Option<Vec<Rgb>>) -> Link<V>
536 where
537 V: Serialize + Clone,
538 {
539 if let Some(colors) = colors {
540 link = link.color_array(colors.iter().map(|color| color.to_plotly()).collect());
541 }
542
543 link
544 }
545
546 fn set_link_color<V>(mut link: Link<V>, color: Option<Rgb>) -> Link<V>
547 where
548 V: Serialize + Clone,
549 {
550 if let Some(color) = color {
551 link = link.color(color);
552 }
553
554 link
555 }
556
557 fn set_node_colors(mut node: Node, colors: Option<Vec<Rgb>>) -> Node {
558 if let Some(colors) = colors {
559 node = node.color_array(colors.iter().map(|color| color.to_plotly()).collect());
560 }
561
562 node
563 }
564
565 fn set_node_color(mut node: Node, color: Option<Rgb>) -> Node {
566 if let Some(color) = color {
567 node = node.color(color);
568 }
569
570 node
571 }
572
573 fn set_arrangement(
574 mut trace: Box<Sankey<Option<f32>>>,
575 arrangement: Option<Arrangement>,
576 ) -> Box<Sankey<Option<f32>>> {
577 if let Some(arrangement) = arrangement {
578 trace = trace.arrangement(arrangement.to_plotly())
579 }
580
581 trace
582 }
583
584 fn set_orientation(
585 mut trace: Box<Sankey<Option<f32>>>,
586 orientation: Option<Orientation>,
587 ) -> Box<Sankey<Option<f32>>> {
588 if let Some(orientation) = orientation {
589 trace = trace.orientation(orientation.to_plotly())
590 }
591
592 trace
593 }
594
595 fn build_label_index<'a>(
596 sources: &'a [Option<String>],
597 targets: &'a [Option<String>],
598 ) -> (Vec<&'a str>, HashMap<&'a str, usize>) {
599 let mut label_to_idx: HashMap<&'a str, usize> = HashMap::new();
600 let mut labels_unique: Vec<&'a str> = Vec::new();
601
602 let iter = sources
603 .iter()
604 .chain(targets.iter())
605 .filter_map(|opt| opt.as_deref());
606
607 for lbl in iter {
608 if let Entry::Vacant(v) = label_to_idx.entry(lbl) {
609 let next_id = labels_unique.len();
610 v.insert(next_id);
611 labels_unique.push(lbl);
612 }
613 }
614
615 (labels_unique, label_to_idx)
616 }
617
618 fn column_to_indices(
619 column: &[Option<String>],
620 label_to_idx: &HashMap<&str, usize>,
621 ) -> Vec<usize> {
622 column
623 .iter()
624 .filter_map(|opt| opt.as_deref())
625 .map(|lbl| *label_to_idx.get(lbl).expect("label must exist in map"))
626 .collect()
627 }
628}
629
630struct FacetCell {
632 domain_x_start: f64,
633 domain_x_end: f64,
634 domain_y_start: f64,
635 domain_y_end: f64,
636 annotation_x: f64,
637 annotation_y: f64,
638}
639
640impl Layout for SankeyDiagram {}
641impl Polar for SankeyDiagram {}
642
643impl PlotHelper for SankeyDiagram {
644 fn get_layout(&self) -> &LayoutPlotly {
645 &self.layout
646 }
647
648 fn get_traces(&self) -> &Vec<Box<dyn Trace + 'static>> {
649 &self.traces
650 }
651}