rdpe/gpu/
spatial_gpu.rs

1//! GPU spatial hashing infrastructure
2//!
3//! Handles Morton code computation, radix sort, and cell table building.
4
5use bytemuck::{Pod, Zeroable};
6use wgpu::util::DeviceExt;
7
8use crate::spatial::{SpatialConfig, MORTON_WGSL};
9
10const WORKGROUP_SIZE: u32 = 256;
11const RADIX_BITS: u32 = 4;
12const RADIX_SIZE: u32 = 16; // 2^4
13
14/// Calculate number of sort passes needed based on grid resolution.
15/// Morton codes use 3 * log2(grid_resolution) bits.
16/// Always returns an even number so final result is in buffer A (for build_cells bind group).
17fn calculate_sort_passes(grid_resolution: u32) -> u32 {
18    let bits_per_axis = (grid_resolution as f32).log2().ceil() as u32;
19    let total_bits = bits_per_axis * 3; // Morton code interleaves 3 axes
20    let passes = total_bits.div_ceil(RADIX_BITS);
21    // Round up to even number so result ends in buffer A
22    if passes % 2 == 1 { passes + 1 } else { passes }
23}
24
25/// GPU-side parameters for spatial hashing, uploaded to shaders.
26#[repr(C)]
27#[derive(Copy, Clone, Pod, Zeroable)]
28pub struct SpatialParams {
29    pub cell_size: f32,
30    pub grid_resolution: u32,
31    pub num_particles: u32,
32    pub max_neighbors: u32,
33}
34
35#[repr(C)]
36#[derive(Copy, Clone, Pod, Zeroable)]
37struct SortParams {
38    num_elements: u32,
39    bit_offset: u32,
40    _pad0: u32,
41    _pad1: u32,
42}
43
44/// GPU resources for spatial hashing
45#[allow(dead_code)] // Fields used indirectly via bind groups
46pub struct SpatialGpu {
47    // Buffers
48    morton_codes_a: wgpu::Buffer,
49    morton_codes_b: wgpu::Buffer,
50    pub particle_indices_a: wgpu::Buffer,
51    particle_indices_b: wgpu::Buffer,
52    histogram: wgpu::Buffer,
53    pub cell_start: wgpu::Buffer,
54    pub cell_end: wgpu::Buffer,
55    pub spatial_params_buffer: wgpu::Buffer,
56    pub sort_params_buffer: wgpu::Buffer,
57
58    // Pipelines
59    compute_morton_pipeline: wgpu::ComputePipeline,
60    histogram_pipeline: wgpu::ComputePipeline,
61    prefix_sum_pipeline: wgpu::ComputePipeline,
62    scatter_pipeline: wgpu::ComputePipeline,
63    build_cells_pipeline: wgpu::ComputePipeline,
64    clear_histogram_pipeline: wgpu::ComputePipeline,
65    clear_cells_pipeline: wgpu::ComputePipeline,
66
67    // Bind groups (we'll need to swap for ping-pong)
68    morton_bind_group: wgpu::BindGroup,
69    histogram_bind_group_a: wgpu::BindGroup,
70    histogram_bind_group_b: wgpu::BindGroup,
71    prefix_sum_bind_group: wgpu::BindGroup,
72    scatter_bind_group_a_to_b: wgpu::BindGroup,
73    scatter_bind_group_b_to_a: wgpu::BindGroup,
74    build_cells_bind_group: wgpu::BindGroup,
75    clear_histogram_bind_group: wgpu::BindGroup,
76    clear_cells_bind_group: wgpu::BindGroup,
77
78    pub config: SpatialConfig,
79    num_particles: u32,
80    sort_passes: u32,
81}
82
83impl SpatialGpu {
84    /// Create a new spatial hashing system for the given particle buffer and configuration.
85    pub fn new(
86        device: &wgpu::Device,
87        particle_buffer: &wgpu::Buffer,
88        num_particles: u32,
89        config: SpatialConfig,
90        particle_wgsl_struct: &str,
91    ) -> Self {
92        // Create buffers
93        let buffer_size = (num_particles as usize * std::mem::size_of::<u32>()) as u64;
94
95        let morton_codes_a = device.create_buffer(&wgpu::BufferDescriptor {
96            label: Some("Morton Codes A"),
97            size: buffer_size,
98            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
99            mapped_at_creation: false,
100        });
101
102        let morton_codes_b = device.create_buffer(&wgpu::BufferDescriptor {
103            label: Some("Morton Codes B"),
104            size: buffer_size,
105            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
106            mapped_at_creation: false,
107        });
108
109        let particle_indices_a = device.create_buffer(&wgpu::BufferDescriptor {
110            label: Some("Particle Indices A"),
111            size: buffer_size,
112            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
113            mapped_at_creation: false,
114        });
115
116        let particle_indices_b = device.create_buffer(&wgpu::BufferDescriptor {
117            label: Some("Particle Indices B"),
118            size: buffer_size,
119            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
120            mapped_at_creation: false,
121        });
122
123        let histogram = device.create_buffer(&wgpu::BufferDescriptor {
124            label: Some("Radix Histogram"),
125            size: (RADIX_SIZE as usize * std::mem::size_of::<u32>()) as u64,
126            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
127            mapped_at_creation: false,
128        });
129
130        let total_cells = config.total_cells();
131        let cell_table_size = (total_cells as usize * std::mem::size_of::<u32>()) as u64;
132
133        let cell_start = device.create_buffer(&wgpu::BufferDescriptor {
134            label: Some("Cell Start"),
135            size: cell_table_size,
136            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
137            mapped_at_creation: false,
138        });
139
140        let cell_end = device.create_buffer(&wgpu::BufferDescriptor {
141            label: Some("Cell End"),
142            size: cell_table_size,
143            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
144            mapped_at_creation: false,
145        });
146
147        let spatial_params = SpatialParams {
148            cell_size: config.cell_size,
149            grid_resolution: config.grid_resolution,
150            num_particles,
151            max_neighbors: config.max_neighbors,
152        };
153
154        let spatial_params_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
155            label: Some("Spatial Params"),
156            contents: bytemuck::cast_slice(&[spatial_params]),
157            usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
158        });
159
160        let sort_params_buffer = device.create_buffer(&wgpu::BufferDescriptor {
161            label: Some("Sort Params"),
162            size: std::mem::size_of::<SortParams>() as u64,
163            usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
164            mapped_at_creation: false,
165        });
166
167        // Create shaders and pipelines
168        let (
169            compute_morton_pipeline,
170            histogram_pipeline,
171            prefix_sum_pipeline,
172            scatter_pipeline,
173            build_cells_pipeline,
174            clear_histogram_pipeline,
175            clear_cells_pipeline,
176        ) = create_pipelines(device, particle_wgsl_struct);
177
178        // Create bind groups
179        let morton_bind_group = create_morton_bind_group(
180            device,
181            &compute_morton_pipeline,
182            particle_buffer,
183            &morton_codes_a,
184            &particle_indices_a,
185            &spatial_params_buffer,
186        );
187
188        let histogram_bind_group_a = create_histogram_bind_group(
189            device,
190            &histogram_pipeline,
191            &morton_codes_a,
192            &histogram,
193            &sort_params_buffer,
194        );
195
196        let histogram_bind_group_b = create_histogram_bind_group(
197            device,
198            &histogram_pipeline,
199            &morton_codes_b,
200            &histogram,
201            &sort_params_buffer,
202        );
203
204        let prefix_sum_bind_group = create_prefix_sum_bind_group(
205            device,
206            &prefix_sum_pipeline,
207            &histogram,
208        );
209
210        let scatter_bind_group_a_to_b = create_scatter_bind_group(
211            device,
212            &scatter_pipeline,
213            &morton_codes_a,
214            &particle_indices_a,
215            &morton_codes_b,
216            &particle_indices_b,
217            &histogram,
218            &sort_params_buffer,
219        );
220
221        let scatter_bind_group_b_to_a = create_scatter_bind_group(
222            device,
223            &scatter_pipeline,
224            &morton_codes_b,
225            &particle_indices_b,
226            &morton_codes_a,
227            &particle_indices_a,
228            &histogram,
229            &sort_params_buffer,
230        );
231
232        let build_cells_bind_group = create_build_cells_bind_group(
233            device,
234            &build_cells_pipeline,
235            &morton_codes_a, // After even number of passes, result is in A
236            &cell_start,
237            &cell_end,
238            &spatial_params_buffer,
239        );
240
241        let clear_histogram_bind_group = create_clear_bind_group(
242            device,
243            &clear_histogram_pipeline,
244            &histogram,
245            RADIX_SIZE,
246        );
247
248        let clear_cells_bind_group = create_clear_bind_group(
249            device,
250            &clear_cells_pipeline,
251            &cell_start,
252            total_cells,
253        );
254
255        let sort_passes = calculate_sort_passes(config.grid_resolution);
256
257        Self {
258            morton_codes_a,
259            morton_codes_b,
260            particle_indices_a,
261            particle_indices_b,
262            histogram,
263            cell_start,
264            cell_end,
265            spatial_params_buffer,
266            sort_params_buffer,
267            compute_morton_pipeline,
268            histogram_pipeline,
269            prefix_sum_pipeline,
270            scatter_pipeline,
271            build_cells_pipeline,
272            clear_histogram_pipeline,
273            clear_cells_pipeline,
274            morton_bind_group,
275            histogram_bind_group_a,
276            histogram_bind_group_b,
277            prefix_sum_bind_group,
278            scatter_bind_group_a_to_b,
279            scatter_bind_group_b_to_a,
280            build_cells_bind_group,
281            clear_histogram_bind_group,
282            clear_cells_bind_group,
283            config,
284            num_particles,
285            sort_passes,
286        }
287    }
288
289    /// Execute spatial hashing passes
290    pub fn execute(&self, encoder: &mut wgpu::CommandEncoder, queue: &wgpu::Queue) {
291        let workgroups = self.num_particles.div_ceil(WORKGROUP_SIZE);
292
293        // Step 1: Compute Morton codes
294        {
295            let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
296                label: Some("Compute Morton"),
297                timestamp_writes: None,
298            });
299            pass.set_pipeline(&self.compute_morton_pipeline);
300            pass.set_bind_group(0, &self.morton_bind_group, &[]);
301            pass.dispatch_workgroups(workgroups, 1, 1);
302        }
303
304        // Step 2: Radix sort (dynamic passes based on grid resolution)
305        let mut source_is_a = true;
306
307        for pass_idx in 0..self.sort_passes {
308            let bit_offset = pass_idx * RADIX_BITS;
309
310            // Update sort params
311            let sort_params = SortParams {
312                num_elements: self.num_particles,
313                bit_offset,
314                _pad0: 0,
315                _pad1: 0,
316            };
317            queue.write_buffer(&self.sort_params_buffer, 0, bytemuck::cast_slice(&[sort_params]));
318
319            // Clear histogram
320            {
321                let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
322                    label: Some("Clear Histogram"),
323                    timestamp_writes: None,
324                });
325                pass.set_pipeline(&self.clear_histogram_pipeline);
326                pass.set_bind_group(0, &self.clear_histogram_bind_group, &[]);
327                pass.dispatch_workgroups(1, 1, 1);
328            }
329
330            // Histogram pass
331            {
332                let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
333                    label: Some("Radix Histogram"),
334                    timestamp_writes: None,
335                });
336                pass.set_pipeline(&self.histogram_pipeline);
337                pass.set_bind_group(
338                    0,
339                    if source_is_a { &self.histogram_bind_group_a } else { &self.histogram_bind_group_b },
340                    &[],
341                );
342                pass.dispatch_workgroups(workgroups, 1, 1);
343            }
344
345            // Prefix sum
346            {
347                let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
348                    label: Some("Prefix Sum"),
349                    timestamp_writes: None,
350                });
351                pass.set_pipeline(&self.prefix_sum_pipeline);
352                pass.set_bind_group(0, &self.prefix_sum_bind_group, &[]);
353                pass.dispatch_workgroups(1, 1, 1);
354            }
355
356            // Scatter pass
357            {
358                let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
359                    label: Some("Radix Scatter"),
360                    timestamp_writes: None,
361                });
362                pass.set_pipeline(&self.scatter_pipeline);
363                pass.set_bind_group(
364                    0,
365                    if source_is_a { &self.scatter_bind_group_a_to_b } else { &self.scatter_bind_group_b_to_a },
366                    &[],
367                );
368                pass.dispatch_workgroups(workgroups, 1, 1);
369            }
370
371            source_is_a = !source_is_a;
372        }
373
374        // Step 3: Build cell table
375        // Clear cell tables first
376        {
377            let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
378                label: Some("Clear Cells"),
379                timestamp_writes: None,
380            });
381            pass.set_pipeline(&self.clear_cells_pipeline);
382            pass.set_bind_group(0, &self.clear_cells_bind_group, &[]);
383            let cell_workgroups = self.config.total_cells().div_ceil(WORKGROUP_SIZE);
384            pass.dispatch_workgroups(cell_workgroups, 1, 1);
385        }
386
387        // Build cell start/end
388        {
389            let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
390                label: Some("Build Cell Table"),
391                timestamp_writes: None,
392            });
393            pass.set_pipeline(&self.build_cells_pipeline);
394            pass.set_bind_group(0, &self.build_cells_bind_group, &[]);
395            pass.dispatch_workgroups(workgroups, 1, 1);
396        }
397    }
398}
399
400fn create_pipelines(
401    device: &wgpu::Device,
402    particle_wgsl_struct: &str,
403) -> (
404    wgpu::ComputePipeline,
405    wgpu::ComputePipeline,
406    wgpu::ComputePipeline,
407    wgpu::ComputePipeline,
408    wgpu::ComputePipeline,
409    wgpu::ComputePipeline,
410    wgpu::ComputePipeline,
411) {
412    // Morton code computation shader - uses actual particle struct for correct stride
413    let morton_shader_src = format!(
414        r#"{}
415
416{}
417
418struct SpatialParams {{
419    cell_size: f32,
420    grid_resolution: u32,
421    num_particles: u32,
422    max_neighbors: u32,
423}};
424
425@group(0) @binding(0) var<storage, read> particles: array<Particle>;
426@group(0) @binding(1) var<storage, read_write> morton_codes: array<u32>;
427@group(0) @binding(2) var<storage, read_write> particle_indices: array<u32>;
428@group(0) @binding(3) var<uniform> params: SpatialParams;
429
430@compute @workgroup_size(256)
431fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {{
432    let idx = global_id.x;
433    if idx >= params.num_particles {{
434        return;
435    }}
436
437    let pos = particles[idx].position;
438    morton_codes[idx] = pos_to_morton(pos, params.cell_size, params.grid_resolution);
439    particle_indices[idx] = idx;
440}}
441"#,
442        MORTON_WGSL,
443        particle_wgsl_struct
444    );
445
446    let morton_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
447        label: Some("Morton Shader"),
448        source: wgpu::ShaderSource::Wgsl(morton_shader_src.into()),
449    });
450
451    // Histogram shader
452    let histogram_shader_src = r#"
453struct SortParams {
454    num_elements: u32,
455    bit_offset: u32,
456    _pad0: u32,
457    _pad1: u32,
458};
459
460@group(0) @binding(0) var<storage, read> keys: array<u32>;
461@group(0) @binding(1) var<storage, read_write> histogram: array<atomic<u32>>;
462@group(0) @binding(2) var<uniform> params: SortParams;
463
464const RADIX_SIZE: u32 = 16u;
465
466@compute @workgroup_size(256)
467fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
468    let idx = global_id.x;
469    if idx >= params.num_elements {
470        return;
471    }
472
473    let key = keys[idx];
474    let digit = (key >> params.bit_offset) & (RADIX_SIZE - 1u);
475    atomicAdd(&histogram[digit], 1u);
476}
477"#;
478
479    let histogram_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
480        label: Some("Histogram Shader"),
481        source: wgpu::ShaderSource::Wgsl(histogram_shader_src.into()),
482    });
483
484    // Prefix sum shader
485    let prefix_sum_shader_src = r#"
486@group(0) @binding(0) var<storage, read_write> data: array<u32>;
487
488var<workgroup> temp: array<u32, 16>;
489
490@compute @workgroup_size(16)
491fn main(@builtin(local_invocation_id) local_id: vec3<u32>) {
492    let tid = local_id.x;
493
494    // Load into shared memory
495    temp[tid] = data[tid];
496    workgroupBarrier();
497
498    // Inclusive scan using up-sweep and down-sweep
499    // Up-sweep
500    for (var stride = 1u; stride < 16u; stride *= 2u) {
501        if tid >= stride {
502            temp[tid] += temp[tid - stride];
503        }
504        workgroupBarrier();
505    }
506
507    // Convert to exclusive scan
508    let inclusive = temp[tid];
509    workgroupBarrier();
510
511    if tid == 0u {
512        data[tid] = 0u;
513    } else {
514        data[tid] = temp[tid - 1u];
515    }
516}
517"#;
518
519    let prefix_sum_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
520        label: Some("Prefix Sum Shader"),
521        source: wgpu::ShaderSource::Wgsl(prefix_sum_shader_src.into()),
522    });
523
524    // Scatter shader
525    let scatter_shader_src = r#"
526struct SortParams {
527    num_elements: u32,
528    bit_offset: u32,
529    _pad0: u32,
530    _pad1: u32,
531};
532
533@group(0) @binding(0) var<storage, read> keys_in: array<u32>;
534@group(0) @binding(1) var<storage, read> vals_in: array<u32>;
535@group(0) @binding(2) var<storage, read_write> keys_out: array<u32>;
536@group(0) @binding(3) var<storage, read_write> vals_out: array<u32>;
537@group(0) @binding(4) var<storage, read_write> offsets: array<atomic<u32>>;
538@group(0) @binding(5) var<uniform> params: SortParams;
539
540const RADIX_SIZE: u32 = 16u;
541
542@compute @workgroup_size(256)
543fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
544    let idx = global_id.x;
545    if idx >= params.num_elements {
546        return;
547    }
548
549    let key = keys_in[idx];
550    let val = vals_in[idx];
551    let digit = (key >> params.bit_offset) & (RADIX_SIZE - 1u);
552
553    let dest = atomicAdd(&offsets[digit], 1u);
554
555    keys_out[dest] = key;
556    vals_out[dest] = val;
557}
558"#;
559
560    let scatter_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
561        label: Some("Scatter Shader"),
562        source: wgpu::ShaderSource::Wgsl(scatter_shader_src.into()),
563    });
564
565    // Build cells shader
566    let build_cells_shader_src = r#"
567struct SpatialParams {
568    cell_size: f32,
569    grid_resolution: u32,
570    num_particles: u32,
571    max_neighbors: u32,
572};
573
574@group(0) @binding(0) var<storage, read> sorted_morton: array<u32>;
575@group(0) @binding(1) var<storage, read_write> cell_start: array<u32>;
576@group(0) @binding(2) var<storage, read_write> cell_end: array<u32>;
577@group(0) @binding(3) var<uniform> params: SpatialParams;
578
579@compute @workgroup_size(256)
580fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
581    let idx = global_id.x;
582    if idx >= params.num_particles {
583        return;
584    }
585
586    let code = sorted_morton[idx];
587
588    if idx == 0u {
589        cell_start[code] = 0u;
590    } else {
591        let prev_code = sorted_morton[idx - 1u];
592        if code != prev_code {
593            cell_start[code] = idx;
594            cell_end[prev_code] = idx;
595        }
596    }
597
598    if idx == params.num_particles - 1u {
599        cell_end[code] = params.num_particles;
600    }
601}
602"#;
603
604    let build_cells_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
605        label: Some("Build Cells Shader"),
606        source: wgpu::ShaderSource::Wgsl(build_cells_shader_src.into()),
607    });
608
609    // Clear buffer shader
610    let clear_shader_src = r#"
611@group(0) @binding(0) var<storage, read_write> data: array<u32>;
612
613@compute @workgroup_size(256)
614fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
615    let idx = global_id.x;
616    if idx < arrayLength(&data) {
617        data[idx] = 0xFFFFFFFFu;
618    }
619}
620"#;
621
622    let clear_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
623        label: Some("Clear Shader"),
624        source: wgpu::ShaderSource::Wgsl(clear_shader_src.into()),
625    });
626
627    // Create pipeline layouts and pipelines
628    let morton_pipeline = create_compute_pipeline(device, &morton_shader, "main", "Morton Pipeline");
629    let histogram_pipeline = create_compute_pipeline(device, &histogram_shader, "main", "Histogram Pipeline");
630    let prefix_sum_pipeline = create_compute_pipeline(device, &prefix_sum_shader, "main", "Prefix Sum Pipeline");
631    let scatter_pipeline = create_compute_pipeline(device, &scatter_shader, "main", "Scatter Pipeline");
632    let build_cells_pipeline = create_compute_pipeline(device, &build_cells_shader, "main", "Build Cells Pipeline");
633    let clear_histogram_pipeline = create_compute_pipeline(device, &clear_shader, "main", "Clear Histogram Pipeline");
634    let clear_cells_pipeline = create_compute_pipeline(device, &clear_shader, "main", "Clear Cells Pipeline");
635
636    (
637        morton_pipeline,
638        histogram_pipeline,
639        prefix_sum_pipeline,
640        scatter_pipeline,
641        build_cells_pipeline,
642        clear_histogram_pipeline,
643        clear_cells_pipeline,
644    )
645}
646
647fn create_compute_pipeline(
648    device: &wgpu::Device,
649    shader: &wgpu::ShaderModule,
650    entry_point: &str,
651    label: &str,
652) -> wgpu::ComputePipeline {
653    device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
654        label: Some(label),
655        layout: None, // Auto layout
656        module: shader,
657        entry_point: Some(entry_point),
658        compilation_options: Default::default(),
659        cache: None,
660    })
661}
662
663fn create_morton_bind_group(
664    device: &wgpu::Device,
665    pipeline: &wgpu::ComputePipeline,
666    particles: &wgpu::Buffer,
667    morton_codes: &wgpu::Buffer,
668    particle_indices: &wgpu::Buffer,
669    params: &wgpu::Buffer,
670) -> wgpu::BindGroup {
671    let layout = pipeline.get_bind_group_layout(0);
672    device.create_bind_group(&wgpu::BindGroupDescriptor {
673        label: Some("Morton Bind Group"),
674        layout: &layout,
675        entries: &[
676            wgpu::BindGroupEntry { binding: 0, resource: particles.as_entire_binding() },
677            wgpu::BindGroupEntry { binding: 1, resource: morton_codes.as_entire_binding() },
678            wgpu::BindGroupEntry { binding: 2, resource: particle_indices.as_entire_binding() },
679            wgpu::BindGroupEntry { binding: 3, resource: params.as_entire_binding() },
680        ],
681    })
682}
683
684fn create_histogram_bind_group(
685    device: &wgpu::Device,
686    pipeline: &wgpu::ComputePipeline,
687    keys: &wgpu::Buffer,
688    histogram: &wgpu::Buffer,
689    params: &wgpu::Buffer,
690) -> wgpu::BindGroup {
691    let layout = pipeline.get_bind_group_layout(0);
692    device.create_bind_group(&wgpu::BindGroupDescriptor {
693        label: Some("Histogram Bind Group"),
694        layout: &layout,
695        entries: &[
696            wgpu::BindGroupEntry { binding: 0, resource: keys.as_entire_binding() },
697            wgpu::BindGroupEntry { binding: 1, resource: histogram.as_entire_binding() },
698            wgpu::BindGroupEntry { binding: 2, resource: params.as_entire_binding() },
699        ],
700    })
701}
702
703fn create_prefix_sum_bind_group(
704    device: &wgpu::Device,
705    pipeline: &wgpu::ComputePipeline,
706    data: &wgpu::Buffer,
707) -> wgpu::BindGroup {
708    let layout = pipeline.get_bind_group_layout(0);
709    device.create_bind_group(&wgpu::BindGroupDescriptor {
710        label: Some("Prefix Sum Bind Group"),
711        layout: &layout,
712        entries: &[
713            wgpu::BindGroupEntry { binding: 0, resource: data.as_entire_binding() },
714        ],
715    })
716}
717
718#[allow(clippy::too_many_arguments)]
719fn create_scatter_bind_group(
720    device: &wgpu::Device,
721    pipeline: &wgpu::ComputePipeline,
722    keys_in: &wgpu::Buffer,
723    vals_in: &wgpu::Buffer,
724    keys_out: &wgpu::Buffer,
725    vals_out: &wgpu::Buffer,
726    offsets: &wgpu::Buffer,
727    params: &wgpu::Buffer,
728) -> wgpu::BindGroup {
729    let layout = pipeline.get_bind_group_layout(0);
730    device.create_bind_group(&wgpu::BindGroupDescriptor {
731        label: Some("Scatter Bind Group"),
732        layout: &layout,
733        entries: &[
734            wgpu::BindGroupEntry { binding: 0, resource: keys_in.as_entire_binding() },
735            wgpu::BindGroupEntry { binding: 1, resource: vals_in.as_entire_binding() },
736            wgpu::BindGroupEntry { binding: 2, resource: keys_out.as_entire_binding() },
737            wgpu::BindGroupEntry { binding: 3, resource: vals_out.as_entire_binding() },
738            wgpu::BindGroupEntry { binding: 4, resource: offsets.as_entire_binding() },
739            wgpu::BindGroupEntry { binding: 5, resource: params.as_entire_binding() },
740        ],
741    })
742}
743
744fn create_build_cells_bind_group(
745    device: &wgpu::Device,
746    pipeline: &wgpu::ComputePipeline,
747    sorted_morton: &wgpu::Buffer,
748    cell_start: &wgpu::Buffer,
749    cell_end: &wgpu::Buffer,
750    params: &wgpu::Buffer,
751) -> wgpu::BindGroup {
752    let layout = pipeline.get_bind_group_layout(0);
753    device.create_bind_group(&wgpu::BindGroupDescriptor {
754        label: Some("Build Cells Bind Group"),
755        layout: &layout,
756        entries: &[
757            wgpu::BindGroupEntry { binding: 0, resource: sorted_morton.as_entire_binding() },
758            wgpu::BindGroupEntry { binding: 1, resource: cell_start.as_entire_binding() },
759            wgpu::BindGroupEntry { binding: 2, resource: cell_end.as_entire_binding() },
760            wgpu::BindGroupEntry { binding: 3, resource: params.as_entire_binding() },
761        ],
762    })
763}
764
765fn create_clear_bind_group(
766    device: &wgpu::Device,
767    pipeline: &wgpu::ComputePipeline,
768    buffer: &wgpu::Buffer,
769    _count: u32,
770) -> wgpu::BindGroup {
771    let layout = pipeline.get_bind_group_layout(0);
772    device.create_bind_group(&wgpu::BindGroupDescriptor {
773        label: Some("Clear Bind Group"),
774        layout: &layout,
775        entries: &[
776            wgpu::BindGroupEntry { binding: 0, resource: buffer.as_entire_binding() },
777        ],
778    })
779}