Skip to main content

runmat_plot/plots/
scatter.rs

1//! Scatter plot implementation
2//!
3//! High-performance scatter plotting with GPU acceleration.
4
5use crate::context::shared_wgpu_context;
6use crate::core::{
7    vertex_utils, BoundingBox, DrawCall, GpuVertexBuffer, Material, PipelineType, RenderData,
8    Vertex,
9};
10use crate::gpu::scatter2::Scatter2GpuInputs;
11use crate::gpu::util::readback_scalar_buffer_f64;
12use crate::plots::surface::ColorMap;
13use glam::{Vec3, Vec4};
14
15/// High-performance GPU-accelerated scatter plot
16#[derive(Debug, Clone)]
17pub struct ScatterPlot {
18    /// Raw data points (x, y coordinates)
19    pub x_data: Vec<f64>,
20    pub y_data: Vec<f64>,
21
22    /// Visual styling
23    pub color: Vec4,
24    pub edge_color: Vec4,
25    pub edge_thickness: f32,
26    pub marker_size: f32,
27    pub marker_style: MarkerStyle,
28    pub per_point_sizes: Option<Vec<f32>>, // pixel diameters per point
29    pub per_point_colors: Option<Vec<Vec4>>, // per-point RGBA
30    pub color_values: Option<Vec<f64>>,    // scalar values mapped by colormap
31    pub color_limits: Option<(f64, f64)>,
32    pub colormap: ColorMap,
33    pub filled: bool,
34    pub edge_color_from_vertex_colors: bool,
35
36    /// Metadata
37    pub label: Option<String>,
38    pub visible: bool,
39
40    /// Generated rendering data (cached)
41    vertices: Option<Vec<Vertex>>,
42    bounds: Option<BoundingBox>,
43    dirty: bool,
44    gpu_vertices: Option<GpuVertexBuffer>,
45    gpu_point_count: Option<usize>,
46    gpu_inputs: Option<Scatter2GpuInputs>,
47    gpu_has_per_point_sizes: bool,
48    gpu_has_per_point_colors: bool,
49}
50
51/// Marker styles for scatter plots
52#[derive(Debug, Clone, Copy, PartialEq, Eq)]
53pub enum MarkerStyle {
54    Circle,
55    Square,
56    Triangle,
57    Diamond,
58    Plus,
59    Cross,
60    Star,
61    Hexagon,
62}
63
64impl Default for MarkerStyle {
65    fn default() -> Self {
66        Self::Circle
67    }
68}
69
70#[derive(Clone, Copy, Debug)]
71pub struct ScatterGpuStyle {
72    pub color: Vec4,
73    pub edge_color: Vec4,
74    pub edge_thickness: f32,
75    pub marker_size: f32,
76    pub marker_style: MarkerStyle,
77    pub filled: bool,
78    pub has_per_point_sizes: bool,
79    pub has_per_point_colors: bool,
80    pub edge_from_vertex_colors: bool,
81}
82
83impl ScatterPlot {
84    pub async fn export_scene_xy_data(&self) -> Result<(Vec<f64>, Vec<f64>), String> {
85        if !self.x_data.is_empty() && self.x_data.len() == self.y_data.len() {
86            return Ok((self.x_data.clone(), self.y_data.clone()));
87        }
88        if !self.x_data.is_empty() || !self.y_data.is_empty() {
89            return Err(format!(
90                "scatter plot has partial CPU source data: x has {} values, y has {} values",
91                self.x_data.len(),
92                self.y_data.len()
93            ));
94        }
95
96        if let Some(inputs) = &self.gpu_inputs {
97            let context = shared_wgpu_context().ok_or_else(|| {
98                "scatter plot has GPU source data but no shared WGPU context is installed"
99                    .to_string()
100            })?;
101            let len = inputs.len as usize;
102            let x = readback_scalar_buffer_f64(
103                &context.device,
104                &context.queue,
105                &inputs.x_buffer,
106                len,
107                inputs.scalar,
108            )
109            .await?;
110            let y = readback_scalar_buffer_f64(
111                &context.device,
112                &context.queue,
113                &inputs.y_buffer,
114                len,
115                inputs.scalar,
116            )
117            .await?;
118            return Ok((x, y));
119        }
120
121        if self.gpu_vertices.is_some() {
122            return Err(
123                "scatter plot has GPU render vertices but no exportable source data".to_string(),
124            );
125        }
126
127        Ok((Vec::new(), Vec::new()))
128    }
129
130    /// Create a new scatter plot with data
131    pub fn new(x_data: Vec<f64>, y_data: Vec<f64>) -> Result<Self, String> {
132        if x_data.len() != y_data.len() {
133            return Err(format!(
134                "Data length mismatch: x_data has {} points, y_data has {} points",
135                x_data.len(),
136                y_data.len()
137            ));
138        }
139
140        if x_data.is_empty() {
141            return Err("Cannot create scatter plot with empty data".to_string());
142        }
143
144        Ok(Self {
145            x_data,
146            y_data,
147            color: Vec4::new(1.0, 0.2, 0.2, 1.0), // Brighter red
148            edge_color: Vec4::new(0.0, 0.0, 0.0, 1.0),
149            edge_thickness: 1.0,
150            marker_size: 12.0,
151            marker_style: MarkerStyle::default(),
152            per_point_sizes: None,
153            per_point_colors: None,
154            color_values: None,
155            color_limits: None,
156            colormap: ColorMap::Parula,
157            filled: false,
158            edge_color_from_vertex_colors: false,
159            label: None,
160            visible: true,
161            vertices: None,
162            bounds: None,
163            dirty: true,
164            gpu_vertices: None,
165            gpu_point_count: None,
166            gpu_inputs: None,
167            gpu_has_per_point_sizes: false,
168            gpu_has_per_point_colors: false,
169        })
170    }
171
172    /// Build a scatter plot directly from a GPU vertex buffer to avoid CPU copies.
173    pub fn from_gpu_buffer(
174        buffer: GpuVertexBuffer,
175        point_count: usize,
176        bounds: BoundingBox,
177        style: ScatterGpuStyle,
178    ) -> Self {
179        Self {
180            x_data: Vec::new(),
181            y_data: Vec::new(),
182            color: style.color,
183            edge_color: style.edge_color,
184            edge_thickness: style.edge_thickness,
185            marker_size: style.marker_size,
186            marker_style: style.marker_style,
187            per_point_sizes: None,
188            per_point_colors: None,
189            color_values: None,
190            color_limits: None,
191            colormap: ColorMap::Parula,
192            filled: style.filled,
193            edge_color_from_vertex_colors: style.edge_from_vertex_colors,
194            label: None,
195            visible: true,
196            vertices: None,
197            bounds: Some(bounds),
198            dirty: false,
199            gpu_vertices: Some(buffer),
200            gpu_point_count: Some(point_count),
201            gpu_inputs: None,
202            gpu_has_per_point_sizes: style.has_per_point_sizes,
203            gpu_has_per_point_colors: style.has_per_point_colors,
204        }
205    }
206
207    pub fn with_gpu_source_inputs(mut self, inputs: Scatter2GpuInputs) -> Self {
208        self.gpu_inputs = Some(inputs);
209        self
210    }
211
212    fn invalidate_gpu_vertices(&mut self) {
213        self.gpu_vertices = None;
214        self.gpu_point_count = None;
215    }
216
217    fn clear_gpu_source_inputs(&mut self) {
218        self.gpu_inputs = None;
219        self.gpu_has_per_point_sizes = false;
220        self.gpu_has_per_point_colors = false;
221    }
222
223    /// Create a scatter plot with custom styling
224    pub fn with_style(mut self, color: Vec4, marker_size: f32, marker_style: MarkerStyle) -> Self {
225        self.color = color;
226        self.marker_size = marker_size;
227        self.marker_style = marker_style;
228        self.dirty = true;
229        self.invalidate_gpu_vertices();
230        self.gpu_has_per_point_sizes = false;
231        self.gpu_has_per_point_colors = false;
232        self
233    }
234
235    /// Set the plot label for legends
236    pub fn with_label<S: Into<String>>(mut self, label: S) -> Self {
237        self.label = Some(label.into());
238        self
239    }
240
241    /// Set marker face color
242    pub fn set_face_color(&mut self, color: Vec4) {
243        self.color = color;
244        self.dirty = true;
245        self.invalidate_gpu_vertices();
246    }
247    /// Set marker edge color
248    pub fn set_edge_color(&mut self, color: Vec4) {
249        self.edge_color = color;
250        self.dirty = true;
251        self.invalidate_gpu_vertices();
252    }
253    pub fn set_edge_color_from_vertex(&mut self, enabled: bool) {
254        self.edge_color_from_vertex_colors = enabled;
255    }
256    /// Set marker edge thickness (pixels)
257    pub fn set_edge_thickness(&mut self, px: f32) {
258        self.edge_thickness = px.max(0.0);
259        self.dirty = true;
260        self.invalidate_gpu_vertices();
261    }
262    pub fn set_sizes(&mut self, sizes: Vec<f32>) {
263        self.per_point_sizes = Some(sizes);
264        self.dirty = true;
265        self.invalidate_gpu_vertices();
266        self.gpu_has_per_point_sizes = false;
267    }
268    pub fn set_colors(&mut self, colors: Vec<Vec4>) {
269        self.per_point_colors = Some(colors);
270        self.dirty = true;
271        self.invalidate_gpu_vertices();
272        self.gpu_has_per_point_colors = false;
273    }
274    pub fn set_color_values(&mut self, values: Vec<f64>, limits: Option<(f64, f64)>) {
275        self.color_values = Some(values);
276        self.color_limits = limits;
277        self.dirty = true;
278        self.invalidate_gpu_vertices();
279        self.gpu_has_per_point_colors = false;
280    }
281    pub fn with_colormap(mut self, cmap: ColorMap) -> Self {
282        self.colormap = cmap;
283        self.dirty = true;
284        self.invalidate_gpu_vertices();
285        self
286    }
287    pub fn set_filled(&mut self, filled: bool) {
288        self.filled = filled;
289        self.dirty = true;
290        self.invalidate_gpu_vertices();
291    }
292
293    /// Update the data points
294    pub fn update_data(&mut self, x_data: Vec<f64>, y_data: Vec<f64>) -> Result<(), String> {
295        if x_data.len() != y_data.len() {
296            return Err(format!(
297                "Data length mismatch: x_data has {} points, y_data has {} points",
298                x_data.len(),
299                y_data.len()
300            ));
301        }
302
303        if x_data.is_empty() {
304            return Err("Cannot update with empty data".to_string());
305        }
306
307        self.x_data = x_data;
308        self.y_data = y_data;
309        self.dirty = true;
310        self.invalidate_gpu_vertices();
311        self.clear_gpu_source_inputs();
312        Ok(())
313    }
314
315    /// Set the color of the markers
316    pub fn set_color(&mut self, color: Vec4) {
317        self.color = color;
318        self.dirty = true;
319        self.invalidate_gpu_vertices();
320    }
321
322    /// Set the marker size
323    pub fn set_marker_size(&mut self, size: f32) {
324        self.marker_size = size.max(0.1); // Minimum marker size
325        self.dirty = true;
326        self.invalidate_gpu_vertices();
327    }
328
329    /// Set the marker style
330    pub fn set_marker_style(&mut self, style: MarkerStyle) {
331        self.marker_style = style;
332        self.dirty = true;
333        self.invalidate_gpu_vertices();
334    }
335
336    /// Show or hide the plot
337    pub fn set_visible(&mut self, visible: bool) {
338        self.visible = visible;
339    }
340
341    /// Get the number of data points
342    pub fn len(&self) -> usize {
343        if !self.x_data.is_empty() {
344            self.x_data.len()
345        } else {
346            self.gpu_point_count.unwrap_or(0)
347        }
348    }
349
350    /// Check if the plot has no data
351    pub fn is_empty(&self) -> bool {
352        self.len() == 0
353    }
354
355    /// Generate vertices for GPU rendering
356    pub fn generate_vertices(&mut self) -> &Vec<Vertex> {
357        if self.gpu_vertices.is_some() {
358            if self.vertices.is_none() {
359                self.vertices = Some(Vec::new());
360            }
361            return self.vertices.as_ref().unwrap();
362        }
363        if self.dirty || self.vertices.is_none() {
364            let base_color = self.color;
365            if self.per_point_colors.is_some() || self.color_values.is_some() { /* vertex color takes precedence; shader blends by face alpha */
366            }
367            let mut verts =
368                vertex_utils::create_scatter_plot(&self.x_data, &self.y_data, base_color);
369            // per-point colors
370            if let Some(ref colors) = self.per_point_colors {
371                let m = colors.len().min(verts.len());
372                for i in 0..m {
373                    verts[i].color = colors[i].to_array();
374                }
375            } else if let Some(ref vals) = self.color_values {
376                let n = verts.len();
377                let (mut cmin, mut cmax) = if let Some(lims) = self.color_limits {
378                    lims
379                } else {
380                    let mut lo = f64::INFINITY;
381                    let mut hi = f64::NEG_INFINITY;
382                    for &v in vals {
383                        if v.is_finite() {
384                            if v < lo {
385                                lo = v;
386                            }
387                            if v > hi {
388                                hi = v;
389                            }
390                        }
391                    }
392                    if !lo.is_finite() || !hi.is_finite() || hi <= lo {
393                        (0.0, 1.0)
394                    } else {
395                        (lo, hi)
396                    }
397                };
398                if !(cmin.is_finite() && cmax.is_finite()) || cmax <= cmin {
399                    cmin = 0.0;
400                    cmax = 1.0;
401                }
402                let denom = (cmax - cmin).max(f64::EPSILON);
403                for (i, vert) in verts.iter_mut().enumerate().take(n) {
404                    let t = ((vals[i] - cmin) / denom) as f32;
405                    let rgb = self.colormap.map_value(t);
406                    vert.color = [rgb.x, rgb.y, rgb.z, 1.0];
407                }
408            }
409            // Store marker size in normal.z for direct point expansion
410            if let Some(ref sizes) = self.per_point_sizes {
411                for (i, vert) in verts.iter_mut().enumerate() {
412                    let s = sizes.get(i).copied().unwrap_or(self.marker_size);
413                    vert.normal[2] = s.max(1.0);
414                }
415            } else {
416                for v in &mut verts {
417                    v.normal[2] = self.marker_size.max(1.0);
418                }
419            }
420            self.vertices = Some(verts);
421            self.dirty = false;
422        }
423        self.vertices.as_ref().unwrap()
424    }
425
426    /// Get the bounding box of the data
427    pub fn bounds(&mut self) -> BoundingBox {
428        if self.gpu_vertices.is_some() {
429            return self.bounds.unwrap_or_default();
430        }
431        if self.dirty || self.bounds.is_none() {
432            let points: Vec<Vec3> = self
433                .x_data
434                .iter()
435                .zip(self.y_data.iter())
436                .map(|(&x, &y)| Vec3::new(x as f32, y as f32, 0.0))
437                .collect();
438            self.bounds = Some(BoundingBox::from_points(&points));
439        }
440        self.bounds.unwrap()
441    }
442
443    /// Generate complete render data for the graphics pipeline
444    pub fn render_data(&mut self) -> RenderData {
445        let using_gpu = self.gpu_vertices.is_some();
446        let gpu_vertices = self.gpu_vertices.clone();
447        let bounds = self.bounds();
448        let (vertices, vertex_count) = if using_gpu {
449            let count = self
450                .gpu_point_count
451                .or_else(|| gpu_vertices.as_ref().map(|buf| buf.vertex_count))
452                .unwrap_or(0);
453            (Vec::new(), count)
454        } else {
455            let verts = self.generate_vertices().clone();
456            let count = verts.len();
457            (verts, count)
458        };
459
460        let mut material = Material {
461            albedo: self.color,
462            ..Default::default()
463        };
464        // If vertex colors vary across points, prefer per-vertex colors (alpha=0)
465        let is_multi_color = if using_gpu {
466            self.gpu_has_per_point_colors
467                || self.per_point_colors.is_some()
468                || self.color_values.is_some()
469        } else if vertices.is_empty() {
470            false
471        } else {
472            let first = vertices[0].color;
473            vertices.iter().any(|v| v.color != first)
474        };
475        if is_multi_color {
476            material.albedo.w = 0.0;
477        } else if self.filled {
478            material.albedo.w = 1.0;
479        }
480        material.emissive = self.edge_color; // stash edge color
481        material.roughness = self.edge_thickness; // stash thickness in roughness
482        material.metallic = match self.marker_style {
483            MarkerStyle::Circle => 0.0,
484            MarkerStyle::Square => 1.0,
485            MarkerStyle::Triangle => 2.0,
486            MarkerStyle::Diamond => 3.0,
487            MarkerStyle::Plus => 4.0,
488            MarkerStyle::Cross => 5.0,
489            MarkerStyle::Star => 6.0,
490            MarkerStyle::Hexagon => 7.0,
491        };
492        let has_vertex_colors = if using_gpu {
493            self.gpu_has_per_point_colors
494        } else {
495            self.per_point_colors.is_some() || self.color_values.is_some()
496        };
497        let use_vertex_edge_color = self.edge_color_from_vertex_colors && has_vertex_colors;
498        material.emissive.w = if use_vertex_edge_color { 0.0 } else { 1.0 };
499
500        let draw_call = DrawCall {
501            vertex_offset: 0,
502            vertex_count,
503            index_offset: None,
504            index_count: None,
505            instance_count: 1,
506        };
507
508        RenderData {
509            pipeline_type: PipelineType::Points,
510            vertices,
511            indices: None,
512            gpu_vertices,
513            bounds: Some(bounds),
514            material,
515            draw_calls: vec![draw_call],
516            image: None,
517        }
518    }
519
520    /// Get plot statistics for debugging
521    pub fn statistics(&self) -> PlotStatistics {
522        let (min_x, max_x, min_y, max_y) = if !self.x_data.is_empty() {
523            let (min_x, max_x) = self
524                .x_data
525                .iter()
526                .fold((f64::INFINITY, f64::NEG_INFINITY), |(min, max), &x| {
527                    (min.min(x), max.max(x))
528                });
529            let (min_y, max_y) = self
530                .y_data
531                .iter()
532                .fold((f64::INFINITY, f64::NEG_INFINITY), |(min, max), &y| {
533                    (min.min(y), max.max(y))
534                });
535            (min_x, max_x, min_y, max_y)
536        } else if let Some(bounds) = &self.bounds {
537            (
538                bounds.min.x as f64,
539                bounds.max.x as f64,
540                bounds.min.y as f64,
541                bounds.max.y as f64,
542            )
543        } else {
544            (0.0, 0.0, 0.0, 0.0)
545        };
546
547        PlotStatistics {
548            point_count: self.len(),
549            x_range: (min_x, max_x),
550            y_range: (min_y, max_y),
551            memory_usage: self.estimated_memory_usage(),
552        }
553    }
554
555    /// Estimate memory usage in bytes
556    pub fn estimated_memory_usage(&self) -> usize {
557        std::mem::size_of::<f64>() * (self.x_data.len() + self.y_data.len())
558            + self
559                .vertices
560                .as_ref()
561                .map_or(0, |v| v.len() * std::mem::size_of::<Vertex>())
562            + self.gpu_point_count.unwrap_or(0) * std::mem::size_of::<Vertex>()
563    }
564}
565
566/// Plot performance and data statistics
567#[derive(Debug, Clone)]
568pub struct PlotStatistics {
569    pub point_count: usize,
570    pub x_range: (f64, f64),
571    pub y_range: (f64, f64),
572    pub memory_usage: usize,
573}
574
575/// MATLAB-compatible scatter plot creation utilities
576pub mod matlab_compat {
577    use super::*;
578
579    /// Create a simple scatter plot (equivalent to MATLAB's `scatter(x, y)`)
580    pub fn scatter(x: Vec<f64>, y: Vec<f64>) -> Result<ScatterPlot, String> {
581        ScatterPlot::new(x, y)
582    }
583
584    /// Create a scatter plot with specified color and size (`scatter(x, y, size, color)`)
585    pub fn scatter_with_style(
586        x: Vec<f64>,
587        y: Vec<f64>,
588        size: f32,
589        color: &str,
590    ) -> Result<ScatterPlot, String> {
591        let color_vec = parse_matlab_color(color)?;
592        Ok(ScatterPlot::new(x, y)?.with_style(color_vec, size, MarkerStyle::Circle))
593    }
594
595    /// Parse MATLAB color specifications
596    fn parse_matlab_color(color: &str) -> Result<Vec4, String> {
597        match color {
598            "r" | "red" => Ok(Vec4::new(1.0, 0.0, 0.0, 1.0)),
599            "g" | "green" => Ok(Vec4::new(0.0, 1.0, 0.0, 1.0)),
600            "b" | "blue" => Ok(Vec4::new(0.0, 0.0, 1.0, 1.0)),
601            "c" | "cyan" => Ok(Vec4::new(0.0, 1.0, 1.0, 1.0)),
602            "m" | "magenta" => Ok(Vec4::new(1.0, 0.0, 1.0, 1.0)),
603            "y" | "yellow" => Ok(Vec4::new(1.0, 1.0, 0.0, 1.0)),
604            "k" | "black" => Ok(Vec4::new(0.0, 0.0, 0.0, 1.0)),
605            "w" | "white" => Ok(Vec4::new(1.0, 1.0, 1.0, 1.0)),
606            _ => Err(format!("Unknown color: {color}")),
607        }
608    }
609}
610
611#[cfg(test)]
612mod tests {
613    use super::*;
614
615    #[test]
616    fn test_scatter_plot_creation() {
617        let x = vec![0.0, 1.0, 2.0, 3.0];
618        let y = vec![0.0, 1.0, 4.0, 9.0];
619
620        let plot = ScatterPlot::new(x.clone(), y.clone()).unwrap();
621
622        assert_eq!(plot.x_data, x);
623        assert_eq!(plot.y_data, y);
624        assert_eq!(plot.len(), 4);
625        assert!(!plot.is_empty());
626        assert!(plot.visible);
627    }
628
629    #[test]
630    fn test_scatter_plot_styling() {
631        let x = vec![0.0, 1.0, 2.0];
632        let y = vec![1.0, 2.0, 1.5];
633        let color = Vec4::new(0.0, 1.0, 0.0, 1.0);
634
635        let plot = ScatterPlot::new(x, y)
636            .unwrap()
637            .with_style(color, 5.0, MarkerStyle::Square)
638            .with_label("Test Scatter");
639
640        assert_eq!(plot.color, color);
641        assert_eq!(plot.marker_size, 5.0);
642        assert_eq!(plot.marker_style, MarkerStyle::Square);
643        assert_eq!(plot.label, Some("Test Scatter".to_string()));
644    }
645
646    #[test]
647    fn test_scatter_plot_render_data() {
648        let x = vec![0.0, 1.0, 2.0];
649        let y = vec![1.0, 2.0, 1.0];
650
651        let mut plot = ScatterPlot::new(x, y).unwrap();
652        let render_data = plot.render_data();
653
654        assert_eq!(render_data.pipeline_type, PipelineType::Points);
655        assert_eq!(render_data.vertices.len(), 3); // One vertex per point
656        assert!(render_data.indices.is_none());
657        assert_eq!(render_data.draw_calls.len(), 1);
658    }
659
660    #[test]
661    fn test_matlab_compat_scatter() {
662        use super::matlab_compat::*;
663
664        let x = vec![0.0, 1.0];
665        let y = vec![0.0, 1.0];
666
667        let basic_scatter = scatter(x.clone(), y.clone()).unwrap();
668        assert_eq!(basic_scatter.len(), 2);
669
670        let styled_scatter = scatter_with_style(x.clone(), y.clone(), 5.0, "g").unwrap();
671        assert_eq!(styled_scatter.color, Vec4::new(0.0, 1.0, 0.0, 1.0));
672        assert_eq!(styled_scatter.marker_size, 5.0);
673    }
674}