runmat_plot/plots/
figure.rs

1//! Figure management for multiple overlaid plots
2//!
3//! This module provides the `Figure` struct that manages multiple plots in a single
4//! coordinate system, handling overlays, legends, and proper rendering order.
5
6use crate::core::{BoundingBox, RenderData};
7use crate::plots::{BarChart, Histogram, LinePlot, PointCloudPlot, ScatterPlot};
8use glam::Vec4;
9use std::collections::HashMap;
10
11/// A figure that can contain multiple overlaid plots
12#[derive(Debug, Clone)]
13pub struct Figure {
14    /// All plots in this figure
15    plots: Vec<PlotElement>,
16
17    /// Figure-level settings
18    pub title: Option<String>,
19    pub x_label: Option<String>,
20    pub y_label: Option<String>,
21    pub legend_enabled: bool,
22    pub grid_enabled: bool,
23    pub background_color: Vec4,
24
25    /// Axis limits (None = auto-scale)
26    pub x_limits: Option<(f64, f64)>,
27    pub y_limits: Option<(f64, f64)>,
28
29    /// Cached data
30    bounds: Option<BoundingBox>,
31    dirty: bool,
32}
33
34/// A plot element that can be any type of plot
35#[derive(Debug, Clone)]
36pub enum PlotElement {
37    Line(LinePlot),
38    Scatter(ScatterPlot),
39    Bar(BarChart),
40    Histogram(Histogram),
41    PointCloud(PointCloudPlot),
42}
43
44/// Legend entry for a plot
45#[derive(Debug, Clone)]
46pub struct LegendEntry {
47    pub label: String,
48    pub color: Vec4,
49    pub plot_type: PlotType,
50}
51
52/// Type of plot for legend rendering
53#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
54pub enum PlotType {
55    Line,
56    Scatter,
57    Bar,
58    Histogram,
59    PointCloud,
60}
61
62impl Figure {
63    /// Create a new empty figure
64    pub fn new() -> Self {
65        Self {
66            plots: Vec::new(),
67            title: None,
68            x_label: None,
69            y_label: None,
70            legend_enabled: true,
71            grid_enabled: true,
72            background_color: Vec4::new(1.0, 1.0, 1.0, 1.0), // White background
73            x_limits: None,
74            y_limits: None,
75            bounds: None,
76            dirty: true,
77        }
78    }
79
80    /// Set the figure title
81    pub fn with_title<S: Into<String>>(mut self, title: S) -> Self {
82        self.title = Some(title.into());
83        self
84    }
85
86    /// Set axis labels
87    pub fn with_labels<S: Into<String>>(mut self, x_label: S, y_label: S) -> Self {
88        self.x_label = Some(x_label.into());
89        self.y_label = Some(y_label.into());
90        self
91    }
92
93    /// Set axis limits manually
94    pub fn with_limits(mut self, x_limits: (f64, f64), y_limits: (f64, f64)) -> Self {
95        self.x_limits = Some(x_limits);
96        self.y_limits = Some(y_limits);
97        self.dirty = true;
98        self
99    }
100
101    /// Enable or disable the legend
102    pub fn with_legend(mut self, enabled: bool) -> Self {
103        self.legend_enabled = enabled;
104        self
105    }
106
107    /// Enable or disable the grid
108    pub fn with_grid(mut self, enabled: bool) -> Self {
109        self.grid_enabled = enabled;
110        self
111    }
112
113    /// Set background color
114    pub fn with_background_color(mut self, color: Vec4) -> Self {
115        self.background_color = color;
116        self
117    }
118
119    /// Add a line plot to the figure
120    pub fn add_line_plot(&mut self, plot: LinePlot) -> usize {
121        self.plots.push(PlotElement::Line(plot));
122        self.dirty = true;
123        self.plots.len() - 1
124    }
125
126    /// Add a scatter plot to the figure
127    pub fn add_scatter_plot(&mut self, plot: ScatterPlot) -> usize {
128        self.plots.push(PlotElement::Scatter(plot));
129        self.dirty = true;
130        self.plots.len() - 1
131    }
132
133    /// Add a bar chart to the figure
134    pub fn add_bar_chart(&mut self, plot: BarChart) -> usize {
135        self.plots.push(PlotElement::Bar(plot));
136        self.dirty = true;
137        self.plots.len() - 1
138    }
139
140    /// Add a histogram to the figure
141    pub fn add_histogram(&mut self, plot: Histogram) -> usize {
142        self.plots.push(PlotElement::Histogram(plot));
143        self.dirty = true;
144        self.plots.len() - 1
145    }
146
147    /// Add a point cloud to the figure
148    pub fn add_point_cloud_plot(&mut self, plot: PointCloudPlot) -> usize {
149        self.plots.push(PlotElement::PointCloud(plot));
150        self.dirty = true;
151        self.plots.len() - 1
152    }
153
154    /// Remove a plot by index
155    pub fn remove_plot(&mut self, index: usize) -> Result<(), String> {
156        if index >= self.plots.len() {
157            return Err(format!("Plot index {index} out of bounds"));
158        }
159        self.plots.remove(index);
160        self.dirty = true;
161        Ok(())
162    }
163
164    /// Clear all plots
165    pub fn clear(&mut self) {
166        self.plots.clear();
167        self.dirty = true;
168    }
169
170    /// Get the number of plots
171    pub fn len(&self) -> usize {
172        self.plots.len()
173    }
174
175    /// Check if figure has no plots
176    pub fn is_empty(&self) -> bool {
177        self.plots.is_empty()
178    }
179
180    /// Get an iterator over all plots in this figure
181    pub fn plots(&self) -> impl Iterator<Item = &PlotElement> {
182        self.plots.iter()
183    }
184
185    /// Get a mutable reference to a plot
186    pub fn get_plot_mut(&mut self, index: usize) -> Option<&mut PlotElement> {
187        self.dirty = true;
188        self.plots.get_mut(index)
189    }
190
191    /// Get the combined bounds of all visible plots
192    pub fn bounds(&mut self) -> BoundingBox {
193        if self.dirty || self.bounds.is_none() {
194            self.compute_bounds();
195        }
196        self.bounds.unwrap()
197    }
198
199    /// Compute the combined bounds from all plots
200    fn compute_bounds(&mut self) {
201        if self.plots.is_empty() {
202            self.bounds = Some(BoundingBox::default());
203            return;
204        }
205
206        let mut combined_bounds = None;
207
208        for plot in &mut self.plots {
209            if !plot.is_visible() {
210                continue;
211            }
212
213            let plot_bounds = plot.bounds();
214
215            combined_bounds = match combined_bounds {
216                None => Some(plot_bounds),
217                Some(existing) => Some(existing.union(&plot_bounds)),
218            };
219        }
220
221        self.bounds = combined_bounds.or_else(|| Some(BoundingBox::default()));
222        self.dirty = false;
223    }
224
225    /// Generate all render data for all visible plots
226    pub fn render_data(&mut self) -> Vec<RenderData> {
227        let mut render_data = Vec::new();
228
229        for plot in &mut self.plots {
230            if plot.is_visible() {
231                render_data.push(plot.render_data());
232            }
233        }
234
235        render_data
236    }
237
238    /// Get legend entries for all labeled plots
239    pub fn legend_entries(&self) -> Vec<LegendEntry> {
240        let mut entries = Vec::new();
241
242        for plot in &self.plots {
243            if let Some(label) = plot.label() {
244                entries.push(LegendEntry {
245                    label,
246                    color: plot.color(),
247                    plot_type: plot.plot_type(),
248                });
249            }
250        }
251
252        entries
253    }
254
255    /// Get figure statistics
256    pub fn statistics(&self) -> FigureStatistics {
257        let plot_counts = self.plots.iter().fold(HashMap::new(), |mut acc, plot| {
258            let plot_type = plot.plot_type();
259            *acc.entry(plot_type).or_insert(0) += 1;
260            acc
261        });
262
263        let total_memory: usize = self
264            .plots
265            .iter()
266            .map(|plot| plot.estimated_memory_usage())
267            .sum();
268
269        let visible_count = self.plots.iter().filter(|plot| plot.is_visible()).count();
270
271        FigureStatistics {
272            total_plots: self.plots.len(),
273            visible_plots: visible_count,
274            plot_type_counts: plot_counts,
275            total_memory_usage: total_memory,
276            has_legend: self.legend_enabled && !self.legend_entries().is_empty(),
277        }
278    }
279}
280
281impl Default for Figure {
282    fn default() -> Self {
283        Self::new()
284    }
285}
286
287impl PlotElement {
288    /// Check if the plot is visible
289    pub fn is_visible(&self) -> bool {
290        match self {
291            PlotElement::Line(plot) => plot.visible,
292            PlotElement::Scatter(plot) => plot.visible,
293            PlotElement::Bar(plot) => plot.visible,
294            PlotElement::Histogram(plot) => plot.visible,
295            PlotElement::PointCloud(plot) => plot.visible,
296        }
297    }
298
299    /// Get the plot's label
300    pub fn label(&self) -> Option<String> {
301        match self {
302            PlotElement::Line(plot) => plot.label.clone(),
303            PlotElement::Scatter(plot) => plot.label.clone(),
304            PlotElement::Bar(plot) => plot.label.clone(),
305            PlotElement::Histogram(plot) => plot.label.clone(),
306            PlotElement::PointCloud(plot) => plot.label.clone(),
307        }
308    }
309
310    /// Get the plot's primary color
311    pub fn color(&self) -> Vec4 {
312        match self {
313            PlotElement::Line(plot) => plot.color,
314            PlotElement::Scatter(plot) => plot.color,
315            PlotElement::Bar(plot) => plot.color,
316            PlotElement::Histogram(plot) => plot.color,
317            PlotElement::PointCloud(plot) => plot.default_color,
318        }
319    }
320
321    /// Get the plot type
322    pub fn plot_type(&self) -> PlotType {
323        match self {
324            PlotElement::Line(_) => PlotType::Line,
325            PlotElement::Scatter(_) => PlotType::Scatter,
326            PlotElement::Bar(_) => PlotType::Bar,
327            PlotElement::Histogram(_) => PlotType::Histogram,
328            PlotElement::PointCloud(_) => PlotType::PointCloud,
329        }
330    }
331
332    /// Get the plot's bounds
333    pub fn bounds(&mut self) -> BoundingBox {
334        match self {
335            PlotElement::Line(plot) => plot.bounds(),
336            PlotElement::Scatter(plot) => plot.bounds(),
337            PlotElement::Bar(plot) => plot.bounds(),
338            PlotElement::Histogram(plot) => plot.bounds(),
339            PlotElement::PointCloud(plot) => plot.bounds(),
340        }
341    }
342
343    /// Generate render data for this plot
344    pub fn render_data(&mut self) -> RenderData {
345        match self {
346            PlotElement::Line(plot) => plot.render_data(),
347            PlotElement::Scatter(plot) => plot.render_data(),
348            PlotElement::Bar(plot) => plot.render_data(),
349            PlotElement::Histogram(plot) => plot.render_data(),
350            PlotElement::PointCloud(plot) => plot.render_data(),
351        }
352    }
353
354    /// Estimate memory usage
355    pub fn estimated_memory_usage(&self) -> usize {
356        match self {
357            PlotElement::Line(plot) => plot.estimated_memory_usage(),
358            PlotElement::Scatter(plot) => plot.estimated_memory_usage(),
359            PlotElement::Bar(plot) => plot.estimated_memory_usage(),
360            PlotElement::Histogram(plot) => plot.estimated_memory_usage(),
361            PlotElement::PointCloud(plot) => plot.estimated_memory_usage(),
362        }
363    }
364}
365
366/// Figure statistics for debugging and optimization
367#[derive(Debug)]
368pub struct FigureStatistics {
369    pub total_plots: usize,
370    pub visible_plots: usize,
371    pub plot_type_counts: HashMap<PlotType, usize>,
372    pub total_memory_usage: usize,
373    pub has_legend: bool,
374}
375
376/// MATLAB-compatible figure creation utilities
377pub mod matlab_compat {
378    use super::*;
379    use crate::plots::{LinePlot, ScatterPlot};
380
381    /// Create a new figure (equivalent to MATLAB's `figure`)
382    pub fn figure() -> Figure {
383        Figure::new()
384    }
385
386    /// Create a figure with a title
387    pub fn figure_with_title<S: Into<String>>(title: S) -> Figure {
388        Figure::new().with_title(title)
389    }
390
391    /// Add multiple line plots to a figure (`hold on` behavior)
392    pub fn plot_multiple_lines(
393        figure: &mut Figure,
394        data_sets: Vec<(Vec<f64>, Vec<f64>, Option<String>)>,
395    ) -> Result<Vec<usize>, String> {
396        let mut indices = Vec::new();
397
398        for (i, (x, y, label)) in data_sets.into_iter().enumerate() {
399            let mut line = LinePlot::new(x, y)?;
400
401            // Automatic color cycling (similar to MATLAB)
402            let colors = [
403                Vec4::new(0.0, 0.4470, 0.7410, 1.0),    // Blue
404                Vec4::new(0.8500, 0.3250, 0.0980, 1.0), // Orange
405                Vec4::new(0.9290, 0.6940, 0.1250, 1.0), // Yellow
406                Vec4::new(0.4940, 0.1840, 0.5560, 1.0), // Purple
407                Vec4::new(0.4660, 0.6740, 0.1880, 1.0), // Green
408                Vec4::new(std::f64::consts::LOG10_2 as f32, 0.7450, 0.9330, 1.0), // Cyan
409                Vec4::new(0.6350, 0.0780, 0.1840, 1.0), // Red
410            ];
411            let color = colors[i % colors.len()];
412            line.set_color(color);
413
414            if let Some(label) = label {
415                line = line.with_label(label);
416            }
417
418            indices.push(figure.add_line_plot(line));
419        }
420
421        Ok(indices)
422    }
423
424    /// Add multiple scatter plots to a figure
425    pub fn scatter_multiple(
426        figure: &mut Figure,
427        data_sets: Vec<(Vec<f64>, Vec<f64>, Option<String>)>,
428    ) -> Result<Vec<usize>, String> {
429        let mut indices = Vec::new();
430
431        for (i, (x, y, label)) in data_sets.into_iter().enumerate() {
432            let mut scatter = ScatterPlot::new(x, y)?;
433
434            // Automatic color cycling
435            let colors = [
436                Vec4::new(1.0, 0.0, 0.0, 1.0), // Red
437                Vec4::new(0.0, 1.0, 0.0, 1.0), // Green
438                Vec4::new(0.0, 0.0, 1.0, 1.0), // Blue
439                Vec4::new(1.0, 1.0, 0.0, 1.0), // Yellow
440                Vec4::new(1.0, 0.0, 1.0, 1.0), // Magenta
441                Vec4::new(0.0, 1.0, 1.0, 1.0), // Cyan
442                Vec4::new(0.5, 0.5, 0.5, 1.0), // Gray
443            ];
444            let color = colors[i % colors.len()];
445            scatter.set_color(color);
446
447            if let Some(label) = label {
448                scatter = scatter.with_label(label);
449            }
450
451            indices.push(figure.add_scatter_plot(scatter));
452        }
453
454        Ok(indices)
455    }
456}
457
458#[cfg(test)]
459mod tests {
460    use super::*;
461    use crate::plots::line::LineStyle;
462
463    #[test]
464    fn test_figure_creation() {
465        let figure = Figure::new();
466
467        assert_eq!(figure.len(), 0);
468        assert!(figure.is_empty());
469        assert!(figure.legend_enabled);
470        assert!(figure.grid_enabled);
471    }
472
473    #[test]
474    fn test_figure_styling() {
475        let figure = Figure::new()
476            .with_title("Test Figure")
477            .with_labels("X Axis", "Y Axis")
478            .with_legend(false)
479            .with_grid(false);
480
481        assert_eq!(figure.title, Some("Test Figure".to_string()));
482        assert_eq!(figure.x_label, Some("X Axis".to_string()));
483        assert_eq!(figure.y_label, Some("Y Axis".to_string()));
484        assert!(!figure.legend_enabled);
485        assert!(!figure.grid_enabled);
486    }
487
488    #[test]
489    fn test_multiple_line_plots() {
490        let mut figure = Figure::new();
491
492        // Add first line plot
493        let line1 = LinePlot::new(vec![0.0, 1.0, 2.0], vec![0.0, 1.0, 4.0])
494            .unwrap()
495            .with_label("Quadratic");
496        let index1 = figure.add_line_plot(line1);
497
498        // Add second line plot
499        let line2 = LinePlot::new(vec![0.0, 1.0, 2.0], vec![0.0, 1.0, 2.0])
500            .unwrap()
501            .with_style(Vec4::new(1.0, 0.0, 0.0, 1.0), 2.0, LineStyle::Dashed)
502            .with_label("Linear");
503        let index2 = figure.add_line_plot(line2);
504
505        assert_eq!(figure.len(), 2);
506        assert_eq!(index1, 0);
507        assert_eq!(index2, 1);
508
509        // Test legend entries
510        let legend = figure.legend_entries();
511        assert_eq!(legend.len(), 2);
512        assert_eq!(legend[0].label, "Quadratic");
513        assert_eq!(legend[1].label, "Linear");
514    }
515
516    #[test]
517    fn test_mixed_plot_types() {
518        let mut figure = Figure::new();
519
520        // Add different plot types
521        let line = LinePlot::new(vec![0.0, 1.0, 2.0], vec![1.0, 2.0, 3.0])
522            .unwrap()
523            .with_label("Line");
524        figure.add_line_plot(line);
525
526        let scatter = ScatterPlot::new(vec![0.5, 1.5, 2.5], vec![1.5, 2.5, 3.5])
527            .unwrap()
528            .with_label("Scatter");
529        figure.add_scatter_plot(scatter);
530
531        let bar = BarChart::new(vec!["A".to_string(), "B".to_string()], vec![2.0, 4.0])
532            .unwrap()
533            .with_label("Bar");
534        figure.add_bar_chart(bar);
535
536        assert_eq!(figure.len(), 3);
537
538        // Test render data generation
539        let render_data = figure.render_data();
540        assert_eq!(render_data.len(), 3);
541
542        // Test statistics
543        let stats = figure.statistics();
544        assert_eq!(stats.total_plots, 3);
545        assert_eq!(stats.visible_plots, 3);
546        assert!(stats.has_legend);
547    }
548
549    #[test]
550    fn test_plot_visibility() {
551        let mut figure = Figure::new();
552
553        let mut line = LinePlot::new(vec![0.0, 1.0], vec![0.0, 1.0]).unwrap();
554        line.set_visible(false); // Hide this plot
555        figure.add_line_plot(line);
556
557        let scatter = ScatterPlot::new(vec![0.0, 1.0], vec![1.0, 2.0]).unwrap();
558        figure.add_scatter_plot(scatter);
559
560        // Only one plot should be visible
561        let render_data = figure.render_data();
562        assert_eq!(render_data.len(), 1);
563
564        let stats = figure.statistics();
565        assert_eq!(stats.total_plots, 2);
566        assert_eq!(stats.visible_plots, 1);
567    }
568
569    #[test]
570    fn test_bounds_computation() {
571        let mut figure = Figure::new();
572
573        // Add plots with different ranges
574        let line = LinePlot::new(vec![-1.0, 0.0, 1.0], vec![-2.0, 0.0, 2.0]).unwrap();
575        figure.add_line_plot(line);
576
577        let scatter = ScatterPlot::new(vec![2.0, 3.0, 4.0], vec![1.0, 3.0, 5.0]).unwrap();
578        figure.add_scatter_plot(scatter);
579
580        let bounds = figure.bounds();
581
582        // Bounds should encompass all plots
583        assert!(bounds.min.x <= -1.0);
584        assert!(bounds.max.x >= 4.0);
585        assert!(bounds.min.y <= -2.0);
586        assert!(bounds.max.y >= 5.0);
587    }
588
589    #[test]
590    fn test_matlab_compat_multiple_lines() {
591        use super::matlab_compat::*;
592
593        let mut figure = figure_with_title("Multiple Lines Test");
594
595        let data_sets = vec![
596            (
597                vec![0.0, 1.0, 2.0],
598                vec![0.0, 1.0, 4.0],
599                Some("Quadratic".to_string()),
600            ),
601            (
602                vec![0.0, 1.0, 2.0],
603                vec![0.0, 1.0, 2.0],
604                Some("Linear".to_string()),
605            ),
606            (
607                vec![0.0, 1.0, 2.0],
608                vec![1.0, 1.0, 1.0],
609                Some("Constant".to_string()),
610            ),
611        ];
612
613        let indices = plot_multiple_lines(&mut figure, data_sets).unwrap();
614
615        assert_eq!(indices.len(), 3);
616        assert_eq!(figure.len(), 3);
617
618        // Each plot should have different colors
619        let legend = figure.legend_entries();
620        assert_eq!(legend.len(), 3);
621        assert_ne!(legend[0].color, legend[1].color);
622        assert_ne!(legend[1].color, legend[2].color);
623    }
624}