use crate::core::{Plot, PlottingError, Result};
use crate::render::skia::SkiaRenderer;
use tiny_skia::Rect;
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct GridSpec {
pub rows: usize,
pub cols: usize,
pub hspace: f32,
pub wspace: f32,
}
impl GridSpec {
pub fn new(rows: usize, cols: usize) -> Self {
Self {
rows,
cols,
hspace: 0.0, wspace: 0.0,
}
}
pub fn with_hspace(mut self, hspace: f32) -> Self {
self.hspace = hspace.clamp(0.0, 1.0);
self
}
pub fn with_wspace(mut self, wspace: f32) -> Self {
self.wspace = wspace.clamp(0.0, 1.0);
self
}
pub fn total_subplots(&self) -> usize {
self.rows * self.cols
}
pub fn validate(&self) -> Result<()> {
if self.rows == 0 || self.cols == 0 {
return Err(PlottingError::InvalidInput(
"Grid must have at least 1 row and 1 column".to_string(),
));
}
if self.rows > 10 || self.cols > 10 {
return Err(PlottingError::InvalidInput(
"Grid size limited to 10x10 for performance".to_string(),
));
}
Ok(())
}
pub fn subplot_rect(
&self,
index: usize,
figure_width: u32,
figure_height: u32,
margin: f32,
top_offset: f32,
) -> Result<Rect> {
if index >= self.total_subplots() {
return Err(PlottingError::InvalidInput(format!(
"Subplot index {} exceeds grid size {}",
index,
self.total_subplots()
)));
}
let row = index / self.cols;
let col = index % self.cols;
let margin_px = margin * figure_width.min(figure_height) as f32;
let available_width = figure_width as f32 - 2.0 * margin_px;
let available_height = figure_height as f32 - 2.0 * margin_px - top_offset;
let subplot_width = available_width / self.cols as f32;
let subplot_height = available_height / self.rows as f32;
let spacing_x = subplot_width * self.wspace;
let spacing_y = subplot_height * self.hspace;
let plot_width = subplot_width - spacing_x;
let plot_height = subplot_height - spacing_y;
let x = margin_px + col as f32 * subplot_width + spacing_x / 2.0;
let y = margin_px + top_offset + row as f32 * subplot_height + spacing_y / 2.0;
Rect::from_xywh(x, y, plot_width, plot_height).ok_or_else(|| {
PlottingError::InvalidInput("Invalid subplot dimensions calculated".to_string())
})
}
}
#[derive(Debug, Clone)]
pub struct SubplotFigure {
grid: GridSpec,
plots: Vec<Option<Plot>>,
width: u32,
height: u32,
suptitle: Option<String>,
margin: f32,
}
impl SubplotFigure {
pub fn new(rows: usize, cols: usize, width: u32, height: u32) -> Result<Self> {
let grid = GridSpec::new(rows, cols);
grid.validate()?;
let total_plots = grid.total_subplots();
let plots = vec![None; total_plots];
Ok(Self {
grid,
plots,
width,
height,
suptitle: None,
margin: 0.05, })
}
pub fn hspace(mut self, hspace: f32) -> Self {
self.grid = self.grid.with_hspace(hspace);
self
}
pub fn wspace(mut self, wspace: f32) -> Self {
self.grid = self.grid.with_wspace(wspace);
self
}
pub fn suptitle<S: Into<String>>(mut self, title: S) -> Self {
self.suptitle = Some(title.into());
self
}
pub fn margin(mut self, margin: f32) -> Self {
self.margin = margin.clamp(0.0, 0.4); self
}
pub fn subplot(mut self, row: usize, col: usize, plot: Plot) -> Result<Self> {
if row >= self.grid.rows || col >= self.grid.cols {
return Err(PlottingError::InvalidInput(format!(
"Subplot position ({}, {}) exceeds grid size {}x{}",
row, col, self.grid.rows, self.grid.cols
)));
}
let index = row * self.grid.cols + col;
self.plots[index] = Some(plot);
Ok(self)
}
pub fn subplot_at(mut self, index: usize, plot: Plot) -> Result<Self> {
if index >= self.plots.len() {
return Err(PlottingError::InvalidInput(format!(
"Subplot index {} exceeds total subplots {}",
index,
self.plots.len()
)));
}
self.plots[index] = Some(plot);
Ok(self)
}
pub fn grid_spec(&self) -> GridSpec {
self.grid
}
pub fn subplot_count(&self) -> usize {
self.plots.iter().filter(|p| p.is_some()).count()
}
pub fn save<P: AsRef<std::path::Path>>(self, path: P) -> Result<()> {
self.save_with_dpi(path, 96.0)
}
pub fn save_with_dpi<P: AsRef<std::path::Path>>(self, path: P, dpi: f32) -> Result<()> {
let theme = crate::render::Theme::default();
let mut renderer = SkiaRenderer::new(self.width, self.height, theme)?;
let suptitle_height = if self.suptitle.is_some() {
45.0_f32 } else {
0.0_f32
};
if let Some(title) = &self.suptitle {
let title_y = 30.0_f32; let title_size = 16.0_f32;
renderer.draw_text_centered(
title,
self.width as f32 / 2.0,
title_y,
title_size,
crate::render::Color::new(0, 0, 0),
)?;
}
for (index, plot_opt) in self.plots.iter().enumerate() {
if let Some(plot) = plot_opt {
let subplot_rect = self.grid.subplot_rect(
index,
self.width,
self.height,
self.margin,
suptitle_height,
)?;
let reference_dim = 300.0_f32;
let subplot_min_dim = subplot_rect.width().min(subplot_rect.height());
let size_scale = (subplot_min_dim / reference_dim).clamp(0.35, 1.0);
let scaled_plot = plot.clone().scale_typography(size_scale);
let subplot_theme = scaled_plot.get_theme();
let mut subplot_renderer = SkiaRenderer::new(
subplot_rect.width() as u32,
subplot_rect.height() as u32,
subplot_theme,
)?;
let subplot_dpi = 96.0_f32;
scaled_plot.render_to_renderer(&mut subplot_renderer, subplot_dpi)?;
renderer.draw_subplot(
subplot_renderer.into_image(),
subplot_rect.left() as u32,
subplot_rect.top() as u32,
)?;
}
}
renderer.save_png(path)?;
Ok(())
}
}
pub fn subplots(rows: usize, cols: usize, width: u32, height: u32) -> Result<SubplotFigure> {
SubplotFigure::new(rows, cols, width, height)
}
pub fn subplots_default(rows: usize, cols: usize) -> Result<SubplotFigure> {
let base_width = 400;
let base_height = 300;
let width = (base_width * cols).min(1600) as u32;
let height = (base_height * rows).min(1200) as u32;
SubplotFigure::new(rows, cols, width, height)
}
#[cfg(test)]
#[allow(deprecated)]
mod tests {
use super::*;
#[test]
fn test_grid_spec_creation() {
let grid = GridSpec::new(2, 3);
assert_eq!(grid.rows, 2);
assert_eq!(grid.cols, 3);
assert_eq!(grid.total_subplots(), 6);
}
#[test]
fn test_grid_spec_spacing() {
let grid = GridSpec::new(2, 2).with_hspace(0.3).with_wspace(0.4);
assert_eq!(grid.hspace, 0.3);
assert_eq!(grid.wspace, 0.4);
}
#[test]
fn test_grid_validation() {
assert!(GridSpec::new(0, 1).validate().is_err());
assert!(GridSpec::new(1, 0).validate().is_err());
assert!(GridSpec::new(11, 1).validate().is_err());
assert!(GridSpec::new(2, 3).validate().is_ok());
}
#[test]
fn test_subplot_rect_calculation() {
let grid = GridSpec::new(2, 2);
let margin = 0.1;
let top_offset = 0.0;
let rect = grid.subplot_rect(0, 800, 600, margin, top_offset).unwrap();
assert!(rect.left() >= 60.0); assert!(rect.top() >= 60.0);
assert!(rect.width() > 0.0);
assert!(rect.height() > 0.0);
let rect = grid.subplot_rect(3, 800, 600, margin, top_offset).unwrap();
assert!(rect.right() <= 740.0); assert!(rect.bottom() <= 540.0); }
#[test]
fn test_subplot_rect_with_suptitle_offset() {
let grid = GridSpec::new(2, 2);
let margin = 0.1;
let top_offset = 45.0;
let rect = grid.subplot_rect(0, 800, 600, margin, top_offset).unwrap();
assert!(rect.top() >= 60.0 + top_offset); }
#[test]
fn test_subplot_figure_creation() {
let figure = SubplotFigure::new(2, 3, 800, 600).unwrap();
assert_eq!(figure.subplot_count(), 0); assert_eq!(figure.grid_spec().total_subplots(), 6);
}
#[test]
fn test_subplot_positioning() {
let mut figure = SubplotFigure::new(2, 2, 800, 600).unwrap();
let plot = Plot::new();
figure = figure.subplot(0, 1, plot.clone()).unwrap();
assert_eq!(figure.subplot_count(), 1);
figure = figure.subplot_at(3, plot).unwrap();
assert_eq!(figure.subplot_count(), 2);
}
#[test]
fn test_subplot_bounds_checking() {
let figure = SubplotFigure::new(2, 2, 800, 600).unwrap();
let plot = Plot::new();
assert!(figure.clone().subplot(2, 0, plot.clone()).is_err());
assert!(figure.clone().subplot(0, 2, plot.clone()).is_err());
assert!(figure.clone().subplot_at(4, plot).is_err());
}
#[test]
fn test_convenience_functions() {
let figure = subplots(2, 3, 800, 600).unwrap();
assert_eq!(figure.grid_spec().rows, 2);
assert_eq!(figure.grid_spec().cols, 3);
let figure = subplots_default(2, 2).unwrap();
assert_eq!(figure.width, 800); assert_eq!(figure.height, 600); }
#[test]
fn test_subplot_rendering_integration() {
use crate::render::Theme;
let x = vec![1.0, 2.0, 3.0];
let y = vec![2.0, 4.0, 3.0];
let plot = Plot::new().line(&x, &y).end_series().title("Test Plot");
let figure = SubplotFigure::new(1, 1, 400, 300)
.unwrap()
.subplot(0, 0, plot)
.unwrap();
assert_eq!(figure.subplot_count(), 1);
assert_eq!(figure.grid_spec().total_subplots(), 1);
assert_eq!(figure.width, 400);
assert_eq!(figure.height, 300);
}
#[test]
fn test_subplot_with_different_themes() {
use crate::render::Theme;
let x = vec![1.0, 2.0, 3.0];
let y1 = vec![2.0, 4.0, 3.0];
let y2 = vec![1.0, 3.0, 2.0];
let plot1 = Plot::new()
.line(&x, &y1)
.end_series()
.theme(Theme::default())
.title("Default Theme");
let plot2 = Plot::new()
.line(&x, &y2)
.end_series()
.theme(Theme::dark())
.title("Dark Theme");
let figure = SubplotFigure::new(1, 2, 800, 400)
.unwrap()
.subplot(0, 0, plot1)
.unwrap()
.subplot(0, 1, plot2)
.unwrap();
assert_eq!(figure.subplot_count(), 2);
let spec = figure.grid_spec();
assert_eq!(spec.rows, 1);
assert_eq!(spec.cols, 2);
}
#[test]
fn test_subplot_suptitle_and_spacing() {
let plot = Plot::new();
let figure = SubplotFigure::new(2, 2, 800, 600)
.unwrap()
.suptitle("Overall Title")
.hspace(0.4)
.wspace(0.5)
.subplot_at(0, plot)
.unwrap();
assert_eq!(figure.subplot_count(), 1);
assert_eq!(figure.grid_spec().hspace, 0.4);
assert_eq!(figure.grid_spec().wspace, 0.5);
assert!(figure.suptitle.is_some());
assert_eq!(figure.suptitle.as_ref().unwrap(), "Overall Title");
}
#[test]
fn test_empty_subplot_figure() {
let figure = SubplotFigure::new(2, 2, 800, 600).unwrap();
assert_eq!(figure.subplot_count(), 0);
assert_eq!(figure.grid_spec().total_subplots(), 4);
let plot = Plot::new();
let updated_figure = figure.subplot(1, 1, plot).unwrap();
assert_eq!(updated_figure.subplot_count(), 1);
}
#[test]
fn test_large_subplot_grid() {
let result = SubplotFigure::new(5, 4, 1200, 900);
assert!(result.is_ok());
let figure = result.unwrap();
assert_eq!(figure.grid_spec().total_subplots(), 20);
let large_result = SubplotFigure::new(10, 10, 2000, 2000);
assert!(large_result.is_ok());
let too_large_result = SubplotFigure::new(11, 10, 2000, 2000);
assert!(too_large_result.is_err());
}
}