rend3_pbr/culling/
gpu.rs

1use std::{borrow::Cow, mem, num::NonZeroU64};
2
3use glam::Mat4;
4use ordered_float::OrderedFloat;
5use rend3::{
6    resources::{CameraManager, GPUCullingInput, InternalObject, ObjectManager},
7    util::{bind_merge::BindGroupBuilder, frustum::ShaderFrustum},
8    ModeData,
9};
10use wgpu::{
11    util::{BufferInitDescriptor, DeviceExt},
12    BindGroupDescriptor, BindGroupEntry, BindGroupLayout, BindGroupLayoutDescriptor, BindGroupLayoutEntry, BindingType,
13    Buffer, BufferBindingType, BufferDescriptor, BufferUsages, CommandEncoder, ComputePassDescriptor, ComputePipeline,
14    ComputePipelineDescriptor, Device, PipelineLayoutDescriptor, PushConstantRange, RenderPass, ShaderModuleDescriptor,
15    ShaderModuleDescriptorSpirV, ShaderStages,
16};
17
18use crate::{
19    common::interfaces::{PerObjectData, ShaderInterfaces},
20    culling::{CulledObjectSet, GPUIndirectData, Sorting},
21    material::{PbrMaterial, TransparencyType},
22    shaders::SPIRV_SHADERS,
23};
24
25pub struct PreCulledBuffer {
26    inner: Buffer,
27    count: usize,
28}
29
30#[repr(C, align(16))]
31#[derive(Debug, Copy, Clone)]
32struct GPUCullingUniforms {
33    view: Mat4,
34    view_proj: Mat4,
35    frustum: ShaderFrustum,
36    object_count: u32,
37}
38
39unsafe impl bytemuck::Pod for GPUCullingUniforms {}
40unsafe impl bytemuck::Zeroable for GPUCullingUniforms {}
41
42pub struct GpuCullerPreCullArgs<'a> {
43    pub device: &'a Device,
44
45    pub camera: &'a CameraManager,
46
47    pub objects: &'a mut ObjectManager,
48
49    pub transparency: TransparencyType,
50    pub sort: Option<Sorting>,
51}
52
53pub struct GpuCullerCullArgs<'a> {
54    pub device: &'a Device,
55    pub encoder: &'a mut CommandEncoder,
56
57    pub interfaces: &'a ShaderInterfaces,
58
59    pub camera: &'a CameraManager,
60
61    pub input_buffer: &'a PreCulledBuffer,
62
63    pub sort: Option<Sorting>,
64}
65
66pub struct GpuCuller {
67    atomic_bgl: BindGroupLayout,
68    atomic_pipeline: ComputePipeline,
69
70    prefix_bgl: BindGroupLayout,
71    prefix_cull_pipeline: ComputePipeline,
72    prefix_sum_pipeline: ComputePipeline,
73    prefix_output_pipeline: ComputePipeline,
74}
75impl GpuCuller {
76    pub fn new(device: &Device) -> Self {
77        profiling::scope!("GpuCuller::new");
78
79        let atomic_bgl = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
80            label: Some("atomic culling pll"),
81            entries: &[
82                BindGroupLayoutEntry {
83                    binding: 0,
84                    visibility: ShaderStages::COMPUTE,
85                    ty: BindingType::Buffer {
86                        ty: BufferBindingType::Storage { read_only: true },
87                        has_dynamic_offset: false,
88                        min_binding_size: NonZeroU64::new(mem::size_of::<GPUCullingInput>() as _),
89                    },
90                    count: None,
91                },
92                BindGroupLayoutEntry {
93                    binding: 1,
94                    visibility: ShaderStages::COMPUTE,
95                    ty: BindingType::Buffer {
96                        ty: BufferBindingType::Uniform,
97                        has_dynamic_offset: false,
98                        min_binding_size: NonZeroU64::new(mem::size_of::<GPUCullingUniforms>() as _),
99                    },
100                    count: None,
101                },
102                BindGroupLayoutEntry {
103                    binding: 2,
104                    visibility: ShaderStages::COMPUTE,
105                    ty: BindingType::Buffer {
106                        ty: BufferBindingType::Storage { read_only: false },
107                        has_dynamic_offset: false,
108                        min_binding_size: NonZeroU64::new(mem::size_of::<PerObjectData>() as _),
109                    },
110                    count: None,
111                },
112                BindGroupLayoutEntry {
113                    binding: 3,
114                    visibility: ShaderStages::COMPUTE,
115                    ty: BindingType::Buffer {
116                        ty: BufferBindingType::Storage { read_only: false },
117                        has_dynamic_offset: false,
118                        min_binding_size: NonZeroU64::new(16 + 20),
119                    },
120                    count: None,
121                },
122            ],
123        });
124
125        let prefix_bgl = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
126            label: Some("prefix culling pll"),
127            entries: &[
128                BindGroupLayoutEntry {
129                    binding: 0,
130                    visibility: ShaderStages::COMPUTE,
131                    ty: BindingType::Buffer {
132                        ty: BufferBindingType::Storage { read_only: true },
133                        has_dynamic_offset: false,
134                        min_binding_size: NonZeroU64::new(mem::size_of::<GPUCullingInput>() as _),
135                    },
136                    count: None,
137                },
138                BindGroupLayoutEntry {
139                    binding: 1,
140                    visibility: ShaderStages::COMPUTE,
141                    ty: BindingType::Buffer {
142                        ty: BufferBindingType::Uniform,
143                        has_dynamic_offset: false,
144                        min_binding_size: NonZeroU64::new(mem::size_of::<GPUCullingUniforms>() as _),
145                    },
146                    count: None,
147                },
148                BindGroupLayoutEntry {
149                    binding: 2,
150                    visibility: ShaderStages::COMPUTE,
151                    ty: BindingType::Buffer {
152                        ty: BufferBindingType::Storage { read_only: false },
153                        has_dynamic_offset: false,
154                        min_binding_size: NonZeroU64::new(mem::size_of::<u32>() as _),
155                    },
156                    count: None,
157                },
158                BindGroupLayoutEntry {
159                    binding: 3,
160                    visibility: ShaderStages::COMPUTE,
161                    ty: BindingType::Buffer {
162                        ty: BufferBindingType::Storage { read_only: false },
163                        has_dynamic_offset: false,
164                        min_binding_size: NonZeroU64::new(mem::size_of::<u32>() as _),
165                    },
166                    count: None,
167                },
168                BindGroupLayoutEntry {
169                    binding: 4,
170                    visibility: ShaderStages::COMPUTE,
171                    ty: BindingType::Buffer {
172                        ty: BufferBindingType::Storage { read_only: false },
173                        has_dynamic_offset: false,
174                        min_binding_size: NonZeroU64::new(mem::size_of::<PerObjectData>() as _),
175                    },
176                    count: None,
177                },
178                BindGroupLayoutEntry {
179                    binding: 5,
180                    visibility: ShaderStages::COMPUTE,
181                    ty: BindingType::Buffer {
182                        ty: BufferBindingType::Storage { read_only: false },
183                        has_dynamic_offset: false,
184                        min_binding_size: NonZeroU64::new(16 + 20),
185                    },
186                    count: None,
187                },
188            ],
189        });
190
191        let atomic_pll = device.create_pipeline_layout(&PipelineLayoutDescriptor {
192            label: Some("atomic culling pll"),
193            bind_group_layouts: &[&atomic_bgl],
194            push_constant_ranges: &[],
195        });
196
197        let prefix_pll = device.create_pipeline_layout(&PipelineLayoutDescriptor {
198            label: Some("prefix culling pll"),
199            bind_group_layouts: &[&prefix_bgl],
200            push_constant_ranges: &[],
201        });
202
203        let prefix_sum_pll = device.create_pipeline_layout(&PipelineLayoutDescriptor {
204            label: Some("prefix sum pll"),
205            bind_group_layouts: &[&prefix_bgl],
206            push_constant_ranges: &[PushConstantRange {
207                stages: ShaderStages::COMPUTE,
208                range: 0..4,
209            }],
210        });
211
212        let atomic_sm = unsafe {
213            device.create_shader_module_spirv(&ShaderModuleDescriptorSpirV {
214                label: Some("cull-atomic-cull"),
215                source: wgpu::util::make_spirv_raw(
216                    SPIRV_SHADERS.get_file("cull-atomic-cull.comp.spv").unwrap().contents(),
217                ),
218            })
219        };
220
221        let prefix_cull_sm = device.create_shader_module(&ShaderModuleDescriptor {
222            label: Some("cull-prefix-cull"),
223            source: wgpu::util::make_spirv(SPIRV_SHADERS.get_file("cull-prefix-cull.comp.spv").unwrap().contents()),
224        });
225
226        let prefix_sum_sm = device.create_shader_module(&ShaderModuleDescriptor {
227            label: Some("cull-prefix-sum"),
228            source: wgpu::util::make_spirv(SPIRV_SHADERS.get_file("cull-prefix-sum.comp.spv").unwrap().contents()),
229        });
230
231        let prefix_output_sm = device.create_shader_module(&ShaderModuleDescriptor {
232            label: Some("cull-prefix-output"),
233            source: wgpu::util::make_spirv(
234                SPIRV_SHADERS
235                    .get_file("cull-prefix-output.comp.spv")
236                    .unwrap()
237                    .contents(),
238            ),
239        });
240
241        let atomic_pipeline = device.create_compute_pipeline(&ComputePipelineDescriptor {
242            label: Some("atomic culling pl"),
243            layout: Some(&atomic_pll),
244            module: &atomic_sm,
245            entry_point: "main",
246        });
247
248        let prefix_cull_pipeline = device.create_compute_pipeline(&ComputePipelineDescriptor {
249            label: Some("prefix cull pl"),
250            layout: Some(&prefix_pll),
251            module: &prefix_cull_sm,
252            entry_point: "main",
253        });
254
255        let prefix_sum_pipeline = device.create_compute_pipeline(&ComputePipelineDescriptor {
256            label: Some("prefix sum pl"),
257            layout: Some(&prefix_sum_pll),
258            module: &prefix_sum_sm,
259            entry_point: "main",
260        });
261
262        let prefix_output_pipeline = device.create_compute_pipeline(&ComputePipelineDescriptor {
263            label: Some("prefix output pl"),
264            layout: Some(&prefix_pll),
265            module: &prefix_output_sm,
266            entry_point: "main",
267        });
268
269        Self {
270            atomic_bgl,
271            atomic_pipeline,
272            prefix_bgl,
273            prefix_cull_pipeline,
274            prefix_sum_pipeline,
275            prefix_output_pipeline,
276        }
277    }
278
279    pub fn pre_cull(&self, args: GpuCullerPreCullArgs<'_>) -> PreCulledBuffer {
280        let objects = args.objects.get_objects::<PbrMaterial>(args.transparency as u64);
281        let count = objects.len();
282
283        let objects = if let Some(sorting) = args.sort {
284            profiling::scope!("Sorting");
285
286            let mut sort_objects = objects.to_vec();
287
288            let camera_location = args.camera.location().into();
289
290            match sorting {
291                Sorting::FrontToBack => {
292                    sort_objects
293                        .sort_unstable_by_key(|o| OrderedFloat(o.mesh_location().distance_squared(camera_location)));
294                }
295                Sorting::BackToFront => {
296                    sort_objects
297                        .sort_unstable_by_key(|o| OrderedFloat(-o.mesh_location().distance_squared(camera_location)));
298                }
299            }
300
301            Cow::Owned(sort_objects)
302        } else {
303            Cow::Borrowed(objects)
304        };
305        let buffer = build_cull_data(args.device, &objects);
306
307        PreCulledBuffer { inner: buffer, count }
308    }
309
310    pub fn cull(&self, args: GpuCullerCullArgs<'_>) -> CulledObjectSet {
311        profiling::scope!("Record GPU Culling");
312
313        let count = args.input_buffer.count;
314
315        let uniform = GPUCullingUniforms {
316            view: args.camera.view(),
317            view_proj: args.camera.view_proj(),
318            frustum: ShaderFrustum::from_matrix(args.camera.proj()),
319            object_count: count as u32,
320        };
321
322        let uniform_buffer = args.device.create_buffer_init(&BufferInitDescriptor {
323            label: Some("gpu culling uniform buffer"),
324            contents: bytemuck::bytes_of(&uniform),
325            usage: BufferUsages::UNIFORM,
326        });
327
328        let output_buffer = args.device.create_buffer(&BufferDescriptor {
329            label: Some("culling output"),
330            size: (count.max(1) * mem::size_of::<PerObjectData>()) as _,
331            usage: BufferUsages::STORAGE,
332            mapped_at_creation: false,
333        });
334
335        let indirect_buffer = args.device.create_buffer(&BufferDescriptor {
336            label: Some("indirect buffer"),
337            // 16 bytes for count, the rest for the indirect count
338            size: (count * 20 + 16) as _,
339            usage: BufferUsages::STORAGE | BufferUsages::INDIRECT | BufferUsages::VERTEX,
340            mapped_at_creation: false,
341        });
342
343        if count != 0 {
344            let dispatch_count = ((count + 255) / 256) as u32;
345
346            if args.sort.is_some() {
347                let buffer_a = args.device.create_buffer(&BufferDescriptor {
348                    label: Some("cull result index buffer A"),
349                    size: (count * 4) as _,
350                    usage: BufferUsages::STORAGE,
351                    mapped_at_creation: false,
352                });
353
354                let buffer_b = args.device.create_buffer(&BufferDescriptor {
355                    label: Some("cull result index buffer B"),
356                    size: (count * 4) as _,
357                    usage: BufferUsages::STORAGE,
358                    mapped_at_creation: false,
359                });
360
361                let bg_a = BindGroupBuilder::new(Some("prefix cull A bg"))
362                    .with_buffer(&args.input_buffer.inner)
363                    .with_buffer(&uniform_buffer)
364                    .with_buffer(&buffer_a)
365                    .with_buffer(&buffer_b)
366                    .with_buffer(&output_buffer)
367                    .with_buffer(&indirect_buffer)
368                    .build(args.device, &self.prefix_bgl);
369
370                let bg_b = BindGroupBuilder::new(Some("prefix cull B bg"))
371                    .with_buffer(&args.input_buffer.inner)
372                    .with_buffer(&uniform_buffer)
373                    .with_buffer(&buffer_b)
374                    .with_buffer(&buffer_a)
375                    .with_buffer(&output_buffer)
376                    .with_buffer(&indirect_buffer)
377                    .build(args.device, &self.prefix_bgl);
378
379                let mut cpass = args.encoder.begin_compute_pass(&ComputePassDescriptor {
380                    label: Some("prefix cull"),
381                });
382
383                cpass.set_pipeline(&self.prefix_cull_pipeline);
384                cpass.set_bind_group(0, &bg_a, &[]);
385                cpass.dispatch(dispatch_count, 1, 1);
386
387                cpass.set_pipeline(&self.prefix_sum_pipeline);
388                let mut stride = 1_u32;
389                let mut iteration = 0;
390                while stride < count as u32 {
391                    let bind_group = if iteration % 2 == 0 { &bg_a } else { &bg_b };
392
393                    cpass.set_push_constants(0, bytemuck::cast_slice(&[stride]));
394                    cpass.set_bind_group(0, bind_group, &[]);
395                    cpass.dispatch(dispatch_count, 1, 1);
396                    stride <<= 1;
397                    iteration += 1;
398                }
399
400                let bind_group = if iteration % 2 == 0 { &bg_a } else { &bg_b };
401                cpass.set_pipeline(&self.prefix_output_pipeline);
402                cpass.set_bind_group(0, bind_group, &[]);
403                cpass.dispatch(dispatch_count, 1, 1);
404            } else {
405                let bg = BindGroupBuilder::new(Some("atomic culling bg"))
406                    .with_buffer(&args.input_buffer.inner)
407                    .with_buffer(&uniform_buffer)
408                    .with_buffer(&output_buffer)
409                    .with_buffer(&indirect_buffer)
410                    .build(args.device, &self.atomic_bgl);
411
412                let mut cpass = args.encoder.begin_compute_pass(&ComputePassDescriptor {
413                    label: Some("atomic cull"),
414                });
415
416                cpass.set_pipeline(&self.atomic_pipeline);
417                cpass.set_bind_group(0, &bg, &[]);
418                cpass.dispatch(dispatch_count, 1, 1);
419
420                drop(cpass);
421            }
422        }
423
424        let output_bg = args.device.create_bind_group(&BindGroupDescriptor {
425            label: Some("culling input bg"),
426            layout: &args.interfaces.culled_object_bgl,
427            entries: &[BindGroupEntry {
428                binding: 0,
429                resource: output_buffer.as_entire_binding(),
430            }],
431        });
432
433        CulledObjectSet {
434            calls: ModeData::GPU(GPUIndirectData { indirect_buffer, count }),
435            output_bg,
436        }
437    }
438}
439
440fn build_cull_data(device: &Device, objects: &[InternalObject]) -> Buffer {
441    profiling::scope!("Building Input Data");
442
443    let total_length = objects.len() * mem::size_of::<GPUCullingInput>();
444
445    let buffer = device.create_buffer(&BufferDescriptor {
446        label: Some("culling inputs"),
447        size: total_length as u64,
448        usage: BufferUsages::STORAGE,
449        mapped_at_creation: true,
450    });
451
452    let mut data = buffer.slice(..).get_mapped_range_mut();
453
454    // This unsafe block measured a bit faster in my tests, and as this is basically _the_ hot path, so this is worthwhile.
455    unsafe {
456        let data_ptr = data.as_mut_ptr() as *mut GPUCullingInput;
457
458        // Iterate over the objects
459        for idx in 0..objects.len() {
460            // We're iterating over 0..len so this is never going to be out of bounds
461            let object = objects.get_unchecked(idx);
462
463            // This is aligned, and we know the vector has enough bytes to hold this, so this is safe
464            data_ptr.add(idx).write_unaligned(object.input);
465        }
466    }
467
468    drop(data);
469    buffer.unmap();
470
471    buffer
472}
473
474pub fn run<'rpass>(rpass: &mut RenderPass<'rpass>, indirect_data: &'rpass GPUIndirectData) {
475    if indirect_data.count != 0 {
476        rpass.set_vertex_buffer(7, indirect_data.indirect_buffer.slice(16..));
477        rpass.multi_draw_indexed_indirect_count(
478            &indirect_data.indirect_buffer,
479            16,
480            &indirect_data.indirect_buffer,
481            0,
482            indirect_data.count as _,
483        );
484    }
485}