1use bon::bon;
2
3use plotly::{
4 common::{Marker as MarkerPlotly, Mode},
5 layout::{GridPattern, LayoutGrid},
6 Layout as LayoutPlotly, Scatter, Trace,
7};
8
9use polars::frame::DataFrame;
10use serde::Serialize;
11
12use crate::{
13 common::{Layout, Marker, PlotHelper, Polar},
14 components::{Axis, FacetConfig, FacetScales, Legend, Rgb, Shape, Text, DEFAULT_PLOTLY_COLORS},
15};
16
17#[derive(Clone, Serialize)]
122pub struct ScatterPlot {
123 traces: Vec<Box<dyn Trace + 'static>>,
124 layout: LayoutPlotly,
125}
126
127#[bon]
128impl ScatterPlot {
129 #[builder(on(String, into), on(Text, into))]
130 pub fn new(
131 data: &DataFrame,
132 x: &str,
133 y: &str,
134 group: Option<&str>,
135 sort_groups_by: Option<fn(&str, &str) -> std::cmp::Ordering>,
136 facet: Option<&str>,
137 facet_config: Option<&FacetConfig>,
138 opacity: Option<f64>,
139 size: Option<usize>,
140 color: Option<Rgb>,
141 colors: Option<Vec<Rgb>>,
142 shape: Option<Shape>,
143 shapes: Option<Vec<Shape>>,
144 plot_title: Option<Text>,
145 x_title: Option<Text>,
146 y_title: Option<Text>,
147 legend_title: Option<Text>,
148 x_axis: Option<&Axis>,
149 y_axis: Option<&Axis>,
150 legend: Option<&Legend>,
151 ) -> Self {
152 let z_title = None;
153 let z_axis = None;
154 let y2_title = None;
155 let y2_axis = None;
156
157 let (layout, traces) = match facet {
158 Some(facet_column) => {
159 let config = facet_config.cloned().unwrap_or_default();
160
161 let layout = Self::create_faceted_layout(
162 data,
163 facet_column,
164 &config,
165 plot_title,
166 x_title,
167 y_title,
168 legend_title,
169 x_axis,
170 y_axis,
171 legend,
172 );
173
174 let traces = Self::create_faceted_traces(
175 data,
176 x,
177 y,
178 group,
179 sort_groups_by,
180 facet_column,
181 &config,
182 opacity,
183 size,
184 color,
185 colors,
186 shape,
187 shapes,
188 );
189
190 (layout, traces)
191 }
192 None => {
193 let layout = Self::create_layout(
194 plot_title,
195 x_title,
196 y_title,
197 y2_title,
198 z_title,
199 legend_title,
200 x_axis,
201 y_axis,
202 y2_axis,
203 z_axis,
204 legend,
205 None,
206 );
207
208 let traces = Self::create_traces(
209 data,
210 x,
211 y,
212 group,
213 sort_groups_by,
214 opacity,
215 size,
216 color,
217 colors,
218 shape,
219 shapes,
220 );
221
222 (layout, traces)
223 }
224 };
225
226 Self { traces, layout }
227 }
228
229 #[allow(clippy::too_many_arguments)]
230 fn create_traces(
231 data: &DataFrame,
232 x: &str,
233 y: &str,
234 group: Option<&str>,
235 sort_groups_by: Option<fn(&str, &str) -> std::cmp::Ordering>,
236 opacity: Option<f64>,
237 size: Option<usize>,
238 color: Option<Rgb>,
239 colors: Option<Vec<Rgb>>,
240 shape: Option<Shape>,
241 shapes: Option<Vec<Shape>>,
242 ) -> Vec<Box<dyn Trace + 'static>> {
243 let mut traces: Vec<Box<dyn Trace + 'static>> = Vec::new();
244
245 match group {
246 Some(group_col) => {
247 let groups = Self::get_unique_groups(data, group_col, sort_groups_by);
248
249 let groups = groups.iter().map(|s| s.as_str());
250
251 for (i, group) in groups.enumerate() {
252 let marker = Self::create_marker(
253 i,
254 opacity,
255 size,
256 color,
257 colors.clone(),
258 shape,
259 shapes.clone(),
260 );
261
262 let subset = Self::filter_data_by_group(data, group_col, group);
263
264 let trace = Self::create_trace(&subset, x, y, Some(group), marker);
265
266 traces.push(trace);
267 }
268 }
269 None => {
270 let group = None;
271
272 let marker = Self::create_marker(
273 0,
274 opacity,
275 size,
276 color,
277 colors.clone(),
278 shape,
279 shapes.clone(),
280 );
281
282 let trace = Self::create_trace(data, x, y, group, marker);
283
284 traces.push(trace);
285 }
286 }
287
288 traces
289 }
290
291 fn create_trace(
292 data: &DataFrame,
293 x: &str,
294 y: &str,
295 group_name: Option<&str>,
296 marker: MarkerPlotly,
297 ) -> Box<dyn Trace + 'static> {
298 Self::build_scatter_trace(data, x, y, group_name, marker)
299 }
300
301 fn build_scatter_trace(
302 data: &DataFrame,
303 x: &str,
304 y: &str,
305 group_name: Option<&str>,
306 marker: MarkerPlotly,
307 ) -> Box<dyn Trace + 'static> {
308 Self::build_scatter_trace_with_axes(data, x, y, group_name, marker, None, None, true, None)
309 }
310
311 #[allow(clippy::too_many_arguments)]
312 fn build_scatter_trace_with_axes(
313 data: &DataFrame,
314 x_col: &str,
315 y_col: &str,
316 group_name: Option<&str>,
317 marker: MarkerPlotly,
318 x_axis: Option<&str>,
319 y_axis: Option<&str>,
320 show_legend: bool,
321 legend_group: Option<&str>,
322 ) -> Box<dyn Trace + 'static> {
323 let x = Self::get_numeric_column(data, x_col);
324 let y = Self::get_numeric_column(data, y_col);
325
326 let trace = Scatter::default().x(x).y(y).mode(Mode::Markers);
327
328 let trace = trace.marker(marker);
329
330 let trace = if let Some(name) = group_name {
331 trace.name(name)
332 } else {
333 trace
334 };
335
336 let trace = if let Some(axis) = x_axis {
337 trace.x_axis(axis)
338 } else {
339 trace
340 };
341
342 let trace = if let Some(axis) = y_axis {
343 trace.y_axis(axis)
344 } else {
345 trace
346 };
347
348 let trace = if let Some(group) = legend_group {
349 trace.legend_group(group)
350 } else {
351 trace
352 };
353
354 if !show_legend {
355 trace.show_legend(false)
356 } else {
357 trace
358 }
359 }
360
361 #[allow(clippy::too_many_arguments)]
362 fn create_faceted_traces(
363 data: &DataFrame,
364 x: &str,
365 y: &str,
366 group: Option<&str>,
367 sort_groups_by: Option<fn(&str, &str) -> std::cmp::Ordering>,
368 facet_column: &str,
369 config: &FacetConfig,
370 opacity: Option<f64>,
371 size: Option<usize>,
372 color: Option<Rgb>,
373 colors: Option<Vec<Rgb>>,
374 shape: Option<Shape>,
375 shapes: Option<Vec<Shape>>,
376 ) -> Vec<Box<dyn Trace + 'static>> {
377 const MAX_FACETS: usize = 8;
378
379 let facet_categories = Self::get_unique_groups(data, facet_column, config.sorter);
380
381 if facet_categories.len() > MAX_FACETS {
382 panic!(
383 "Facet column '{}' has {} unique values, but plotly.rs supports maximum {} subplots",
384 facet_column,
385 facet_categories.len(),
386 MAX_FACETS
387 );
388 }
389
390 if let Some(ref color_vec) = colors {
391 if group.is_none() {
392 let color_count = color_vec.len();
393 let facet_count = facet_categories.len();
394
395 if color_count != facet_count {
396 panic!(
397 "When using colors with facet (without group), colors.len() must equal number of facets. \
398 Expected {} colors for {} facets, but got {} colors. \
399 Each facet must be assigned exactly one color.",
400 facet_count, facet_count, color_count
401 );
402 }
403 } else if let Some(group_col) = group {
404 let groups = Self::get_unique_groups(data, group_col, sort_groups_by);
405 let color_count = color_vec.len();
406 let group_count = groups.len();
407
408 if color_count < group_count {
409 panic!(
410 "When using colors with group, colors.len() must be >= number of groups. \
411 Need at least {} colors for {} groups, but got {} colors",
412 group_count, group_count, color_count
413 );
414 }
415 }
416 }
417
418 let global_group_indices: std::collections::HashMap<String, usize> =
419 if let Some(group_col) = group {
420 let global_groups = Self::get_unique_groups(data, group_col, sort_groups_by);
421 global_groups
422 .into_iter()
423 .enumerate()
424 .map(|(idx, group_name)| (group_name, idx))
425 .collect()
426 } else {
427 std::collections::HashMap::new()
428 };
429
430 let colors = if group.is_some() && colors.is_none() {
431 Some(DEFAULT_PLOTLY_COLORS.to_vec())
432 } else {
433 colors
434 };
435
436 let mut all_traces = Vec::new();
437
438 if config.highlight_facet {
439 for (facet_idx, facet_value) in facet_categories.iter().enumerate() {
440 let x_axis = Self::get_axis_reference(facet_idx, "x");
441 let y_axis = Self::get_axis_reference(facet_idx, "y");
442
443 for other_facet_value in facet_categories.iter() {
444 if other_facet_value != facet_value {
445 let other_data =
446 Self::filter_data_by_group(data, facet_column, other_facet_value);
447
448 let grey_color = config.unhighlighted_color.unwrap_or(Rgb(200, 200, 200));
449 let grey_marker = Self::create_marker(
450 0,
451 opacity,
452 size,
453 Some(grey_color),
454 None,
455 shape,
456 None,
457 );
458
459 let trace = Self::build_scatter_trace_with_axes(
460 &other_data,
461 x,
462 y,
463 None,
464 grey_marker,
465 Some(&x_axis),
466 Some(&y_axis),
467 false,
468 None,
469 );
470
471 all_traces.push(trace);
472 }
473 }
474
475 let facet_data = Self::filter_data_by_group(data, facet_column, facet_value);
476
477 match group {
478 Some(group_col) => {
479 let groups =
480 Self::get_unique_groups(&facet_data, group_col, sort_groups_by);
481
482 for group_val in groups.iter() {
483 let group_data =
484 Self::filter_data_by_group(&facet_data, group_col, group_val);
485
486 let global_idx =
487 global_group_indices.get(group_val).copied().unwrap_or(0);
488
489 let marker = Self::create_marker(
490 global_idx,
491 opacity,
492 size,
493 color,
494 colors.clone(),
495 shape,
496 shapes.clone(),
497 );
498
499 let show_legend = facet_idx == 0;
500
501 let trace = Self::build_scatter_trace_with_axes(
502 &group_data,
503 x,
504 y,
505 Some(group_val),
506 marker,
507 Some(&x_axis),
508 Some(&y_axis),
509 show_legend,
510 Some(group_val),
511 );
512
513 all_traces.push(trace);
514 }
515 }
516 None => {
517 let marker = Self::create_marker(
518 facet_idx,
519 opacity,
520 size,
521 color,
522 colors.clone(),
523 shape,
524 shapes.clone(),
525 );
526
527 let trace = Self::build_scatter_trace_with_axes(
528 &facet_data,
529 x,
530 y,
531 None,
532 marker,
533 Some(&x_axis),
534 Some(&y_axis),
535 false,
536 None,
537 );
538
539 all_traces.push(trace);
540 }
541 }
542 }
543 } else {
544 for (facet_idx, facet_value) in facet_categories.iter().enumerate() {
545 let facet_data = Self::filter_data_by_group(data, facet_column, facet_value);
546
547 let x_axis = Self::get_axis_reference(facet_idx, "x");
548 let y_axis = Self::get_axis_reference(facet_idx, "y");
549
550 match group {
551 Some(group_col) => {
552 let groups =
553 Self::get_unique_groups(&facet_data, group_col, sort_groups_by);
554
555 for group_val in groups.iter() {
556 let group_data =
557 Self::filter_data_by_group(&facet_data, group_col, group_val);
558
559 let global_idx =
560 global_group_indices.get(group_val).copied().unwrap_or(0);
561
562 let marker = Self::create_marker(
563 global_idx,
564 opacity,
565 size,
566 color,
567 colors.clone(),
568 shape,
569 shapes.clone(),
570 );
571
572 let show_legend = facet_idx == 0;
573
574 let trace = Self::build_scatter_trace_with_axes(
575 &group_data,
576 x,
577 y,
578 Some(group_val),
579 marker,
580 Some(&x_axis),
581 Some(&y_axis),
582 show_legend,
583 Some(group_val),
584 );
585
586 all_traces.push(trace);
587 }
588 }
589 None => {
590 let marker = Self::create_marker(
591 facet_idx,
592 opacity,
593 size,
594 color,
595 colors.clone(),
596 shape,
597 shapes.clone(),
598 );
599
600 let trace = Self::build_scatter_trace_with_axes(
601 &facet_data,
602 x,
603 y,
604 None,
605 marker,
606 Some(&x_axis),
607 Some(&y_axis),
608 false,
609 None,
610 );
611
612 all_traces.push(trace);
613 }
614 }
615 }
616 }
617
618 all_traces
619 }
620
621 #[allow(clippy::too_many_arguments)]
622 fn create_faceted_layout(
623 data: &DataFrame,
624 facet_column: &str,
625 config: &FacetConfig,
626 plot_title: Option<Text>,
627 x_title: Option<Text>,
628 y_title: Option<Text>,
629 legend_title: Option<Text>,
630 x_axis: Option<&Axis>,
631 y_axis: Option<&Axis>,
632 legend: Option<&Legend>,
633 ) -> LayoutPlotly {
634 let facet_categories = Self::get_unique_groups(data, facet_column, config.sorter);
635 let n_facets = facet_categories.len();
636
637 let (ncols, nrows) = Self::calculate_grid_dimensions(n_facets, config.cols, config.rows);
638
639 let mut grid = LayoutGrid::new()
640 .rows(nrows)
641 .columns(ncols)
642 .pattern(GridPattern::Independent);
643
644 if let Some(x_gap) = config.h_gap {
645 grid = grid.x_gap(x_gap);
646 }
647 if let Some(y_gap) = config.v_gap {
648 grid = grid.y_gap(y_gap);
649 }
650
651 let mut layout = LayoutPlotly::new().grid(grid);
652
653 if let Some(title) = plot_title {
654 layout = layout.title(title.to_plotly());
655 }
656
657 layout = Self::apply_axis_matching(layout, n_facets, &config.scales);
658
659 layout = Self::apply_facet_axis_titles(
660 layout, n_facets, ncols, nrows, x_title, y_title, x_axis, y_axis,
661 );
662
663 let annotations =
664 Self::create_facet_annotations(&facet_categories, config.title_style.as_ref());
665 layout = layout.annotations(annotations);
666
667 layout = layout.legend(Legend::set_legend(legend_title, legend));
668
669 layout
670 }
671
672 fn apply_axis_matching(
673 mut layout: LayoutPlotly,
674 n_facets: usize,
675 scales: &FacetScales,
676 ) -> LayoutPlotly {
677 use plotly::layout::Axis as AxisPlotly;
678
679 match scales {
680 FacetScales::Fixed => {
681 for i in 1..n_facets {
682 let x_axis = AxisPlotly::new().matches("x");
683 let y_axis = AxisPlotly::new().matches("y");
684 layout = match i {
685 1 => layout.x_axis2(x_axis).y_axis2(y_axis),
686 2 => layout.x_axis3(x_axis).y_axis3(y_axis),
687 3 => layout.x_axis4(x_axis).y_axis4(y_axis),
688 4 => layout.x_axis5(x_axis).y_axis5(y_axis),
689 5 => layout.x_axis6(x_axis).y_axis6(y_axis),
690 6 => layout.x_axis7(x_axis).y_axis7(y_axis),
691 7 => layout.x_axis8(x_axis).y_axis8(y_axis),
692 _ => layout,
693 };
694 }
695 }
696 FacetScales::FreeX => {
697 for i in 1..n_facets {
698 let axis = AxisPlotly::new().matches("y");
699 layout = match i {
700 1 => layout.y_axis2(axis),
701 2 => layout.y_axis3(axis),
702 3 => layout.y_axis4(axis),
703 4 => layout.y_axis5(axis),
704 5 => layout.y_axis6(axis),
705 6 => layout.y_axis7(axis),
706 7 => layout.y_axis8(axis),
707 _ => layout,
708 };
709 }
710 }
711 FacetScales::FreeY => {
712 for i in 1..n_facets {
713 let axis = AxisPlotly::new().matches("x");
714 layout = match i {
715 1 => layout.x_axis2(axis),
716 2 => layout.x_axis3(axis),
717 3 => layout.x_axis4(axis),
718 4 => layout.x_axis5(axis),
719 5 => layout.x_axis6(axis),
720 6 => layout.x_axis7(axis),
721 7 => layout.x_axis8(axis),
722 _ => layout,
723 };
724 }
725 }
726 FacetScales::Free => {}
727 }
728
729 layout
730 }
731
732 #[allow(clippy::too_many_arguments)]
733 fn apply_facet_axis_titles(
734 mut layout: LayoutPlotly,
735 n_facets: usize,
736 ncols: usize,
737 nrows: usize,
738 x_title: Option<Text>,
739 y_title: Option<Text>,
740 x_axis_config: Option<&Axis>,
741 y_axis_config: Option<&Axis>,
742 ) -> LayoutPlotly {
743 for i in 0..n_facets {
744 let is_bottom = Self::is_bottom_row(i, ncols, nrows, n_facets);
745 let is_left = Self::is_left_column(i, ncols);
746
747 let x_title_for_subplot = if is_bottom { x_title.clone() } else { None };
748 let y_title_for_subplot = if is_left { y_title.clone() } else { None };
749
750 if x_title_for_subplot.is_some() || x_axis_config.is_some() {
751 let axis = match x_axis_config {
752 Some(config) => Axis::set_axis(x_title_for_subplot, config, None),
753 None => {
754 if let Some(title) = x_title_for_subplot {
755 Axis::set_axis(Some(title), &Axis::default(), None)
756 } else {
757 continue;
758 }
759 }
760 };
761
762 layout = match i {
763 0 => layout.x_axis(axis),
764 1 => layout.x_axis2(axis),
765 2 => layout.x_axis3(axis),
766 3 => layout.x_axis4(axis),
767 4 => layout.x_axis5(axis),
768 5 => layout.x_axis6(axis),
769 6 => layout.x_axis7(axis),
770 7 => layout.x_axis8(axis),
771 _ => layout,
772 };
773 }
774
775 if y_title_for_subplot.is_some() || y_axis_config.is_some() {
776 let axis = match y_axis_config {
777 Some(config) => Axis::set_axis(y_title_for_subplot, config, None),
778 None => {
779 if let Some(title) = y_title_for_subplot {
780 Axis::set_axis(Some(title), &Axis::default(), None)
781 } else {
782 continue;
783 }
784 }
785 };
786
787 layout = match i {
788 0 => layout.y_axis(axis),
789 1 => layout.y_axis2(axis),
790 2 => layout.y_axis3(axis),
791 3 => layout.y_axis4(axis),
792 4 => layout.y_axis5(axis),
793 5 => layout.y_axis6(axis),
794 6 => layout.y_axis7(axis),
795 7 => layout.y_axis8(axis),
796 _ => layout,
797 };
798 }
799 }
800
801 layout
802 }
803}
804
805impl Layout for ScatterPlot {}
806impl Marker for ScatterPlot {}
807impl Polar for ScatterPlot {}
808
809impl PlotHelper for ScatterPlot {
810 fn get_layout(&self) -> &LayoutPlotly {
811 &self.layout
812 }
813
814 fn get_traces(&self) -> &Vec<Box<dyn Trace + 'static>> {
815 &self.traces
816 }
817}