1use plotly_derive::FieldSetter;
4use serde::Serialize;
5
6use crate::{
7 color::{Color, ColorArray},
8 common::{Dim, Domain, Font, HoverInfo, Label, LegendGroupTitle, Orientation, PlotType},
9 Trace,
10};
11
12#[derive(Serialize, Clone)]
13#[serde(rename_all = "lowercase")]
14pub enum Arrangement {
15 Snap,
16 Perpendicular,
17 Freeform,
18 Fixed,
19}
20
21#[serde_with::skip_serializing_none]
22#[derive(Serialize, Clone, Default)]
23pub struct Line {
24 color: Option<Dim<Box<dyn Color>>>,
25 width: Option<f64>,
26}
27
28impl Line {
29 pub fn new() -> Self {
30 Default::default()
31 }
32
33 pub fn color<C: Color>(mut self, color: C) -> Self {
34 self.color = Some(Dim::Scalar(Box::new(color)));
35 self
36 }
37
38 pub fn color_array<C: Color>(mut self, colors: Vec<C>) -> Self {
39 self.color = Some(Dim::Vector(ColorArray(colors).into()));
40 self
41 }
42
43 pub fn width(mut self, width: f64) -> Self {
44 self.width = Some(width);
45 self
46 }
47}
48
49#[serde_with::skip_serializing_none]
50#[derive(Serialize, Default, Clone)]
51pub struct Node {
52 color: Option<Dim<Box<dyn Color>>>,
54 #[serde(rename = "hoverinfo")]
55 hover_info: Option<HoverInfo>,
56 #[serde(rename = "hoverlabel")]
57 hover_label: Option<Label>,
58 #[serde(rename = "hovertemplate")]
59 hover_template: Option<Dim<String>>,
60 label: Option<Vec<String>>,
61 line: Option<Line>,
62 pad: Option<usize>,
63 thickness: Option<usize>,
64 x: Option<f64>,
65 y: Option<f64>,
66}
67
68impl Node {
69 pub fn new() -> Self {
70 Default::default()
71 }
72
73 pub fn color<C: Color>(mut self, color: C) -> Self {
74 self.color = Some(Dim::Scalar(Box::new(color)));
75 self
76 }
77
78 pub fn color_array<C: Color>(mut self, colors: Vec<C>) -> Self {
79 self.color = Some(Dim::Vector(ColorArray(colors).into()));
80 self
81 }
82
83 pub fn hover_info(mut self, hover_info: HoverInfo) -> Self {
84 self.hover_info = Some(hover_info);
85 self
86 }
87
88 pub fn hover_label(mut self, hover_label: Label) -> Self {
89 self.hover_label = Some(hover_label);
90 self
91 }
92
93 pub fn hover_template(mut self, hover_template: &str) -> Self {
94 self.hover_template = Some(Dim::Scalar(hover_template.to_string()));
95 self
96 }
97
98 pub fn label(mut self, label: Vec<&str>) -> Self {
99 self.label = Some(label.iter().map(|&el| el.to_string()).collect());
100 self
101 }
102
103 pub fn line(mut self, line: Line) -> Self {
104 self.line = Some(line);
105 self
106 }
107
108 pub fn pad(mut self, pad: usize) -> Self {
109 self.pad = Some(pad);
110 self
111 }
112
113 pub fn thickness(mut self, thickness: usize) -> Self {
114 self.thickness = Some(thickness);
115 self
116 }
117
118 pub fn x(mut self, x: f64) -> Self {
119 self.x = Some(x);
120 self
121 }
122
123 pub fn y(mut self, y: f64) -> Self {
124 self.y = Some(y);
125 self
126 }
127}
128
129#[serde_with::skip_serializing_none]
130#[derive(Serialize, Clone)]
131pub struct Link<V>
132where
133 V: Serialize + Clone,
134{
135 color: Option<Dim<Box<dyn Color>>>,
137 #[serde(rename = "hoverinfo")]
138 hover_info: Option<HoverInfo>,
139 #[serde(rename = "hoverlabel")]
140 hover_label: Option<Label>,
141 #[serde(rename = "hovertemplate")]
142 hover_template: Option<Dim<String>>,
143 line: Option<Line>,
144 source: Option<Vec<usize>>,
145 target: Option<Vec<usize>>,
146 value: Option<Vec<V>>,
147}
148
149impl<V> Default for Link<V>
150where
151 V: Serialize + Clone,
152{
153 fn default() -> Self {
154 Self {
155 color: None,
156 hover_info: None,
157 hover_label: None,
158 hover_template: None,
159 line: None,
160 source: None,
161 target: None,
162 value: None,
163 }
164 }
165}
166
167impl<V> Link<V>
168where
169 V: Serialize + Clone,
170{
171 pub fn new() -> Self {
172 Default::default()
173 }
174
175 pub fn color<C: Color>(mut self, color: C) -> Self {
176 self.color = Some(Dim::Scalar(Box::new(color)));
177 self
178 }
179
180 pub fn color_array<C: Color>(mut self, colors: Vec<C>) -> Self {
181 self.color = Some(Dim::Vector(ColorArray(colors).into()));
182 self
183 }
184
185 pub fn hover_info(mut self, hover_info: HoverInfo) -> Self {
186 self.hover_info = Some(hover_info);
187 self
188 }
189
190 pub fn hover_label(mut self, hover_label: Label) -> Self {
191 self.hover_label = Some(hover_label);
192 self
193 }
194
195 pub fn hover_template(mut self, hover_template: &str) -> Self {
196 self.hover_template = Some(Dim::Scalar(hover_template.to_string()));
197 self
198 }
199
200 pub fn line(mut self, line: Line) -> Self {
201 self.line = Some(line);
202 self
203 }
204
205 pub fn source(mut self, source: Vec<usize>) -> Self {
206 self.source = Some(source);
207 self
208 }
209
210 pub fn target(mut self, target: Vec<usize>) -> Self {
211 self.target = Some(target);
212 self
213 }
214
215 pub fn value(mut self, target: Vec<V>) -> Self {
216 self.value = Some(target);
217 self
218 }
219}
220
221#[serde_with::skip_serializing_none]
274#[derive(Serialize, Clone, FieldSetter)]
275#[field_setter(box_self, kind = "trace")]
276pub struct Sankey<V>
277where
278 V: Serialize + Clone,
279{
280 #[field_setter(default = "PlotType::Sankey")]
282 r#type: PlotType,
283 arrangement: Option<Arrangement>,
290 domain: Option<Domain>,
292 ids: Option<Vec<String>>,
295 #[serde(rename = "hoverinfo")]
301 hover_info: Option<HoverInfo>,
302 #[serde(rename = "hoverlabel")]
304 hover_label: Option<Label>,
305 #[serde(rename = "legendgrouptitle")]
307 legend_group_title: Option<LegendGroupTitle>,
308 #[serde(rename = "legendrank")]
315 legend_rank: Option<usize>,
316 link: Option<Link<V>>,
318 name: Option<String>,
321 node: Option<Node>,
323 orientation: Option<Orientation>,
325 #[serde(rename = "selectedpoints")]
330 selected_points: Option<Vec<usize>>,
331 #[serde(rename = "textfont")]
333 text_font: Option<Font>,
334 #[serde(rename = "valueformat")]
337 value_format: Option<String>,
338 #[serde(rename = "valuesuffix")]
341 value_suffix: Option<String>,
342 visible: Option<bool>,
346}
347
348impl<V> Sankey<V>
349where
350 V: Serialize + Clone,
351{
352 pub fn new() -> Box<Self> {
354 Box::default()
355 }
356}
357
358impl<V> Trace for Sankey<V>
359where
360 V: Serialize + Clone,
361{
362 fn to_json(&self) -> String {
363 serde_json::to_string(self).unwrap()
364 }
365}
366
367#[cfg(test)]
368mod tests {
369 use serde_json::{json, to_value};
370
371 use super::*;
372 use crate::color::NamedColor;
373
374 #[test]
375 fn test_serialize_default_sankey() {
376 let trace = Sankey::<i32>::default();
377 let expected = json!({"type": "sankey"});
378
379 assert_eq!(to_value(trace).unwrap(), expected);
380 }
381
382 #[test]
383 fn test_serialize_basic_sankey_trace() {
384 let trace = Sankey::new()
387 .orientation(Orientation::Horizontal)
388 .node(
389 Node::new()
390 .pad(15)
391 .thickness(30)
392 .line(Line::new().color(NamedColor::Black).width(0.5))
393 .label(vec!["A1", "A2", "B1", "B2", "C1", "C2"])
394 .color_array(vec![
395 NamedColor::Blue,
396 NamedColor::Blue,
397 NamedColor::Blue,
398 NamedColor::Blue,
399 NamedColor::Blue,
400 NamedColor::Blue,
401 ]),
402 )
403 .link(
404 Link::new()
405 .value(vec![8, 4, 2, 8, 4, 2])
406 .source(vec![0, 1, 0, 2, 3, 3])
407 .target(vec![2, 3, 3, 4, 4, 5]),
408 );
409
410 let expected = json!({
411 "link": {
412 "source": [0, 1, 0, 2, 3, 3],
413 "target": [2, 3, 3, 4, 4, 5],
414 "value": [8, 4, 2, 8, 4, 2]
415 },
416 "orientation": "h",
417 "type": "sankey",
418 "node": {
419 "color": ["blue", "blue", "blue", "blue", "blue", "blue"],
420 "label": ["A1", "A2", "B1", "B2", "C1", "C2"],
421 "line": {
422 "color": "black",
423 "width": 0.5
424 },
425 "pad": 15,
426 "thickness": 30
427 }
428 });
429
430 assert_eq!(to_value(trace).unwrap(), expected);
431 }
432
433 #[test]
434 fn test_serialize_full_sankey_trace() {
435 let trace = Sankey::<i32>::new()
436 .name("sankey")
437 .visible(true)
438 .legend_rank(1000)
439 .legend_group_title(LegendGroupTitle::new("Legend Group Title"))
440 .ids(vec!["one"])
441 .hover_info(HoverInfo::All)
442 .hover_label(Label::new())
443 .domain(Domain::new())
444 .orientation(Orientation::Horizontal)
445 .node(Node::new())
446 .link(Link::new())
447 .text_font(Font::new())
448 .selected_points(vec![0])
449 .arrangement(Arrangement::Fixed)
450 .value_format(".3f")
451 .value_suffix("nT");
452
453 let expected = json!({
454 "type": "sankey",
455 "name": "sankey",
456 "visible": true,
457 "legendrank": 1000,
458 "legendgrouptitle": {"text": "Legend Group Title"},
459 "ids": ["one"],
460 "hoverinfo": "all",
461 "hoverlabel": {},
462 "domain": {},
463 "orientation": "h",
464 "node": {},
465 "link": {},
466 "textfont": {},
467 "selectedpoints": [0],
468 "arrangement": "fixed",
469 "valueformat": ".3f",
470 "valuesuffix": "nT"
471 });
472
473 assert_eq!(to_value(trace).unwrap(), expected);
474 }
475
476 #[test]
477 fn test_serialize_arrangement() {
478 assert_eq!(to_value(Arrangement::Snap).unwrap(), json!("snap"));
479 assert_eq!(
480 to_value(Arrangement::Perpendicular).unwrap(),
481 json!("perpendicular")
482 );
483 assert_eq!(to_value(Arrangement::Freeform).unwrap(), json!("freeform"));
484 assert_eq!(to_value(Arrangement::Fixed).unwrap(), json!("fixed"));
485 }
486
487 #[test]
488 fn test_serialize_line() {
489 let line = Line::new()
490 .color_array(vec![NamedColor::Black, NamedColor::Blue])
491 .color(NamedColor::Black)
492 .width(0.1);
493 let expected = json!({
494 "color": "black",
495 "width": 0.1
496 });
497
498 assert_eq!(to_value(line).unwrap(), expected)
499 }
500
501 #[test]
502 fn test_serialize_node() {
503 let node = Node::new()
504 .color(NamedColor::Blue)
505 .color_array(vec![NamedColor::Blue])
506 .hover_info(HoverInfo::All)
507 .hover_label(Label::new())
508 .hover_template("template")
509 .line(Line::new())
510 .pad(5)
511 .thickness(10)
512 .x(0.5)
513 .y(0.25);
514 let expected = json!({
515 "color": ["blue"],
516 "hoverinfo": "all",
517 "hoverlabel": {},
518 "hovertemplate": "template",
519 "line": {},
520 "pad": 5,
521 "thickness": 10,
522 "x": 0.5,
523 "y": 0.25
524 });
525
526 assert_eq!(to_value(node).unwrap(), expected)
527 }
528
529 #[test]
530 fn test_serialize_link() {
531 let link = Link::new()
532 .color_array(vec![NamedColor::Blue])
533 .color(NamedColor::Blue)
534 .hover_info(HoverInfo::All)
535 .hover_label(Label::new())
536 .hover_template("template")
537 .line(Line::new())
538 .value(vec![2, 2, 2])
539 .source(vec![0, 1, 2])
540 .target(vec![1, 2, 0]);
541 let expected = json!({
542 "color": "blue",
543 "hoverinfo": "all",
544 "hoverlabel": {},
545 "hovertemplate": "template",
546 "line": {},
547 "source": [0, 1, 2],
548 "target": [1, 2, 0],
549 "value": [2, 2, 2],
550 });
551
552 assert_eq!(to_value(link).unwrap(), expected)
553 }
554}