1use bon::bon;
2use std::collections::{
3 HashMap,
4 hash_map::Entry,
5};
6
7use plotly::{
8 sankey::{Link, Node}, Layout as LayoutPlotly, Sankey, Trace
9};
10
11use polars::frame::DataFrame;
12use serde::Serialize;
13
14use crate::{
15 common::{Layout, PlotHelper, Polar},
16 components::{Arrangement, Orientation, Rgb, Text},
17};
18
19#[derive(Clone, Serialize)]
95pub struct SankeyDiagram {
96 traces: Vec<Box<dyn Trace + 'static>>,
97 layout: LayoutPlotly,
98}
99
100#[bon]
101impl SankeyDiagram {
102 #[builder(on(String, into), on(Text, into))]
103 pub fn new(
104 data: &DataFrame,
105 sources: &str,
106 targets: &str,
107 values: &str,
108 orientation: Option<Orientation>,
109 arrangement: Option<Arrangement>,
110 pad: Option<usize>,
111 thickness: Option<usize>,
112 node_color: Option<Rgb>,
113 node_colors: Option<Vec<Rgb>>,
114 link_color: Option<Rgb>,
115 link_colors: Option<Vec<Rgb>>,
116 plot_title: Option<Text>,
117 ) -> Self {
118 let legend = None;
119 let legend_title = None;
120 let x_title = None;
121 let y_title = None;
122 let z_title = None;
123 let x_axis = None;
124 let y_axis = None;
125 let z_axis = None;
126
127 let layout = Self::create_layout(
128 plot_title,
129 x_title,
130 y_title,
131 None, z_title,
133 legend_title,
134 x_axis,
135 y_axis,
136 None, z_axis,
138 legend,
139 );
140
141 let traces = Self::create_traces(
142 data,
143 sources,
144 targets,
145 values,
146 orientation,
147 arrangement,
148 pad,
149 thickness,
150 node_color,
151 node_colors,
152 link_color,
153 link_colors,
154 );
155
156 Self { traces, layout }
157 }
158
159 #[allow(clippy::too_many_arguments)]
160 fn create_traces(
161 data: &DataFrame,
162 sources: &str,
163 targets: &str,
164 values: &str,
165 orientation: Option<Orientation>,
166 arrangement: Option<Arrangement>,
167 pad: Option<usize>,
168 thickness: Option<usize>,
169 node_color: Option<Rgb>,
170 node_colors: Option<Vec<Rgb>>,
171 link_color: Option<Rgb>,
172 link_colors: Option<Vec<Rgb>>,
173 ) -> Vec<Box<dyn Trace + 'static>> {
174 let mut traces: Vec<Box<dyn Trace + 'static>> = Vec::new();
175
176 let trace = Self::create_trace(
177 data,
178 sources,
179 targets,
180 values,
181 orientation,
182 arrangement,
183 pad,
184 thickness,
185 node_color,
186 node_colors,
187 link_color,
188 link_colors,
189 );
190
191 traces.push(trace);
192 traces
193 }
194
195 #[allow(clippy::too_many_arguments)]
196 fn create_trace(
197 data: &DataFrame,
198 sources: &str,
199 targets: &str,
200 values: &str,
201 orientation: Option<Orientation>,
202 arrangement: Option<Arrangement>,
203 pad: Option<usize>,
204 thickness: Option<usize>,
205 node_color: Option<Rgb>,
206 node_colors: Option<Vec<Rgb>>,
207 link_color: Option<Rgb>,
208 link_colors: Option<Vec<Rgb>>,
209 ) -> Box<dyn Trace + 'static> {
210 let sources = Self::get_string_column(data, sources);
211 let targets = Self::get_string_column(data, targets);
212 let values = Self::get_numeric_column(data, values);
213
214 let (labels_unique, label_to_idx) = Self::build_label_index(&sources, &targets);
215
216 let sources_idx = Self::column_to_indices(&sources, &label_to_idx);
217 let targets_idx = Self::column_to_indices(&targets, &label_to_idx);
218
219 let mut node = Node::new()
220 .label(labels_unique);
221
222 node = Self::set_pad(node, pad);
223 node = Self::set_thickness(node, thickness);
224 node = Self::set_node_color(node, node_color);
225 node = Self::set_node_colors(node, node_colors);
226
227 let mut link = Link::new()
228 .source(sources_idx)
229 .target(targets_idx)
230 .value(values);
231
232 link = Self::set_link_color(link, link_color);
233 link = Self::set_link_colors(link, link_colors);
234
235 let mut trace = Sankey::new()
236 .node(node)
237 .link(link);
238
239 trace = Self::set_orientation(trace, orientation);
240 trace = Self::set_arrangement(trace, arrangement);
241 trace
242 }
243
244 fn set_thickness(
245 mut node: Node,
246 thickness: Option<usize>,
247 ) -> Node {
248 if let Some(thickness) = thickness {
249 node = node.thickness(thickness);
250 }
251
252 node
253 }
254
255 fn set_pad(
256 mut node: Node,
257 pad: Option<usize>,
258 ) -> Node {
259 if let Some(pad) = pad {
260 node = node.pad(pad);
261 }
262
263 node
264 }
265
266 fn set_link_colors<V>(
267 mut link: Link<V>,
268 colors: Option<Vec<Rgb>>,
269 ) -> Link<V>
270 where
271 V: Serialize + Clone,
272 {
273 if let Some(colors) = colors {
274 link = link.color_array(
275 colors
276 .iter()
277 .map(|color| color.to_plotly())
278 .collect()
279 );
280 }
281
282 link
283 }
284
285 fn set_link_color<V>(
286 mut link: Link<V>,
287 color: Option<Rgb>,
288 ) -> Link<V>
289 where
290 V: Serialize + Clone,
291 {
292 if let Some(color) = color {
293 link = link.color(color);
294 }
295
296 link
297 }
298
299 fn set_node_colors(
300 mut node: Node,
301 colors: Option<Vec<Rgb>>,
302 ) -> Node {
303 if let Some(colors) = colors {
304 node = node.color_array(
305 colors
306 .iter()
307 .map(|color| color.to_plotly())
308 .collect()
309 );
310 }
311
312 node
313 }
314
315 fn set_node_color(
316 mut node: Node,
317 color: Option<Rgb>,
318 ) -> Node {
319 if let Some(color) = color {
320 node = node.color(color);
321 }
322
323 node
324 }
325
326 fn set_arrangement(
327 mut trace: Box<Sankey<Option<f32>>>,
328 arrangement: Option<Arrangement>,
329 ) -> Box<Sankey<Option<f32>>> {
330 if let Some(arrangement) = arrangement {
331 trace = trace.arrangement(arrangement.to_plotly())
332 }
333
334 trace
335 }
336
337 fn set_orientation(
338 mut trace: Box<Sankey<Option<f32>>>,
339 orientation: Option<Orientation>,
340 ) -> Box<Sankey<Option<f32>>> {
341 if let Some(orientation) = orientation {
342 trace = trace.orientation(orientation.to_plotly())
343 }
344
345 trace
346 }
347
348 fn build_label_index<'a>(
349 sources: &'a [Option<String>],
350 targets: &'a [Option<String>],
351 ) -> (Vec<&'a str>, HashMap<&'a str, usize>) {
352 let mut label_to_idx: HashMap<&'a str, usize> = HashMap::new();
353 let mut labels_unique: Vec<&'a str> = Vec::new();
354
355 let iter = sources
356 .iter()
357 .chain(targets.iter())
358 .filter_map(|opt| opt.as_deref());
359
360 for lbl in iter {
361 if let Entry::Vacant(v) = label_to_idx.entry(lbl) {
362 let next_id = labels_unique.len();
363 v.insert(next_id);
364 labels_unique.push(lbl);
365 }
366 }
367
368 (labels_unique, label_to_idx)
369 }
370
371 fn column_to_indices(
372 column: &[Option<String>],
373 label_to_idx: &HashMap<&str, usize>,
374 ) -> Vec<usize> {
375 column
376 .iter()
377 .filter_map(|opt| opt.as_deref())
378 .map(|lbl| *label_to_idx.get(lbl).expect("label must exist in map"))
379 .collect()
380 }
381}
382
383impl Layout for SankeyDiagram {}
384impl Polar for SankeyDiagram {}
385
386impl PlotHelper for SankeyDiagram {
387 fn get_layout(&self) -> &LayoutPlotly {
388 &self.layout
389 }
390
391 fn get_traces(&self) -> &Vec<Box<dyn Trace + 'static>> {
392 &self.traces
393 }
394}