Skip to main content

plotlars_core/plots/
scatterplot.rs

1use bon::bon;
2
3use polars::frame::DataFrame;
4
5use crate::{
6    components::{Axis, FacetConfig, Legend, Rgb, Shape, Text, DEFAULT_PLOTLY_COLORS},
7    ir::data::ColumnData,
8    ir::layout::LayoutIR,
9    ir::marker::MarkerIR,
10    ir::trace::{ScatterPlotIR, TraceIR},
11};
12
13/// A structure representing a scatter plot.
14///
15/// The `ScatterPlot` struct facilitates the creation and customization of scatter plots with various options
16/// for data selection, grouping, layout configuration, and aesthetic adjustments. It supports grouping of data,
17/// customization of marker shapes, colors, sizes, opacity settings, and comprehensive layout customization
18/// including titles, axes, and legends.
19///
20/// # Backend Support
21///
22/// | Backend | Supported |
23/// |---------|-----------|
24/// | Plotly  | Yes       |
25/// | Plotters| Yes       |
26///
27/// # Arguments
28///
29/// * `data` - A reference to the `DataFrame` containing the data to be plotted.
30/// * `x` - A string slice specifying the column name to be used for the x-axis (independent variable).
31/// * `y` - A string slice specifying the column name to be used for the y-axis (dependent variable).
32/// * `group` - An optional string slice specifying the column name to be used for grouping data points.
33/// * `sort_groups_by` - Optional comparator `fn(&str, &str) -> std::cmp::Ordering` to control group ordering. Groups are sorted lexically by default.
34/// * `facet` - An optional string slice specifying the column name to be used for faceting (creating multiple subplots).
35/// * `facet_config` - An optional reference to a `FacetConfig` struct for customizing facet behavior (grid dimensions, scales, gaps, etc.).
36/// * `opacity` - An optional `f64` value specifying the opacity of the plot markers (range: 0.0 to 1.0).
37/// * `size` - An optional `usize` specifying the size of the markers.
38/// * `color` - An optional `Rgb` value specifying the color of the markers. This is used when `group` is not specified.
39/// * `colors` - An optional vector of `Rgb` values specifying the colors for the markers. This is used when `group` is specified to differentiate between groups.
40/// * `shape` - An optional `Shape` specifying the shape of the markers. This is used when `group` is not specified.
41/// * `shapes` - An optional vector of `Shape` values specifying multiple shapes for the markers when plotting multiple groups.
42/// * `plot_title` - An optional `Text` struct specifying the title of the plot.
43/// * `x_title` - An optional `Text` struct specifying the title of the x-axis.
44/// * `y_title` - An optional `Text` struct specifying the title of the y-axis.
45/// * `legend_title` - An optional `Text` struct specifying the title of the legend.
46/// * `x_axis` - An optional reference to an `Axis` struct for customizing the x-axis.
47/// * `y_axis` - An optional reference to an `Axis` struct for customizing the y-axis.
48/// * `legend` - An optional reference to a `Legend` struct for customizing the legend of the plot (e.g., positioning, font, etc.).
49///
50/// # Example
51///
52/// ```rust
53/// use plotlars::{Axis, Legend, Plot, Rgb, ScatterPlot, Shape, Text, TickDirection};
54/// use polars::prelude::*;
55///
56/// let dataset = LazyCsvReader::new(PlRefPath::new("data/penguins.csv"))
57///     .finish()
58///     .unwrap()
59///     .select([
60///         col("species"),
61///         col("sex").alias("gender"),
62///         col("flipper_length_mm").cast(DataType::Int16),
63///         col("body_mass_g").cast(DataType::Int16),
64///     ])
65///     .collect()
66///     .unwrap();
67///
68/// let axis = Axis::new()
69///     .show_line(true)
70///     .tick_direction(TickDirection::OutSide)
71///     .value_thousands(true);
72///
73/// ScatterPlot::builder()
74///     .data(&dataset)
75///     .x("body_mass_g")
76///     .y("flipper_length_mm")
77///     .group("species")
78///     .sort_groups_by(|a, b| {
79///         if a.len() == b.len() {
80///             a.cmp(b)
81///         } else {
82///             a.len().cmp(&b.len())
83///         }
84///     })
85///     .opacity(0.5)
86///     .size(12)
87///     .colors(vec![
88///         Rgb(178, 34, 34),
89///         Rgb(65, 105, 225),
90///         Rgb(255, 140, 0),
91///     ])
92///     .shapes(vec![
93///         Shape::Circle,
94///         Shape::Square,
95///         Shape::Diamond,
96///     ])
97///     .plot_title(
98///         Text::from("Scatter Plot")
99///             .font("Arial")
100///             .size(20)
101///             .x(0.065)
102///     )
103///     .x_title("body mass (g)")
104///     .y_title("flipper length (mm)")
105///     .legend_title("species")
106///     .x_axis(
107///         &axis.clone()
108///             .value_range(2500.0, 6500.0)
109///     )
110///     .y_axis(
111///         &axis.clone()
112///             .value_range(170.0, 240.0)
113///     )
114///     .legend(
115///         &Legend::new()
116///             .x(0.85)
117///             .y(0.15)
118///     )
119///     .build()
120///     .plot();
121/// ```
122///
123/// ![Example](https://imgur.com/9jfO8RU.png)
124#[derive(Clone)]
125#[allow(dead_code)]
126pub struct ScatterPlot {
127    traces: Vec<TraceIR>,
128    layout: LayoutIR,
129}
130
131#[bon]
132impl ScatterPlot {
133    #[builder(on(String, into), on(Text, into))]
134    pub fn new(
135        data: &DataFrame,
136        x: &str,
137        y: &str,
138        group: Option<&str>,
139        sort_groups_by: Option<fn(&str, &str) -> std::cmp::Ordering>,
140        facet: Option<&str>,
141        facet_config: Option<&FacetConfig>,
142        opacity: Option<f64>,
143        size: Option<usize>,
144        color: Option<Rgb>,
145        colors: Option<Vec<Rgb>>,
146        shape: Option<Shape>,
147        shapes: Option<Vec<Shape>>,
148        plot_title: Option<Text>,
149        x_title: Option<Text>,
150        y_title: Option<Text>,
151        legend_title: Option<Text>,
152        x_axis: Option<&Axis>,
153        y_axis: Option<&Axis>,
154        legend: Option<&Legend>,
155    ) -> Self {
156        let grid = facet.map(|facet_column| {
157            let config = facet_config.cloned().unwrap_or_default();
158            let facet_categories =
159                crate::data::get_unique_groups(data, facet_column, config.sorter);
160            let n_facets = facet_categories.len();
161            let (ncols, nrows) =
162                crate::faceting::calculate_grid_dimensions(n_facets, config.cols, config.rows);
163            crate::ir::facet::GridSpec {
164                kind: crate::ir::facet::FacetKind::Axis,
165                rows: nrows,
166                cols: ncols,
167                h_gap: config.h_gap,
168                v_gap: config.v_gap,
169                scales: config.scales.clone(),
170                n_facets,
171                facet_categories,
172                title_style: config.title_style.clone(),
173                x_title: x_title.clone(),
174                y_title: y_title.clone(),
175                x_axis: x_axis.cloned(),
176                y_axis: y_axis.cloned(),
177                legend_title: legend_title.clone(),
178                legend: legend.cloned(),
179            }
180        });
181
182        let layout = LayoutIR {
183            title: plot_title.clone(),
184            x_title: if grid.is_some() {
185                None
186            } else {
187                x_title.clone()
188            },
189            y_title: if grid.is_some() {
190                None
191            } else {
192                y_title.clone()
193            },
194            y2_title: None,
195            z_title: None,
196            legend_title: if grid.is_some() {
197                None
198            } else {
199                legend_title.clone()
200            },
201            legend: if grid.is_some() {
202                None
203            } else {
204                legend.cloned()
205            },
206            dimensions: None,
207            bar_mode: None,
208            box_mode: None,
209            box_gap: None,
210            margin_bottom: None,
211            axes_2d: if grid.is_some() {
212                None
213            } else {
214                Some(crate::ir::layout::Axes2dIR {
215                    x_axis: x_axis.cloned(),
216                    y_axis: y_axis.cloned(),
217                    y2_axis: None,
218                })
219            },
220            scene_3d: None,
221            polar: None,
222            mapbox: None,
223            grid,
224            annotations: vec![],
225        };
226
227        let traces = match facet {
228            Some(facet_column) => {
229                let config = facet_config.cloned().unwrap_or_default();
230                Self::create_ir_traces_faceted(
231                    data,
232                    x,
233                    y,
234                    group,
235                    sort_groups_by,
236                    facet_column,
237                    &config,
238                    opacity,
239                    size,
240                    color,
241                    colors.clone(),
242                    shape,
243                    shapes.clone(),
244                )
245            }
246            None => Self::create_ir_traces(
247                data,
248                x,
249                y,
250                group,
251                sort_groups_by,
252                opacity,
253                size,
254                color,
255                colors,
256                shape,
257                shapes,
258            ),
259        };
260
261        Self { traces, layout }
262    }
263}
264
265#[bon]
266impl ScatterPlot {
267    #[builder(
268        start_fn = try_builder,
269        finish_fn = try_build,
270        builder_type = ScatterPlotTryBuilder,
271        on(String, into),
272        on(Text, into),
273    )]
274    pub fn try_new(
275        data: &DataFrame,
276        x: &str,
277        y: &str,
278        group: Option<&str>,
279        sort_groups_by: Option<fn(&str, &str) -> std::cmp::Ordering>,
280        facet: Option<&str>,
281        facet_config: Option<&FacetConfig>,
282        opacity: Option<f64>,
283        size: Option<usize>,
284        color: Option<Rgb>,
285        colors: Option<Vec<Rgb>>,
286        shape: Option<Shape>,
287        shapes: Option<Vec<Shape>>,
288        plot_title: Option<Text>,
289        x_title: Option<Text>,
290        y_title: Option<Text>,
291        legend_title: Option<Text>,
292        x_axis: Option<&Axis>,
293        y_axis: Option<&Axis>,
294        legend: Option<&Legend>,
295    ) -> Result<Self, crate::io::PlotlarsError> {
296        std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
297            Self::__orig_new(
298                data,
299                x,
300                y,
301                group,
302                sort_groups_by,
303                facet,
304                facet_config,
305                opacity,
306                size,
307                color,
308                colors,
309                shape,
310                shapes,
311                plot_title,
312                x_title,
313                y_title,
314                legend_title,
315                x_axis,
316                y_axis,
317                legend,
318            )
319        }))
320        .map_err(|panic| {
321            let msg = panic
322                .downcast_ref::<String>()
323                .cloned()
324                .or_else(|| panic.downcast_ref::<&str>().map(|s| s.to_string()))
325                .unwrap_or_else(|| "unknown error".to_string());
326            crate::io::PlotlarsError::PlotBuild { message: msg }
327        })
328    }
329}
330
331impl ScatterPlot {
332    #[allow(clippy::too_many_arguments)]
333    fn create_ir_traces(
334        data: &DataFrame,
335        x: &str,
336        y: &str,
337        group: Option<&str>,
338        sort_groups_by: Option<fn(&str, &str) -> std::cmp::Ordering>,
339        opacity: Option<f64>,
340        size: Option<usize>,
341        color: Option<Rgb>,
342        colors: Option<Vec<Rgb>>,
343        shape: Option<Shape>,
344        shapes: Option<Vec<Shape>>,
345    ) -> Vec<TraceIR> {
346        let mut traces = Vec::new();
347
348        match group {
349            Some(group_col) => {
350                let groups = crate::data::get_unique_groups(data, group_col, sort_groups_by);
351
352                for (i, group_name) in groups.iter().enumerate() {
353                    let subset = crate::data::filter_data_by_group(data, group_col, group_name);
354
355                    let marker_ir = MarkerIR {
356                        opacity,
357                        size,
358                        color: Self::resolve_color(i, color, colors.clone()),
359                        shape: Self::resolve_shape(i, shape, shapes.clone()),
360                    };
361
362                    traces.push(TraceIR::ScatterPlot(ScatterPlotIR {
363                        x: ColumnData::Numeric(crate::data::get_numeric_column(&subset, x)),
364                        y: ColumnData::Numeric(crate::data::get_numeric_column(&subset, y)),
365                        name: Some(group_name.to_string()),
366                        marker: Some(marker_ir),
367                        fill: None,
368                        show_legend: None,
369                        legend_group: None,
370                        subplot_ref: None,
371                    }));
372                }
373            }
374            None => {
375                let marker_ir = MarkerIR {
376                    opacity,
377                    size,
378                    color: Self::resolve_color(0, color, colors),
379                    shape: Self::resolve_shape(0, shape, shapes),
380                };
381
382                traces.push(TraceIR::ScatterPlot(ScatterPlotIR {
383                    x: ColumnData::Numeric(crate::data::get_numeric_column(data, x)),
384                    y: ColumnData::Numeric(crate::data::get_numeric_column(data, y)),
385                    name: None,
386                    marker: Some(marker_ir),
387                    fill: None,
388                    show_legend: None,
389                    legend_group: None,
390                    subplot_ref: None,
391                }));
392            }
393        }
394
395        traces
396    }
397
398    #[allow(clippy::too_many_arguments)]
399    fn create_ir_traces_faceted(
400        data: &DataFrame,
401        x: &str,
402        y: &str,
403        group: Option<&str>,
404        sort_groups_by: Option<fn(&str, &str) -> std::cmp::Ordering>,
405        facet_column: &str,
406        config: &FacetConfig,
407        opacity: Option<f64>,
408        size: Option<usize>,
409        color: Option<Rgb>,
410        colors: Option<Vec<Rgb>>,
411        shape: Option<Shape>,
412        shapes: Option<Vec<Shape>>,
413    ) -> Vec<TraceIR> {
414        const MAX_FACETS: usize = 8;
415
416        let facet_categories = crate::data::get_unique_groups(data, facet_column, config.sorter);
417
418        if facet_categories.len() > MAX_FACETS {
419            panic!(
420                "Facet column '{}' has {} unique values, but plotly.rs supports maximum {} subplots",
421                facet_column,
422                facet_categories.len(),
423                MAX_FACETS
424            );
425        }
426
427        if let Some(ref color_vec) = colors {
428            if group.is_none() {
429                let color_count = color_vec.len();
430                let facet_count = facet_categories.len();
431                if color_count != facet_count {
432                    panic!(
433                        "When using colors with facet (without group), colors.len() must equal number of facets. \
434                         Expected {} colors for {} facets, but got {} colors. \
435                         Each facet must be assigned exactly one color.",
436                        facet_count, facet_count, color_count
437                    );
438                }
439            } else if let Some(group_col) = group {
440                let groups = crate::data::get_unique_groups(data, group_col, sort_groups_by);
441                let color_count = color_vec.len();
442                let group_count = groups.len();
443                if color_count < group_count {
444                    panic!(
445                        "When using colors with group, colors.len() must be >= number of groups. \
446                         Need at least {} colors for {} groups, but got {} colors",
447                        group_count, group_count, color_count
448                    );
449                }
450            }
451        }
452
453        let global_group_indices: std::collections::HashMap<String, usize> =
454            if let Some(group_col) = group {
455                let global_groups = crate::data::get_unique_groups(data, group_col, sort_groups_by);
456                global_groups
457                    .into_iter()
458                    .enumerate()
459                    .map(|(idx, group_name)| (group_name, idx))
460                    .collect()
461            } else {
462                std::collections::HashMap::new()
463            };
464
465        let colors = if group.is_some() && colors.is_none() {
466            Some(DEFAULT_PLOTLY_COLORS.to_vec())
467        } else {
468            colors
469        };
470
471        let mut traces = Vec::new();
472
473        if config.highlight_facet {
474            for (facet_idx, facet_value) in facet_categories.iter().enumerate() {
475                let subplot_ref = format!(
476                    "{}{}",
477                    crate::faceting::get_axis_reference(facet_idx, "x"),
478                    crate::faceting::get_axis_reference(facet_idx, "y")
479                );
480
481                for other_facet_value in facet_categories.iter() {
482                    if other_facet_value != facet_value {
483                        let other_data = crate::data::filter_data_by_group(
484                            data,
485                            facet_column,
486                            other_facet_value,
487                        );
488
489                        let grey_color = config.unhighlighted_color.unwrap_or(Rgb(200, 200, 200));
490                        let marker_ir = MarkerIR {
491                            opacity,
492                            size,
493                            color: Some(grey_color),
494                            shape: Self::resolve_shape(0, shape, None),
495                        };
496
497                        traces.push(TraceIR::ScatterPlot(ScatterPlotIR {
498                            x: ColumnData::Numeric(crate::data::get_numeric_column(&other_data, x)),
499                            y: ColumnData::Numeric(crate::data::get_numeric_column(&other_data, y)),
500                            name: None,
501                            marker: Some(marker_ir),
502                            fill: None,
503                            show_legend: Some(false),
504                            legend_group: None,
505                            subplot_ref: Some(subplot_ref.clone()),
506                        }));
507                    }
508                }
509
510                let facet_data = crate::data::filter_data_by_group(data, facet_column, facet_value);
511
512                match group {
513                    Some(group_col) => {
514                        let groups =
515                            crate::data::get_unique_groups(&facet_data, group_col, sort_groups_by);
516
517                        for group_val in groups.iter() {
518                            let group_data = crate::data::filter_data_by_group(
519                                &facet_data,
520                                group_col,
521                                group_val,
522                            );
523
524                            let global_idx =
525                                global_group_indices.get(group_val).copied().unwrap_or(0);
526
527                            let marker_ir = MarkerIR {
528                                opacity,
529                                size,
530                                color: Self::resolve_color(global_idx, color, colors.clone()),
531                                shape: Self::resolve_shape(global_idx, shape, shapes.clone()),
532                            };
533
534                            traces.push(TraceIR::ScatterPlot(ScatterPlotIR {
535                                x: ColumnData::Numeric(crate::data::get_numeric_column(
536                                    &group_data,
537                                    x,
538                                )),
539                                y: ColumnData::Numeric(crate::data::get_numeric_column(
540                                    &group_data,
541                                    y,
542                                )),
543                                name: Some(group_val.to_string()),
544                                marker: Some(marker_ir),
545                                fill: None,
546                                show_legend: Some(facet_idx == 0),
547                                legend_group: Some(group_val.to_string()),
548                                subplot_ref: Some(subplot_ref.clone()),
549                            }));
550                        }
551                    }
552                    None => {
553                        let marker_ir = MarkerIR {
554                            opacity,
555                            size,
556                            color: Self::resolve_color(facet_idx, color, colors.clone()),
557                            shape: Self::resolve_shape(facet_idx, shape, shapes.clone()),
558                        };
559
560                        traces.push(TraceIR::ScatterPlot(ScatterPlotIR {
561                            x: ColumnData::Numeric(crate::data::get_numeric_column(&facet_data, x)),
562                            y: ColumnData::Numeric(crate::data::get_numeric_column(&facet_data, y)),
563                            name: None,
564                            marker: Some(marker_ir),
565                            fill: None,
566                            show_legend: Some(false),
567                            legend_group: None,
568                            subplot_ref: Some(subplot_ref.clone()),
569                        }));
570                    }
571                }
572            }
573        } else {
574            for (facet_idx, facet_value) in facet_categories.iter().enumerate() {
575                let facet_data = crate::data::filter_data_by_group(data, facet_column, facet_value);
576
577                let subplot_ref = format!(
578                    "{}{}",
579                    crate::faceting::get_axis_reference(facet_idx, "x"),
580                    crate::faceting::get_axis_reference(facet_idx, "y")
581                );
582
583                match group {
584                    Some(group_col) => {
585                        let groups =
586                            crate::data::get_unique_groups(&facet_data, group_col, sort_groups_by);
587
588                        for group_val in groups.iter() {
589                            let group_data = crate::data::filter_data_by_group(
590                                &facet_data,
591                                group_col,
592                                group_val,
593                            );
594
595                            let global_idx =
596                                global_group_indices.get(group_val).copied().unwrap_or(0);
597
598                            let marker_ir = MarkerIR {
599                                opacity,
600                                size,
601                                color: Self::resolve_color(global_idx, color, colors.clone()),
602                                shape: Self::resolve_shape(global_idx, shape, shapes.clone()),
603                            };
604
605                            traces.push(TraceIR::ScatterPlot(ScatterPlotIR {
606                                x: ColumnData::Numeric(crate::data::get_numeric_column(
607                                    &group_data,
608                                    x,
609                                )),
610                                y: ColumnData::Numeric(crate::data::get_numeric_column(
611                                    &group_data,
612                                    y,
613                                )),
614                                name: Some(group_val.to_string()),
615                                marker: Some(marker_ir),
616                                fill: None,
617                                show_legend: Some(facet_idx == 0),
618                                legend_group: Some(group_val.to_string()),
619                                subplot_ref: Some(subplot_ref.clone()),
620                            }));
621                        }
622                    }
623                    None => {
624                        let marker_ir = MarkerIR {
625                            opacity,
626                            size,
627                            color: Self::resolve_color(facet_idx, color, colors.clone()),
628                            shape: Self::resolve_shape(facet_idx, shape, shapes.clone()),
629                        };
630
631                        traces.push(TraceIR::ScatterPlot(ScatterPlotIR {
632                            x: ColumnData::Numeric(crate::data::get_numeric_column(&facet_data, x)),
633                            y: ColumnData::Numeric(crate::data::get_numeric_column(&facet_data, y)),
634                            name: None,
635                            marker: Some(marker_ir),
636                            fill: None,
637                            show_legend: Some(false),
638                            legend_group: None,
639                            subplot_ref: Some(subplot_ref.clone()),
640                        }));
641                    }
642                }
643            }
644        }
645
646        traces
647    }
648
649    fn resolve_color(index: usize, color: Option<Rgb>, colors: Option<Vec<Rgb>>) -> Option<Rgb> {
650        if let Some(c) = color {
651            return Some(c);
652        }
653        if let Some(ref cs) = colors {
654            return cs.get(index).copied();
655        }
656        None
657    }
658
659    fn resolve_shape(
660        index: usize,
661        shape: Option<Shape>,
662        shapes: Option<Vec<Shape>>,
663    ) -> Option<Shape> {
664        if let Some(s) = shape {
665            return Some(s);
666        }
667        if let Some(ref ss) = shapes {
668            return ss.get(index).cloned();
669        }
670        None
671    }
672}
673
674impl crate::Plot for ScatterPlot {
675    fn ir_traces(&self) -> &[TraceIR] {
676        &self.traces
677    }
678
679    fn ir_layout(&self) -> &LayoutIR {
680        &self.layout
681    }
682}
683
684#[cfg(test)]
685mod tests {
686    use super::*;
687    use crate::Plot;
688    use polars::prelude::*;
689
690    fn assert_rgb(actual: Option<Rgb>, r: u8, g: u8, b: u8) {
691        let c = actual.expect("expected Some(Rgb)");
692        assert_eq!((c.0, c.1, c.2), (r, g, b));
693    }
694
695    #[test]
696    fn test_resolve_color_singular_priority() {
697        let result =
698            ScatterPlot::resolve_color(0, Some(Rgb(255, 0, 0)), Some(vec![Rgb(0, 0, 255)]));
699        assert_rgb(result, 255, 0, 0);
700    }
701
702    #[test]
703    fn test_resolve_color_from_vec() {
704        let result = ScatterPlot::resolve_color(
705            1,
706            None,
707            Some(vec![Rgb(1, 0, 0), Rgb(0, 1, 0), Rgb(0, 0, 1)]),
708        );
709        assert_rgb(result, 0, 1, 0);
710    }
711
712    #[test]
713    fn test_resolve_color_out_of_bounds() {
714        let result = ScatterPlot::resolve_color(5, None, Some(vec![Rgb(1, 0, 0)]));
715        assert!(result.is_none());
716    }
717
718    #[test]
719    fn test_resolve_color_both_none() {
720        let result = ScatterPlot::resolve_color(0, None, None);
721        assert!(result.is_none());
722    }
723
724    #[test]
725    fn test_resolve_shape_singular_priority() {
726        let result = ScatterPlot::resolve_shape(0, Some(Shape::Circle), Some(vec![Shape::Square]));
727        assert!(matches!(result, Some(Shape::Circle)));
728    }
729
730    #[test]
731    fn test_resolve_shape_from_vec() {
732        let result = ScatterPlot::resolve_shape(
733            1,
734            None,
735            Some(vec![Shape::Circle, Shape::Diamond, Shape::Square]),
736        );
737        assert!(matches!(result, Some(Shape::Diamond)));
738    }
739
740    #[test]
741    fn test_resolve_shape_out_of_bounds() {
742        let result = ScatterPlot::resolve_shape(5, None, Some(vec![Shape::Circle]));
743        assert!(result.is_none());
744    }
745
746    #[test]
747    fn test_resolve_shape_both_none() {
748        let result = ScatterPlot::resolve_shape(0, None, None);
749        assert!(result.is_none());
750    }
751
752    #[test]
753    fn test_no_group_one_trace() {
754        let df = df!["x" => [1.0, 2.0, 3.0], "y" => [4.0, 5.0, 6.0]].unwrap();
755        let plot = ScatterPlot::builder().data(&df).x("x").y("y").build();
756        assert_eq!(plot.ir_traces().len(), 1);
757    }
758
759    #[test]
760    fn test_with_group_multiple_traces() {
761        let df = df![
762            "x" => [1.0, 2.0, 3.0, 4.0],
763            "y" => [4.0, 5.0, 6.0, 7.0],
764            "g" => ["a", "b", "a", "b"]
765        ]
766        .unwrap();
767        let plot = ScatterPlot::builder()
768            .data(&df)
769            .x("x")
770            .y("y")
771            .group("g")
772            .build();
773        assert_eq!(plot.ir_traces().len(), 2);
774    }
775
776    #[test]
777    fn test_faceted_trace_count() {
778        let df = df![
779            "x" => [1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
780            "y" => [10.0, 20.0, 30.0, 40.0, 50.0, 60.0],
781            "f" => ["a", "b", "c", "a", "b", "c"]
782        ]
783        .unwrap();
784        let plot = ScatterPlot::builder()
785            .data(&df)
786            .x("x")
787            .y("y")
788            .facet("f")
789            .build();
790        assert_eq!(plot.ir_traces().len(), 3);
791    }
792
793    #[test]
794    fn test_faceted_with_group_trace_count() {
795        let df = df![
796            "x" => [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
797            "y" => [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0],
798            "f" => ["f1", "f1", "f1", "f1", "f2", "f2", "f2", "f2"],
799            "g" => ["g1", "g2", "g1", "g2", "g1", "g2", "g1", "g2"]
800        ]
801        .unwrap();
802        let plot = ScatterPlot::builder()
803            .data(&df)
804            .x("x")
805            .y("y")
806            .facet("f")
807            .group("g")
808            .build();
809        // 2 facets * 2 groups = 4 traces
810        assert_eq!(plot.ir_traces().len(), 4);
811    }
812
813    #[test]
814    fn test_faceted_show_legend_first_only() {
815        let df = df![
816            "x" => [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
817            "y" => [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0],
818            "f" => ["f1", "f1", "f1", "f1", "f2", "f2", "f2", "f2"],
819            "g" => ["g1", "g2", "g1", "g2", "g1", "g2", "g1", "g2"]
820        ]
821        .unwrap();
822        let plot = ScatterPlot::builder()
823            .data(&df)
824            .x("x")
825            .y("y")
826            .facet("f")
827            .group("g")
828            .build();
829
830        for trace in plot.ir_traces() {
831            match trace {
832                TraceIR::ScatterPlot(ir) => {
833                    let subplot = ir.subplot_ref.as_deref().unwrap();
834                    if subplot == "xy" {
835                        // First facet -> show legend
836                        assert_eq!(ir.show_legend, Some(true));
837                    } else {
838                        // Later facets -> hide legend
839                        assert_eq!(ir.show_legend, Some(false));
840                    }
841                }
842                _ => panic!("expected ScatterPlot trace"),
843            }
844        }
845    }
846
847    #[test]
848    fn test_faceted_subplot_ref() {
849        let df = df![
850            "x" => [1.0, 2.0, 3.0, 4.0],
851            "y" => [10.0, 20.0, 30.0, 40.0],
852            "f" => ["a", "b", "a", "b"]
853        ]
854        .unwrap();
855        let plot = ScatterPlot::builder()
856            .data(&df)
857            .x("x")
858            .y("y")
859            .facet("f")
860            .build();
861
862        let refs: Vec<&str> = plot
863            .ir_traces()
864            .iter()
865            .map(|t| match t {
866                TraceIR::ScatterPlot(ir) => ir.subplot_ref.as_deref().unwrap(),
867                _ => panic!("expected ScatterPlot trace"),
868            })
869            .collect();
870        assert_eq!(refs[0], "xy");
871        assert_eq!(refs[1], "x2y2");
872    }
873
874    #[test]
875    #[should_panic(expected = "maximum")]
876    fn test_max_facets_panics() {
877        let facet_values: Vec<&str> = (0..9)
878            .map(|i| match i {
879                0 => "a",
880                1 => "b",
881                2 => "c",
882                3 => "d",
883                4 => "e",
884                5 => "f",
885                6 => "g",
886                7 => "h",
887                _ => "i",
888            })
889            .collect();
890        let n = facet_values.len();
891        let x_vals: Vec<f64> = (0..n).map(|i| i as f64).collect();
892        let y_vals: Vec<f64> = (0..n).map(|i| i as f64 * 10.0).collect();
893        let df = DataFrame::new(
894            n,
895            vec![
896                Column::new("x".into(), &x_vals),
897                Column::new("y".into(), &y_vals),
898                Column::new("f".into(), &facet_values),
899            ],
900        )
901        .unwrap();
902        ScatterPlot::builder()
903            .data(&df)
904            .x("x")
905            .y("y")
906            .facet("f")
907            .build();
908    }
909
910    #[test]
911    #[should_panic(expected = "colors.len() must equal number of facets")]
912    fn test_faceted_colors_mismatch_panics() {
913        let df = df![
914            "x" => [1.0, 2.0, 3.0],
915            "y" => [10.0, 20.0, 30.0],
916            "f" => ["a", "b", "c"]
917        ]
918        .unwrap();
919        ScatterPlot::builder()
920            .data(&df)
921            .x("x")
922            .y("y")
923            .facet("f")
924            .colors(vec![Rgb(255, 0, 0), Rgb(0, 255, 0)])
925            .build();
926    }
927}