1use bon::bon;
2
3use plotly::{
4 common::{Anchor, Domain},
5 layout::Annotation,
6 Layout as LayoutPlotly, Pie, Trace,
7};
8
9use polars::frame::DataFrame;
10use serde::Serialize;
11use std::collections::HashMap;
12
13use crate::{
14 common::{Layout, PlotHelper, Polar},
15 components::{FacetConfig, Legend, Rgb, Text},
16};
17
18#[derive(Clone, Serialize)]
72pub struct PieChart {
73 traces: Vec<Box<dyn Trace + 'static>>,
74 layout: LayoutPlotly,
75}
76
77#[bon]
78impl PieChart {
79 #[builder(on(String, into), on(Text, into))]
80 pub fn new(
81 data: &DataFrame,
82 labels: &str,
83 facet: Option<&str>,
84 facet_config: Option<&FacetConfig>,
85 hole: Option<f64>,
86 pull: Option<f64>,
87 rotation: Option<f64>,
88 colors: Option<Vec<Rgb>>,
89 plot_title: Option<Text>,
90 legend_title: Option<Text>,
91 legend: Option<&Legend>,
92 ) -> Self {
93 let x_title = None;
94 let y_title = None;
95 let z_title = None;
96 let x_axis = None;
97 let y_axis = None;
98 let z_axis = None;
99 let y2_title = None;
100 let y2_axis = None;
101
102 let (layout, traces) = match facet {
103 Some(facet_column) => {
104 let config = facet_config.cloned().unwrap_or_default();
105
106 let layout = Self::create_faceted_layout(
107 data,
108 facet_column,
109 &config,
110 plot_title,
111 legend_title,
112 legend,
113 );
114
115 let traces = Self::create_faceted_traces(
116 data,
117 labels,
118 facet_column,
119 &config,
120 hole,
121 pull,
122 rotation,
123 colors,
124 );
125
126 (layout, traces)
127 }
128 None => {
129 let layout = Self::create_layout(
130 plot_title,
131 x_title,
132 y_title,
133 y2_title,
134 z_title,
135 legend_title,
136 x_axis,
137 y_axis,
138 y2_axis,
139 z_axis,
140 legend,
141 None,
142 );
143
144 let traces = Self::create_traces(data, labels, hole, pull, rotation, colors);
145
146 (layout, traces)
147 }
148 };
149
150 Self { traces, layout }
151 }
152
153 fn create_traces(
154 data: &DataFrame,
155 labels: &str,
156 hole: Option<f64>,
157 pull: Option<f64>,
158 rotation: Option<f64>,
159 colors: Option<Vec<Rgb>>,
160 ) -> Vec<Box<dyn Trace + 'static>> {
161 let mut traces: Vec<Box<dyn Trace + 'static>> = Vec::new();
162
163 let color_map = if let Some(ref color_vec) = colors {
164 let label_values = Self::get_string_column(data, labels);
165 let unique_labels: Vec<String> = label_values
166 .iter()
167 .filter_map(|s| s.as_ref().map(|v| v.to_string()))
168 .collect::<std::collections::HashSet<_>>()
169 .into_iter()
170 .collect();
171
172 Some(Self::create_global_color_map(&unique_labels, color_vec))
173 } else {
174 None
175 };
176
177 let default_domain = Domain::new().x(&[0.0, 1.0]).y(&[0.0, 0.9]);
180
181 let trace = Self::create_trace(
182 data,
183 labels,
184 hole,
185 pull,
186 rotation,
187 Some(default_domain),
188 color_map,
189 );
190
191 traces.push(trace);
192 traces
193 }
194
195 #[allow(clippy::too_many_arguments)]
196 fn create_trace(
197 data: &DataFrame,
198 labels: &str,
199 hole: Option<f64>,
200 pull: Option<f64>,
201 rotation: Option<f64>,
202 domain: Option<Domain>,
203 color_map: Option<HashMap<String, String>>,
204 ) -> Box<dyn Trace + 'static> {
205 let labels = Self::get_string_column(data, labels)
206 .iter()
207 .filter_map(|s| {
208 if s.is_some() {
209 Some(s.clone().unwrap().to_owned())
210 } else {
211 None
212 }
213 })
214 .collect::<Vec<String>>();
215
216 let mut trace = Pie::<u32>::from_labels(&labels);
217
218 if let Some(hole) = hole {
219 trace = trace.hole(hole);
220 }
221
222 if let Some(pull) = pull {
223 trace = trace.pull(pull);
224 }
225
226 if let Some(rotation) = rotation {
227 trace = trace.rotation(rotation);
228 }
229
230 if let Some(domain_val) = domain {
231 trace = trace.domain(domain_val);
232 }
233
234 if let Some(color_mapping) = color_map {
235 let colors: Vec<String> = labels
236 .iter()
237 .map(|label| {
238 color_mapping
239 .get(label)
240 .cloned()
241 .unwrap_or_else(|| "#636EFA".to_string())
242 })
243 .collect();
244 trace = trace.marker(plotly::common::Marker::new().color_array(colors));
245 }
246
247 trace
248 }
249
250 #[allow(clippy::too_many_arguments)]
251 fn create_faceted_traces(
252 data: &DataFrame,
253 labels: &str,
254 facet_column: &str,
255 config: &FacetConfig,
256 hole: Option<f64>,
257 pull: Option<f64>,
258 rotation: Option<f64>,
259 colors: Option<Vec<Rgb>>,
260 ) -> Vec<Box<dyn Trace + 'static>> {
261 const MAX_FACETS: usize = 8;
262
263 let facet_categories = Self::get_unique_groups(data, facet_column, config.sorter);
264
265 if facet_categories.len() > MAX_FACETS {
266 panic!(
267 "Facet column '{}' has {} unique values, but plotly.rs supports maximum {} subplots",
268 facet_column,
269 facet_categories.len(),
270 MAX_FACETS
271 );
272 }
273
274 let color_map = if let Some(ref color_vec) = colors {
275 let label_values = Self::get_string_column(data, labels);
276 let unique_labels: Vec<String> = label_values
277 .iter()
278 .filter_map(|s| s.as_ref().map(|v| v.to_string()))
279 .collect::<std::collections::HashSet<_>>()
280 .into_iter()
281 .collect();
282
283 Some(Self::create_global_color_map(&unique_labels, color_vec))
284 } else {
285 None
286 };
287
288 let n_facets = facet_categories.len();
289 let (ncols, nrows) = Self::calculate_grid_dimensions(n_facets, config.cols, config.rows);
290
291 let facet_categories_non_empty: Vec<String> = facet_categories
292 .iter()
293 .filter(|facet_value| {
294 let facet_data = Self::filter_data_by_group(data, facet_column, facet_value);
295 facet_data.height() > 0
296 })
297 .cloned()
298 .collect();
299
300 let mut all_traces = Vec::new();
301
302 for (idx, facet_value) in facet_categories_non_empty.iter().enumerate() {
303 let facet_data = Self::filter_data_by_group(data, facet_column, facet_value);
304
305 let domain = Self::calculate_pie_domain(idx, ncols, nrows, config.h_gap, config.v_gap);
306
307 let trace = Self::create_trace(
308 &facet_data,
309 labels,
310 hole,
311 pull,
312 rotation,
313 Some(domain),
314 color_map.clone(),
315 );
316
317 all_traces.push(trace);
318 }
319
320 all_traces
321 }
322
323 fn create_faceted_layout(
324 data: &DataFrame,
325 facet_column: &str,
326 config: &FacetConfig,
327 plot_title: Option<Text>,
328 legend_title: Option<Text>,
329 legend: Option<&Legend>,
330 ) -> LayoutPlotly {
331 let facet_categories = Self::get_unique_groups(data, facet_column, config.sorter);
332
333 let facet_categories_non_empty: Vec<String> = facet_categories
334 .iter()
335 .filter(|facet_value| {
336 let facet_data = Self::filter_data_by_group(data, facet_column, facet_value);
337 facet_data.height() > 0
338 })
339 .cloned()
340 .collect();
341
342 let n_facets = facet_categories_non_empty.len();
343 let (ncols, nrows) = Self::calculate_grid_dimensions(n_facets, config.cols, config.rows);
344
345 let mut layout = LayoutPlotly::new();
346
347 if let Some(title) = plot_title {
348 layout = layout.title(title.to_plotly());
349 }
350
351 let annotations = Self::create_facet_annotations_pie(
352 &facet_categories_non_empty,
353 ncols,
354 nrows,
355 config.title_style.as_ref(),
356 config.h_gap,
357 config.v_gap,
358 );
359 layout = layout.annotations(annotations);
360
361 layout = layout.legend(Legend::set_legend(legend_title, legend));
362
363 layout
364 }
365
366 fn calculate_facet_cell(
372 subplot_index: usize,
373 ncols: usize,
374 nrows: usize,
375 x_gap: Option<f64>,
376 y_gap: Option<f64>,
377 ) -> FacetCell {
378 let row = subplot_index / ncols;
379 let col = subplot_index % ncols;
380
381 let x_gap_val = x_gap.unwrap_or(0.05);
382 let y_gap_val = y_gap.unwrap_or(0.10);
383
384 const TITLE_HEIGHT_RATIO: f64 = 0.10;
386 const TITLE_PADDING_RATIO: f64 = 0.35;
388
389 let cell_width = (1.0 - x_gap_val * (ncols - 1) as f64) / ncols as f64;
391 let cell_height = (1.0 - y_gap_val * (nrows - 1) as f64) / nrows as f64;
392
393 let cell_x_start = col as f64 * (cell_width + x_gap_val);
395 let cell_y_top = 1.0 - row as f64 * (cell_height + y_gap_val);
396 let cell_y_bottom = cell_y_top - cell_height;
397
398 let title_height = cell_height * TITLE_HEIGHT_RATIO;
400 let pie_y_top = cell_y_top - title_height;
401
402 let pie_x_start = cell_x_start;
404 let pie_x_end = cell_x_start + cell_width;
405 let pie_y_start = cell_y_bottom;
406 let pie_y_end = pie_y_top;
407
408 let padding_height = title_height * TITLE_PADDING_RATIO;
411 let actual_title_height = title_height - padding_height;
412 let annotation_x = cell_x_start + cell_width / 2.0;
413 let annotation_y = pie_y_top + padding_height + (actual_title_height / 2.0);
414
415 FacetCell {
416 pie_x_start,
417 pie_x_end,
418 pie_y_start,
419 pie_y_end,
420 annotation_x,
421 annotation_y,
422 }
423 }
424
425 fn calculate_pie_domain(
426 subplot_index: usize,
427 ncols: usize,
428 nrows: usize,
429 x_gap: Option<f64>,
430 y_gap: Option<f64>,
431 ) -> Domain {
432 let cell = Self::calculate_facet_cell(subplot_index, ncols, nrows, x_gap, y_gap);
433 Domain::new()
434 .x(&[cell.pie_x_start, cell.pie_x_end])
435 .y(&[cell.pie_y_start, cell.pie_y_end])
436 }
437
438 fn create_facet_annotations_pie(
439 categories: &[String],
440 ncols: usize,
441 nrows: usize,
442 title_style: Option<&Text>,
443 x_gap: Option<f64>,
444 y_gap: Option<f64>,
445 ) -> Vec<Annotation> {
446 categories
447 .iter()
448 .enumerate()
449 .map(|(i, cat)| {
450 let cell = Self::calculate_facet_cell(i, ncols, nrows, x_gap, y_gap);
451
452 let mut ann = Annotation::new()
453 .text(cat.as_str())
454 .x_ref("paper")
455 .y_ref("paper")
456 .x_anchor(Anchor::Center)
457 .y_anchor(Anchor::Middle)
458 .x(cell.annotation_x)
459 .y(cell.annotation_y)
460 .show_arrow(false);
461
462 if let Some(style) = title_style {
463 ann = ann.font(style.to_font());
464 }
465
466 ann
467 })
468 .collect()
469 }
470
471 fn create_global_color_map(labels: &[String], colors: &[Rgb]) -> HashMap<String, String> {
472 labels
473 .iter()
474 .enumerate()
475 .map(|(i, label)| {
476 let color_idx = i % colors.len();
477 let rgb = &colors[color_idx];
478 let color_str = format!("rgb({},{},{})", rgb.0, rgb.1, rgb.2);
479 (label.clone(), color_str)
480 })
481 .collect()
482 }
483}
484
485struct FacetCell {
487 pie_x_start: f64,
488 pie_x_end: f64,
489 pie_y_start: f64,
490 pie_y_end: f64,
491 annotation_x: f64,
492 annotation_y: f64,
493}
494
495impl Layout for PieChart {}
496impl Polar for PieChart {}
497
498impl PlotHelper for PieChart {
499 fn get_layout(&self) -> &LayoutPlotly {
500 &self.layout
501 }
502
503 fn get_traces(&self) -> &Vec<Box<dyn Trace + 'static>> {
504 &self.traces
505 }
506}