Skip to main content

plotlars_core/plots/
surfaceplot.rs

1use bon::bon;
2use indexmap::IndexSet;
3use ordered_float::OrderedFloat;
4
5use crate::{
6    components::{ColorBar, FacetConfig, Legend, Lighting, Palette, Text},
7    ir::data::ColumnData,
8    ir::layout::LayoutIR,
9    ir::trace::{SurfacePlotIR, TraceIR},
10};
11use polars::frame::DataFrame;
12
13/// A structure representing a 3-D surface plot.
14///
15/// The `SurfacePlot` struct is designed to build and customize 3-dimensional
16/// surface visualizations.  It supports fine-grained control over the color
17/// mapping of *z* values, interactive color bars, lighting effects that enhance
18/// depth perception, and global opacity settings.  Layout elements such as the
19/// plot title and axis labels can also be configured through the builder API,
20/// allowing you to embed the surface seamlessly in complex dashboards.
21///
22/// # Backend Support
23///
24/// | Backend | Supported |
25/// |---------|-----------|
26/// | Plotly  | Yes       |
27/// | Plotters| --        |
28///
29/// # Arguments
30///
31/// * `data` – A reference to the `DataFrame` that supplies the data.
32///   It must contain three numeric columns whose names are given by `x`, `y`
33///   and `z`.
34/// * `x` – The column name to be used for the x-axis coordinates.
35///   Each distinct *x* value becomes a row in the underlying *z* grid.
36/// * `y` – The column name to be used for the y-axis coordinates.
37///   Each distinct *y* value becomes a column in the *z* grid.
38/// * `z` – The column name that provides the z-axis heights.  These values
39///   are mapped to colors according to `color_scale` / `reverse_scale`.
40/// * `color_bar` – An optional Reference to a `ColorBar` describing the
41///   appearance of the color legend (length, tick formatting, border, etc.).
42/// * `color_scale` – An optional `Palette` that defines the color gradient
43///   (e.g. *Viridis*, *Cividis*).
44/// * `reverse_scale` – An optional `bool` indicating whether the chosen
45///   `color_scale` should run in the opposite direction.
46/// * `show_scale` – An optional `bool` that toggles the visibility of the
47///   color bar.  Useful when you have multiple surfaces that share an external
48///   legend.
49/// * `lighting` – An optional Reference to a `Lighting` struct that
50///   specifies *ambient*, *diffuse*, *specular* components, *roughness*,
51///   *fresnel* and light position.  Leaving it `None` applies Plotly's
52///   default Phong shading.
53/// * `opacity` – An optional `f64` in `[0.0, 1.0]` that sets the global
54///   transparency of the surface (1 = opaque, 0 = fully transparent).
55/// * `facet` – An optional string slice specifying the column name to create faceted subplots (one surface per category).
56/// * `facet_config` – An optional reference to a `FacetConfig` struct for customizing facet layout (ncol, nrow, gap sizes, etc.).
57/// * `plot_title` – An optional `Text` that customizes the title (content,
58///   font, size, alignment).
59/// * `legend` – An optional reference to a `Legend` struct for legend customization.
60///
61/// # Example
62///
63/// ```rust
64/// use ndarray::Array;
65/// use plotlars::{ColorBar, Lighting, Palette, Plot, SurfacePlot, Text};
66/// use polars::prelude::*;
67/// use std::iter;
68///
69/// let n: usize = 100;
70/// let (x_base, _): (Vec<f64>, Option<usize>) = Array::linspace(-10.0, 10.0, n).into_raw_vec_and_offset();
71/// let (y_base, _): (Vec<f64>, Option<usize>) = Array::linspace(-10.0, 10.0, n).into_raw_vec_and_offset();
72///
73/// let x = x_base
74///     .iter()
75///     .flat_map(|&xi| iter::repeat_n(xi, n))
76///     .collect::<Vec<_>>();
77///
78/// let y = y_base
79///     .iter()
80///     .cycle()
81///     .take(n * n)
82///     .cloned()
83///     .collect::<Vec<_>>();
84///
85/// let z = x_base
86///     .iter()
87///     .flat_map(|i| {
88///         y_base
89///             .iter()
90///             .map(|j| 1.0 / (j * j + 5.0) * j.sin() + 1.0 / (i * i + 5.0) * i.cos())
91///             .collect::<Vec<_>>()
92///     })
93///     .collect::<Vec<_>>();
94///
95/// let dataset = df![
96///         "x" => &x,
97///         "y" => &y,
98///         "z" => &z,
99///     ]
100///     .unwrap();
101///
102/// SurfacePlot::builder()
103///     .data(&dataset)
104///     .x("x")
105///     .y("y")
106///     .z("z")
107///     .plot_title(
108///         Text::from("Surface Plot")
109///             .font("Arial")
110///             .size(18),
111///     )
112///     .color_bar(
113///         &ColorBar::new()
114///             .border_width(1),
115///     )
116///     .color_scale(Palette::Cividis)
117///     .reverse_scale(true)
118///     .opacity(0.5)
119///     .build()
120///     .plot();
121/// ```
122///
123/// ![Example](https://imgur.com/tdVte4l.png)
124#[derive(Clone)]
125#[allow(dead_code)]
126pub struct SurfacePlot {
127    traces: Vec<TraceIR>,
128    layout: LayoutIR,
129}
130
131#[bon]
132impl SurfacePlot {
133    #[builder(on(String, into), on(Text, into))]
134    pub fn new(
135        data: &DataFrame,
136        x: &str,
137        y: &str,
138        z: &str,
139        color_bar: Option<&ColorBar>,
140        color_scale: Option<Palette>,
141        reverse_scale: Option<bool>,
142        show_scale: Option<bool>,
143        lighting: Option<&Lighting>,
144        opacity: Option<f64>,
145        facet: Option<&str>,
146        facet_config: Option<&FacetConfig>,
147        plot_title: Option<Text>,
148        legend: Option<&Legend>,
149    ) -> Self {
150        let grid = facet.map(|facet_column| {
151            let config = facet_config.cloned().unwrap_or_default();
152            let facet_categories =
153                crate::data::get_unique_groups(data, facet_column, config.sorter);
154            let n_facets = facet_categories.len();
155            let (ncols, nrows) =
156                crate::faceting::calculate_grid_dimensions(n_facets, config.cols, config.rows);
157            crate::ir::facet::GridSpec {
158                kind: crate::ir::facet::FacetKind::Scene,
159                rows: nrows,
160                cols: ncols,
161                h_gap: config.h_gap,
162                v_gap: config.v_gap,
163                scales: config.scales.clone(),
164                n_facets,
165                facet_categories,
166                title_style: config.title_style.clone(),
167                x_title: None,
168                y_title: None,
169                x_axis: None,
170                y_axis: None,
171                legend_title: None,
172                legend: legend.cloned(),
173            }
174        });
175
176        let traces = match facet {
177            Some(facet_column) => {
178                let config = facet_config.cloned().unwrap_or_default();
179                Self::create_ir_traces_faceted(
180                    data,
181                    x,
182                    y,
183                    z,
184                    facet_column,
185                    &config,
186                    color_bar,
187                    color_scale,
188                    reverse_scale,
189                    show_scale,
190                    lighting,
191                    opacity,
192                )
193            }
194            None => Self::create_ir_traces(
195                data,
196                x,
197                y,
198                z,
199                color_bar,
200                color_scale,
201                reverse_scale,
202                show_scale,
203                lighting,
204                opacity,
205            ),
206        };
207
208        let layout = LayoutIR {
209            title: plot_title,
210            x_title: None,
211            y_title: None,
212            y2_title: None,
213            z_title: None,
214            legend_title: None,
215            legend: if grid.is_some() {
216                None
217            } else {
218                legend.cloned()
219            },
220            dimensions: None,
221            bar_mode: None,
222            box_mode: None,
223            box_gap: None,
224            margin_bottom: None,
225            axes_2d: None,
226            scene_3d: None,
227            polar: None,
228            mapbox: None,
229            grid,
230            annotations: vec![],
231        };
232
233        Self { traces, layout }
234    }
235}
236
237#[bon]
238impl SurfacePlot {
239    #[builder(
240        start_fn = try_builder,
241        finish_fn = try_build,
242        builder_type = SurfacePlotTryBuilder,
243        on(String, into),
244        on(Text, into),
245    )]
246    pub fn try_new(
247        data: &DataFrame,
248        x: &str,
249        y: &str,
250        z: &str,
251        color_bar: Option<&ColorBar>,
252        color_scale: Option<Palette>,
253        reverse_scale: Option<bool>,
254        show_scale: Option<bool>,
255        lighting: Option<&Lighting>,
256        opacity: Option<f64>,
257        facet: Option<&str>,
258        facet_config: Option<&FacetConfig>,
259        plot_title: Option<Text>,
260        legend: Option<&Legend>,
261    ) -> Result<Self, crate::io::PlotlarsError> {
262        std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
263            Self::__orig_new(
264                data,
265                x,
266                y,
267                z,
268                color_bar,
269                color_scale,
270                reverse_scale,
271                show_scale,
272                lighting,
273                opacity,
274                facet,
275                facet_config,
276                plot_title,
277                legend,
278            )
279        }))
280        .map_err(|panic| {
281            let msg = panic
282                .downcast_ref::<String>()
283                .cloned()
284                .or_else(|| panic.downcast_ref::<&str>().map(|s| s.to_string()))
285                .unwrap_or_else(|| "unknown error".to_string());
286            crate::io::PlotlarsError::PlotBuild { message: msg }
287        })
288    }
289}
290
291impl SurfacePlot {
292    fn unique_ordered(v: Vec<Option<f32>>) -> Vec<f32> {
293        IndexSet::<OrderedFloat<f32>>::from_iter(v.into_iter().flatten().map(OrderedFloat))
294            .into_iter()
295            .map(|of| of.into_inner())
296            .collect()
297    }
298
299    #[allow(clippy::too_many_arguments)]
300    fn create_ir_traces(
301        data: &DataFrame,
302        x: &str,
303        y: &str,
304        z: &str,
305        color_bar: Option<&ColorBar>,
306        color_scale: Option<Palette>,
307        reverse_scale: Option<bool>,
308        show_scale: Option<bool>,
309        lighting: Option<&Lighting>,
310        opacity: Option<f64>,
311    ) -> Vec<TraceIR> {
312        let ir = Self::build_surface_ir(
313            data,
314            x,
315            y,
316            z,
317            color_bar,
318            color_scale,
319            reverse_scale,
320            show_scale,
321            lighting,
322            opacity,
323            None,
324        );
325        vec![TraceIR::SurfacePlot(ir)]
326    }
327
328    #[allow(clippy::too_many_arguments)]
329    fn create_ir_traces_faceted(
330        data: &DataFrame,
331        x: &str,
332        y: &str,
333        z: &str,
334        facet_column: &str,
335        config: &FacetConfig,
336        color_bar: Option<&ColorBar>,
337        color_scale: Option<Palette>,
338        reverse_scale: Option<bool>,
339        show_scale: Option<bool>,
340        lighting: Option<&Lighting>,
341        opacity: Option<f64>,
342    ) -> Vec<TraceIR> {
343        const MAX_FACETS: usize = 8;
344
345        let facet_categories = crate::data::get_unique_groups(data, facet_column, config.sorter);
346
347        if facet_categories.len() > MAX_FACETS {
348            panic!(
349                "Facet column '{}' has {} unique values, but plotly.rs supports maximum {} 3D scenes",
350                facet_column,
351                facet_categories.len(),
352                MAX_FACETS
353            );
354        }
355
356        let mut traces = Vec::new();
357
358        for (facet_idx, facet_value) in facet_categories.iter().enumerate() {
359            let facet_data = crate::data::filter_data_by_group(data, facet_column, facet_value);
360            let scene = Self::get_scene_reference(facet_idx);
361
362            // Only show colorbar on the first faceted trace to avoid overlap
363            let facet_show_scale = if facet_idx == 0 {
364                show_scale
365            } else {
366                Some(false)
367            };
368
369            let ir = Self::build_surface_ir(
370                &facet_data,
371                x,
372                y,
373                z,
374                if facet_idx == 0 { color_bar } else { None },
375                color_scale,
376                reverse_scale,
377                facet_show_scale,
378                lighting,
379                opacity,
380                Some(scene),
381            );
382
383            traces.push(TraceIR::SurfacePlot(ir));
384        }
385
386        traces
387    }
388
389    #[allow(clippy::too_many_arguments)]
390    fn build_surface_ir(
391        data: &DataFrame,
392        x: &str,
393        y: &str,
394        z: &str,
395        color_bar: Option<&ColorBar>,
396        color_scale: Option<Palette>,
397        reverse_scale: Option<bool>,
398        show_scale: Option<bool>,
399        lighting: Option<&Lighting>,
400        opacity: Option<f64>,
401        scene_ref: Option<String>,
402    ) -> SurfacePlotIR {
403        let x_raw = crate::data::get_numeric_column(data, x);
404        let y_raw = crate::data::get_numeric_column(data, y);
405        let z_raw = crate::data::get_numeric_column(data, z);
406
407        let x_unique = Self::unique_ordered(x_raw);
408        let y_unique = Self::unique_ordered(y_raw.clone());
409
410        let z_grid: Vec<Vec<f64>> = z_raw
411            .into_iter()
412            .collect::<Vec<_>>()
413            .chunks(y_unique.len())
414            .map(|chunk| chunk.iter().map(|v| v.unwrap_or(0.0) as f64).collect())
415            .collect();
416
417        SurfacePlotIR {
418            x: ColumnData::Numeric(x_unique.iter().map(|v| Some(*v)).collect()),
419            y: ColumnData::Numeric(y_unique.iter().map(|v| Some(*v)).collect()),
420            z: z_grid,
421            color_scale,
422            color_bar: color_bar.cloned(),
423            reverse_scale,
424            show_scale,
425            lighting: lighting.cloned(),
426            opacity,
427            scene_ref,
428        }
429    }
430
431    fn get_scene_reference(index: usize) -> String {
432        match index {
433            0 => "scene".to_string(),
434            1 => "scene2".to_string(),
435            2 => "scene3".to_string(),
436            3 => "scene4".to_string(),
437            4 => "scene5".to_string(),
438            5 => "scene6".to_string(),
439            6 => "scene7".to_string(),
440            7 => "scene8".to_string(),
441            _ => "scene".to_string(),
442        }
443    }
444}
445
446impl crate::Plot for SurfacePlot {
447    fn ir_traces(&self) -> &[TraceIR] {
448        &self.traces
449    }
450
451    fn ir_layout(&self) -> &LayoutIR {
452        &self.layout
453    }
454}
455
456#[cfg(test)]
457mod tests {
458    use super::*;
459    use crate::Plot;
460    use polars::prelude::*;
461
462    #[test]
463    fn test_basic_one_trace() {
464        let df = df![
465            "x" => [1.0, 1.0, 2.0, 2.0],
466            "y" => [1.0, 2.0, 1.0, 2.0],
467            "z" => [5.0, 6.0, 7.0, 8.0]
468        ]
469        .unwrap();
470        let plot = SurfacePlot::builder()
471            .data(&df)
472            .x("x")
473            .y("y")
474            .z("z")
475            .build();
476        assert_eq!(plot.ir_traces().len(), 1);
477        assert!(matches!(plot.ir_traces()[0], TraceIR::SurfacePlot(_)));
478    }
479
480    #[test]
481    fn test_layout_no_axes_2d() {
482        let df = df![
483            "x" => [1.0, 1.0, 2.0, 2.0],
484            "y" => [1.0, 2.0, 1.0, 2.0],
485            "z" => [5.0, 6.0, 7.0, 8.0]
486        ]
487        .unwrap();
488        let plot = SurfacePlot::builder()
489            .data(&df)
490            .x("x")
491            .y("y")
492            .z("z")
493            .build();
494        assert!(plot.ir_layout().axes_2d.is_none());
495    }
496
497    #[test]
498    fn test_layout_title() {
499        let df = df![
500            "x" => [1.0, 1.0, 2.0, 2.0],
501            "y" => [1.0, 2.0, 1.0, 2.0],
502            "z" => [5.0, 6.0, 7.0, 8.0]
503        ]
504        .unwrap();
505        let plot = SurfacePlot::builder()
506            .data(&df)
507            .x("x")
508            .y("y")
509            .z("z")
510            .plot_title("Surface")
511            .build();
512        assert!(plot.ir_layout().title.is_some());
513    }
514}