Skip to main content

plotlars_core/plots/
barplot.rs

1use bon::bon;
2
3use polars::frame::DataFrame;
4
5use crate::{
6    components::{
7        Axis, BarMode, FacetConfig, Legend, Orientation, Rgb, Text, DEFAULT_PLOTLY_COLORS,
8    },
9    ir::data::ColumnData,
10    ir::layout::LayoutIR,
11    ir::marker::MarkerIR,
12    ir::trace::{BarPlotIR, TraceIR},
13};
14
15/// A structure representing a bar plot.
16///
17/// The `BarPlot` struct allows for the creation and customization of bar plots with various options
18/// for data, layout, and aesthetics. It supports both vertical and horizontal orientations, grouping
19/// of data, error bars, and customizable markers and colors.
20///
21/// # Backend Support
22///
23/// | Backend | Supported |
24/// |---------|-----------|
25/// | Plotly  | Yes       |
26/// | Plotters| Yes       |
27///
28/// # Arguments
29///
30/// * `data` - A reference to the `DataFrame` containing the data to be plotted.
31/// * `labels` - A string slice specifying the column name to be used for the x-axis (independent variable).
32/// * `values` - A string slice specifying the column name to be used for the y-axis (dependent variable).
33/// * `orientation` - An optional `Orientation` enum specifying whether the plot should be horizontal or vertical.
34/// * `group` - An optional string slice specifying the column name to be used for grouping data points.
35/// * `sort_groups_by` - Optional comparator `fn(&str, &str) -> std::cmp::Ordering` to control group ordering.
36///   Groups are sorted lexically by default.
37/// * `facet` - An optional string slice specifying the column name to be used for faceting (creating multiple subplots).
38/// * `facet_config` - An optional reference to a `FacetConfig` struct for customizing facet behavior (grid dimensions, scales, gaps, etc.).
39/// * `error` - An optional string slice specifying the column name containing error values for the y-axis data.
40/// * `color` - An optional `Rgb` value specifying the color of the markers to be used for the plot. This is used when `group` is not specified.
41/// * `colors` - An optional vector of `Rgb` values specifying the colors to be used for the plot. This is used when `group` is specified to differentiate between groups.
42/// * `mode` - An optional `BarMode` enum specifying how bars are displayed (e.g., grouped, stacked, overlaid). Defaults to `BarMode::Group`.
43/// * `plot_title` - An optional `Text` struct specifying the title of the plot.
44/// * `x_title` - An optional `Text` struct specifying the title of the x-axis.
45/// * `y_title` - An optional `Text` struct specifying the title of the y-axis.
46/// * `legend_title` - An optional `Text` struct specifying the title of the legend.
47/// * `x_axis` - An optional reference to an `Axis` struct for customizing the x-axis.
48/// * `y_axis` - An optional reference to an `Axis` struct for customizing the y-axis.
49/// * `legend` - An optional reference to a `Legend` struct for customizing the legend of the plot (e.g., positioning, font, etc.).
50///
51/// # Example
52///
53/// ```rust
54/// use plotlars::{BarPlot, Legend, Orientation, Plot, Rgb, Text};
55/// use polars::prelude::*;
56///
57/// let dataset = LazyCsvReader::new(PlRefPath::new("data/animal_statistics.csv"))
58///     .finish()
59///     .unwrap()
60///     .collect()
61///     .unwrap();
62///
63/// BarPlot::builder()
64///     .data(&dataset)
65///     .labels("animal")
66///     .values("value")
67///     .orientation(Orientation::Vertical)
68///     .group("gender")
69///     .sort_groups_by(|a, b| a.len().cmp(&b.len()))
70///     .error("error")
71///     .colors(vec![
72///         Rgb(255, 127, 80),
73///         Rgb(64, 224, 208),
74///     ])
75///     .plot_title(
76///         Text::from("Bar Plot")
77///             .font("Arial")
78///             .size(18)
79///     )
80///     .x_title(
81///         Text::from("animal")
82///             .font("Arial")
83///             .size(15)
84///     )
85///     .y_title(
86///         Text::from("value")
87///             .font("Arial")
88///             .size(15)
89///     )
90///     .legend_title(
91///         Text::from("gender")
92///             .font("Arial")
93///             .size(15)
94///     )
95///     .legend(
96///         &Legend::new()
97///             .orientation(Orientation::Horizontal)
98///             .y(1.0)
99///             .x(0.43)
100///     )
101///     .build()
102///     .plot();
103/// ```
104///
105/// ![Example](https://imgur.com/HQQvQey.png)
106#[derive(Clone)]
107#[allow(dead_code)]
108pub struct BarPlot {
109    traces: Vec<TraceIR>,
110    layout: LayoutIR,
111}
112
113#[bon]
114impl BarPlot {
115    #[builder(on(String, into), on(Text, into))]
116    pub fn new(
117        data: &DataFrame,
118        labels: &str,
119        values: &str,
120        orientation: Option<Orientation>,
121        group: Option<&str>,
122        sort_groups_by: Option<fn(&str, &str) -> std::cmp::Ordering>,
123        facet: Option<&str>,
124        facet_config: Option<&FacetConfig>,
125        error: Option<&str>,
126        color: Option<Rgb>,
127        colors: Option<Vec<Rgb>>,
128        mode: Option<BarMode>,
129        plot_title: Option<Text>,
130        x_title: Option<Text>,
131        y_title: Option<Text>,
132        legend_title: Option<Text>,
133        x_axis: Option<&Axis>,
134        y_axis: Option<&Axis>,
135        legend: Option<&Legend>,
136    ) -> Self {
137        let grid = facet.map(|facet_column| {
138            let config = facet_config.cloned().unwrap_or_default();
139            let facet_categories =
140                crate::data::get_unique_groups(data, facet_column, config.sorter);
141            let n_facets = facet_categories.len();
142            let (ncols, nrows) =
143                crate::faceting::calculate_grid_dimensions(n_facets, config.cols, config.rows);
144            crate::ir::facet::GridSpec {
145                kind: crate::ir::facet::FacetKind::Axis,
146                rows: nrows,
147                cols: ncols,
148                h_gap: config.h_gap,
149                v_gap: config.v_gap,
150                scales: config.scales.clone(),
151                n_facets,
152                facet_categories,
153                title_style: config.title_style.clone(),
154                x_title: x_title.clone(),
155                y_title: y_title.clone(),
156                x_axis: x_axis.cloned(),
157                y_axis: y_axis.cloned(),
158                legend_title: legend_title.clone(),
159                legend: legend.cloned(),
160            }
161        });
162
163        let layout = LayoutIR {
164            title: plot_title.clone(),
165            x_title: if grid.is_some() {
166                None
167            } else {
168                x_title.clone()
169            },
170            y_title: if grid.is_some() {
171                None
172            } else {
173                y_title.clone()
174            },
175            y2_title: None,
176            z_title: None,
177            legend_title: if grid.is_some() {
178                None
179            } else {
180                legend_title.clone()
181            },
182            legend: if grid.is_some() {
183                None
184            } else {
185                legend.cloned()
186            },
187            dimensions: None,
188            bar_mode: Some(mode.clone().unwrap_or(crate::components::BarMode::Group)),
189            box_mode: None,
190            box_gap: None,
191            margin_bottom: None,
192            axes_2d: if grid.is_some() {
193                None
194            } else {
195                Some(crate::ir::layout::Axes2dIR {
196                    x_axis: x_axis.cloned(),
197                    y_axis: y_axis.cloned(),
198                    y2_axis: None,
199                })
200            },
201            scene_3d: None,
202            polar: None,
203            mapbox: None,
204            grid,
205            annotations: vec![],
206        };
207
208        let traces = match facet {
209            Some(facet_column) => {
210                let config = facet_config.cloned().unwrap_or_default();
211                Self::create_ir_traces_faceted(
212                    data,
213                    labels,
214                    values,
215                    orientation.clone(),
216                    group,
217                    sort_groups_by,
218                    facet_column,
219                    &config,
220                    error,
221                    color,
222                    colors.clone(),
223                )
224            }
225            None => Self::create_ir_traces(
226                data,
227                labels,
228                values,
229                orientation,
230                group,
231                sort_groups_by,
232                error,
233                color,
234                colors,
235            ),
236        };
237
238        Self { traces, layout }
239    }
240}
241
242#[bon]
243impl BarPlot {
244    #[builder(
245        start_fn = try_builder,
246        finish_fn = try_build,
247        builder_type = BarPlotTryBuilder,
248        on(String, into),
249        on(Text, into),
250    )]
251    pub fn try_new(
252        data: &DataFrame,
253        labels: &str,
254        values: &str,
255        orientation: Option<Orientation>,
256        group: Option<&str>,
257        sort_groups_by: Option<fn(&str, &str) -> std::cmp::Ordering>,
258        facet: Option<&str>,
259        facet_config: Option<&FacetConfig>,
260        error: Option<&str>,
261        color: Option<Rgb>,
262        colors: Option<Vec<Rgb>>,
263        mode: Option<BarMode>,
264        plot_title: Option<Text>,
265        x_title: Option<Text>,
266        y_title: Option<Text>,
267        legend_title: Option<Text>,
268        x_axis: Option<&Axis>,
269        y_axis: Option<&Axis>,
270        legend: Option<&Legend>,
271    ) -> Result<Self, crate::io::PlotlarsError> {
272        std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
273            Self::__orig_new(
274                data,
275                labels,
276                values,
277                orientation,
278                group,
279                sort_groups_by,
280                facet,
281                facet_config,
282                error,
283                color,
284                colors,
285                mode,
286                plot_title,
287                x_title,
288                y_title,
289                legend_title,
290                x_axis,
291                y_axis,
292                legend,
293            )
294        }))
295        .map_err(|panic| {
296            let msg = panic
297                .downcast_ref::<String>()
298                .cloned()
299                .or_else(|| panic.downcast_ref::<&str>().map(|s| s.to_string()))
300                .unwrap_or_else(|| "unknown error".to_string());
301            crate::io::PlotlarsError::PlotBuild { message: msg }
302        })
303    }
304}
305
306impl BarPlot {
307    #[allow(clippy::too_many_arguments)]
308    fn create_ir_traces(
309        data: &DataFrame,
310        labels: &str,
311        values: &str,
312        orientation: Option<Orientation>,
313        group: Option<&str>,
314        sort_groups_by: Option<fn(&str, &str) -> std::cmp::Ordering>,
315        error: Option<&str>,
316        color: Option<Rgb>,
317        colors: Option<Vec<Rgb>>,
318    ) -> Vec<TraceIR> {
319        let mut traces = Vec::new();
320
321        match group {
322            Some(group_col) => {
323                let groups = crate::data::get_unique_groups(data, group_col, sort_groups_by);
324
325                for (i, group_name) in groups.iter().enumerate() {
326                    let subset = crate::data::filter_data_by_group(data, group_col, group_name);
327
328                    let marker_ir = MarkerIR {
329                        opacity: None,
330                        size: None,
331                        color: Self::resolve_color(i, color, colors.clone()),
332                        shape: None,
333                    };
334
335                    let error_data = error
336                        .map(|e| ColumnData::Numeric(crate::data::get_numeric_column(&subset, e)));
337
338                    traces.push(TraceIR::BarPlot(BarPlotIR {
339                        labels: ColumnData::String(crate::data::get_string_column(&subset, labels)),
340                        values: ColumnData::Numeric(crate::data::get_numeric_column(
341                            &subset, values,
342                        )),
343                        name: Some(group_name.to_string()),
344                        orientation: orientation.clone(),
345                        marker: Some(marker_ir),
346                        error: error_data,
347                        show_legend: None,
348                        legend_group: None,
349                        subplot_ref: None,
350                    }));
351                }
352            }
353            None => {
354                let marker_ir = MarkerIR {
355                    opacity: None,
356                    size: None,
357                    color: Self::resolve_color(0, color, colors),
358                    shape: None,
359                };
360
361                let error_data =
362                    error.map(|e| ColumnData::Numeric(crate::data::get_numeric_column(data, e)));
363
364                traces.push(TraceIR::BarPlot(BarPlotIR {
365                    labels: ColumnData::String(crate::data::get_string_column(data, labels)),
366                    values: ColumnData::Numeric(crate::data::get_numeric_column(data, values)),
367                    name: None,
368                    orientation: orientation.clone(),
369                    marker: Some(marker_ir),
370                    error: error_data,
371                    show_legend: None,
372                    legend_group: None,
373                    subplot_ref: None,
374                }));
375            }
376        }
377
378        traces
379    }
380
381    #[allow(clippy::too_many_arguments)]
382    fn create_ir_traces_faceted(
383        data: &DataFrame,
384        labels: &str,
385        values: &str,
386        orientation: Option<Orientation>,
387        group: Option<&str>,
388        sort_groups_by: Option<fn(&str, &str) -> std::cmp::Ordering>,
389        facet_column: &str,
390        config: &FacetConfig,
391        error: Option<&str>,
392        color: Option<Rgb>,
393        colors: Option<Vec<Rgb>>,
394    ) -> Vec<TraceIR> {
395        const MAX_FACETS: usize = 8;
396
397        let facet_categories = crate::data::get_unique_groups(data, facet_column, config.sorter);
398
399        if facet_categories.len() > MAX_FACETS {
400            panic!(
401                "Facet column '{}' has {} unique values, but plotly.rs supports maximum {} subplots",
402                facet_column,
403                facet_categories.len(),
404                MAX_FACETS
405            );
406        }
407
408        if let Some(ref color_vec) = colors {
409            if group.is_none() {
410                let color_count = color_vec.len();
411                let facet_count = facet_categories.len();
412                if color_count != facet_count {
413                    panic!(
414                        "When using colors with facet (without group), colors.len() must equal number of facets. \
415                         Expected {} colors for {} facets, but got {} colors. \
416                         Each facet must be assigned exactly one color.",
417                        facet_count, facet_count, color_count
418                    );
419                }
420            } else if let Some(group_col) = group {
421                let groups = crate::data::get_unique_groups(data, group_col, sort_groups_by);
422                let color_count = color_vec.len();
423                let group_count = groups.len();
424                if color_count < group_count {
425                    panic!(
426                        "When using colors with group, colors.len() must be >= number of groups. \
427                         Need at least {} colors for {} groups, but got {} colors",
428                        group_count, group_count, color_count
429                    );
430                }
431            }
432        }
433
434        let global_group_indices: std::collections::HashMap<String, usize> =
435            if let Some(group_col) = group {
436                let global_groups = crate::data::get_unique_groups(data, group_col, sort_groups_by);
437                global_groups
438                    .into_iter()
439                    .enumerate()
440                    .map(|(idx, group_name)| (group_name, idx))
441                    .collect()
442            } else {
443                std::collections::HashMap::new()
444            };
445
446        let colors = if group.is_some() && colors.is_none() {
447            Some(DEFAULT_PLOTLY_COLORS.to_vec())
448        } else {
449            colors
450        };
451
452        let mut traces = Vec::new();
453
454        for (facet_idx, facet_value) in facet_categories.iter().enumerate() {
455            let facet_data = crate::data::filter_data_by_group(data, facet_column, facet_value);
456
457            let subplot_ref = format!(
458                "{}{}",
459                crate::faceting::get_axis_reference(facet_idx, "x"),
460                crate::faceting::get_axis_reference(facet_idx, "y")
461            );
462
463            match group {
464                Some(group_col) => {
465                    let groups =
466                        crate::data::get_unique_groups(&facet_data, group_col, sort_groups_by);
467
468                    for group_val in groups.iter() {
469                        let group_data =
470                            crate::data::filter_data_by_group(&facet_data, group_col, group_val);
471
472                        let global_idx = global_group_indices.get(group_val).copied().unwrap_or(0);
473
474                        let marker_ir = MarkerIR {
475                            opacity: None,
476                            size: None,
477                            color: Self::resolve_color(global_idx, color, colors.clone()),
478                            shape: None,
479                        };
480
481                        let error_data = error.map(|e| {
482                            ColumnData::Numeric(crate::data::get_numeric_column(&group_data, e))
483                        });
484
485                        traces.push(TraceIR::BarPlot(BarPlotIR {
486                            labels: ColumnData::String(crate::data::get_string_column(
487                                &group_data,
488                                labels,
489                            )),
490                            values: ColumnData::Numeric(crate::data::get_numeric_column(
491                                &group_data,
492                                values,
493                            )),
494                            name: Some(group_val.to_string()),
495                            orientation: orientation.clone(),
496                            marker: Some(marker_ir),
497                            error: error_data,
498                            show_legend: Some(facet_idx == 0),
499                            legend_group: Some(group_val.to_string()),
500                            subplot_ref: Some(subplot_ref.clone()),
501                        }));
502                    }
503                }
504                None => {
505                    let marker_ir = MarkerIR {
506                        opacity: None,
507                        size: None,
508                        color: Self::resolve_color(facet_idx, color, colors.clone()),
509                        shape: None,
510                    };
511
512                    let error_data = error.map(|e| {
513                        ColumnData::Numeric(crate::data::get_numeric_column(&facet_data, e))
514                    });
515
516                    traces.push(TraceIR::BarPlot(BarPlotIR {
517                        labels: ColumnData::String(crate::data::get_string_column(
518                            &facet_data,
519                            labels,
520                        )),
521                        values: ColumnData::Numeric(crate::data::get_numeric_column(
522                            &facet_data,
523                            values,
524                        )),
525                        name: None,
526                        orientation: orientation.clone(),
527                        marker: Some(marker_ir),
528                        error: error_data,
529                        show_legend: Some(false),
530                        legend_group: None,
531                        subplot_ref: Some(subplot_ref.clone()),
532                    }));
533                }
534            }
535        }
536
537        traces
538    }
539
540    fn resolve_color(index: usize, color: Option<Rgb>, colors: Option<Vec<Rgb>>) -> Option<Rgb> {
541        if let Some(c) = color {
542            return Some(c);
543        }
544        if let Some(ref cs) = colors {
545            return cs.get(index).copied();
546        }
547        None
548    }
549}
550
551impl crate::Plot for BarPlot {
552    fn ir_traces(&self) -> &[TraceIR] {
553        &self.traces
554    }
555
556    fn ir_layout(&self) -> &LayoutIR {
557        &self.layout
558    }
559}
560
561#[cfg(test)]
562mod tests {
563    use super::*;
564    use crate::Plot;
565    use polars::prelude::*;
566
567    fn assert_rgb(actual: Option<Rgb>, r: u8, g: u8, b: u8) {
568        let c = actual.expect("expected Some(Rgb)");
569        assert_eq!((c.0, c.1, c.2), (r, g, b));
570    }
571
572    #[test]
573    fn test_resolve_color_singular_priority() {
574        let result = BarPlot::resolve_color(0, Some(Rgb(255, 0, 0)), Some(vec![Rgb(0, 0, 255)]));
575        assert_rgb(result, 255, 0, 0);
576    }
577
578    #[test]
579    fn test_resolve_color_both_none() {
580        let result = BarPlot::resolve_color(0, None, None);
581        assert!(result.is_none());
582    }
583
584    #[test]
585    fn test_no_group_one_trace() {
586        let df = df!["labels" => ["a", "b", "c"], "values" => [1.0, 2.0, 3.0]].unwrap();
587        let plot = BarPlot::builder()
588            .data(&df)
589            .labels("labels")
590            .values("values")
591            .build();
592        assert_eq!(plot.ir_traces().len(), 1);
593    }
594
595    #[test]
596    fn test_with_group() {
597        let df = df![
598            "labels" => ["a", "b", "a", "b"],
599            "values" => [1.0, 2.0, 3.0, 4.0],
600            "g" => ["x", "x", "y", "y"]
601        ]
602        .unwrap();
603        let plot = BarPlot::builder()
604            .data(&df)
605            .labels("labels")
606            .values("values")
607            .group("g")
608            .build();
609        assert_eq!(plot.ir_traces().len(), 2);
610    }
611
612    #[test]
613    fn test_faceted_trace_count() {
614        let df = df![
615            "labels" => ["a", "b", "c", "a", "b", "c"],
616            "values" => [1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
617            "f" => ["f1", "f2", "f1", "f2", "f1", "f2"]
618        ]
619        .unwrap();
620        let plot = BarPlot::builder()
621            .data(&df)
622            .labels("labels")
623            .values("values")
624            .facet("f")
625            .build();
626        assert_eq!(plot.ir_traces().len(), 2);
627    }
628
629    #[test]
630    #[should_panic(expected = "maximum")]
631    fn test_max_facets_panics() {
632        let facet_values: Vec<&str> = (0..9)
633            .map(|i| match i {
634                0 => "a",
635                1 => "b",
636                2 => "c",
637                3 => "d",
638                4 => "e",
639                5 => "f",
640                6 => "g",
641                7 => "h",
642                _ => "i",
643            })
644            .collect();
645        let n = facet_values.len();
646        let labels: Vec<&str> = (0..n).map(|_| "label").collect();
647        let values: Vec<f64> = (0..n).map(|i| i as f64).collect();
648        let df = DataFrame::new(
649            n,
650            vec![
651                Column::new("labels".into(), &labels),
652                Column::new("values".into(), &values),
653                Column::new("f".into(), &facet_values),
654            ],
655        )
656        .unwrap();
657        BarPlot::builder()
658            .data(&df)
659            .labels("labels")
660            .values("values")
661            .facet("f")
662            .build();
663    }
664}