1use crate::core::renderer::Vertex;
2use crate::core::scene::{DrawIndirectArgsRaw, GpuVertexBuffer};
3use crate::gpu::shaders;
4use crate::gpu::{tuning, ScalarType};
5use glam::Vec4;
6use std::sync::Arc;
7use wgpu::util::DeviceExt;
8
9pub enum ScatterAttributeBuffer {
11 None,
12 Host(Vec<f32>),
13 Gpu(Arc<wgpu::Buffer>),
14}
15
16impl ScatterAttributeBuffer {
17 pub fn has_data(&self) -> bool {
18 !matches!(self, ScatterAttributeBuffer::None)
19 }
20}
21
22pub enum ScatterColorBuffer {
25 None,
26 Host(Vec<[f32; 4]>),
27 Gpu {
28 buffer: Arc<wgpu::Buffer>,
29 components: u32,
30 },
31}
32
33impl ScatterColorBuffer {
34 pub fn has_data(&self) -> bool {
35 !matches!(self, ScatterColorBuffer::None)
36 }
37
38 pub fn stride(&self) -> u32 {
39 match self {
40 ScatterColorBuffer::None => 4,
41 ScatterColorBuffer::Host(_) => 4,
42 ScatterColorBuffer::Gpu { components, .. } => *components,
43 }
44 }
45
46 pub fn buffer(&self) -> Option<Arc<wgpu::Buffer>> {
47 match self {
48 ScatterColorBuffer::Gpu { buffer, .. } => Some(buffer.clone()),
49 _ => None,
50 }
51 }
52}
53
54pub struct Scatter2GpuInputs {
56 pub x_buffer: Arc<wgpu::Buffer>,
57 pub y_buffer: Arc<wgpu::Buffer>,
58 pub len: u32,
59 pub scalar: ScalarType,
60}
61
62pub struct Scatter2GpuParams {
64 pub color: Vec4,
65 pub point_size: f32,
66 pub sizes: ScatterAttributeBuffer,
67 pub colors: ScatterColorBuffer,
68 pub lod_stride: u32,
69}
70
71#[repr(C)]
72#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
73struct Scatter2Uniforms {
74 color: [f32; 4],
75 point_size: f32,
76 count: u32,
77 lod_stride: u32,
78 has_sizes: u32,
79 has_colors: u32,
80 color_stride: u32,
81}
82
83pub fn pack_vertices_from_xy(
87 device: &Arc<wgpu::Device>,
88 queue: &Arc<wgpu::Queue>,
89 inputs: &Scatter2GpuInputs,
90 params: &Scatter2GpuParams,
91) -> Result<GpuVertexBuffer, String> {
92 if inputs.len == 0 {
93 return Err("scatter: empty input tensors".to_string());
94 }
95
96 let lod_stride = params.lod_stride.max(1);
97 let max_points = inputs.len.div_ceil(lod_stride);
98
99 let workgroup_size = tuning::effective_workgroup_size();
100 let shader = compile_shader(device, workgroup_size, inputs.scalar);
101
102 let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
103 label: Some("scatter2-pack-bind-layout"),
104 entries: &[
105 wgpu::BindGroupLayoutEntry {
106 binding: 0,
107 visibility: wgpu::ShaderStages::COMPUTE,
108 ty: wgpu::BindingType::Buffer {
109 ty: wgpu::BufferBindingType::Storage { read_only: true },
110 has_dynamic_offset: false,
111 min_binding_size: None,
112 },
113 count: None,
114 },
115 wgpu::BindGroupLayoutEntry {
116 binding: 1,
117 visibility: wgpu::ShaderStages::COMPUTE,
118 ty: wgpu::BindingType::Buffer {
119 ty: wgpu::BufferBindingType::Storage { read_only: true },
120 has_dynamic_offset: false,
121 min_binding_size: None,
122 },
123 count: None,
124 },
125 wgpu::BindGroupLayoutEntry {
126 binding: 2,
127 visibility: wgpu::ShaderStages::COMPUTE,
128 ty: wgpu::BindingType::Buffer {
129 ty: wgpu::BufferBindingType::Storage { read_only: false },
130 has_dynamic_offset: false,
131 min_binding_size: None,
132 },
133 count: None,
134 },
135 wgpu::BindGroupLayoutEntry {
136 binding: 3,
137 visibility: wgpu::ShaderStages::COMPUTE,
138 ty: wgpu::BindingType::Buffer {
139 ty: wgpu::BufferBindingType::Uniform,
140 has_dynamic_offset: false,
141 min_binding_size: None,
142 },
143 count: None,
144 },
145 wgpu::BindGroupLayoutEntry {
146 binding: 4,
147 visibility: wgpu::ShaderStages::COMPUTE,
148 ty: wgpu::BindingType::Buffer {
149 ty: wgpu::BufferBindingType::Storage { read_only: true },
150 has_dynamic_offset: false,
151 min_binding_size: None,
152 },
153 count: None,
154 },
155 wgpu::BindGroupLayoutEntry {
156 binding: 5,
157 visibility: wgpu::ShaderStages::COMPUTE,
158 ty: wgpu::BindingType::Buffer {
159 ty: wgpu::BufferBindingType::Storage { read_only: true },
160 has_dynamic_offset: false,
161 min_binding_size: None,
162 },
163 count: None,
164 },
165 wgpu::BindGroupLayoutEntry {
166 binding: 6,
167 visibility: wgpu::ShaderStages::COMPUTE,
168 ty: wgpu::BindingType::Buffer {
169 ty: wgpu::BufferBindingType::Storage { read_only: false },
170 has_dynamic_offset: false,
171 min_binding_size: None,
172 },
173 count: None,
174 },
175 ],
176 });
177
178 let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
179 label: Some("scatter2-pack-pipeline-layout"),
180 bind_group_layouts: &[&bind_group_layout],
181 push_constant_ranges: &[],
182 });
183
184 let pipeline =
185 device.create_compute_pipeline(&crate::wgpu_compat::wgpu_compute_pipeline_descriptor! {
186 label: Some("scatter2-pack-pipeline"),
187 layout: Some(&pipeline_layout),
188 module: &shader,
189 entry_point: "main",
190 });
191
192 let output_size = max_points as u64 * 6 * std::mem::size_of::<Vertex>() as u64;
193 let output_buffer = Arc::new(device.create_buffer(&wgpu::BufferDescriptor {
194 label: Some("scatter2-gpu-vertices"),
195 size: output_size,
196 usage: wgpu::BufferUsages::STORAGE
197 | wgpu::BufferUsages::VERTEX
198 | wgpu::BufferUsages::COPY_DST,
199 mapped_at_creation: false,
200 }));
201
202 let indirect_args = Arc::new(device.create_buffer(&wgpu::BufferDescriptor {
203 label: Some("scatter2-gpu-indirect-args"),
204 size: std::mem::size_of::<DrawIndirectArgsRaw>() as u64,
205 usage: wgpu::BufferUsages::STORAGE
206 | wgpu::BufferUsages::INDIRECT
207 | wgpu::BufferUsages::COPY_DST,
208 mapped_at_creation: false,
209 }));
210 let init = DrawIndirectArgsRaw {
211 vertex_count: 0,
212 instance_count: 1,
213 first_vertex: 0,
214 first_instance: 0,
215 };
216 queue.write_buffer(&indirect_args, 0, bytemuck::bytes_of(&init));
217
218 let (size_buffer, has_sizes) = prepare_size_buffer(device, params);
219 let (color_buffer, has_colors, color_stride) = prepare_color_buffer(device, params);
220
221 let uniforms = Scatter2Uniforms {
222 color: params.color.to_array(),
223 point_size: params.point_size,
224 count: inputs.len,
225 lod_stride,
226 has_sizes: if has_sizes { 1 } else { 0 },
227 has_colors: if has_colors { 1 } else { 0 },
228 color_stride,
229 };
230 let uniform_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
231 label: Some("scatter2-pack-uniforms"),
232 contents: bytemuck::bytes_of(&uniforms),
233 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
234 });
235
236 let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
237 label: Some("scatter2-pack-bind-group"),
238 layout: &bind_group_layout,
239 entries: &[
240 wgpu::BindGroupEntry {
241 binding: 0,
242 resource: inputs.x_buffer.as_entire_binding(),
243 },
244 wgpu::BindGroupEntry {
245 binding: 1,
246 resource: inputs.y_buffer.as_entire_binding(),
247 },
248 wgpu::BindGroupEntry {
249 binding: 2,
250 resource: output_buffer.as_entire_binding(),
251 },
252 wgpu::BindGroupEntry {
253 binding: 3,
254 resource: uniform_buffer.as_entire_binding(),
255 },
256 wgpu::BindGroupEntry {
257 binding: 4,
258 resource: size_buffer.as_entire_binding(),
259 },
260 wgpu::BindGroupEntry {
261 binding: 5,
262 resource: color_buffer.as_entire_binding(),
263 },
264 wgpu::BindGroupEntry {
265 binding: 6,
266 resource: indirect_args.as_entire_binding(),
267 },
268 ],
269 });
270
271 let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
272 label: Some("scatter2-pack-encoder"),
273 });
274 {
275 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
276 label: Some("scatter2-pack-pass"),
277 timestamp_writes: None,
278 });
279 pass.set_pipeline(&pipeline);
280 pass.set_bind_group(0, &bind_group, &[]);
281 let workgroups = inputs.len.div_ceil(workgroup_size);
282 pass.dispatch_workgroups(workgroups, 1, 1);
283 }
284 queue.submit(Some(encoder.finish()));
285
286 Ok(GpuVertexBuffer::with_indirect(
287 output_buffer,
288 (max_points as usize) * 6,
289 indirect_args,
290 ))
291}
292
293fn prepare_size_buffer(
294 device: &Arc<wgpu::Device>,
295 params: &Scatter2GpuParams,
296) -> (Arc<wgpu::Buffer>, bool) {
297 match ¶ms.sizes {
298 ScatterAttributeBuffer::None => (
299 Arc::new(
300 device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
301 label: Some("scatter2-size-fallback"),
302 contents: bytemuck::cast_slice(&[0.0f32]),
303 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
304 }),
305 ),
306 false,
307 ),
308 ScatterAttributeBuffer::Host(data) => {
309 if data.is_empty() {
310 (
311 Arc::new(
312 device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
313 label: Some("scatter2-size-fallback"),
314 contents: bytemuck::cast_slice(&[0.0f32]),
315 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
316 }),
317 ),
318 false,
319 )
320 } else {
321 (
322 Arc::new(
323 device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
324 label: Some("scatter2-size-host"),
325 contents: bytemuck::cast_slice(data.as_slice()),
326 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
327 }),
328 ),
329 true,
330 )
331 }
332 }
333 ScatterAttributeBuffer::Gpu(buffer) => (buffer.clone(), true),
334 }
335}
336
337fn prepare_color_buffer(
338 device: &Arc<wgpu::Device>,
339 params: &Scatter2GpuParams,
340) -> (Arc<wgpu::Buffer>, bool, u32) {
341 match ¶ms.colors {
342 ScatterColorBuffer::None => (
343 Arc::new(
344 device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
345 label: Some("scatter2-color-fallback"),
346 contents: bytemuck::cast_slice(&[
347 params.color.x,
348 params.color.y,
349 params.color.z,
350 params.color.w,
351 ]),
352 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
353 }),
354 ),
355 false,
356 4,
357 ),
358 ScatterColorBuffer::Host(colors) => {
359 if colors.is_empty() {
360 (
361 Arc::new(
362 device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
363 label: Some("scatter2-color-fallback"),
364 contents: bytemuck::cast_slice(&[
365 params.color.x,
366 params.color.y,
367 params.color.z,
368 params.color.w,
369 ]),
370 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
371 }),
372 ),
373 false,
374 4,
375 )
376 } else {
377 (
378 Arc::new(
379 device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
380 label: Some("scatter2-color-host"),
381 contents: bytemuck::cast_slice(colors.as_slice()),
382 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
383 }),
384 ),
385 true,
386 4,
387 )
388 }
389 }
390 ScatterColorBuffer::Gpu { buffer, components } => (buffer.clone(), true, *components),
391 }
392}
393
394fn compile_shader(
395 device: &Arc<wgpu::Device>,
396 workgroup_size: u32,
397 scalar: ScalarType,
398) -> wgpu::ShaderModule {
399 let template = match scalar {
400 ScalarType::F32 => shaders::scatter2::F32,
401 ScalarType::F64 => shaders::scatter2::F64,
402 };
403 let source = template.replace("{{WORKGROUP_SIZE}}", &workgroup_size.to_string());
404 device.create_shader_module(wgpu::ShaderModuleDescriptor {
405 label: Some("scatter2-pack-shader"),
406 source: wgpu::ShaderSource::Wgsl(source.into()),
407 })
408}
409
410#[cfg(test)]
411mod stress_tests {
412 use super::*;
413 use pollster::FutureExt;
414
415 fn maybe_device() -> Option<(Arc<wgpu::Device>, Arc<wgpu::Queue>)> {
416 if std::env::var("RUNMAT_PLOT_SKIP_GPU_TESTS").is_ok()
417 || std::env::var("RUNMAT_PLOT_FORCE_GPU_TESTS").is_err()
418 {
419 return None;
420 }
421 let instance = wgpu::Instance::default();
422 let adapter = instance
423 .request_adapter(&wgpu::RequestAdapterOptions {
424 power_preference: wgpu::PowerPreference::HighPerformance,
425 compatible_surface: None,
426 force_fallback_adapter: false,
427 })
428 .block_on()?;
429 let limits = adapter.limits();
430 let (device, queue) = adapter
431 .request_device(
432 &crate::wgpu_compat::device_descriptor(
433 Some("runmat-plot-scatter-test-device"),
434 wgpu::Features::empty(),
435 limits,
436 ),
437 None,
438 )
439 .block_on()
440 .ok()?;
441 Some((Arc::new(device), Arc::new(queue)))
442 }
443
444 #[test]
445 fn gpu_packer_handles_large_point_cloud() {
446 let Some((device, queue)) = maybe_device() else {
447 return;
448 };
449 let point_count = 1_200_000u32;
450 let x_data: Vec<f32> = (0..point_count).map(|i| i as f32 * 0.001).collect();
451 let y_data: Vec<f32> = x_data.iter().map(|v| v.sin()).collect();
452
453 let x_buffer = Arc::new(
454 device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
455 label: Some("scatter2-test-x"),
456 contents: bytemuck::cast_slice(&x_data),
457 usage: wgpu::BufferUsages::STORAGE,
458 }),
459 );
460 let y_buffer = Arc::new(
461 device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
462 label: Some("scatter2-test-y"),
463 contents: bytemuck::cast_slice(&y_data),
464 usage: wgpu::BufferUsages::STORAGE,
465 }),
466 );
467
468 let target = 250_000u32;
469 let stride = if point_count <= target {
470 1
471 } else {
472 point_count.div_ceil(target)
473 };
474 let expected_vertices = point_count.div_ceil(stride) as usize * 6;
475
476 let inputs = Scatter2GpuInputs {
477 x_buffer,
478 y_buffer,
479 len: point_count,
480 scalar: ScalarType::F32,
481 };
482 let params = Scatter2GpuParams {
483 color: Vec4::new(0.8, 0.1, 0.3, 1.0),
484 point_size: 8.0,
485 sizes: ScatterAttributeBuffer::None,
486 colors: ScatterColorBuffer::None,
487 lod_stride: stride,
488 };
489
490 let gpu_vertices =
491 pack_vertices_from_xy(&device, &queue, &inputs, ¶ms).expect("gpu packing failed");
492 assert!(gpu_vertices.vertex_count > 0);
493 assert_eq!(gpu_vertices.vertex_count, expected_vertices);
494 }
495}