1use plotly_derive::FieldSetter;
4use serde::Serialize;
5
6use crate::{
7 color::{Color, ColorArray},
8 common::{Dim, Domain, Font, HoverInfo, Label, LegendGroupTitle, Orientation, PlotType},
9 Trace,
10};
11
12#[derive(Serialize, Clone)]
13#[serde(rename_all = "lowercase")]
14pub enum Arrangement {
15 Snap,
16 Perpendicular,
17 Freeform,
18 Fixed,
19}
20
21#[serde_with::skip_serializing_none]
22#[derive(Serialize, Clone, Default)]
23pub struct Line {
24 color: Option<Dim<Box<dyn Color>>>,
25 width: Option<f64>,
26}
27
28impl Line {
29 pub fn new() -> Self {
30 Default::default()
31 }
32
33 pub fn color<C: Color>(mut self, color: C) -> Self {
34 self.color = Some(Dim::Scalar(Box::new(color)));
35 self
36 }
37
38 pub fn color_array<C: Color>(mut self, colors: Vec<C>) -> Self {
39 self.color = Some(Dim::Vector(ColorArray(colors).into()));
40 self
41 }
42
43 pub fn width(mut self, width: f64) -> Self {
44 self.width = Some(width);
45 self
46 }
47}
48
49#[serde_with::skip_serializing_none]
50#[derive(Serialize, Default, Clone)]
51pub struct Node {
52 color: Option<Dim<Box<dyn Color>>>,
54 #[serde(rename = "hoverinfo")]
55 hover_info: Option<HoverInfo>,
56 #[serde(rename = "hoverlabel")]
57 hover_label: Option<Label>,
58 #[serde(rename = "hovertemplate")]
59 hover_template: Option<Dim<String>>,
60 label: Option<Vec<String>>,
61 line: Option<Line>,
62 pad: Option<usize>,
64 thickness: Option<usize>,
66 x: Option<Vec<f64>>,
68 y: Option<Vec<f64>>,
70}
71
72impl Node {
73 pub fn new() -> Self {
74 Default::default()
75 }
76
77 pub fn color<C: Color>(mut self, color: C) -> Self {
78 self.color = Some(Dim::Scalar(Box::new(color)));
79 self
80 }
81
82 pub fn color_array<C: Color>(mut self, colors: Vec<C>) -> Self {
83 self.color = Some(Dim::Vector(ColorArray(colors).into()));
84 self
85 }
86
87 pub fn hover_info(mut self, hover_info: HoverInfo) -> Self {
88 self.hover_info = Some(hover_info);
89 self
90 }
91
92 pub fn hover_label(mut self, hover_label: Label) -> Self {
93 self.hover_label = Some(hover_label);
94 self
95 }
96
97 pub fn hover_template(mut self, hover_template: &str) -> Self {
98 self.hover_template = Some(Dim::Scalar(hover_template.to_string()));
99 self
100 }
101
102 pub fn label(mut self, label: Vec<&str>) -> Self {
103 self.label = Some(label.iter().map(|&el| el.to_string()).collect());
104 self
105 }
106
107 pub fn line(mut self, line: Line) -> Self {
108 self.line = Some(line);
109 self
110 }
111
112 pub fn pad(mut self, pad: usize) -> Self {
113 self.pad = Some(pad);
114 self
115 }
116
117 pub fn thickness(mut self, thickness: usize) -> Self {
118 self.thickness = Some(thickness);
119 self
120 }
121
122 pub fn x(mut self, x: Vec<f64>) -> Self {
123 self.x = Some(x);
124 self
125 }
126
127 pub fn y(mut self, y: Vec<f64>) -> Self {
128 self.y = Some(y);
129 self
130 }
131}
132
133#[serde_with::skip_serializing_none]
134#[derive(Serialize, Clone)]
135pub struct Link<V>
136where
137 V: Serialize + Clone,
138{
139 color: Option<Dim<Box<dyn Color>>>,
141 #[serde(rename = "hoverinfo")]
142 hover_info: Option<HoverInfo>,
143 #[serde(rename = "hoverlabel")]
144 hover_label: Option<Label>,
145 #[serde(rename = "hovertemplate")]
146 hover_template: Option<Dim<String>>,
147 line: Option<Line>,
148 source: Option<Vec<usize>>,
149 target: Option<Vec<usize>>,
150 value: Option<Vec<V>>,
151}
152
153impl<V> Default for Link<V>
154where
155 V: Serialize + Clone,
156{
157 fn default() -> Self {
158 Self {
159 color: None,
160 hover_info: None,
161 hover_label: None,
162 hover_template: None,
163 line: None,
164 source: None,
165 target: None,
166 value: None,
167 }
168 }
169}
170
171impl<V> Link<V>
172where
173 V: Serialize + Clone,
174{
175 pub fn new() -> Self {
176 Default::default()
177 }
178
179 pub fn color<C: Color>(mut self, color: C) -> Self {
180 self.color = Some(Dim::Scalar(Box::new(color)));
181 self
182 }
183
184 pub fn color_array<C: Color>(mut self, colors: Vec<C>) -> Self {
185 self.color = Some(Dim::Vector(ColorArray(colors).into()));
186 self
187 }
188
189 pub fn hover_info(mut self, hover_info: HoverInfo) -> Self {
190 self.hover_info = Some(hover_info);
191 self
192 }
193
194 pub fn hover_label(mut self, hover_label: Label) -> Self {
195 self.hover_label = Some(hover_label);
196 self
197 }
198
199 pub fn hover_template(mut self, hover_template: &str) -> Self {
200 self.hover_template = Some(Dim::Scalar(hover_template.to_string()));
201 self
202 }
203
204 pub fn line(mut self, line: Line) -> Self {
205 self.line = Some(line);
206 self
207 }
208
209 pub fn source(mut self, source: Vec<usize>) -> Self {
210 self.source = Some(source);
211 self
212 }
213
214 pub fn target(mut self, target: Vec<usize>) -> Self {
215 self.target = Some(target);
216 self
217 }
218
219 pub fn value(mut self, target: Vec<V>) -> Self {
220 self.value = Some(target);
221 self
222 }
223}
224
225#[serde_with::skip_serializing_none]
278#[derive(Serialize, Clone, FieldSetter)]
279#[field_setter(box_self, kind = "trace")]
280pub struct Sankey<V>
281where
282 V: Serialize + Clone,
283{
284 #[field_setter(default = "PlotType::Sankey")]
286 r#type: PlotType,
287 arrangement: Option<Arrangement>,
294 domain: Option<Domain>,
296 ids: Option<Vec<String>>,
299 #[serde(rename = "hoverinfo")]
305 hover_info: Option<HoverInfo>,
306 #[serde(rename = "hoverlabel")]
308 hover_label: Option<Label>,
309 #[serde(rename = "legendgrouptitle")]
311 legend_group_title: Option<LegendGroupTitle>,
312 #[serde(rename = "legendrank")]
319 legend_rank: Option<usize>,
320 link: Option<Link<V>>,
322 name: Option<String>,
325 node: Option<Node>,
327 orientation: Option<Orientation>,
329 #[serde(rename = "selectedpoints")]
334 selected_points: Option<Vec<usize>>,
335 #[serde(rename = "textfont")]
337 text_font: Option<Font>,
338 #[serde(rename = "valueformat")]
341 value_format: Option<String>,
342 #[serde(rename = "valuesuffix")]
345 value_suffix: Option<String>,
346 visible: Option<bool>,
350}
351
352impl<V> Sankey<V>
353where
354 V: Serialize + Clone,
355{
356 pub fn new() -> Box<Self> {
358 Box::default()
359 }
360}
361
362impl<V> Trace for Sankey<V>
363where
364 V: Serialize + Clone,
365{
366 fn to_json(&self) -> String {
367 serde_json::to_string(self).unwrap()
368 }
369}
370
371#[cfg(test)]
372mod tests {
373 use serde_json::{json, to_value};
374
375 use super::*;
376 use crate::color::NamedColor;
377
378 #[test]
379 fn serialize_default_sankey() {
380 let trace = Sankey::<i32>::default();
381 let expected = json!({"type": "sankey"});
382
383 assert_eq!(to_value(trace).unwrap(), expected);
384 }
385
386 #[test]
387 fn serialize_basic_sankey_trace() {
388 let trace = Sankey::new()
391 .orientation(Orientation::Horizontal)
392 .node(
393 Node::new()
394 .pad(15)
395 .thickness(30)
396 .line(Line::new().color(NamedColor::Black).width(0.5))
397 .label(vec!["A1", "A2", "B1", "B2", "C1", "C2"])
398 .color_array(vec![
399 NamedColor::Blue,
400 NamedColor::Blue,
401 NamedColor::Blue,
402 NamedColor::Blue,
403 NamedColor::Blue,
404 NamedColor::Blue,
405 ]),
406 )
407 .link(
408 Link::new()
409 .value(vec![8, 4, 2, 8, 4, 2])
410 .source(vec![0, 1, 0, 2, 3, 3])
411 .target(vec![2, 3, 3, 4, 4, 5]),
412 );
413
414 let expected = json!({
415 "link": {
416 "source": [0, 1, 0, 2, 3, 3],
417 "target": [2, 3, 3, 4, 4, 5],
418 "value": [8, 4, 2, 8, 4, 2]
419 },
420 "orientation": "h",
421 "type": "sankey",
422 "node": {
423 "color": ["blue", "blue", "blue", "blue", "blue", "blue"],
424 "label": ["A1", "A2", "B1", "B2", "C1", "C2"],
425 "line": {
426 "color": "black",
427 "width": 0.5
428 },
429 "pad": 15,
430 "thickness": 30
431 }
432 });
433
434 assert_eq!(to_value(trace).unwrap(), expected);
435 }
436
437 #[test]
438 fn serialize_full_sankey_trace() {
439 let trace = Sankey::<i32>::new()
440 .name("sankey")
441 .visible(true)
442 .legend_rank(1000)
443 .legend_group_title("Legend Group Title")
444 .ids(vec!["one"])
445 .hover_info(HoverInfo::All)
446 .hover_label(Label::new())
447 .domain(Domain::new())
448 .orientation(Orientation::Horizontal)
449 .node(Node::new())
450 .link(Link::new())
451 .text_font(Font::new())
452 .selected_points(vec![0])
453 .arrangement(Arrangement::Fixed)
454 .value_format(".3f")
455 .value_suffix("nT");
456
457 let expected = json!({
458 "type": "sankey",
459 "name": "sankey",
460 "visible": true,
461 "legendrank": 1000,
462 "legendgrouptitle": {"text": "Legend Group Title"},
463 "ids": ["one"],
464 "hoverinfo": "all",
465 "hoverlabel": {},
466 "domain": {},
467 "orientation": "h",
468 "node": {},
469 "link": {},
470 "textfont": {},
471 "selectedpoints": [0],
472 "arrangement": "fixed",
473 "valueformat": ".3f",
474 "valuesuffix": "nT"
475 });
476
477 assert_eq!(to_value(trace).unwrap(), expected);
478 }
479
480 #[test]
481 fn serialize_arrangement() {
482 assert_eq!(to_value(Arrangement::Snap).unwrap(), json!("snap"));
483 assert_eq!(
484 to_value(Arrangement::Perpendicular).unwrap(),
485 json!("perpendicular")
486 );
487 assert_eq!(to_value(Arrangement::Freeform).unwrap(), json!("freeform"));
488 assert_eq!(to_value(Arrangement::Fixed).unwrap(), json!("fixed"));
489 }
490
491 #[test]
492 fn serialize_line() {
493 let line = Line::new()
494 .color_array(vec![NamedColor::Black, NamedColor::Blue])
495 .color(NamedColor::Black)
496 .width(0.1);
497 let expected = json!({
498 "color": "black",
499 "width": 0.1
500 });
501
502 assert_eq!(to_value(line).unwrap(), expected)
503 }
504
505 #[test]
506 fn serialize_node() {
507 let node = Node::new()
508 .color(NamedColor::Blue)
509 .color_array(vec![NamedColor::Blue])
510 .hover_info(HoverInfo::All)
511 .hover_label(Label::new())
512 .hover_template("template")
513 .line(Line::new())
514 .pad(5)
515 .thickness(10)
516 .x(vec![0.5])
517 .y(vec![0.25]);
518 let expected = json!({
519 "color": ["blue"],
520 "hoverinfo": "all",
521 "hoverlabel": {},
522 "hovertemplate": "template",
523 "line": {},
524 "pad": 5,
525 "thickness": 10,
526 "x": [0.5],
527 "y": [0.25]
528 });
529
530 assert_eq!(to_value(node).unwrap(), expected)
531 }
532
533 #[test]
534 fn serialize_link() {
535 let link = Link::new()
536 .color_array(vec![NamedColor::Blue])
537 .color(NamedColor::Blue)
538 .hover_info(HoverInfo::All)
539 .hover_label(Label::new())
540 .hover_template("template")
541 .line(Line::new())
542 .value(vec![2, 2, 2])
543 .source(vec![0, 1, 2])
544 .target(vec![1, 2, 0]);
545 let expected = json!({
546 "color": "blue",
547 "hoverinfo": "all",
548 "hoverlabel": {},
549 "hovertemplate": "template",
550 "line": {},
551 "source": [0, 1, 2],
552 "target": [1, 2, 0],
553 "value": [2, 2, 2],
554 });
555
556 assert_eq!(to_value(link).unwrap(), expected)
557 }
558}