Skip to main content

runmat_plot/plots/
scatter.rs

1//! Scatter plot implementation
2//!
3//! High-performance scatter plotting with GPU acceleration.
4
5use crate::core::{
6    vertex_utils, BoundingBox, DrawCall, GpuVertexBuffer, Material, PipelineType, RenderData,
7    Vertex,
8};
9use crate::plots::surface::ColorMap;
10use glam::{Vec3, Vec4};
11
12/// High-performance GPU-accelerated scatter plot
13#[derive(Debug, Clone)]
14pub struct ScatterPlot {
15    /// Raw data points (x, y coordinates)
16    pub x_data: Vec<f64>,
17    pub y_data: Vec<f64>,
18
19    /// Visual styling
20    pub color: Vec4,
21    pub edge_color: Vec4,
22    pub edge_thickness: f32,
23    pub marker_size: f32,
24    pub marker_style: MarkerStyle,
25    pub per_point_sizes: Option<Vec<f32>>, // pixel diameters per point
26    pub per_point_colors: Option<Vec<Vec4>>, // per-point RGBA
27    pub color_values: Option<Vec<f64>>,    // scalar values mapped by colormap
28    pub color_limits: Option<(f64, f64)>,
29    pub colormap: ColorMap,
30    pub filled: bool,
31    pub edge_color_from_vertex_colors: bool,
32
33    /// Metadata
34    pub label: Option<String>,
35    pub visible: bool,
36
37    /// Generated rendering data (cached)
38    vertices: Option<Vec<Vertex>>,
39    bounds: Option<BoundingBox>,
40    dirty: bool,
41    gpu_vertices: Option<GpuVertexBuffer>,
42    gpu_point_count: Option<usize>,
43    gpu_has_per_point_sizes: bool,
44    gpu_has_per_point_colors: bool,
45}
46
47/// Marker styles for scatter plots
48#[derive(Debug, Clone, Copy, PartialEq, Eq)]
49pub enum MarkerStyle {
50    Circle,
51    Square,
52    Triangle,
53    Diamond,
54    Plus,
55    Cross,
56    Star,
57    Hexagon,
58}
59
60impl Default for MarkerStyle {
61    fn default() -> Self {
62        Self::Circle
63    }
64}
65
66#[derive(Clone, Copy, Debug)]
67pub struct ScatterGpuStyle {
68    pub color: Vec4,
69    pub edge_color: Vec4,
70    pub edge_thickness: f32,
71    pub marker_size: f32,
72    pub marker_style: MarkerStyle,
73    pub filled: bool,
74    pub has_per_point_sizes: bool,
75    pub has_per_point_colors: bool,
76    pub edge_from_vertex_colors: bool,
77}
78
79impl ScatterPlot {
80    /// Create a new scatter plot with data
81    pub fn new(x_data: Vec<f64>, y_data: Vec<f64>) -> Result<Self, String> {
82        if x_data.len() != y_data.len() {
83            return Err(format!(
84                "Data length mismatch: x_data has {} points, y_data has {} points",
85                x_data.len(),
86                y_data.len()
87            ));
88        }
89
90        if x_data.is_empty() {
91            return Err("Cannot create scatter plot with empty data".to_string());
92        }
93
94        Ok(Self {
95            x_data,
96            y_data,
97            color: Vec4::new(1.0, 0.2, 0.2, 1.0), // Brighter red
98            edge_color: Vec4::new(0.0, 0.0, 0.0, 1.0),
99            edge_thickness: 1.0,
100            marker_size: 12.0,
101            marker_style: MarkerStyle::default(),
102            per_point_sizes: None,
103            per_point_colors: None,
104            color_values: None,
105            color_limits: None,
106            colormap: ColorMap::Parula,
107            filled: false,
108            edge_color_from_vertex_colors: false,
109            label: None,
110            visible: true,
111            vertices: None,
112            bounds: None,
113            dirty: true,
114            gpu_vertices: None,
115            gpu_point_count: None,
116            gpu_has_per_point_sizes: false,
117            gpu_has_per_point_colors: false,
118        })
119    }
120
121    /// Build a scatter plot directly from a GPU vertex buffer to avoid CPU copies.
122    pub fn from_gpu_buffer(
123        buffer: GpuVertexBuffer,
124        point_count: usize,
125        bounds: BoundingBox,
126        style: ScatterGpuStyle,
127    ) -> Self {
128        Self {
129            x_data: Vec::new(),
130            y_data: Vec::new(),
131            color: style.color,
132            edge_color: style.edge_color,
133            edge_thickness: style.edge_thickness,
134            marker_size: style.marker_size,
135            marker_style: style.marker_style,
136            per_point_sizes: None,
137            per_point_colors: None,
138            color_values: None,
139            color_limits: None,
140            colormap: ColorMap::Parula,
141            filled: style.filled,
142            edge_color_from_vertex_colors: style.edge_from_vertex_colors,
143            label: None,
144            visible: true,
145            vertices: None,
146            bounds: Some(bounds),
147            dirty: false,
148            gpu_vertices: Some(buffer),
149            gpu_point_count: Some(point_count),
150            gpu_has_per_point_sizes: style.has_per_point_sizes,
151            gpu_has_per_point_colors: style.has_per_point_colors,
152        }
153    }
154
155    fn invalidate_gpu_vertices(&mut self) {
156        self.gpu_vertices = None;
157        self.gpu_point_count = None;
158    }
159
160    /// Create a scatter plot with custom styling
161    pub fn with_style(mut self, color: Vec4, marker_size: f32, marker_style: MarkerStyle) -> Self {
162        self.color = color;
163        self.marker_size = marker_size;
164        self.marker_style = marker_style;
165        self.dirty = true;
166        self.invalidate_gpu_vertices();
167        self.gpu_has_per_point_sizes = false;
168        self.gpu_has_per_point_colors = false;
169        self
170    }
171
172    /// Set the plot label for legends
173    pub fn with_label<S: Into<String>>(mut self, label: S) -> Self {
174        self.label = Some(label.into());
175        self
176    }
177
178    /// Set marker face color
179    pub fn set_face_color(&mut self, color: Vec4) {
180        self.color = color;
181        self.dirty = true;
182        self.invalidate_gpu_vertices();
183    }
184    /// Set marker edge color
185    pub fn set_edge_color(&mut self, color: Vec4) {
186        self.edge_color = color;
187        self.dirty = true;
188        self.invalidate_gpu_vertices();
189    }
190    pub fn set_edge_color_from_vertex(&mut self, enabled: bool) {
191        self.edge_color_from_vertex_colors = enabled;
192    }
193    /// Set marker edge thickness (pixels)
194    pub fn set_edge_thickness(&mut self, px: f32) {
195        self.edge_thickness = px.max(0.0);
196        self.dirty = true;
197        self.invalidate_gpu_vertices();
198    }
199    pub fn set_sizes(&mut self, sizes: Vec<f32>) {
200        self.per_point_sizes = Some(sizes);
201        self.dirty = true;
202        self.invalidate_gpu_vertices();
203        self.gpu_has_per_point_sizes = false;
204    }
205    pub fn set_colors(&mut self, colors: Vec<Vec4>) {
206        self.per_point_colors = Some(colors);
207        self.dirty = true;
208        self.invalidate_gpu_vertices();
209        self.gpu_has_per_point_colors = false;
210    }
211    pub fn set_color_values(&mut self, values: Vec<f64>, limits: Option<(f64, f64)>) {
212        self.color_values = Some(values);
213        self.color_limits = limits;
214        self.dirty = true;
215        self.invalidate_gpu_vertices();
216        self.gpu_has_per_point_colors = false;
217    }
218    pub fn with_colormap(mut self, cmap: ColorMap) -> Self {
219        self.colormap = cmap;
220        self.dirty = true;
221        self.invalidate_gpu_vertices();
222        self
223    }
224    pub fn set_filled(&mut self, filled: bool) {
225        self.filled = filled;
226        self.dirty = true;
227        self.invalidate_gpu_vertices();
228    }
229
230    /// Update the data points
231    pub fn update_data(&mut self, x_data: Vec<f64>, y_data: Vec<f64>) -> Result<(), String> {
232        if x_data.len() != y_data.len() {
233            return Err(format!(
234                "Data length mismatch: x_data has {} points, y_data has {} points",
235                x_data.len(),
236                y_data.len()
237            ));
238        }
239
240        if x_data.is_empty() {
241            return Err("Cannot update with empty data".to_string());
242        }
243
244        self.x_data = x_data;
245        self.y_data = y_data;
246        self.dirty = true;
247        self.invalidate_gpu_vertices();
248        Ok(())
249    }
250
251    /// Set the color of the markers
252    pub fn set_color(&mut self, color: Vec4) {
253        self.color = color;
254        self.dirty = true;
255        self.invalidate_gpu_vertices();
256    }
257
258    /// Set the marker size
259    pub fn set_marker_size(&mut self, size: f32) {
260        self.marker_size = size.max(0.1); // Minimum marker size
261        self.dirty = true;
262        self.invalidate_gpu_vertices();
263    }
264
265    /// Set the marker style
266    pub fn set_marker_style(&mut self, style: MarkerStyle) {
267        self.marker_style = style;
268        self.dirty = true;
269        self.invalidate_gpu_vertices();
270    }
271
272    /// Show or hide the plot
273    pub fn set_visible(&mut self, visible: bool) {
274        self.visible = visible;
275    }
276
277    /// Get the number of data points
278    pub fn len(&self) -> usize {
279        if !self.x_data.is_empty() {
280            self.x_data.len()
281        } else {
282            self.gpu_point_count.unwrap_or(0)
283        }
284    }
285
286    /// Check if the plot has no data
287    pub fn is_empty(&self) -> bool {
288        self.len() == 0
289    }
290
291    /// Generate vertices for GPU rendering
292    pub fn generate_vertices(&mut self) -> &Vec<Vertex> {
293        if self.gpu_vertices.is_some() {
294            if self.vertices.is_none() {
295                self.vertices = Some(Vec::new());
296            }
297            return self.vertices.as_ref().unwrap();
298        }
299        if self.dirty || self.vertices.is_none() {
300            let base_color = self.color;
301            if self.per_point_colors.is_some() || self.color_values.is_some() { /* vertex color takes precedence; shader blends by face alpha */
302            }
303            let mut verts =
304                vertex_utils::create_scatter_plot(&self.x_data, &self.y_data, base_color);
305            // per-point colors
306            if let Some(ref colors) = self.per_point_colors {
307                let m = colors.len().min(verts.len());
308                for i in 0..m {
309                    verts[i].color = colors[i].to_array();
310                }
311            } else if let Some(ref vals) = self.color_values {
312                let n = verts.len();
313                let (mut cmin, mut cmax) = if let Some(lims) = self.color_limits {
314                    lims
315                } else {
316                    let mut lo = f64::INFINITY;
317                    let mut hi = f64::NEG_INFINITY;
318                    for &v in vals {
319                        if v.is_finite() {
320                            if v < lo {
321                                lo = v;
322                            }
323                            if v > hi {
324                                hi = v;
325                            }
326                        }
327                    }
328                    if !lo.is_finite() || !hi.is_finite() || hi <= lo {
329                        (0.0, 1.0)
330                    } else {
331                        (lo, hi)
332                    }
333                };
334                if !(cmin.is_finite() && cmax.is_finite()) || cmax <= cmin {
335                    cmin = 0.0;
336                    cmax = 1.0;
337                }
338                let denom = (cmax - cmin).max(f64::EPSILON);
339                for (i, vert) in verts.iter_mut().enumerate().take(n) {
340                    let t = ((vals[i] - cmin) / denom) as f32;
341                    let rgb = self.colormap.map_value(t);
342                    vert.color = [rgb.x, rgb.y, rgb.z, 1.0];
343                }
344            }
345            // Store marker size in normal.z for direct point expansion
346            if let Some(ref sizes) = self.per_point_sizes {
347                for (i, vert) in verts.iter_mut().enumerate() {
348                    let s = sizes.get(i).copied().unwrap_or(self.marker_size);
349                    vert.normal[2] = s.max(1.0);
350                }
351            } else {
352                for v in &mut verts {
353                    v.normal[2] = self.marker_size.max(1.0);
354                }
355            }
356            self.vertices = Some(verts);
357            self.dirty = false;
358        }
359        self.vertices.as_ref().unwrap()
360    }
361
362    /// Get the bounding box of the data
363    pub fn bounds(&mut self) -> BoundingBox {
364        if self.gpu_vertices.is_some() {
365            return self.bounds.unwrap_or_default();
366        }
367        if self.dirty || self.bounds.is_none() {
368            let points: Vec<Vec3> = self
369                .x_data
370                .iter()
371                .zip(self.y_data.iter())
372                .map(|(&x, &y)| Vec3::new(x as f32, y as f32, 0.0))
373                .collect();
374            self.bounds = Some(BoundingBox::from_points(&points));
375        }
376        self.bounds.unwrap()
377    }
378
379    /// Generate complete render data for the graphics pipeline
380    pub fn render_data(&mut self) -> RenderData {
381        let using_gpu = self.gpu_vertices.is_some();
382        let gpu_vertices = self.gpu_vertices.clone();
383        let bounds = self.bounds();
384        let (vertices, vertex_count) = if using_gpu {
385            let count = self
386                .gpu_point_count
387                .or_else(|| gpu_vertices.as_ref().map(|buf| buf.vertex_count))
388                .unwrap_or(0);
389            (Vec::new(), count)
390        } else {
391            let verts = self.generate_vertices().clone();
392            let count = verts.len();
393            (verts, count)
394        };
395
396        let mut material = Material {
397            albedo: self.color,
398            ..Default::default()
399        };
400        // If vertex colors vary across points, prefer per-vertex colors (alpha=0)
401        let is_multi_color = if using_gpu {
402            self.gpu_has_per_point_colors
403                || self.per_point_colors.is_some()
404                || self.color_values.is_some()
405        } else if vertices.is_empty() {
406            false
407        } else {
408            let first = vertices[0].color;
409            vertices.iter().any(|v| v.color != first)
410        };
411        if is_multi_color {
412            material.albedo.w = 0.0;
413        } else if self.filled {
414            material.albedo.w = 1.0;
415        }
416        material.emissive = self.edge_color; // stash edge color
417        material.roughness = self.edge_thickness; // stash thickness in roughness
418        material.metallic = match self.marker_style {
419            MarkerStyle::Circle => 0.0,
420            MarkerStyle::Square => 1.0,
421            MarkerStyle::Triangle => 2.0,
422            MarkerStyle::Diamond => 3.0,
423            MarkerStyle::Plus => 4.0,
424            MarkerStyle::Cross => 5.0,
425            MarkerStyle::Star => 6.0,
426            MarkerStyle::Hexagon => 7.0,
427        };
428        let has_vertex_colors = if using_gpu {
429            self.gpu_has_per_point_colors
430        } else {
431            self.per_point_colors.is_some() || self.color_values.is_some()
432        };
433        let use_vertex_edge_color = self.edge_color_from_vertex_colors && has_vertex_colors;
434        material.emissive.w = if use_vertex_edge_color { 0.0 } else { 1.0 };
435
436        let draw_call = DrawCall {
437            vertex_offset: 0,
438            vertex_count,
439            index_offset: None,
440            index_count: None,
441            instance_count: 1,
442        };
443
444        RenderData {
445            pipeline_type: PipelineType::Points,
446            vertices,
447            indices: None,
448            gpu_vertices,
449            bounds: Some(bounds),
450            material,
451            draw_calls: vec![draw_call],
452            image: None,
453        }
454    }
455
456    /// Get plot statistics for debugging
457    pub fn statistics(&self) -> PlotStatistics {
458        let (min_x, max_x, min_y, max_y) = if !self.x_data.is_empty() {
459            let (min_x, max_x) = self
460                .x_data
461                .iter()
462                .fold((f64::INFINITY, f64::NEG_INFINITY), |(min, max), &x| {
463                    (min.min(x), max.max(x))
464                });
465            let (min_y, max_y) = self
466                .y_data
467                .iter()
468                .fold((f64::INFINITY, f64::NEG_INFINITY), |(min, max), &y| {
469                    (min.min(y), max.max(y))
470                });
471            (min_x, max_x, min_y, max_y)
472        } else if let Some(bounds) = &self.bounds {
473            (
474                bounds.min.x as f64,
475                bounds.max.x as f64,
476                bounds.min.y as f64,
477                bounds.max.y as f64,
478            )
479        } else {
480            (0.0, 0.0, 0.0, 0.0)
481        };
482
483        PlotStatistics {
484            point_count: self.len(),
485            x_range: (min_x, max_x),
486            y_range: (min_y, max_y),
487            memory_usage: self.estimated_memory_usage(),
488        }
489    }
490
491    /// Estimate memory usage in bytes
492    pub fn estimated_memory_usage(&self) -> usize {
493        std::mem::size_of::<f64>() * (self.x_data.len() + self.y_data.len())
494            + self
495                .vertices
496                .as_ref()
497                .map_or(0, |v| v.len() * std::mem::size_of::<Vertex>())
498            + self.gpu_point_count.unwrap_or(0) * std::mem::size_of::<Vertex>()
499    }
500}
501
502/// Plot performance and data statistics
503#[derive(Debug, Clone)]
504pub struct PlotStatistics {
505    pub point_count: usize,
506    pub x_range: (f64, f64),
507    pub y_range: (f64, f64),
508    pub memory_usage: usize,
509}
510
511/// MATLAB-compatible scatter plot creation utilities
512pub mod matlab_compat {
513    use super::*;
514
515    /// Create a simple scatter plot (equivalent to MATLAB's `scatter(x, y)`)
516    pub fn scatter(x: Vec<f64>, y: Vec<f64>) -> Result<ScatterPlot, String> {
517        ScatterPlot::new(x, y)
518    }
519
520    /// Create a scatter plot with specified color and size (`scatter(x, y, size, color)`)
521    pub fn scatter_with_style(
522        x: Vec<f64>,
523        y: Vec<f64>,
524        size: f32,
525        color: &str,
526    ) -> Result<ScatterPlot, String> {
527        let color_vec = parse_matlab_color(color)?;
528        Ok(ScatterPlot::new(x, y)?.with_style(color_vec, size, MarkerStyle::Circle))
529    }
530
531    /// Parse MATLAB color specifications
532    fn parse_matlab_color(color: &str) -> Result<Vec4, String> {
533        match color {
534            "r" | "red" => Ok(Vec4::new(1.0, 0.0, 0.0, 1.0)),
535            "g" | "green" => Ok(Vec4::new(0.0, 1.0, 0.0, 1.0)),
536            "b" | "blue" => Ok(Vec4::new(0.0, 0.0, 1.0, 1.0)),
537            "c" | "cyan" => Ok(Vec4::new(0.0, 1.0, 1.0, 1.0)),
538            "m" | "magenta" => Ok(Vec4::new(1.0, 0.0, 1.0, 1.0)),
539            "y" | "yellow" => Ok(Vec4::new(1.0, 1.0, 0.0, 1.0)),
540            "k" | "black" => Ok(Vec4::new(0.0, 0.0, 0.0, 1.0)),
541            "w" | "white" => Ok(Vec4::new(1.0, 1.0, 1.0, 1.0)),
542            _ => Err(format!("Unknown color: {color}")),
543        }
544    }
545}
546
547#[cfg(test)]
548mod tests {
549    use super::*;
550
551    #[test]
552    fn test_scatter_plot_creation() {
553        let x = vec![0.0, 1.0, 2.0, 3.0];
554        let y = vec![0.0, 1.0, 4.0, 9.0];
555
556        let plot = ScatterPlot::new(x.clone(), y.clone()).unwrap();
557
558        assert_eq!(plot.x_data, x);
559        assert_eq!(plot.y_data, y);
560        assert_eq!(plot.len(), 4);
561        assert!(!plot.is_empty());
562        assert!(plot.visible);
563    }
564
565    #[test]
566    fn test_scatter_plot_styling() {
567        let x = vec![0.0, 1.0, 2.0];
568        let y = vec![1.0, 2.0, 1.5];
569        let color = Vec4::new(0.0, 1.0, 0.0, 1.0);
570
571        let plot = ScatterPlot::new(x, y)
572            .unwrap()
573            .with_style(color, 5.0, MarkerStyle::Square)
574            .with_label("Test Scatter");
575
576        assert_eq!(plot.color, color);
577        assert_eq!(plot.marker_size, 5.0);
578        assert_eq!(plot.marker_style, MarkerStyle::Square);
579        assert_eq!(plot.label, Some("Test Scatter".to_string()));
580    }
581
582    #[test]
583    fn test_scatter_plot_render_data() {
584        let x = vec![0.0, 1.0, 2.0];
585        let y = vec![1.0, 2.0, 1.0];
586
587        let mut plot = ScatterPlot::new(x, y).unwrap();
588        let render_data = plot.render_data();
589
590        assert_eq!(render_data.pipeline_type, PipelineType::Points);
591        assert_eq!(render_data.vertices.len(), 3); // One vertex per point
592        assert!(render_data.indices.is_none());
593        assert_eq!(render_data.draw_calls.len(), 1);
594    }
595
596    #[test]
597    fn test_matlab_compat_scatter() {
598        use super::matlab_compat::*;
599
600        let x = vec![0.0, 1.0];
601        let y = vec![0.0, 1.0];
602
603        let basic_scatter = scatter(x.clone(), y.clone()).unwrap();
604        assert_eq!(basic_scatter.len(), 2);
605
606        let styled_scatter = scatter_with_style(x.clone(), y.clone(), 5.0, "g").unwrap();
607        assert_eq!(styled_scatter.color, Vec4::new(0.0, 1.0, 0.0, 1.0));
608        assert_eq!(styled_scatter.marker_size, 5.0);
609    }
610}