1use crate::core::renderer::Vertex;
2use crate::core::scene::{DrawIndirectArgsRaw, GpuVertexBuffer};
3use crate::gpu::shaders;
4use crate::gpu::{tuning, ScalarType};
5use glam::Vec4;
6use std::sync::Arc;
7use wgpu::util::DeviceExt;
8
9#[derive(Clone, Debug)]
11pub enum ScatterAttributeBuffer {
12 None,
13 Host(Vec<f32>),
14 Gpu(Arc<wgpu::Buffer>),
15}
16
17impl ScatterAttributeBuffer {
18 pub fn has_data(&self) -> bool {
19 !matches!(self, ScatterAttributeBuffer::None)
20 }
21}
22
23#[derive(Clone, Debug)]
26pub enum ScatterColorBuffer {
27 None,
28 Host(Vec<[f32; 4]>),
29 Gpu {
30 buffer: Arc<wgpu::Buffer>,
31 components: u32,
32 },
33}
34
35impl ScatterColorBuffer {
36 pub fn has_data(&self) -> bool {
37 !matches!(self, ScatterColorBuffer::None)
38 }
39
40 pub fn stride(&self) -> u32 {
41 match self {
42 ScatterColorBuffer::None => 4,
43 ScatterColorBuffer::Host(_) => 4,
44 ScatterColorBuffer::Gpu { components, .. } => *components,
45 }
46 }
47
48 pub fn buffer(&self) -> Option<Arc<wgpu::Buffer>> {
49 match self {
50 ScatterColorBuffer::Gpu { buffer, .. } => Some(buffer.clone()),
51 _ => None,
52 }
53 }
54}
55
56#[derive(Clone, Debug)]
58pub struct Scatter2GpuInputs {
59 pub x_buffer: Arc<wgpu::Buffer>,
60 pub y_buffer: Arc<wgpu::Buffer>,
61 pub len: u32,
62 pub scalar: ScalarType,
63}
64
65pub struct Scatter2GpuParams {
67 pub color: Vec4,
68 pub point_size: f32,
69 pub sizes: ScatterAttributeBuffer,
70 pub colors: ScatterColorBuffer,
71 pub lod_stride: u32,
72}
73
74#[repr(C)]
75#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
76struct Scatter2Uniforms {
77 color: [f32; 4],
78 point_size: f32,
79 count: u32,
80 lod_stride: u32,
81 has_sizes: u32,
82 has_colors: u32,
83 color_stride: u32,
84}
85
86pub fn pack_vertices_from_xy(
90 device: &Arc<wgpu::Device>,
91 queue: &Arc<wgpu::Queue>,
92 inputs: &Scatter2GpuInputs,
93 params: &Scatter2GpuParams,
94) -> Result<GpuVertexBuffer, String> {
95 if inputs.len == 0 {
96 return Err("scatter: empty input tensors".to_string());
97 }
98
99 let lod_stride = params.lod_stride.max(1);
100 let max_points = inputs.len.div_ceil(lod_stride);
101
102 let workgroup_size = tuning::effective_workgroup_size();
103 let shader = compile_shader(device, workgroup_size, inputs.scalar);
104
105 let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
106 label: Some("scatter2-pack-bind-layout"),
107 entries: &[
108 wgpu::BindGroupLayoutEntry {
109 binding: 0,
110 visibility: wgpu::ShaderStages::COMPUTE,
111 ty: wgpu::BindingType::Buffer {
112 ty: wgpu::BufferBindingType::Storage { read_only: true },
113 has_dynamic_offset: false,
114 min_binding_size: None,
115 },
116 count: None,
117 },
118 wgpu::BindGroupLayoutEntry {
119 binding: 1,
120 visibility: wgpu::ShaderStages::COMPUTE,
121 ty: wgpu::BindingType::Buffer {
122 ty: wgpu::BufferBindingType::Storage { read_only: true },
123 has_dynamic_offset: false,
124 min_binding_size: None,
125 },
126 count: None,
127 },
128 wgpu::BindGroupLayoutEntry {
129 binding: 2,
130 visibility: wgpu::ShaderStages::COMPUTE,
131 ty: wgpu::BindingType::Buffer {
132 ty: wgpu::BufferBindingType::Storage { read_only: false },
133 has_dynamic_offset: false,
134 min_binding_size: None,
135 },
136 count: None,
137 },
138 wgpu::BindGroupLayoutEntry {
139 binding: 3,
140 visibility: wgpu::ShaderStages::COMPUTE,
141 ty: wgpu::BindingType::Buffer {
142 ty: wgpu::BufferBindingType::Uniform,
143 has_dynamic_offset: false,
144 min_binding_size: None,
145 },
146 count: None,
147 },
148 wgpu::BindGroupLayoutEntry {
149 binding: 4,
150 visibility: wgpu::ShaderStages::COMPUTE,
151 ty: wgpu::BindingType::Buffer {
152 ty: wgpu::BufferBindingType::Storage { read_only: true },
153 has_dynamic_offset: false,
154 min_binding_size: None,
155 },
156 count: None,
157 },
158 wgpu::BindGroupLayoutEntry {
159 binding: 5,
160 visibility: wgpu::ShaderStages::COMPUTE,
161 ty: wgpu::BindingType::Buffer {
162 ty: wgpu::BufferBindingType::Storage { read_only: true },
163 has_dynamic_offset: false,
164 min_binding_size: None,
165 },
166 count: None,
167 },
168 wgpu::BindGroupLayoutEntry {
169 binding: 6,
170 visibility: wgpu::ShaderStages::COMPUTE,
171 ty: wgpu::BindingType::Buffer {
172 ty: wgpu::BufferBindingType::Storage { read_only: false },
173 has_dynamic_offset: false,
174 min_binding_size: None,
175 },
176 count: None,
177 },
178 ],
179 });
180
181 let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
182 label: Some("scatter2-pack-pipeline-layout"),
183 bind_group_layouts: &[&bind_group_layout],
184 push_constant_ranges: &[],
185 });
186
187 let pipeline =
188 device.create_compute_pipeline(&crate::wgpu_compat::wgpu_compute_pipeline_descriptor! {
189 label: Some("scatter2-pack-pipeline"),
190 layout: Some(&pipeline_layout),
191 module: &shader,
192 entry_point: "main",
193 });
194
195 let output_size = max_points as u64 * 6 * std::mem::size_of::<Vertex>() as u64;
196 let output_buffer = Arc::new(device.create_buffer(&wgpu::BufferDescriptor {
197 label: Some("scatter2-gpu-vertices"),
198 size: output_size,
199 usage: wgpu::BufferUsages::STORAGE
200 | wgpu::BufferUsages::VERTEX
201 | wgpu::BufferUsages::COPY_DST
202 | wgpu::BufferUsages::COPY_SRC,
203 mapped_at_creation: false,
204 }));
205
206 let indirect_args = Arc::new(device.create_buffer(&wgpu::BufferDescriptor {
207 label: Some("scatter2-gpu-indirect-args"),
208 size: std::mem::size_of::<DrawIndirectArgsRaw>() as u64,
209 usage: wgpu::BufferUsages::STORAGE
210 | wgpu::BufferUsages::INDIRECT
211 | wgpu::BufferUsages::COPY_DST
212 | wgpu::BufferUsages::COPY_SRC,
213 mapped_at_creation: false,
214 }));
215 let init = DrawIndirectArgsRaw {
216 vertex_count: 0,
217 instance_count: 1,
218 first_vertex: 0,
219 first_instance: 0,
220 };
221 queue.write_buffer(&indirect_args, 0, bytemuck::bytes_of(&init));
222
223 let (size_buffer, has_sizes) = prepare_size_buffer(device, params);
224 let (color_buffer, has_colors, color_stride) = prepare_color_buffer(device, params);
225
226 let uniforms = Scatter2Uniforms {
227 color: params.color.to_array(),
228 point_size: params.point_size,
229 count: inputs.len,
230 lod_stride,
231 has_sizes: if has_sizes { 1 } else { 0 },
232 has_colors: if has_colors { 1 } else { 0 },
233 color_stride,
234 };
235 let uniform_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
236 label: Some("scatter2-pack-uniforms"),
237 contents: bytemuck::bytes_of(&uniforms),
238 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
239 });
240
241 let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
242 label: Some("scatter2-pack-bind-group"),
243 layout: &bind_group_layout,
244 entries: &[
245 wgpu::BindGroupEntry {
246 binding: 0,
247 resource: inputs.x_buffer.as_entire_binding(),
248 },
249 wgpu::BindGroupEntry {
250 binding: 1,
251 resource: inputs.y_buffer.as_entire_binding(),
252 },
253 wgpu::BindGroupEntry {
254 binding: 2,
255 resource: output_buffer.as_entire_binding(),
256 },
257 wgpu::BindGroupEntry {
258 binding: 3,
259 resource: uniform_buffer.as_entire_binding(),
260 },
261 wgpu::BindGroupEntry {
262 binding: 4,
263 resource: size_buffer.as_entire_binding(),
264 },
265 wgpu::BindGroupEntry {
266 binding: 5,
267 resource: color_buffer.as_entire_binding(),
268 },
269 wgpu::BindGroupEntry {
270 binding: 6,
271 resource: indirect_args.as_entire_binding(),
272 },
273 ],
274 });
275
276 let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
277 label: Some("scatter2-pack-encoder"),
278 });
279 {
280 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
281 label: Some("scatter2-pack-pass"),
282 timestamp_writes: None,
283 });
284 pass.set_pipeline(&pipeline);
285 pass.set_bind_group(0, &bind_group, &[]);
286 let workgroups = inputs.len.div_ceil(workgroup_size);
287 pass.dispatch_workgroups(workgroups, 1, 1);
288 }
289 queue.submit(Some(encoder.finish()));
290
291 Ok(GpuVertexBuffer::with_indirect(
292 output_buffer,
293 (max_points as usize) * 6,
294 indirect_args,
295 ))
296}
297
298fn prepare_size_buffer(
299 device: &Arc<wgpu::Device>,
300 params: &Scatter2GpuParams,
301) -> (Arc<wgpu::Buffer>, bool) {
302 match ¶ms.sizes {
303 ScatterAttributeBuffer::None => (
304 Arc::new(
305 device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
306 label: Some("scatter2-size-fallback"),
307 contents: bytemuck::cast_slice(&[0.0f32]),
308 usage: wgpu::BufferUsages::STORAGE
309 | wgpu::BufferUsages::COPY_DST
310 | wgpu::BufferUsages::COPY_SRC,
311 }),
312 ),
313 false,
314 ),
315 ScatterAttributeBuffer::Host(data) => {
316 if data.is_empty() {
317 (
318 Arc::new(
319 device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
320 label: Some("scatter2-size-fallback"),
321 contents: bytemuck::cast_slice(&[0.0f32]),
322 usage: wgpu::BufferUsages::STORAGE
323 | wgpu::BufferUsages::COPY_DST
324 | wgpu::BufferUsages::COPY_SRC,
325 }),
326 ),
327 false,
328 )
329 } else {
330 (
331 Arc::new(
332 device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
333 label: Some("scatter2-size-host"),
334 contents: bytemuck::cast_slice(data.as_slice()),
335 usage: wgpu::BufferUsages::STORAGE
336 | wgpu::BufferUsages::COPY_DST
337 | wgpu::BufferUsages::COPY_SRC,
338 }),
339 ),
340 true,
341 )
342 }
343 }
344 ScatterAttributeBuffer::Gpu(buffer) => (buffer.clone(), true),
345 }
346}
347
348fn prepare_color_buffer(
349 device: &Arc<wgpu::Device>,
350 params: &Scatter2GpuParams,
351) -> (Arc<wgpu::Buffer>, bool, u32) {
352 match ¶ms.colors {
353 ScatterColorBuffer::None => (
354 Arc::new(
355 device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
356 label: Some("scatter2-color-fallback"),
357 contents: bytemuck::cast_slice(&[
358 params.color.x,
359 params.color.y,
360 params.color.z,
361 params.color.w,
362 ]),
363 usage: wgpu::BufferUsages::STORAGE
364 | wgpu::BufferUsages::COPY_DST
365 | wgpu::BufferUsages::COPY_SRC,
366 }),
367 ),
368 false,
369 4,
370 ),
371 ScatterColorBuffer::Host(colors) => {
372 if colors.is_empty() {
373 (
374 Arc::new(
375 device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
376 label: Some("scatter2-color-fallback"),
377 contents: bytemuck::cast_slice(&[
378 params.color.x,
379 params.color.y,
380 params.color.z,
381 params.color.w,
382 ]),
383 usage: wgpu::BufferUsages::STORAGE
384 | wgpu::BufferUsages::COPY_DST
385 | wgpu::BufferUsages::COPY_SRC,
386 }),
387 ),
388 false,
389 4,
390 )
391 } else {
392 (
393 Arc::new(
394 device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
395 label: Some("scatter2-color-host"),
396 contents: bytemuck::cast_slice(colors.as_slice()),
397 usage: wgpu::BufferUsages::STORAGE
398 | wgpu::BufferUsages::COPY_DST
399 | wgpu::BufferUsages::COPY_SRC,
400 }),
401 ),
402 true,
403 4,
404 )
405 }
406 }
407 ScatterColorBuffer::Gpu { buffer, components } => (buffer.clone(), true, *components),
408 }
409}
410
411fn compile_shader(
412 device: &Arc<wgpu::Device>,
413 workgroup_size: u32,
414 scalar: ScalarType,
415) -> wgpu::ShaderModule {
416 let template = match scalar {
417 ScalarType::F32 => shaders::scatter2::F32,
418 ScalarType::F64 => shaders::scatter2::F64,
419 };
420 let source = template.replace("{{WORKGROUP_SIZE}}", &workgroup_size.to_string());
421 device.create_shader_module(wgpu::ShaderModuleDescriptor {
422 label: Some("scatter2-pack-shader"),
423 source: wgpu::ShaderSource::Wgsl(source.into()),
424 })
425}
426
427#[cfg(test)]
428mod stress_tests {
429 use super::*;
430 use pollster::FutureExt;
431
432 fn maybe_device() -> Option<(Arc<wgpu::Device>, Arc<wgpu::Queue>)> {
433 if std::env::var("RUNMAT_PLOT_SKIP_GPU_TESTS").is_ok()
434 || std::env::var("RUNMAT_PLOT_FORCE_GPU_TESTS").is_err()
435 {
436 return None;
437 }
438 let instance = wgpu::Instance::default();
439 let adapter = instance
440 .request_adapter(&wgpu::RequestAdapterOptions {
441 power_preference: wgpu::PowerPreference::HighPerformance,
442 compatible_surface: None,
443 force_fallback_adapter: false,
444 })
445 .block_on()?;
446 let limits = adapter.limits();
447 let (device, queue) = adapter
448 .request_device(
449 &crate::wgpu_compat::device_descriptor(
450 Some("runmat-plot-scatter-test-device"),
451 wgpu::Features::empty(),
452 limits,
453 ),
454 None,
455 )
456 .block_on()
457 .ok()?;
458 Some((Arc::new(device), Arc::new(queue)))
459 }
460
461 #[test]
462 fn gpu_packer_handles_large_point_cloud() {
463 let Some((device, queue)) = maybe_device() else {
464 return;
465 };
466 let point_count = 1_200_000u32;
467 let x_data: Vec<f32> = (0..point_count).map(|i| i as f32 * 0.001).collect();
468 let y_data: Vec<f32> = x_data.iter().map(|v| v.sin()).collect();
469
470 let x_buffer = Arc::new(
471 device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
472 label: Some("scatter2-test-x"),
473 contents: bytemuck::cast_slice(&x_data),
474 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
475 }),
476 );
477 let y_buffer = Arc::new(
478 device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
479 label: Some("scatter2-test-y"),
480 contents: bytemuck::cast_slice(&y_data),
481 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
482 }),
483 );
484
485 let target = 250_000u32;
486 let stride = if point_count <= target {
487 1
488 } else {
489 point_count.div_ceil(target)
490 };
491 let expected_vertices = point_count.div_ceil(stride) as usize * 6;
492
493 let inputs = Scatter2GpuInputs {
494 x_buffer,
495 y_buffer,
496 len: point_count,
497 scalar: ScalarType::F32,
498 };
499 let params = Scatter2GpuParams {
500 color: Vec4::new(0.8, 0.1, 0.3, 1.0),
501 point_size: 8.0,
502 sizes: ScatterAttributeBuffer::None,
503 colors: ScatterColorBuffer::None,
504 lod_stride: stride,
505 };
506
507 let gpu_vertices =
508 pack_vertices_from_xy(&device, &queue, &inputs, ¶ms).expect("gpu packing failed");
509 assert!(gpu_vertices.vertex_count > 0);
510 assert_eq!(gpu_vertices.vertex_count, expected_vertices);
511 }
512}