1use crate::core::renderer::Vertex;
2use crate::core::scene::{DrawIndirectArgsRaw, GpuVertexBuffer};
3use crate::gpu::shaders;
4use crate::gpu::{tuning, ScalarType};
5use glam::Vec4;
6use std::sync::Arc;
7use wgpu::util::DeviceExt;
8
9pub enum ScatterAttributeBuffer {
11 None,
12 Host(Vec<f32>),
13 Gpu(Arc<wgpu::Buffer>),
14}
15
16impl ScatterAttributeBuffer {
17 pub fn has_data(&self) -> bool {
18 !matches!(self, ScatterAttributeBuffer::None)
19 }
20}
21
22pub enum ScatterColorBuffer {
25 None,
26 Host(Vec<[f32; 4]>),
27 Gpu {
28 buffer: Arc<wgpu::Buffer>,
29 components: u32,
30 },
31}
32
33impl ScatterColorBuffer {
34 pub fn has_data(&self) -> bool {
35 !matches!(self, ScatterColorBuffer::None)
36 }
37
38 pub fn stride(&self) -> u32 {
39 match self {
40 ScatterColorBuffer::None => 4,
41 ScatterColorBuffer::Host(_) => 4,
42 ScatterColorBuffer::Gpu { components, .. } => *components,
43 }
44 }
45
46 pub fn buffer(&self) -> Option<Arc<wgpu::Buffer>> {
47 match self {
48 ScatterColorBuffer::Gpu { buffer, .. } => Some(buffer.clone()),
49 _ => None,
50 }
51 }
52}
53
54pub struct Scatter2GpuInputs {
56 pub x_buffer: Arc<wgpu::Buffer>,
57 pub y_buffer: Arc<wgpu::Buffer>,
58 pub len: u32,
59 pub scalar: ScalarType,
60}
61
62pub struct Scatter2GpuParams {
64 pub color: Vec4,
65 pub point_size: f32,
66 pub sizes: ScatterAttributeBuffer,
67 pub colors: ScatterColorBuffer,
68 pub lod_stride: u32,
69}
70
71#[repr(C)]
72#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
73struct Scatter2Uniforms {
74 color: [f32; 4],
75 point_size: f32,
76 count: u32,
77 lod_stride: u32,
78 has_sizes: u32,
79 has_colors: u32,
80 color_stride: u32,
81}
82
83pub fn pack_vertices_from_xy(
87 device: &Arc<wgpu::Device>,
88 queue: &Arc<wgpu::Queue>,
89 inputs: &Scatter2GpuInputs,
90 params: &Scatter2GpuParams,
91) -> Result<GpuVertexBuffer, String> {
92 if inputs.len == 0 {
93 return Err("scatter: empty input tensors".to_string());
94 }
95
96 let lod_stride = params.lod_stride.max(1);
97 let max_points = inputs.len.div_ceil(lod_stride);
98
99 let workgroup_size = tuning::effective_workgroup_size();
100 let shader = compile_shader(device, workgroup_size, inputs.scalar);
101
102 let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
103 label: Some("scatter2-pack-bind-layout"),
104 entries: &[
105 wgpu::BindGroupLayoutEntry {
106 binding: 0,
107 visibility: wgpu::ShaderStages::COMPUTE,
108 ty: wgpu::BindingType::Buffer {
109 ty: wgpu::BufferBindingType::Storage { read_only: true },
110 has_dynamic_offset: false,
111 min_binding_size: None,
112 },
113 count: None,
114 },
115 wgpu::BindGroupLayoutEntry {
116 binding: 1,
117 visibility: wgpu::ShaderStages::COMPUTE,
118 ty: wgpu::BindingType::Buffer {
119 ty: wgpu::BufferBindingType::Storage { read_only: true },
120 has_dynamic_offset: false,
121 min_binding_size: None,
122 },
123 count: None,
124 },
125 wgpu::BindGroupLayoutEntry {
126 binding: 2,
127 visibility: wgpu::ShaderStages::COMPUTE,
128 ty: wgpu::BindingType::Buffer {
129 ty: wgpu::BufferBindingType::Storage { read_only: false },
130 has_dynamic_offset: false,
131 min_binding_size: None,
132 },
133 count: None,
134 },
135 wgpu::BindGroupLayoutEntry {
136 binding: 3,
137 visibility: wgpu::ShaderStages::COMPUTE,
138 ty: wgpu::BindingType::Buffer {
139 ty: wgpu::BufferBindingType::Uniform,
140 has_dynamic_offset: false,
141 min_binding_size: None,
142 },
143 count: None,
144 },
145 wgpu::BindGroupLayoutEntry {
146 binding: 4,
147 visibility: wgpu::ShaderStages::COMPUTE,
148 ty: wgpu::BindingType::Buffer {
149 ty: wgpu::BufferBindingType::Storage { read_only: true },
150 has_dynamic_offset: false,
151 min_binding_size: None,
152 },
153 count: None,
154 },
155 wgpu::BindGroupLayoutEntry {
156 binding: 5,
157 visibility: wgpu::ShaderStages::COMPUTE,
158 ty: wgpu::BindingType::Buffer {
159 ty: wgpu::BufferBindingType::Storage { read_only: true },
160 has_dynamic_offset: false,
161 min_binding_size: None,
162 },
163 count: None,
164 },
165 wgpu::BindGroupLayoutEntry {
166 binding: 6,
167 visibility: wgpu::ShaderStages::COMPUTE,
168 ty: wgpu::BindingType::Buffer {
169 ty: wgpu::BufferBindingType::Storage { read_only: false },
170 has_dynamic_offset: false,
171 min_binding_size: None,
172 },
173 count: None,
174 },
175 ],
176 });
177
178 let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
179 label: Some("scatter2-pack-pipeline-layout"),
180 bind_group_layouts: &[&bind_group_layout],
181 push_constant_ranges: &[],
182 });
183
184 let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
185 label: Some("scatter2-pack-pipeline"),
186 layout: Some(&pipeline_layout),
187 module: &shader,
188 entry_point: "main",
189 });
190
191 let output_size = max_points as u64 * std::mem::size_of::<Vertex>() as u64;
192 let output_buffer = Arc::new(device.create_buffer(&wgpu::BufferDescriptor {
193 label: Some("scatter2-gpu-vertices"),
194 size: output_size,
195 usage: wgpu::BufferUsages::STORAGE
196 | wgpu::BufferUsages::VERTEX
197 | wgpu::BufferUsages::COPY_DST,
198 mapped_at_creation: false,
199 }));
200
201 let indirect_args = Arc::new(device.create_buffer(&wgpu::BufferDescriptor {
202 label: Some("scatter2-gpu-indirect-args"),
203 size: std::mem::size_of::<DrawIndirectArgsRaw>() as u64,
204 usage: wgpu::BufferUsages::STORAGE
205 | wgpu::BufferUsages::INDIRECT
206 | wgpu::BufferUsages::COPY_DST,
207 mapped_at_creation: false,
208 }));
209 let init = DrawIndirectArgsRaw {
210 vertex_count: 0,
211 instance_count: 1,
212 first_vertex: 0,
213 first_instance: 0,
214 };
215 queue.write_buffer(&indirect_args, 0, bytemuck::bytes_of(&init));
216
217 let (size_buffer, has_sizes) = prepare_size_buffer(device, params);
218 let (color_buffer, has_colors, color_stride) = prepare_color_buffer(device, params);
219
220 let uniforms = Scatter2Uniforms {
221 color: params.color.to_array(),
222 point_size: params.point_size,
223 count: inputs.len,
224 lod_stride,
225 has_sizes: if has_sizes { 1 } else { 0 },
226 has_colors: if has_colors { 1 } else { 0 },
227 color_stride,
228 };
229 let uniform_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
230 label: Some("scatter2-pack-uniforms"),
231 contents: bytemuck::bytes_of(&uniforms),
232 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
233 });
234
235 let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
236 label: Some("scatter2-pack-bind-group"),
237 layout: &bind_group_layout,
238 entries: &[
239 wgpu::BindGroupEntry {
240 binding: 0,
241 resource: inputs.x_buffer.as_entire_binding(),
242 },
243 wgpu::BindGroupEntry {
244 binding: 1,
245 resource: inputs.y_buffer.as_entire_binding(),
246 },
247 wgpu::BindGroupEntry {
248 binding: 2,
249 resource: output_buffer.as_entire_binding(),
250 },
251 wgpu::BindGroupEntry {
252 binding: 3,
253 resource: uniform_buffer.as_entire_binding(),
254 },
255 wgpu::BindGroupEntry {
256 binding: 4,
257 resource: size_buffer.as_entire_binding(),
258 },
259 wgpu::BindGroupEntry {
260 binding: 5,
261 resource: color_buffer.as_entire_binding(),
262 },
263 wgpu::BindGroupEntry {
264 binding: 6,
265 resource: indirect_args.as_entire_binding(),
266 },
267 ],
268 });
269
270 let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
271 label: Some("scatter2-pack-encoder"),
272 });
273 {
274 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
275 label: Some("scatter2-pack-pass"),
276 timestamp_writes: None,
277 });
278 pass.set_pipeline(&pipeline);
279 pass.set_bind_group(0, &bind_group, &[]);
280 let workgroups = inputs.len.div_ceil(workgroup_size);
281 pass.dispatch_workgroups(workgroups, 1, 1);
282 }
283 queue.submit(Some(encoder.finish()));
284
285 Ok(GpuVertexBuffer::with_indirect(
286 output_buffer,
287 max_points as usize,
288 indirect_args,
289 ))
290}
291
292fn prepare_size_buffer(
293 device: &Arc<wgpu::Device>,
294 params: &Scatter2GpuParams,
295) -> (Arc<wgpu::Buffer>, bool) {
296 match ¶ms.sizes {
297 ScatterAttributeBuffer::None => (
298 Arc::new(
299 device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
300 label: Some("scatter2-size-fallback"),
301 contents: bytemuck::cast_slice(&[0.0f32]),
302 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
303 }),
304 ),
305 false,
306 ),
307 ScatterAttributeBuffer::Host(data) => {
308 if data.is_empty() {
309 (
310 Arc::new(
311 device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
312 label: Some("scatter2-size-fallback"),
313 contents: bytemuck::cast_slice(&[0.0f32]),
314 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
315 }),
316 ),
317 false,
318 )
319 } else {
320 (
321 Arc::new(
322 device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
323 label: Some("scatter2-size-host"),
324 contents: bytemuck::cast_slice(data.as_slice()),
325 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
326 }),
327 ),
328 true,
329 )
330 }
331 }
332 ScatterAttributeBuffer::Gpu(buffer) => (buffer.clone(), true),
333 }
334}
335
336fn prepare_color_buffer(
337 device: &Arc<wgpu::Device>,
338 params: &Scatter2GpuParams,
339) -> (Arc<wgpu::Buffer>, bool, u32) {
340 match ¶ms.colors {
341 ScatterColorBuffer::None => (
342 Arc::new(
343 device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
344 label: Some("scatter2-color-fallback"),
345 contents: bytemuck::cast_slice(&[
346 params.color.x,
347 params.color.y,
348 params.color.z,
349 params.color.w,
350 ]),
351 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
352 }),
353 ),
354 false,
355 4,
356 ),
357 ScatterColorBuffer::Host(colors) => {
358 if colors.is_empty() {
359 (
360 Arc::new(
361 device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
362 label: Some("scatter2-color-fallback"),
363 contents: bytemuck::cast_slice(&[
364 params.color.x,
365 params.color.y,
366 params.color.z,
367 params.color.w,
368 ]),
369 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
370 }),
371 ),
372 false,
373 4,
374 )
375 } else {
376 (
377 Arc::new(
378 device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
379 label: Some("scatter2-color-host"),
380 contents: bytemuck::cast_slice(colors.as_slice()),
381 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
382 }),
383 ),
384 true,
385 4,
386 )
387 }
388 }
389 ScatterColorBuffer::Gpu { buffer, components } => (buffer.clone(), true, *components),
390 }
391}
392
393fn compile_shader(
394 device: &Arc<wgpu::Device>,
395 workgroup_size: u32,
396 scalar: ScalarType,
397) -> wgpu::ShaderModule {
398 let template = match scalar {
399 ScalarType::F32 => shaders::scatter2::F32,
400 ScalarType::F64 => shaders::scatter2::F64,
401 };
402 let source = template.replace("{{WORKGROUP_SIZE}}", &workgroup_size.to_string());
403 device.create_shader_module(wgpu::ShaderModuleDescriptor {
404 label: Some("scatter2-pack-shader"),
405 source: wgpu::ShaderSource::Wgsl(source.into()),
406 })
407}
408
409#[cfg(test)]
410mod stress_tests {
411 use super::*;
412 use pollster::FutureExt;
413
414 fn maybe_device() -> Option<(Arc<wgpu::Device>, Arc<wgpu::Queue>)> {
415 if std::env::var("RUNMAT_PLOT_SKIP_GPU_TESTS").is_ok()
416 || std::env::var("RUNMAT_PLOT_FORCE_GPU_TESTS").is_err()
417 {
418 return None;
419 }
420 let instance = wgpu::Instance::default();
421 let adapter = instance
422 .request_adapter(&wgpu::RequestAdapterOptions {
423 power_preference: wgpu::PowerPreference::HighPerformance,
424 compatible_surface: None,
425 force_fallback_adapter: false,
426 })
427 .block_on()?;
428 let limits = adapter.limits();
429 let (device, queue) = adapter
430 .request_device(
431 &wgpu::DeviceDescriptor {
432 label: Some("runmat-plot-scatter-test-device"),
433 required_features: wgpu::Features::empty(),
434 required_limits: limits,
435 },
436 None,
437 )
438 .block_on()
439 .ok()?;
440 Some((Arc::new(device), Arc::new(queue)))
441 }
442
443 #[test]
444 fn gpu_packer_handles_large_point_cloud() {
445 let Some((device, queue)) = maybe_device() else {
446 return;
447 };
448 let point_count = 1_200_000u32;
449 let x_data: Vec<f32> = (0..point_count).map(|i| i as f32 * 0.001).collect();
450 let y_data: Vec<f32> = x_data.iter().map(|v| v.sin()).collect();
451
452 let x_buffer = Arc::new(
453 device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
454 label: Some("scatter2-test-x"),
455 contents: bytemuck::cast_slice(&x_data),
456 usage: wgpu::BufferUsages::STORAGE,
457 }),
458 );
459 let y_buffer = Arc::new(
460 device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
461 label: Some("scatter2-test-y"),
462 contents: bytemuck::cast_slice(&y_data),
463 usage: wgpu::BufferUsages::STORAGE,
464 }),
465 );
466
467 let target = 250_000u32;
468 let stride = if point_count <= target {
469 1
470 } else {
471 point_count.div_ceil(target)
472 };
473 let expected_vertices = point_count.div_ceil(stride) as usize;
474
475 let inputs = Scatter2GpuInputs {
476 x_buffer,
477 y_buffer,
478 len: point_count,
479 scalar: ScalarType::F32,
480 };
481 let params = Scatter2GpuParams {
482 color: Vec4::new(0.8, 0.1, 0.3, 1.0),
483 point_size: 8.0,
484 sizes: ScatterAttributeBuffer::None,
485 colors: ScatterColorBuffer::None,
486 lod_stride: stride,
487 };
488
489 let gpu_vertices =
490 pack_vertices_from_xy(&device, &queue, &inputs, ¶ms).expect("gpu packing failed");
491 assert!(gpu_vertices.vertex_count > 0);
492 assert_eq!(gpu_vertices.vertex_count, expected_vertices);
493 }
494}