Skip to main content

runmat_plot/gpu/
scatter2.rs

1use crate::core::renderer::Vertex;
2use crate::core::scene::{DrawIndirectArgsRaw, GpuVertexBuffer};
3use crate::gpu::shaders;
4use crate::gpu::{tuning, ScalarType};
5use glam::Vec4;
6use std::sync::Arc;
7use wgpu::util::DeviceExt;
8
9/// Optional per-point scalar attribute buffer (marker sizes).
10pub enum ScatterAttributeBuffer {
11    None,
12    Host(Vec<f32>),
13    Gpu(Arc<wgpu::Buffer>),
14}
15
16impl ScatterAttributeBuffer {
17    pub fn has_data(&self) -> bool {
18        !matches!(self, ScatterAttributeBuffer::None)
19    }
20}
21
22/// Optional per-point color buffer. Host data is supplied as RGBA tuples; GPU buffers
23/// may contain RGB (3 floats) or RGBA (4 floats) sequences.
24pub enum ScatterColorBuffer {
25    None,
26    Host(Vec<[f32; 4]>),
27    Gpu {
28        buffer: Arc<wgpu::Buffer>,
29        components: u32,
30    },
31}
32
33impl ScatterColorBuffer {
34    pub fn has_data(&self) -> bool {
35        !matches!(self, ScatterColorBuffer::None)
36    }
37
38    pub fn stride(&self) -> u32 {
39        match self {
40            ScatterColorBuffer::None => 4,
41            ScatterColorBuffer::Host(_) => 4,
42            ScatterColorBuffer::Gpu { components, .. } => *components,
43        }
44    }
45
46    pub fn buffer(&self) -> Option<Arc<wgpu::Buffer>> {
47        match self {
48            ScatterColorBuffer::Gpu { buffer, .. } => Some(buffer.clone()),
49            _ => None,
50        }
51    }
52}
53
54/// Inputs required to pack scatter vertices directly on the GPU.
55pub struct Scatter2GpuInputs {
56    pub x_buffer: Arc<wgpu::Buffer>,
57    pub y_buffer: Arc<wgpu::Buffer>,
58    pub len: u32,
59    pub scalar: ScalarType,
60}
61
62/// Parameters describing how the GPU vertices should be generated.
63pub struct Scatter2GpuParams {
64    pub color: Vec4,
65    pub point_size: f32,
66    pub sizes: ScatterAttributeBuffer,
67    pub colors: ScatterColorBuffer,
68    pub lod_stride: u32,
69}
70
71#[repr(C)]
72#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
73struct Scatter2Uniforms {
74    color: [f32; 4],
75    point_size: f32,
76    count: u32,
77    lod_stride: u32,
78    has_sizes: u32,
79    has_colors: u32,
80    color_stride: u32,
81}
82
83/// Builds a GPU-resident vertex buffer for scatter plots directly from
84/// provider-owned XY arrays. Only single-precision tensors are supported
85/// today; callers should fall back to host packing for other types.
86pub fn pack_vertices_from_xy(
87    device: &Arc<wgpu::Device>,
88    queue: &Arc<wgpu::Queue>,
89    inputs: &Scatter2GpuInputs,
90    params: &Scatter2GpuParams,
91) -> Result<GpuVertexBuffer, String> {
92    if inputs.len == 0 {
93        return Err("scatter: empty input tensors".to_string());
94    }
95
96    let lod_stride = params.lod_stride.max(1);
97    let max_points = inputs.len.div_ceil(lod_stride);
98
99    let workgroup_size = tuning::effective_workgroup_size();
100    let shader = compile_shader(device, workgroup_size, inputs.scalar);
101
102    let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
103        label: Some("scatter2-pack-bind-layout"),
104        entries: &[
105            wgpu::BindGroupLayoutEntry {
106                binding: 0,
107                visibility: wgpu::ShaderStages::COMPUTE,
108                ty: wgpu::BindingType::Buffer {
109                    ty: wgpu::BufferBindingType::Storage { read_only: true },
110                    has_dynamic_offset: false,
111                    min_binding_size: None,
112                },
113                count: None,
114            },
115            wgpu::BindGroupLayoutEntry {
116                binding: 1,
117                visibility: wgpu::ShaderStages::COMPUTE,
118                ty: wgpu::BindingType::Buffer {
119                    ty: wgpu::BufferBindingType::Storage { read_only: true },
120                    has_dynamic_offset: false,
121                    min_binding_size: None,
122                },
123                count: None,
124            },
125            wgpu::BindGroupLayoutEntry {
126                binding: 2,
127                visibility: wgpu::ShaderStages::COMPUTE,
128                ty: wgpu::BindingType::Buffer {
129                    ty: wgpu::BufferBindingType::Storage { read_only: false },
130                    has_dynamic_offset: false,
131                    min_binding_size: None,
132                },
133                count: None,
134            },
135            wgpu::BindGroupLayoutEntry {
136                binding: 3,
137                visibility: wgpu::ShaderStages::COMPUTE,
138                ty: wgpu::BindingType::Buffer {
139                    ty: wgpu::BufferBindingType::Uniform,
140                    has_dynamic_offset: false,
141                    min_binding_size: None,
142                },
143                count: None,
144            },
145            wgpu::BindGroupLayoutEntry {
146                binding: 4,
147                visibility: wgpu::ShaderStages::COMPUTE,
148                ty: wgpu::BindingType::Buffer {
149                    ty: wgpu::BufferBindingType::Storage { read_only: true },
150                    has_dynamic_offset: false,
151                    min_binding_size: None,
152                },
153                count: None,
154            },
155            wgpu::BindGroupLayoutEntry {
156                binding: 5,
157                visibility: wgpu::ShaderStages::COMPUTE,
158                ty: wgpu::BindingType::Buffer {
159                    ty: wgpu::BufferBindingType::Storage { read_only: true },
160                    has_dynamic_offset: false,
161                    min_binding_size: None,
162                },
163                count: None,
164            },
165            wgpu::BindGroupLayoutEntry {
166                binding: 6,
167                visibility: wgpu::ShaderStages::COMPUTE,
168                ty: wgpu::BindingType::Buffer {
169                    ty: wgpu::BufferBindingType::Storage { read_only: false },
170                    has_dynamic_offset: false,
171                    min_binding_size: None,
172                },
173                count: None,
174            },
175        ],
176    });
177
178    let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
179        label: Some("scatter2-pack-pipeline-layout"),
180        bind_group_layouts: &[&bind_group_layout],
181        push_constant_ranges: &[],
182    });
183
184    let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
185        label: Some("scatter2-pack-pipeline"),
186        layout: Some(&pipeline_layout),
187        module: &shader,
188        entry_point: "main",
189    });
190
191    let output_size = max_points as u64 * std::mem::size_of::<Vertex>() as u64;
192    let output_buffer = Arc::new(device.create_buffer(&wgpu::BufferDescriptor {
193        label: Some("scatter2-gpu-vertices"),
194        size: output_size,
195        usage: wgpu::BufferUsages::STORAGE
196            | wgpu::BufferUsages::VERTEX
197            | wgpu::BufferUsages::COPY_DST,
198        mapped_at_creation: false,
199    }));
200
201    let indirect_args = Arc::new(device.create_buffer(&wgpu::BufferDescriptor {
202        label: Some("scatter2-gpu-indirect-args"),
203        size: std::mem::size_of::<DrawIndirectArgsRaw>() as u64,
204        usage: wgpu::BufferUsages::STORAGE
205            | wgpu::BufferUsages::INDIRECT
206            | wgpu::BufferUsages::COPY_DST,
207        mapped_at_creation: false,
208    }));
209    let init = DrawIndirectArgsRaw {
210        vertex_count: 0,
211        instance_count: 1,
212        first_vertex: 0,
213        first_instance: 0,
214    };
215    queue.write_buffer(&indirect_args, 0, bytemuck::bytes_of(&init));
216
217    let (size_buffer, has_sizes) = prepare_size_buffer(device, params);
218    let (color_buffer, has_colors, color_stride) = prepare_color_buffer(device, params);
219
220    let uniforms = Scatter2Uniforms {
221        color: params.color.to_array(),
222        point_size: params.point_size,
223        count: inputs.len,
224        lod_stride,
225        has_sizes: if has_sizes { 1 } else { 0 },
226        has_colors: if has_colors { 1 } else { 0 },
227        color_stride,
228    };
229    let uniform_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
230        label: Some("scatter2-pack-uniforms"),
231        contents: bytemuck::bytes_of(&uniforms),
232        usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
233    });
234
235    let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
236        label: Some("scatter2-pack-bind-group"),
237        layout: &bind_group_layout,
238        entries: &[
239            wgpu::BindGroupEntry {
240                binding: 0,
241                resource: inputs.x_buffer.as_entire_binding(),
242            },
243            wgpu::BindGroupEntry {
244                binding: 1,
245                resource: inputs.y_buffer.as_entire_binding(),
246            },
247            wgpu::BindGroupEntry {
248                binding: 2,
249                resource: output_buffer.as_entire_binding(),
250            },
251            wgpu::BindGroupEntry {
252                binding: 3,
253                resource: uniform_buffer.as_entire_binding(),
254            },
255            wgpu::BindGroupEntry {
256                binding: 4,
257                resource: size_buffer.as_entire_binding(),
258            },
259            wgpu::BindGroupEntry {
260                binding: 5,
261                resource: color_buffer.as_entire_binding(),
262            },
263            wgpu::BindGroupEntry {
264                binding: 6,
265                resource: indirect_args.as_entire_binding(),
266            },
267        ],
268    });
269
270    let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
271        label: Some("scatter2-pack-encoder"),
272    });
273    {
274        let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
275            label: Some("scatter2-pack-pass"),
276            timestamp_writes: None,
277        });
278        pass.set_pipeline(&pipeline);
279        pass.set_bind_group(0, &bind_group, &[]);
280        let workgroups = inputs.len.div_ceil(workgroup_size);
281        pass.dispatch_workgroups(workgroups, 1, 1);
282    }
283    queue.submit(Some(encoder.finish()));
284
285    Ok(GpuVertexBuffer::with_indirect(
286        output_buffer,
287        max_points as usize,
288        indirect_args,
289    ))
290}
291
292fn prepare_size_buffer(
293    device: &Arc<wgpu::Device>,
294    params: &Scatter2GpuParams,
295) -> (Arc<wgpu::Buffer>, bool) {
296    match &params.sizes {
297        ScatterAttributeBuffer::None => (
298            Arc::new(
299                device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
300                    label: Some("scatter2-size-fallback"),
301                    contents: bytemuck::cast_slice(&[0.0f32]),
302                    usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
303                }),
304            ),
305            false,
306        ),
307        ScatterAttributeBuffer::Host(data) => {
308            if data.is_empty() {
309                (
310                    Arc::new(
311                        device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
312                            label: Some("scatter2-size-fallback"),
313                            contents: bytemuck::cast_slice(&[0.0f32]),
314                            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
315                        }),
316                    ),
317                    false,
318                )
319            } else {
320                (
321                    Arc::new(
322                        device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
323                            label: Some("scatter2-size-host"),
324                            contents: bytemuck::cast_slice(data.as_slice()),
325                            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
326                        }),
327                    ),
328                    true,
329                )
330            }
331        }
332        ScatterAttributeBuffer::Gpu(buffer) => (buffer.clone(), true),
333    }
334}
335
336fn prepare_color_buffer(
337    device: &Arc<wgpu::Device>,
338    params: &Scatter2GpuParams,
339) -> (Arc<wgpu::Buffer>, bool, u32) {
340    match &params.colors {
341        ScatterColorBuffer::None => (
342            Arc::new(
343                device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
344                    label: Some("scatter2-color-fallback"),
345                    contents: bytemuck::cast_slice(&[
346                        params.color.x,
347                        params.color.y,
348                        params.color.z,
349                        params.color.w,
350                    ]),
351                    usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
352                }),
353            ),
354            false,
355            4,
356        ),
357        ScatterColorBuffer::Host(colors) => {
358            if colors.is_empty() {
359                (
360                    Arc::new(
361                        device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
362                            label: Some("scatter2-color-fallback"),
363                            contents: bytemuck::cast_slice(&[
364                                params.color.x,
365                                params.color.y,
366                                params.color.z,
367                                params.color.w,
368                            ]),
369                            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
370                        }),
371                    ),
372                    false,
373                    4,
374                )
375            } else {
376                (
377                    Arc::new(
378                        device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
379                            label: Some("scatter2-color-host"),
380                            contents: bytemuck::cast_slice(colors.as_slice()),
381                            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
382                        }),
383                    ),
384                    true,
385                    4,
386                )
387            }
388        }
389        ScatterColorBuffer::Gpu { buffer, components } => (buffer.clone(), true, *components),
390    }
391}
392
393fn compile_shader(
394    device: &Arc<wgpu::Device>,
395    workgroup_size: u32,
396    scalar: ScalarType,
397) -> wgpu::ShaderModule {
398    let template = match scalar {
399        ScalarType::F32 => shaders::scatter2::F32,
400        ScalarType::F64 => shaders::scatter2::F64,
401    };
402    let source = template.replace("{{WORKGROUP_SIZE}}", &workgroup_size.to_string());
403    device.create_shader_module(wgpu::ShaderModuleDescriptor {
404        label: Some("scatter2-pack-shader"),
405        source: wgpu::ShaderSource::Wgsl(source.into()),
406    })
407}
408
409#[cfg(test)]
410mod stress_tests {
411    use super::*;
412    use pollster::FutureExt;
413
414    fn maybe_device() -> Option<(Arc<wgpu::Device>, Arc<wgpu::Queue>)> {
415        if std::env::var("RUNMAT_PLOT_SKIP_GPU_TESTS").is_ok()
416            || std::env::var("RUNMAT_PLOT_FORCE_GPU_TESTS").is_err()
417        {
418            return None;
419        }
420        let instance = wgpu::Instance::default();
421        let adapter = instance
422            .request_adapter(&wgpu::RequestAdapterOptions {
423                power_preference: wgpu::PowerPreference::HighPerformance,
424                compatible_surface: None,
425                force_fallback_adapter: false,
426            })
427            .block_on()?;
428        let limits = adapter.limits();
429        let (device, queue) = adapter
430            .request_device(
431                &wgpu::DeviceDescriptor {
432                    label: Some("runmat-plot-scatter-test-device"),
433                    required_features: wgpu::Features::empty(),
434                    required_limits: limits,
435                },
436                None,
437            )
438            .block_on()
439            .ok()?;
440        Some((Arc::new(device), Arc::new(queue)))
441    }
442
443    #[test]
444    fn gpu_packer_handles_large_point_cloud() {
445        let Some((device, queue)) = maybe_device() else {
446            return;
447        };
448        let point_count = 1_200_000u32;
449        let x_data: Vec<f32> = (0..point_count).map(|i| i as f32 * 0.001).collect();
450        let y_data: Vec<f32> = x_data.iter().map(|v| v.sin()).collect();
451
452        let x_buffer = Arc::new(
453            device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
454                label: Some("scatter2-test-x"),
455                contents: bytemuck::cast_slice(&x_data),
456                usage: wgpu::BufferUsages::STORAGE,
457            }),
458        );
459        let y_buffer = Arc::new(
460            device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
461                label: Some("scatter2-test-y"),
462                contents: bytemuck::cast_slice(&y_data),
463                usage: wgpu::BufferUsages::STORAGE,
464            }),
465        );
466
467        let target = 250_000u32;
468        let stride = if point_count <= target {
469            1
470        } else {
471            point_count.div_ceil(target)
472        };
473        let expected_vertices = point_count.div_ceil(stride) as usize;
474
475        let inputs = Scatter2GpuInputs {
476            x_buffer,
477            y_buffer,
478            len: point_count,
479            scalar: ScalarType::F32,
480        };
481        let params = Scatter2GpuParams {
482            color: Vec4::new(0.8, 0.1, 0.3, 1.0),
483            point_size: 8.0,
484            sizes: ScatterAttributeBuffer::None,
485            colors: ScatterColorBuffer::None,
486            lod_stride: stride,
487        };
488
489        let gpu_vertices =
490            pack_vertices_from_xy(&device, &queue, &inputs, &params).expect("gpu packing failed");
491        assert!(gpu_vertices.vertex_count > 0);
492        assert_eq!(gpu_vertices.vertex_count, expected_vertices);
493    }
494}