use crate::color::Color;
use crate::core::{Bounds, Canvas, Drawable, Point2D};
use crate::error::{Error, Result};
use crate::legend::LegendEntry;
use crate::plots::bar::BarOrientation;
pub struct StackedBarPlot {
categories: Vec<String>,
series: Vec<Vec<f64>>,
colors: Vec<Color>,
labels: Vec<String>,
orientation: BarOrientation,
bar_width: f64,
gap: f64,
}
impl StackedBarPlot {
pub fn new(categories: Vec<impl Into<String>>, series: Vec<Vec<f64>>) -> Result<Self> {
if categories.is_empty() {
return Err(Error::InvalidData("Categories cannot be empty".into()));
}
if series.is_empty() {
return Err(Error::InvalidData("Series cannot be empty".into()));
}
let n_categories = categories.len();
for (i, s) in series.iter().enumerate() {
if s.len() != n_categories {
return Err(Error::InvalidData(format!(
"Series {} length ({}) doesn't match categories length ({})",
i,
s.len(),
n_categories
)));
}
}
let n_series = series.len();
let categories: Vec<String> = categories
.into_iter()
.map(std::convert::Into::into)
.collect();
let default_colors = vec![
Color::from_hex("#3498db").unwrap(), Color::from_hex("#e74c3c").unwrap(), Color::from_hex("#2ecc71").unwrap(), Color::from_hex("#f39c12").unwrap(), Color::from_hex("#9b59b6").unwrap(), Color::from_hex("#1abc9c").unwrap(), Color::from_hex("#e67e22").unwrap(), Color::from_hex("#34495e").unwrap(), ];
let mut colors = Vec::new();
for i in 0..n_series {
colors.push(default_colors[i % default_colors.len()]);
}
let labels = (0..n_series).map(|i| format!("Series {}", i + 1)).collect();
Ok(Self {
categories,
series,
colors,
labels,
orientation: BarOrientation::Vertical,
bar_width: 0.8,
gap: 0.2,
})
}
#[must_use]
pub fn colors(mut self, colors: Vec<Color>) -> Self {
if !colors.is_empty() {
self.colors = colors;
}
self
}
#[must_use]
pub fn labels(mut self, labels: Vec<impl Into<String>>) -> Self {
self.labels = labels.into_iter().map(std::convert::Into::into).collect();
self
}
#[must_use]
pub fn orientation(mut self, orientation: BarOrientation) -> Self {
self.orientation = orientation;
self
}
#[must_use]
pub fn bar_width(mut self, width: f64) -> Self {
self.bar_width = width.clamp(0.1, 1.0);
self
}
#[must_use]
pub fn gap(mut self, gap: f64) -> Self {
self.gap = gap.max(0.0);
self
}
#[must_use]
pub fn legend_entries(&self) -> Vec<LegendEntry> {
self.labels
.iter()
.zip(&self.colors)
.map(|(label, color)| LegendEntry::new(label.clone()).color(*color).swatch_shape())
.collect()
}
}
impl Drawable for StackedBarPlot {
fn draw(&self, canvas: &mut dyn Canvas) -> Result<()> {
let n_categories = self.categories.len();
for cat_idx in 0..n_categories {
let x_pos = cat_idx as f64;
let mut cumulative = 0.0;
for (series_idx, series_data) in self.series.iter().enumerate() {
let value = series_data[cat_idx];
if value < 0.0 {
return Err(Error::InvalidData(
"Stacked bar plots do not support negative values".into(),
));
}
let color = self
.colors
.get(series_idx)
.unwrap_or(&self.colors[0])
.to_rgba();
match self.orientation {
BarOrientation::Vertical => {
let x_left = x_pos - self.bar_width / 2.0;
let y_bottom = cumulative;
let bottom_left = Point2D::new(x_left, y_bottom);
canvas.draw_rectangle(&bottom_left, self.bar_width, value, &color)?;
}
BarOrientation::Horizontal => {
let y_center = x_pos;
let y_bottom = y_center - self.bar_width / 2.0;
let x_left = cumulative;
let bottom_left = Point2D::new(x_left, y_bottom);
canvas.draw_rectangle(&bottom_left, value, self.bar_width, &color)?;
}
}
cumulative += value;
}
}
Ok(())
}
}
impl StackedBarPlot {
#[must_use]
pub fn bounds(&self) -> Option<Bounds> {
if self.categories.is_empty() || self.series.is_empty() {
return None;
}
let n_categories = self.categories.len();
let mut max_cumulative: f64 = 0.0;
for cat_idx in 0..n_categories {
let mut cumulative: f64 = 0.0;
for series_data in &self.series {
cumulative += series_data[cat_idx];
}
max_cumulative = max_cumulative.max(cumulative);
}
match self.orientation {
BarOrientation::Vertical => {
Some(Bounds::new(
-0.5,
n_categories as f64 - 0.5,
0.0,
max_cumulative,
))
}
BarOrientation::Horizontal => {
Some(Bounds::new(
0.0,
max_cumulative,
-0.5,
n_categories as f64 - 0.5,
))
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_stacked_bar_creation() {
let categories = vec!["A", "B", "C"];
let series = vec![vec![10.0, 20.0, 15.0], vec![5.0, 10.0, 8.0]];
let stacked = StackedBarPlot::new(categories, series).unwrap();
assert!(stacked.bounds().is_some());
}
#[test]
fn test_stacked_bar_empty_categories() {
let categories: Vec<&str> = vec![];
let series = vec![vec![10.0, 20.0]];
let result = StackedBarPlot::new(categories, series);
assert!(result.is_err());
}
#[test]
fn test_stacked_bar_empty_series() {
let categories = vec!["A", "B"];
let series: Vec<Vec<f64>> = vec![];
let result = StackedBarPlot::new(categories, series);
assert!(result.is_err());
}
#[test]
fn test_stacked_bar_mismatched_lengths() {
let categories = vec!["A", "B", "C"];
let series = vec![vec![10.0, 20.0], vec![5.0, 10.0, 8.0]]; let result = StackedBarPlot::new(categories, series);
assert!(result.is_err());
}
#[test]
fn test_stacked_bar_bounds() {
let categories = vec!["A", "B"];
let series = vec![vec![10.0, 20.0], vec![5.0, 10.0]];
let stacked = StackedBarPlot::new(categories, series).unwrap();
let bounds = stacked.bounds().unwrap();
assert_eq!(bounds.x_min, -0.5);
assert_eq!(bounds.x_max, 1.5);
assert_eq!(bounds.y_min, 0.0);
assert_eq!(bounds.y_max, 30.0); }
#[test]
fn test_stacked_bar_horizontal_bounds() {
let categories = vec!["A", "B"];
let series = vec![vec![10.0, 20.0], vec![5.0, 10.0]];
let stacked = StackedBarPlot::new(categories, series)
.unwrap()
.orientation(BarOrientation::Horizontal);
let bounds = stacked.bounds().unwrap();
assert_eq!(bounds.x_min, 0.0);
assert_eq!(bounds.x_max, 30.0);
assert_eq!(bounds.y_min, -0.5);
assert_eq!(bounds.y_max, 1.5);
}
}