rend3_routine/culling/
gpu.rs

1use std::{borrow::Cow, mem, num::NonZeroU64};
2
3use glam::Mat4;
4use rend3::{
5    managers::{CameraManager, GpuCullingInput, InternalObject, VERTEX_OBJECT_INDEX_SLOT},
6    util::{bind_merge::BindGroupBuilder, frustum::ShaderFrustum},
7    ProfileData,
8};
9use wgpu::{
10    util::{BufferInitDescriptor, DeviceExt},
11    BindGroupLayout, BindGroupLayoutDescriptor, BindGroupLayoutEntry, BindingType, Buffer, BufferBindingType,
12    BufferDescriptor, BufferUsages, CommandEncoder, ComputePassDescriptor, ComputePipeline, ComputePipelineDescriptor,
13    Device, PipelineLayoutDescriptor, PushConstantRange, RenderPass, ShaderModuleDescriptor,
14    ShaderModuleDescriptorSpirV, ShaderSource, ShaderStages,
15};
16
17use crate::{
18    common::{PerObjectDataAbi, Sorting},
19    culling::CulledObjectSet,
20    shaders::{SPIRV_SHADERS, WGSL_SHADERS},
21};
22
23#[repr(C, align(16))]
24#[derive(Debug, Copy, Clone)]
25struct GPUCullingUniforms {
26    view: Mat4,
27    view_proj: Mat4,
28    frustum: ShaderFrustum,
29    object_count: u32,
30}
31
32unsafe impl bytemuck::Pod for GPUCullingUniforms {}
33unsafe impl bytemuck::Zeroable for GPUCullingUniforms {}
34
35/// The data needed to do an indirect draw call for an entire material
36/// archetype.
37pub struct GpuIndirectData {
38    pub indirect_buffer: Buffer,
39    pub count: usize,
40}
41
42/// Provides GPU based object culling.
43pub struct GpuCuller {
44    atomic_bgl: BindGroupLayout,
45    atomic_pipeline: ComputePipeline,
46
47    prefix_bgl: BindGroupLayout,
48    prefix_cull_pipeline: ComputePipeline,
49    prefix_sum_pipeline: ComputePipeline,
50    prefix_output_pipeline: ComputePipeline,
51}
52impl GpuCuller {
53    pub fn new(device: &Device) -> Self {
54        profiling::scope!("GpuCuller::new");
55
56        let atomic_bgl = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
57            label: Some("atomic culling pll"),
58            entries: &[
59                BindGroupLayoutEntry {
60                    binding: 0,
61                    visibility: ShaderStages::COMPUTE,
62                    ty: BindingType::Buffer {
63                        ty: BufferBindingType::Storage { read_only: true },
64                        has_dynamic_offset: false,
65                        min_binding_size: NonZeroU64::new(mem::size_of::<GpuCullingInput>() as _),
66                    },
67                    count: None,
68                },
69                BindGroupLayoutEntry {
70                    binding: 1,
71                    visibility: ShaderStages::COMPUTE,
72                    ty: BindingType::Buffer {
73                        ty: BufferBindingType::Uniform,
74                        has_dynamic_offset: false,
75                        min_binding_size: NonZeroU64::new(mem::size_of::<GPUCullingUniforms>() as _),
76                    },
77                    count: None,
78                },
79                BindGroupLayoutEntry {
80                    binding: 2,
81                    visibility: ShaderStages::COMPUTE,
82                    ty: BindingType::Buffer {
83                        ty: BufferBindingType::Storage { read_only: false },
84                        has_dynamic_offset: false,
85                        min_binding_size: NonZeroU64::new(mem::size_of::<PerObjectDataAbi>() as _),
86                    },
87                    count: None,
88                },
89                BindGroupLayoutEntry {
90                    binding: 3,
91                    visibility: ShaderStages::COMPUTE,
92                    ty: BindingType::Buffer {
93                        ty: BufferBindingType::Storage { read_only: false },
94                        has_dynamic_offset: false,
95                        min_binding_size: NonZeroU64::new(16 + 20),
96                    },
97                    count: None,
98                },
99            ],
100        });
101
102        let prefix_bgl = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
103            label: Some("prefix culling pll"),
104            entries: &[
105                BindGroupLayoutEntry {
106                    binding: 0,
107                    visibility: ShaderStages::COMPUTE,
108                    ty: BindingType::Buffer {
109                        ty: BufferBindingType::Storage { read_only: true },
110                        has_dynamic_offset: false,
111                        min_binding_size: NonZeroU64::new(mem::size_of::<GpuCullingInput>() as _),
112                    },
113                    count: None,
114                },
115                BindGroupLayoutEntry {
116                    binding: 1,
117                    visibility: ShaderStages::COMPUTE,
118                    ty: BindingType::Buffer {
119                        ty: BufferBindingType::Uniform,
120                        has_dynamic_offset: false,
121                        min_binding_size: NonZeroU64::new(mem::size_of::<GPUCullingUniforms>() as _),
122                    },
123                    count: None,
124                },
125                BindGroupLayoutEntry {
126                    binding: 2,
127                    visibility: ShaderStages::COMPUTE,
128                    ty: BindingType::Buffer {
129                        ty: BufferBindingType::Storage { read_only: false },
130                        has_dynamic_offset: false,
131                        min_binding_size: NonZeroU64::new(mem::size_of::<u32>() as _),
132                    },
133                    count: None,
134                },
135                BindGroupLayoutEntry {
136                    binding: 3,
137                    visibility: ShaderStages::COMPUTE,
138                    ty: BindingType::Buffer {
139                        ty: BufferBindingType::Storage { read_only: false },
140                        has_dynamic_offset: false,
141                        min_binding_size: NonZeroU64::new(mem::size_of::<u32>() as _),
142                    },
143                    count: None,
144                },
145                BindGroupLayoutEntry {
146                    binding: 4,
147                    visibility: ShaderStages::COMPUTE,
148                    ty: BindingType::Buffer {
149                        ty: BufferBindingType::Storage { read_only: false },
150                        has_dynamic_offset: false,
151                        min_binding_size: NonZeroU64::new(mem::size_of::<PerObjectDataAbi>() as _),
152                    },
153                    count: None,
154                },
155                BindGroupLayoutEntry {
156                    binding: 5,
157                    visibility: ShaderStages::COMPUTE,
158                    ty: BindingType::Buffer {
159                        ty: BufferBindingType::Storage { read_only: false },
160                        has_dynamic_offset: false,
161                        min_binding_size: NonZeroU64::new(16 + 20),
162                    },
163                    count: None,
164                },
165            ],
166        });
167
168        let atomic_pll = device.create_pipeline_layout(&PipelineLayoutDescriptor {
169            label: Some("atomic culling pll"),
170            bind_group_layouts: &[&atomic_bgl],
171            push_constant_ranges: &[],
172        });
173
174        let prefix_pll = device.create_pipeline_layout(&PipelineLayoutDescriptor {
175            label: Some("prefix culling pll"),
176            bind_group_layouts: &[&prefix_bgl],
177            push_constant_ranges: &[],
178        });
179
180        let prefix_sum_pll = device.create_pipeline_layout(&PipelineLayoutDescriptor {
181            label: Some("prefix sum pll"),
182            bind_group_layouts: &[&prefix_bgl],
183            push_constant_ranges: &[PushConstantRange {
184                stages: ShaderStages::COMPUTE,
185                range: 0..4,
186            }],
187        });
188
189        let atomic_sm = unsafe {
190            device.create_shader_module_spirv(&ShaderModuleDescriptorSpirV {
191                label: Some("cull-atomic-cull"),
192                source: wgpu::util::make_spirv_raw(
193                    SPIRV_SHADERS.get_file("cull-atomic-cull.comp.spv").unwrap().contents(),
194                ),
195            })
196        };
197
198        let prefix_cull_sm = device.create_shader_module(&ShaderModuleDescriptor {
199            label: Some("cull-prefix-cull"),
200            source: ShaderSource::Wgsl(Cow::Borrowed(
201                WGSL_SHADERS
202                    .get_file("cull-prefix-cull.comp.wgsl")
203                    .unwrap()
204                    .contents_utf8()
205                    .unwrap(),
206            )),
207        });
208
209        let prefix_sum_sm = device.create_shader_module(&ShaderModuleDescriptor {
210            label: Some("cull-prefix-sum"),
211            source: ShaderSource::Wgsl(Cow::Borrowed(
212                WGSL_SHADERS
213                    .get_file("cull-prefix-sum.comp.wgsl")
214                    .unwrap()
215                    .contents_utf8()
216                    .unwrap(),
217            )),
218        });
219
220        let prefix_output_sm = device.create_shader_module(&ShaderModuleDescriptor {
221            label: Some("cull-prefix-output"),
222            source: ShaderSource::Wgsl(Cow::Borrowed(
223                WGSL_SHADERS
224                    .get_file("cull-prefix-output.comp.wgsl")
225                    .unwrap()
226                    .contents_utf8()
227                    .unwrap(),
228            )),
229        });
230
231        let atomic_pipeline = device.create_compute_pipeline(&ComputePipelineDescriptor {
232            label: Some("atomic culling pl"),
233            layout: Some(&atomic_pll),
234            module: &atomic_sm,
235            entry_point: "main",
236        });
237
238        let prefix_cull_pipeline = device.create_compute_pipeline(&ComputePipelineDescriptor {
239            label: Some("prefix cull pl"),
240            layout: Some(&prefix_pll),
241            module: &prefix_cull_sm,
242            entry_point: "main",
243        });
244
245        let prefix_sum_pipeline = device.create_compute_pipeline(&ComputePipelineDescriptor {
246            label: Some("prefix sum pl"),
247            layout: Some(&prefix_sum_pll),
248            module: &prefix_sum_sm,
249            entry_point: "main",
250        });
251
252        let prefix_output_pipeline = device.create_compute_pipeline(&ComputePipelineDescriptor {
253            label: Some("prefix output pl"),
254            layout: Some(&prefix_pll),
255            module: &prefix_output_sm,
256            entry_point: "main",
257        });
258
259        Self {
260            atomic_bgl,
261            atomic_pipeline,
262            prefix_bgl,
263            prefix_cull_pipeline,
264            prefix_sum_pipeline,
265            prefix_output_pipeline,
266        }
267    }
268
269    /// Perform culling on a given camera and input.
270    pub fn cull(
271        &self,
272        device: &Device,
273        encoder: &mut CommandEncoder,
274        camera: &CameraManager,
275        input_buffer: &Buffer,
276        input_count: usize,
277        sorting: Option<Sorting>,
278    ) -> CulledObjectSet {
279        profiling::scope!("Record GPU Culling");
280
281        let count = input_count;
282
283        let uniform = GPUCullingUniforms {
284            view: camera.view(),
285            view_proj: camera.view_proj(),
286            frustum: ShaderFrustum::from_matrix(camera.proj()),
287            object_count: count as u32,
288        };
289
290        let uniform_buffer = device.create_buffer_init(&BufferInitDescriptor {
291            label: Some("gpu culling uniform buffer"),
292            contents: bytemuck::bytes_of(&uniform),
293            usage: BufferUsages::UNIFORM,
294        });
295
296        let output_buffer = device.create_buffer(&BufferDescriptor {
297            label: Some("culling output"),
298            size: (count.max(1) * mem::size_of::<PerObjectDataAbi>()) as _,
299            usage: BufferUsages::STORAGE,
300            mapped_at_creation: false,
301        });
302
303        let indirect_buffer = device.create_buffer(&BufferDescriptor {
304            label: Some("indirect buffer"),
305            // 16 bytes for count, the rest for the indirect count
306            size: (count * 20 + 16) as _,
307            usage: BufferUsages::STORAGE | BufferUsages::INDIRECT | BufferUsages::VERTEX,
308            mapped_at_creation: false,
309        });
310
311        if count != 0 {
312            let dispatch_count = ((count + 255) / 256) as u32;
313
314            if sorting.is_some() {
315                let buffer_a = device.create_buffer(&BufferDescriptor {
316                    label: Some("cull result index buffer A"),
317                    size: (count * 4) as _,
318                    usage: BufferUsages::STORAGE,
319                    mapped_at_creation: false,
320                });
321
322                let buffer_b = device.create_buffer(&BufferDescriptor {
323                    label: Some("cull result index buffer B"),
324                    size: (count * 4) as _,
325                    usage: BufferUsages::STORAGE,
326                    mapped_at_creation: false,
327                });
328
329                let bg_a = BindGroupBuilder::new()
330                    .append_buffer(input_buffer)
331                    .append_buffer(&uniform_buffer)
332                    .append_buffer(&buffer_a)
333                    .append_buffer(&buffer_b)
334                    .append_buffer(&output_buffer)
335                    .append_buffer(&indirect_buffer)
336                    .build(device, Some("prefix cull A bg"), &self.prefix_bgl);
337
338                let bg_b = BindGroupBuilder::new()
339                    .append_buffer(input_buffer)
340                    .append_buffer(&uniform_buffer)
341                    .append_buffer(&buffer_b)
342                    .append_buffer(&buffer_a)
343                    .append_buffer(&output_buffer)
344                    .append_buffer(&indirect_buffer)
345                    .build(device, Some("prefix cull B bg"), &self.prefix_bgl);
346
347                let mut cpass = encoder.begin_compute_pass(&ComputePassDescriptor {
348                    label: Some("prefix cull"),
349                });
350
351                cpass.set_pipeline(&self.prefix_cull_pipeline);
352                cpass.set_bind_group(0, &bg_a, &[]);
353                cpass.dispatch(dispatch_count, 1, 1);
354
355                cpass.set_pipeline(&self.prefix_sum_pipeline);
356                let mut stride = 1_u32;
357                let mut iteration = 0;
358                while stride < count as u32 {
359                    let bind_group = if iteration % 2 == 0 { &bg_a } else { &bg_b };
360
361                    cpass.set_push_constants(0, bytemuck::cast_slice(&[stride]));
362                    cpass.set_bind_group(0, bind_group, &[]);
363                    cpass.dispatch(dispatch_count, 1, 1);
364                    stride <<= 1;
365                    iteration += 1;
366                }
367
368                let bind_group = if iteration % 2 == 0 { &bg_a } else { &bg_b };
369                cpass.set_pipeline(&self.prefix_output_pipeline);
370                cpass.set_bind_group(0, bind_group, &[]);
371                cpass.dispatch(dispatch_count, 1, 1);
372            } else {
373                let bg = BindGroupBuilder::new()
374                    .append_buffer(input_buffer)
375                    .append_buffer(&uniform_buffer)
376                    .append_buffer(&output_buffer)
377                    .append_buffer(&indirect_buffer)
378                    .build(device, Some("atomic culling bg"), &self.atomic_bgl);
379
380                let mut cpass = encoder.begin_compute_pass(&ComputePassDescriptor {
381                    label: Some("atomic cull"),
382                });
383
384                cpass.set_pipeline(&self.atomic_pipeline);
385                cpass.set_bind_group(0, &bg, &[]);
386                cpass.dispatch(dispatch_count, 1, 1);
387
388                drop(cpass);
389            }
390        }
391
392        CulledObjectSet {
393            calls: ProfileData::Gpu(GpuIndirectData { indirect_buffer, count }),
394            output_buffer,
395        }
396    }
397}
398
399/// Build and upload the inputs into a buffer to be passed to
400/// [`GpuCuller::cull`].
401pub fn build_gpu_cull_input(device: &Device, objects: &[InternalObject]) -> Buffer {
402    profiling::scope!("Building Input Data");
403
404    let total_length = objects.len() * mem::size_of::<GpuCullingInput>();
405
406    let buffer = device.create_buffer(&BufferDescriptor {
407        label: Some("culling inputs"),
408        size: total_length as u64,
409        usage: BufferUsages::STORAGE,
410        mapped_at_creation: true,
411    });
412
413    let mut data = buffer.slice(..).get_mapped_range_mut();
414
415    // This unsafe block measured a bit faster in my tests, and as this is basically
416    // _the_ hot path, so this is worthwhile.
417    unsafe {
418        let data_ptr = data.as_mut_ptr() as *mut GpuCullingInput;
419
420        // Iterate over the objects
421        for idx in 0..objects.len() {
422            // We're iterating over 0..len so this is never going to be out of bounds
423            let object = objects.get_unchecked(idx);
424
425            // This is aligned, and we know the vector has enough bytes to hold this, so
426            // this is safe
427            data_ptr.add(idx).write_unaligned(object.input);
428        }
429    }
430
431    drop(data);
432    buffer.unmap();
433
434    buffer
435}
436
437/// Draw the given indirect call.
438///
439/// No-op if there are 0 objects.
440pub fn draw_gpu_powered<'rpass>(rpass: &mut RenderPass<'rpass>, indirect_data: &'rpass GpuIndirectData) {
441    if indirect_data.count != 0 {
442        rpass.set_vertex_buffer(VERTEX_OBJECT_INDEX_SLOT, indirect_data.indirect_buffer.slice(16..));
443        rpass.multi_draw_indexed_indirect_count(
444            &indirect_data.indirect_buffer,
445            16,
446            &indirect_data.indirect_buffer,
447            0,
448            indirect_data.count as _,
449        );
450    }
451}