Skip to main content

runmat_plot/plots/
scatter3.rs

1//! 3D scatter plot implementation for MATLAB's `scatter3`.
2
3use crate::context::shared_wgpu_context;
4use crate::core::{
5    vertex_utils, BoundingBox, DrawCall, GpuVertexBuffer, Material, PipelineType, RenderData,
6    Vertex,
7};
8use crate::gpu::scatter2::ScatterColorBuffer;
9use crate::gpu::scatter3::Scatter3GpuInputs;
10use crate::gpu::util::{copy_readback_bytes, readback_scalar_buffer_f64};
11use crate::plots::scatter::MarkerStyle;
12use glam::{Vec3, Vec4};
13
14#[derive(Clone, Copy, Debug)]
15pub struct Scatter3GpuStyle {
16    pub color: Vec4,
17    pub edge_color: Vec4,
18    pub edge_thickness: f32,
19    pub marker_style: MarkerStyle,
20    pub filled: bool,
21    pub has_per_point_colors: bool,
22    pub edge_from_vertex_colors: bool,
23}
24
25/// GPU-accelerated scatter3 plot for MATLAB semantics.
26#[derive(Debug, Clone)]
27pub struct Scatter3Plot {
28    /// Point positions in 3D space.
29    pub points: Vec<Vec3>,
30    /// Per-point RGBA colors.
31    pub colors: Vec<Vec4>,
32    /// Marker size in pixels.
33    pub point_size: f32,
34    /// Optional per-point marker sizes.
35    pub point_sizes: Option<Vec<f32>>,
36    /// Marker edge color.
37    pub edge_color: Vec4,
38    /// Marker edge thickness in pixels.
39    pub edge_thickness: f32,
40    /// Marker shape.
41    pub marker_style: MarkerStyle,
42    /// Whether marker faces are filled.
43    pub filled: bool,
44    /// Whether edge color should come from per-vertex colors.
45    pub edge_color_from_vertex_colors: bool,
46    /// Legend label.
47    pub label: Option<String>,
48    /// Visibility flag.
49    pub visible: bool,
50    vertices: Option<Vec<Vertex>>,
51    bounds: Option<BoundingBox>,
52    gpu_vertices: Option<GpuVertexBuffer>,
53    gpu_point_count: Option<usize>,
54    gpu_inputs: Option<Scatter3GpuInputs>,
55    gpu_has_per_point_colors: bool,
56}
57
58impl Scatter3Plot {
59    pub async fn export_scene_points(&self) -> Result<Vec<Vec3>, String> {
60        if !self.points.is_empty() {
61            return Ok(self.points.clone());
62        }
63
64        if let Some(inputs) = &self.gpu_inputs {
65            let context = shared_wgpu_context().ok_or_else(|| {
66                "scatter3 plot has GPU source data but no shared WGPU context is installed"
67                    .to_string()
68            })?;
69            let len = inputs.len as usize;
70            let x = readback_scalar_buffer_f64(
71                &context.device,
72                &context.queue,
73                &inputs.x_buffer,
74                len,
75                inputs.scalar,
76            )
77            .await?;
78            let y = readback_scalar_buffer_f64(
79                &context.device,
80                &context.queue,
81                &inputs.y_buffer,
82                len,
83                inputs.scalar,
84            )
85            .await?;
86            let z = readback_scalar_buffer_f64(
87                &context.device,
88                &context.queue,
89                &inputs.z_buffer,
90                len,
91                inputs.scalar,
92            )
93            .await?;
94            let points = x
95                .into_iter()
96                .zip(y)
97                .zip(z)
98                .map(|((x, y), z)| Vec3::new(x as f32, y as f32, z as f32))
99                .collect();
100            return Ok(points);
101        }
102
103        if self.gpu_vertices.is_some() {
104            return Err(
105                "scatter3 plot has GPU render vertices but no exportable source data".to_string(),
106            );
107        }
108
109        Ok(Vec::new())
110    }
111
112    pub async fn export_scene_colors(&self, point_count: usize) -> Result<Vec<Vec4>, String> {
113        if self.colors.len() == point_count {
114            return Ok(self.colors.clone());
115        }
116        if self.colors.len() == 1 && !self.gpu_has_per_point_colors {
117            return Ok(vec![self.colors[0]; point_count]);
118        }
119
120        if let Some(inputs) = &self.gpu_inputs {
121            match &inputs.colors {
122                ScatterColorBuffer::None => {
123                    let color = self.colors.first().copied().unwrap_or(Vec4::ONE);
124                    return Ok(vec![color; point_count]);
125                }
126                ScatterColorBuffer::Host(colors) => {
127                    if colors.len() != point_count {
128                        return Err(format!(
129                            "scatter3 color count ({}) does not match point count ({point_count})",
130                            colors.len()
131                        ));
132                    }
133                    return Ok(colors
134                        .iter()
135                        .map(|color| Vec4::from_array(*color))
136                        .collect());
137                }
138                ScatterColorBuffer::Gpu { buffer, components } => {
139                    let context = shared_wgpu_context().ok_or_else(|| {
140                        "scatter3 plot has GPU color data but no shared WGPU context is installed"
141                            .to_string()
142                    })?;
143                    let components = *components as usize;
144                    if components != 3 && components != 4 {
145                        return Err(format!(
146                            "scatter3 GPU color source has unsupported component count {components}"
147                        ));
148                    }
149                    let value_count = point_count
150                        .checked_mul(components)
151                        .ok_or_else(|| "scatter3 GPU color source size overflowed".to_string())?;
152                    let byte_len = value_count
153                        .checked_mul(std::mem::size_of::<f32>())
154                        .ok_or_else(|| {
155                            "scatter3 GPU color source byte size overflowed".to_string()
156                        })?;
157                    let bytes =
158                        copy_readback_bytes(&context.device, &context.queue, buffer, byte_len)
159                            .await?;
160                    let values: &[f32] = bytemuck::try_cast_slice(&bytes)
161                        .map_err(|err| format!("scatter3 GPU color readback failed: {err}"))?;
162                    if values.len() != value_count {
163                        return Err(format!(
164                            "scatter3 GPU color readback returned {} values, expected {value_count}",
165                            values.len()
166                        ));
167                    }
168                    let mut colors = Vec::with_capacity(point_count);
169                    for chunk in values.chunks_exact(components) {
170                        let alpha = if components == 4 { chunk[3] } else { 1.0 };
171                        colors.push(Vec4::new(chunk[0], chunk[1], chunk[2], alpha));
172                    }
173                    return Ok(colors);
174                }
175            }
176        }
177
178        if self.gpu_has_per_point_colors {
179            return Err(
180                "scatter3 plot has GPU per-point colors but no exportable color source".to_string(),
181            );
182        }
183        if self.colors.is_empty() {
184            return Ok(vec![Vec4::ONE; point_count]);
185        }
186        Err(format!(
187            "scatter3 color count ({}) does not match point count {point_count}",
188            self.colors.len()
189        ))
190    }
191
192    /// Create a new scatter3 plot. Colors default to a blue colormap.
193    pub fn new(points: Vec<Vec3>) -> Result<Self, String> {
194        let default_color = Vec4::new(0.1, 0.7, 0.3, 1.0);
195        let colors = vec![default_color; points.len()];
196        Ok(Self {
197            points,
198            colors,
199            point_size: 8.0,
200            point_sizes: None,
201            edge_color: default_color,
202            edge_thickness: 1.0,
203            marker_style: MarkerStyle::Circle,
204            filled: true,
205            edge_color_from_vertex_colors: false,
206            label: None,
207            visible: true,
208            vertices: None,
209            bounds: None,
210            gpu_vertices: None,
211            gpu_point_count: None,
212            gpu_inputs: None,
213            gpu_has_per_point_colors: false,
214        })
215    }
216
217    /// Build a scatter plot directly from a GPU vertex buffer, bypassing CPU copies.
218    pub fn from_gpu_buffer(
219        buffer: GpuVertexBuffer,
220        point_count: usize,
221        style: Scatter3GpuStyle,
222        point_size: f32,
223        bounds: BoundingBox,
224    ) -> Self {
225        Self {
226            points: Vec::new(),
227            colors: vec![style.color],
228            point_size,
229            point_sizes: None,
230            edge_color: style.edge_color,
231            edge_thickness: style.edge_thickness,
232            marker_style: style.marker_style,
233            filled: style.filled,
234            edge_color_from_vertex_colors: style.edge_from_vertex_colors,
235            label: None,
236            visible: true,
237            vertices: None,
238            bounds: Some(bounds),
239            gpu_vertices: Some(buffer),
240            gpu_point_count: Some(point_count),
241            gpu_inputs: None,
242            gpu_has_per_point_colors: style.has_per_point_colors,
243        }
244    }
245
246    pub fn with_gpu_source_inputs(mut self, inputs: Scatter3GpuInputs) -> Self {
247        self.gpu_inputs = Some(inputs);
248        self
249    }
250
251    fn invalidate_gpu_vertices(&mut self) {
252        self.vertices = None;
253        self.gpu_vertices = None;
254        self.gpu_point_count = None;
255    }
256
257    fn clear_gpu_source_inputs(&mut self) {
258        self.gpu_inputs = None;
259        self.gpu_has_per_point_colors = false;
260    }
261
262    /// Override all point colors with a single RGBA value.
263    pub fn with_color(mut self, color: Vec4) -> Self {
264        self.colors = if self.points.is_empty() {
265            vec![color]
266        } else {
267            vec![color; self.points.len()]
268        };
269        self.invalidate_gpu_vertices();
270        self.gpu_has_per_point_colors = false;
271        self
272    }
273
274    /// Supply per-point colors. Length must match the number of points.
275    pub fn with_colors(mut self, colors: Vec<Vec4>) -> Result<Self, String> {
276        if colors.len() != self.points.len() {
277            return Err(format!(
278                "Point cloud color count ({}) must match point count ({})",
279                colors.len(),
280                self.points.len()
281            ));
282        }
283        self.colors = colors;
284        self.invalidate_gpu_vertices();
285        self.clear_gpu_source_inputs();
286        Ok(self)
287    }
288
289    /// Set the legend label.
290    pub fn with_label<S: Into<String>>(mut self, label: S) -> Self {
291        self.label = Some(label.into());
292        self
293    }
294
295    /// Set marker size in pixels.
296    pub fn with_point_size(mut self, size: f32) -> Self {
297        self.point_size = size.max(1.0);
298        self.point_sizes = None;
299        self.invalidate_gpu_vertices();
300        self
301    }
302
303    pub fn set_marker_style(&mut self, style: MarkerStyle) {
304        self.marker_style = style;
305        self.invalidate_gpu_vertices();
306    }
307
308    pub fn set_filled(&mut self, filled: bool) {
309        self.filled = filled;
310        self.invalidate_gpu_vertices();
311    }
312
313    pub fn set_edge_color(&mut self, color: Vec4) {
314        self.edge_color = color;
315        self.invalidate_gpu_vertices();
316    }
317
318    pub fn set_edge_thickness(&mut self, px: f32) {
319        self.edge_thickness = px.max(0.0);
320        self.invalidate_gpu_vertices();
321    }
322
323    pub fn set_edge_color_from_vertex(&mut self, enabled: bool) {
324        self.edge_color_from_vertex_colors = enabled;
325        self.invalidate_gpu_vertices();
326    }
327
328    /// Enable or disable visibility.
329    pub fn set_visible(&mut self, visible: bool) {
330        self.visible = visible;
331    }
332
333    /// Attach a GPU-resident vertex buffer that already encodes this point cloud in the renderer's vertex format.
334    /// When provided, the renderer can skip per-frame uploads and reuse the supplied buffer directly.
335    pub fn with_gpu_vertices(mut self, buffer: GpuVertexBuffer, point_count: usize) -> Self {
336        self.gpu_vertices = Some(buffer);
337        self.gpu_point_count = Some(point_count);
338        self.vertices = None;
339        self.clear_gpu_source_inputs();
340        self
341    }
342
343    /// Supply per-point sizes in pixels.
344    pub fn set_point_sizes(&mut self, sizes: Vec<f32>) {
345        self.point_sizes = Some(sizes);
346        self.invalidate_gpu_vertices();
347    }
348
349    fn ensure_vertices(&mut self) {
350        if self.vertices.is_none() {
351            let mut verts = vertex_utils::create_point_cloud(&self.points, &self.colors);
352            if let Some(sizes) = self.point_sizes.as_ref() {
353                for (idx, vertex) in verts.iter_mut().enumerate() {
354                    let size = sizes.get(idx).copied().unwrap_or(self.point_size);
355                    vertex.normal[2] = size;
356                }
357            } else {
358                for vertex in &mut verts {
359                    vertex.normal[2] = self.point_size;
360                }
361            }
362            self.vertices = Some(verts);
363        }
364    }
365
366    fn ensure_bounds(&mut self) {
367        if self.bounds.is_none() {
368            self.bounds = Some(BoundingBox::from_points(&self.points));
369        }
370    }
371
372    /// Estimate memory required for this plot.
373    pub fn estimated_memory_usage(&self) -> usize {
374        let gpu_bytes = self
375            .gpu_point_count
376            .map(|count| count * std::mem::size_of::<Vertex>())
377            .unwrap_or(0);
378        self.points.len() * std::mem::size_of::<Vec3>()
379            + self.colors.len() * std::mem::size_of::<Vec4>()
380            + self
381                .point_sizes
382                .as_ref()
383                .map(|sizes| sizes.len() * std::mem::size_of::<f32>())
384                .unwrap_or(0)
385            + gpu_bytes
386    }
387
388    /// Generate render data for the renderer.
389    pub fn render_data(&mut self) -> RenderData {
390        let bounds = self.bounds();
391        let vertex_count = self.gpu_point_count.unwrap_or_else(|| {
392            self.ensure_vertices();
393            self.vertices
394                .as_ref()
395                .map(|v| v.len())
396                .unwrap_or(self.points.len())
397        });
398
399        let vertices = if self.gpu_vertices.is_some() {
400            Vec::new()
401        } else {
402            self.ensure_vertices();
403            self.vertices.clone().unwrap_or_default()
404        };
405
406        let is_multi_color = if self.gpu_vertices.is_some() {
407            self.gpu_has_per_point_colors || self.colors.len() > 1
408        } else if vertices.is_empty() {
409            false
410        } else {
411            let first = vertices[0].color;
412            vertices.iter().any(|v| v.color != first)
413        };
414        let has_vertex_colors = if self.gpu_vertices.is_some() {
415            self.gpu_has_per_point_colors
416        } else {
417            self.colors.len() > 1
418        };
419        let use_vertex_edge_color = self.edge_color_from_vertex_colors && has_vertex_colors;
420        let mut material = Material {
421            albedo: self.colors.first().copied().unwrap_or(Vec4::ONE),
422            roughness: self.edge_thickness,
423            metallic: match self.marker_style {
424                MarkerStyle::Circle => 0.0,
425                MarkerStyle::Square => 1.0,
426                MarkerStyle::Triangle => 2.0,
427                MarkerStyle::Diamond => 3.0,
428                MarkerStyle::Plus => 4.0,
429                MarkerStyle::Cross => 5.0,
430                MarkerStyle::Star => 6.0,
431                MarkerStyle::Hexagon => 7.0,
432            },
433            emissive: self.edge_color,
434            alpha_mode: crate::core::scene::AlphaMode::Blend,
435            double_sided: true,
436        };
437        if is_multi_color {
438            material.albedo.w = 0.0;
439        } else if self.filled {
440            material.albedo.w = 1.0;
441        }
442        material.emissive.w = if use_vertex_edge_color { 0.0 } else { 1.0 };
443
444        RenderData {
445            pipeline_type: PipelineType::Scatter3,
446            vertices,
447            indices: None,
448            gpu_vertices: self.gpu_vertices.clone(),
449            bounds: Some(bounds),
450            material,
451            draw_calls: vec![DrawCall {
452                vertex_offset: 0,
453                vertex_count,
454                index_offset: None,
455                index_count: None,
456                instance_count: 1,
457            }],
458            image: None,
459        }
460    }
461
462    /// Compute the axis-aligned bounding box.
463    pub fn bounds(&mut self) -> BoundingBox {
464        self.ensure_bounds();
465        self.bounds.unwrap_or_default()
466    }
467}
468
469#[cfg(test)]
470mod tests {
471    use super::*;
472
473    #[test]
474    fn scatter3_defaults() {
475        let points = vec![Vec3::new(0.0, 0.0, 0.0), Vec3::new(1.0, 2.0, 3.0)];
476        let cloud = Scatter3Plot::new(points.clone()).unwrap();
477        assert_eq!(cloud.points.len(), points.len());
478        assert_eq!(cloud.colors.len(), points.len());
479        assert!(cloud.visible);
480    }
481
482    #[test]
483    fn scatter3_custom_colors() {
484        let points = vec![Vec3::new(0.0, 0.0, 0.0)];
485        let colors = vec![Vec4::new(1.0, 0.0, 0.0, 1.0)];
486        let cloud = Scatter3Plot::new(points)
487            .unwrap()
488            .with_colors(colors)
489            .unwrap();
490        assert_eq!(cloud.colors[0], Vec4::new(1.0, 0.0, 0.0, 1.0));
491    }
492
493    #[test]
494    fn scatter3_render_data_contains_vertices() {
495        let points = vec![Vec3::new(0.0, 0.0, 0.0), Vec3::new(1.0, 1.0, 1.0)];
496        let mut cloud = Scatter3Plot::new(points).unwrap();
497        let render_data = cloud.render_data();
498        assert_eq!(render_data.vertices.len(), 2);
499        assert_eq!(render_data.pipeline_type, PipelineType::Scatter3);
500    }
501
502    #[test]
503    fn scatter3_marker_style_encodes_material_shape_channel() {
504        let points = vec![Vec3::new(0.0, 0.0, 0.0)];
505        let mut cloud = Scatter3Plot::new(points).unwrap();
506        cloud.set_marker_style(MarkerStyle::Diamond);
507        let render_data = cloud.render_data();
508        assert_eq!(render_data.material.metallic, 3.0);
509    }
510
511    #[test]
512    fn scatter3_default_material_uses_plot_color_not_white_override() {
513        let points = vec![Vec3::new(0.0, 0.0, 0.0)];
514        let mut cloud = Scatter3Plot::new(points).unwrap();
515        let render_data = cloud.render_data();
516        assert_ne!(render_data.material.albedo.truncate(), Vec4::ONE.truncate());
517    }
518}