1use bytemuck::{Pod, Zeroable};
6use wgpu::util::DeviceExt;
7
8use crate::spatial::{SpatialConfig, MORTON_WGSL};
9
10const WORKGROUP_SIZE: u32 = 256;
11const RADIX_BITS: u32 = 4;
12const RADIX_SIZE: u32 = 16; fn calculate_sort_passes(grid_resolution: u32) -> u32 {
18 let bits_per_axis = (grid_resolution as f32).log2().ceil() as u32;
19 let total_bits = bits_per_axis * 3; let passes = total_bits.div_ceil(RADIX_BITS);
21 if passes % 2 == 1 { passes + 1 } else { passes }
23}
24
25#[repr(C)]
27#[derive(Copy, Clone, Pod, Zeroable)]
28pub struct SpatialParams {
29 pub cell_size: f32,
30 pub grid_resolution: u32,
31 pub num_particles: u32,
32 pub max_neighbors: u32,
33}
34
35#[repr(C)]
36#[derive(Copy, Clone, Pod, Zeroable)]
37struct SortParams {
38 num_elements: u32,
39 bit_offset: u32,
40 _pad0: u32,
41 _pad1: u32,
42}
43
44#[allow(dead_code)] pub struct SpatialGpu {
47 morton_codes_a: wgpu::Buffer,
49 morton_codes_b: wgpu::Buffer,
50 pub particle_indices_a: wgpu::Buffer,
51 particle_indices_b: wgpu::Buffer,
52 histogram: wgpu::Buffer,
53 pub cell_start: wgpu::Buffer,
54 pub cell_end: wgpu::Buffer,
55 pub spatial_params_buffer: wgpu::Buffer,
56 pub sort_params_buffer: wgpu::Buffer,
57
58 compute_morton_pipeline: wgpu::ComputePipeline,
60 histogram_pipeline: wgpu::ComputePipeline,
61 prefix_sum_pipeline: wgpu::ComputePipeline,
62 scatter_pipeline: wgpu::ComputePipeline,
63 build_cells_pipeline: wgpu::ComputePipeline,
64 clear_histogram_pipeline: wgpu::ComputePipeline,
65 clear_cells_pipeline: wgpu::ComputePipeline,
66
67 morton_bind_group: wgpu::BindGroup,
69 histogram_bind_group_a: wgpu::BindGroup,
70 histogram_bind_group_b: wgpu::BindGroup,
71 prefix_sum_bind_group: wgpu::BindGroup,
72 scatter_bind_group_a_to_b: wgpu::BindGroup,
73 scatter_bind_group_b_to_a: wgpu::BindGroup,
74 build_cells_bind_group: wgpu::BindGroup,
75 clear_histogram_bind_group: wgpu::BindGroup,
76 clear_cells_bind_group: wgpu::BindGroup,
77
78 pub config: SpatialConfig,
79 num_particles: u32,
80 sort_passes: u32,
81}
82
83impl SpatialGpu {
84 pub fn new(
86 device: &wgpu::Device,
87 particle_buffer: &wgpu::Buffer,
88 num_particles: u32,
89 config: SpatialConfig,
90 particle_wgsl_struct: &str,
91 ) -> Self {
92 let buffer_size = (num_particles as usize * std::mem::size_of::<u32>()) as u64;
94
95 let morton_codes_a = device.create_buffer(&wgpu::BufferDescriptor {
96 label: Some("Morton Codes A"),
97 size: buffer_size,
98 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
99 mapped_at_creation: false,
100 });
101
102 let morton_codes_b = device.create_buffer(&wgpu::BufferDescriptor {
103 label: Some("Morton Codes B"),
104 size: buffer_size,
105 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
106 mapped_at_creation: false,
107 });
108
109 let particle_indices_a = device.create_buffer(&wgpu::BufferDescriptor {
110 label: Some("Particle Indices A"),
111 size: buffer_size,
112 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
113 mapped_at_creation: false,
114 });
115
116 let particle_indices_b = device.create_buffer(&wgpu::BufferDescriptor {
117 label: Some("Particle Indices B"),
118 size: buffer_size,
119 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
120 mapped_at_creation: false,
121 });
122
123 let histogram = device.create_buffer(&wgpu::BufferDescriptor {
124 label: Some("Radix Histogram"),
125 size: (RADIX_SIZE as usize * std::mem::size_of::<u32>()) as u64,
126 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
127 mapped_at_creation: false,
128 });
129
130 let total_cells = config.total_cells();
131 let cell_table_size = (total_cells as usize * std::mem::size_of::<u32>()) as u64;
132
133 let cell_start = device.create_buffer(&wgpu::BufferDescriptor {
134 label: Some("Cell Start"),
135 size: cell_table_size,
136 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
137 mapped_at_creation: false,
138 });
139
140 let cell_end = device.create_buffer(&wgpu::BufferDescriptor {
141 label: Some("Cell End"),
142 size: cell_table_size,
143 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
144 mapped_at_creation: false,
145 });
146
147 let spatial_params = SpatialParams {
148 cell_size: config.cell_size,
149 grid_resolution: config.grid_resolution,
150 num_particles,
151 max_neighbors: config.max_neighbors,
152 };
153
154 let spatial_params_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
155 label: Some("Spatial Params"),
156 contents: bytemuck::cast_slice(&[spatial_params]),
157 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
158 });
159
160 let sort_params_buffer = device.create_buffer(&wgpu::BufferDescriptor {
161 label: Some("Sort Params"),
162 size: std::mem::size_of::<SortParams>() as u64,
163 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
164 mapped_at_creation: false,
165 });
166
167 let (
169 compute_morton_pipeline,
170 histogram_pipeline,
171 prefix_sum_pipeline,
172 scatter_pipeline,
173 build_cells_pipeline,
174 clear_histogram_pipeline,
175 clear_cells_pipeline,
176 ) = create_pipelines(device, particle_wgsl_struct);
177
178 let morton_bind_group = create_morton_bind_group(
180 device,
181 &compute_morton_pipeline,
182 particle_buffer,
183 &morton_codes_a,
184 &particle_indices_a,
185 &spatial_params_buffer,
186 );
187
188 let histogram_bind_group_a = create_histogram_bind_group(
189 device,
190 &histogram_pipeline,
191 &morton_codes_a,
192 &histogram,
193 &sort_params_buffer,
194 );
195
196 let histogram_bind_group_b = create_histogram_bind_group(
197 device,
198 &histogram_pipeline,
199 &morton_codes_b,
200 &histogram,
201 &sort_params_buffer,
202 );
203
204 let prefix_sum_bind_group = create_prefix_sum_bind_group(
205 device,
206 &prefix_sum_pipeline,
207 &histogram,
208 );
209
210 let scatter_bind_group_a_to_b = create_scatter_bind_group(
211 device,
212 &scatter_pipeline,
213 &morton_codes_a,
214 &particle_indices_a,
215 &morton_codes_b,
216 &particle_indices_b,
217 &histogram,
218 &sort_params_buffer,
219 );
220
221 let scatter_bind_group_b_to_a = create_scatter_bind_group(
222 device,
223 &scatter_pipeline,
224 &morton_codes_b,
225 &particle_indices_b,
226 &morton_codes_a,
227 &particle_indices_a,
228 &histogram,
229 &sort_params_buffer,
230 );
231
232 let build_cells_bind_group = create_build_cells_bind_group(
233 device,
234 &build_cells_pipeline,
235 &morton_codes_a, &cell_start,
237 &cell_end,
238 &spatial_params_buffer,
239 );
240
241 let clear_histogram_bind_group = create_clear_bind_group(
242 device,
243 &clear_histogram_pipeline,
244 &histogram,
245 RADIX_SIZE,
246 );
247
248 let clear_cells_bind_group = create_clear_bind_group(
249 device,
250 &clear_cells_pipeline,
251 &cell_start,
252 total_cells,
253 );
254
255 let sort_passes = calculate_sort_passes(config.grid_resolution);
256
257 Self {
258 morton_codes_a,
259 morton_codes_b,
260 particle_indices_a,
261 particle_indices_b,
262 histogram,
263 cell_start,
264 cell_end,
265 spatial_params_buffer,
266 sort_params_buffer,
267 compute_morton_pipeline,
268 histogram_pipeline,
269 prefix_sum_pipeline,
270 scatter_pipeline,
271 build_cells_pipeline,
272 clear_histogram_pipeline,
273 clear_cells_pipeline,
274 morton_bind_group,
275 histogram_bind_group_a,
276 histogram_bind_group_b,
277 prefix_sum_bind_group,
278 scatter_bind_group_a_to_b,
279 scatter_bind_group_b_to_a,
280 build_cells_bind_group,
281 clear_histogram_bind_group,
282 clear_cells_bind_group,
283 config,
284 num_particles,
285 sort_passes,
286 }
287 }
288
289 pub fn execute(&self, encoder: &mut wgpu::CommandEncoder, queue: &wgpu::Queue) {
291 let workgroups = self.num_particles.div_ceil(WORKGROUP_SIZE);
292
293 {
295 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
296 label: Some("Compute Morton"),
297 timestamp_writes: None,
298 });
299 pass.set_pipeline(&self.compute_morton_pipeline);
300 pass.set_bind_group(0, &self.morton_bind_group, &[]);
301 pass.dispatch_workgroups(workgroups, 1, 1);
302 }
303
304 let mut source_is_a = true;
306
307 for pass_idx in 0..self.sort_passes {
308 let bit_offset = pass_idx * RADIX_BITS;
309
310 let sort_params = SortParams {
312 num_elements: self.num_particles,
313 bit_offset,
314 _pad0: 0,
315 _pad1: 0,
316 };
317 queue.write_buffer(&self.sort_params_buffer, 0, bytemuck::cast_slice(&[sort_params]));
318
319 {
321 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
322 label: Some("Clear Histogram"),
323 timestamp_writes: None,
324 });
325 pass.set_pipeline(&self.clear_histogram_pipeline);
326 pass.set_bind_group(0, &self.clear_histogram_bind_group, &[]);
327 pass.dispatch_workgroups(1, 1, 1);
328 }
329
330 {
332 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
333 label: Some("Radix Histogram"),
334 timestamp_writes: None,
335 });
336 pass.set_pipeline(&self.histogram_pipeline);
337 pass.set_bind_group(
338 0,
339 if source_is_a { &self.histogram_bind_group_a } else { &self.histogram_bind_group_b },
340 &[],
341 );
342 pass.dispatch_workgroups(workgroups, 1, 1);
343 }
344
345 {
347 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
348 label: Some("Prefix Sum"),
349 timestamp_writes: None,
350 });
351 pass.set_pipeline(&self.prefix_sum_pipeline);
352 pass.set_bind_group(0, &self.prefix_sum_bind_group, &[]);
353 pass.dispatch_workgroups(1, 1, 1);
354 }
355
356 {
358 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
359 label: Some("Radix Scatter"),
360 timestamp_writes: None,
361 });
362 pass.set_pipeline(&self.scatter_pipeline);
363 pass.set_bind_group(
364 0,
365 if source_is_a { &self.scatter_bind_group_a_to_b } else { &self.scatter_bind_group_b_to_a },
366 &[],
367 );
368 pass.dispatch_workgroups(workgroups, 1, 1);
369 }
370
371 source_is_a = !source_is_a;
372 }
373
374 {
377 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
378 label: Some("Clear Cells"),
379 timestamp_writes: None,
380 });
381 pass.set_pipeline(&self.clear_cells_pipeline);
382 pass.set_bind_group(0, &self.clear_cells_bind_group, &[]);
383 let cell_workgroups = self.config.total_cells().div_ceil(WORKGROUP_SIZE);
384 pass.dispatch_workgroups(cell_workgroups, 1, 1);
385 }
386
387 {
389 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
390 label: Some("Build Cell Table"),
391 timestamp_writes: None,
392 });
393 pass.set_pipeline(&self.build_cells_pipeline);
394 pass.set_bind_group(0, &self.build_cells_bind_group, &[]);
395 pass.dispatch_workgroups(workgroups, 1, 1);
396 }
397 }
398}
399
400fn create_pipelines(
401 device: &wgpu::Device,
402 particle_wgsl_struct: &str,
403) -> (
404 wgpu::ComputePipeline,
405 wgpu::ComputePipeline,
406 wgpu::ComputePipeline,
407 wgpu::ComputePipeline,
408 wgpu::ComputePipeline,
409 wgpu::ComputePipeline,
410 wgpu::ComputePipeline,
411) {
412 let morton_shader_src = format!(
414 r#"{}
415
416{}
417
418struct SpatialParams {{
419 cell_size: f32,
420 grid_resolution: u32,
421 num_particles: u32,
422 max_neighbors: u32,
423}};
424
425@group(0) @binding(0) var<storage, read> particles: array<Particle>;
426@group(0) @binding(1) var<storage, read_write> morton_codes: array<u32>;
427@group(0) @binding(2) var<storage, read_write> particle_indices: array<u32>;
428@group(0) @binding(3) var<uniform> params: SpatialParams;
429
430@compute @workgroup_size(256)
431fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {{
432 let idx = global_id.x;
433 if idx >= params.num_particles {{
434 return;
435 }}
436
437 let pos = particles[idx].position;
438 morton_codes[idx] = pos_to_morton(pos, params.cell_size, params.grid_resolution);
439 particle_indices[idx] = idx;
440}}
441"#,
442 MORTON_WGSL,
443 particle_wgsl_struct
444 );
445
446 let morton_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
447 label: Some("Morton Shader"),
448 source: wgpu::ShaderSource::Wgsl(morton_shader_src.into()),
449 });
450
451 let histogram_shader_src = r#"
453struct SortParams {
454 num_elements: u32,
455 bit_offset: u32,
456 _pad0: u32,
457 _pad1: u32,
458};
459
460@group(0) @binding(0) var<storage, read> keys: array<u32>;
461@group(0) @binding(1) var<storage, read_write> histogram: array<atomic<u32>>;
462@group(0) @binding(2) var<uniform> params: SortParams;
463
464const RADIX_SIZE: u32 = 16u;
465
466@compute @workgroup_size(256)
467fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
468 let idx = global_id.x;
469 if idx >= params.num_elements {
470 return;
471 }
472
473 let key = keys[idx];
474 let digit = (key >> params.bit_offset) & (RADIX_SIZE - 1u);
475 atomicAdd(&histogram[digit], 1u);
476}
477"#;
478
479 let histogram_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
480 label: Some("Histogram Shader"),
481 source: wgpu::ShaderSource::Wgsl(histogram_shader_src.into()),
482 });
483
484 let prefix_sum_shader_src = r#"
486@group(0) @binding(0) var<storage, read_write> data: array<u32>;
487
488var<workgroup> temp: array<u32, 16>;
489
490@compute @workgroup_size(16)
491fn main(@builtin(local_invocation_id) local_id: vec3<u32>) {
492 let tid = local_id.x;
493
494 // Load into shared memory
495 temp[tid] = data[tid];
496 workgroupBarrier();
497
498 // Inclusive scan using up-sweep and down-sweep
499 // Up-sweep
500 for (var stride = 1u; stride < 16u; stride *= 2u) {
501 if tid >= stride {
502 temp[tid] += temp[tid - stride];
503 }
504 workgroupBarrier();
505 }
506
507 // Convert to exclusive scan
508 let inclusive = temp[tid];
509 workgroupBarrier();
510
511 if tid == 0u {
512 data[tid] = 0u;
513 } else {
514 data[tid] = temp[tid - 1u];
515 }
516}
517"#;
518
519 let prefix_sum_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
520 label: Some("Prefix Sum Shader"),
521 source: wgpu::ShaderSource::Wgsl(prefix_sum_shader_src.into()),
522 });
523
524 let scatter_shader_src = r#"
526struct SortParams {
527 num_elements: u32,
528 bit_offset: u32,
529 _pad0: u32,
530 _pad1: u32,
531};
532
533@group(0) @binding(0) var<storage, read> keys_in: array<u32>;
534@group(0) @binding(1) var<storage, read> vals_in: array<u32>;
535@group(0) @binding(2) var<storage, read_write> keys_out: array<u32>;
536@group(0) @binding(3) var<storage, read_write> vals_out: array<u32>;
537@group(0) @binding(4) var<storage, read_write> offsets: array<atomic<u32>>;
538@group(0) @binding(5) var<uniform> params: SortParams;
539
540const RADIX_SIZE: u32 = 16u;
541
542@compute @workgroup_size(256)
543fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
544 let idx = global_id.x;
545 if idx >= params.num_elements {
546 return;
547 }
548
549 let key = keys_in[idx];
550 let val = vals_in[idx];
551 let digit = (key >> params.bit_offset) & (RADIX_SIZE - 1u);
552
553 let dest = atomicAdd(&offsets[digit], 1u);
554
555 keys_out[dest] = key;
556 vals_out[dest] = val;
557}
558"#;
559
560 let scatter_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
561 label: Some("Scatter Shader"),
562 source: wgpu::ShaderSource::Wgsl(scatter_shader_src.into()),
563 });
564
565 let build_cells_shader_src = r#"
567struct SpatialParams {
568 cell_size: f32,
569 grid_resolution: u32,
570 num_particles: u32,
571 max_neighbors: u32,
572};
573
574@group(0) @binding(0) var<storage, read> sorted_morton: array<u32>;
575@group(0) @binding(1) var<storage, read_write> cell_start: array<u32>;
576@group(0) @binding(2) var<storage, read_write> cell_end: array<u32>;
577@group(0) @binding(3) var<uniform> params: SpatialParams;
578
579@compute @workgroup_size(256)
580fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
581 let idx = global_id.x;
582 if idx >= params.num_particles {
583 return;
584 }
585
586 let code = sorted_morton[idx];
587
588 if idx == 0u {
589 cell_start[code] = 0u;
590 } else {
591 let prev_code = sorted_morton[idx - 1u];
592 if code != prev_code {
593 cell_start[code] = idx;
594 cell_end[prev_code] = idx;
595 }
596 }
597
598 if idx == params.num_particles - 1u {
599 cell_end[code] = params.num_particles;
600 }
601}
602"#;
603
604 let build_cells_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
605 label: Some("Build Cells Shader"),
606 source: wgpu::ShaderSource::Wgsl(build_cells_shader_src.into()),
607 });
608
609 let clear_shader_src = r#"
611@group(0) @binding(0) var<storage, read_write> data: array<u32>;
612
613@compute @workgroup_size(256)
614fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
615 let idx = global_id.x;
616 if idx < arrayLength(&data) {
617 data[idx] = 0xFFFFFFFFu;
618 }
619}
620"#;
621
622 let clear_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
623 label: Some("Clear Shader"),
624 source: wgpu::ShaderSource::Wgsl(clear_shader_src.into()),
625 });
626
627 let morton_pipeline = create_compute_pipeline(device, &morton_shader, "main", "Morton Pipeline");
629 let histogram_pipeline = create_compute_pipeline(device, &histogram_shader, "main", "Histogram Pipeline");
630 let prefix_sum_pipeline = create_compute_pipeline(device, &prefix_sum_shader, "main", "Prefix Sum Pipeline");
631 let scatter_pipeline = create_compute_pipeline(device, &scatter_shader, "main", "Scatter Pipeline");
632 let build_cells_pipeline = create_compute_pipeline(device, &build_cells_shader, "main", "Build Cells Pipeline");
633 let clear_histogram_pipeline = create_compute_pipeline(device, &clear_shader, "main", "Clear Histogram Pipeline");
634 let clear_cells_pipeline = create_compute_pipeline(device, &clear_shader, "main", "Clear Cells Pipeline");
635
636 (
637 morton_pipeline,
638 histogram_pipeline,
639 prefix_sum_pipeline,
640 scatter_pipeline,
641 build_cells_pipeline,
642 clear_histogram_pipeline,
643 clear_cells_pipeline,
644 )
645}
646
647fn create_compute_pipeline(
648 device: &wgpu::Device,
649 shader: &wgpu::ShaderModule,
650 entry_point: &str,
651 label: &str,
652) -> wgpu::ComputePipeline {
653 device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
654 label: Some(label),
655 layout: None, module: shader,
657 entry_point: Some(entry_point),
658 compilation_options: Default::default(),
659 cache: None,
660 })
661}
662
663fn create_morton_bind_group(
664 device: &wgpu::Device,
665 pipeline: &wgpu::ComputePipeline,
666 particles: &wgpu::Buffer,
667 morton_codes: &wgpu::Buffer,
668 particle_indices: &wgpu::Buffer,
669 params: &wgpu::Buffer,
670) -> wgpu::BindGroup {
671 let layout = pipeline.get_bind_group_layout(0);
672 device.create_bind_group(&wgpu::BindGroupDescriptor {
673 label: Some("Morton Bind Group"),
674 layout: &layout,
675 entries: &[
676 wgpu::BindGroupEntry { binding: 0, resource: particles.as_entire_binding() },
677 wgpu::BindGroupEntry { binding: 1, resource: morton_codes.as_entire_binding() },
678 wgpu::BindGroupEntry { binding: 2, resource: particle_indices.as_entire_binding() },
679 wgpu::BindGroupEntry { binding: 3, resource: params.as_entire_binding() },
680 ],
681 })
682}
683
684fn create_histogram_bind_group(
685 device: &wgpu::Device,
686 pipeline: &wgpu::ComputePipeline,
687 keys: &wgpu::Buffer,
688 histogram: &wgpu::Buffer,
689 params: &wgpu::Buffer,
690) -> wgpu::BindGroup {
691 let layout = pipeline.get_bind_group_layout(0);
692 device.create_bind_group(&wgpu::BindGroupDescriptor {
693 label: Some("Histogram Bind Group"),
694 layout: &layout,
695 entries: &[
696 wgpu::BindGroupEntry { binding: 0, resource: keys.as_entire_binding() },
697 wgpu::BindGroupEntry { binding: 1, resource: histogram.as_entire_binding() },
698 wgpu::BindGroupEntry { binding: 2, resource: params.as_entire_binding() },
699 ],
700 })
701}
702
703fn create_prefix_sum_bind_group(
704 device: &wgpu::Device,
705 pipeline: &wgpu::ComputePipeline,
706 data: &wgpu::Buffer,
707) -> wgpu::BindGroup {
708 let layout = pipeline.get_bind_group_layout(0);
709 device.create_bind_group(&wgpu::BindGroupDescriptor {
710 label: Some("Prefix Sum Bind Group"),
711 layout: &layout,
712 entries: &[
713 wgpu::BindGroupEntry { binding: 0, resource: data.as_entire_binding() },
714 ],
715 })
716}
717
718#[allow(clippy::too_many_arguments)]
719fn create_scatter_bind_group(
720 device: &wgpu::Device,
721 pipeline: &wgpu::ComputePipeline,
722 keys_in: &wgpu::Buffer,
723 vals_in: &wgpu::Buffer,
724 keys_out: &wgpu::Buffer,
725 vals_out: &wgpu::Buffer,
726 offsets: &wgpu::Buffer,
727 params: &wgpu::Buffer,
728) -> wgpu::BindGroup {
729 let layout = pipeline.get_bind_group_layout(0);
730 device.create_bind_group(&wgpu::BindGroupDescriptor {
731 label: Some("Scatter Bind Group"),
732 layout: &layout,
733 entries: &[
734 wgpu::BindGroupEntry { binding: 0, resource: keys_in.as_entire_binding() },
735 wgpu::BindGroupEntry { binding: 1, resource: vals_in.as_entire_binding() },
736 wgpu::BindGroupEntry { binding: 2, resource: keys_out.as_entire_binding() },
737 wgpu::BindGroupEntry { binding: 3, resource: vals_out.as_entire_binding() },
738 wgpu::BindGroupEntry { binding: 4, resource: offsets.as_entire_binding() },
739 wgpu::BindGroupEntry { binding: 5, resource: params.as_entire_binding() },
740 ],
741 })
742}
743
744fn create_build_cells_bind_group(
745 device: &wgpu::Device,
746 pipeline: &wgpu::ComputePipeline,
747 sorted_morton: &wgpu::Buffer,
748 cell_start: &wgpu::Buffer,
749 cell_end: &wgpu::Buffer,
750 params: &wgpu::Buffer,
751) -> wgpu::BindGroup {
752 let layout = pipeline.get_bind_group_layout(0);
753 device.create_bind_group(&wgpu::BindGroupDescriptor {
754 label: Some("Build Cells Bind Group"),
755 layout: &layout,
756 entries: &[
757 wgpu::BindGroupEntry { binding: 0, resource: sorted_morton.as_entire_binding() },
758 wgpu::BindGroupEntry { binding: 1, resource: cell_start.as_entire_binding() },
759 wgpu::BindGroupEntry { binding: 2, resource: cell_end.as_entire_binding() },
760 wgpu::BindGroupEntry { binding: 3, resource: params.as_entire_binding() },
761 ],
762 })
763}
764
765fn create_clear_bind_group(
766 device: &wgpu::Device,
767 pipeline: &wgpu::ComputePipeline,
768 buffer: &wgpu::Buffer,
769 _count: u32,
770) -> wgpu::BindGroup {
771 let layout = pipeline.get_bind_group_layout(0);
772 device.create_bind_group(&wgpu::BindGroupDescriptor {
773 label: Some("Clear Bind Group"),
774 layout: &layout,
775 entries: &[
776 wgpu::BindGroupEntry { binding: 0, resource: buffer.as_entire_binding() },
777 ],
778 })
779}