1use crate::context::shared_wgpu_context;
6use crate::core::{
7 vertex_utils, BoundingBox, DrawCall, GpuVertexBuffer, Material, PipelineType, RenderData,
8 Vertex,
9};
10use crate::gpu::scatter2::Scatter2GpuInputs;
11use crate::gpu::util::readback_scalar_buffer_f64;
12use crate::plots::surface::ColorMap;
13use glam::{Vec3, Vec4};
14
15#[derive(Debug, Clone)]
17pub struct ScatterPlot {
18 pub x_data: Vec<f64>,
20 pub y_data: Vec<f64>,
21
22 pub color: Vec4,
24 pub edge_color: Vec4,
25 pub edge_thickness: f32,
26 pub marker_size: f32,
27 pub marker_style: MarkerStyle,
28 pub per_point_sizes: Option<Vec<f32>>, pub per_point_colors: Option<Vec<Vec4>>, pub color_values: Option<Vec<f64>>, pub color_limits: Option<(f64, f64)>,
32 pub colormap: ColorMap,
33 pub filled: bool,
34 pub edge_color_from_vertex_colors: bool,
35
36 pub label: Option<String>,
38 pub visible: bool,
39
40 vertices: Option<Vec<Vertex>>,
42 bounds: Option<BoundingBox>,
43 dirty: bool,
44 gpu_vertices: Option<GpuVertexBuffer>,
45 gpu_point_count: Option<usize>,
46 gpu_inputs: Option<Scatter2GpuInputs>,
47 gpu_has_per_point_sizes: bool,
48 gpu_has_per_point_colors: bool,
49}
50
51#[derive(Debug, Clone, Copy, PartialEq, Eq)]
53pub enum MarkerStyle {
54 Circle,
55 Square,
56 Triangle,
57 Diamond,
58 Plus,
59 Cross,
60 Star,
61 Hexagon,
62}
63
64impl Default for MarkerStyle {
65 fn default() -> Self {
66 Self::Circle
67 }
68}
69
70#[derive(Clone, Copy, Debug)]
71pub struct ScatterGpuStyle {
72 pub color: Vec4,
73 pub edge_color: Vec4,
74 pub edge_thickness: f32,
75 pub marker_size: f32,
76 pub marker_style: MarkerStyle,
77 pub filled: bool,
78 pub has_per_point_sizes: bool,
79 pub has_per_point_colors: bool,
80 pub edge_from_vertex_colors: bool,
81}
82
83impl ScatterPlot {
84 pub async fn export_scene_xy_data(&self) -> Result<(Vec<f64>, Vec<f64>), String> {
85 if !self.x_data.is_empty() && self.x_data.len() == self.y_data.len() {
86 return Ok((self.x_data.clone(), self.y_data.clone()));
87 }
88 if !self.x_data.is_empty() || !self.y_data.is_empty() {
89 return Err(format!(
90 "scatter plot has partial CPU source data: x has {} values, y has {} values",
91 self.x_data.len(),
92 self.y_data.len()
93 ));
94 }
95
96 if let Some(inputs) = &self.gpu_inputs {
97 let context = shared_wgpu_context().ok_or_else(|| {
98 "scatter plot has GPU source data but no shared WGPU context is installed"
99 .to_string()
100 })?;
101 let len = inputs.len as usize;
102 let x = readback_scalar_buffer_f64(
103 &context.device,
104 &context.queue,
105 &inputs.x_buffer,
106 len,
107 inputs.scalar,
108 )
109 .await?;
110 let y = readback_scalar_buffer_f64(
111 &context.device,
112 &context.queue,
113 &inputs.y_buffer,
114 len,
115 inputs.scalar,
116 )
117 .await?;
118 return Ok((x, y));
119 }
120
121 if self.gpu_vertices.is_some() {
122 return Err(
123 "scatter plot has GPU render vertices but no exportable source data".to_string(),
124 );
125 }
126
127 Ok((Vec::new(), Vec::new()))
128 }
129
130 pub fn new(x_data: Vec<f64>, y_data: Vec<f64>) -> Result<Self, String> {
132 if x_data.len() != y_data.len() {
133 return Err(format!(
134 "Data length mismatch: x_data has {} points, y_data has {} points",
135 x_data.len(),
136 y_data.len()
137 ));
138 }
139
140 if x_data.is_empty() {
141 return Err("Cannot create scatter plot with empty data".to_string());
142 }
143
144 Ok(Self {
145 x_data,
146 y_data,
147 color: Vec4::new(1.0, 0.2, 0.2, 1.0), edge_color: Vec4::new(0.0, 0.0, 0.0, 1.0),
149 edge_thickness: 1.0,
150 marker_size: 12.0,
151 marker_style: MarkerStyle::default(),
152 per_point_sizes: None,
153 per_point_colors: None,
154 color_values: None,
155 color_limits: None,
156 colormap: ColorMap::Parula,
157 filled: false,
158 edge_color_from_vertex_colors: false,
159 label: None,
160 visible: true,
161 vertices: None,
162 bounds: None,
163 dirty: true,
164 gpu_vertices: None,
165 gpu_point_count: None,
166 gpu_inputs: None,
167 gpu_has_per_point_sizes: false,
168 gpu_has_per_point_colors: false,
169 })
170 }
171
172 pub fn from_gpu_buffer(
174 buffer: GpuVertexBuffer,
175 point_count: usize,
176 bounds: BoundingBox,
177 style: ScatterGpuStyle,
178 ) -> Self {
179 Self {
180 x_data: Vec::new(),
181 y_data: Vec::new(),
182 color: style.color,
183 edge_color: style.edge_color,
184 edge_thickness: style.edge_thickness,
185 marker_size: style.marker_size,
186 marker_style: style.marker_style,
187 per_point_sizes: None,
188 per_point_colors: None,
189 color_values: None,
190 color_limits: None,
191 colormap: ColorMap::Parula,
192 filled: style.filled,
193 edge_color_from_vertex_colors: style.edge_from_vertex_colors,
194 label: None,
195 visible: true,
196 vertices: None,
197 bounds: Some(bounds),
198 dirty: false,
199 gpu_vertices: Some(buffer),
200 gpu_point_count: Some(point_count),
201 gpu_inputs: None,
202 gpu_has_per_point_sizes: style.has_per_point_sizes,
203 gpu_has_per_point_colors: style.has_per_point_colors,
204 }
205 }
206
207 pub fn with_gpu_source_inputs(mut self, inputs: Scatter2GpuInputs) -> Self {
208 self.gpu_inputs = Some(inputs);
209 self
210 }
211
212 fn invalidate_gpu_vertices(&mut self) {
213 self.gpu_vertices = None;
214 self.gpu_point_count = None;
215 }
216
217 fn clear_gpu_source_inputs(&mut self) {
218 self.gpu_inputs = None;
219 self.gpu_has_per_point_sizes = false;
220 self.gpu_has_per_point_colors = false;
221 }
222
223 pub fn with_style(mut self, color: Vec4, marker_size: f32, marker_style: MarkerStyle) -> Self {
225 self.color = color;
226 self.marker_size = marker_size;
227 self.marker_style = marker_style;
228 self.dirty = true;
229 self.invalidate_gpu_vertices();
230 self.gpu_has_per_point_sizes = false;
231 self.gpu_has_per_point_colors = false;
232 self
233 }
234
235 pub fn with_label<S: Into<String>>(mut self, label: S) -> Self {
237 self.label = Some(label.into());
238 self
239 }
240
241 pub fn set_face_color(&mut self, color: Vec4) {
243 self.color = color;
244 self.dirty = true;
245 self.invalidate_gpu_vertices();
246 }
247 pub fn set_edge_color(&mut self, color: Vec4) {
249 self.edge_color = color;
250 self.dirty = true;
251 self.invalidate_gpu_vertices();
252 }
253 pub fn set_edge_color_from_vertex(&mut self, enabled: bool) {
254 self.edge_color_from_vertex_colors = enabled;
255 }
256 pub fn set_edge_thickness(&mut self, px: f32) {
258 self.edge_thickness = px.max(0.0);
259 self.dirty = true;
260 self.invalidate_gpu_vertices();
261 }
262 pub fn set_sizes(&mut self, sizes: Vec<f32>) {
263 self.per_point_sizes = Some(sizes);
264 self.dirty = true;
265 self.invalidate_gpu_vertices();
266 self.gpu_has_per_point_sizes = false;
267 }
268 pub fn set_colors(&mut self, colors: Vec<Vec4>) {
269 self.per_point_colors = Some(colors);
270 self.dirty = true;
271 self.invalidate_gpu_vertices();
272 self.gpu_has_per_point_colors = false;
273 }
274 pub fn set_color_values(&mut self, values: Vec<f64>, limits: Option<(f64, f64)>) {
275 self.color_values = Some(values);
276 self.color_limits = limits;
277 self.dirty = true;
278 self.invalidate_gpu_vertices();
279 self.gpu_has_per_point_colors = false;
280 }
281 pub fn with_colormap(mut self, cmap: ColorMap) -> Self {
282 self.colormap = cmap;
283 self.dirty = true;
284 self.invalidate_gpu_vertices();
285 self
286 }
287 pub fn set_filled(&mut self, filled: bool) {
288 self.filled = filled;
289 self.dirty = true;
290 self.invalidate_gpu_vertices();
291 }
292
293 pub fn update_data(&mut self, x_data: Vec<f64>, y_data: Vec<f64>) -> Result<(), String> {
295 if x_data.len() != y_data.len() {
296 return Err(format!(
297 "Data length mismatch: x_data has {} points, y_data has {} points",
298 x_data.len(),
299 y_data.len()
300 ));
301 }
302
303 if x_data.is_empty() {
304 return Err("Cannot update with empty data".to_string());
305 }
306
307 self.x_data = x_data;
308 self.y_data = y_data;
309 self.dirty = true;
310 self.invalidate_gpu_vertices();
311 self.clear_gpu_source_inputs();
312 Ok(())
313 }
314
315 pub fn set_color(&mut self, color: Vec4) {
317 self.color = color;
318 self.dirty = true;
319 self.invalidate_gpu_vertices();
320 }
321
322 pub fn set_marker_size(&mut self, size: f32) {
324 self.marker_size = size.max(0.1); self.dirty = true;
326 self.invalidate_gpu_vertices();
327 }
328
329 pub fn set_marker_style(&mut self, style: MarkerStyle) {
331 self.marker_style = style;
332 self.dirty = true;
333 self.invalidate_gpu_vertices();
334 }
335
336 pub fn set_visible(&mut self, visible: bool) {
338 self.visible = visible;
339 }
340
341 pub fn len(&self) -> usize {
343 if !self.x_data.is_empty() {
344 self.x_data.len()
345 } else {
346 self.gpu_point_count.unwrap_or(0)
347 }
348 }
349
350 pub fn is_empty(&self) -> bool {
352 self.len() == 0
353 }
354
355 pub fn generate_vertices(&mut self) -> &Vec<Vertex> {
357 if self.gpu_vertices.is_some() {
358 if self.vertices.is_none() {
359 self.vertices = Some(Vec::new());
360 }
361 return self.vertices.as_ref().unwrap();
362 }
363 if self.dirty || self.vertices.is_none() {
364 let base_color = self.color;
365 if self.per_point_colors.is_some() || self.color_values.is_some() { }
367 let mut verts =
368 vertex_utils::create_scatter_plot(&self.x_data, &self.y_data, base_color);
369 if let Some(ref colors) = self.per_point_colors {
371 let m = colors.len().min(verts.len());
372 for i in 0..m {
373 verts[i].color = colors[i].to_array();
374 }
375 } else if let Some(ref vals) = self.color_values {
376 let n = verts.len();
377 let (mut cmin, mut cmax) = if let Some(lims) = self.color_limits {
378 lims
379 } else {
380 let mut lo = f64::INFINITY;
381 let mut hi = f64::NEG_INFINITY;
382 for &v in vals {
383 if v.is_finite() {
384 if v < lo {
385 lo = v;
386 }
387 if v > hi {
388 hi = v;
389 }
390 }
391 }
392 if !lo.is_finite() || !hi.is_finite() || hi <= lo {
393 (0.0, 1.0)
394 } else {
395 (lo, hi)
396 }
397 };
398 if !(cmin.is_finite() && cmax.is_finite()) || cmax <= cmin {
399 cmin = 0.0;
400 cmax = 1.0;
401 }
402 let denom = (cmax - cmin).max(f64::EPSILON);
403 for (i, vert) in verts.iter_mut().enumerate().take(n) {
404 let t = ((vals[i] - cmin) / denom) as f32;
405 let rgb = self.colormap.map_value(t);
406 vert.color = [rgb.x, rgb.y, rgb.z, 1.0];
407 }
408 }
409 if let Some(ref sizes) = self.per_point_sizes {
411 for (i, vert) in verts.iter_mut().enumerate() {
412 let s = sizes.get(i).copied().unwrap_or(self.marker_size);
413 vert.normal[2] = s.max(1.0);
414 }
415 } else {
416 for v in &mut verts {
417 v.normal[2] = self.marker_size.max(1.0);
418 }
419 }
420 self.vertices = Some(verts);
421 self.dirty = false;
422 }
423 self.vertices.as_ref().unwrap()
424 }
425
426 pub fn bounds(&mut self) -> BoundingBox {
428 if self.gpu_vertices.is_some() {
429 return self.bounds.unwrap_or_default();
430 }
431 if self.dirty || self.bounds.is_none() {
432 let points: Vec<Vec3> = self
433 .x_data
434 .iter()
435 .zip(self.y_data.iter())
436 .map(|(&x, &y)| Vec3::new(x as f32, y as f32, 0.0))
437 .collect();
438 self.bounds = Some(BoundingBox::from_points(&points));
439 }
440 self.bounds.unwrap()
441 }
442
443 pub fn render_data(&mut self) -> RenderData {
445 let using_gpu = self.gpu_vertices.is_some();
446 let gpu_vertices = self.gpu_vertices.clone();
447 let bounds = self.bounds();
448 let (vertices, vertex_count) = if using_gpu {
449 let count = self
450 .gpu_point_count
451 .or_else(|| gpu_vertices.as_ref().map(|buf| buf.vertex_count))
452 .unwrap_or(0);
453 (Vec::new(), count)
454 } else {
455 let verts = self.generate_vertices().clone();
456 let count = verts.len();
457 (verts, count)
458 };
459
460 let mut material = Material {
461 albedo: self.color,
462 ..Default::default()
463 };
464 let is_multi_color = if using_gpu {
466 self.gpu_has_per_point_colors
467 || self.per_point_colors.is_some()
468 || self.color_values.is_some()
469 } else if vertices.is_empty() {
470 false
471 } else {
472 let first = vertices[0].color;
473 vertices.iter().any(|v| v.color != first)
474 };
475 if is_multi_color {
476 material.albedo.w = 0.0;
477 } else if self.filled {
478 material.albedo.w = 1.0;
479 }
480 material.emissive = self.edge_color; material.roughness = self.edge_thickness; material.metallic = match self.marker_style {
483 MarkerStyle::Circle => 0.0,
484 MarkerStyle::Square => 1.0,
485 MarkerStyle::Triangle => 2.0,
486 MarkerStyle::Diamond => 3.0,
487 MarkerStyle::Plus => 4.0,
488 MarkerStyle::Cross => 5.0,
489 MarkerStyle::Star => 6.0,
490 MarkerStyle::Hexagon => 7.0,
491 };
492 let has_vertex_colors = if using_gpu {
493 self.gpu_has_per_point_colors
494 } else {
495 self.per_point_colors.is_some() || self.color_values.is_some()
496 };
497 let use_vertex_edge_color = self.edge_color_from_vertex_colors && has_vertex_colors;
498 material.emissive.w = if use_vertex_edge_color { 0.0 } else { 1.0 };
499
500 let draw_call = DrawCall {
501 vertex_offset: 0,
502 vertex_count,
503 index_offset: None,
504 index_count: None,
505 instance_count: 1,
506 };
507
508 RenderData {
509 pipeline_type: PipelineType::Points,
510 vertices,
511 indices: None,
512 gpu_vertices,
513 bounds: Some(bounds),
514 material,
515 draw_calls: vec![draw_call],
516 image: None,
517 }
518 }
519
520 pub fn statistics(&self) -> PlotStatistics {
522 let (min_x, max_x, min_y, max_y) = if !self.x_data.is_empty() {
523 let (min_x, max_x) = self
524 .x_data
525 .iter()
526 .fold((f64::INFINITY, f64::NEG_INFINITY), |(min, max), &x| {
527 (min.min(x), max.max(x))
528 });
529 let (min_y, max_y) = self
530 .y_data
531 .iter()
532 .fold((f64::INFINITY, f64::NEG_INFINITY), |(min, max), &y| {
533 (min.min(y), max.max(y))
534 });
535 (min_x, max_x, min_y, max_y)
536 } else if let Some(bounds) = &self.bounds {
537 (
538 bounds.min.x as f64,
539 bounds.max.x as f64,
540 bounds.min.y as f64,
541 bounds.max.y as f64,
542 )
543 } else {
544 (0.0, 0.0, 0.0, 0.0)
545 };
546
547 PlotStatistics {
548 point_count: self.len(),
549 x_range: (min_x, max_x),
550 y_range: (min_y, max_y),
551 memory_usage: self.estimated_memory_usage(),
552 }
553 }
554
555 pub fn estimated_memory_usage(&self) -> usize {
557 std::mem::size_of::<f64>() * (self.x_data.len() + self.y_data.len())
558 + self
559 .vertices
560 .as_ref()
561 .map_or(0, |v| v.len() * std::mem::size_of::<Vertex>())
562 + self.gpu_point_count.unwrap_or(0) * std::mem::size_of::<Vertex>()
563 }
564}
565
566#[derive(Debug, Clone)]
568pub struct PlotStatistics {
569 pub point_count: usize,
570 pub x_range: (f64, f64),
571 pub y_range: (f64, f64),
572 pub memory_usage: usize,
573}
574
575pub mod matlab_compat {
577 use super::*;
578
579 pub fn scatter(x: Vec<f64>, y: Vec<f64>) -> Result<ScatterPlot, String> {
581 ScatterPlot::new(x, y)
582 }
583
584 pub fn scatter_with_style(
586 x: Vec<f64>,
587 y: Vec<f64>,
588 size: f32,
589 color: &str,
590 ) -> Result<ScatterPlot, String> {
591 let color_vec = parse_matlab_color(color)?;
592 Ok(ScatterPlot::new(x, y)?.with_style(color_vec, size, MarkerStyle::Circle))
593 }
594
595 fn parse_matlab_color(color: &str) -> Result<Vec4, String> {
597 match color {
598 "r" | "red" => Ok(Vec4::new(1.0, 0.0, 0.0, 1.0)),
599 "g" | "green" => Ok(Vec4::new(0.0, 1.0, 0.0, 1.0)),
600 "b" | "blue" => Ok(Vec4::new(0.0, 0.0, 1.0, 1.0)),
601 "c" | "cyan" => Ok(Vec4::new(0.0, 1.0, 1.0, 1.0)),
602 "m" | "magenta" => Ok(Vec4::new(1.0, 0.0, 1.0, 1.0)),
603 "y" | "yellow" => Ok(Vec4::new(1.0, 1.0, 0.0, 1.0)),
604 "k" | "black" => Ok(Vec4::new(0.0, 0.0, 0.0, 1.0)),
605 "w" | "white" => Ok(Vec4::new(1.0, 1.0, 1.0, 1.0)),
606 _ => Err(format!("Unknown color: {color}")),
607 }
608 }
609}
610
611#[cfg(test)]
612mod tests {
613 use super::*;
614
615 #[test]
616 fn test_scatter_plot_creation() {
617 let x = vec![0.0, 1.0, 2.0, 3.0];
618 let y = vec![0.0, 1.0, 4.0, 9.0];
619
620 let plot = ScatterPlot::new(x.clone(), y.clone()).unwrap();
621
622 assert_eq!(plot.x_data, x);
623 assert_eq!(plot.y_data, y);
624 assert_eq!(plot.len(), 4);
625 assert!(!plot.is_empty());
626 assert!(plot.visible);
627 }
628
629 #[test]
630 fn test_scatter_plot_styling() {
631 let x = vec![0.0, 1.0, 2.0];
632 let y = vec![1.0, 2.0, 1.5];
633 let color = Vec4::new(0.0, 1.0, 0.0, 1.0);
634
635 let plot = ScatterPlot::new(x, y)
636 .unwrap()
637 .with_style(color, 5.0, MarkerStyle::Square)
638 .with_label("Test Scatter");
639
640 assert_eq!(plot.color, color);
641 assert_eq!(plot.marker_size, 5.0);
642 assert_eq!(plot.marker_style, MarkerStyle::Square);
643 assert_eq!(plot.label, Some("Test Scatter".to_string()));
644 }
645
646 #[test]
647 fn test_scatter_plot_render_data() {
648 let x = vec![0.0, 1.0, 2.0];
649 let y = vec![1.0, 2.0, 1.0];
650
651 let mut plot = ScatterPlot::new(x, y).unwrap();
652 let render_data = plot.render_data();
653
654 assert_eq!(render_data.pipeline_type, PipelineType::Points);
655 assert_eq!(render_data.vertices.len(), 3); assert!(render_data.indices.is_none());
657 assert_eq!(render_data.draw_calls.len(), 1);
658 }
659
660 #[test]
661 fn test_matlab_compat_scatter() {
662 use super::matlab_compat::*;
663
664 let x = vec![0.0, 1.0];
665 let y = vec![0.0, 1.0];
666
667 let basic_scatter = scatter(x.clone(), y.clone()).unwrap();
668 assert_eq!(basic_scatter.len(), 2);
669
670 let styled_scatter = scatter_with_style(x.clone(), y.clone(), 5.0, "g").unwrap();
671 assert_eq!(styled_scatter.color, Vec4::new(0.0, 1.0, 0.0, 1.0));
672 assert_eq!(styled_scatter.marker_size, 5.0);
673 }
674}