Skip to main content

runmat_plot/gpu/
scatter3.rs

1use crate::core::renderer::Vertex;
2use crate::core::scene::{DrawIndirectArgsRaw, GpuVertexBuffer};
3use crate::gpu::scatter2::{ScatterAttributeBuffer, ScatterColorBuffer};
4use crate::gpu::shaders;
5use crate::gpu::{tuning, ScalarType};
6use glam::Vec4;
7use std::sync::Arc;
8use wgpu::util::DeviceExt;
9
10/// Inputs required to pack scatter3 vertices directly on the GPU.
11#[derive(Clone, Debug)]
12pub struct Scatter3GpuInputs {
13    pub x_buffer: Arc<wgpu::Buffer>,
14    pub y_buffer: Arc<wgpu::Buffer>,
15    pub z_buffer: Arc<wgpu::Buffer>,
16    pub len: u32,
17    pub scalar: ScalarType,
18    pub colors: ScatterColorBuffer,
19}
20
21/// Parameters describing how the GPU vertices should be generated.
22pub struct Scatter3GpuParams {
23    pub color: Vec4,
24    pub point_size: f32,
25    pub sizes: ScatterAttributeBuffer,
26    pub colors: ScatterColorBuffer,
27    pub lod_stride: u32,
28}
29
30#[repr(C)]
31#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
32struct Scatter3Uniforms {
33    color: [f32; 4],
34    point_size: f32,
35    count: u32,
36    lod_stride: u32,
37    has_sizes: u32,
38    has_colors: u32,
39    color_stride: u32,
40    _pad: u32,
41}
42
43/// Builds a GPU-resident vertex buffer for scatter3 plots directly from
44/// provider-owned XYZ arrays with either single- or double-precision inputs.
45pub fn pack_vertices_from_xyz(
46    device: &Arc<wgpu::Device>,
47    queue: &Arc<wgpu::Queue>,
48    inputs: &Scatter3GpuInputs,
49    params: &Scatter3GpuParams,
50) -> Result<GpuVertexBuffer, String> {
51    if inputs.len == 0 {
52        return Err("scatter3: empty input tensors".to_string());
53    }
54
55    let workgroup_size = tuning::effective_workgroup_size();
56    let shader = compile_shader(device, workgroup_size, inputs.scalar);
57
58    let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
59        label: Some("scatter3-pack-bind-layout"),
60        entries: &[
61            wgpu::BindGroupLayoutEntry {
62                binding: 0,
63                visibility: wgpu::ShaderStages::COMPUTE,
64                ty: wgpu::BindingType::Buffer {
65                    ty: wgpu::BufferBindingType::Storage { read_only: true },
66                    has_dynamic_offset: false,
67                    min_binding_size: None,
68                },
69                count: None,
70            },
71            wgpu::BindGroupLayoutEntry {
72                binding: 1,
73                visibility: wgpu::ShaderStages::COMPUTE,
74                ty: wgpu::BindingType::Buffer {
75                    ty: wgpu::BufferBindingType::Storage { read_only: true },
76                    has_dynamic_offset: false,
77                    min_binding_size: None,
78                },
79                count: None,
80            },
81            wgpu::BindGroupLayoutEntry {
82                binding: 2,
83                visibility: wgpu::ShaderStages::COMPUTE,
84                ty: wgpu::BindingType::Buffer {
85                    ty: wgpu::BufferBindingType::Storage { read_only: true },
86                    has_dynamic_offset: false,
87                    min_binding_size: None,
88                },
89                count: None,
90            },
91            wgpu::BindGroupLayoutEntry {
92                binding: 3,
93                visibility: wgpu::ShaderStages::COMPUTE,
94                ty: wgpu::BindingType::Buffer {
95                    ty: wgpu::BufferBindingType::Storage { read_only: false },
96                    has_dynamic_offset: false,
97                    min_binding_size: None,
98                },
99                count: None,
100            },
101            wgpu::BindGroupLayoutEntry {
102                binding: 4,
103                visibility: wgpu::ShaderStages::COMPUTE,
104                ty: wgpu::BindingType::Buffer {
105                    ty: wgpu::BufferBindingType::Uniform,
106                    has_dynamic_offset: false,
107                    min_binding_size: None,
108                },
109                count: None,
110            },
111            wgpu::BindGroupLayoutEntry {
112                binding: 5,
113                visibility: wgpu::ShaderStages::COMPUTE,
114                ty: wgpu::BindingType::Buffer {
115                    ty: wgpu::BufferBindingType::Storage { read_only: true },
116                    has_dynamic_offset: false,
117                    min_binding_size: None,
118                },
119                count: None,
120            },
121            wgpu::BindGroupLayoutEntry {
122                binding: 6,
123                visibility: wgpu::ShaderStages::COMPUTE,
124                ty: wgpu::BindingType::Buffer {
125                    ty: wgpu::BufferBindingType::Storage { read_only: true },
126                    has_dynamic_offset: false,
127                    min_binding_size: None,
128                },
129                count: None,
130            },
131            wgpu::BindGroupLayoutEntry {
132                binding: 7,
133                visibility: wgpu::ShaderStages::COMPUTE,
134                ty: wgpu::BindingType::Buffer {
135                    ty: wgpu::BufferBindingType::Storage { read_only: false },
136                    has_dynamic_offset: false,
137                    min_binding_size: None,
138                },
139                count: None,
140            },
141        ],
142    });
143
144    let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
145        label: Some("scatter3-pack-pipeline-layout"),
146        bind_group_layouts: &[&bind_group_layout],
147        push_constant_ranges: &[],
148    });
149
150    let pipeline =
151        device.create_compute_pipeline(&crate::wgpu_compat::wgpu_compute_pipeline_descriptor! {
152            label: Some("scatter3-pack-pipeline"),
153            layout: Some(&pipeline_layout),
154            module: &shader,
155            entry_point: "main",
156        });
157
158    let lod_stride = params.lod_stride.max(1);
159    let max_points = inputs.len.div_ceil(lod_stride);
160    let output_size = max_points as u64 * 6 * std::mem::size_of::<Vertex>() as u64;
161    let output_buffer = Arc::new(device.create_buffer(&wgpu::BufferDescriptor {
162        label: Some("scatter3-gpu-vertices"),
163        size: output_size,
164        usage: wgpu::BufferUsages::STORAGE
165            | wgpu::BufferUsages::VERTEX
166            | wgpu::BufferUsages::COPY_DST
167            | wgpu::BufferUsages::COPY_SRC,
168        mapped_at_creation: false,
169    }));
170
171    let indirect_args = Arc::new(device.create_buffer(&wgpu::BufferDescriptor {
172        label: Some("scatter3-gpu-indirect-args"),
173        size: std::mem::size_of::<DrawIndirectArgsRaw>() as u64,
174        usage: wgpu::BufferUsages::STORAGE
175            | wgpu::BufferUsages::INDIRECT
176            | wgpu::BufferUsages::COPY_DST
177            | wgpu::BufferUsages::COPY_SRC,
178        mapped_at_creation: false,
179    }));
180    let init = DrawIndirectArgsRaw {
181        vertex_count: 0,
182        instance_count: 1,
183        first_vertex: 0,
184        first_instance: 0,
185    };
186    queue.write_buffer(&indirect_args, 0, bytemuck::bytes_of(&init));
187
188    let (size_buffer, has_sizes) = prepare_size_buffer(device, params);
189    let (color_buffer, has_colors, color_stride) = prepare_color_buffer(device, params);
190
191    let uniforms = Scatter3Uniforms {
192        color: params.color.to_array(),
193        point_size: params.point_size,
194        count: inputs.len,
195        lod_stride,
196        has_sizes: if has_sizes { 1 } else { 0 },
197        has_colors: if has_colors { 1 } else { 0 },
198        color_stride,
199        _pad: 0,
200    };
201    let uniform_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
202        label: Some("scatter3-pack-uniforms"),
203        contents: bytemuck::bytes_of(&uniforms),
204        usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
205    });
206
207    let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
208        label: Some("scatter3-pack-bind-group"),
209        layout: &bind_group_layout,
210        entries: &[
211            wgpu::BindGroupEntry {
212                binding: 0,
213                resource: inputs.x_buffer.as_entire_binding(),
214            },
215            wgpu::BindGroupEntry {
216                binding: 1,
217                resource: inputs.y_buffer.as_entire_binding(),
218            },
219            wgpu::BindGroupEntry {
220                binding: 2,
221                resource: inputs.z_buffer.as_entire_binding(),
222            },
223            wgpu::BindGroupEntry {
224                binding: 3,
225                resource: output_buffer.as_entire_binding(),
226            },
227            wgpu::BindGroupEntry {
228                binding: 4,
229                resource: uniform_buffer.as_entire_binding(),
230            },
231            wgpu::BindGroupEntry {
232                binding: 5,
233                resource: size_buffer.as_entire_binding(),
234            },
235            wgpu::BindGroupEntry {
236                binding: 6,
237                resource: color_buffer.as_entire_binding(),
238            },
239            wgpu::BindGroupEntry {
240                binding: 7,
241                resource: indirect_args.as_entire_binding(),
242            },
243        ],
244    });
245
246    let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
247        label: Some("scatter3-pack-encoder"),
248    });
249    {
250        let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
251            label: Some("scatter3-pack-pass"),
252            timestamp_writes: None,
253        });
254        pass.set_pipeline(&pipeline);
255        pass.set_bind_group(0, &bind_group, &[]);
256        let workgroups = inputs.len.div_ceil(workgroup_size);
257        pass.dispatch_workgroups(workgroups, 1, 1);
258    }
259    queue.submit(Some(encoder.finish()));
260
261    Ok(GpuVertexBuffer::with_indirect(
262        output_buffer,
263        (max_points as usize) * 6,
264        indirect_args,
265    ))
266}
267
268fn prepare_size_buffer(
269    device: &Arc<wgpu::Device>,
270    params: &Scatter3GpuParams,
271) -> (Arc<wgpu::Buffer>, bool) {
272    match &params.sizes {
273        ScatterAttributeBuffer::None => (
274            Arc::new(
275                device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
276                    label: Some("scatter3-size-fallback"),
277                    contents: bytemuck::cast_slice(&[0.0f32]),
278                    usage: wgpu::BufferUsages::STORAGE
279                        | wgpu::BufferUsages::COPY_DST
280                        | wgpu::BufferUsages::COPY_SRC,
281                }),
282            ),
283            false,
284        ),
285        ScatterAttributeBuffer::Host(data) => {
286            if data.is_empty() {
287                (
288                    Arc::new(
289                        device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
290                            label: Some("scatter3-size-fallback"),
291                            contents: bytemuck::cast_slice(&[0.0f32]),
292                            usage: wgpu::BufferUsages::STORAGE
293                                | wgpu::BufferUsages::COPY_DST
294                                | wgpu::BufferUsages::COPY_SRC,
295                        }),
296                    ),
297                    false,
298                )
299            } else {
300                (
301                    Arc::new(
302                        device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
303                            label: Some("scatter3-size-host"),
304                            contents: bytemuck::cast_slice(data.as_slice()),
305                            usage: wgpu::BufferUsages::STORAGE
306                                | wgpu::BufferUsages::COPY_DST
307                                | wgpu::BufferUsages::COPY_SRC,
308                        }),
309                    ),
310                    true,
311                )
312            }
313        }
314        ScatterAttributeBuffer::Gpu(buffer) => (buffer.clone(), true),
315    }
316}
317
318fn prepare_color_buffer(
319    device: &Arc<wgpu::Device>,
320    params: &Scatter3GpuParams,
321) -> (Arc<wgpu::Buffer>, bool, u32) {
322    match &params.colors {
323        ScatterColorBuffer::None => (
324            Arc::new(
325                device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
326                    label: Some("scatter3-color-fallback"),
327                    contents: bytemuck::cast_slice(&[
328                        params.color.x,
329                        params.color.y,
330                        params.color.z,
331                        params.color.w,
332                    ]),
333                    usage: wgpu::BufferUsages::STORAGE
334                        | wgpu::BufferUsages::COPY_DST
335                        | wgpu::BufferUsages::COPY_SRC,
336                }),
337            ),
338            false,
339            4,
340        ),
341        ScatterColorBuffer::Host(colors) => {
342            if colors.is_empty() {
343                (
344                    Arc::new(
345                        device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
346                            label: Some("scatter3-color-fallback"),
347                            contents: bytemuck::cast_slice(&[
348                                params.color.x,
349                                params.color.y,
350                                params.color.z,
351                                params.color.w,
352                            ]),
353                            usage: wgpu::BufferUsages::STORAGE
354                                | wgpu::BufferUsages::COPY_DST
355                                | wgpu::BufferUsages::COPY_SRC,
356                        }),
357                    ),
358                    false,
359                    4,
360                )
361            } else {
362                (
363                    Arc::new(
364                        device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
365                            label: Some("scatter3-color-host"),
366                            contents: bytemuck::cast_slice(colors.as_slice()),
367                            usage: wgpu::BufferUsages::STORAGE
368                                | wgpu::BufferUsages::COPY_DST
369                                | wgpu::BufferUsages::COPY_SRC,
370                        }),
371                    ),
372                    true,
373                    4,
374                )
375            }
376        }
377        ScatterColorBuffer::Gpu { buffer, components } => (buffer.clone(), true, *components),
378    }
379}
380
381fn compile_shader(
382    device: &Arc<wgpu::Device>,
383    workgroup_size: u32,
384    scalar: ScalarType,
385) -> wgpu::ShaderModule {
386    let template = match scalar {
387        ScalarType::F32 => shaders::scatter3::F32,
388        ScalarType::F64 => shaders::scatter3::F64,
389    };
390    let source = template.replace("{{WORKGROUP_SIZE}}", &workgroup_size.to_string());
391    device.create_shader_module(wgpu::ShaderModuleDescriptor {
392        label: Some("scatter3-pack-shader"),
393        source: wgpu::ShaderSource::Wgsl(source.into()),
394    })
395}
396
397#[cfg(test)]
398mod stress_tests {
399    use super::*;
400    use pollster::FutureExt;
401
402    fn maybe_device() -> Option<(Arc<wgpu::Device>, Arc<wgpu::Queue>)> {
403        if std::env::var("RUNMAT_PLOT_SKIP_GPU_TESTS").is_ok()
404            || std::env::var("RUNMAT_PLOT_FORCE_GPU_TESTS").is_err()
405        {
406            return None;
407        }
408        let instance = wgpu::Instance::default();
409        let adapter = instance
410            .request_adapter(&wgpu::RequestAdapterOptions {
411                power_preference: wgpu::PowerPreference::HighPerformance,
412                compatible_surface: None,
413                force_fallback_adapter: false,
414            })
415            .block_on()?;
416        let (device, queue) = adapter
417            .request_device(
418                &crate::wgpu_compat::device_descriptor(
419                    Some("scatter3-test-device"),
420                    wgpu::Features::empty(),
421                    adapter.limits(),
422                ),
423                None,
424            )
425            .block_on()
426            .ok()?;
427        Some((Arc::new(device), Arc::new(queue)))
428    }
429
430    #[test]
431    fn lod_stride_limits_vertex_count() {
432        let Some((device, queue)) = maybe_device() else {
433            return;
434        };
435        let point_count = 1_200_000u32;
436        let stride = 4u32;
437        let max_points = point_count.div_ceil(stride);
438
439        let x: Vec<f32> = (0..point_count).map(|i| i as f32 * 0.001).collect();
440        let y: Vec<f32> = x.iter().map(|v| v.cos()).collect();
441        let z: Vec<f32> = x.iter().map(|v| v.sin()).collect();
442
443        let make_buffer = |label: &str, data: &[f32]| {
444            Arc::new(
445                device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
446                    label: Some(label),
447                    contents: bytemuck::cast_slice(data),
448                    usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
449                }),
450            )
451        };
452
453        let inputs = Scatter3GpuInputs {
454            x_buffer: make_buffer("scatter3-test-x", &x),
455            y_buffer: make_buffer("scatter3-test-y", &y),
456            z_buffer: make_buffer("scatter3-test-z", &z),
457            len: point_count,
458            scalar: ScalarType::F32,
459            colors: ScatterColorBuffer::None,
460        };
461        let params = Scatter3GpuParams {
462            color: Vec4::new(0.2, 0.6, 0.9, 1.0),
463            point_size: 6.0,
464            sizes: ScatterAttributeBuffer::None,
465            colors: ScatterColorBuffer::None,
466            lod_stride: stride,
467        };
468
469        let gpu_vertices =
470            pack_vertices_from_xyz(&device, &queue, &inputs, &params).expect("gpu scatter3 pack");
471        assert_eq!(gpu_vertices.vertex_count, max_points as usize * 6);
472    }
473}