1use crate::core::{BoundingBox, RenderData};
7use crate::plots::{BarChart, Histogram, LinePlot, PointCloudPlot, ScatterPlot};
8use glam::Vec4;
9use std::collections::HashMap;
10
11#[derive(Debug, Clone)]
13pub struct Figure {
14 plots: Vec<PlotElement>,
16
17 pub title: Option<String>,
19 pub x_label: Option<String>,
20 pub y_label: Option<String>,
21 pub legend_enabled: bool,
22 pub grid_enabled: bool,
23 pub background_color: Vec4,
24
25 pub x_limits: Option<(f64, f64)>,
27 pub y_limits: Option<(f64, f64)>,
28
29 bounds: Option<BoundingBox>,
31 dirty: bool,
32}
33
34#[derive(Debug, Clone)]
36pub enum PlotElement {
37 Line(LinePlot),
38 Scatter(ScatterPlot),
39 Bar(BarChart),
40 Histogram(Histogram),
41 PointCloud(PointCloudPlot),
42}
43
44#[derive(Debug, Clone)]
46pub struct LegendEntry {
47 pub label: String,
48 pub color: Vec4,
49 pub plot_type: PlotType,
50}
51
52#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
54pub enum PlotType {
55 Line,
56 Scatter,
57 Bar,
58 Histogram,
59 PointCloud,
60}
61
62impl Figure {
63 pub fn new() -> Self {
65 Self {
66 plots: Vec::new(),
67 title: None,
68 x_label: None,
69 y_label: None,
70 legend_enabled: true,
71 grid_enabled: true,
72 background_color: Vec4::new(1.0, 1.0, 1.0, 1.0), x_limits: None,
74 y_limits: None,
75 bounds: None,
76 dirty: true,
77 }
78 }
79
80 pub fn with_title<S: Into<String>>(mut self, title: S) -> Self {
82 self.title = Some(title.into());
83 self
84 }
85
86 pub fn with_labels<S: Into<String>>(mut self, x_label: S, y_label: S) -> Self {
88 self.x_label = Some(x_label.into());
89 self.y_label = Some(y_label.into());
90 self
91 }
92
93 pub fn with_limits(mut self, x_limits: (f64, f64), y_limits: (f64, f64)) -> Self {
95 self.x_limits = Some(x_limits);
96 self.y_limits = Some(y_limits);
97 self.dirty = true;
98 self
99 }
100
101 pub fn with_legend(mut self, enabled: bool) -> Self {
103 self.legend_enabled = enabled;
104 self
105 }
106
107 pub fn with_grid(mut self, enabled: bool) -> Self {
109 self.grid_enabled = enabled;
110 self
111 }
112
113 pub fn with_background_color(mut self, color: Vec4) -> Self {
115 self.background_color = color;
116 self
117 }
118
119 pub fn add_line_plot(&mut self, plot: LinePlot) -> usize {
121 self.plots.push(PlotElement::Line(plot));
122 self.dirty = true;
123 self.plots.len() - 1
124 }
125
126 pub fn add_scatter_plot(&mut self, plot: ScatterPlot) -> usize {
128 self.plots.push(PlotElement::Scatter(plot));
129 self.dirty = true;
130 self.plots.len() - 1
131 }
132
133 pub fn add_bar_chart(&mut self, plot: BarChart) -> usize {
135 self.plots.push(PlotElement::Bar(plot));
136 self.dirty = true;
137 self.plots.len() - 1
138 }
139
140 pub fn add_histogram(&mut self, plot: Histogram) -> usize {
142 self.plots.push(PlotElement::Histogram(plot));
143 self.dirty = true;
144 self.plots.len() - 1
145 }
146
147 pub fn add_point_cloud_plot(&mut self, plot: PointCloudPlot) -> usize {
149 self.plots.push(PlotElement::PointCloud(plot));
150 self.dirty = true;
151 self.plots.len() - 1
152 }
153
154 pub fn remove_plot(&mut self, index: usize) -> Result<(), String> {
156 if index >= self.plots.len() {
157 return Err(format!("Plot index {index} out of bounds"));
158 }
159 self.plots.remove(index);
160 self.dirty = true;
161 Ok(())
162 }
163
164 pub fn clear(&mut self) {
166 self.plots.clear();
167 self.dirty = true;
168 }
169
170 pub fn len(&self) -> usize {
172 self.plots.len()
173 }
174
175 pub fn is_empty(&self) -> bool {
177 self.plots.is_empty()
178 }
179
180 pub fn plots(&self) -> impl Iterator<Item = &PlotElement> {
182 self.plots.iter()
183 }
184
185 pub fn get_plot_mut(&mut self, index: usize) -> Option<&mut PlotElement> {
187 self.dirty = true;
188 self.plots.get_mut(index)
189 }
190
191 pub fn bounds(&mut self) -> BoundingBox {
193 if self.dirty || self.bounds.is_none() {
194 self.compute_bounds();
195 }
196 self.bounds.unwrap()
197 }
198
199 fn compute_bounds(&mut self) {
201 if self.plots.is_empty() {
202 self.bounds = Some(BoundingBox::default());
203 return;
204 }
205
206 let mut combined_bounds = None;
207
208 for plot in &mut self.plots {
209 if !plot.is_visible() {
210 continue;
211 }
212
213 let plot_bounds = plot.bounds();
214
215 combined_bounds = match combined_bounds {
216 None => Some(plot_bounds),
217 Some(existing) => Some(existing.union(&plot_bounds)),
218 };
219 }
220
221 self.bounds = combined_bounds.or_else(|| Some(BoundingBox::default()));
222 self.dirty = false;
223 }
224
225 pub fn render_data(&mut self) -> Vec<RenderData> {
227 let mut render_data = Vec::new();
228
229 for plot in &mut self.plots {
230 if plot.is_visible() {
231 render_data.push(plot.render_data());
232 }
233 }
234
235 render_data
236 }
237
238 pub fn legend_entries(&self) -> Vec<LegendEntry> {
240 let mut entries = Vec::new();
241
242 for plot in &self.plots {
243 if let Some(label) = plot.label() {
244 entries.push(LegendEntry {
245 label,
246 color: plot.color(),
247 plot_type: plot.plot_type(),
248 });
249 }
250 }
251
252 entries
253 }
254
255 pub fn statistics(&self) -> FigureStatistics {
257 let plot_counts = self.plots.iter().fold(HashMap::new(), |mut acc, plot| {
258 let plot_type = plot.plot_type();
259 *acc.entry(plot_type).or_insert(0) += 1;
260 acc
261 });
262
263 let total_memory: usize = self
264 .plots
265 .iter()
266 .map(|plot| plot.estimated_memory_usage())
267 .sum();
268
269 let visible_count = self.plots.iter().filter(|plot| plot.is_visible()).count();
270
271 FigureStatistics {
272 total_plots: self.plots.len(),
273 visible_plots: visible_count,
274 plot_type_counts: plot_counts,
275 total_memory_usage: total_memory,
276 has_legend: self.legend_enabled && !self.legend_entries().is_empty(),
277 }
278 }
279}
280
281impl Default for Figure {
282 fn default() -> Self {
283 Self::new()
284 }
285}
286
287impl PlotElement {
288 pub fn is_visible(&self) -> bool {
290 match self {
291 PlotElement::Line(plot) => plot.visible,
292 PlotElement::Scatter(plot) => plot.visible,
293 PlotElement::Bar(plot) => plot.visible,
294 PlotElement::Histogram(plot) => plot.visible,
295 PlotElement::PointCloud(plot) => plot.visible,
296 }
297 }
298
299 pub fn label(&self) -> Option<String> {
301 match self {
302 PlotElement::Line(plot) => plot.label.clone(),
303 PlotElement::Scatter(plot) => plot.label.clone(),
304 PlotElement::Bar(plot) => plot.label.clone(),
305 PlotElement::Histogram(plot) => plot.label.clone(),
306 PlotElement::PointCloud(plot) => plot.label.clone(),
307 }
308 }
309
310 pub fn color(&self) -> Vec4 {
312 match self {
313 PlotElement::Line(plot) => plot.color,
314 PlotElement::Scatter(plot) => plot.color,
315 PlotElement::Bar(plot) => plot.color,
316 PlotElement::Histogram(plot) => plot.color,
317 PlotElement::PointCloud(plot) => plot.default_color,
318 }
319 }
320
321 pub fn plot_type(&self) -> PlotType {
323 match self {
324 PlotElement::Line(_) => PlotType::Line,
325 PlotElement::Scatter(_) => PlotType::Scatter,
326 PlotElement::Bar(_) => PlotType::Bar,
327 PlotElement::Histogram(_) => PlotType::Histogram,
328 PlotElement::PointCloud(_) => PlotType::PointCloud,
329 }
330 }
331
332 pub fn bounds(&mut self) -> BoundingBox {
334 match self {
335 PlotElement::Line(plot) => plot.bounds(),
336 PlotElement::Scatter(plot) => plot.bounds(),
337 PlotElement::Bar(plot) => plot.bounds(),
338 PlotElement::Histogram(plot) => plot.bounds(),
339 PlotElement::PointCloud(plot) => plot.bounds(),
340 }
341 }
342
343 pub fn render_data(&mut self) -> RenderData {
345 match self {
346 PlotElement::Line(plot) => plot.render_data(),
347 PlotElement::Scatter(plot) => plot.render_data(),
348 PlotElement::Bar(plot) => plot.render_data(),
349 PlotElement::Histogram(plot) => plot.render_data(),
350 PlotElement::PointCloud(plot) => plot.render_data(),
351 }
352 }
353
354 pub fn estimated_memory_usage(&self) -> usize {
356 match self {
357 PlotElement::Line(plot) => plot.estimated_memory_usage(),
358 PlotElement::Scatter(plot) => plot.estimated_memory_usage(),
359 PlotElement::Bar(plot) => plot.estimated_memory_usage(),
360 PlotElement::Histogram(plot) => plot.estimated_memory_usage(),
361 PlotElement::PointCloud(plot) => plot.estimated_memory_usage(),
362 }
363 }
364}
365
366#[derive(Debug)]
368pub struct FigureStatistics {
369 pub total_plots: usize,
370 pub visible_plots: usize,
371 pub plot_type_counts: HashMap<PlotType, usize>,
372 pub total_memory_usage: usize,
373 pub has_legend: bool,
374}
375
376pub mod matlab_compat {
378 use super::*;
379 use crate::plots::{LinePlot, ScatterPlot};
380
381 pub fn figure() -> Figure {
383 Figure::new()
384 }
385
386 pub fn figure_with_title<S: Into<String>>(title: S) -> Figure {
388 Figure::new().with_title(title)
389 }
390
391 pub fn plot_multiple_lines(
393 figure: &mut Figure,
394 data_sets: Vec<(Vec<f64>, Vec<f64>, Option<String>)>,
395 ) -> Result<Vec<usize>, String> {
396 let mut indices = Vec::new();
397
398 for (i, (x, y, label)) in data_sets.into_iter().enumerate() {
399 let mut line = LinePlot::new(x, y)?;
400
401 let colors = [
403 Vec4::new(0.0, 0.4470, 0.7410, 1.0), Vec4::new(0.8500, 0.3250, 0.0980, 1.0), Vec4::new(0.9290, 0.6940, 0.1250, 1.0), Vec4::new(0.4940, 0.1840, 0.5560, 1.0), Vec4::new(0.4660, 0.6740, 0.1880, 1.0), Vec4::new(std::f64::consts::LOG10_2 as f32, 0.7450, 0.9330, 1.0), Vec4::new(0.6350, 0.0780, 0.1840, 1.0), ];
411 let color = colors[i % colors.len()];
412 line.set_color(color);
413
414 if let Some(label) = label {
415 line = line.with_label(label);
416 }
417
418 indices.push(figure.add_line_plot(line));
419 }
420
421 Ok(indices)
422 }
423
424 pub fn scatter_multiple(
426 figure: &mut Figure,
427 data_sets: Vec<(Vec<f64>, Vec<f64>, Option<String>)>,
428 ) -> Result<Vec<usize>, String> {
429 let mut indices = Vec::new();
430
431 for (i, (x, y, label)) in data_sets.into_iter().enumerate() {
432 let mut scatter = ScatterPlot::new(x, y)?;
433
434 let colors = [
436 Vec4::new(1.0, 0.0, 0.0, 1.0), Vec4::new(0.0, 1.0, 0.0, 1.0), Vec4::new(0.0, 0.0, 1.0, 1.0), Vec4::new(1.0, 1.0, 0.0, 1.0), Vec4::new(1.0, 0.0, 1.0, 1.0), Vec4::new(0.0, 1.0, 1.0, 1.0), Vec4::new(0.5, 0.5, 0.5, 1.0), ];
444 let color = colors[i % colors.len()];
445 scatter.set_color(color);
446
447 if let Some(label) = label {
448 scatter = scatter.with_label(label);
449 }
450
451 indices.push(figure.add_scatter_plot(scatter));
452 }
453
454 Ok(indices)
455 }
456}
457
458#[cfg(test)]
459mod tests {
460 use super::*;
461 use crate::plots::line::LineStyle;
462
463 #[test]
464 fn test_figure_creation() {
465 let figure = Figure::new();
466
467 assert_eq!(figure.len(), 0);
468 assert!(figure.is_empty());
469 assert!(figure.legend_enabled);
470 assert!(figure.grid_enabled);
471 }
472
473 #[test]
474 fn test_figure_styling() {
475 let figure = Figure::new()
476 .with_title("Test Figure")
477 .with_labels("X Axis", "Y Axis")
478 .with_legend(false)
479 .with_grid(false);
480
481 assert_eq!(figure.title, Some("Test Figure".to_string()));
482 assert_eq!(figure.x_label, Some("X Axis".to_string()));
483 assert_eq!(figure.y_label, Some("Y Axis".to_string()));
484 assert!(!figure.legend_enabled);
485 assert!(!figure.grid_enabled);
486 }
487
488 #[test]
489 fn test_multiple_line_plots() {
490 let mut figure = Figure::new();
491
492 let line1 = LinePlot::new(vec![0.0, 1.0, 2.0], vec![0.0, 1.0, 4.0])
494 .unwrap()
495 .with_label("Quadratic");
496 let index1 = figure.add_line_plot(line1);
497
498 let line2 = LinePlot::new(vec![0.0, 1.0, 2.0], vec![0.0, 1.0, 2.0])
500 .unwrap()
501 .with_style(Vec4::new(1.0, 0.0, 0.0, 1.0), 2.0, LineStyle::Dashed)
502 .with_label("Linear");
503 let index2 = figure.add_line_plot(line2);
504
505 assert_eq!(figure.len(), 2);
506 assert_eq!(index1, 0);
507 assert_eq!(index2, 1);
508
509 let legend = figure.legend_entries();
511 assert_eq!(legend.len(), 2);
512 assert_eq!(legend[0].label, "Quadratic");
513 assert_eq!(legend[1].label, "Linear");
514 }
515
516 #[test]
517 fn test_mixed_plot_types() {
518 let mut figure = Figure::new();
519
520 let line = LinePlot::new(vec![0.0, 1.0, 2.0], vec![1.0, 2.0, 3.0])
522 .unwrap()
523 .with_label("Line");
524 figure.add_line_plot(line);
525
526 let scatter = ScatterPlot::new(vec![0.5, 1.5, 2.5], vec![1.5, 2.5, 3.5])
527 .unwrap()
528 .with_label("Scatter");
529 figure.add_scatter_plot(scatter);
530
531 let bar = BarChart::new(vec!["A".to_string(), "B".to_string()], vec![2.0, 4.0])
532 .unwrap()
533 .with_label("Bar");
534 figure.add_bar_chart(bar);
535
536 assert_eq!(figure.len(), 3);
537
538 let render_data = figure.render_data();
540 assert_eq!(render_data.len(), 3);
541
542 let stats = figure.statistics();
544 assert_eq!(stats.total_plots, 3);
545 assert_eq!(stats.visible_plots, 3);
546 assert!(stats.has_legend);
547 }
548
549 #[test]
550 fn test_plot_visibility() {
551 let mut figure = Figure::new();
552
553 let mut line = LinePlot::new(vec![0.0, 1.0], vec![0.0, 1.0]).unwrap();
554 line.set_visible(false); figure.add_line_plot(line);
556
557 let scatter = ScatterPlot::new(vec![0.0, 1.0], vec![1.0, 2.0]).unwrap();
558 figure.add_scatter_plot(scatter);
559
560 let render_data = figure.render_data();
562 assert_eq!(render_data.len(), 1);
563
564 let stats = figure.statistics();
565 assert_eq!(stats.total_plots, 2);
566 assert_eq!(stats.visible_plots, 1);
567 }
568
569 #[test]
570 fn test_bounds_computation() {
571 let mut figure = Figure::new();
572
573 let line = LinePlot::new(vec![-1.0, 0.0, 1.0], vec![-2.0, 0.0, 2.0]).unwrap();
575 figure.add_line_plot(line);
576
577 let scatter = ScatterPlot::new(vec![2.0, 3.0, 4.0], vec![1.0, 3.0, 5.0]).unwrap();
578 figure.add_scatter_plot(scatter);
579
580 let bounds = figure.bounds();
581
582 assert!(bounds.min.x <= -1.0);
584 assert!(bounds.max.x >= 4.0);
585 assert!(bounds.min.y <= -2.0);
586 assert!(bounds.max.y >= 5.0);
587 }
588
589 #[test]
590 fn test_matlab_compat_multiple_lines() {
591 use super::matlab_compat::*;
592
593 let mut figure = figure_with_title("Multiple Lines Test");
594
595 let data_sets = vec![
596 (
597 vec![0.0, 1.0, 2.0],
598 vec![0.0, 1.0, 4.0],
599 Some("Quadratic".to_string()),
600 ),
601 (
602 vec![0.0, 1.0, 2.0],
603 vec![0.0, 1.0, 2.0],
604 Some("Linear".to_string()),
605 ),
606 (
607 vec![0.0, 1.0, 2.0],
608 vec![1.0, 1.0, 1.0],
609 Some("Constant".to_string()),
610 ),
611 ];
612
613 let indices = plot_multiple_lines(&mut figure, data_sets).unwrap();
614
615 assert_eq!(indices.len(), 3);
616 assert_eq!(figure.len(), 3);
617
618 let legend = figure.legend_entries();
620 assert_eq!(legend.len(), 3);
621 assert_ne!(legend[0].color, legend[1].color);
622 assert_ne!(legend[1].color, legend[2].color);
623 }
624}