1use bon::bon;
2
3use polars::frame::DataFrame;
4
5use crate::{
6 components::{FacetConfig, Legend, Rgb, Shape, Text, DEFAULT_PLOTLY_COLORS},
7 ir::data::ColumnData,
8 ir::layout::LayoutIR,
9 ir::marker::MarkerIR,
10 ir::trace::{Scatter3dPlotIR, TraceIR},
11};
12
13#[derive(Clone)]
101#[allow(dead_code)]
102pub struct Scatter3dPlot {
103 traces: Vec<TraceIR>,
104 layout: LayoutIR,
105}
106
107#[bon]
108impl Scatter3dPlot {
109 #[builder(on(String, into), on(Text, into))]
110 pub fn new(
111 data: &DataFrame,
112 x: &str,
113 y: &str,
114 z: &str,
115 group: Option<&str>,
116 sort_groups_by: Option<fn(&str, &str) -> std::cmp::Ordering>,
117 facet: Option<&str>,
118 facet_config: Option<&FacetConfig>,
119 opacity: Option<f64>,
120 size: Option<usize>,
121 color: Option<Rgb>,
122 colors: Option<Vec<Rgb>>,
123 shape: Option<Shape>,
124 shapes: Option<Vec<Shape>>,
125 plot_title: Option<Text>,
126 legend: Option<&Legend>,
127 ) -> Self {
128 let grid = facet.map(|facet_column| {
129 let config = facet_config.cloned().unwrap_or_default();
130 let facet_categories =
131 crate::data::get_unique_groups(data, facet_column, config.sorter);
132 let n_facets = facet_categories.len();
133 let (ncols, nrows) =
134 crate::faceting::calculate_grid_dimensions(n_facets, config.cols, config.rows);
135 crate::ir::facet::GridSpec {
136 kind: crate::ir::facet::FacetKind::Scene,
137 rows: nrows,
138 cols: ncols,
139 h_gap: config.h_gap,
140 v_gap: config.v_gap,
141 scales: config.scales.clone(),
142 n_facets,
143 facet_categories,
144 title_style: config.title_style.clone(),
145 x_title: None,
146 y_title: None,
147 x_axis: None,
148 y_axis: None,
149 legend_title: None,
150 legend: legend.cloned(),
151 }
152 });
153
154 let traces = match facet {
155 Some(facet_column) => {
156 let config = facet_config.cloned().unwrap_or_default();
157 Self::create_ir_traces_faceted(
158 data,
159 x,
160 y,
161 z,
162 group,
163 sort_groups_by,
164 facet_column,
165 &config,
166 opacity,
167 size,
168 color,
169 colors,
170 shape,
171 shapes,
172 )
173 }
174 None => Self::create_ir_traces(
175 data,
176 x,
177 y,
178 z,
179 group,
180 sort_groups_by,
181 opacity,
182 size,
183 color,
184 colors,
185 shape,
186 shapes,
187 ),
188 };
189
190 let layout = LayoutIR {
191 title: plot_title,
192 x_title: None,
193 y_title: None,
194 y2_title: None,
195 z_title: None,
196 legend_title: None,
197 legend: if grid.is_some() {
198 None
199 } else {
200 legend.cloned()
201 },
202 dimensions: None,
203 bar_mode: None,
204 box_mode: None,
205 box_gap: None,
206 margin_bottom: None,
207 axes_2d: None,
208 scene_3d: None,
209 polar: None,
210 mapbox: None,
211 grid,
212 annotations: vec![],
213 };
214
215 Self { traces, layout }
216 }
217}
218
219#[bon]
220impl Scatter3dPlot {
221 #[builder(
222 start_fn = try_builder,
223 finish_fn = try_build,
224 builder_type = Scatter3dPlotTryBuilder,
225 on(String, into),
226 on(Text, into),
227 )]
228 pub fn try_new(
229 data: &DataFrame,
230 x: &str,
231 y: &str,
232 z: &str,
233 group: Option<&str>,
234 sort_groups_by: Option<fn(&str, &str) -> std::cmp::Ordering>,
235 facet: Option<&str>,
236 facet_config: Option<&FacetConfig>,
237 opacity: Option<f64>,
238 size: Option<usize>,
239 color: Option<Rgb>,
240 colors: Option<Vec<Rgb>>,
241 shape: Option<Shape>,
242 shapes: Option<Vec<Shape>>,
243 plot_title: Option<Text>,
244 legend: Option<&Legend>,
245 ) -> Result<Self, crate::io::PlotlarsError> {
246 std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
247 Self::__orig_new(
248 data,
249 x,
250 y,
251 z,
252 group,
253 sort_groups_by,
254 facet,
255 facet_config,
256 opacity,
257 size,
258 color,
259 colors,
260 shape,
261 shapes,
262 plot_title,
263 legend,
264 )
265 }))
266 .map_err(|panic| {
267 let msg = panic
268 .downcast_ref::<String>()
269 .cloned()
270 .or_else(|| panic.downcast_ref::<&str>().map(|s| s.to_string()))
271 .unwrap_or_else(|| "unknown error".to_string());
272 crate::io::PlotlarsError::PlotBuild { message: msg }
273 })
274 }
275}
276
277impl Scatter3dPlot {
278 fn get_scene_reference(index: usize) -> String {
279 match index {
280 0 => "scene".to_string(),
281 1 => "scene2".to_string(),
282 2 => "scene3".to_string(),
283 3 => "scene4".to_string(),
284 4 => "scene5".to_string(),
285 5 => "scene6".to_string(),
286 6 => "scene7".to_string(),
287 7 => "scene8".to_string(),
288 _ => "scene".to_string(),
289 }
290 }
291
292 #[allow(clippy::too_many_arguments)]
293 fn create_ir_traces(
294 data: &DataFrame,
295 x: &str,
296 y: &str,
297 z: &str,
298 group: Option<&str>,
299 sort_groups_by: Option<fn(&str, &str) -> std::cmp::Ordering>,
300 opacity: Option<f64>,
301 size: Option<usize>,
302 color: Option<Rgb>,
303 colors: Option<Vec<Rgb>>,
304 shape: Option<Shape>,
305 shapes: Option<Vec<Shape>>,
306 ) -> Vec<TraceIR> {
307 let mut traces = Vec::new();
308
309 match group {
310 Some(group_col) => {
311 let groups = crate::data::get_unique_groups(data, group_col, sort_groups_by);
312
313 for (i, group_name) in groups.iter().enumerate() {
314 let subset = crate::data::filter_data_by_group(data, group_col, group_name);
315
316 let marker_ir = MarkerIR {
317 opacity,
318 size,
319 color: Self::resolve_color(i, color, colors.clone()),
320 shape: Self::resolve_shape(i, shape, shapes.clone()),
321 };
322
323 traces.push(TraceIR::Scatter3dPlot(Scatter3dPlotIR {
324 x: ColumnData::Numeric(crate::data::get_numeric_column(&subset, x)),
325 y: ColumnData::Numeric(crate::data::get_numeric_column(&subset, y)),
326 z: ColumnData::Numeric(crate::data::get_numeric_column(&subset, z)),
327 name: Some(group_name.to_string()),
328 mode: None,
329 marker: Some(marker_ir),
330 show_legend: None,
331 legend_group: None,
332 scene_ref: None,
333 }));
334 }
335 }
336 None => {
337 let marker_ir = MarkerIR {
338 opacity,
339 size,
340 color: Self::resolve_color(0, color, colors),
341 shape: Self::resolve_shape(0, shape, shapes),
342 };
343
344 traces.push(TraceIR::Scatter3dPlot(Scatter3dPlotIR {
345 x: ColumnData::Numeric(crate::data::get_numeric_column(data, x)),
346 y: ColumnData::Numeric(crate::data::get_numeric_column(data, y)),
347 z: ColumnData::Numeric(crate::data::get_numeric_column(data, z)),
348 name: None,
349 mode: None,
350 marker: Some(marker_ir),
351 show_legend: None,
352 legend_group: None,
353 scene_ref: None,
354 }));
355 }
356 }
357
358 traces
359 }
360
361 #[allow(clippy::too_many_arguments)]
362 fn create_ir_traces_faceted(
363 data: &DataFrame,
364 x: &str,
365 y: &str,
366 z: &str,
367 group: Option<&str>,
368 sort_groups_by: Option<fn(&str, &str) -> std::cmp::Ordering>,
369 facet_column: &str,
370 config: &FacetConfig,
371 opacity: Option<f64>,
372 size: Option<usize>,
373 color: Option<Rgb>,
374 colors: Option<Vec<Rgb>>,
375 shape: Option<Shape>,
376 shapes: Option<Vec<Shape>>,
377 ) -> Vec<TraceIR> {
378 const MAX_FACETS: usize = 8;
379
380 let facet_categories = crate::data::get_unique_groups(data, facet_column, config.sorter);
381
382 if facet_categories.len() > MAX_FACETS {
383 panic!(
384 "Facet column '{}' has {} unique values, but plotly.rs supports maximum {} 3D scenes",
385 facet_column,
386 facet_categories.len(),
387 MAX_FACETS
388 );
389 }
390
391 if let Some(ref color_vec) = colors {
392 if group.is_none() {
393 let color_count = color_vec.len();
394 let facet_count = facet_categories.len();
395 if color_count != facet_count {
396 panic!(
397 "When using colors with facet (without group), colors.len() must equal number of facets. \
398 Expected {} colors for {} facets, but got {} colors.",
399 facet_count, facet_count, color_count
400 );
401 }
402 } else if let Some(group_col) = group {
403 let groups = crate::data::get_unique_groups(data, group_col, sort_groups_by);
404 let color_count = color_vec.len();
405 let group_count = groups.len();
406 if color_count < group_count {
407 panic!(
408 "When using colors with group, colors.len() must be >= number of groups. \
409 Need at least {} colors for {} groups, but got {} colors",
410 group_count, group_count, color_count
411 );
412 }
413 }
414 }
415
416 let global_group_indices: std::collections::HashMap<String, usize> =
417 if let Some(group_col) = group {
418 let global_groups = crate::data::get_unique_groups(data, group_col, sort_groups_by);
419 global_groups
420 .into_iter()
421 .enumerate()
422 .map(|(idx, group_name)| (group_name, idx))
423 .collect()
424 } else {
425 std::collections::HashMap::new()
426 };
427
428 let colors = if group.is_some() && colors.is_none() {
429 Some(DEFAULT_PLOTLY_COLORS.to_vec())
430 } else {
431 colors
432 };
433
434 let mut traces = Vec::new();
435
436 if config.highlight_facet {
437 for (facet_idx, facet_value) in facet_categories.iter().enumerate() {
438 let scene = Self::get_scene_reference(facet_idx);
439
440 for other_facet_value in facet_categories.iter() {
441 if other_facet_value != facet_value {
442 let other_data = crate::data::filter_data_by_group(
443 data,
444 facet_column,
445 other_facet_value,
446 );
447
448 let grey_color = config.unhighlighted_color.unwrap_or(Rgb(200, 200, 200));
449 let marker_ir = MarkerIR {
450 opacity,
451 size,
452 color: Some(grey_color),
453 shape: Self::resolve_shape(0, shape, None),
454 };
455
456 traces.push(TraceIR::Scatter3dPlot(Scatter3dPlotIR {
457 x: ColumnData::Numeric(crate::data::get_numeric_column(&other_data, x)),
458 y: ColumnData::Numeric(crate::data::get_numeric_column(&other_data, y)),
459 z: ColumnData::Numeric(crate::data::get_numeric_column(&other_data, z)),
460 name: None,
461 mode: None,
462 marker: Some(marker_ir),
463 show_legend: Some(false),
464 legend_group: None,
465 scene_ref: Some(scene.clone()),
466 }));
467 }
468 }
469
470 let facet_data = crate::data::filter_data_by_group(data, facet_column, facet_value);
471
472 match group {
473 Some(group_col) => {
474 let groups =
475 crate::data::get_unique_groups(&facet_data, group_col, sort_groups_by);
476
477 for group_val in groups.iter() {
478 let group_data = crate::data::filter_data_by_group(
479 &facet_data,
480 group_col,
481 group_val,
482 );
483
484 let global_idx =
485 global_group_indices.get(group_val).copied().unwrap_or(0);
486
487 let marker_ir = MarkerIR {
488 opacity,
489 size,
490 color: Self::resolve_color(global_idx, color, colors.clone()),
491 shape: Self::resolve_shape(global_idx, shape, shapes.clone()),
492 };
493
494 traces.push(TraceIR::Scatter3dPlot(Scatter3dPlotIR {
495 x: ColumnData::Numeric(crate::data::get_numeric_column(
496 &group_data,
497 x,
498 )),
499 y: ColumnData::Numeric(crate::data::get_numeric_column(
500 &group_data,
501 y,
502 )),
503 z: ColumnData::Numeric(crate::data::get_numeric_column(
504 &group_data,
505 z,
506 )),
507 name: Some(group_val.to_string()),
508 mode: None,
509 marker: Some(marker_ir),
510 show_legend: Some(facet_idx == 0),
511 legend_group: Some(group_val.to_string()),
512 scene_ref: Some(scene.clone()),
513 }));
514 }
515 }
516 None => {
517 let marker_ir = MarkerIR {
518 opacity,
519 size,
520 color: Self::resolve_color(facet_idx, color, colors.clone()),
521 shape: Self::resolve_shape(facet_idx, shape, shapes.clone()),
522 };
523
524 traces.push(TraceIR::Scatter3dPlot(Scatter3dPlotIR {
525 x: ColumnData::Numeric(crate::data::get_numeric_column(&facet_data, x)),
526 y: ColumnData::Numeric(crate::data::get_numeric_column(&facet_data, y)),
527 z: ColumnData::Numeric(crate::data::get_numeric_column(&facet_data, z)),
528 name: None,
529 mode: None,
530 marker: Some(marker_ir),
531 show_legend: Some(false),
532 legend_group: None,
533 scene_ref: Some(scene.clone()),
534 }));
535 }
536 }
537 }
538 } else {
539 for (facet_idx, facet_value) in facet_categories.iter().enumerate() {
540 let facet_data = crate::data::filter_data_by_group(data, facet_column, facet_value);
541 let scene = Self::get_scene_reference(facet_idx);
542
543 match group {
544 Some(group_col) => {
545 let groups =
546 crate::data::get_unique_groups(&facet_data, group_col, sort_groups_by);
547
548 for group_val in groups.iter() {
549 let group_data = crate::data::filter_data_by_group(
550 &facet_data,
551 group_col,
552 group_val,
553 );
554
555 let global_idx =
556 global_group_indices.get(group_val).copied().unwrap_or(0);
557
558 let marker_ir = MarkerIR {
559 opacity,
560 size,
561 color: Self::resolve_color(global_idx, color, colors.clone()),
562 shape: Self::resolve_shape(global_idx, shape, shapes.clone()),
563 };
564
565 traces.push(TraceIR::Scatter3dPlot(Scatter3dPlotIR {
566 x: ColumnData::Numeric(crate::data::get_numeric_column(
567 &group_data,
568 x,
569 )),
570 y: ColumnData::Numeric(crate::data::get_numeric_column(
571 &group_data,
572 y,
573 )),
574 z: ColumnData::Numeric(crate::data::get_numeric_column(
575 &group_data,
576 z,
577 )),
578 name: Some(group_val.to_string()),
579 mode: None,
580 marker: Some(marker_ir),
581 show_legend: Some(facet_idx == 0),
582 legend_group: Some(group_val.to_string()),
583 scene_ref: Some(scene.clone()),
584 }));
585 }
586 }
587 None => {
588 let marker_ir = MarkerIR {
589 opacity,
590 size,
591 color: Self::resolve_color(facet_idx, color, colors.clone()),
592 shape: Self::resolve_shape(facet_idx, shape, shapes.clone()),
593 };
594
595 traces.push(TraceIR::Scatter3dPlot(Scatter3dPlotIR {
596 x: ColumnData::Numeric(crate::data::get_numeric_column(&facet_data, x)),
597 y: ColumnData::Numeric(crate::data::get_numeric_column(&facet_data, y)),
598 z: ColumnData::Numeric(crate::data::get_numeric_column(&facet_data, z)),
599 name: None,
600 mode: None,
601 marker: Some(marker_ir),
602 show_legend: Some(false),
603 legend_group: None,
604 scene_ref: Some(scene.clone()),
605 }));
606 }
607 }
608 }
609 }
610
611 traces
612 }
613
614 fn resolve_color(index: usize, color: Option<Rgb>, colors: Option<Vec<Rgb>>) -> Option<Rgb> {
615 if let Some(c) = color {
616 return Some(c);
617 }
618 if let Some(ref cs) = colors {
619 return cs.get(index).copied();
620 }
621 None
622 }
623
624 fn resolve_shape(
625 index: usize,
626 shape: Option<Shape>,
627 shapes: Option<Vec<Shape>>,
628 ) -> Option<Shape> {
629 if let Some(s) = shape {
630 return Some(s);
631 }
632 if let Some(ref ss) = shapes {
633 return ss.get(index).cloned();
634 }
635 None
636 }
637}
638
639impl crate::Plot for Scatter3dPlot {
640 fn ir_traces(&self) -> &[TraceIR] {
641 &self.traces
642 }
643
644 fn ir_layout(&self) -> &LayoutIR {
645 &self.layout
646 }
647}
648
649#[cfg(test)]
650mod tests {
651 use super::*;
652 use crate::Plot;
653 use polars::prelude::*;
654
655 #[test]
656 fn test_basic_one_trace() {
657 let df = df![
658 "x" => [1.0, 2.0, 3.0],
659 "y" => [4.0, 5.0, 6.0],
660 "z" => [7.0, 8.0, 9.0]
661 ]
662 .unwrap();
663 let plot = Scatter3dPlot::builder()
664 .data(&df)
665 .x("x")
666 .y("y")
667 .z("z")
668 .build();
669 assert_eq!(plot.ir_traces().len(), 1);
670 assert!(matches!(plot.ir_traces()[0], TraceIR::Scatter3dPlot(_)));
671 }
672
673 #[test]
674 fn test_with_group() {
675 let df = df![
676 "x" => [1.0, 2.0, 3.0, 4.0],
677 "y" => [4.0, 5.0, 6.0, 7.0],
678 "z" => [7.0, 8.0, 9.0, 10.0],
679 "g" => ["a", "b", "a", "b"]
680 ]
681 .unwrap();
682 let plot = Scatter3dPlot::builder()
683 .data(&df)
684 .x("x")
685 .y("y")
686 .z("z")
687 .group("g")
688 .build();
689 assert_eq!(plot.ir_traces().len(), 2);
690 }
691
692 #[test]
693 fn test_layout_no_axes_2d() {
694 let df = df![
695 "x" => [1.0, 2.0],
696 "y" => [3.0, 4.0],
697 "z" => [5.0, 6.0]
698 ]
699 .unwrap();
700 let plot = Scatter3dPlot::builder()
701 .data(&df)
702 .x("x")
703 .y("y")
704 .z("z")
705 .build();
706 assert!(plot.ir_layout().axes_2d.is_none());
707 }
708
709 #[test]
710 fn test_resolve_color_both_none() {
711 let result = Scatter3dPlot::resolve_color(0, None, None);
712 assert!(result.is_none());
713 }
714}