1use bon::bon;
2use std::collections::{hash_map::Entry, HashMap};
3
4use crate::{
5 components::{Arrangement, FacetConfig, Legend, Orientation, Rgb, Text},
6 ir::data::ColumnData,
7 ir::layout::LayoutIR,
8 ir::trace::{SankeyDiagramIR, TraceIR},
9};
10use polars::frame::DataFrame;
11
12#[derive(Clone)]
100#[allow(dead_code)]
101pub struct SankeyDiagram {
102 traces: Vec<TraceIR>,
103 layout: LayoutIR,
104}
105
106struct FacetCell {
107 domain_x_start: f64,
108 domain_x_end: f64,
109 domain_y_start: f64,
110 domain_y_end: f64,
111}
112
113#[bon]
114impl SankeyDiagram {
115 #[builder(on(String, into), on(Text, into))]
116 pub fn new(
117 data: &DataFrame,
118 sources: &str,
119 targets: &str,
120 values: &str,
121 facet: Option<&str>,
122 facet_config: Option<&FacetConfig>,
123 orientation: Option<Orientation>,
124 arrangement: Option<Arrangement>,
125 pad: Option<usize>,
126 thickness: Option<usize>,
127 node_color: Option<Rgb>,
128 node_colors: Option<Vec<Rgb>>,
129 link_color: Option<Rgb>,
130 link_colors: Option<Vec<Rgb>>,
131 plot_title: Option<Text>,
132 legend_title: Option<Text>,
133 legend: Option<&Legend>,
134 ) -> Self {
135 let grid = facet.map(|facet_column| {
136 let config = facet_config.cloned().unwrap_or_default();
137 let facet_categories =
138 crate::data::get_unique_groups(data, facet_column, config.sorter);
139 let n_facets = facet_categories.len();
140 let (ncols, nrows) =
141 crate::faceting::calculate_grid_dimensions(n_facets, config.cols, config.rows);
142 crate::ir::facet::GridSpec {
143 kind: crate::ir::facet::FacetKind::Domain,
144 rows: nrows,
145 cols: ncols,
146 h_gap: config.h_gap,
147 v_gap: config.v_gap,
148 scales: config.scales.clone(),
149 n_facets,
150 facet_categories,
151 title_style: config.title_style.clone(),
152 x_title: None,
153 y_title: None,
154 x_axis: None,
155 y_axis: None,
156 legend_title: legend_title.clone(),
157 legend: legend.cloned(),
158 }
159 });
160
161 let layout = LayoutIR {
162 title: plot_title,
163 x_title: None,
164 y_title: None,
165 y2_title: None,
166 z_title: None,
167 legend_title: if grid.is_some() { None } else { legend_title },
168 legend: if grid.is_some() {
169 None
170 } else {
171 legend.cloned()
172 },
173 dimensions: None,
174 bar_mode: None,
175 box_mode: None,
176 box_gap: None,
177 margin_bottom: None,
178 axes_2d: None,
179 scene_3d: None,
180 polar: None,
181 mapbox: None,
182 grid,
183 annotations: vec![],
184 };
185
186 let traces = match facet {
187 Some(facet_column) => {
188 let config = facet_config.cloned().unwrap_or_default();
189 Self::create_ir_traces_faceted(
190 data,
191 sources,
192 targets,
193 values,
194 facet_column,
195 &config,
196 orientation,
197 arrangement,
198 pad,
199 thickness,
200 node_color,
201 node_colors,
202 link_color,
203 link_colors,
204 )
205 }
206 None => Self::create_ir_traces(
207 data,
208 sources,
209 targets,
210 values,
211 orientation,
212 arrangement,
213 pad,
214 thickness,
215 node_color,
216 node_colors,
217 link_color,
218 link_colors,
219 ),
220 };
221 Self { traces, layout }
222 }
223}
224
225#[bon]
226impl SankeyDiagram {
227 #[builder(
228 start_fn = try_builder,
229 finish_fn = try_build,
230 builder_type = SankeyDiagramTryBuilder,
231 on(String, into),
232 on(Text, into),
233 )]
234 pub fn try_new(
235 data: &DataFrame,
236 sources: &str,
237 targets: &str,
238 values: &str,
239 facet: Option<&str>,
240 facet_config: Option<&FacetConfig>,
241 orientation: Option<Orientation>,
242 arrangement: Option<Arrangement>,
243 pad: Option<usize>,
244 thickness: Option<usize>,
245 node_color: Option<Rgb>,
246 node_colors: Option<Vec<Rgb>>,
247 link_color: Option<Rgb>,
248 link_colors: Option<Vec<Rgb>>,
249 plot_title: Option<Text>,
250 legend_title: Option<Text>,
251 legend: Option<&Legend>,
252 ) -> Result<Self, crate::io::PlotlarsError> {
253 std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
254 Self::__orig_new(
255 data,
256 sources,
257 targets,
258 values,
259 facet,
260 facet_config,
261 orientation,
262 arrangement,
263 pad,
264 thickness,
265 node_color,
266 node_colors,
267 link_color,
268 link_colors,
269 plot_title,
270 legend_title,
271 legend,
272 )
273 }))
274 .map_err(|panic| {
275 let msg = panic
276 .downcast_ref::<String>()
277 .cloned()
278 .or_else(|| panic.downcast_ref::<&str>().map(|s| s.to_string()))
279 .unwrap_or_else(|| "unknown error".to_string());
280 crate::io::PlotlarsError::PlotBuild { message: msg }
281 })
282 }
283}
284
285impl SankeyDiagram {
286 #[allow(clippy::too_many_arguments)]
287 fn create_ir_traces(
288 data: &DataFrame,
289 sources: &str,
290 targets: &str,
291 values: &str,
292 orientation: Option<Orientation>,
293 arrangement: Option<Arrangement>,
294 pad: Option<usize>,
295 thickness: Option<usize>,
296 node_color: Option<Rgb>,
297 node_colors: Option<Vec<Rgb>>,
298 link_color: Option<Rgb>,
299 link_colors: Option<Vec<Rgb>>,
300 ) -> Vec<TraceIR> {
301 let sources_col = crate::data::get_string_column(data, sources);
302 let targets_col = crate::data::get_string_column(data, targets);
303 let values_data = crate::data::get_numeric_column(data, values);
304
305 let (labels_unique, label_to_idx) = Self::build_label_index(&sources_col, &targets_col);
306
307 let sources_idx = Self::column_to_indices(&sources_col, &label_to_idx);
308 let targets_idx = Self::column_to_indices(&targets_col, &label_to_idx);
309
310 let resolved_node_colors = Self::resolve_node_colors(node_color, node_colors);
311 let resolved_link_colors = Self::resolve_link_colors(link_color, link_colors);
312
313 vec![TraceIR::SankeyDiagram(SankeyDiagramIR {
314 sources: sources_idx,
315 targets: targets_idx,
316 values: ColumnData::Numeric(values_data),
317 node_labels: labels_unique.iter().map(|s| s.to_string()).collect(),
318 orientation,
319 arrangement,
320 pad,
321 thickness,
322 node_colors: resolved_node_colors,
323 link_colors: resolved_link_colors,
324 domain_x: None,
325 domain_y: None,
326 })]
327 }
328
329 #[allow(clippy::too_many_arguments)]
330 fn create_ir_traces_faceted(
331 data: &DataFrame,
332 sources: &str,
333 targets: &str,
334 values: &str,
335 facet_column: &str,
336 config: &FacetConfig,
337 orientation: Option<Orientation>,
338 arrangement: Option<Arrangement>,
339 pad: Option<usize>,
340 thickness: Option<usize>,
341 node_color: Option<Rgb>,
342 node_colors: Option<Vec<Rgb>>,
343 link_color: Option<Rgb>,
344 link_colors: Option<Vec<Rgb>>,
345 ) -> Vec<TraceIR> {
346 const MAX_FACETS: usize = 8;
347
348 let facet_categories = crate::data::get_unique_groups(data, facet_column, config.sorter);
349
350 if facet_categories.len() > MAX_FACETS {
351 panic!(
352 "Facet column '{}' has {} unique values, but plotly.rs supports maximum {} subplots",
353 facet_column,
354 facet_categories.len(),
355 MAX_FACETS
356 );
357 }
358
359 let n_facets = facet_categories.len();
360 let (ncols, nrows) =
361 crate::faceting::calculate_grid_dimensions(n_facets, config.cols, config.rows);
362
363 let facet_categories_non_empty: Vec<String> = facet_categories
364 .iter()
365 .filter(|facet_value| {
366 let facet_data = crate::data::filter_data_by_group(data, facet_column, facet_value);
367 facet_data.height() > 0
368 })
369 .cloned()
370 .collect();
371
372 let resolved_node_colors = Self::resolve_node_colors(node_color, node_colors);
373 let resolved_link_colors = Self::resolve_link_colors(link_color, link_colors);
374
375 let mut traces = Vec::new();
376
377 for (idx, facet_value) in facet_categories_non_empty.iter().enumerate() {
378 let facet_data = crate::data::filter_data_by_group(data, facet_column, facet_value);
379
380 let cell = Self::calculate_facet_cell(idx, ncols, nrows, config.h_gap, config.v_gap);
381
382 let sources_col = crate::data::get_string_column(&facet_data, sources);
383 let targets_col = crate::data::get_string_column(&facet_data, targets);
384 let values_data = crate::data::get_numeric_column(&facet_data, values);
385
386 let (labels_unique, label_to_idx) = Self::build_label_index(&sources_col, &targets_col);
387
388 let sources_idx = Self::column_to_indices(&sources_col, &label_to_idx);
389 let targets_idx = Self::column_to_indices(&targets_col, &label_to_idx);
390
391 traces.push(TraceIR::SankeyDiagram(SankeyDiagramIR {
392 sources: sources_idx,
393 targets: targets_idx,
394 values: ColumnData::Numeric(values_data),
395 node_labels: labels_unique.iter().map(|s| s.to_string()).collect(),
396 orientation: orientation.clone(),
397 arrangement: arrangement.clone(),
398 pad,
399 thickness,
400 node_colors: resolved_node_colors.clone(),
401 link_colors: resolved_link_colors.clone(),
402 domain_x: Some((cell.domain_x_start, cell.domain_x_end)),
403 domain_y: Some((cell.domain_y_start, cell.domain_y_end)),
404 }));
405 }
406
407 traces
408 }
409
410 fn resolve_node_colors(
411 node_color: Option<Rgb>,
412 node_colors: Option<Vec<Rgb>>,
413 ) -> Option<Vec<Rgb>> {
414 if let Some(colors) = node_colors {
415 Some(colors)
416 } else if let Some(color) = node_color {
417 Some(vec![color])
418 } else {
419 None
420 }
421 }
422
423 fn resolve_link_colors(
424 link_color: Option<Rgb>,
425 link_colors: Option<Vec<Rgb>>,
426 ) -> Option<Vec<Rgb>> {
427 if let Some(colors) = link_colors {
428 Some(colors)
429 } else if let Some(color) = link_color {
430 Some(vec![color])
431 } else {
432 None
433 }
434 }
435 fn calculate_facet_cell(
440 subplot_index: usize,
441 ncols: usize,
442 nrows: usize,
443 x_gap: Option<f64>,
444 y_gap: Option<f64>,
445 ) -> FacetCell {
446 let row = subplot_index / ncols;
447 let col = subplot_index % ncols;
448
449 let x_gap_val = x_gap.unwrap_or(0.05);
450 let y_gap_val = y_gap.unwrap_or(0.10);
451
452 const TITLE_HEIGHT_RATIO: f64 = 0.10;
454
455 let cell_width = (1.0 - x_gap_val * (ncols - 1) as f64) / ncols as f64;
457 let cell_height = (1.0 - y_gap_val * (nrows - 1) as f64) / nrows as f64;
458
459 let cell_x_start = col as f64 * (cell_width + x_gap_val);
461 let cell_y_top = 1.0 - row as f64 * (cell_height + y_gap_val);
462 let cell_y_bottom = cell_y_top - cell_height;
463
464 let title_height = cell_height * TITLE_HEIGHT_RATIO;
466 let sankey_y_top = cell_y_top - title_height;
467
468 let sankey_x_start = cell_x_start;
470 let sankey_x_end = cell_x_start + cell_width;
471 let sankey_y_start = cell_y_bottom;
472 let sankey_y_end = sankey_y_top;
473
474 FacetCell {
475 domain_x_start: sankey_x_start,
476 domain_x_end: sankey_x_end,
477 domain_y_start: sankey_y_start,
478 domain_y_end: sankey_y_end,
479 }
480 }
481
482 fn build_label_index<'a>(
483 sources: &'a [Option<String>],
484 targets: &'a [Option<String>],
485 ) -> (Vec<&'a str>, HashMap<&'a str, usize>) {
486 let mut label_to_idx: HashMap<&'a str, usize> = HashMap::new();
487 let mut labels_unique: Vec<&'a str> = Vec::new();
488
489 let iter = sources
490 .iter()
491 .chain(targets.iter())
492 .filter_map(|opt| opt.as_deref());
493
494 for lbl in iter {
495 if let Entry::Vacant(v) = label_to_idx.entry(lbl) {
496 let next_id = labels_unique.len();
497 v.insert(next_id);
498 labels_unique.push(lbl);
499 }
500 }
501
502 (labels_unique, label_to_idx)
503 }
504
505 fn column_to_indices(
506 column: &[Option<String>],
507 label_to_idx: &HashMap<&str, usize>,
508 ) -> Vec<usize> {
509 column
510 .iter()
511 .filter_map(|opt| opt.as_deref())
512 .map(|lbl| *label_to_idx.get(lbl).expect("label must exist in map"))
513 .collect()
514 }
515}
516
517impl crate::Plot for SankeyDiagram {
518 fn ir_traces(&self) -> &[TraceIR] {
519 &self.traces
520 }
521
522 fn ir_layout(&self) -> &LayoutIR {
523 &self.layout
524 }
525}
526
527#[cfg(test)]
528mod tests {
529 use super::*;
530 use crate::Plot;
531 use polars::prelude::*;
532
533 #[test]
534 fn test_basic_one_trace() {
535 let df = df![
536 "source" => ["A", "A", "B"],
537 "target" => ["B", "C", "C"],
538 "value" => [10.0, 20.0, 30.0]
539 ]
540 .unwrap();
541 let plot = SankeyDiagram::builder()
542 .data(&df)
543 .sources("source")
544 .targets("target")
545 .values("value")
546 .build();
547 assert_eq!(plot.ir_traces().len(), 1);
548 assert!(matches!(plot.ir_traces()[0], TraceIR::SankeyDiagram(_)));
549 }
550
551 #[test]
552 fn test_layout_no_axes() {
553 let df = df![
554 "source" => ["A", "B"],
555 "target" => ["B", "C"],
556 "value" => [10.0, 20.0]
557 ]
558 .unwrap();
559 let plot = SankeyDiagram::builder()
560 .data(&df)
561 .sources("source")
562 .targets("target")
563 .values("value")
564 .build();
565 assert!(plot.ir_layout().axes_2d.is_none());
566 }
567
568 #[test]
569 fn test_layout_title() {
570 let df = df![
571 "source" => ["A"],
572 "target" => ["B"],
573 "value" => [10.0]
574 ]
575 .unwrap();
576 let plot = SankeyDiagram::builder()
577 .data(&df)
578 .sources("source")
579 .targets("target")
580 .values("value")
581 .plot_title("Sankey")
582 .build();
583 assert!(plot.ir_layout().title.is_some());
584 }
585}