1use bon::bon;
2
3use plotly::{
4 common::{Line as LinePlotly, Marker as MarkerPlotly},
5 layout::Margin,
6 Layout as LayoutPlotly, ScatterPolar as ScatterPolarPlotly, Trace,
7};
8
9use polars::frame::DataFrame;
10use serde::Serialize;
11
12use crate::{
13 common::{Layout, Marker, PlotHelper, Polar},
14 components::{
15 FacetConfig, Fill, Legend, Line as LineStyle, Mode, Rgb, Shape, Text, DEFAULT_PLOTLY_COLORS,
16 },
17};
18
19#[derive(Clone)]
103pub struct ScatterPolar {
104 traces: Vec<Box<dyn Trace + 'static>>,
105 layout: LayoutPlotly,
106 layout_json: Option<serde_json::Value>,
107}
108
109impl Serialize for ScatterPolar {
110 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
111 where
112 S: serde::Serializer,
113 {
114 use serde::ser::SerializeStruct;
115 let mut state = serializer.serialize_struct("ScatterPolar", 2)?;
116 state.serialize_field("traces", &self.traces)?;
117 if let Some(ref layout_json) = self.layout_json {
119 state.serialize_field("layout", layout_json)?;
120 } else {
121 state.serialize_field("layout", &self.layout)?;
122 }
123 state.end()
124 }
125}
126
127#[derive(Clone)]
128struct FacetGrid {
129 ncols: usize,
130 nrows: usize,
131 x_gap: f64,
132 y_gap: f64,
133}
134
135const POLAR_FACET_TITLE_HEIGHT_RATIO: f64 = 0.12;
136const POLAR_FACET_TOP_INSET_RATIO: f64 = 0.10;
137
138#[bon]
139impl ScatterPolar {
140 #[builder(on(String, into), on(Text, into))]
141 pub fn new(
142 data: &DataFrame,
143 theta: &str,
144 r: &str,
145 group: Option<&str>,
146 sort_groups_by: Option<fn(&str, &str) -> std::cmp::Ordering>,
147 facet: Option<&str>,
148 facet_config: Option<&FacetConfig>,
149 mode: Option<Mode>,
150 opacity: Option<f64>,
151 fill: Option<Fill>,
152 size: Option<usize>,
153 color: Option<Rgb>,
154 colors: Option<Vec<Rgb>>,
155 shape: Option<Shape>,
156 shapes: Option<Vec<Shape>>,
157 width: Option<f64>,
158 line: Option<LineStyle>,
159 lines: Option<Vec<LineStyle>>,
160 plot_title: Option<Text>,
161 legend_title: Option<Text>,
162 legend: Option<&Legend>,
163 ) -> Self {
164 let x_title = None;
165 let y_title = None;
166 let y2_title = None;
167 let z_title = None;
168 let x_axis = None;
169 let y_axis = None;
170 let y2_axis = None;
171 let z_axis = None;
172
173 let (layout, traces, layout_json) = match facet {
174 Some(facet_column) => {
175 let config = facet_config.cloned().unwrap_or_default();
176
177 let (layout, grid) = Self::create_faceted_layout(
178 data,
179 facet_column,
180 &config,
181 plot_title,
182 legend_title,
183 legend,
184 );
185
186 let traces = Self::create_faceted_traces(
187 data,
188 theta,
189 r,
190 group,
191 sort_groups_by,
192 facet_column,
193 &config,
194 mode,
195 opacity,
196 fill,
197 size,
198 color,
199 colors,
200 shape,
201 shapes,
202 width,
203 line,
204 lines,
205 );
206
207 let mut layout_json = serde_json::to_value(&layout).unwrap();
209 Self::inject_polar_domains_static(
210 &mut layout_json,
211 grid.ncols,
212 grid.nrows,
213 grid.x_gap,
214 grid.y_gap,
215 );
216
217 (layout, traces, Some(layout_json))
218 }
219 None => {
220 let layout = Self::create_layout(
221 plot_title,
222 x_title,
223 y_title,
224 y2_title,
225 z_title,
226 legend_title,
227 x_axis,
228 y_axis,
229 y2_axis,
230 z_axis,
231 legend,
232 None,
233 );
234
235 let traces = Self::create_traces(
236 data,
237 theta,
238 r,
239 group,
240 sort_groups_by,
241 mode,
242 opacity,
243 fill,
244 size,
245 color,
246 colors,
247 shape,
248 shapes,
249 width,
250 line,
251 lines,
252 );
253
254 (layout, traces, None)
255 }
256 };
257
258 Self {
259 traces,
260 layout,
261 layout_json,
262 }
263 }
264
265 #[allow(clippy::too_many_arguments)]
266 fn create_traces(
267 data: &DataFrame,
268 theta: &str,
269 r: &str,
270 group: Option<&str>,
271 sort_groups_by: Option<fn(&str, &str) -> std::cmp::Ordering>,
272 mode: Option<Mode>,
273 opacity: Option<f64>,
274 fill: Option<Fill>,
275 size: Option<usize>,
276 color: Option<Rgb>,
277 colors: Option<Vec<Rgb>>,
278 shape: Option<Shape>,
279 shapes: Option<Vec<Shape>>,
280 width: Option<f64>,
281 line: Option<LineStyle>,
282 lines: Option<Vec<LineStyle>>,
283 ) -> Vec<Box<dyn Trace + 'static>> {
284 let mut traces: Vec<Box<dyn Trace + 'static>> = Vec::new();
285 let mode = mode
286 .map(|m| m.to_plotly())
287 .unwrap_or(plotly::common::Mode::Markers);
288
289 match group {
290 Some(group_col) => {
291 let groups = Self::get_unique_groups(data, group_col, sort_groups_by);
292 let groups = groups.iter().map(|s| s.as_str());
293
294 for (i, group) in groups.enumerate() {
295 let marker = Self::create_marker(
296 i,
297 opacity,
298 size,
299 color,
300 colors.clone(),
301 shape,
302 shapes.clone(),
303 );
304
305 let line_style = Self::create_line_with_color(
306 i,
307 width,
308 color,
309 colors.clone(),
310 line,
311 lines.clone(),
312 );
313
314 let subset = Self::filter_data_by_group(data, group_col, group);
315
316 let trace = Self::create_trace(
317 &subset,
318 theta,
319 r,
320 Some(group),
321 mode.clone(),
322 marker,
323 line_style,
324 fill,
325 );
326
327 traces.push(trace);
328 }
329 }
330 None => {
331 let group = None;
332
333 let marker = Self::create_marker(
334 0,
335 opacity,
336 size,
337 color,
338 colors.clone(),
339 shape,
340 shapes.clone(),
341 );
342
343 let line_style = Self::create_line_with_color(
344 0,
345 width,
346 color,
347 colors.clone(),
348 line,
349 lines.clone(),
350 );
351
352 let trace =
353 Self::create_trace(data, theta, r, group, mode, marker, line_style, fill);
354
355 traces.push(trace);
356 }
357 }
358
359 traces
360 }
361
362 #[allow(clippy::too_many_arguments)]
363 fn create_trace(
364 data: &DataFrame,
365 theta: &str,
366 r: &str,
367 group_name: Option<&str>,
368 mode: plotly::common::Mode,
369 marker: MarkerPlotly,
370 line: LinePlotly,
371 fill: Option<Fill>,
372 ) -> Box<dyn Trace + 'static> {
373 let theta_values = Self::get_numeric_column(data, theta);
374 let r_values = Self::get_numeric_column(data, r);
375
376 let mut trace = ScatterPolarPlotly::default()
377 .theta(theta_values)
378 .r(r_values)
379 .mode(mode);
380
381 trace = trace.marker(marker);
382 trace = trace.line(line);
383
384 if let Some(fill_type) = fill {
385 trace = trace.fill(fill_type.to_plotly());
386 }
387
388 if let Some(name) = group_name {
389 trace = trace.name(name);
390 }
391
392 trace
393 }
394
395 fn create_line_with_color(
396 index: usize,
397 width: Option<f64>,
398 color: Option<Rgb>,
399 colors: Option<Vec<Rgb>>,
400 style: Option<LineStyle>,
401 styles: Option<Vec<LineStyle>>,
402 ) -> LinePlotly {
403 let mut line = LinePlotly::new();
404
405 if let Some(width) = width {
407 line = line.width(width);
408 }
409
410 if let Some(style) = style {
412 line = line.dash(style.to_plotly());
413 } else if let Some(styles) = styles {
414 if let Some(style) = styles.get(index) {
415 line = line.dash(style.to_plotly());
416 }
417 }
418
419 if let Some(color) = color {
421 line = line.color(color.to_plotly());
422 } else if let Some(colors) = colors {
423 if let Some(color) = colors.get(index) {
424 line = line.color(color.to_plotly());
425 }
426 }
427
428 line
429 }
430
431 fn get_polar_subplot_reference(index: usize) -> String {
432 match index {
433 0 => "polar".to_string(),
434 1 => "polar2".to_string(),
435 2 => "polar3".to_string(),
436 3 => "polar4".to_string(),
437 4 => "polar5".to_string(),
438 5 => "polar6".to_string(),
439 6 => "polar7".to_string(),
440 7 => "polar8".to_string(),
441 _ => "polar".to_string(),
442 }
443 }
444
445 #[allow(clippy::too_many_arguments)]
446 fn build_scatter_polar_trace_with_subplot(
447 data: &DataFrame,
448 theta: &str,
449 r: &str,
450 group_name: Option<&str>,
451 mode: plotly::common::Mode,
452 marker: MarkerPlotly,
453 line: LinePlotly,
454 fill: Option<Fill>,
455 subplot: Option<&str>,
456 show_legend: bool,
457 legend_group: Option<&str>,
458 ) -> Box<dyn Trace + 'static> {
459 let theta_values = Self::get_numeric_column(data, theta);
460 let r_values = Self::get_numeric_column(data, r);
461
462 let mut trace = ScatterPolarPlotly::default()
463 .theta(theta_values)
464 .r(r_values)
465 .mode(mode);
466
467 trace = trace.marker(marker);
468 trace = trace.line(line);
469
470 if let Some(fill_type) = fill {
471 trace = trace.fill(fill_type.to_plotly());
472 }
473
474 if let Some(name) = group_name {
475 trace = trace.name(name);
476 }
477
478 if let Some(subplot_ref) = subplot {
479 trace = trace.subplot(subplot_ref);
480 }
481
482 let trace = if let Some(group) = legend_group {
483 trace.legend_group(group)
484 } else {
485 trace
486 };
487
488 if !show_legend {
489 trace.show_legend(false)
490 } else {
491 trace
492 }
493 }
494
495 #[allow(clippy::too_many_arguments)]
496 fn create_faceted_traces(
497 data: &DataFrame,
498 theta: &str,
499 r: &str,
500 group: Option<&str>,
501 sort_groups_by: Option<fn(&str, &str) -> std::cmp::Ordering>,
502 facet_column: &str,
503 config: &FacetConfig,
504 mode: Option<Mode>,
505 opacity: Option<f64>,
506 fill: Option<Fill>,
507 size: Option<usize>,
508 color: Option<Rgb>,
509 colors: Option<Vec<Rgb>>,
510 shape: Option<Shape>,
511 shapes: Option<Vec<Shape>>,
512 width: Option<f64>,
513 line: Option<LineStyle>,
514 lines: Option<Vec<LineStyle>>,
515 ) -> Vec<Box<dyn Trace + 'static>> {
516 const MAX_FACETS: usize = 8;
517
518 let facet_categories = Self::get_unique_groups(data, facet_column, config.sorter);
519
520 if facet_categories.len() > MAX_FACETS {
521 panic!(
522 "Facet column '{}' has {} unique values, but plotly.rs supports maximum {} polar subplots",
523 facet_column,
524 facet_categories.len(),
525 MAX_FACETS
526 );
527 }
528
529 if let Some(ref color_vec) = colors {
530 if group.is_none() {
531 let color_count = color_vec.len();
532 let facet_count = facet_categories.len();
533
534 if color_count != facet_count {
535 panic!(
536 "When using colors with facet (without group), colors.len() must equal number of facets. \
537 Expected {} colors for {} facets, but got {} colors. \
538 Each facet must be assigned exactly one color.",
539 facet_count, facet_count, color_count
540 );
541 }
542 } else if let Some(group_col) = group {
543 let groups = Self::get_unique_groups(data, group_col, sort_groups_by);
544 let color_count = color_vec.len();
545 let group_count = groups.len();
546
547 if color_count < group_count {
548 panic!(
549 "When using colors with group, colors.len() must be >= number of groups. \
550 Need at least {} colors for {} groups, but got {} colors",
551 group_count, group_count, color_count
552 );
553 }
554 }
555 }
556
557 let global_group_indices: std::collections::HashMap<String, usize> =
558 if let Some(group_col) = group {
559 let global_groups = Self::get_unique_groups(data, group_col, sort_groups_by);
560 global_groups
561 .into_iter()
562 .enumerate()
563 .map(|(idx, group_name)| (group_name, idx))
564 .collect()
565 } else {
566 std::collections::HashMap::new()
567 };
568
569 let colors = if group.is_some() && colors.is_none() {
570 Some(DEFAULT_PLOTLY_COLORS.to_vec())
571 } else {
572 colors
573 };
574
575 let mode = mode
576 .map(|m| m.to_plotly())
577 .unwrap_or(plotly::common::Mode::Markers);
578
579 let mut all_traces = Vec::new();
580
581 if config.highlight_facet {
582 for (facet_idx, facet_value) in facet_categories.iter().enumerate() {
583 let subplot = Self::get_polar_subplot_reference(facet_idx);
584
585 for other_facet_value in facet_categories.iter() {
586 if other_facet_value != facet_value {
587 let other_data =
588 Self::filter_data_by_group(data, facet_column, other_facet_value);
589
590 let grey_color = config.unhighlighted_color.unwrap_or(Rgb(200, 200, 200));
591 let grey_marker = Self::create_marker(
592 0,
593 opacity,
594 size,
595 Some(grey_color),
596 None,
597 shape,
598 None,
599 );
600
601 let grey_line = Self::create_line_with_color(
602 0,
603 width,
604 Some(grey_color),
605 None,
606 line,
607 None,
608 );
609
610 let trace = Self::build_scatter_polar_trace_with_subplot(
611 &other_data,
612 theta,
613 r,
614 None,
615 mode.clone(),
616 grey_marker,
617 grey_line,
618 fill,
619 Some(&subplot),
620 false,
621 None,
622 );
623
624 all_traces.push(trace);
625 }
626 }
627
628 let facet_data = Self::filter_data_by_group(data, facet_column, facet_value);
629
630 match group {
631 Some(group_col) => {
632 let groups =
633 Self::get_unique_groups(&facet_data, group_col, sort_groups_by);
634
635 for group_val in groups.iter() {
636 let group_data =
637 Self::filter_data_by_group(&facet_data, group_col, group_val);
638
639 let global_idx =
640 global_group_indices.get(group_val).copied().unwrap_or(0);
641
642 let marker = Self::create_marker(
643 global_idx,
644 opacity,
645 size,
646 color,
647 colors.clone(),
648 shape,
649 shapes.clone(),
650 );
651
652 let line_style = Self::create_line_with_color(
653 global_idx,
654 width,
655 color,
656 colors.clone(),
657 line,
658 lines.clone(),
659 );
660
661 let show_legend = facet_idx == 0;
662
663 let trace = Self::build_scatter_polar_trace_with_subplot(
664 &group_data,
665 theta,
666 r,
667 Some(group_val),
668 mode.clone(),
669 marker,
670 line_style,
671 fill,
672 Some(&subplot),
673 show_legend,
674 Some(group_val),
675 );
676
677 all_traces.push(trace);
678 }
679 }
680 None => {
681 let marker = Self::create_marker(
682 facet_idx,
683 opacity,
684 size,
685 color,
686 colors.clone(),
687 shape,
688 shapes.clone(),
689 );
690
691 let line_style = Self::create_line_with_color(
692 facet_idx,
693 width,
694 color,
695 colors.clone(),
696 line,
697 lines.clone(),
698 );
699
700 let trace = Self::build_scatter_polar_trace_with_subplot(
701 &facet_data,
702 theta,
703 r,
704 None,
705 mode.clone(),
706 marker,
707 line_style,
708 fill,
709 Some(&subplot),
710 false,
711 None,
712 );
713
714 all_traces.push(trace);
715 }
716 }
717 }
718 } else {
719 for (facet_idx, facet_value) in facet_categories.iter().enumerate() {
720 let facet_data = Self::filter_data_by_group(data, facet_column, facet_value);
721
722 let subplot = Self::get_polar_subplot_reference(facet_idx);
723
724 match group {
725 Some(group_col) => {
726 let groups =
727 Self::get_unique_groups(&facet_data, group_col, sort_groups_by);
728
729 for group_val in groups.iter() {
730 let group_data =
731 Self::filter_data_by_group(&facet_data, group_col, group_val);
732
733 let global_idx =
734 global_group_indices.get(group_val).copied().unwrap_or(0);
735
736 let marker = Self::create_marker(
737 global_idx,
738 opacity,
739 size,
740 color,
741 colors.clone(),
742 shape,
743 shapes.clone(),
744 );
745
746 let line_style = Self::create_line_with_color(
747 global_idx,
748 width,
749 color,
750 colors.clone(),
751 line,
752 lines.clone(),
753 );
754
755 let show_legend = facet_idx == 0;
756
757 let trace = Self::build_scatter_polar_trace_with_subplot(
758 &group_data,
759 theta,
760 r,
761 Some(group_val),
762 mode.clone(),
763 marker,
764 line_style,
765 fill,
766 Some(&subplot),
767 show_legend,
768 Some(group_val),
769 );
770
771 all_traces.push(trace);
772 }
773 }
774 None => {
775 let marker = Self::create_marker(
776 facet_idx,
777 opacity,
778 size,
779 color,
780 colors.clone(),
781 shape,
782 shapes.clone(),
783 );
784
785 let line_style = Self::create_line_with_color(
786 facet_idx,
787 width,
788 color,
789 colors.clone(),
790 line,
791 lines.clone(),
792 );
793
794 let trace = Self::build_scatter_polar_trace_with_subplot(
795 &facet_data,
796 theta,
797 r,
798 None,
799 mode.clone(),
800 marker,
801 line_style,
802 fill,
803 Some(&subplot),
804 false,
805 None,
806 );
807
808 all_traces.push(trace);
809 }
810 }
811 }
812 }
813
814 all_traces
815 }
816
817 fn create_faceted_layout(
818 data: &DataFrame,
819 facet_column: &str,
820 config: &FacetConfig,
821 plot_title: Option<Text>,
822 legend_title: Option<Text>,
823 legend: Option<&Legend>,
824 ) -> (LayoutPlotly, FacetGrid) {
825 let facet_categories = Self::get_unique_groups(data, facet_column, config.sorter);
826 let n_facets = facet_categories.len();
827
828 let (ncols, nrows) = Self::calculate_grid_dimensions(n_facets, config.cols, config.rows);
829
830 let x_gap = config.h_gap.unwrap_or(0.08);
832 let y_gap = config.v_gap.unwrap_or(0.12);
833
834 let grid = FacetGrid {
835 ncols,
836 nrows,
837 x_gap,
838 y_gap,
839 };
840
841 let mut layout = LayoutPlotly::new();
844
845 if let Some(title) = plot_title {
846 layout = layout.title(title.to_plotly());
847 }
848
849 let annotations = Self::create_facet_annotations_polar(
850 &facet_categories,
851 ncols,
852 nrows,
853 config.title_style.as_ref(),
854 config.h_gap,
855 config.v_gap,
856 );
857 layout = layout.annotations(annotations);
858
859 layout = layout.legend(Legend::set_legend(legend_title, legend));
860
861 layout = layout.margin(Margin::new().top(140).bottom(80).left(80).right(80));
865
866 (layout, grid)
867 }
868
869 fn calculate_polar_facet_cell(
874 subplot_index: usize,
875 ncols: usize,
876 nrows: usize,
877 x_gap: Option<f64>,
878 y_gap: Option<f64>,
879 ) -> PolarFacetCell {
880 let row = subplot_index / ncols;
881 let col = subplot_index % ncols;
882
883 let x_gap_val = x_gap.unwrap_or(0.08);
884 let y_gap_val = y_gap.unwrap_or(0.12);
885
886 let cell_width = (1.0 - x_gap_val * (ncols - 1) as f64) / ncols as f64;
887 let cell_height = (1.0 - y_gap_val * (nrows - 1) as f64) / nrows as f64;
888
889 let title_height = cell_height * POLAR_FACET_TITLE_HEIGHT_RATIO;
890 let polar_padding = cell_height * POLAR_FACET_TOP_INSET_RATIO;
891
892 let cell_x_start = col as f64 * (cell_width + x_gap_val);
893 let cell_y_top = 1.0 - row as f64 * (cell_height + y_gap_val);
894 let cell_y_bottom = cell_y_top - cell_height;
895
896 let domain_y_top = cell_y_top - title_height - polar_padding;
897 let domain_y_bottom = cell_y_bottom;
898
899 let domain_x = [cell_x_start, cell_x_start + cell_width];
900 let domain_y = [domain_y_bottom, domain_y_top];
901
902 let annotation_x = cell_x_start + cell_width / 2.0;
903 let annotation_y = cell_y_top - title_height / 2.0;
904
905 PolarFacetCell {
906 annotation_x,
907 annotation_y,
908 domain_x,
909 domain_y,
910 }
911 }
912
913 fn create_facet_annotations_polar(
914 categories: &[String],
915 ncols: usize,
916 nrows: usize,
917 title_style: Option<&Text>,
918 x_gap: Option<f64>,
919 y_gap: Option<f64>,
920 ) -> Vec<plotly::layout::Annotation> {
921 use plotly::common::Anchor;
922 use plotly::layout::Annotation;
923
924 categories
925 .iter()
926 .enumerate()
927 .map(|(i, cat)| {
928 let cell = Self::calculate_polar_facet_cell(i, ncols, nrows, x_gap, y_gap);
929
930 let mut ann = Annotation::new()
931 .text(cat.as_str())
932 .x_ref("paper")
933 .y_ref("paper")
934 .x_anchor(Anchor::Center)
935 .y_anchor(Anchor::Bottom)
936 .x(cell.annotation_x)
937 .y(cell.annotation_y)
938 .show_arrow(false);
939
940 if let Some(style) = title_style {
941 ann = ann.font(style.to_font());
942 }
943
944 ann
945 })
946 .collect()
947 }
948}
949
950struct PolarFacetCell {
952 annotation_x: f64,
953 annotation_y: f64,
954 domain_x: [f64; 2],
955 domain_y: [f64; 2],
956}
957
958impl ScatterPolar {
959 fn inject_polar_domains_static(
962 layout_json: &mut serde_json::Value,
963 ncols: usize,
964 nrows: usize,
965 x_gap: f64,
966 y_gap: f64,
967 ) {
968 let total_cells = (ncols * nrows).clamp(1, 8);
972
973 for i in 0..total_cells {
974 let polar_key = if i == 0 {
975 "polar".to_string()
976 } else {
977 format!("polar{}", i + 1)
978 };
979
980 let cell = Self::calculate_polar_facet_cell(i, ncols, nrows, Some(x_gap), Some(y_gap));
981
982 let compression_factor = 0.9;
983 let domain_height = cell.domain_y[1] - cell.domain_y[0];
984 let height_reduction = domain_height * (1.0 - compression_factor);
985 let compressed_domain_y = [
986 cell.domain_y[0] + height_reduction / 2.0,
987 cell.domain_y[1] - height_reduction / 2.0,
988 ];
989
990 let polar_config = serde_json::json!({
991 "domain": {
992 "x": cell.domain_x,
993 "y": compressed_domain_y
994 }
995 });
996
997 layout_json[polar_key] = polar_config;
998 }
999 }
1000}
1001
1002impl Layout for ScatterPolar {}
1003impl Marker for ScatterPolar {}
1004impl Polar for ScatterPolar {}
1005
1006impl PlotHelper for ScatterPolar {
1007 fn get_layout(&self) -> &LayoutPlotly {
1008 &self.layout
1009 }
1010
1011 fn get_traces(&self) -> &Vec<Box<dyn Trace + 'static>> {
1012 &self.traces
1013 }
1014
1015 fn get_layout_override(&self) -> Option<&serde_json::Value> {
1016 self.layout_json.as_ref()
1017 }
1018}