Skip to main content

runmat_plot/gpu/
scatter2.rs

1use crate::core::renderer::Vertex;
2use crate::core::scene::{DrawIndirectArgsRaw, GpuVertexBuffer};
3use crate::gpu::shaders;
4use crate::gpu::{tuning, ScalarType};
5use glam::Vec4;
6use std::sync::Arc;
7use wgpu::util::DeviceExt;
8
9/// Optional per-point scalar attribute buffer (marker sizes).
10#[derive(Clone, Debug)]
11pub enum ScatterAttributeBuffer {
12    None,
13    Host(Vec<f32>),
14    Gpu(Arc<wgpu::Buffer>),
15}
16
17impl ScatterAttributeBuffer {
18    pub fn has_data(&self) -> bool {
19        !matches!(self, ScatterAttributeBuffer::None)
20    }
21}
22
23/// Optional per-point color buffer. Host data is supplied as RGBA tuples; GPU buffers
24/// may contain RGB (3 floats) or RGBA (4 floats) sequences.
25#[derive(Clone, Debug)]
26pub enum ScatterColorBuffer {
27    None,
28    Host(Vec<[f32; 4]>),
29    Gpu {
30        buffer: Arc<wgpu::Buffer>,
31        components: u32,
32    },
33}
34
35impl ScatterColorBuffer {
36    pub fn has_data(&self) -> bool {
37        !matches!(self, ScatterColorBuffer::None)
38    }
39
40    pub fn stride(&self) -> u32 {
41        match self {
42            ScatterColorBuffer::None => 4,
43            ScatterColorBuffer::Host(_) => 4,
44            ScatterColorBuffer::Gpu { components, .. } => *components,
45        }
46    }
47
48    pub fn buffer(&self) -> Option<Arc<wgpu::Buffer>> {
49        match self {
50            ScatterColorBuffer::Gpu { buffer, .. } => Some(buffer.clone()),
51            _ => None,
52        }
53    }
54}
55
56/// Inputs required to pack scatter vertices directly on the GPU.
57#[derive(Clone, Debug)]
58pub struct Scatter2GpuInputs {
59    pub x_buffer: Arc<wgpu::Buffer>,
60    pub y_buffer: Arc<wgpu::Buffer>,
61    pub len: u32,
62    pub scalar: ScalarType,
63}
64
65/// Parameters describing how the GPU vertices should be generated.
66pub struct Scatter2GpuParams {
67    pub color: Vec4,
68    pub point_size: f32,
69    pub sizes: ScatterAttributeBuffer,
70    pub colors: ScatterColorBuffer,
71    pub lod_stride: u32,
72}
73
74#[repr(C)]
75#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
76struct Scatter2Uniforms {
77    color: [f32; 4],
78    point_size: f32,
79    count: u32,
80    lod_stride: u32,
81    has_sizes: u32,
82    has_colors: u32,
83    color_stride: u32,
84}
85
86/// Builds a GPU-resident vertex buffer for scatter plots directly from
87/// provider-owned XY arrays. Only single-precision tensors are supported
88/// today; callers should fall back to host packing for other types.
89pub fn pack_vertices_from_xy(
90    device: &Arc<wgpu::Device>,
91    queue: &Arc<wgpu::Queue>,
92    inputs: &Scatter2GpuInputs,
93    params: &Scatter2GpuParams,
94) -> Result<GpuVertexBuffer, String> {
95    if inputs.len == 0 {
96        return Err("scatter: empty input tensors".to_string());
97    }
98
99    let lod_stride = params.lod_stride.max(1);
100    let max_points = inputs.len.div_ceil(lod_stride);
101
102    let workgroup_size = tuning::effective_workgroup_size();
103    let shader = compile_shader(device, workgroup_size, inputs.scalar);
104
105    let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
106        label: Some("scatter2-pack-bind-layout"),
107        entries: &[
108            wgpu::BindGroupLayoutEntry {
109                binding: 0,
110                visibility: wgpu::ShaderStages::COMPUTE,
111                ty: wgpu::BindingType::Buffer {
112                    ty: wgpu::BufferBindingType::Storage { read_only: true },
113                    has_dynamic_offset: false,
114                    min_binding_size: None,
115                },
116                count: None,
117            },
118            wgpu::BindGroupLayoutEntry {
119                binding: 1,
120                visibility: wgpu::ShaderStages::COMPUTE,
121                ty: wgpu::BindingType::Buffer {
122                    ty: wgpu::BufferBindingType::Storage { read_only: true },
123                    has_dynamic_offset: false,
124                    min_binding_size: None,
125                },
126                count: None,
127            },
128            wgpu::BindGroupLayoutEntry {
129                binding: 2,
130                visibility: wgpu::ShaderStages::COMPUTE,
131                ty: wgpu::BindingType::Buffer {
132                    ty: wgpu::BufferBindingType::Storage { read_only: false },
133                    has_dynamic_offset: false,
134                    min_binding_size: None,
135                },
136                count: None,
137            },
138            wgpu::BindGroupLayoutEntry {
139                binding: 3,
140                visibility: wgpu::ShaderStages::COMPUTE,
141                ty: wgpu::BindingType::Buffer {
142                    ty: wgpu::BufferBindingType::Uniform,
143                    has_dynamic_offset: false,
144                    min_binding_size: None,
145                },
146                count: None,
147            },
148            wgpu::BindGroupLayoutEntry {
149                binding: 4,
150                visibility: wgpu::ShaderStages::COMPUTE,
151                ty: wgpu::BindingType::Buffer {
152                    ty: wgpu::BufferBindingType::Storage { read_only: true },
153                    has_dynamic_offset: false,
154                    min_binding_size: None,
155                },
156                count: None,
157            },
158            wgpu::BindGroupLayoutEntry {
159                binding: 5,
160                visibility: wgpu::ShaderStages::COMPUTE,
161                ty: wgpu::BindingType::Buffer {
162                    ty: wgpu::BufferBindingType::Storage { read_only: true },
163                    has_dynamic_offset: false,
164                    min_binding_size: None,
165                },
166                count: None,
167            },
168            wgpu::BindGroupLayoutEntry {
169                binding: 6,
170                visibility: wgpu::ShaderStages::COMPUTE,
171                ty: wgpu::BindingType::Buffer {
172                    ty: wgpu::BufferBindingType::Storage { read_only: false },
173                    has_dynamic_offset: false,
174                    min_binding_size: None,
175                },
176                count: None,
177            },
178        ],
179    });
180
181    let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
182        label: Some("scatter2-pack-pipeline-layout"),
183        bind_group_layouts: &[&bind_group_layout],
184        push_constant_ranges: &[],
185    });
186
187    let pipeline =
188        device.create_compute_pipeline(&crate::wgpu_compat::wgpu_compute_pipeline_descriptor! {
189            label: Some("scatter2-pack-pipeline"),
190            layout: Some(&pipeline_layout),
191            module: &shader,
192            entry_point: "main",
193        });
194
195    let output_size = max_points as u64 * 6 * std::mem::size_of::<Vertex>() as u64;
196    let output_buffer = Arc::new(device.create_buffer(&wgpu::BufferDescriptor {
197        label: Some("scatter2-gpu-vertices"),
198        size: output_size,
199        usage: wgpu::BufferUsages::STORAGE
200            | wgpu::BufferUsages::VERTEX
201            | wgpu::BufferUsages::COPY_DST
202            | wgpu::BufferUsages::COPY_SRC,
203        mapped_at_creation: false,
204    }));
205
206    let indirect_args = Arc::new(device.create_buffer(&wgpu::BufferDescriptor {
207        label: Some("scatter2-gpu-indirect-args"),
208        size: std::mem::size_of::<DrawIndirectArgsRaw>() as u64,
209        usage: wgpu::BufferUsages::STORAGE
210            | wgpu::BufferUsages::INDIRECT
211            | wgpu::BufferUsages::COPY_DST
212            | wgpu::BufferUsages::COPY_SRC,
213        mapped_at_creation: false,
214    }));
215    let init = DrawIndirectArgsRaw {
216        vertex_count: 0,
217        instance_count: 1,
218        first_vertex: 0,
219        first_instance: 0,
220    };
221    queue.write_buffer(&indirect_args, 0, bytemuck::bytes_of(&init));
222
223    let (size_buffer, has_sizes) = prepare_size_buffer(device, params);
224    let (color_buffer, has_colors, color_stride) = prepare_color_buffer(device, params);
225
226    let uniforms = Scatter2Uniforms {
227        color: params.color.to_array(),
228        point_size: params.point_size,
229        count: inputs.len,
230        lod_stride,
231        has_sizes: if has_sizes { 1 } else { 0 },
232        has_colors: if has_colors { 1 } else { 0 },
233        color_stride,
234    };
235    let uniform_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
236        label: Some("scatter2-pack-uniforms"),
237        contents: bytemuck::bytes_of(&uniforms),
238        usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
239    });
240
241    let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
242        label: Some("scatter2-pack-bind-group"),
243        layout: &bind_group_layout,
244        entries: &[
245            wgpu::BindGroupEntry {
246                binding: 0,
247                resource: inputs.x_buffer.as_entire_binding(),
248            },
249            wgpu::BindGroupEntry {
250                binding: 1,
251                resource: inputs.y_buffer.as_entire_binding(),
252            },
253            wgpu::BindGroupEntry {
254                binding: 2,
255                resource: output_buffer.as_entire_binding(),
256            },
257            wgpu::BindGroupEntry {
258                binding: 3,
259                resource: uniform_buffer.as_entire_binding(),
260            },
261            wgpu::BindGroupEntry {
262                binding: 4,
263                resource: size_buffer.as_entire_binding(),
264            },
265            wgpu::BindGroupEntry {
266                binding: 5,
267                resource: color_buffer.as_entire_binding(),
268            },
269            wgpu::BindGroupEntry {
270                binding: 6,
271                resource: indirect_args.as_entire_binding(),
272            },
273        ],
274    });
275
276    let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
277        label: Some("scatter2-pack-encoder"),
278    });
279    {
280        let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
281            label: Some("scatter2-pack-pass"),
282            timestamp_writes: None,
283        });
284        pass.set_pipeline(&pipeline);
285        pass.set_bind_group(0, &bind_group, &[]);
286        let workgroups = inputs.len.div_ceil(workgroup_size);
287        pass.dispatch_workgroups(workgroups, 1, 1);
288    }
289    queue.submit(Some(encoder.finish()));
290
291    Ok(GpuVertexBuffer::with_indirect(
292        output_buffer,
293        (max_points as usize) * 6,
294        indirect_args,
295    ))
296}
297
298fn prepare_size_buffer(
299    device: &Arc<wgpu::Device>,
300    params: &Scatter2GpuParams,
301) -> (Arc<wgpu::Buffer>, bool) {
302    match &params.sizes {
303        ScatterAttributeBuffer::None => (
304            Arc::new(
305                device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
306                    label: Some("scatter2-size-fallback"),
307                    contents: bytemuck::cast_slice(&[0.0f32]),
308                    usage: wgpu::BufferUsages::STORAGE
309                        | wgpu::BufferUsages::COPY_DST
310                        | wgpu::BufferUsages::COPY_SRC,
311                }),
312            ),
313            false,
314        ),
315        ScatterAttributeBuffer::Host(data) => {
316            if data.is_empty() {
317                (
318                    Arc::new(
319                        device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
320                            label: Some("scatter2-size-fallback"),
321                            contents: bytemuck::cast_slice(&[0.0f32]),
322                            usage: wgpu::BufferUsages::STORAGE
323                                | wgpu::BufferUsages::COPY_DST
324                                | wgpu::BufferUsages::COPY_SRC,
325                        }),
326                    ),
327                    false,
328                )
329            } else {
330                (
331                    Arc::new(
332                        device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
333                            label: Some("scatter2-size-host"),
334                            contents: bytemuck::cast_slice(data.as_slice()),
335                            usage: wgpu::BufferUsages::STORAGE
336                                | wgpu::BufferUsages::COPY_DST
337                                | wgpu::BufferUsages::COPY_SRC,
338                        }),
339                    ),
340                    true,
341                )
342            }
343        }
344        ScatterAttributeBuffer::Gpu(buffer) => (buffer.clone(), true),
345    }
346}
347
348fn prepare_color_buffer(
349    device: &Arc<wgpu::Device>,
350    params: &Scatter2GpuParams,
351) -> (Arc<wgpu::Buffer>, bool, u32) {
352    match &params.colors {
353        ScatterColorBuffer::None => (
354            Arc::new(
355                device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
356                    label: Some("scatter2-color-fallback"),
357                    contents: bytemuck::cast_slice(&[
358                        params.color.x,
359                        params.color.y,
360                        params.color.z,
361                        params.color.w,
362                    ]),
363                    usage: wgpu::BufferUsages::STORAGE
364                        | wgpu::BufferUsages::COPY_DST
365                        | wgpu::BufferUsages::COPY_SRC,
366                }),
367            ),
368            false,
369            4,
370        ),
371        ScatterColorBuffer::Host(colors) => {
372            if colors.is_empty() {
373                (
374                    Arc::new(
375                        device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
376                            label: Some("scatter2-color-fallback"),
377                            contents: bytemuck::cast_slice(&[
378                                params.color.x,
379                                params.color.y,
380                                params.color.z,
381                                params.color.w,
382                            ]),
383                            usage: wgpu::BufferUsages::STORAGE
384                                | wgpu::BufferUsages::COPY_DST
385                                | wgpu::BufferUsages::COPY_SRC,
386                        }),
387                    ),
388                    false,
389                    4,
390                )
391            } else {
392                (
393                    Arc::new(
394                        device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
395                            label: Some("scatter2-color-host"),
396                            contents: bytemuck::cast_slice(colors.as_slice()),
397                            usage: wgpu::BufferUsages::STORAGE
398                                | wgpu::BufferUsages::COPY_DST
399                                | wgpu::BufferUsages::COPY_SRC,
400                        }),
401                    ),
402                    true,
403                    4,
404                )
405            }
406        }
407        ScatterColorBuffer::Gpu { buffer, components } => (buffer.clone(), true, *components),
408    }
409}
410
411fn compile_shader(
412    device: &Arc<wgpu::Device>,
413    workgroup_size: u32,
414    scalar: ScalarType,
415) -> wgpu::ShaderModule {
416    let template = match scalar {
417        ScalarType::F32 => shaders::scatter2::F32,
418        ScalarType::F64 => shaders::scatter2::F64,
419    };
420    let source = template.replace("{{WORKGROUP_SIZE}}", &workgroup_size.to_string());
421    device.create_shader_module(wgpu::ShaderModuleDescriptor {
422        label: Some("scatter2-pack-shader"),
423        source: wgpu::ShaderSource::Wgsl(source.into()),
424    })
425}
426
427#[cfg(test)]
428mod stress_tests {
429    use super::*;
430    use pollster::FutureExt;
431
432    fn maybe_device() -> Option<(Arc<wgpu::Device>, Arc<wgpu::Queue>)> {
433        if std::env::var("RUNMAT_PLOT_SKIP_GPU_TESTS").is_ok()
434            || std::env::var("RUNMAT_PLOT_FORCE_GPU_TESTS").is_err()
435        {
436            return None;
437        }
438        let instance = wgpu::Instance::default();
439        let adapter = instance
440            .request_adapter(&wgpu::RequestAdapterOptions {
441                power_preference: wgpu::PowerPreference::HighPerformance,
442                compatible_surface: None,
443                force_fallback_adapter: false,
444            })
445            .block_on()?;
446        let limits = adapter.limits();
447        let (device, queue) = adapter
448            .request_device(
449                &crate::wgpu_compat::device_descriptor(
450                    Some("runmat-plot-scatter-test-device"),
451                    wgpu::Features::empty(),
452                    limits,
453                ),
454                None,
455            )
456            .block_on()
457            .ok()?;
458        Some((Arc::new(device), Arc::new(queue)))
459    }
460
461    #[test]
462    fn gpu_packer_handles_large_point_cloud() {
463        let Some((device, queue)) = maybe_device() else {
464            return;
465        };
466        let point_count = 1_200_000u32;
467        let x_data: Vec<f32> = (0..point_count).map(|i| i as f32 * 0.001).collect();
468        let y_data: Vec<f32> = x_data.iter().map(|v| v.sin()).collect();
469
470        let x_buffer = Arc::new(
471            device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
472                label: Some("scatter2-test-x"),
473                contents: bytemuck::cast_slice(&x_data),
474                usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
475            }),
476        );
477        let y_buffer = Arc::new(
478            device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
479                label: Some("scatter2-test-y"),
480                contents: bytemuck::cast_slice(&y_data),
481                usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
482            }),
483        );
484
485        let target = 250_000u32;
486        let stride = if point_count <= target {
487            1
488        } else {
489            point_count.div_ceil(target)
490        };
491        let expected_vertices = point_count.div_ceil(stride) as usize * 6;
492
493        let inputs = Scatter2GpuInputs {
494            x_buffer,
495            y_buffer,
496            len: point_count,
497            scalar: ScalarType::F32,
498        };
499        let params = Scatter2GpuParams {
500            color: Vec4::new(0.8, 0.1, 0.3, 1.0),
501            point_size: 8.0,
502            sizes: ScatterAttributeBuffer::None,
503            colors: ScatterColorBuffer::None,
504            lod_stride: stride,
505        };
506
507        let gpu_vertices =
508            pack_vertices_from_xy(&device, &queue, &inputs, &params).expect("gpu packing failed");
509        assert!(gpu_vertices.vertex_count > 0);
510        assert_eq!(gpu_vertices.vertex_count, expected_vertices);
511    }
512}