Skip to main content

runmat_plot/gpu/
scatter3.rs

1use crate::core::renderer::Vertex;
2use crate::core::scene::{DrawIndirectArgsRaw, GpuVertexBuffer};
3use crate::gpu::scatter2::{ScatterAttributeBuffer, ScatterColorBuffer};
4use crate::gpu::shaders;
5use crate::gpu::{tuning, ScalarType};
6use glam::Vec4;
7use std::sync::Arc;
8use wgpu::util::DeviceExt;
9
10/// Inputs required to pack scatter3 vertices directly on the GPU.
11pub struct Scatter3GpuInputs {
12    pub x_buffer: Arc<wgpu::Buffer>,
13    pub y_buffer: Arc<wgpu::Buffer>,
14    pub z_buffer: Arc<wgpu::Buffer>,
15    pub len: u32,
16    pub scalar: ScalarType,
17}
18
19/// Parameters describing how the GPU vertices should be generated.
20pub struct Scatter3GpuParams {
21    pub color: Vec4,
22    pub point_size: f32,
23    pub sizes: ScatterAttributeBuffer,
24    pub colors: ScatterColorBuffer,
25    pub lod_stride: u32,
26}
27
28#[repr(C)]
29#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
30struct Scatter3Uniforms {
31    color: [f32; 4],
32    point_size: f32,
33    count: u32,
34    lod_stride: u32,
35    has_sizes: u32,
36    has_colors: u32,
37    color_stride: u32,
38    _pad: u32,
39}
40
41/// Builds a GPU-resident vertex buffer for scatter3 plots directly from
42/// provider-owned XYZ arrays with either single- or double-precision inputs.
43pub fn pack_vertices_from_xyz(
44    device: &Arc<wgpu::Device>,
45    queue: &Arc<wgpu::Queue>,
46    inputs: &Scatter3GpuInputs,
47    params: &Scatter3GpuParams,
48) -> Result<GpuVertexBuffer, String> {
49    if inputs.len == 0 {
50        return Err("scatter3: empty input tensors".to_string());
51    }
52
53    let workgroup_size = tuning::effective_workgroup_size();
54    let shader = compile_shader(device, workgroup_size, inputs.scalar);
55
56    let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
57        label: Some("scatter3-pack-bind-layout"),
58        entries: &[
59            wgpu::BindGroupLayoutEntry {
60                binding: 0,
61                visibility: wgpu::ShaderStages::COMPUTE,
62                ty: wgpu::BindingType::Buffer {
63                    ty: wgpu::BufferBindingType::Storage { read_only: true },
64                    has_dynamic_offset: false,
65                    min_binding_size: None,
66                },
67                count: None,
68            },
69            wgpu::BindGroupLayoutEntry {
70                binding: 1,
71                visibility: wgpu::ShaderStages::COMPUTE,
72                ty: wgpu::BindingType::Buffer {
73                    ty: wgpu::BufferBindingType::Storage { read_only: true },
74                    has_dynamic_offset: false,
75                    min_binding_size: None,
76                },
77                count: None,
78            },
79            wgpu::BindGroupLayoutEntry {
80                binding: 2,
81                visibility: wgpu::ShaderStages::COMPUTE,
82                ty: wgpu::BindingType::Buffer {
83                    ty: wgpu::BufferBindingType::Storage { read_only: true },
84                    has_dynamic_offset: false,
85                    min_binding_size: None,
86                },
87                count: None,
88            },
89            wgpu::BindGroupLayoutEntry {
90                binding: 3,
91                visibility: wgpu::ShaderStages::COMPUTE,
92                ty: wgpu::BindingType::Buffer {
93                    ty: wgpu::BufferBindingType::Storage { read_only: false },
94                    has_dynamic_offset: false,
95                    min_binding_size: None,
96                },
97                count: None,
98            },
99            wgpu::BindGroupLayoutEntry {
100                binding: 4,
101                visibility: wgpu::ShaderStages::COMPUTE,
102                ty: wgpu::BindingType::Buffer {
103                    ty: wgpu::BufferBindingType::Uniform,
104                    has_dynamic_offset: false,
105                    min_binding_size: None,
106                },
107                count: None,
108            },
109            wgpu::BindGroupLayoutEntry {
110                binding: 5,
111                visibility: wgpu::ShaderStages::COMPUTE,
112                ty: wgpu::BindingType::Buffer {
113                    ty: wgpu::BufferBindingType::Storage { read_only: true },
114                    has_dynamic_offset: false,
115                    min_binding_size: None,
116                },
117                count: None,
118            },
119            wgpu::BindGroupLayoutEntry {
120                binding: 6,
121                visibility: wgpu::ShaderStages::COMPUTE,
122                ty: wgpu::BindingType::Buffer {
123                    ty: wgpu::BufferBindingType::Storage { read_only: true },
124                    has_dynamic_offset: false,
125                    min_binding_size: None,
126                },
127                count: None,
128            },
129            wgpu::BindGroupLayoutEntry {
130                binding: 7,
131                visibility: wgpu::ShaderStages::COMPUTE,
132                ty: wgpu::BindingType::Buffer {
133                    ty: wgpu::BufferBindingType::Storage { read_only: false },
134                    has_dynamic_offset: false,
135                    min_binding_size: None,
136                },
137                count: None,
138            },
139        ],
140    });
141
142    let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
143        label: Some("scatter3-pack-pipeline-layout"),
144        bind_group_layouts: &[&bind_group_layout],
145        push_constant_ranges: &[],
146    });
147
148    let pipeline =
149        device.create_compute_pipeline(&crate::wgpu_compat::wgpu_compute_pipeline_descriptor! {
150            label: Some("scatter3-pack-pipeline"),
151            layout: Some(&pipeline_layout),
152            module: &shader,
153            entry_point: "main",
154        });
155
156    let lod_stride = params.lod_stride.max(1);
157    let max_points = inputs.len.div_ceil(lod_stride);
158    let output_size = max_points as u64 * 6 * std::mem::size_of::<Vertex>() as u64;
159    let output_buffer = Arc::new(device.create_buffer(&wgpu::BufferDescriptor {
160        label: Some("scatter3-gpu-vertices"),
161        size: output_size,
162        usage: wgpu::BufferUsages::STORAGE
163            | wgpu::BufferUsages::VERTEX
164            | wgpu::BufferUsages::COPY_DST,
165        mapped_at_creation: false,
166    }));
167
168    let indirect_args = Arc::new(device.create_buffer(&wgpu::BufferDescriptor {
169        label: Some("scatter3-gpu-indirect-args"),
170        size: std::mem::size_of::<DrawIndirectArgsRaw>() as u64,
171        usage: wgpu::BufferUsages::STORAGE
172            | wgpu::BufferUsages::INDIRECT
173            | wgpu::BufferUsages::COPY_DST,
174        mapped_at_creation: false,
175    }));
176    let init = DrawIndirectArgsRaw {
177        vertex_count: 0,
178        instance_count: 1,
179        first_vertex: 0,
180        first_instance: 0,
181    };
182    queue.write_buffer(&indirect_args, 0, bytemuck::bytes_of(&init));
183
184    let (size_buffer, has_sizes) = prepare_size_buffer(device, params);
185    let (color_buffer, has_colors, color_stride) = prepare_color_buffer(device, params);
186
187    let uniforms = Scatter3Uniforms {
188        color: params.color.to_array(),
189        point_size: params.point_size,
190        count: inputs.len,
191        lod_stride,
192        has_sizes: if has_sizes { 1 } else { 0 },
193        has_colors: if has_colors { 1 } else { 0 },
194        color_stride,
195        _pad: 0,
196    };
197    let uniform_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
198        label: Some("scatter3-pack-uniforms"),
199        contents: bytemuck::bytes_of(&uniforms),
200        usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
201    });
202
203    let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
204        label: Some("scatter3-pack-bind-group"),
205        layout: &bind_group_layout,
206        entries: &[
207            wgpu::BindGroupEntry {
208                binding: 0,
209                resource: inputs.x_buffer.as_entire_binding(),
210            },
211            wgpu::BindGroupEntry {
212                binding: 1,
213                resource: inputs.y_buffer.as_entire_binding(),
214            },
215            wgpu::BindGroupEntry {
216                binding: 2,
217                resource: inputs.z_buffer.as_entire_binding(),
218            },
219            wgpu::BindGroupEntry {
220                binding: 3,
221                resource: output_buffer.as_entire_binding(),
222            },
223            wgpu::BindGroupEntry {
224                binding: 4,
225                resource: uniform_buffer.as_entire_binding(),
226            },
227            wgpu::BindGroupEntry {
228                binding: 5,
229                resource: size_buffer.as_entire_binding(),
230            },
231            wgpu::BindGroupEntry {
232                binding: 6,
233                resource: color_buffer.as_entire_binding(),
234            },
235            wgpu::BindGroupEntry {
236                binding: 7,
237                resource: indirect_args.as_entire_binding(),
238            },
239        ],
240    });
241
242    let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
243        label: Some("scatter3-pack-encoder"),
244    });
245    {
246        let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
247            label: Some("scatter3-pack-pass"),
248            timestamp_writes: None,
249        });
250        pass.set_pipeline(&pipeline);
251        pass.set_bind_group(0, &bind_group, &[]);
252        let workgroups = inputs.len.div_ceil(workgroup_size);
253        pass.dispatch_workgroups(workgroups, 1, 1);
254    }
255    queue.submit(Some(encoder.finish()));
256
257    Ok(GpuVertexBuffer::with_indirect(
258        output_buffer,
259        (max_points as usize) * 6,
260        indirect_args,
261    ))
262}
263
264fn prepare_size_buffer(
265    device: &Arc<wgpu::Device>,
266    params: &Scatter3GpuParams,
267) -> (Arc<wgpu::Buffer>, bool) {
268    match &params.sizes {
269        ScatterAttributeBuffer::None => (
270            Arc::new(
271                device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
272                    label: Some("scatter3-size-fallback"),
273                    contents: bytemuck::cast_slice(&[0.0f32]),
274                    usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
275                }),
276            ),
277            false,
278        ),
279        ScatterAttributeBuffer::Host(data) => {
280            if data.is_empty() {
281                (
282                    Arc::new(
283                        device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
284                            label: Some("scatter3-size-fallback"),
285                            contents: bytemuck::cast_slice(&[0.0f32]),
286                            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
287                        }),
288                    ),
289                    false,
290                )
291            } else {
292                (
293                    Arc::new(
294                        device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
295                            label: Some("scatter3-size-host"),
296                            contents: bytemuck::cast_slice(data.as_slice()),
297                            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
298                        }),
299                    ),
300                    true,
301                )
302            }
303        }
304        ScatterAttributeBuffer::Gpu(buffer) => (buffer.clone(), true),
305    }
306}
307
308fn prepare_color_buffer(
309    device: &Arc<wgpu::Device>,
310    params: &Scatter3GpuParams,
311) -> (Arc<wgpu::Buffer>, bool, u32) {
312    match &params.colors {
313        ScatterColorBuffer::None => (
314            Arc::new(
315                device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
316                    label: Some("scatter3-color-fallback"),
317                    contents: bytemuck::cast_slice(&[
318                        params.color.x,
319                        params.color.y,
320                        params.color.z,
321                        params.color.w,
322                    ]),
323                    usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
324                }),
325            ),
326            false,
327            4,
328        ),
329        ScatterColorBuffer::Host(colors) => {
330            if colors.is_empty() {
331                (
332                    Arc::new(
333                        device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
334                            label: Some("scatter3-color-fallback"),
335                            contents: bytemuck::cast_slice(&[
336                                params.color.x,
337                                params.color.y,
338                                params.color.z,
339                                params.color.w,
340                            ]),
341                            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
342                        }),
343                    ),
344                    false,
345                    4,
346                )
347            } else {
348                (
349                    Arc::new(
350                        device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
351                            label: Some("scatter3-color-host"),
352                            contents: bytemuck::cast_slice(colors.as_slice()),
353                            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
354                        }),
355                    ),
356                    true,
357                    4,
358                )
359            }
360        }
361        ScatterColorBuffer::Gpu { buffer, components } => (buffer.clone(), true, *components),
362    }
363}
364
365fn compile_shader(
366    device: &Arc<wgpu::Device>,
367    workgroup_size: u32,
368    scalar: ScalarType,
369) -> wgpu::ShaderModule {
370    let template = match scalar {
371        ScalarType::F32 => shaders::scatter3::F32,
372        ScalarType::F64 => shaders::scatter3::F64,
373    };
374    let source = template.replace("{{WORKGROUP_SIZE}}", &workgroup_size.to_string());
375    device.create_shader_module(wgpu::ShaderModuleDescriptor {
376        label: Some("scatter3-pack-shader"),
377        source: wgpu::ShaderSource::Wgsl(source.into()),
378    })
379}
380
381#[cfg(test)]
382mod stress_tests {
383    use super::*;
384    use pollster::FutureExt;
385
386    fn maybe_device() -> Option<(Arc<wgpu::Device>, Arc<wgpu::Queue>)> {
387        if std::env::var("RUNMAT_PLOT_SKIP_GPU_TESTS").is_ok()
388            || std::env::var("RUNMAT_PLOT_FORCE_GPU_TESTS").is_err()
389        {
390            return None;
391        }
392        let instance = wgpu::Instance::default();
393        let adapter = instance
394            .request_adapter(&wgpu::RequestAdapterOptions {
395                power_preference: wgpu::PowerPreference::HighPerformance,
396                compatible_surface: None,
397                force_fallback_adapter: false,
398            })
399            .block_on()?;
400        let (device, queue) = adapter
401            .request_device(
402                &crate::wgpu_compat::device_descriptor(
403                    Some("scatter3-test-device"),
404                    wgpu::Features::empty(),
405                    adapter.limits(),
406                ),
407                None,
408            )
409            .block_on()
410            .ok()?;
411        Some((Arc::new(device), Arc::new(queue)))
412    }
413
414    #[test]
415    fn lod_stride_limits_vertex_count() {
416        let Some((device, queue)) = maybe_device() else {
417            return;
418        };
419        let point_count = 1_200_000u32;
420        let stride = 4u32;
421        let max_points = point_count.div_ceil(stride);
422
423        let x: Vec<f32> = (0..point_count).map(|i| i as f32 * 0.001).collect();
424        let y: Vec<f32> = x.iter().map(|v| v.cos()).collect();
425        let z: Vec<f32> = x.iter().map(|v| v.sin()).collect();
426
427        let make_buffer = |label: &str, data: &[f32]| {
428            Arc::new(
429                device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
430                    label: Some(label),
431                    contents: bytemuck::cast_slice(data),
432                    usage: wgpu::BufferUsages::STORAGE,
433                }),
434            )
435        };
436
437        let inputs = Scatter3GpuInputs {
438            x_buffer: make_buffer("scatter3-test-x", &x),
439            y_buffer: make_buffer("scatter3-test-y", &y),
440            z_buffer: make_buffer("scatter3-test-z", &z),
441            len: point_count,
442            scalar: ScalarType::F32,
443        };
444        let params = Scatter3GpuParams {
445            color: Vec4::new(0.2, 0.6, 0.9, 1.0),
446            point_size: 6.0,
447            sizes: ScatterAttributeBuffer::None,
448            colors: ScatterColorBuffer::None,
449            lod_stride: stride,
450        };
451
452        let gpu_vertices =
453            pack_vertices_from_xyz(&device, &queue, &inputs, &params).expect("gpu scatter3 pack");
454        assert_eq!(gpu_vertices.vertex_count, max_points as usize * 6);
455    }
456}