Skip to main content

plotlars_core/plots/
scatter3dplot.rs

1use bon::bon;
2
3use polars::frame::DataFrame;
4
5use crate::{
6    components::{FacetConfig, Legend, Rgb, Shape, Text, DEFAULT_PLOTLY_COLORS},
7    ir::data::ColumnData,
8    ir::layout::LayoutIR,
9    ir::marker::MarkerIR,
10    ir::trace::{Scatter3dPlotIR, TraceIR},
11};
12
13/// A structure representing a 3D scatter plot.
14///
15/// The `Scatter3dPlot` struct is designed to create and customize 3D scatter plots with options for data selection,
16/// grouping, layout configuration, and aesthetic adjustments. It supports visual differentiation in data groups
17/// through varied marker shapes, colors, sizes, opacity levels, and comprehensive layout customization, including
18/// titles, axis labels, and legends.
19///
20/// # Backend Support
21///
22/// | Backend | Supported |
23/// |---------|-----------|
24/// | Plotly  | Yes       |
25/// | Plotters| --        |
26///
27/// # Arguments
28///
29/// * `data` - A reference to the `DataFrame` containing the data to be plotted.
30/// * `x` - A string slice specifying the column name to be used for the x-axis (independent variable).
31/// * `y` - A string slice specifying the column name to be used for the y-axis (dependent variable).
32/// * `z` - A string slice specifying the column name to be used for the z-axis, adding a third dimension to the scatter plot.
33/// * `group` - An optional string slice specifying the column name used for grouping data points by category.
34/// * `sort_groups_by` - Optional comparator `fn(&str, &str) -> std::cmp::Ordering` to control group ordering. Groups are sorted lexically by default.
35/// * `facet` - An optional string slice specifying the column name to be used for faceting (creating multiple subplots).
36/// * `facet_config` - An optional reference to a `FacetConfig` struct for customizing facet behavior (grid dimensions, scales, gaps, etc.).
37/// * `opacity` - An optional `f64` value specifying the opacity of the plot markers (range: 0.0 to 1.0).
38/// * `size` - An optional `usize` specifying the size of the markers.
39/// * `color` - An optional `Rgb` value for marker color when `group` is not specified.
40/// * `colors` - An optional vector of `Rgb` values specifying colors for markers when `group` is specified, enhancing group distinction.
41/// * `shape` - An optional `Shape` specifying the shape of markers when `group` is not specified.
42/// * `shapes` - An optional vector of `Shape` values defining multiple marker shapes for different groups.
43/// * `plot_title` - An optional `Text` struct specifying the plot title.
44/// * `x_title` - An optional `Text` struct for the x-axis title.
45/// * `y_title` - An optional `Text` struct for the y-axis title.
46/// * `z_title` - An optional `Text` struct for the z-axis title.
47/// * `legend_title` - An optional `Text` struct specifying the legend title.
48/// * `x_axis` - An optional reference to an `Axis` struct for custom x-axis settings.
49/// * `y_axis` - An optional reference to an `Axis` struct for custom y-axis settings.
50/// * `z_axis` - An optional reference to an `Axis` struct for custom z-axis settings, adding depth perspective.
51/// * `legend` - An optional reference to a `Legend` struct for legend customization, including position and font settings.
52///
53/// # Example
54///
55/// ```rust
56/// use plotlars::{Legend, Plot, Rgb, Scatter3dPlot, Shape};
57/// use polars::prelude::*;
58///
59/// let dataset = LazyCsvReader::new(PlRefPath::new("data/penguins.csv"))
60///     .finish()
61///     .unwrap()
62///     .select([
63///         col("species"),
64///         col("sex").alias("gender"),
65///         col("bill_length_mm").cast(DataType::Float32),
66///         col("flipper_length_mm").cast(DataType::Int16),
67///         col("body_mass_g").cast(DataType::Int16),
68///     ])
69///     .collect()
70///     .unwrap();
71///
72/// Scatter3dPlot::builder()
73///     .data(&dataset)
74///     .x("body_mass_g")
75///     .y("flipper_length_mm")
76///     .z("bill_length_mm")
77///     .group("species")
78///     .opacity(0.25)
79///     .size(8)
80///     .colors(vec![
81///         Rgb(178, 34, 34),
82///         Rgb(65, 105, 225),
83///         Rgb(255, 140, 0),
84///     ])
85///     .shapes(vec![
86///         Shape::Circle,
87///         Shape::Square,
88///         Shape::Diamond,
89///     ])
90///     .plot_title("Scatter Plot")
91///     .legend(
92///         &Legend::new()
93///             .x(0.6)
94///     )
95///     .build()
96///     .plot();
97/// ```
98///
99/// ![Example](https://imgur.com/WYTQxHA.png)
100#[derive(Clone)]
101#[allow(dead_code)]
102pub struct Scatter3dPlot {
103    traces: Vec<TraceIR>,
104    layout: LayoutIR,
105}
106
107#[bon]
108impl Scatter3dPlot {
109    #[builder(on(String, into), on(Text, into))]
110    pub fn new(
111        data: &DataFrame,
112        x: &str,
113        y: &str,
114        z: &str,
115        group: Option<&str>,
116        sort_groups_by: Option<fn(&str, &str) -> std::cmp::Ordering>,
117        facet: Option<&str>,
118        facet_config: Option<&FacetConfig>,
119        opacity: Option<f64>,
120        size: Option<usize>,
121        color: Option<Rgb>,
122        colors: Option<Vec<Rgb>>,
123        shape: Option<Shape>,
124        shapes: Option<Vec<Shape>>,
125        plot_title: Option<Text>,
126        legend: Option<&Legend>,
127    ) -> Self {
128        let grid = facet.map(|facet_column| {
129            let config = facet_config.cloned().unwrap_or_default();
130            let facet_categories =
131                crate::data::get_unique_groups(data, facet_column, config.sorter);
132            let n_facets = facet_categories.len();
133            let (ncols, nrows) =
134                crate::faceting::calculate_grid_dimensions(n_facets, config.cols, config.rows);
135            crate::ir::facet::GridSpec {
136                kind: crate::ir::facet::FacetKind::Scene,
137                rows: nrows,
138                cols: ncols,
139                h_gap: config.h_gap,
140                v_gap: config.v_gap,
141                scales: config.scales.clone(),
142                n_facets,
143                facet_categories,
144                title_style: config.title_style.clone(),
145                x_title: None,
146                y_title: None,
147                x_axis: None,
148                y_axis: None,
149                legend_title: None,
150                legend: legend.cloned(),
151            }
152        });
153
154        let traces = match facet {
155            Some(facet_column) => {
156                let config = facet_config.cloned().unwrap_or_default();
157                Self::create_ir_traces_faceted(
158                    data,
159                    x,
160                    y,
161                    z,
162                    group,
163                    sort_groups_by,
164                    facet_column,
165                    &config,
166                    opacity,
167                    size,
168                    color,
169                    colors,
170                    shape,
171                    shapes,
172                )
173            }
174            None => Self::create_ir_traces(
175                data,
176                x,
177                y,
178                z,
179                group,
180                sort_groups_by,
181                opacity,
182                size,
183                color,
184                colors,
185                shape,
186                shapes,
187            ),
188        };
189
190        let layout = LayoutIR {
191            title: plot_title,
192            x_title: None,
193            y_title: None,
194            y2_title: None,
195            z_title: None,
196            legend_title: None,
197            legend: if grid.is_some() {
198                None
199            } else {
200                legend.cloned()
201            },
202            dimensions: None,
203            bar_mode: None,
204            box_mode: None,
205            box_gap: None,
206            margin_bottom: None,
207            axes_2d: None,
208            scene_3d: None,
209            polar: None,
210            mapbox: None,
211            grid,
212            annotations: vec![],
213        };
214
215        Self { traces, layout }
216    }
217}
218
219#[bon]
220impl Scatter3dPlot {
221    #[builder(
222        start_fn = try_builder,
223        finish_fn = try_build,
224        builder_type = Scatter3dPlotTryBuilder,
225        on(String, into),
226        on(Text, into),
227    )]
228    pub fn try_new(
229        data: &DataFrame,
230        x: &str,
231        y: &str,
232        z: &str,
233        group: Option<&str>,
234        sort_groups_by: Option<fn(&str, &str) -> std::cmp::Ordering>,
235        facet: Option<&str>,
236        facet_config: Option<&FacetConfig>,
237        opacity: Option<f64>,
238        size: Option<usize>,
239        color: Option<Rgb>,
240        colors: Option<Vec<Rgb>>,
241        shape: Option<Shape>,
242        shapes: Option<Vec<Shape>>,
243        plot_title: Option<Text>,
244        legend: Option<&Legend>,
245    ) -> Result<Self, crate::io::PlotlarsError> {
246        std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
247            Self::__orig_new(
248                data,
249                x,
250                y,
251                z,
252                group,
253                sort_groups_by,
254                facet,
255                facet_config,
256                opacity,
257                size,
258                color,
259                colors,
260                shape,
261                shapes,
262                plot_title,
263                legend,
264            )
265        }))
266        .map_err(|panic| {
267            let msg = panic
268                .downcast_ref::<String>()
269                .cloned()
270                .or_else(|| panic.downcast_ref::<&str>().map(|s| s.to_string()))
271                .unwrap_or_else(|| "unknown error".to_string());
272            crate::io::PlotlarsError::PlotBuild { message: msg }
273        })
274    }
275}
276
277impl Scatter3dPlot {
278    fn get_scene_reference(index: usize) -> String {
279        match index {
280            0 => "scene".to_string(),
281            1 => "scene2".to_string(),
282            2 => "scene3".to_string(),
283            3 => "scene4".to_string(),
284            4 => "scene5".to_string(),
285            5 => "scene6".to_string(),
286            6 => "scene7".to_string(),
287            7 => "scene8".to_string(),
288            _ => "scene".to_string(),
289        }
290    }
291
292    #[allow(clippy::too_many_arguments)]
293    fn create_ir_traces(
294        data: &DataFrame,
295        x: &str,
296        y: &str,
297        z: &str,
298        group: Option<&str>,
299        sort_groups_by: Option<fn(&str, &str) -> std::cmp::Ordering>,
300        opacity: Option<f64>,
301        size: Option<usize>,
302        color: Option<Rgb>,
303        colors: Option<Vec<Rgb>>,
304        shape: Option<Shape>,
305        shapes: Option<Vec<Shape>>,
306    ) -> Vec<TraceIR> {
307        let mut traces = Vec::new();
308
309        match group {
310            Some(group_col) => {
311                let groups = crate::data::get_unique_groups(data, group_col, sort_groups_by);
312
313                for (i, group_name) in groups.iter().enumerate() {
314                    let subset = crate::data::filter_data_by_group(data, group_col, group_name);
315
316                    let marker_ir = MarkerIR {
317                        opacity,
318                        size,
319                        color: Self::resolve_color(i, color, colors.clone()),
320                        shape: Self::resolve_shape(i, shape, shapes.clone()),
321                    };
322
323                    traces.push(TraceIR::Scatter3dPlot(Scatter3dPlotIR {
324                        x: ColumnData::Numeric(crate::data::get_numeric_column(&subset, x)),
325                        y: ColumnData::Numeric(crate::data::get_numeric_column(&subset, y)),
326                        z: ColumnData::Numeric(crate::data::get_numeric_column(&subset, z)),
327                        name: Some(group_name.to_string()),
328                        mode: None,
329                        marker: Some(marker_ir),
330                        show_legend: None,
331                        legend_group: None,
332                        scene_ref: None,
333                    }));
334                }
335            }
336            None => {
337                let marker_ir = MarkerIR {
338                    opacity,
339                    size,
340                    color: Self::resolve_color(0, color, colors),
341                    shape: Self::resolve_shape(0, shape, shapes),
342                };
343
344                traces.push(TraceIR::Scatter3dPlot(Scatter3dPlotIR {
345                    x: ColumnData::Numeric(crate::data::get_numeric_column(data, x)),
346                    y: ColumnData::Numeric(crate::data::get_numeric_column(data, y)),
347                    z: ColumnData::Numeric(crate::data::get_numeric_column(data, z)),
348                    name: None,
349                    mode: None,
350                    marker: Some(marker_ir),
351                    show_legend: None,
352                    legend_group: None,
353                    scene_ref: None,
354                }));
355            }
356        }
357
358        traces
359    }
360
361    #[allow(clippy::too_many_arguments)]
362    fn create_ir_traces_faceted(
363        data: &DataFrame,
364        x: &str,
365        y: &str,
366        z: &str,
367        group: Option<&str>,
368        sort_groups_by: Option<fn(&str, &str) -> std::cmp::Ordering>,
369        facet_column: &str,
370        config: &FacetConfig,
371        opacity: Option<f64>,
372        size: Option<usize>,
373        color: Option<Rgb>,
374        colors: Option<Vec<Rgb>>,
375        shape: Option<Shape>,
376        shapes: Option<Vec<Shape>>,
377    ) -> Vec<TraceIR> {
378        const MAX_FACETS: usize = 8;
379
380        let facet_categories = crate::data::get_unique_groups(data, facet_column, config.sorter);
381
382        if facet_categories.len() > MAX_FACETS {
383            panic!(
384                "Facet column '{}' has {} unique values, but plotly.rs supports maximum {} 3D scenes",
385                facet_column,
386                facet_categories.len(),
387                MAX_FACETS
388            );
389        }
390
391        if let Some(ref color_vec) = colors {
392            if group.is_none() {
393                let color_count = color_vec.len();
394                let facet_count = facet_categories.len();
395                if color_count != facet_count {
396                    panic!(
397                        "When using colors with facet (without group), colors.len() must equal number of facets. \
398                         Expected {} colors for {} facets, but got {} colors.",
399                        facet_count, facet_count, color_count
400                    );
401                }
402            } else if let Some(group_col) = group {
403                let groups = crate::data::get_unique_groups(data, group_col, sort_groups_by);
404                let color_count = color_vec.len();
405                let group_count = groups.len();
406                if color_count < group_count {
407                    panic!(
408                        "When using colors with group, colors.len() must be >= number of groups. \
409                         Need at least {} colors for {} groups, but got {} colors",
410                        group_count, group_count, color_count
411                    );
412                }
413            }
414        }
415
416        let global_group_indices: std::collections::HashMap<String, usize> =
417            if let Some(group_col) = group {
418                let global_groups = crate::data::get_unique_groups(data, group_col, sort_groups_by);
419                global_groups
420                    .into_iter()
421                    .enumerate()
422                    .map(|(idx, group_name)| (group_name, idx))
423                    .collect()
424            } else {
425                std::collections::HashMap::new()
426            };
427
428        let colors = if group.is_some() && colors.is_none() {
429            Some(DEFAULT_PLOTLY_COLORS.to_vec())
430        } else {
431            colors
432        };
433
434        let mut traces = Vec::new();
435
436        if config.highlight_facet {
437            for (facet_idx, facet_value) in facet_categories.iter().enumerate() {
438                let scene = Self::get_scene_reference(facet_idx);
439
440                for other_facet_value in facet_categories.iter() {
441                    if other_facet_value != facet_value {
442                        let other_data = crate::data::filter_data_by_group(
443                            data,
444                            facet_column,
445                            other_facet_value,
446                        );
447
448                        let grey_color = config.unhighlighted_color.unwrap_or(Rgb(200, 200, 200));
449                        let marker_ir = MarkerIR {
450                            opacity,
451                            size,
452                            color: Some(grey_color),
453                            shape: Self::resolve_shape(0, shape, None),
454                        };
455
456                        traces.push(TraceIR::Scatter3dPlot(Scatter3dPlotIR {
457                            x: ColumnData::Numeric(crate::data::get_numeric_column(&other_data, x)),
458                            y: ColumnData::Numeric(crate::data::get_numeric_column(&other_data, y)),
459                            z: ColumnData::Numeric(crate::data::get_numeric_column(&other_data, z)),
460                            name: None,
461                            mode: None,
462                            marker: Some(marker_ir),
463                            show_legend: Some(false),
464                            legend_group: None,
465                            scene_ref: Some(scene.clone()),
466                        }));
467                    }
468                }
469
470                let facet_data = crate::data::filter_data_by_group(data, facet_column, facet_value);
471
472                match group {
473                    Some(group_col) => {
474                        let groups =
475                            crate::data::get_unique_groups(&facet_data, group_col, sort_groups_by);
476
477                        for group_val in groups.iter() {
478                            let group_data = crate::data::filter_data_by_group(
479                                &facet_data,
480                                group_col,
481                                group_val,
482                            );
483
484                            let global_idx =
485                                global_group_indices.get(group_val).copied().unwrap_or(0);
486
487                            let marker_ir = MarkerIR {
488                                opacity,
489                                size,
490                                color: Self::resolve_color(global_idx, color, colors.clone()),
491                                shape: Self::resolve_shape(global_idx, shape, shapes.clone()),
492                            };
493
494                            traces.push(TraceIR::Scatter3dPlot(Scatter3dPlotIR {
495                                x: ColumnData::Numeric(crate::data::get_numeric_column(
496                                    &group_data,
497                                    x,
498                                )),
499                                y: ColumnData::Numeric(crate::data::get_numeric_column(
500                                    &group_data,
501                                    y,
502                                )),
503                                z: ColumnData::Numeric(crate::data::get_numeric_column(
504                                    &group_data,
505                                    z,
506                                )),
507                                name: Some(group_val.to_string()),
508                                mode: None,
509                                marker: Some(marker_ir),
510                                show_legend: Some(facet_idx == 0),
511                                legend_group: Some(group_val.to_string()),
512                                scene_ref: Some(scene.clone()),
513                            }));
514                        }
515                    }
516                    None => {
517                        let marker_ir = MarkerIR {
518                            opacity,
519                            size,
520                            color: Self::resolve_color(facet_idx, color, colors.clone()),
521                            shape: Self::resolve_shape(facet_idx, shape, shapes.clone()),
522                        };
523
524                        traces.push(TraceIR::Scatter3dPlot(Scatter3dPlotIR {
525                            x: ColumnData::Numeric(crate::data::get_numeric_column(&facet_data, x)),
526                            y: ColumnData::Numeric(crate::data::get_numeric_column(&facet_data, y)),
527                            z: ColumnData::Numeric(crate::data::get_numeric_column(&facet_data, z)),
528                            name: None,
529                            mode: None,
530                            marker: Some(marker_ir),
531                            show_legend: Some(false),
532                            legend_group: None,
533                            scene_ref: Some(scene.clone()),
534                        }));
535                    }
536                }
537            }
538        } else {
539            for (facet_idx, facet_value) in facet_categories.iter().enumerate() {
540                let facet_data = crate::data::filter_data_by_group(data, facet_column, facet_value);
541                let scene = Self::get_scene_reference(facet_idx);
542
543                match group {
544                    Some(group_col) => {
545                        let groups =
546                            crate::data::get_unique_groups(&facet_data, group_col, sort_groups_by);
547
548                        for group_val in groups.iter() {
549                            let group_data = crate::data::filter_data_by_group(
550                                &facet_data,
551                                group_col,
552                                group_val,
553                            );
554
555                            let global_idx =
556                                global_group_indices.get(group_val).copied().unwrap_or(0);
557
558                            let marker_ir = MarkerIR {
559                                opacity,
560                                size,
561                                color: Self::resolve_color(global_idx, color, colors.clone()),
562                                shape: Self::resolve_shape(global_idx, shape, shapes.clone()),
563                            };
564
565                            traces.push(TraceIR::Scatter3dPlot(Scatter3dPlotIR {
566                                x: ColumnData::Numeric(crate::data::get_numeric_column(
567                                    &group_data,
568                                    x,
569                                )),
570                                y: ColumnData::Numeric(crate::data::get_numeric_column(
571                                    &group_data,
572                                    y,
573                                )),
574                                z: ColumnData::Numeric(crate::data::get_numeric_column(
575                                    &group_data,
576                                    z,
577                                )),
578                                name: Some(group_val.to_string()),
579                                mode: None,
580                                marker: Some(marker_ir),
581                                show_legend: Some(facet_idx == 0),
582                                legend_group: Some(group_val.to_string()),
583                                scene_ref: Some(scene.clone()),
584                            }));
585                        }
586                    }
587                    None => {
588                        let marker_ir = MarkerIR {
589                            opacity,
590                            size,
591                            color: Self::resolve_color(facet_idx, color, colors.clone()),
592                            shape: Self::resolve_shape(facet_idx, shape, shapes.clone()),
593                        };
594
595                        traces.push(TraceIR::Scatter3dPlot(Scatter3dPlotIR {
596                            x: ColumnData::Numeric(crate::data::get_numeric_column(&facet_data, x)),
597                            y: ColumnData::Numeric(crate::data::get_numeric_column(&facet_data, y)),
598                            z: ColumnData::Numeric(crate::data::get_numeric_column(&facet_data, z)),
599                            name: None,
600                            mode: None,
601                            marker: Some(marker_ir),
602                            show_legend: Some(false),
603                            legend_group: None,
604                            scene_ref: Some(scene.clone()),
605                        }));
606                    }
607                }
608            }
609        }
610
611        traces
612    }
613
614    fn resolve_color(index: usize, color: Option<Rgb>, colors: Option<Vec<Rgb>>) -> Option<Rgb> {
615        if let Some(c) = color {
616            return Some(c);
617        }
618        if let Some(ref cs) = colors {
619            return cs.get(index).copied();
620        }
621        None
622    }
623
624    fn resolve_shape(
625        index: usize,
626        shape: Option<Shape>,
627        shapes: Option<Vec<Shape>>,
628    ) -> Option<Shape> {
629        if let Some(s) = shape {
630            return Some(s);
631        }
632        if let Some(ref ss) = shapes {
633            return ss.get(index).cloned();
634        }
635        None
636    }
637}
638
639impl crate::Plot for Scatter3dPlot {
640    fn ir_traces(&self) -> &[TraceIR] {
641        &self.traces
642    }
643
644    fn ir_layout(&self) -> &LayoutIR {
645        &self.layout
646    }
647}
648
649#[cfg(test)]
650mod tests {
651    use super::*;
652    use crate::Plot;
653    use polars::prelude::*;
654
655    #[test]
656    fn test_basic_one_trace() {
657        let df = df![
658            "x" => [1.0, 2.0, 3.0],
659            "y" => [4.0, 5.0, 6.0],
660            "z" => [7.0, 8.0, 9.0]
661        ]
662        .unwrap();
663        let plot = Scatter3dPlot::builder()
664            .data(&df)
665            .x("x")
666            .y("y")
667            .z("z")
668            .build();
669        assert_eq!(plot.ir_traces().len(), 1);
670        assert!(matches!(plot.ir_traces()[0], TraceIR::Scatter3dPlot(_)));
671    }
672
673    #[test]
674    fn test_with_group() {
675        let df = df![
676            "x" => [1.0, 2.0, 3.0, 4.0],
677            "y" => [4.0, 5.0, 6.0, 7.0],
678            "z" => [7.0, 8.0, 9.0, 10.0],
679            "g" => ["a", "b", "a", "b"]
680        ]
681        .unwrap();
682        let plot = Scatter3dPlot::builder()
683            .data(&df)
684            .x("x")
685            .y("y")
686            .z("z")
687            .group("g")
688            .build();
689        assert_eq!(plot.ir_traces().len(), 2);
690    }
691
692    #[test]
693    fn test_layout_no_axes_2d() {
694        let df = df![
695            "x" => [1.0, 2.0],
696            "y" => [3.0, 4.0],
697            "z" => [5.0, 6.0]
698        ]
699        .unwrap();
700        let plot = Scatter3dPlot::builder()
701            .data(&df)
702            .x("x")
703            .y("y")
704            .z("z")
705            .build();
706        assert!(plot.ir_layout().axes_2d.is_none());
707    }
708
709    #[test]
710    fn test_resolve_color_both_none() {
711        let result = Scatter3dPlot::resolve_color(0, None, None);
712        assert!(result.is_none());
713    }
714}