use crate::core::{BoundingBox, RenderData};
use crate::plots::{BarChart, Histogram, LinePlot, PointCloudPlot, ScatterPlot};
use glam::Vec4;
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct Figure {
plots: Vec<PlotElement>,
pub title: Option<String>,
pub x_label: Option<String>,
pub y_label: Option<String>,
pub legend_enabled: bool,
pub grid_enabled: bool,
pub background_color: Vec4,
pub x_limits: Option<(f64, f64)>,
pub y_limits: Option<(f64, f64)>,
bounds: Option<BoundingBox>,
dirty: bool,
}
#[derive(Debug, Clone)]
pub enum PlotElement {
Line(LinePlot),
Scatter(ScatterPlot),
Bar(BarChart),
Histogram(Histogram),
PointCloud(PointCloudPlot),
}
#[derive(Debug, Clone)]
pub struct LegendEntry {
pub label: String,
pub color: Vec4,
pub plot_type: PlotType,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum PlotType {
Line,
Scatter,
Bar,
Histogram,
PointCloud,
}
impl Figure {
pub fn new() -> Self {
Self {
plots: Vec::new(),
title: None,
x_label: None,
y_label: None,
legend_enabled: true,
grid_enabled: true,
background_color: Vec4::new(1.0, 1.0, 1.0, 1.0), x_limits: None,
y_limits: None,
bounds: None,
dirty: true,
}
}
pub fn with_title<S: Into<String>>(mut self, title: S) -> Self {
self.title = Some(title.into());
self
}
pub fn with_labels<S: Into<String>>(mut self, x_label: S, y_label: S) -> Self {
self.x_label = Some(x_label.into());
self.y_label = Some(y_label.into());
self
}
pub fn with_limits(mut self, x_limits: (f64, f64), y_limits: (f64, f64)) -> Self {
self.x_limits = Some(x_limits);
self.y_limits = Some(y_limits);
self.dirty = true;
self
}
pub fn with_legend(mut self, enabled: bool) -> Self {
self.legend_enabled = enabled;
self
}
pub fn with_grid(mut self, enabled: bool) -> Self {
self.grid_enabled = enabled;
self
}
pub fn with_background_color(mut self, color: Vec4) -> Self {
self.background_color = color;
self
}
pub fn add_line_plot(&mut self, plot: LinePlot) -> usize {
self.plots.push(PlotElement::Line(plot));
self.dirty = true;
self.plots.len() - 1
}
pub fn add_scatter_plot(&mut self, plot: ScatterPlot) -> usize {
self.plots.push(PlotElement::Scatter(plot));
self.dirty = true;
self.plots.len() - 1
}
pub fn add_bar_chart(&mut self, plot: BarChart) -> usize {
self.plots.push(PlotElement::Bar(plot));
self.dirty = true;
self.plots.len() - 1
}
pub fn add_histogram(&mut self, plot: Histogram) -> usize {
self.plots.push(PlotElement::Histogram(plot));
self.dirty = true;
self.plots.len() - 1
}
pub fn add_point_cloud_plot(&mut self, plot: PointCloudPlot) -> usize {
self.plots.push(PlotElement::PointCloud(plot));
self.dirty = true;
self.plots.len() - 1
}
pub fn remove_plot(&mut self, index: usize) -> Result<(), String> {
if index >= self.plots.len() {
return Err(format!("Plot index {index} out of bounds"));
}
self.plots.remove(index);
self.dirty = true;
Ok(())
}
pub fn clear(&mut self) {
self.plots.clear();
self.dirty = true;
}
pub fn len(&self) -> usize {
self.plots.len()
}
pub fn is_empty(&self) -> bool {
self.plots.is_empty()
}
pub fn plots(&self) -> impl Iterator<Item = &PlotElement> {
self.plots.iter()
}
pub fn get_plot_mut(&mut self, index: usize) -> Option<&mut PlotElement> {
self.dirty = true;
self.plots.get_mut(index)
}
pub fn bounds(&mut self) -> BoundingBox {
if self.dirty || self.bounds.is_none() {
self.compute_bounds();
}
self.bounds.unwrap()
}
fn compute_bounds(&mut self) {
if self.plots.is_empty() {
self.bounds = Some(BoundingBox::default());
return;
}
let mut combined_bounds = None;
for plot in &mut self.plots {
if !plot.is_visible() {
continue;
}
let plot_bounds = plot.bounds();
combined_bounds = match combined_bounds {
None => Some(plot_bounds),
Some(existing) => Some(existing.union(&plot_bounds)),
};
}
self.bounds = combined_bounds.or_else(|| Some(BoundingBox::default()));
self.dirty = false;
}
pub fn render_data(&mut self) -> Vec<RenderData> {
let mut render_data = Vec::new();
for plot in &mut self.plots {
if plot.is_visible() {
render_data.push(plot.render_data());
}
}
render_data
}
pub fn legend_entries(&self) -> Vec<LegendEntry> {
let mut entries = Vec::new();
for plot in &self.plots {
if let Some(label) = plot.label() {
entries.push(LegendEntry {
label,
color: plot.color(),
plot_type: plot.plot_type(),
});
}
}
entries
}
pub fn statistics(&self) -> FigureStatistics {
let plot_counts = self.plots.iter().fold(HashMap::new(), |mut acc, plot| {
let plot_type = plot.plot_type();
*acc.entry(plot_type).or_insert(0) += 1;
acc
});
let total_memory: usize = self
.plots
.iter()
.map(|plot| plot.estimated_memory_usage())
.sum();
let visible_count = self.plots.iter().filter(|plot| plot.is_visible()).count();
FigureStatistics {
total_plots: self.plots.len(),
visible_plots: visible_count,
plot_type_counts: plot_counts,
total_memory_usage: total_memory,
has_legend: self.legend_enabled && !self.legend_entries().is_empty(),
}
}
}
impl Default for Figure {
fn default() -> Self {
Self::new()
}
}
impl PlotElement {
pub fn is_visible(&self) -> bool {
match self {
PlotElement::Line(plot) => plot.visible,
PlotElement::Scatter(plot) => plot.visible,
PlotElement::Bar(plot) => plot.visible,
PlotElement::Histogram(plot) => plot.visible,
PlotElement::PointCloud(plot) => plot.visible,
}
}
pub fn label(&self) -> Option<String> {
match self {
PlotElement::Line(plot) => plot.label.clone(),
PlotElement::Scatter(plot) => plot.label.clone(),
PlotElement::Bar(plot) => plot.label.clone(),
PlotElement::Histogram(plot) => plot.label.clone(),
PlotElement::PointCloud(plot) => plot.label.clone(),
}
}
pub fn color(&self) -> Vec4 {
match self {
PlotElement::Line(plot) => plot.color,
PlotElement::Scatter(plot) => plot.color,
PlotElement::Bar(plot) => plot.color,
PlotElement::Histogram(plot) => plot.color,
PlotElement::PointCloud(plot) => plot.default_color,
}
}
pub fn plot_type(&self) -> PlotType {
match self {
PlotElement::Line(_) => PlotType::Line,
PlotElement::Scatter(_) => PlotType::Scatter,
PlotElement::Bar(_) => PlotType::Bar,
PlotElement::Histogram(_) => PlotType::Histogram,
PlotElement::PointCloud(_) => PlotType::PointCloud,
}
}
pub fn bounds(&mut self) -> BoundingBox {
match self {
PlotElement::Line(plot) => plot.bounds(),
PlotElement::Scatter(plot) => plot.bounds(),
PlotElement::Bar(plot) => plot.bounds(),
PlotElement::Histogram(plot) => plot.bounds(),
PlotElement::PointCloud(plot) => plot.bounds(),
}
}
pub fn render_data(&mut self) -> RenderData {
match self {
PlotElement::Line(plot) => plot.render_data(),
PlotElement::Scatter(plot) => plot.render_data(),
PlotElement::Bar(plot) => plot.render_data(),
PlotElement::Histogram(plot) => plot.render_data(),
PlotElement::PointCloud(plot) => plot.render_data(),
}
}
pub fn estimated_memory_usage(&self) -> usize {
match self {
PlotElement::Line(plot) => plot.estimated_memory_usage(),
PlotElement::Scatter(plot) => plot.estimated_memory_usage(),
PlotElement::Bar(plot) => plot.estimated_memory_usage(),
PlotElement::Histogram(plot) => plot.estimated_memory_usage(),
PlotElement::PointCloud(plot) => plot.estimated_memory_usage(),
}
}
}
#[derive(Debug)]
pub struct FigureStatistics {
pub total_plots: usize,
pub visible_plots: usize,
pub plot_type_counts: HashMap<PlotType, usize>,
pub total_memory_usage: usize,
pub has_legend: bool,
}
pub mod matlab_compat {
use super::*;
use crate::plots::{LinePlot, ScatterPlot};
pub fn figure() -> Figure {
Figure::new()
}
pub fn figure_with_title<S: Into<String>>(title: S) -> Figure {
Figure::new().with_title(title)
}
pub fn plot_multiple_lines(
figure: &mut Figure,
data_sets: Vec<(Vec<f64>, Vec<f64>, Option<String>)>,
) -> Result<Vec<usize>, String> {
let mut indices = Vec::new();
for (i, (x, y, label)) in data_sets.into_iter().enumerate() {
let mut line = LinePlot::new(x, y)?;
let colors = [
Vec4::new(0.0, 0.4470, 0.7410, 1.0), Vec4::new(0.8500, 0.3250, 0.0980, 1.0), Vec4::new(0.9290, 0.6940, 0.1250, 1.0), Vec4::new(0.4940, 0.1840, 0.5560, 1.0), Vec4::new(0.4660, 0.6740, 0.1880, 1.0), Vec4::new(std::f64::consts::LOG10_2 as f32, 0.7450, 0.9330, 1.0), Vec4::new(0.6350, 0.0780, 0.1840, 1.0), ];
let color = colors[i % colors.len()];
line.set_color(color);
if let Some(label) = label {
line = line.with_label(label);
}
indices.push(figure.add_line_plot(line));
}
Ok(indices)
}
pub fn scatter_multiple(
figure: &mut Figure,
data_sets: Vec<(Vec<f64>, Vec<f64>, Option<String>)>,
) -> Result<Vec<usize>, String> {
let mut indices = Vec::new();
for (i, (x, y, label)) in data_sets.into_iter().enumerate() {
let mut scatter = ScatterPlot::new(x, y)?;
let colors = [
Vec4::new(1.0, 0.0, 0.0, 1.0), Vec4::new(0.0, 1.0, 0.0, 1.0), Vec4::new(0.0, 0.0, 1.0, 1.0), Vec4::new(1.0, 1.0, 0.0, 1.0), Vec4::new(1.0, 0.0, 1.0, 1.0), Vec4::new(0.0, 1.0, 1.0, 1.0), Vec4::new(0.5, 0.5, 0.5, 1.0), ];
let color = colors[i % colors.len()];
scatter.set_color(color);
if let Some(label) = label {
scatter = scatter.with_label(label);
}
indices.push(figure.add_scatter_plot(scatter));
}
Ok(indices)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::plots::line::LineStyle;
#[test]
fn test_figure_creation() {
let figure = Figure::new();
assert_eq!(figure.len(), 0);
assert!(figure.is_empty());
assert!(figure.legend_enabled);
assert!(figure.grid_enabled);
}
#[test]
fn test_figure_styling() {
let figure = Figure::new()
.with_title("Test Figure")
.with_labels("X Axis", "Y Axis")
.with_legend(false)
.with_grid(false);
assert_eq!(figure.title, Some("Test Figure".to_string()));
assert_eq!(figure.x_label, Some("X Axis".to_string()));
assert_eq!(figure.y_label, Some("Y Axis".to_string()));
assert!(!figure.legend_enabled);
assert!(!figure.grid_enabled);
}
#[test]
fn test_multiple_line_plots() {
let mut figure = Figure::new();
let line1 = LinePlot::new(vec![0.0, 1.0, 2.0], vec![0.0, 1.0, 4.0])
.unwrap()
.with_label("Quadratic");
let index1 = figure.add_line_plot(line1);
let line2 = LinePlot::new(vec![0.0, 1.0, 2.0], vec![0.0, 1.0, 2.0])
.unwrap()
.with_style(Vec4::new(1.0, 0.0, 0.0, 1.0), 2.0, LineStyle::Dashed)
.with_label("Linear");
let index2 = figure.add_line_plot(line2);
assert_eq!(figure.len(), 2);
assert_eq!(index1, 0);
assert_eq!(index2, 1);
let legend = figure.legend_entries();
assert_eq!(legend.len(), 2);
assert_eq!(legend[0].label, "Quadratic");
assert_eq!(legend[1].label, "Linear");
}
#[test]
fn test_mixed_plot_types() {
let mut figure = Figure::new();
let line = LinePlot::new(vec![0.0, 1.0, 2.0], vec![1.0, 2.0, 3.0])
.unwrap()
.with_label("Line");
figure.add_line_plot(line);
let scatter = ScatterPlot::new(vec![0.5, 1.5, 2.5], vec![1.5, 2.5, 3.5])
.unwrap()
.with_label("Scatter");
figure.add_scatter_plot(scatter);
let bar = BarChart::new(vec!["A".to_string(), "B".to_string()], vec![2.0, 4.0])
.unwrap()
.with_label("Bar");
figure.add_bar_chart(bar);
assert_eq!(figure.len(), 3);
let render_data = figure.render_data();
assert_eq!(render_data.len(), 3);
let stats = figure.statistics();
assert_eq!(stats.total_plots, 3);
assert_eq!(stats.visible_plots, 3);
assert!(stats.has_legend);
}
#[test]
fn test_plot_visibility() {
let mut figure = Figure::new();
let mut line = LinePlot::new(vec![0.0, 1.0], vec![0.0, 1.0]).unwrap();
line.set_visible(false); figure.add_line_plot(line);
let scatter = ScatterPlot::new(vec![0.0, 1.0], vec![1.0, 2.0]).unwrap();
figure.add_scatter_plot(scatter);
let render_data = figure.render_data();
assert_eq!(render_data.len(), 1);
let stats = figure.statistics();
assert_eq!(stats.total_plots, 2);
assert_eq!(stats.visible_plots, 1);
}
#[test]
fn test_bounds_computation() {
let mut figure = Figure::new();
let line = LinePlot::new(vec![-1.0, 0.0, 1.0], vec![-2.0, 0.0, 2.0]).unwrap();
figure.add_line_plot(line);
let scatter = ScatterPlot::new(vec![2.0, 3.0, 4.0], vec![1.0, 3.0, 5.0]).unwrap();
figure.add_scatter_plot(scatter);
let bounds = figure.bounds();
assert!(bounds.min.x <= -1.0);
assert!(bounds.max.x >= 4.0);
assert!(bounds.min.y <= -2.0);
assert!(bounds.max.y >= 5.0);
}
#[test]
fn test_matlab_compat_multiple_lines() {
use super::matlab_compat::*;
let mut figure = figure_with_title("Multiple Lines Test");
let data_sets = vec![
(
vec![0.0, 1.0, 2.0],
vec![0.0, 1.0, 4.0],
Some("Quadratic".to_string()),
),
(
vec![0.0, 1.0, 2.0],
vec![0.0, 1.0, 2.0],
Some("Linear".to_string()),
),
(
vec![0.0, 1.0, 2.0],
vec![1.0, 1.0, 1.0],
Some("Constant".to_string()),
),
];
let indices = plot_multiple_lines(&mut figure, data_sets).unwrap();
assert_eq!(indices.len(), 3);
assert_eq!(figure.len(), 3);
let legend = figure.legend_entries();
assert_eq!(legend.len(), 3);
assert_ne!(legend[0].color, legend[1].color);
assert_ne!(legend[1].color, legend[2].color);
}
}