1use bon::bon;
2use std::collections::{hash_map::Entry, HashMap};
3
4use crate::{
5 components::{Arrangement, FacetConfig, Legend, Orientation, Rgb, Text},
6 ir::data::ColumnData,
7 ir::layout::LayoutIR,
8 ir::trace::{SankeyDiagramIR, TraceIR},
9};
10use polars::frame::DataFrame;
11
12#[derive(Clone)]
100#[allow(dead_code)]
101pub struct SankeyDiagram {
102 traces: Vec<TraceIR>,
103 layout: LayoutIR,
104}
105
106struct FacetCell {
107 domain_x_start: f64,
108 domain_x_end: f64,
109 domain_y_start: f64,
110 domain_y_end: f64,
111}
112
113#[bon]
114impl SankeyDiagram {
115 #[builder(on(String, into), on(Text, into))]
116 pub fn new(
117 data: &DataFrame,
118 sources: &str,
119 targets: &str,
120 values: &str,
121 facet: Option<&str>,
122 facet_config: Option<&FacetConfig>,
123 orientation: Option<Orientation>,
124 arrangement: Option<Arrangement>,
125 pad: Option<usize>,
126 thickness: Option<usize>,
127 node_color: Option<Rgb>,
128 node_colors: Option<Vec<Rgb>>,
129 link_color: Option<Rgb>,
130 link_colors: Option<Vec<Rgb>>,
131 plot_title: Option<Text>,
132 legend_title: Option<Text>,
133 legend: Option<&Legend>,
134 ) -> Self {
135 let grid = facet.map(|facet_column| {
136 let config = facet_config.cloned().unwrap_or_default();
137 let facet_categories =
138 crate::data::get_unique_groups(data, facet_column, config.sorter);
139 let n_facets = facet_categories.len();
140 let (ncols, nrows) =
141 crate::faceting::calculate_grid_dimensions(n_facets, config.cols, config.rows);
142 crate::ir::facet::GridSpec {
143 kind: crate::ir::facet::FacetKind::Domain,
144 rows: nrows,
145 cols: ncols,
146 h_gap: config.h_gap,
147 v_gap: config.v_gap,
148 scales: config.scales.clone(),
149 n_facets,
150 facet_categories,
151 title_style: config.title_style.clone(),
152 x_title: None,
153 y_title: None,
154 x_axis: None,
155 y_axis: None,
156 legend_title: legend_title.clone(),
157 legend: legend.cloned(),
158 }
159 });
160
161 let layout = LayoutIR {
162 title: plot_title,
163 x_title: None,
164 y_title: None,
165 y2_title: None,
166 z_title: None,
167 legend_title: if grid.is_some() { None } else { legend_title },
168 legend: if grid.is_some() {
169 None
170 } else {
171 legend.cloned()
172 },
173 dimensions: None,
174 bar_mode: None,
175 box_mode: None,
176 box_gap: None,
177 margin_bottom: None,
178 axes_2d: None,
179 scene_3d: None,
180 polar: None,
181 mapbox: None,
182 grid,
183 annotations: vec![],
184 };
185
186 let traces = match facet {
187 Some(facet_column) => {
188 let config = facet_config.cloned().unwrap_or_default();
189 Self::create_ir_traces_faceted(
190 data,
191 sources,
192 targets,
193 values,
194 facet_column,
195 &config,
196 orientation,
197 arrangement,
198 pad,
199 thickness,
200 node_color,
201 node_colors,
202 link_color,
203 link_colors,
204 )
205 }
206 None => Self::create_ir_traces(
207 data,
208 sources,
209 targets,
210 values,
211 orientation,
212 arrangement,
213 pad,
214 thickness,
215 node_color,
216 node_colors,
217 link_color,
218 link_colors,
219 ),
220 };
221 Self { traces, layout }
222 }
223}
224
225#[bon]
226impl SankeyDiagram {
227 #[builder(
228 start_fn = try_builder,
229 finish_fn = try_build,
230 builder_type = SankeyDiagramTryBuilder,
231 on(String, into),
232 on(Text, into),
233 )]
234 pub fn try_new(
235 data: &DataFrame,
236 sources: &str,
237 targets: &str,
238 values: &str,
239 facet: Option<&str>,
240 facet_config: Option<&FacetConfig>,
241 orientation: Option<Orientation>,
242 arrangement: Option<Arrangement>,
243 pad: Option<usize>,
244 thickness: Option<usize>,
245 node_color: Option<Rgb>,
246 node_colors: Option<Vec<Rgb>>,
247 link_color: Option<Rgb>,
248 link_colors: Option<Vec<Rgb>>,
249 plot_title: Option<Text>,
250 legend_title: Option<Text>,
251 legend: Option<&Legend>,
252 ) -> Result<Self, crate::io::PlotlarsError> {
253 std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
254 Self::__orig_new(
255 data,
256 sources,
257 targets,
258 values,
259 facet,
260 facet_config,
261 orientation,
262 arrangement,
263 pad,
264 thickness,
265 node_color,
266 node_colors,
267 link_color,
268 link_colors,
269 plot_title,
270 legend_title,
271 legend,
272 )
273 }))
274 .map_err(|panic| {
275 let msg = panic
276 .downcast_ref::<String>()
277 .cloned()
278 .or_else(|| panic.downcast_ref::<&str>().map(|s| s.to_string()))
279 .unwrap_or_else(|| "unknown error".to_string());
280 crate::io::PlotlarsError::PlotBuild { message: msg }
281 })
282 }
283}
284
285impl SankeyDiagram {
286 #[allow(clippy::too_many_arguments)]
287 fn create_ir_traces(
288 data: &DataFrame,
289 sources: &str,
290 targets: &str,
291 values: &str,
292 orientation: Option<Orientation>,
293 arrangement: Option<Arrangement>,
294 pad: Option<usize>,
295 thickness: Option<usize>,
296 node_color: Option<Rgb>,
297 node_colors: Option<Vec<Rgb>>,
298 link_color: Option<Rgb>,
299 link_colors: Option<Vec<Rgb>>,
300 ) -> Vec<TraceIR> {
301 let sources_col = crate::data::get_string_column(data, sources);
302 let targets_col = crate::data::get_string_column(data, targets);
303 let values_data = crate::data::get_numeric_column(data, values);
304
305 let (labels_unique, label_to_idx) = Self::build_label_index(&sources_col, &targets_col);
306
307 let sources_idx = Self::column_to_indices(&sources_col, &label_to_idx);
308 let targets_idx = Self::column_to_indices(&targets_col, &label_to_idx);
309
310 let resolved_node_colors = Self::resolve_node_colors(node_color, node_colors);
311 let resolved_link_colors = Self::resolve_link_colors(link_color, link_colors);
312
313 vec![TraceIR::SankeyDiagram(SankeyDiagramIR {
314 sources: sources_idx,
315 targets: targets_idx,
316 values: ColumnData::Numeric(values_data),
317 node_labels: labels_unique.iter().map(|s| s.to_string()).collect(),
318 orientation,
319 arrangement,
320 pad,
321 thickness,
322 node_colors: resolved_node_colors,
323 link_colors: resolved_link_colors,
324 domain_x: None,
325 domain_y: None,
326 })]
327 }
328
329 #[allow(clippy::too_many_arguments)]
330 fn create_ir_traces_faceted(
331 data: &DataFrame,
332 sources: &str,
333 targets: &str,
334 values: &str,
335 facet_column: &str,
336 config: &FacetConfig,
337 orientation: Option<Orientation>,
338 arrangement: Option<Arrangement>,
339 pad: Option<usize>,
340 thickness: Option<usize>,
341 node_color: Option<Rgb>,
342 node_colors: Option<Vec<Rgb>>,
343 link_color: Option<Rgb>,
344 link_colors: Option<Vec<Rgb>>,
345 ) -> Vec<TraceIR> {
346 const MAX_FACETS: usize = 8;
347
348 let facet_categories = crate::data::get_unique_groups(data, facet_column, config.sorter);
349
350 if facet_categories.len() > MAX_FACETS {
351 panic!(
352 "Facet column '{}' has {} unique values, but plotly.rs supports maximum {} subplots",
353 facet_column,
354 facet_categories.len(),
355 MAX_FACETS
356 );
357 }
358
359 let n_facets = facet_categories.len();
360 let (ncols, nrows) =
361 crate::faceting::calculate_grid_dimensions(n_facets, config.cols, config.rows);
362
363 let facet_categories_non_empty: Vec<String> = facet_categories
364 .iter()
365 .filter(|facet_value| {
366 let facet_data = crate::data::filter_data_by_group(data, facet_column, facet_value);
367 facet_data.height() > 0
368 })
369 .cloned()
370 .collect();
371
372 let resolved_node_colors = Self::resolve_node_colors(node_color, node_colors);
373 let resolved_link_colors = Self::resolve_link_colors(link_color, link_colors);
374
375 let mut traces = Vec::new();
376
377 for (idx, facet_value) in facet_categories_non_empty.iter().enumerate() {
378 let facet_data = crate::data::filter_data_by_group(data, facet_column, facet_value);
379
380 let cell = Self::calculate_facet_cell(idx, ncols, nrows, config.h_gap, config.v_gap);
381
382 let sources_col = crate::data::get_string_column(&facet_data, sources);
383 let targets_col = crate::data::get_string_column(&facet_data, targets);
384 let values_data = crate::data::get_numeric_column(&facet_data, values);
385
386 let (labels_unique, label_to_idx) = Self::build_label_index(&sources_col, &targets_col);
387
388 let sources_idx = Self::column_to_indices(&sources_col, &label_to_idx);
389 let targets_idx = Self::column_to_indices(&targets_col, &label_to_idx);
390
391 traces.push(TraceIR::SankeyDiagram(SankeyDiagramIR {
392 sources: sources_idx,
393 targets: targets_idx,
394 values: ColumnData::Numeric(values_data),
395 node_labels: labels_unique.iter().map(|s| s.to_string()).collect(),
396 orientation: orientation.clone(),
397 arrangement: arrangement.clone(),
398 pad,
399 thickness,
400 node_colors: resolved_node_colors.clone(),
401 link_colors: resolved_link_colors.clone(),
402 domain_x: Some((cell.domain_x_start, cell.domain_x_end)),
403 domain_y: Some((cell.domain_y_start, cell.domain_y_end)),
404 }));
405 }
406
407 traces
408 }
409
410 fn resolve_node_colors(
411 node_color: Option<Rgb>,
412 node_colors: Option<Vec<Rgb>>,
413 ) -> Option<Vec<Rgb>> {
414 node_colors.or_else(|| node_color.map(|color| vec![color]))
415 }
416
417 fn resolve_link_colors(
418 link_color: Option<Rgb>,
419 link_colors: Option<Vec<Rgb>>,
420 ) -> Option<Vec<Rgb>> {
421 link_colors.or_else(|| link_color.map(|color| vec![color]))
422 }
423 fn calculate_facet_cell(
428 subplot_index: usize,
429 ncols: usize,
430 nrows: usize,
431 x_gap: Option<f64>,
432 y_gap: Option<f64>,
433 ) -> FacetCell {
434 let row = subplot_index / ncols;
435 let col = subplot_index % ncols;
436
437 let x_gap_val = x_gap.unwrap_or(0.05);
438 let y_gap_val = y_gap.unwrap_or(0.10);
439
440 const TITLE_HEIGHT_RATIO: f64 = 0.10;
442
443 let cell_width = (1.0 - x_gap_val * (ncols - 1) as f64) / ncols as f64;
445 let cell_height = (1.0 - y_gap_val * (nrows - 1) as f64) / nrows as f64;
446
447 let cell_x_start = col as f64 * (cell_width + x_gap_val);
449 let cell_y_top = 1.0 - row as f64 * (cell_height + y_gap_val);
450 let cell_y_bottom = cell_y_top - cell_height;
451
452 let title_height = cell_height * TITLE_HEIGHT_RATIO;
454 let sankey_y_top = cell_y_top - title_height;
455
456 let sankey_x_start = cell_x_start;
458 let sankey_x_end = cell_x_start + cell_width;
459 let sankey_y_start = cell_y_bottom;
460 let sankey_y_end = sankey_y_top;
461
462 FacetCell {
463 domain_x_start: sankey_x_start,
464 domain_x_end: sankey_x_end,
465 domain_y_start: sankey_y_start,
466 domain_y_end: sankey_y_end,
467 }
468 }
469
470 fn build_label_index<'a>(
471 sources: &'a [Option<String>],
472 targets: &'a [Option<String>],
473 ) -> (Vec<&'a str>, HashMap<&'a str, usize>) {
474 let mut label_to_idx: HashMap<&'a str, usize> = HashMap::new();
475 let mut labels_unique: Vec<&'a str> = Vec::new();
476
477 let iter = sources
478 .iter()
479 .chain(targets.iter())
480 .filter_map(|opt| opt.as_deref());
481
482 for lbl in iter {
483 if let Entry::Vacant(v) = label_to_idx.entry(lbl) {
484 let next_id = labels_unique.len();
485 v.insert(next_id);
486 labels_unique.push(lbl);
487 }
488 }
489
490 (labels_unique, label_to_idx)
491 }
492
493 fn column_to_indices(
494 column: &[Option<String>],
495 label_to_idx: &HashMap<&str, usize>,
496 ) -> Vec<usize> {
497 column
498 .iter()
499 .filter_map(|opt| opt.as_deref())
500 .map(|lbl| *label_to_idx.get(lbl).expect("label must exist in map"))
501 .collect()
502 }
503}
504
505impl crate::Plot for SankeyDiagram {
506 fn ir_traces(&self) -> &[TraceIR] {
507 &self.traces
508 }
509
510 fn ir_layout(&self) -> &LayoutIR {
511 &self.layout
512 }
513}
514
515#[cfg(test)]
516mod tests {
517 use super::*;
518 use crate::Plot;
519 use polars::prelude::*;
520
521 #[test]
522 fn test_basic_one_trace() {
523 let df = df![
524 "source" => ["A", "A", "B"],
525 "target" => ["B", "C", "C"],
526 "value" => [10.0, 20.0, 30.0]
527 ]
528 .unwrap();
529 let plot = SankeyDiagram::builder()
530 .data(&df)
531 .sources("source")
532 .targets("target")
533 .values("value")
534 .build();
535 assert_eq!(plot.ir_traces().len(), 1);
536 assert!(matches!(plot.ir_traces()[0], TraceIR::SankeyDiagram(_)));
537 }
538
539 #[test]
540 fn test_layout_no_axes() {
541 let df = df![
542 "source" => ["A", "B"],
543 "target" => ["B", "C"],
544 "value" => [10.0, 20.0]
545 ]
546 .unwrap();
547 let plot = SankeyDiagram::builder()
548 .data(&df)
549 .sources("source")
550 .targets("target")
551 .values("value")
552 .build();
553 assert!(plot.ir_layout().axes_2d.is_none());
554 }
555
556 #[test]
557 fn test_layout_title() {
558 let df = df![
559 "source" => ["A"],
560 "target" => ["B"],
561 "value" => [10.0]
562 ]
563 .unwrap();
564 let plot = SankeyDiagram::builder()
565 .data(&df)
566 .sources("source")
567 .targets("target")
568 .values("value")
569 .plot_title("Sankey")
570 .build();
571 assert!(plot.ir_layout().title.is_some());
572 }
573}