plotlars/plots/
piechart.rs

1use bon::bon;
2
3use plotly::{
4    common::{Anchor, Domain},
5    layout::Annotation,
6    Layout as LayoutPlotly, Pie, Trace,
7};
8
9use polars::frame::DataFrame;
10use serde::Serialize;
11use std::collections::HashMap;
12
13use crate::{
14    common::{Layout, PlotHelper, Polar},
15    components::{FacetConfig, Legend, Rgb, Text},
16};
17
18/// A structure representing a pie chart.
19///
20/// The `PieChart` struct allows for the creation and customization of pie charts, supporting
21/// features such as labels, hole size for donut-style charts, slice pulling, rotation, faceting, and customizable plot titles.
22/// It is ideal for visualizing proportions and distributions in categorical data.
23///
24/// # Arguments
25///
26/// * `data` - A reference to the `DataFrame` containing the data to be plotted.
27/// * `labels` - A string slice specifying the column name to be used for slice labels.
28/// * `facet` - An optional string slice specifying the column name to be used for creating facets (small multiples).
29/// * `facet_config` - An optional reference to a `FacetConfig` struct for customizing facet layout and behavior.
30/// * `hole` - An optional `f64` value specifying the size of the hole in the center of the pie chart.
31///   A value of `0.0` creates a full pie chart, while a value closer to `1.0` creates a thinner ring.
32/// * `pull` - An optional `f64` value specifying the fraction by which each slice should be pulled out from the center.
33/// * `rotation` - An optional `f64` value specifying the starting angle (in degrees) of the first slice.
34/// * `colors` - An optional vector of `Rgb` values specifying colors for consistent slice colors across facets.
35/// * `plot_title` - An optional `Text` struct specifying the title of the plot.
36/// * `legend_title` - An optional `Text` struct specifying the title of the legend.
37/// * `legend` - An optional reference to a `Legend` struct for customizing the legend of the plot (e.g., positioning, font, etc.).
38///
39/// # Example
40///
41/// ## Basic Pie Chart with Customization
42///
43/// ```rust
44/// use plotlars::{PieChart, Plot, Text};
45/// use polars::prelude::*;
46///
47/// let dataset = LazyCsvReader::new(PlPath::new("data/penguins.csv"))
48///     .finish()
49///     .unwrap()
50///     .select([col("species")])
51///     .collect()
52///     .unwrap();
53///
54/// PieChart::builder()
55///     .data(&dataset)
56///     .labels("species")
57///     .hole(0.4)
58///     .pull(0.01)
59///     .rotation(20.0)
60///     .plot_title(
61///         Text::from("Pie Chart")
62///             .font("Arial")
63///             .size(18)
64///             .x(0.485)
65///     )
66///     .build()
67///     .plot();
68/// ```
69///
70/// ![Example](https://imgur.com/q44HDwT.png)
71#[derive(Clone, Serialize)]
72pub struct PieChart {
73    traces: Vec<Box<dyn Trace + 'static>>,
74    layout: LayoutPlotly,
75}
76
77#[bon]
78impl PieChart {
79    #[builder(on(String, into), on(Text, into))]
80    pub fn new(
81        data: &DataFrame,
82        labels: &str,
83        facet: Option<&str>,
84        facet_config: Option<&FacetConfig>,
85        hole: Option<f64>,
86        pull: Option<f64>,
87        rotation: Option<f64>,
88        colors: Option<Vec<Rgb>>,
89        plot_title: Option<Text>,
90        legend_title: Option<Text>,
91        legend: Option<&Legend>,
92    ) -> Self {
93        let x_title = None;
94        let y_title = None;
95        let z_title = None;
96        let x_axis = None;
97        let y_axis = None;
98        let z_axis = None;
99        let y2_title = None;
100        let y2_axis = None;
101
102        let (layout, traces) = match facet {
103            Some(facet_column) => {
104                let config = facet_config.cloned().unwrap_or_default();
105
106                let layout = Self::create_faceted_layout(
107                    data,
108                    facet_column,
109                    &config,
110                    plot_title,
111                    legend_title,
112                    legend,
113                );
114
115                let traces = Self::create_faceted_traces(
116                    data,
117                    labels,
118                    facet_column,
119                    &config,
120                    hole,
121                    pull,
122                    rotation,
123                    colors,
124                );
125
126                (layout, traces)
127            }
128            None => {
129                let layout = Self::create_layout(
130                    plot_title,
131                    x_title,
132                    y_title,
133                    y2_title,
134                    z_title,
135                    legend_title,
136                    x_axis,
137                    y_axis,
138                    y2_axis,
139                    z_axis,
140                    legend,
141                    None,
142                );
143
144                let traces = Self::create_traces(data, labels, hole, pull, rotation, colors);
145
146                (layout, traces)
147            }
148        };
149
150        Self { traces, layout }
151    }
152
153    fn create_traces(
154        data: &DataFrame,
155        labels: &str,
156        hole: Option<f64>,
157        pull: Option<f64>,
158        rotation: Option<f64>,
159        colors: Option<Vec<Rgb>>,
160    ) -> Vec<Box<dyn Trace + 'static>> {
161        let mut traces: Vec<Box<dyn Trace + 'static>> = Vec::new();
162
163        let color_map = if let Some(ref color_vec) = colors {
164            let label_values = Self::get_string_column(data, labels);
165            let unique_labels: Vec<String> = label_values
166                .iter()
167                .filter_map(|s| s.as_ref().map(|v| v.to_string()))
168                .collect::<std::collections::HashSet<_>>()
169                .into_iter()
170                .collect();
171
172            Some(Self::create_global_color_map(&unique_labels, color_vec))
173        } else {
174            None
175        };
176
177        // Create default domain that reserves 10% space at top for title
178        // This matches the default title y-position of 0.9, creating visual separation
179        let default_domain = Domain::new().x(&[0.0, 1.0]).y(&[0.0, 0.9]);
180
181        let trace = Self::create_trace(
182            data,
183            labels,
184            hole,
185            pull,
186            rotation,
187            Some(default_domain),
188            color_map,
189        );
190
191        traces.push(trace);
192        traces
193    }
194
195    #[allow(clippy::too_many_arguments)]
196    fn create_trace(
197        data: &DataFrame,
198        labels: &str,
199        hole: Option<f64>,
200        pull: Option<f64>,
201        rotation: Option<f64>,
202        domain: Option<Domain>,
203        color_map: Option<HashMap<String, String>>,
204    ) -> Box<dyn Trace + 'static> {
205        let labels = Self::get_string_column(data, labels)
206            .iter()
207            .filter_map(|s| {
208                if s.is_some() {
209                    Some(s.clone().unwrap().to_owned())
210                } else {
211                    None
212                }
213            })
214            .collect::<Vec<String>>();
215
216        let mut trace = Pie::<u32>::from_labels(&labels);
217
218        if let Some(hole) = hole {
219            trace = trace.hole(hole);
220        }
221
222        if let Some(pull) = pull {
223            trace = trace.pull(pull);
224        }
225
226        if let Some(rotation) = rotation {
227            trace = trace.rotation(rotation);
228        }
229
230        if let Some(domain_val) = domain {
231            trace = trace.domain(domain_val);
232        }
233
234        if let Some(color_mapping) = color_map {
235            let colors: Vec<String> = labels
236                .iter()
237                .map(|label| {
238                    color_mapping
239                        .get(label)
240                        .cloned()
241                        .unwrap_or_else(|| "#636EFA".to_string())
242                })
243                .collect();
244            trace = trace.marker(plotly::common::Marker::new().color_array(colors));
245        }
246
247        trace
248    }
249
250    #[allow(clippy::too_many_arguments)]
251    fn create_faceted_traces(
252        data: &DataFrame,
253        labels: &str,
254        facet_column: &str,
255        config: &FacetConfig,
256        hole: Option<f64>,
257        pull: Option<f64>,
258        rotation: Option<f64>,
259        colors: Option<Vec<Rgb>>,
260    ) -> Vec<Box<dyn Trace + 'static>> {
261        const MAX_FACETS: usize = 8;
262
263        let facet_categories = Self::get_unique_groups(data, facet_column, config.sorter);
264
265        if facet_categories.len() > MAX_FACETS {
266            panic!(
267                "Facet column '{}' has {} unique values, but plotly.rs supports maximum {} subplots",
268                facet_column,
269                facet_categories.len(),
270                MAX_FACETS
271            );
272        }
273
274        let color_map = if let Some(ref color_vec) = colors {
275            let label_values = Self::get_string_column(data, labels);
276            let unique_labels: Vec<String> = label_values
277                .iter()
278                .filter_map(|s| s.as_ref().map(|v| v.to_string()))
279                .collect::<std::collections::HashSet<_>>()
280                .into_iter()
281                .collect();
282
283            Some(Self::create_global_color_map(&unique_labels, color_vec))
284        } else {
285            None
286        };
287
288        let n_facets = facet_categories.len();
289        let (ncols, nrows) = Self::calculate_grid_dimensions(n_facets, config.cols, config.rows);
290
291        let facet_categories_non_empty: Vec<String> = facet_categories
292            .iter()
293            .filter(|facet_value| {
294                let facet_data = Self::filter_data_by_group(data, facet_column, facet_value);
295                facet_data.height() > 0
296            })
297            .cloned()
298            .collect();
299
300        let mut all_traces = Vec::new();
301
302        for (idx, facet_value) in facet_categories_non_empty.iter().enumerate() {
303            let facet_data = Self::filter_data_by_group(data, facet_column, facet_value);
304
305            let domain = Self::calculate_pie_domain(idx, ncols, nrows, config.h_gap, config.v_gap);
306
307            let trace = Self::create_trace(
308                &facet_data,
309                labels,
310                hole,
311                pull,
312                rotation,
313                Some(domain),
314                color_map.clone(),
315            );
316
317            all_traces.push(trace);
318        }
319
320        all_traces
321    }
322
323    fn create_faceted_layout(
324        data: &DataFrame,
325        facet_column: &str,
326        config: &FacetConfig,
327        plot_title: Option<Text>,
328        legend_title: Option<Text>,
329        legend: Option<&Legend>,
330    ) -> LayoutPlotly {
331        let facet_categories = Self::get_unique_groups(data, facet_column, config.sorter);
332
333        let facet_categories_non_empty: Vec<String> = facet_categories
334            .iter()
335            .filter(|facet_value| {
336                let facet_data = Self::filter_data_by_group(data, facet_column, facet_value);
337                facet_data.height() > 0
338            })
339            .cloned()
340            .collect();
341
342        let n_facets = facet_categories_non_empty.len();
343        let (ncols, nrows) = Self::calculate_grid_dimensions(n_facets, config.cols, config.rows);
344
345        let mut layout = LayoutPlotly::new();
346
347        if let Some(title) = plot_title {
348            layout = layout.title(title.to_plotly());
349        }
350
351        let annotations = Self::create_facet_annotations_pie(
352            &facet_categories_non_empty,
353            ncols,
354            nrows,
355            config.title_style.as_ref(),
356            config.h_gap,
357            config.v_gap,
358        );
359        layout = layout.annotations(annotations);
360
361        layout = layout.legend(Legend::set_legend(legend_title, legend));
362
363        layout
364    }
365
366    /// Calculates the grid cell positions for a subplot with reserved space for titles.
367    ///
368    /// This function computes both the pie chart domain and annotation position,
369    /// ensuring that space is reserved above each pie chart for the facet title.
370    /// The title space prevents overlap between annotations and adjacent pie charts.
371    fn calculate_facet_cell(
372        subplot_index: usize,
373        ncols: usize,
374        nrows: usize,
375        x_gap: Option<f64>,
376        y_gap: Option<f64>,
377    ) -> FacetCell {
378        let row = subplot_index / ncols;
379        let col = subplot_index % ncols;
380
381        let x_gap_val = x_gap.unwrap_or(0.05);
382        let y_gap_val = y_gap.unwrap_or(0.10);
383
384        // Reserve space for facet title (10% of each cell's height)
385        const TITLE_HEIGHT_RATIO: f64 = 0.10;
386        // Padding ratio creates buffer space above annotation (35% of reserved title space)
387        const TITLE_PADDING_RATIO: f64 = 0.35;
388
389        // Calculate total cell dimensions
390        let cell_width = (1.0 - x_gap_val * (ncols - 1) as f64) / ncols as f64;
391        let cell_height = (1.0 - y_gap_val * (nrows - 1) as f64) / nrows as f64;
392
393        // Calculate cell boundaries
394        let cell_x_start = col as f64 * (cell_width + x_gap_val);
395        let cell_y_top = 1.0 - row as f64 * (cell_height + y_gap_val);
396        let cell_y_bottom = cell_y_top - cell_height;
397
398        // Reserve title space at the top of the cell (maintains 90% pie size)
399        let title_height = cell_height * TITLE_HEIGHT_RATIO;
400        let pie_y_top = cell_y_top - title_height;
401
402        // Pie chart domain (bottom 90% of the cell - preserved from original)
403        let pie_x_start = cell_x_start;
404        let pie_x_end = cell_x_start + cell_width;
405        let pie_y_start = cell_y_bottom;
406        let pie_y_end = pie_y_top;
407
408        // Calculate annotation position with padding buffer
409        // Padding creates visual space above annotation without reducing pie size
410        let padding_height = title_height * TITLE_PADDING_RATIO;
411        let actual_title_height = title_height - padding_height;
412        let annotation_x = cell_x_start + cell_width / 2.0;
413        let annotation_y = pie_y_top + padding_height + (actual_title_height / 2.0);
414
415        FacetCell {
416            pie_x_start,
417            pie_x_end,
418            pie_y_start,
419            pie_y_end,
420            annotation_x,
421            annotation_y,
422        }
423    }
424
425    fn calculate_pie_domain(
426        subplot_index: usize,
427        ncols: usize,
428        nrows: usize,
429        x_gap: Option<f64>,
430        y_gap: Option<f64>,
431    ) -> Domain {
432        let cell = Self::calculate_facet_cell(subplot_index, ncols, nrows, x_gap, y_gap);
433        Domain::new()
434            .x(&[cell.pie_x_start, cell.pie_x_end])
435            .y(&[cell.pie_y_start, cell.pie_y_end])
436    }
437
438    fn create_facet_annotations_pie(
439        categories: &[String],
440        ncols: usize,
441        nrows: usize,
442        title_style: Option<&Text>,
443        x_gap: Option<f64>,
444        y_gap: Option<f64>,
445    ) -> Vec<Annotation> {
446        categories
447            .iter()
448            .enumerate()
449            .map(|(i, cat)| {
450                let cell = Self::calculate_facet_cell(i, ncols, nrows, x_gap, y_gap);
451
452                let mut ann = Annotation::new()
453                    .text(cat.as_str())
454                    .x_ref("paper")
455                    .y_ref("paper")
456                    .x_anchor(Anchor::Center)
457                    .y_anchor(Anchor::Middle)
458                    .x(cell.annotation_x)
459                    .y(cell.annotation_y)
460                    .show_arrow(false);
461
462                if let Some(style) = title_style {
463                    ann = ann.font(style.to_font());
464                }
465
466                ann
467            })
468            .collect()
469    }
470
471    fn create_global_color_map(labels: &[String], colors: &[Rgb]) -> HashMap<String, String> {
472        labels
473            .iter()
474            .enumerate()
475            .map(|(i, label)| {
476                let color_idx = i % colors.len();
477                let rgb = &colors[color_idx];
478                let color_str = format!("rgb({},{},{})", rgb.0, rgb.1, rgb.2);
479                (label.clone(), color_str)
480            })
481            .collect()
482    }
483}
484
485/// Helper struct containing calculated positions for a facet cell
486struct FacetCell {
487    pie_x_start: f64,
488    pie_x_end: f64,
489    pie_y_start: f64,
490    pie_y_end: f64,
491    annotation_x: f64,
492    annotation_y: f64,
493}
494
495impl Layout for PieChart {}
496impl Polar for PieChart {}
497
498impl PlotHelper for PieChart {
499    fn get_layout(&self) -> &LayoutPlotly {
500        &self.layout
501    }
502
503    fn get_traces(&self) -> &Vec<Box<dyn Trace + 'static>> {
504        &self.traces
505    }
506}