Skip to main content

runmat_plot/gpu/
scatter3.rs

1use crate::core::renderer::Vertex;
2use crate::core::scene::{DrawIndirectArgsRaw, GpuVertexBuffer};
3use crate::gpu::scatter2::{ScatterAttributeBuffer, ScatterColorBuffer};
4use crate::gpu::shaders;
5use crate::gpu::{tuning, ScalarType};
6use glam::Vec4;
7use std::sync::Arc;
8use wgpu::util::DeviceExt;
9
10/// Inputs required to pack scatter3 vertices directly on the GPU.
11pub struct Scatter3GpuInputs {
12    pub x_buffer: Arc<wgpu::Buffer>,
13    pub y_buffer: Arc<wgpu::Buffer>,
14    pub z_buffer: Arc<wgpu::Buffer>,
15    pub len: u32,
16    pub scalar: ScalarType,
17}
18
19/// Parameters describing how the GPU vertices should be generated.
20pub struct Scatter3GpuParams {
21    pub color: Vec4,
22    pub point_size: f32,
23    pub sizes: ScatterAttributeBuffer,
24    pub colors: ScatterColorBuffer,
25    pub lod_stride: u32,
26}
27
28#[repr(C)]
29#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
30struct Scatter3Uniforms {
31    color: [f32; 4],
32    point_size: f32,
33    count: u32,
34    lod_stride: u32,
35    has_sizes: u32,
36    has_colors: u32,
37    color_stride: u32,
38    _pad: u32,
39}
40
41/// Builds a GPU-resident vertex buffer for scatter3 plots directly from
42/// provider-owned XYZ arrays with either single- or double-precision inputs.
43pub fn pack_vertices_from_xyz(
44    device: &Arc<wgpu::Device>,
45    queue: &Arc<wgpu::Queue>,
46    inputs: &Scatter3GpuInputs,
47    params: &Scatter3GpuParams,
48) -> Result<GpuVertexBuffer, String> {
49    if inputs.len == 0 {
50        return Err("scatter3: empty input tensors".to_string());
51    }
52
53    let workgroup_size = tuning::effective_workgroup_size();
54    let shader = compile_shader(device, workgroup_size, inputs.scalar);
55
56    let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
57        label: Some("scatter3-pack-bind-layout"),
58        entries: &[
59            wgpu::BindGroupLayoutEntry {
60                binding: 0,
61                visibility: wgpu::ShaderStages::COMPUTE,
62                ty: wgpu::BindingType::Buffer {
63                    ty: wgpu::BufferBindingType::Storage { read_only: true },
64                    has_dynamic_offset: false,
65                    min_binding_size: None,
66                },
67                count: None,
68            },
69            wgpu::BindGroupLayoutEntry {
70                binding: 1,
71                visibility: wgpu::ShaderStages::COMPUTE,
72                ty: wgpu::BindingType::Buffer {
73                    ty: wgpu::BufferBindingType::Storage { read_only: true },
74                    has_dynamic_offset: false,
75                    min_binding_size: None,
76                },
77                count: None,
78            },
79            wgpu::BindGroupLayoutEntry {
80                binding: 2,
81                visibility: wgpu::ShaderStages::COMPUTE,
82                ty: wgpu::BindingType::Buffer {
83                    ty: wgpu::BufferBindingType::Storage { read_only: true },
84                    has_dynamic_offset: false,
85                    min_binding_size: None,
86                },
87                count: None,
88            },
89            wgpu::BindGroupLayoutEntry {
90                binding: 3,
91                visibility: wgpu::ShaderStages::COMPUTE,
92                ty: wgpu::BindingType::Buffer {
93                    ty: wgpu::BufferBindingType::Storage { read_only: false },
94                    has_dynamic_offset: false,
95                    min_binding_size: None,
96                },
97                count: None,
98            },
99            wgpu::BindGroupLayoutEntry {
100                binding: 4,
101                visibility: wgpu::ShaderStages::COMPUTE,
102                ty: wgpu::BindingType::Buffer {
103                    ty: wgpu::BufferBindingType::Uniform,
104                    has_dynamic_offset: false,
105                    min_binding_size: None,
106                },
107                count: None,
108            },
109            wgpu::BindGroupLayoutEntry {
110                binding: 5,
111                visibility: wgpu::ShaderStages::COMPUTE,
112                ty: wgpu::BindingType::Buffer {
113                    ty: wgpu::BufferBindingType::Storage { read_only: true },
114                    has_dynamic_offset: false,
115                    min_binding_size: None,
116                },
117                count: None,
118            },
119            wgpu::BindGroupLayoutEntry {
120                binding: 6,
121                visibility: wgpu::ShaderStages::COMPUTE,
122                ty: wgpu::BindingType::Buffer {
123                    ty: wgpu::BufferBindingType::Storage { read_only: true },
124                    has_dynamic_offset: false,
125                    min_binding_size: None,
126                },
127                count: None,
128            },
129            wgpu::BindGroupLayoutEntry {
130                binding: 7,
131                visibility: wgpu::ShaderStages::COMPUTE,
132                ty: wgpu::BindingType::Buffer {
133                    ty: wgpu::BufferBindingType::Storage { read_only: false },
134                    has_dynamic_offset: false,
135                    min_binding_size: None,
136                },
137                count: None,
138            },
139        ],
140    });
141
142    let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
143        label: Some("scatter3-pack-pipeline-layout"),
144        bind_group_layouts: &[&bind_group_layout],
145        push_constant_ranges: &[],
146    });
147
148    let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
149        label: Some("scatter3-pack-pipeline"),
150        layout: Some(&pipeline_layout),
151        module: &shader,
152        entry_point: "main",
153    });
154
155    let lod_stride = params.lod_stride.max(1);
156    let max_points = inputs.len.div_ceil(lod_stride);
157    let output_size = max_points as u64 * std::mem::size_of::<Vertex>() as u64;
158    let output_buffer = Arc::new(device.create_buffer(&wgpu::BufferDescriptor {
159        label: Some("scatter3-gpu-vertices"),
160        size: output_size,
161        usage: wgpu::BufferUsages::STORAGE
162            | wgpu::BufferUsages::VERTEX
163            | wgpu::BufferUsages::COPY_DST,
164        mapped_at_creation: false,
165    }));
166
167    let indirect_args = Arc::new(device.create_buffer(&wgpu::BufferDescriptor {
168        label: Some("scatter3-gpu-indirect-args"),
169        size: std::mem::size_of::<DrawIndirectArgsRaw>() as u64,
170        usage: wgpu::BufferUsages::STORAGE
171            | wgpu::BufferUsages::INDIRECT
172            | wgpu::BufferUsages::COPY_DST,
173        mapped_at_creation: false,
174    }));
175    let init = DrawIndirectArgsRaw {
176        vertex_count: 0,
177        instance_count: 1,
178        first_vertex: 0,
179        first_instance: 0,
180    };
181    queue.write_buffer(&indirect_args, 0, bytemuck::bytes_of(&init));
182
183    let (size_buffer, has_sizes) = prepare_size_buffer(device, params);
184    let (color_buffer, has_colors, color_stride) = prepare_color_buffer(device, params);
185
186    let uniforms = Scatter3Uniforms {
187        color: params.color.to_array(),
188        point_size: params.point_size,
189        count: inputs.len,
190        lod_stride,
191        has_sizes: if has_sizes { 1 } else { 0 },
192        has_colors: if has_colors { 1 } else { 0 },
193        color_stride,
194        _pad: 0,
195    };
196    let uniform_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
197        label: Some("scatter3-pack-uniforms"),
198        contents: bytemuck::bytes_of(&uniforms),
199        usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
200    });
201
202    let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
203        label: Some("scatter3-pack-bind-group"),
204        layout: &bind_group_layout,
205        entries: &[
206            wgpu::BindGroupEntry {
207                binding: 0,
208                resource: inputs.x_buffer.as_entire_binding(),
209            },
210            wgpu::BindGroupEntry {
211                binding: 1,
212                resource: inputs.y_buffer.as_entire_binding(),
213            },
214            wgpu::BindGroupEntry {
215                binding: 2,
216                resource: inputs.z_buffer.as_entire_binding(),
217            },
218            wgpu::BindGroupEntry {
219                binding: 3,
220                resource: output_buffer.as_entire_binding(),
221            },
222            wgpu::BindGroupEntry {
223                binding: 4,
224                resource: uniform_buffer.as_entire_binding(),
225            },
226            wgpu::BindGroupEntry {
227                binding: 5,
228                resource: size_buffer.as_entire_binding(),
229            },
230            wgpu::BindGroupEntry {
231                binding: 6,
232                resource: color_buffer.as_entire_binding(),
233            },
234            wgpu::BindGroupEntry {
235                binding: 7,
236                resource: indirect_args.as_entire_binding(),
237            },
238        ],
239    });
240
241    let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
242        label: Some("scatter3-pack-encoder"),
243    });
244    {
245        let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
246            label: Some("scatter3-pack-pass"),
247            timestamp_writes: None,
248        });
249        pass.set_pipeline(&pipeline);
250        pass.set_bind_group(0, &bind_group, &[]);
251        let workgroups = inputs.len.div_ceil(workgroup_size);
252        pass.dispatch_workgroups(workgroups, 1, 1);
253    }
254    queue.submit(Some(encoder.finish()));
255
256    Ok(GpuVertexBuffer::with_indirect(
257        output_buffer,
258        max_points as usize,
259        indirect_args,
260    ))
261}
262
263fn prepare_size_buffer(
264    device: &Arc<wgpu::Device>,
265    params: &Scatter3GpuParams,
266) -> (Arc<wgpu::Buffer>, bool) {
267    match &params.sizes {
268        ScatterAttributeBuffer::None => (
269            Arc::new(
270                device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
271                    label: Some("scatter3-size-fallback"),
272                    contents: bytemuck::cast_slice(&[0.0f32]),
273                    usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
274                }),
275            ),
276            false,
277        ),
278        ScatterAttributeBuffer::Host(data) => {
279            if data.is_empty() {
280                (
281                    Arc::new(
282                        device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
283                            label: Some("scatter3-size-fallback"),
284                            contents: bytemuck::cast_slice(&[0.0f32]),
285                            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
286                        }),
287                    ),
288                    false,
289                )
290            } else {
291                (
292                    Arc::new(
293                        device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
294                            label: Some("scatter3-size-host"),
295                            contents: bytemuck::cast_slice(data.as_slice()),
296                            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
297                        }),
298                    ),
299                    true,
300                )
301            }
302        }
303        ScatterAttributeBuffer::Gpu(buffer) => (buffer.clone(), true),
304    }
305}
306
307fn prepare_color_buffer(
308    device: &Arc<wgpu::Device>,
309    params: &Scatter3GpuParams,
310) -> (Arc<wgpu::Buffer>, bool, u32) {
311    match &params.colors {
312        ScatterColorBuffer::None => (
313            Arc::new(
314                device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
315                    label: Some("scatter3-color-fallback"),
316                    contents: bytemuck::cast_slice(&[
317                        params.color.x,
318                        params.color.y,
319                        params.color.z,
320                        params.color.w,
321                    ]),
322                    usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
323                }),
324            ),
325            false,
326            4,
327        ),
328        ScatterColorBuffer::Host(colors) => {
329            if colors.is_empty() {
330                (
331                    Arc::new(
332                        device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
333                            label: Some("scatter3-color-fallback"),
334                            contents: bytemuck::cast_slice(&[
335                                params.color.x,
336                                params.color.y,
337                                params.color.z,
338                                params.color.w,
339                            ]),
340                            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
341                        }),
342                    ),
343                    false,
344                    4,
345                )
346            } else {
347                (
348                    Arc::new(
349                        device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
350                            label: Some("scatter3-color-host"),
351                            contents: bytemuck::cast_slice(colors.as_slice()),
352                            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
353                        }),
354                    ),
355                    true,
356                    4,
357                )
358            }
359        }
360        ScatterColorBuffer::Gpu { buffer, components } => (buffer.clone(), true, *components),
361    }
362}
363
364fn compile_shader(
365    device: &Arc<wgpu::Device>,
366    workgroup_size: u32,
367    scalar: ScalarType,
368) -> wgpu::ShaderModule {
369    let template = match scalar {
370        ScalarType::F32 => shaders::scatter3::F32,
371        ScalarType::F64 => shaders::scatter3::F64,
372    };
373    let source = template.replace("{{WORKGROUP_SIZE}}", &workgroup_size.to_string());
374    device.create_shader_module(wgpu::ShaderModuleDescriptor {
375        label: Some("scatter3-pack-shader"),
376        source: wgpu::ShaderSource::Wgsl(source.into()),
377    })
378}
379
380#[cfg(test)]
381mod stress_tests {
382    use super::*;
383    use pollster::FutureExt;
384
385    fn maybe_device() -> Option<(Arc<wgpu::Device>, Arc<wgpu::Queue>)> {
386        if std::env::var("RUNMAT_PLOT_SKIP_GPU_TESTS").is_ok()
387            || std::env::var("RUNMAT_PLOT_FORCE_GPU_TESTS").is_err()
388        {
389            return None;
390        }
391        let instance = wgpu::Instance::default();
392        let adapter = instance
393            .request_adapter(&wgpu::RequestAdapterOptions {
394                power_preference: wgpu::PowerPreference::HighPerformance,
395                compatible_surface: None,
396                force_fallback_adapter: false,
397            })
398            .block_on()?;
399        let (device, queue) = adapter
400            .request_device(
401                &wgpu::DeviceDescriptor {
402                    label: Some("scatter3-test-device"),
403                    required_features: wgpu::Features::empty(),
404                    required_limits: adapter.limits(),
405                },
406                None,
407            )
408            .block_on()
409            .ok()?;
410        Some((Arc::new(device), Arc::new(queue)))
411    }
412
413    #[test]
414    fn lod_stride_limits_vertex_count() {
415        let Some((device, queue)) = maybe_device() else {
416            return;
417        };
418        let point_count = 1_200_000u32;
419        let stride = 4u32;
420        let max_points = point_count.div_ceil(stride);
421
422        let x: Vec<f32> = (0..point_count).map(|i| i as f32 * 0.001).collect();
423        let y: Vec<f32> = x.iter().map(|v| v.cos()).collect();
424        let z: Vec<f32> = x.iter().map(|v| v.sin()).collect();
425
426        let make_buffer = |label: &str, data: &[f32]| {
427            Arc::new(
428                device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
429                    label: Some(label),
430                    contents: bytemuck::cast_slice(data),
431                    usage: wgpu::BufferUsages::STORAGE,
432                }),
433            )
434        };
435
436        let inputs = Scatter3GpuInputs {
437            x_buffer: make_buffer("scatter3-test-x", &x),
438            y_buffer: make_buffer("scatter3-test-y", &y),
439            z_buffer: make_buffer("scatter3-test-z", &z),
440            len: point_count,
441            scalar: ScalarType::F32,
442        };
443        let params = Scatter3GpuParams {
444            color: Vec4::new(0.2, 0.6, 0.9, 1.0),
445            point_size: 6.0,
446            sizes: ScatterAttributeBuffer::None,
447            colors: ScatterColorBuffer::None,
448            lod_stride: stride,
449        };
450
451        let gpu_vertices =
452            pack_vertices_from_xyz(&device, &queue, &inputs, &params).expect("gpu scatter3 pack");
453        assert_eq!(gpu_vertices.vertex_count, max_points as usize);
454    }
455}