1use crate::core::renderer::Vertex;
2use crate::core::scene::{DrawIndirectArgsRaw, GpuVertexBuffer};
3use crate::gpu::scatter2::{ScatterAttributeBuffer, ScatterColorBuffer};
4use crate::gpu::shaders;
5use crate::gpu::{tuning, ScalarType};
6use glam::Vec4;
7use std::sync::Arc;
8use wgpu::util::DeviceExt;
9
10#[derive(Clone, Debug)]
12pub struct Scatter3GpuInputs {
13 pub x_buffer: Arc<wgpu::Buffer>,
14 pub y_buffer: Arc<wgpu::Buffer>,
15 pub z_buffer: Arc<wgpu::Buffer>,
16 pub len: u32,
17 pub scalar: ScalarType,
18 pub colors: ScatterColorBuffer,
19}
20
21pub struct Scatter3GpuParams {
23 pub color: Vec4,
24 pub point_size: f32,
25 pub sizes: ScatterAttributeBuffer,
26 pub colors: ScatterColorBuffer,
27 pub lod_stride: u32,
28}
29
30#[repr(C)]
31#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
32struct Scatter3Uniforms {
33 color: [f32; 4],
34 point_size: f32,
35 count: u32,
36 lod_stride: u32,
37 has_sizes: u32,
38 has_colors: u32,
39 color_stride: u32,
40 _pad: u32,
41}
42
43pub fn pack_vertices_from_xyz(
46 device: &Arc<wgpu::Device>,
47 queue: &Arc<wgpu::Queue>,
48 inputs: &Scatter3GpuInputs,
49 params: &Scatter3GpuParams,
50) -> Result<GpuVertexBuffer, String> {
51 if inputs.len == 0 {
52 return Err("scatter3: empty input tensors".to_string());
53 }
54
55 let workgroup_size = tuning::effective_workgroup_size();
56 let shader = compile_shader(device, workgroup_size, inputs.scalar);
57
58 let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
59 label: Some("scatter3-pack-bind-layout"),
60 entries: &[
61 wgpu::BindGroupLayoutEntry {
62 binding: 0,
63 visibility: wgpu::ShaderStages::COMPUTE,
64 ty: wgpu::BindingType::Buffer {
65 ty: wgpu::BufferBindingType::Storage { read_only: true },
66 has_dynamic_offset: false,
67 min_binding_size: None,
68 },
69 count: None,
70 },
71 wgpu::BindGroupLayoutEntry {
72 binding: 1,
73 visibility: wgpu::ShaderStages::COMPUTE,
74 ty: wgpu::BindingType::Buffer {
75 ty: wgpu::BufferBindingType::Storage { read_only: true },
76 has_dynamic_offset: false,
77 min_binding_size: None,
78 },
79 count: None,
80 },
81 wgpu::BindGroupLayoutEntry {
82 binding: 2,
83 visibility: wgpu::ShaderStages::COMPUTE,
84 ty: wgpu::BindingType::Buffer {
85 ty: wgpu::BufferBindingType::Storage { read_only: true },
86 has_dynamic_offset: false,
87 min_binding_size: None,
88 },
89 count: None,
90 },
91 wgpu::BindGroupLayoutEntry {
92 binding: 3,
93 visibility: wgpu::ShaderStages::COMPUTE,
94 ty: wgpu::BindingType::Buffer {
95 ty: wgpu::BufferBindingType::Storage { read_only: false },
96 has_dynamic_offset: false,
97 min_binding_size: None,
98 },
99 count: None,
100 },
101 wgpu::BindGroupLayoutEntry {
102 binding: 4,
103 visibility: wgpu::ShaderStages::COMPUTE,
104 ty: wgpu::BindingType::Buffer {
105 ty: wgpu::BufferBindingType::Uniform,
106 has_dynamic_offset: false,
107 min_binding_size: None,
108 },
109 count: None,
110 },
111 wgpu::BindGroupLayoutEntry {
112 binding: 5,
113 visibility: wgpu::ShaderStages::COMPUTE,
114 ty: wgpu::BindingType::Buffer {
115 ty: wgpu::BufferBindingType::Storage { read_only: true },
116 has_dynamic_offset: false,
117 min_binding_size: None,
118 },
119 count: None,
120 },
121 wgpu::BindGroupLayoutEntry {
122 binding: 6,
123 visibility: wgpu::ShaderStages::COMPUTE,
124 ty: wgpu::BindingType::Buffer {
125 ty: wgpu::BufferBindingType::Storage { read_only: true },
126 has_dynamic_offset: false,
127 min_binding_size: None,
128 },
129 count: None,
130 },
131 wgpu::BindGroupLayoutEntry {
132 binding: 7,
133 visibility: wgpu::ShaderStages::COMPUTE,
134 ty: wgpu::BindingType::Buffer {
135 ty: wgpu::BufferBindingType::Storage { read_only: false },
136 has_dynamic_offset: false,
137 min_binding_size: None,
138 },
139 count: None,
140 },
141 ],
142 });
143
144 let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
145 label: Some("scatter3-pack-pipeline-layout"),
146 bind_group_layouts: &[&bind_group_layout],
147 push_constant_ranges: &[],
148 });
149
150 let pipeline =
151 device.create_compute_pipeline(&crate::wgpu_compat::wgpu_compute_pipeline_descriptor! {
152 label: Some("scatter3-pack-pipeline"),
153 layout: Some(&pipeline_layout),
154 module: &shader,
155 entry_point: "main",
156 });
157
158 let lod_stride = params.lod_stride.max(1);
159 let max_points = inputs.len.div_ceil(lod_stride);
160 let output_size = max_points as u64 * 6 * std::mem::size_of::<Vertex>() as u64;
161 let output_buffer = Arc::new(device.create_buffer(&wgpu::BufferDescriptor {
162 label: Some("scatter3-gpu-vertices"),
163 size: output_size,
164 usage: wgpu::BufferUsages::STORAGE
165 | wgpu::BufferUsages::VERTEX
166 | wgpu::BufferUsages::COPY_DST
167 | wgpu::BufferUsages::COPY_SRC,
168 mapped_at_creation: false,
169 }));
170
171 let indirect_args = Arc::new(device.create_buffer(&wgpu::BufferDescriptor {
172 label: Some("scatter3-gpu-indirect-args"),
173 size: std::mem::size_of::<DrawIndirectArgsRaw>() as u64,
174 usage: wgpu::BufferUsages::STORAGE
175 | wgpu::BufferUsages::INDIRECT
176 | wgpu::BufferUsages::COPY_DST
177 | wgpu::BufferUsages::COPY_SRC,
178 mapped_at_creation: false,
179 }));
180 let init = DrawIndirectArgsRaw {
181 vertex_count: 0,
182 instance_count: 1,
183 first_vertex: 0,
184 first_instance: 0,
185 };
186 queue.write_buffer(&indirect_args, 0, bytemuck::bytes_of(&init));
187
188 let (size_buffer, has_sizes) = prepare_size_buffer(device, params);
189 let (color_buffer, has_colors, color_stride) = prepare_color_buffer(device, params);
190
191 let uniforms = Scatter3Uniforms {
192 color: params.color.to_array(),
193 point_size: params.point_size,
194 count: inputs.len,
195 lod_stride,
196 has_sizes: if has_sizes { 1 } else { 0 },
197 has_colors: if has_colors { 1 } else { 0 },
198 color_stride,
199 _pad: 0,
200 };
201 let uniform_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
202 label: Some("scatter3-pack-uniforms"),
203 contents: bytemuck::bytes_of(&uniforms),
204 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
205 });
206
207 let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
208 label: Some("scatter3-pack-bind-group"),
209 layout: &bind_group_layout,
210 entries: &[
211 wgpu::BindGroupEntry {
212 binding: 0,
213 resource: inputs.x_buffer.as_entire_binding(),
214 },
215 wgpu::BindGroupEntry {
216 binding: 1,
217 resource: inputs.y_buffer.as_entire_binding(),
218 },
219 wgpu::BindGroupEntry {
220 binding: 2,
221 resource: inputs.z_buffer.as_entire_binding(),
222 },
223 wgpu::BindGroupEntry {
224 binding: 3,
225 resource: output_buffer.as_entire_binding(),
226 },
227 wgpu::BindGroupEntry {
228 binding: 4,
229 resource: uniform_buffer.as_entire_binding(),
230 },
231 wgpu::BindGroupEntry {
232 binding: 5,
233 resource: size_buffer.as_entire_binding(),
234 },
235 wgpu::BindGroupEntry {
236 binding: 6,
237 resource: color_buffer.as_entire_binding(),
238 },
239 wgpu::BindGroupEntry {
240 binding: 7,
241 resource: indirect_args.as_entire_binding(),
242 },
243 ],
244 });
245
246 let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
247 label: Some("scatter3-pack-encoder"),
248 });
249 {
250 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
251 label: Some("scatter3-pack-pass"),
252 timestamp_writes: None,
253 });
254 pass.set_pipeline(&pipeline);
255 pass.set_bind_group(0, &bind_group, &[]);
256 let workgroups = inputs.len.div_ceil(workgroup_size);
257 pass.dispatch_workgroups(workgroups, 1, 1);
258 }
259 queue.submit(Some(encoder.finish()));
260
261 Ok(GpuVertexBuffer::with_indirect(
262 output_buffer,
263 (max_points as usize) * 6,
264 indirect_args,
265 ))
266}
267
268fn prepare_size_buffer(
269 device: &Arc<wgpu::Device>,
270 params: &Scatter3GpuParams,
271) -> (Arc<wgpu::Buffer>, bool) {
272 match ¶ms.sizes {
273 ScatterAttributeBuffer::None => (
274 Arc::new(
275 device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
276 label: Some("scatter3-size-fallback"),
277 contents: bytemuck::cast_slice(&[0.0f32]),
278 usage: wgpu::BufferUsages::STORAGE
279 | wgpu::BufferUsages::COPY_DST
280 | wgpu::BufferUsages::COPY_SRC,
281 }),
282 ),
283 false,
284 ),
285 ScatterAttributeBuffer::Host(data) => {
286 if data.is_empty() {
287 (
288 Arc::new(
289 device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
290 label: Some("scatter3-size-fallback"),
291 contents: bytemuck::cast_slice(&[0.0f32]),
292 usage: wgpu::BufferUsages::STORAGE
293 | wgpu::BufferUsages::COPY_DST
294 | wgpu::BufferUsages::COPY_SRC,
295 }),
296 ),
297 false,
298 )
299 } else {
300 (
301 Arc::new(
302 device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
303 label: Some("scatter3-size-host"),
304 contents: bytemuck::cast_slice(data.as_slice()),
305 usage: wgpu::BufferUsages::STORAGE
306 | wgpu::BufferUsages::COPY_DST
307 | wgpu::BufferUsages::COPY_SRC,
308 }),
309 ),
310 true,
311 )
312 }
313 }
314 ScatterAttributeBuffer::Gpu(buffer) => (buffer.clone(), true),
315 }
316}
317
318fn prepare_color_buffer(
319 device: &Arc<wgpu::Device>,
320 params: &Scatter3GpuParams,
321) -> (Arc<wgpu::Buffer>, bool, u32) {
322 match ¶ms.colors {
323 ScatterColorBuffer::None => (
324 Arc::new(
325 device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
326 label: Some("scatter3-color-fallback"),
327 contents: bytemuck::cast_slice(&[
328 params.color.x,
329 params.color.y,
330 params.color.z,
331 params.color.w,
332 ]),
333 usage: wgpu::BufferUsages::STORAGE
334 | wgpu::BufferUsages::COPY_DST
335 | wgpu::BufferUsages::COPY_SRC,
336 }),
337 ),
338 false,
339 4,
340 ),
341 ScatterColorBuffer::Host(colors) => {
342 if colors.is_empty() {
343 (
344 Arc::new(
345 device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
346 label: Some("scatter3-color-fallback"),
347 contents: bytemuck::cast_slice(&[
348 params.color.x,
349 params.color.y,
350 params.color.z,
351 params.color.w,
352 ]),
353 usage: wgpu::BufferUsages::STORAGE
354 | wgpu::BufferUsages::COPY_DST
355 | wgpu::BufferUsages::COPY_SRC,
356 }),
357 ),
358 false,
359 4,
360 )
361 } else {
362 (
363 Arc::new(
364 device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
365 label: Some("scatter3-color-host"),
366 contents: bytemuck::cast_slice(colors.as_slice()),
367 usage: wgpu::BufferUsages::STORAGE
368 | wgpu::BufferUsages::COPY_DST
369 | wgpu::BufferUsages::COPY_SRC,
370 }),
371 ),
372 true,
373 4,
374 )
375 }
376 }
377 ScatterColorBuffer::Gpu { buffer, components } => (buffer.clone(), true, *components),
378 }
379}
380
381fn compile_shader(
382 device: &Arc<wgpu::Device>,
383 workgroup_size: u32,
384 scalar: ScalarType,
385) -> wgpu::ShaderModule {
386 let template = match scalar {
387 ScalarType::F32 => shaders::scatter3::F32,
388 ScalarType::F64 => shaders::scatter3::F64,
389 };
390 let source = template.replace("{{WORKGROUP_SIZE}}", &workgroup_size.to_string());
391 device.create_shader_module(wgpu::ShaderModuleDescriptor {
392 label: Some("scatter3-pack-shader"),
393 source: wgpu::ShaderSource::Wgsl(source.into()),
394 })
395}
396
397#[cfg(test)]
398mod stress_tests {
399 use super::*;
400 use pollster::FutureExt;
401
402 fn maybe_device() -> Option<(Arc<wgpu::Device>, Arc<wgpu::Queue>)> {
403 if std::env::var("RUNMAT_PLOT_SKIP_GPU_TESTS").is_ok()
404 || std::env::var("RUNMAT_PLOT_FORCE_GPU_TESTS").is_err()
405 {
406 return None;
407 }
408 let instance = wgpu::Instance::default();
409 let adapter = instance
410 .request_adapter(&wgpu::RequestAdapterOptions {
411 power_preference: wgpu::PowerPreference::HighPerformance,
412 compatible_surface: None,
413 force_fallback_adapter: false,
414 })
415 .block_on()?;
416 let (device, queue) = adapter
417 .request_device(
418 &crate::wgpu_compat::device_descriptor(
419 Some("scatter3-test-device"),
420 wgpu::Features::empty(),
421 adapter.limits(),
422 ),
423 None,
424 )
425 .block_on()
426 .ok()?;
427 Some((Arc::new(device), Arc::new(queue)))
428 }
429
430 #[test]
431 fn lod_stride_limits_vertex_count() {
432 let Some((device, queue)) = maybe_device() else {
433 return;
434 };
435 let point_count = 1_200_000u32;
436 let stride = 4u32;
437 let max_points = point_count.div_ceil(stride);
438
439 let x: Vec<f32> = (0..point_count).map(|i| i as f32 * 0.001).collect();
440 let y: Vec<f32> = x.iter().map(|v| v.cos()).collect();
441 let z: Vec<f32> = x.iter().map(|v| v.sin()).collect();
442
443 let make_buffer = |label: &str, data: &[f32]| {
444 Arc::new(
445 device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
446 label: Some(label),
447 contents: bytemuck::cast_slice(data),
448 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
449 }),
450 )
451 };
452
453 let inputs = Scatter3GpuInputs {
454 x_buffer: make_buffer("scatter3-test-x", &x),
455 y_buffer: make_buffer("scatter3-test-y", &y),
456 z_buffer: make_buffer("scatter3-test-z", &z),
457 len: point_count,
458 scalar: ScalarType::F32,
459 colors: ScatterColorBuffer::None,
460 };
461 let params = Scatter3GpuParams {
462 color: Vec4::new(0.2, 0.6, 0.9, 1.0),
463 point_size: 6.0,
464 sizes: ScatterAttributeBuffer::None,
465 colors: ScatterColorBuffer::None,
466 lod_stride: stride,
467 };
468
469 let gpu_vertices =
470 pack_vertices_from_xyz(&device, &queue, &inputs, ¶ms).expect("gpu scatter3 pack");
471 assert_eq!(gpu_vertices.vertex_count, max_points as usize * 6);
472 }
473}