Skip to main content

plotlars/plots/
scatterpolar.rs

1use bon::bon;
2
3use plotly::{
4    common::{Line as LinePlotly, Marker as MarkerPlotly},
5    layout::Margin,
6    Layout as LayoutPlotly, ScatterPolar as ScatterPolarPlotly, Trace,
7};
8
9use polars::frame::DataFrame;
10use serde::Serialize;
11
12use crate::{
13    common::{Layout, Marker, PlotHelper, Polar},
14    components::{
15        FacetConfig, Fill, Legend, Line as LineStyle, Mode, Rgb, Shape, Text, DEFAULT_PLOTLY_COLORS,
16    },
17};
18
19/// A structure representing a scatter polar plot.
20///
21/// The `ScatterPolar` struct facilitates the creation and customization of polar scatter plots with various options
22/// for data selection, grouping, layout configuration, and aesthetic adjustments. It supports grouping of data,
23/// customization of marker shapes, colors, sizes, line styles, and comprehensive layout customization
24/// including titles and legends.
25///
26/// # Arguments
27///
28/// * `data` - A reference to the `DataFrame` containing the data to be plotted.
29/// * `theta` - A string slice specifying the column name to be used for the angular coordinates (in degrees).
30/// * `r` - A string slice specifying the column name to be used for the radial coordinates.
31/// * `group` - An optional string slice specifying the column name to be used for grouping data points.
32/// * `sort_groups_by` - Optional comparator `fn(&str, &str) -> std::cmp::Ordering` to control group ordering. Groups are sorted lexically by default.
33/// * `facet` - An optional string slice specifying the column name to be used for faceting (creating multiple subplots).
34/// * `facet_config` - An optional reference to a `FacetConfig` struct for customizing facet behavior (grid dimensions, scales, gaps, etc.).
35/// * `mode` - An optional `Mode` specifying the drawing mode (lines, markers, or both). Defaults to markers.
36/// * `opacity` - An optional `f64` value specifying the opacity of the plot elements (range: 0.0 to 1.0).
37/// * `fill` - An optional `Fill` type specifying how to fill the area under the trace.
38/// * `size` - An optional `usize` specifying the size of the markers.
39/// * `color` - An optional `Rgb` value specifying the color of the markers. This is used when `group` is not specified.
40/// * `colors` - An optional vector of `Rgb` values specifying the colors for the markers. This is used when `group` is specified to differentiate between groups.
41/// * `shape` - An optional `Shape` specifying the shape of the markers. This is used when `group` is not specified.
42/// * `shapes` - An optional vector of `Shape` values specifying multiple shapes for the markers when plotting multiple groups.
43/// * `width` - An optional `f64` specifying the width of the lines.
44/// * `line` - An optional `LineStyle` specifying the style of the line (e.g., solid, dashed).
45/// * `lines` - An optional vector of `LineStyle` enums specifying the styles of lines for multiple traces.
46/// * `plot_title` - An optional `Text` struct specifying the title of the plot.
47/// * `legend_title` - An optional `Text` struct specifying the title of the legend.
48/// * `legend` - An optional reference to a `Legend` struct for customizing the legend of the plot (e.g., positioning, font, etc.).
49///
50/// # Example
51///
52/// ```rust
53/// use plotlars::{Legend, Line, Mode, Plot, Rgb, ScatterPolar, Shape, Text};
54/// use polars::prelude::*;
55///
56/// let dataset = LazyCsvReader::new(PlRefPath::new("data/product_comparison_polar.csv"))
57///     .finish()
58///     .unwrap()
59///     .collect()
60///     .unwrap();
61///
62/// ScatterPolar::builder()
63///     .data(&dataset)
64///     .theta("angle")
65///     .r("score")
66///     .group("product")
67///     .mode(Mode::LinesMarkers)
68///     .colors(vec![
69///         Rgb(255, 99, 71),
70///         Rgb(60, 179, 113),
71///     ])
72///     .shapes(vec![
73///         Shape::Circle,
74///         Shape::Square,
75///     ])
76///     .lines(vec![
77///         Line::Solid,
78///         Line::Dash,
79///     ])
80///     .width(2.5)
81///     .size(8)
82///     .plot_title(
83///         Text::from("Scatter Polar Plot")
84///             .font("Arial")
85///             .size(24)
86///     )
87///     .legend_title(
88///         Text::from("Products")
89///             .font("Arial")
90///             .size(14)
91///     )
92///     .legend(
93///         &Legend::new()
94///             .x(0.85)
95///             .y(0.95)
96///     )
97///     .build()
98///     .plot();
99/// ```
100///
101/// ![Example](https://imgur.com/kl1pY9c.png)
102#[derive(Clone)]
103pub struct ScatterPolar {
104    traces: Vec<Box<dyn Trace + 'static>>,
105    layout: LayoutPlotly,
106    layout_json: Option<serde_json::Value>,
107}
108
109impl Serialize for ScatterPolar {
110    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
111    where
112        S: serde::Serializer,
113    {
114        use serde::ser::SerializeStruct;
115        let mut state = serializer.serialize_struct("ScatterPolar", 2)?;
116        state.serialize_field("traces", &self.traces)?;
117        // Use modified layout JSON if available, otherwise serialize the layout
118        if let Some(ref layout_json) = self.layout_json {
119            state.serialize_field("layout", layout_json)?;
120        } else {
121            state.serialize_field("layout", &self.layout)?;
122        }
123        state.end()
124    }
125}
126
127#[derive(Clone)]
128struct FacetGrid {
129    ncols: usize,
130    nrows: usize,
131    x_gap: f64,
132    y_gap: f64,
133}
134
135const POLAR_FACET_TITLE_HEIGHT_RATIO: f64 = 0.12;
136const POLAR_FACET_TOP_INSET_RATIO: f64 = 0.10;
137
138#[bon]
139impl ScatterPolar {
140    #[builder(on(String, into), on(Text, into))]
141    pub fn new(
142        data: &DataFrame,
143        theta: &str,
144        r: &str,
145        group: Option<&str>,
146        sort_groups_by: Option<fn(&str, &str) -> std::cmp::Ordering>,
147        facet: Option<&str>,
148        facet_config: Option<&FacetConfig>,
149        mode: Option<Mode>,
150        opacity: Option<f64>,
151        fill: Option<Fill>,
152        size: Option<usize>,
153        color: Option<Rgb>,
154        colors: Option<Vec<Rgb>>,
155        shape: Option<Shape>,
156        shapes: Option<Vec<Shape>>,
157        width: Option<f64>,
158        line: Option<LineStyle>,
159        lines: Option<Vec<LineStyle>>,
160        plot_title: Option<Text>,
161        legend_title: Option<Text>,
162        legend: Option<&Legend>,
163    ) -> Self {
164        let x_title = None;
165        let y_title = None;
166        let y2_title = None;
167        let z_title = None;
168        let x_axis = None;
169        let y_axis = None;
170        let y2_axis = None;
171        let z_axis = None;
172
173        let (layout, traces, layout_json) = match facet {
174            Some(facet_column) => {
175                let config = facet_config.cloned().unwrap_or_default();
176
177                let (layout, grid) = Self::create_faceted_layout(
178                    data,
179                    facet_column,
180                    &config,
181                    plot_title,
182                    legend_title,
183                    legend,
184                );
185
186                let traces = Self::create_faceted_traces(
187                    data,
188                    theta,
189                    r,
190                    group,
191                    sort_groups_by,
192                    facet_column,
193                    &config,
194                    mode,
195                    opacity,
196                    fill,
197                    size,
198                    color,
199                    colors,
200                    shape,
201                    shapes,
202                    width,
203                    line,
204                    lines,
205                );
206
207                // Inject polar subplot domains into layout JSON
208                let mut layout_json = serde_json::to_value(&layout).unwrap();
209                Self::inject_polar_domains_static(
210                    &mut layout_json,
211                    grid.ncols,
212                    grid.nrows,
213                    grid.x_gap,
214                    grid.y_gap,
215                );
216
217                (layout, traces, Some(layout_json))
218            }
219            None => {
220                let layout = Self::create_layout(
221                    plot_title,
222                    x_title,
223                    y_title,
224                    y2_title,
225                    z_title,
226                    legend_title,
227                    x_axis,
228                    y_axis,
229                    y2_axis,
230                    z_axis,
231                    legend,
232                    None,
233                );
234
235                let traces = Self::create_traces(
236                    data,
237                    theta,
238                    r,
239                    group,
240                    sort_groups_by,
241                    mode,
242                    opacity,
243                    fill,
244                    size,
245                    color,
246                    colors,
247                    shape,
248                    shapes,
249                    width,
250                    line,
251                    lines,
252                );
253
254                (layout, traces, None)
255            }
256        };
257
258        Self {
259            traces,
260            layout,
261            layout_json,
262        }
263    }
264
265    #[allow(clippy::too_many_arguments)]
266    fn create_traces(
267        data: &DataFrame,
268        theta: &str,
269        r: &str,
270        group: Option<&str>,
271        sort_groups_by: Option<fn(&str, &str) -> std::cmp::Ordering>,
272        mode: Option<Mode>,
273        opacity: Option<f64>,
274        fill: Option<Fill>,
275        size: Option<usize>,
276        color: Option<Rgb>,
277        colors: Option<Vec<Rgb>>,
278        shape: Option<Shape>,
279        shapes: Option<Vec<Shape>>,
280        width: Option<f64>,
281        line: Option<LineStyle>,
282        lines: Option<Vec<LineStyle>>,
283    ) -> Vec<Box<dyn Trace + 'static>> {
284        let mut traces: Vec<Box<dyn Trace + 'static>> = Vec::new();
285        let mode = mode
286            .map(|m| m.to_plotly())
287            .unwrap_or(plotly::common::Mode::Markers);
288
289        match group {
290            Some(group_col) => {
291                let groups = Self::get_unique_groups(data, group_col, sort_groups_by);
292                let groups = groups.iter().map(|s| s.as_str());
293
294                for (i, group) in groups.enumerate() {
295                    let marker = Self::create_marker(
296                        i,
297                        opacity,
298                        size,
299                        color,
300                        colors.clone(),
301                        shape,
302                        shapes.clone(),
303                    );
304
305                    let line_style = Self::create_line_with_color(
306                        i,
307                        width,
308                        color,
309                        colors.clone(),
310                        line,
311                        lines.clone(),
312                    );
313
314                    let subset = Self::filter_data_by_group(data, group_col, group);
315
316                    let trace = Self::create_trace(
317                        &subset,
318                        theta,
319                        r,
320                        Some(group),
321                        mode.clone(),
322                        marker,
323                        line_style,
324                        fill,
325                    );
326
327                    traces.push(trace);
328                }
329            }
330            None => {
331                let group = None;
332
333                let marker = Self::create_marker(
334                    0,
335                    opacity,
336                    size,
337                    color,
338                    colors.clone(),
339                    shape,
340                    shapes.clone(),
341                );
342
343                let line_style = Self::create_line_with_color(
344                    0,
345                    width,
346                    color,
347                    colors.clone(),
348                    line,
349                    lines.clone(),
350                );
351
352                let trace =
353                    Self::create_trace(data, theta, r, group, mode, marker, line_style, fill);
354
355                traces.push(trace);
356            }
357        }
358
359        traces
360    }
361
362    #[allow(clippy::too_many_arguments)]
363    fn create_trace(
364        data: &DataFrame,
365        theta: &str,
366        r: &str,
367        group_name: Option<&str>,
368        mode: plotly::common::Mode,
369        marker: MarkerPlotly,
370        line: LinePlotly,
371        fill: Option<Fill>,
372    ) -> Box<dyn Trace + 'static> {
373        let theta_values = Self::get_numeric_column(data, theta);
374        let r_values = Self::get_numeric_column(data, r);
375
376        let mut trace = ScatterPolarPlotly::default()
377            .theta(theta_values)
378            .r(r_values)
379            .mode(mode);
380
381        trace = trace.marker(marker);
382        trace = trace.line(line);
383
384        if let Some(fill_type) = fill {
385            trace = trace.fill(fill_type.to_plotly());
386        }
387
388        if let Some(name) = group_name {
389            trace = trace.name(name);
390        }
391
392        trace
393    }
394
395    fn create_line_with_color(
396        index: usize,
397        width: Option<f64>,
398        color: Option<Rgb>,
399        colors: Option<Vec<Rgb>>,
400        style: Option<LineStyle>,
401        styles: Option<Vec<LineStyle>>,
402    ) -> LinePlotly {
403        let mut line = LinePlotly::new();
404
405        // Set width
406        if let Some(width) = width {
407            line = line.width(width);
408        }
409
410        // Set style
411        if let Some(style) = style {
412            line = line.dash(style.to_plotly());
413        } else if let Some(styles) = styles {
414            if let Some(style) = styles.get(index) {
415                line = line.dash(style.to_plotly());
416            }
417        }
418
419        // Set color
420        if let Some(color) = color {
421            line = line.color(color.to_plotly());
422        } else if let Some(colors) = colors {
423            if let Some(color) = colors.get(index) {
424                line = line.color(color.to_plotly());
425            }
426        }
427
428        line
429    }
430
431    fn get_polar_subplot_reference(index: usize) -> String {
432        match index {
433            0 => "polar".to_string(),
434            1 => "polar2".to_string(),
435            2 => "polar3".to_string(),
436            3 => "polar4".to_string(),
437            4 => "polar5".to_string(),
438            5 => "polar6".to_string(),
439            6 => "polar7".to_string(),
440            7 => "polar8".to_string(),
441            _ => "polar".to_string(),
442        }
443    }
444
445    #[allow(clippy::too_many_arguments)]
446    fn build_scatter_polar_trace_with_subplot(
447        data: &DataFrame,
448        theta: &str,
449        r: &str,
450        group_name: Option<&str>,
451        mode: plotly::common::Mode,
452        marker: MarkerPlotly,
453        line: LinePlotly,
454        fill: Option<Fill>,
455        subplot: Option<&str>,
456        show_legend: bool,
457        legend_group: Option<&str>,
458    ) -> Box<dyn Trace + 'static> {
459        let theta_values = Self::get_numeric_column(data, theta);
460        let r_values = Self::get_numeric_column(data, r);
461
462        let mut trace = ScatterPolarPlotly::default()
463            .theta(theta_values)
464            .r(r_values)
465            .mode(mode);
466
467        trace = trace.marker(marker);
468        trace = trace.line(line);
469
470        if let Some(fill_type) = fill {
471            trace = trace.fill(fill_type.to_plotly());
472        }
473
474        if let Some(name) = group_name {
475            trace = trace.name(name);
476        }
477
478        if let Some(subplot_ref) = subplot {
479            trace = trace.subplot(subplot_ref);
480        }
481
482        let trace = if let Some(group) = legend_group {
483            trace.legend_group(group)
484        } else {
485            trace
486        };
487
488        if !show_legend {
489            trace.show_legend(false)
490        } else {
491            trace
492        }
493    }
494
495    #[allow(clippy::too_many_arguments)]
496    fn create_faceted_traces(
497        data: &DataFrame,
498        theta: &str,
499        r: &str,
500        group: Option<&str>,
501        sort_groups_by: Option<fn(&str, &str) -> std::cmp::Ordering>,
502        facet_column: &str,
503        config: &FacetConfig,
504        mode: Option<Mode>,
505        opacity: Option<f64>,
506        fill: Option<Fill>,
507        size: Option<usize>,
508        color: Option<Rgb>,
509        colors: Option<Vec<Rgb>>,
510        shape: Option<Shape>,
511        shapes: Option<Vec<Shape>>,
512        width: Option<f64>,
513        line: Option<LineStyle>,
514        lines: Option<Vec<LineStyle>>,
515    ) -> Vec<Box<dyn Trace + 'static>> {
516        const MAX_FACETS: usize = 8;
517
518        let facet_categories = Self::get_unique_groups(data, facet_column, config.sorter);
519
520        if facet_categories.len() > MAX_FACETS {
521            panic!(
522                "Facet column '{}' has {} unique values, but plotly.rs supports maximum {} polar subplots",
523                facet_column,
524                facet_categories.len(),
525                MAX_FACETS
526            );
527        }
528
529        if let Some(ref color_vec) = colors {
530            if group.is_none() {
531                let color_count = color_vec.len();
532                let facet_count = facet_categories.len();
533
534                if color_count != facet_count {
535                    panic!(
536                        "When using colors with facet (without group), colors.len() must equal number of facets. \
537                         Expected {} colors for {} facets, but got {} colors. \
538                         Each facet must be assigned exactly one color.",
539                        facet_count, facet_count, color_count
540                    );
541                }
542            } else if let Some(group_col) = group {
543                let groups = Self::get_unique_groups(data, group_col, sort_groups_by);
544                let color_count = color_vec.len();
545                let group_count = groups.len();
546
547                if color_count < group_count {
548                    panic!(
549                        "When using colors with group, colors.len() must be >= number of groups. \
550                         Need at least {} colors for {} groups, but got {} colors",
551                        group_count, group_count, color_count
552                    );
553                }
554            }
555        }
556
557        let global_group_indices: std::collections::HashMap<String, usize> =
558            if let Some(group_col) = group {
559                let global_groups = Self::get_unique_groups(data, group_col, sort_groups_by);
560                global_groups
561                    .into_iter()
562                    .enumerate()
563                    .map(|(idx, group_name)| (group_name, idx))
564                    .collect()
565            } else {
566                std::collections::HashMap::new()
567            };
568
569        let colors = if group.is_some() && colors.is_none() {
570            Some(DEFAULT_PLOTLY_COLORS.to_vec())
571        } else {
572            colors
573        };
574
575        let mode = mode
576            .map(|m| m.to_plotly())
577            .unwrap_or(plotly::common::Mode::Markers);
578
579        let mut all_traces = Vec::new();
580
581        if config.highlight_facet {
582            for (facet_idx, facet_value) in facet_categories.iter().enumerate() {
583                let subplot = Self::get_polar_subplot_reference(facet_idx);
584
585                for other_facet_value in facet_categories.iter() {
586                    if other_facet_value != facet_value {
587                        let other_data =
588                            Self::filter_data_by_group(data, facet_column, other_facet_value);
589
590                        let grey_color = config.unhighlighted_color.unwrap_or(Rgb(200, 200, 200));
591                        let grey_marker = Self::create_marker(
592                            0,
593                            opacity,
594                            size,
595                            Some(grey_color),
596                            None,
597                            shape,
598                            None,
599                        );
600
601                        let grey_line = Self::create_line_with_color(
602                            0,
603                            width,
604                            Some(grey_color),
605                            None,
606                            line,
607                            None,
608                        );
609
610                        let trace = Self::build_scatter_polar_trace_with_subplot(
611                            &other_data,
612                            theta,
613                            r,
614                            None,
615                            mode.clone(),
616                            grey_marker,
617                            grey_line,
618                            fill,
619                            Some(&subplot),
620                            false,
621                            None,
622                        );
623
624                        all_traces.push(trace);
625                    }
626                }
627
628                let facet_data = Self::filter_data_by_group(data, facet_column, facet_value);
629
630                match group {
631                    Some(group_col) => {
632                        let groups =
633                            Self::get_unique_groups(&facet_data, group_col, sort_groups_by);
634
635                        for group_val in groups.iter() {
636                            let group_data =
637                                Self::filter_data_by_group(&facet_data, group_col, group_val);
638
639                            let global_idx =
640                                global_group_indices.get(group_val).copied().unwrap_or(0);
641
642                            let marker = Self::create_marker(
643                                global_idx,
644                                opacity,
645                                size,
646                                color,
647                                colors.clone(),
648                                shape,
649                                shapes.clone(),
650                            );
651
652                            let line_style = Self::create_line_with_color(
653                                global_idx,
654                                width,
655                                color,
656                                colors.clone(),
657                                line,
658                                lines.clone(),
659                            );
660
661                            let show_legend = facet_idx == 0;
662
663                            let trace = Self::build_scatter_polar_trace_with_subplot(
664                                &group_data,
665                                theta,
666                                r,
667                                Some(group_val),
668                                mode.clone(),
669                                marker,
670                                line_style,
671                                fill,
672                                Some(&subplot),
673                                show_legend,
674                                Some(group_val),
675                            );
676
677                            all_traces.push(trace);
678                        }
679                    }
680                    None => {
681                        let marker = Self::create_marker(
682                            facet_idx,
683                            opacity,
684                            size,
685                            color,
686                            colors.clone(),
687                            shape,
688                            shapes.clone(),
689                        );
690
691                        let line_style = Self::create_line_with_color(
692                            facet_idx,
693                            width,
694                            color,
695                            colors.clone(),
696                            line,
697                            lines.clone(),
698                        );
699
700                        let trace = Self::build_scatter_polar_trace_with_subplot(
701                            &facet_data,
702                            theta,
703                            r,
704                            None,
705                            mode.clone(),
706                            marker,
707                            line_style,
708                            fill,
709                            Some(&subplot),
710                            false,
711                            None,
712                        );
713
714                        all_traces.push(trace);
715                    }
716                }
717            }
718        } else {
719            for (facet_idx, facet_value) in facet_categories.iter().enumerate() {
720                let facet_data = Self::filter_data_by_group(data, facet_column, facet_value);
721
722                let subplot = Self::get_polar_subplot_reference(facet_idx);
723
724                match group {
725                    Some(group_col) => {
726                        let groups =
727                            Self::get_unique_groups(&facet_data, group_col, sort_groups_by);
728
729                        for group_val in groups.iter() {
730                            let group_data =
731                                Self::filter_data_by_group(&facet_data, group_col, group_val);
732
733                            let global_idx =
734                                global_group_indices.get(group_val).copied().unwrap_or(0);
735
736                            let marker = Self::create_marker(
737                                global_idx,
738                                opacity,
739                                size,
740                                color,
741                                colors.clone(),
742                                shape,
743                                shapes.clone(),
744                            );
745
746                            let line_style = Self::create_line_with_color(
747                                global_idx,
748                                width,
749                                color,
750                                colors.clone(),
751                                line,
752                                lines.clone(),
753                            );
754
755                            let show_legend = facet_idx == 0;
756
757                            let trace = Self::build_scatter_polar_trace_with_subplot(
758                                &group_data,
759                                theta,
760                                r,
761                                Some(group_val),
762                                mode.clone(),
763                                marker,
764                                line_style,
765                                fill,
766                                Some(&subplot),
767                                show_legend,
768                                Some(group_val),
769                            );
770
771                            all_traces.push(trace);
772                        }
773                    }
774                    None => {
775                        let marker = Self::create_marker(
776                            facet_idx,
777                            opacity,
778                            size,
779                            color,
780                            colors.clone(),
781                            shape,
782                            shapes.clone(),
783                        );
784
785                        let line_style = Self::create_line_with_color(
786                            facet_idx,
787                            width,
788                            color,
789                            colors.clone(),
790                            line,
791                            lines.clone(),
792                        );
793
794                        let trace = Self::build_scatter_polar_trace_with_subplot(
795                            &facet_data,
796                            theta,
797                            r,
798                            None,
799                            mode.clone(),
800                            marker,
801                            line_style,
802                            fill,
803                            Some(&subplot),
804                            false,
805                            None,
806                        );
807
808                        all_traces.push(trace);
809                    }
810                }
811            }
812        }
813
814        all_traces
815    }
816
817    fn create_faceted_layout(
818        data: &DataFrame,
819        facet_column: &str,
820        config: &FacetConfig,
821        plot_title: Option<Text>,
822        legend_title: Option<Text>,
823        legend: Option<&Legend>,
824    ) -> (LayoutPlotly, FacetGrid) {
825        let facet_categories = Self::get_unique_groups(data, facet_column, config.sorter);
826        let n_facets = facet_categories.len();
827
828        let (ncols, nrows) = Self::calculate_grid_dimensions(n_facets, config.cols, config.rows);
829
830        // Store grid dimensions for polar domain injection later
831        let x_gap = config.h_gap.unwrap_or(0.08);
832        let y_gap = config.v_gap.unwrap_or(0.12);
833
834        let grid = FacetGrid {
835            ncols,
836            nrows,
837            x_gap,
838            y_gap,
839        };
840
841        // Note: We'll inject polar subplot domain configurations manually via Plot trait
842        // since plotly.rs doesn't support LayoutPolar
843        let mut layout = LayoutPlotly::new();
844
845        if let Some(title) = plot_title {
846            layout = layout.title(title.to_plotly());
847        }
848
849        let annotations = Self::create_facet_annotations_polar(
850            &facet_categories,
851            ncols,
852            nrows,
853            config.title_style.as_ref(),
854            config.h_gap,
855            config.v_gap,
856        );
857        layout = layout.annotations(annotations);
858
859        layout = layout.legend(Legend::set_legend(legend_title, legend));
860
861        // Add margins to provide adequate space for polar subplots
862        // Top margin accounts for plot title and facet labels
863        // Side margins prevent clipping of circular polar plots
864        layout = layout.margin(Margin::new().top(140).bottom(80).left(80).right(80));
865
866        (layout, grid)
867    }
868
869    /// Calculates the geometry for a polar facet cell, including subplot domain bounds and title baseline.
870    ///
871    /// Returning both the domain and annotation placement keeps titles aligned with their subplot
872    /// while guaranteeing consistent padding above the polar chart.
873    fn calculate_polar_facet_cell(
874        subplot_index: usize,
875        ncols: usize,
876        nrows: usize,
877        x_gap: Option<f64>,
878        y_gap: Option<f64>,
879    ) -> PolarFacetCell {
880        let row = subplot_index / ncols;
881        let col = subplot_index % ncols;
882
883        let x_gap_val = x_gap.unwrap_or(0.08);
884        let y_gap_val = y_gap.unwrap_or(0.12);
885
886        let cell_width = (1.0 - x_gap_val * (ncols - 1) as f64) / ncols as f64;
887        let cell_height = (1.0 - y_gap_val * (nrows - 1) as f64) / nrows as f64;
888
889        let title_height = cell_height * POLAR_FACET_TITLE_HEIGHT_RATIO;
890        let polar_padding = cell_height * POLAR_FACET_TOP_INSET_RATIO;
891
892        let cell_x_start = col as f64 * (cell_width + x_gap_val);
893        let cell_y_top = 1.0 - row as f64 * (cell_height + y_gap_val);
894        let cell_y_bottom = cell_y_top - cell_height;
895
896        let domain_y_top = cell_y_top - title_height - polar_padding;
897        let domain_y_bottom = cell_y_bottom;
898
899        let domain_x = [cell_x_start, cell_x_start + cell_width];
900        let domain_y = [domain_y_bottom, domain_y_top];
901
902        let annotation_x = cell_x_start + cell_width / 2.0;
903        let annotation_y = cell_y_top - title_height / 2.0;
904
905        PolarFacetCell {
906            annotation_x,
907            annotation_y,
908            domain_x,
909            domain_y,
910        }
911    }
912
913    fn create_facet_annotations_polar(
914        categories: &[String],
915        ncols: usize,
916        nrows: usize,
917        title_style: Option<&Text>,
918        x_gap: Option<f64>,
919        y_gap: Option<f64>,
920    ) -> Vec<plotly::layout::Annotation> {
921        use plotly::common::Anchor;
922        use plotly::layout::Annotation;
923
924        categories
925            .iter()
926            .enumerate()
927            .map(|(i, cat)| {
928                let cell = Self::calculate_polar_facet_cell(i, ncols, nrows, x_gap, y_gap);
929
930                let mut ann = Annotation::new()
931                    .text(cat.as_str())
932                    .x_ref("paper")
933                    .y_ref("paper")
934                    .x_anchor(Anchor::Center)
935                    .y_anchor(Anchor::Bottom)
936                    .x(cell.annotation_x)
937                    .y(cell.annotation_y)
938                    .show_arrow(false);
939
940                if let Some(style) = title_style {
941                    ann = ann.font(style.to_font());
942                }
943
944                ann
945            })
946            .collect()
947    }
948}
949
950/// Helper struct containing calculated annotation positions for a polar facet cell
951struct PolarFacetCell {
952    annotation_x: f64,
953    annotation_y: f64,
954    domain_x: [f64; 2],
955    domain_y: [f64; 2],
956}
957
958impl ScatterPolar {
959    /// Injects polar subplot domain configurations into the layout JSON
960    /// This is a workaround for plotly.rs not supporting LayoutPolar configuration
961    fn inject_polar_domains_static(
962        layout_json: &mut serde_json::Value,
963        ncols: usize,
964        nrows: usize,
965        x_gap: f64,
966        y_gap: f64,
967    ) {
968        // Configure all 8 possible polar subplots (polar, polar2, ..., polar8)
969        // Traces reference these via their subplot parameter
970
971        let total_cells = (ncols * nrows).clamp(1, 8);
972
973        for i in 0..total_cells {
974            let polar_key = if i == 0 {
975                "polar".to_string()
976            } else {
977                format!("polar{}", i + 1)
978            };
979
980            let cell = Self::calculate_polar_facet_cell(i, ncols, nrows, Some(x_gap), Some(y_gap));
981
982            let compression_factor = 0.9;
983            let domain_height = cell.domain_y[1] - cell.domain_y[0];
984            let height_reduction = domain_height * (1.0 - compression_factor);
985            let compressed_domain_y = [
986                cell.domain_y[0] + height_reduction / 2.0,
987                cell.domain_y[1] - height_reduction / 2.0,
988            ];
989
990            let polar_config = serde_json::json!({
991                "domain": {
992                    "x": cell.domain_x,
993                    "y": compressed_domain_y
994                }
995            });
996
997            layout_json[polar_key] = polar_config;
998        }
999    }
1000}
1001
1002impl Layout for ScatterPolar {}
1003impl Marker for ScatterPolar {}
1004impl Polar for ScatterPolar {}
1005
1006impl PlotHelper for ScatterPolar {
1007    fn get_layout(&self) -> &LayoutPlotly {
1008        &self.layout
1009    }
1010
1011    fn get_traces(&self) -> &Vec<Box<dyn Trace + 'static>> {
1012        &self.traces
1013    }
1014
1015    fn get_layout_override(&self) -> Option<&serde_json::Value> {
1016        self.layout_json.as_ref()
1017    }
1018}