1use bon::bon;
2use indexmap::IndexSet;
3use ordered_float::OrderedFloat;
4
5use crate::{
6 components::{ColorBar, FacetConfig, Legend, Lighting, Palette, Text},
7 ir::data::ColumnData,
8 ir::layout::LayoutIR,
9 ir::trace::{SurfacePlotIR, TraceIR},
10};
11use polars::frame::DataFrame;
12
13#[derive(Clone)]
125#[allow(dead_code)]
126pub struct SurfacePlot {
127 traces: Vec<TraceIR>,
128 layout: LayoutIR,
129}
130
131#[bon]
132impl SurfacePlot {
133 #[builder(on(String, into), on(Text, into))]
134 pub fn new(
135 data: &DataFrame,
136 x: &str,
137 y: &str,
138 z: &str,
139 color_bar: Option<&ColorBar>,
140 color_scale: Option<Palette>,
141 reverse_scale: Option<bool>,
142 show_scale: Option<bool>,
143 lighting: Option<&Lighting>,
144 opacity: Option<f64>,
145 facet: Option<&str>,
146 facet_config: Option<&FacetConfig>,
147 plot_title: Option<Text>,
148 legend: Option<&Legend>,
149 ) -> Self {
150 let grid = facet.map(|facet_column| {
151 let config = facet_config.cloned().unwrap_or_default();
152 let facet_categories =
153 crate::data::get_unique_groups(data, facet_column, config.sorter);
154 let n_facets = facet_categories.len();
155 let (ncols, nrows) =
156 crate::faceting::calculate_grid_dimensions(n_facets, config.cols, config.rows);
157 crate::ir::facet::GridSpec {
158 kind: crate::ir::facet::FacetKind::Scene,
159 rows: nrows,
160 cols: ncols,
161 h_gap: config.h_gap,
162 v_gap: config.v_gap,
163 scales: config.scales.clone(),
164 n_facets,
165 facet_categories,
166 title_style: config.title_style.clone(),
167 x_title: None,
168 y_title: None,
169 x_axis: None,
170 y_axis: None,
171 legend_title: None,
172 legend: legend.cloned(),
173 }
174 });
175
176 let traces = match facet {
177 Some(facet_column) => {
178 let config = facet_config.cloned().unwrap_or_default();
179 Self::create_ir_traces_faceted(
180 data,
181 x,
182 y,
183 z,
184 facet_column,
185 &config,
186 color_bar,
187 color_scale,
188 reverse_scale,
189 show_scale,
190 lighting,
191 opacity,
192 )
193 }
194 None => Self::create_ir_traces(
195 data,
196 x,
197 y,
198 z,
199 color_bar,
200 color_scale,
201 reverse_scale,
202 show_scale,
203 lighting,
204 opacity,
205 ),
206 };
207
208 let layout = LayoutIR {
209 title: plot_title,
210 x_title: None,
211 y_title: None,
212 y2_title: None,
213 z_title: None,
214 legend_title: None,
215 legend: if grid.is_some() {
216 None
217 } else {
218 legend.cloned()
219 },
220 dimensions: None,
221 bar_mode: None,
222 box_mode: None,
223 box_gap: None,
224 margin_bottom: None,
225 axes_2d: None,
226 scene_3d: None,
227 polar: None,
228 mapbox: None,
229 grid,
230 annotations: vec![],
231 };
232
233 Self { traces, layout }
234 }
235}
236
237#[bon]
238impl SurfacePlot {
239 #[builder(
240 start_fn = try_builder,
241 finish_fn = try_build,
242 builder_type = SurfacePlotTryBuilder,
243 on(String, into),
244 on(Text, into),
245 )]
246 pub fn try_new(
247 data: &DataFrame,
248 x: &str,
249 y: &str,
250 z: &str,
251 color_bar: Option<&ColorBar>,
252 color_scale: Option<Palette>,
253 reverse_scale: Option<bool>,
254 show_scale: Option<bool>,
255 lighting: Option<&Lighting>,
256 opacity: Option<f64>,
257 facet: Option<&str>,
258 facet_config: Option<&FacetConfig>,
259 plot_title: Option<Text>,
260 legend: Option<&Legend>,
261 ) -> Result<Self, crate::io::PlotlarsError> {
262 std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
263 Self::__orig_new(
264 data,
265 x,
266 y,
267 z,
268 color_bar,
269 color_scale,
270 reverse_scale,
271 show_scale,
272 lighting,
273 opacity,
274 facet,
275 facet_config,
276 plot_title,
277 legend,
278 )
279 }))
280 .map_err(|panic| {
281 let msg = panic
282 .downcast_ref::<String>()
283 .cloned()
284 .or_else(|| panic.downcast_ref::<&str>().map(|s| s.to_string()))
285 .unwrap_or_else(|| "unknown error".to_string());
286 crate::io::PlotlarsError::PlotBuild { message: msg }
287 })
288 }
289}
290
291impl SurfacePlot {
292 fn unique_ordered(v: Vec<Option<f32>>) -> Vec<f32> {
293 IndexSet::<OrderedFloat<f32>>::from_iter(v.into_iter().flatten().map(OrderedFloat))
294 .into_iter()
295 .map(|of| of.into_inner())
296 .collect()
297 }
298
299 #[allow(clippy::too_many_arguments)]
300 fn create_ir_traces(
301 data: &DataFrame,
302 x: &str,
303 y: &str,
304 z: &str,
305 color_bar: Option<&ColorBar>,
306 color_scale: Option<Palette>,
307 reverse_scale: Option<bool>,
308 show_scale: Option<bool>,
309 lighting: Option<&Lighting>,
310 opacity: Option<f64>,
311 ) -> Vec<TraceIR> {
312 let ir = Self::build_surface_ir(
313 data,
314 x,
315 y,
316 z,
317 color_bar,
318 color_scale,
319 reverse_scale,
320 show_scale,
321 lighting,
322 opacity,
323 None,
324 );
325 vec![TraceIR::SurfacePlot(ir)]
326 }
327
328 #[allow(clippy::too_many_arguments)]
329 fn create_ir_traces_faceted(
330 data: &DataFrame,
331 x: &str,
332 y: &str,
333 z: &str,
334 facet_column: &str,
335 config: &FacetConfig,
336 color_bar: Option<&ColorBar>,
337 color_scale: Option<Palette>,
338 reverse_scale: Option<bool>,
339 show_scale: Option<bool>,
340 lighting: Option<&Lighting>,
341 opacity: Option<f64>,
342 ) -> Vec<TraceIR> {
343 const MAX_FACETS: usize = 8;
344
345 let facet_categories = crate::data::get_unique_groups(data, facet_column, config.sorter);
346
347 if facet_categories.len() > MAX_FACETS {
348 panic!(
349 "Facet column '{}' has {} unique values, but plotly.rs supports maximum {} 3D scenes",
350 facet_column,
351 facet_categories.len(),
352 MAX_FACETS
353 );
354 }
355
356 let mut traces = Vec::new();
357
358 for (facet_idx, facet_value) in facet_categories.iter().enumerate() {
359 let facet_data = crate::data::filter_data_by_group(data, facet_column, facet_value);
360 let scene = Self::get_scene_reference(facet_idx);
361
362 let facet_show_scale = if facet_idx == 0 {
364 show_scale
365 } else {
366 Some(false)
367 };
368
369 let ir = Self::build_surface_ir(
370 &facet_data,
371 x,
372 y,
373 z,
374 if facet_idx == 0 { color_bar } else { None },
375 color_scale,
376 reverse_scale,
377 facet_show_scale,
378 lighting,
379 opacity,
380 Some(scene),
381 );
382
383 traces.push(TraceIR::SurfacePlot(ir));
384 }
385
386 traces
387 }
388
389 #[allow(clippy::too_many_arguments)]
390 fn build_surface_ir(
391 data: &DataFrame,
392 x: &str,
393 y: &str,
394 z: &str,
395 color_bar: Option<&ColorBar>,
396 color_scale: Option<Palette>,
397 reverse_scale: Option<bool>,
398 show_scale: Option<bool>,
399 lighting: Option<&Lighting>,
400 opacity: Option<f64>,
401 scene_ref: Option<String>,
402 ) -> SurfacePlotIR {
403 let x_raw = crate::data::get_numeric_column(data, x);
404 let y_raw = crate::data::get_numeric_column(data, y);
405 let z_raw = crate::data::get_numeric_column(data, z);
406
407 let x_unique = Self::unique_ordered(x_raw);
408 let y_unique = Self::unique_ordered(y_raw.clone());
409
410 let z_grid: Vec<Vec<f64>> = z_raw
411 .into_iter()
412 .collect::<Vec<_>>()
413 .chunks(y_unique.len())
414 .map(|chunk| chunk.iter().map(|v| v.unwrap_or(0.0) as f64).collect())
415 .collect();
416
417 SurfacePlotIR {
418 x: ColumnData::Numeric(x_unique.iter().map(|v| Some(*v)).collect()),
419 y: ColumnData::Numeric(y_unique.iter().map(|v| Some(*v)).collect()),
420 z: z_grid,
421 color_scale,
422 color_bar: color_bar.cloned(),
423 reverse_scale,
424 show_scale,
425 lighting: lighting.cloned(),
426 opacity,
427 scene_ref,
428 }
429 }
430
431 fn get_scene_reference(index: usize) -> String {
432 match index {
433 0 => "scene".to_string(),
434 1 => "scene2".to_string(),
435 2 => "scene3".to_string(),
436 3 => "scene4".to_string(),
437 4 => "scene5".to_string(),
438 5 => "scene6".to_string(),
439 6 => "scene7".to_string(),
440 7 => "scene8".to_string(),
441 _ => "scene".to_string(),
442 }
443 }
444}
445
446impl crate::Plot for SurfacePlot {
447 fn ir_traces(&self) -> &[TraceIR] {
448 &self.traces
449 }
450
451 fn ir_layout(&self) -> &LayoutIR {
452 &self.layout
453 }
454}
455
456#[cfg(test)]
457mod tests {
458 use super::*;
459 use crate::Plot;
460 use polars::prelude::*;
461
462 #[test]
463 fn test_basic_one_trace() {
464 let df = df![
465 "x" => [1.0, 1.0, 2.0, 2.0],
466 "y" => [1.0, 2.0, 1.0, 2.0],
467 "z" => [5.0, 6.0, 7.0, 8.0]
468 ]
469 .unwrap();
470 let plot = SurfacePlot::builder()
471 .data(&df)
472 .x("x")
473 .y("y")
474 .z("z")
475 .build();
476 assert_eq!(plot.ir_traces().len(), 1);
477 assert!(matches!(plot.ir_traces()[0], TraceIR::SurfacePlot(_)));
478 }
479
480 #[test]
481 fn test_layout_no_axes_2d() {
482 let df = df![
483 "x" => [1.0, 1.0, 2.0, 2.0],
484 "y" => [1.0, 2.0, 1.0, 2.0],
485 "z" => [5.0, 6.0, 7.0, 8.0]
486 ]
487 .unwrap();
488 let plot = SurfacePlot::builder()
489 .data(&df)
490 .x("x")
491 .y("y")
492 .z("z")
493 .build();
494 assert!(plot.ir_layout().axes_2d.is_none());
495 }
496
497 #[test]
498 fn test_layout_title() {
499 let df = df![
500 "x" => [1.0, 1.0, 2.0, 2.0],
501 "y" => [1.0, 2.0, 1.0, 2.0],
502 "z" => [5.0, 6.0, 7.0, 8.0]
503 ]
504 .unwrap();
505 let plot = SurfacePlot::builder()
506 .data(&df)
507 .x("x")
508 .y("y")
509 .z("z")
510 .plot_title("Surface")
511 .build();
512 assert!(plot.ir_layout().title.is_some());
513 }
514}