1use bon::bon;
2
3use polars::frame::DataFrame;
4
5use crate::{
6 components::{
7 Axis, BarMode, FacetConfig, Legend, Orientation, Rgb, Text, DEFAULT_PLOTLY_COLORS,
8 },
9 ir::data::ColumnData,
10 ir::layout::LayoutIR,
11 ir::marker::MarkerIR,
12 ir::trace::{BarPlotIR, TraceIR},
13};
14
15#[derive(Clone)]
107#[allow(dead_code)]
108pub struct BarPlot {
109 traces: Vec<TraceIR>,
110 layout: LayoutIR,
111}
112
113#[bon]
114impl BarPlot {
115 #[builder(on(String, into), on(Text, into))]
116 pub fn new(
117 data: &DataFrame,
118 labels: &str,
119 values: &str,
120 orientation: Option<Orientation>,
121 group: Option<&str>,
122 sort_groups_by: Option<fn(&str, &str) -> std::cmp::Ordering>,
123 facet: Option<&str>,
124 facet_config: Option<&FacetConfig>,
125 error: Option<&str>,
126 color: Option<Rgb>,
127 colors: Option<Vec<Rgb>>,
128 mode: Option<BarMode>,
129 plot_title: Option<Text>,
130 x_title: Option<Text>,
131 y_title: Option<Text>,
132 legend_title: Option<Text>,
133 x_axis: Option<&Axis>,
134 y_axis: Option<&Axis>,
135 legend: Option<&Legend>,
136 ) -> Self {
137 let grid = facet.map(|facet_column| {
138 let config = facet_config.cloned().unwrap_or_default();
139 let facet_categories =
140 crate::data::get_unique_groups(data, facet_column, config.sorter);
141 let n_facets = facet_categories.len();
142 let (ncols, nrows) =
143 crate::faceting::calculate_grid_dimensions(n_facets, config.cols, config.rows);
144 crate::ir::facet::GridSpec {
145 kind: crate::ir::facet::FacetKind::Axis,
146 rows: nrows,
147 cols: ncols,
148 h_gap: config.h_gap,
149 v_gap: config.v_gap,
150 scales: config.scales.clone(),
151 n_facets,
152 facet_categories,
153 title_style: config.title_style.clone(),
154 x_title: x_title.clone(),
155 y_title: y_title.clone(),
156 x_axis: x_axis.cloned(),
157 y_axis: y_axis.cloned(),
158 legend_title: legend_title.clone(),
159 legend: legend.cloned(),
160 }
161 });
162
163 let layout = LayoutIR {
164 title: plot_title.clone(),
165 x_title: if grid.is_some() {
166 None
167 } else {
168 x_title.clone()
169 },
170 y_title: if grid.is_some() {
171 None
172 } else {
173 y_title.clone()
174 },
175 y2_title: None,
176 z_title: None,
177 legend_title: if grid.is_some() {
178 None
179 } else {
180 legend_title.clone()
181 },
182 legend: if grid.is_some() {
183 None
184 } else {
185 legend.cloned()
186 },
187 dimensions: None,
188 bar_mode: Some(mode.clone().unwrap_or(crate::components::BarMode::Group)),
189 box_mode: None,
190 box_gap: None,
191 margin_bottom: None,
192 axes_2d: if grid.is_some() {
193 None
194 } else {
195 Some(crate::ir::layout::Axes2dIR {
196 x_axis: x_axis.cloned(),
197 y_axis: y_axis.cloned(),
198 y2_axis: None,
199 })
200 },
201 scene_3d: None,
202 polar: None,
203 mapbox: None,
204 grid,
205 annotations: vec![],
206 };
207
208 let traces = match facet {
209 Some(facet_column) => {
210 let config = facet_config.cloned().unwrap_or_default();
211 Self::create_ir_traces_faceted(
212 data,
213 labels,
214 values,
215 orientation.clone(),
216 group,
217 sort_groups_by,
218 facet_column,
219 &config,
220 error,
221 color,
222 colors.clone(),
223 )
224 }
225 None => Self::create_ir_traces(
226 data,
227 labels,
228 values,
229 orientation,
230 group,
231 sort_groups_by,
232 error,
233 color,
234 colors,
235 ),
236 };
237
238 Self { traces, layout }
239 }
240}
241
242#[bon]
243impl BarPlot {
244 #[builder(
245 start_fn = try_builder,
246 finish_fn = try_build,
247 builder_type = BarPlotTryBuilder,
248 on(String, into),
249 on(Text, into),
250 )]
251 pub fn try_new(
252 data: &DataFrame,
253 labels: &str,
254 values: &str,
255 orientation: Option<Orientation>,
256 group: Option<&str>,
257 sort_groups_by: Option<fn(&str, &str) -> std::cmp::Ordering>,
258 facet: Option<&str>,
259 facet_config: Option<&FacetConfig>,
260 error: Option<&str>,
261 color: Option<Rgb>,
262 colors: Option<Vec<Rgb>>,
263 mode: Option<BarMode>,
264 plot_title: Option<Text>,
265 x_title: Option<Text>,
266 y_title: Option<Text>,
267 legend_title: Option<Text>,
268 x_axis: Option<&Axis>,
269 y_axis: Option<&Axis>,
270 legend: Option<&Legend>,
271 ) -> Result<Self, crate::io::PlotlarsError> {
272 std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
273 Self::__orig_new(
274 data,
275 labels,
276 values,
277 orientation,
278 group,
279 sort_groups_by,
280 facet,
281 facet_config,
282 error,
283 color,
284 colors,
285 mode,
286 plot_title,
287 x_title,
288 y_title,
289 legend_title,
290 x_axis,
291 y_axis,
292 legend,
293 )
294 }))
295 .map_err(|panic| {
296 let msg = panic
297 .downcast_ref::<String>()
298 .cloned()
299 .or_else(|| panic.downcast_ref::<&str>().map(|s| s.to_string()))
300 .unwrap_or_else(|| "unknown error".to_string());
301 crate::io::PlotlarsError::PlotBuild { message: msg }
302 })
303 }
304}
305
306impl BarPlot {
307 #[allow(clippy::too_many_arguments)]
308 fn create_ir_traces(
309 data: &DataFrame,
310 labels: &str,
311 values: &str,
312 orientation: Option<Orientation>,
313 group: Option<&str>,
314 sort_groups_by: Option<fn(&str, &str) -> std::cmp::Ordering>,
315 error: Option<&str>,
316 color: Option<Rgb>,
317 colors: Option<Vec<Rgb>>,
318 ) -> Vec<TraceIR> {
319 let mut traces = Vec::new();
320
321 match group {
322 Some(group_col) => {
323 let groups = crate::data::get_unique_groups(data, group_col, sort_groups_by);
324
325 for (i, group_name) in groups.iter().enumerate() {
326 let subset = crate::data::filter_data_by_group(data, group_col, group_name);
327
328 let marker_ir = MarkerIR {
329 opacity: None,
330 size: None,
331 color: Self::resolve_color(i, color, colors.clone()),
332 shape: None,
333 };
334
335 let error_data = error
336 .map(|e| ColumnData::Numeric(crate::data::get_numeric_column(&subset, e)));
337
338 traces.push(TraceIR::BarPlot(BarPlotIR {
339 labels: ColumnData::String(crate::data::get_string_column(&subset, labels)),
340 values: ColumnData::Numeric(crate::data::get_numeric_column(
341 &subset, values,
342 )),
343 name: Some(group_name.to_string()),
344 orientation: orientation.clone(),
345 marker: Some(marker_ir),
346 error: error_data,
347 show_legend: None,
348 legend_group: None,
349 subplot_ref: None,
350 }));
351 }
352 }
353 None => {
354 let marker_ir = MarkerIR {
355 opacity: None,
356 size: None,
357 color: Self::resolve_color(0, color, colors),
358 shape: None,
359 };
360
361 let error_data =
362 error.map(|e| ColumnData::Numeric(crate::data::get_numeric_column(data, e)));
363
364 traces.push(TraceIR::BarPlot(BarPlotIR {
365 labels: ColumnData::String(crate::data::get_string_column(data, labels)),
366 values: ColumnData::Numeric(crate::data::get_numeric_column(data, values)),
367 name: None,
368 orientation: orientation.clone(),
369 marker: Some(marker_ir),
370 error: error_data,
371 show_legend: None,
372 legend_group: None,
373 subplot_ref: None,
374 }));
375 }
376 }
377
378 traces
379 }
380
381 #[allow(clippy::too_many_arguments)]
382 fn create_ir_traces_faceted(
383 data: &DataFrame,
384 labels: &str,
385 values: &str,
386 orientation: Option<Orientation>,
387 group: Option<&str>,
388 sort_groups_by: Option<fn(&str, &str) -> std::cmp::Ordering>,
389 facet_column: &str,
390 config: &FacetConfig,
391 error: Option<&str>,
392 color: Option<Rgb>,
393 colors: Option<Vec<Rgb>>,
394 ) -> Vec<TraceIR> {
395 const MAX_FACETS: usize = 8;
396
397 let facet_categories = crate::data::get_unique_groups(data, facet_column, config.sorter);
398
399 if facet_categories.len() > MAX_FACETS {
400 panic!(
401 "Facet column '{}' has {} unique values, but plotly.rs supports maximum {} subplots",
402 facet_column,
403 facet_categories.len(),
404 MAX_FACETS
405 );
406 }
407
408 if let Some(ref color_vec) = colors {
409 if group.is_none() {
410 let color_count = color_vec.len();
411 let facet_count = facet_categories.len();
412 if color_count != facet_count {
413 panic!(
414 "When using colors with facet (without group), colors.len() must equal number of facets. \
415 Expected {} colors for {} facets, but got {} colors. \
416 Each facet must be assigned exactly one color.",
417 facet_count, facet_count, color_count
418 );
419 }
420 } else if let Some(group_col) = group {
421 let groups = crate::data::get_unique_groups(data, group_col, sort_groups_by);
422 let color_count = color_vec.len();
423 let group_count = groups.len();
424 if color_count < group_count {
425 panic!(
426 "When using colors with group, colors.len() must be >= number of groups. \
427 Need at least {} colors for {} groups, but got {} colors",
428 group_count, group_count, color_count
429 );
430 }
431 }
432 }
433
434 let global_group_indices: std::collections::HashMap<String, usize> =
435 if let Some(group_col) = group {
436 let global_groups = crate::data::get_unique_groups(data, group_col, sort_groups_by);
437 global_groups
438 .into_iter()
439 .enumerate()
440 .map(|(idx, group_name)| (group_name, idx))
441 .collect()
442 } else {
443 std::collections::HashMap::new()
444 };
445
446 let colors = if group.is_some() && colors.is_none() {
447 Some(DEFAULT_PLOTLY_COLORS.to_vec())
448 } else {
449 colors
450 };
451
452 let mut traces = Vec::new();
453
454 for (facet_idx, facet_value) in facet_categories.iter().enumerate() {
455 let facet_data = crate::data::filter_data_by_group(data, facet_column, facet_value);
456
457 let subplot_ref = format!(
458 "{}{}",
459 crate::faceting::get_axis_reference(facet_idx, "x"),
460 crate::faceting::get_axis_reference(facet_idx, "y")
461 );
462
463 match group {
464 Some(group_col) => {
465 let groups =
466 crate::data::get_unique_groups(&facet_data, group_col, sort_groups_by);
467
468 for group_val in groups.iter() {
469 let group_data =
470 crate::data::filter_data_by_group(&facet_data, group_col, group_val);
471
472 let global_idx = global_group_indices.get(group_val).copied().unwrap_or(0);
473
474 let marker_ir = MarkerIR {
475 opacity: None,
476 size: None,
477 color: Self::resolve_color(global_idx, color, colors.clone()),
478 shape: None,
479 };
480
481 let error_data = error.map(|e| {
482 ColumnData::Numeric(crate::data::get_numeric_column(&group_data, e))
483 });
484
485 traces.push(TraceIR::BarPlot(BarPlotIR {
486 labels: ColumnData::String(crate::data::get_string_column(
487 &group_data,
488 labels,
489 )),
490 values: ColumnData::Numeric(crate::data::get_numeric_column(
491 &group_data,
492 values,
493 )),
494 name: Some(group_val.to_string()),
495 orientation: orientation.clone(),
496 marker: Some(marker_ir),
497 error: error_data,
498 show_legend: Some(facet_idx == 0),
499 legend_group: Some(group_val.to_string()),
500 subplot_ref: Some(subplot_ref.clone()),
501 }));
502 }
503 }
504 None => {
505 let marker_ir = MarkerIR {
506 opacity: None,
507 size: None,
508 color: Self::resolve_color(facet_idx, color, colors.clone()),
509 shape: None,
510 };
511
512 let error_data = error.map(|e| {
513 ColumnData::Numeric(crate::data::get_numeric_column(&facet_data, e))
514 });
515
516 traces.push(TraceIR::BarPlot(BarPlotIR {
517 labels: ColumnData::String(crate::data::get_string_column(
518 &facet_data,
519 labels,
520 )),
521 values: ColumnData::Numeric(crate::data::get_numeric_column(
522 &facet_data,
523 values,
524 )),
525 name: None,
526 orientation: orientation.clone(),
527 marker: Some(marker_ir),
528 error: error_data,
529 show_legend: Some(false),
530 legend_group: None,
531 subplot_ref: Some(subplot_ref.clone()),
532 }));
533 }
534 }
535 }
536
537 traces
538 }
539
540 fn resolve_color(index: usize, color: Option<Rgb>, colors: Option<Vec<Rgb>>) -> Option<Rgb> {
541 if let Some(c) = color {
542 return Some(c);
543 }
544 if let Some(ref cs) = colors {
545 return cs.get(index).copied();
546 }
547 None
548 }
549}
550
551impl crate::Plot for BarPlot {
552 fn ir_traces(&self) -> &[TraceIR] {
553 &self.traces
554 }
555
556 fn ir_layout(&self) -> &LayoutIR {
557 &self.layout
558 }
559}
560
561#[cfg(test)]
562mod tests {
563 use super::*;
564 use crate::Plot;
565 use polars::prelude::*;
566
567 fn assert_rgb(actual: Option<Rgb>, r: u8, g: u8, b: u8) {
568 let c = actual.expect("expected Some(Rgb)");
569 assert_eq!((c.0, c.1, c.2), (r, g, b));
570 }
571
572 #[test]
573 fn test_resolve_color_singular_priority() {
574 let result = BarPlot::resolve_color(0, Some(Rgb(255, 0, 0)), Some(vec![Rgb(0, 0, 255)]));
575 assert_rgb(result, 255, 0, 0);
576 }
577
578 #[test]
579 fn test_resolve_color_both_none() {
580 let result = BarPlot::resolve_color(0, None, None);
581 assert!(result.is_none());
582 }
583
584 #[test]
585 fn test_no_group_one_trace() {
586 let df = df!["labels" => ["a", "b", "c"], "values" => [1.0, 2.0, 3.0]].unwrap();
587 let plot = BarPlot::builder()
588 .data(&df)
589 .labels("labels")
590 .values("values")
591 .build();
592 assert_eq!(plot.ir_traces().len(), 1);
593 }
594
595 #[test]
596 fn test_with_group() {
597 let df = df![
598 "labels" => ["a", "b", "a", "b"],
599 "values" => [1.0, 2.0, 3.0, 4.0],
600 "g" => ["x", "x", "y", "y"]
601 ]
602 .unwrap();
603 let plot = BarPlot::builder()
604 .data(&df)
605 .labels("labels")
606 .values("values")
607 .group("g")
608 .build();
609 assert_eq!(plot.ir_traces().len(), 2);
610 }
611
612 #[test]
613 fn test_faceted_trace_count() {
614 let df = df![
615 "labels" => ["a", "b", "c", "a", "b", "c"],
616 "values" => [1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
617 "f" => ["f1", "f2", "f1", "f2", "f1", "f2"]
618 ]
619 .unwrap();
620 let plot = BarPlot::builder()
621 .data(&df)
622 .labels("labels")
623 .values("values")
624 .facet("f")
625 .build();
626 assert_eq!(plot.ir_traces().len(), 2);
627 }
628
629 #[test]
630 #[should_panic(expected = "maximum")]
631 fn test_max_facets_panics() {
632 let facet_values: Vec<&str> = (0..9)
633 .map(|i| match i {
634 0 => "a",
635 1 => "b",
636 2 => "c",
637 3 => "d",
638 4 => "e",
639 5 => "f",
640 6 => "g",
641 7 => "h",
642 _ => "i",
643 })
644 .collect();
645 let n = facet_values.len();
646 let labels: Vec<&str> = (0..n).map(|_| "label").collect();
647 let values: Vec<f64> = (0..n).map(|i| i as f64).collect();
648 let df = DataFrame::new(
649 n,
650 vec![
651 Column::new("labels".into(), &labels),
652 Column::new("values".into(), &values),
653 Column::new("f".into(), &facet_values),
654 ],
655 )
656 .unwrap();
657 BarPlot::builder()
658 .data(&df)
659 .labels("labels")
660 .values("values")
661 .facet("f")
662 .build();
663 }
664}