1use std::{borrow::Cow, mem, num::NonZeroU64};
2
3use glam::Mat4;
4use ordered_float::OrderedFloat;
5use rend3::{
6 resources::{CameraManager, GPUCullingInput, InternalObject, ObjectManager},
7 util::{bind_merge::BindGroupBuilder, frustum::ShaderFrustum},
8 ModeData,
9};
10use wgpu::{
11 util::{BufferInitDescriptor, DeviceExt},
12 BindGroupDescriptor, BindGroupEntry, BindGroupLayout, BindGroupLayoutDescriptor, BindGroupLayoutEntry, BindingType,
13 Buffer, BufferBindingType, BufferDescriptor, BufferUsages, CommandEncoder, ComputePassDescriptor, ComputePipeline,
14 ComputePipelineDescriptor, Device, PipelineLayoutDescriptor, PushConstantRange, RenderPass, ShaderModuleDescriptor,
15 ShaderModuleDescriptorSpirV, ShaderStages,
16};
17
18use crate::{
19 common::interfaces::{PerObjectData, ShaderInterfaces},
20 culling::{CulledObjectSet, GPUIndirectData, Sorting},
21 material::{PbrMaterial, TransparencyType},
22 shaders::SPIRV_SHADERS,
23};
24
25pub struct PreCulledBuffer {
26 inner: Buffer,
27 count: usize,
28}
29
30#[repr(C, align(16))]
31#[derive(Debug, Copy, Clone)]
32struct GPUCullingUniforms {
33 view: Mat4,
34 view_proj: Mat4,
35 frustum: ShaderFrustum,
36 object_count: u32,
37}
38
39unsafe impl bytemuck::Pod for GPUCullingUniforms {}
40unsafe impl bytemuck::Zeroable for GPUCullingUniforms {}
41
42pub struct GpuCullerPreCullArgs<'a> {
43 pub device: &'a Device,
44
45 pub camera: &'a CameraManager,
46
47 pub objects: &'a mut ObjectManager,
48
49 pub transparency: TransparencyType,
50 pub sort: Option<Sorting>,
51}
52
53pub struct GpuCullerCullArgs<'a> {
54 pub device: &'a Device,
55 pub encoder: &'a mut CommandEncoder,
56
57 pub interfaces: &'a ShaderInterfaces,
58
59 pub camera: &'a CameraManager,
60
61 pub input_buffer: &'a PreCulledBuffer,
62
63 pub sort: Option<Sorting>,
64}
65
66pub struct GpuCuller {
67 atomic_bgl: BindGroupLayout,
68 atomic_pipeline: ComputePipeline,
69
70 prefix_bgl: BindGroupLayout,
71 prefix_cull_pipeline: ComputePipeline,
72 prefix_sum_pipeline: ComputePipeline,
73 prefix_output_pipeline: ComputePipeline,
74}
75impl GpuCuller {
76 pub fn new(device: &Device) -> Self {
77 profiling::scope!("GpuCuller::new");
78
79 let atomic_bgl = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
80 label: Some("atomic culling pll"),
81 entries: &[
82 BindGroupLayoutEntry {
83 binding: 0,
84 visibility: ShaderStages::COMPUTE,
85 ty: BindingType::Buffer {
86 ty: BufferBindingType::Storage { read_only: true },
87 has_dynamic_offset: false,
88 min_binding_size: NonZeroU64::new(mem::size_of::<GPUCullingInput>() as _),
89 },
90 count: None,
91 },
92 BindGroupLayoutEntry {
93 binding: 1,
94 visibility: ShaderStages::COMPUTE,
95 ty: BindingType::Buffer {
96 ty: BufferBindingType::Uniform,
97 has_dynamic_offset: false,
98 min_binding_size: NonZeroU64::new(mem::size_of::<GPUCullingUniforms>() as _),
99 },
100 count: None,
101 },
102 BindGroupLayoutEntry {
103 binding: 2,
104 visibility: ShaderStages::COMPUTE,
105 ty: BindingType::Buffer {
106 ty: BufferBindingType::Storage { read_only: false },
107 has_dynamic_offset: false,
108 min_binding_size: NonZeroU64::new(mem::size_of::<PerObjectData>() as _),
109 },
110 count: None,
111 },
112 BindGroupLayoutEntry {
113 binding: 3,
114 visibility: ShaderStages::COMPUTE,
115 ty: BindingType::Buffer {
116 ty: BufferBindingType::Storage { read_only: false },
117 has_dynamic_offset: false,
118 min_binding_size: NonZeroU64::new(16 + 20),
119 },
120 count: None,
121 },
122 ],
123 });
124
125 let prefix_bgl = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
126 label: Some("prefix culling pll"),
127 entries: &[
128 BindGroupLayoutEntry {
129 binding: 0,
130 visibility: ShaderStages::COMPUTE,
131 ty: BindingType::Buffer {
132 ty: BufferBindingType::Storage { read_only: true },
133 has_dynamic_offset: false,
134 min_binding_size: NonZeroU64::new(mem::size_of::<GPUCullingInput>() as _),
135 },
136 count: None,
137 },
138 BindGroupLayoutEntry {
139 binding: 1,
140 visibility: ShaderStages::COMPUTE,
141 ty: BindingType::Buffer {
142 ty: BufferBindingType::Uniform,
143 has_dynamic_offset: false,
144 min_binding_size: NonZeroU64::new(mem::size_of::<GPUCullingUniforms>() as _),
145 },
146 count: None,
147 },
148 BindGroupLayoutEntry {
149 binding: 2,
150 visibility: ShaderStages::COMPUTE,
151 ty: BindingType::Buffer {
152 ty: BufferBindingType::Storage { read_only: false },
153 has_dynamic_offset: false,
154 min_binding_size: NonZeroU64::new(mem::size_of::<u32>() as _),
155 },
156 count: None,
157 },
158 BindGroupLayoutEntry {
159 binding: 3,
160 visibility: ShaderStages::COMPUTE,
161 ty: BindingType::Buffer {
162 ty: BufferBindingType::Storage { read_only: false },
163 has_dynamic_offset: false,
164 min_binding_size: NonZeroU64::new(mem::size_of::<u32>() as _),
165 },
166 count: None,
167 },
168 BindGroupLayoutEntry {
169 binding: 4,
170 visibility: ShaderStages::COMPUTE,
171 ty: BindingType::Buffer {
172 ty: BufferBindingType::Storage { read_only: false },
173 has_dynamic_offset: false,
174 min_binding_size: NonZeroU64::new(mem::size_of::<PerObjectData>() as _),
175 },
176 count: None,
177 },
178 BindGroupLayoutEntry {
179 binding: 5,
180 visibility: ShaderStages::COMPUTE,
181 ty: BindingType::Buffer {
182 ty: BufferBindingType::Storage { read_only: false },
183 has_dynamic_offset: false,
184 min_binding_size: NonZeroU64::new(16 + 20),
185 },
186 count: None,
187 },
188 ],
189 });
190
191 let atomic_pll = device.create_pipeline_layout(&PipelineLayoutDescriptor {
192 label: Some("atomic culling pll"),
193 bind_group_layouts: &[&atomic_bgl],
194 push_constant_ranges: &[],
195 });
196
197 let prefix_pll = device.create_pipeline_layout(&PipelineLayoutDescriptor {
198 label: Some("prefix culling pll"),
199 bind_group_layouts: &[&prefix_bgl],
200 push_constant_ranges: &[],
201 });
202
203 let prefix_sum_pll = device.create_pipeline_layout(&PipelineLayoutDescriptor {
204 label: Some("prefix sum pll"),
205 bind_group_layouts: &[&prefix_bgl],
206 push_constant_ranges: &[PushConstantRange {
207 stages: ShaderStages::COMPUTE,
208 range: 0..4,
209 }],
210 });
211
212 let atomic_sm = unsafe {
213 device.create_shader_module_spirv(&ShaderModuleDescriptorSpirV {
214 label: Some("cull-atomic-cull"),
215 source: wgpu::util::make_spirv_raw(
216 SPIRV_SHADERS.get_file("cull-atomic-cull.comp.spv").unwrap().contents(),
217 ),
218 })
219 };
220
221 let prefix_cull_sm = device.create_shader_module(&ShaderModuleDescriptor {
222 label: Some("cull-prefix-cull"),
223 source: wgpu::util::make_spirv(SPIRV_SHADERS.get_file("cull-prefix-cull.comp.spv").unwrap().contents()),
224 });
225
226 let prefix_sum_sm = device.create_shader_module(&ShaderModuleDescriptor {
227 label: Some("cull-prefix-sum"),
228 source: wgpu::util::make_spirv(SPIRV_SHADERS.get_file("cull-prefix-sum.comp.spv").unwrap().contents()),
229 });
230
231 let prefix_output_sm = device.create_shader_module(&ShaderModuleDescriptor {
232 label: Some("cull-prefix-output"),
233 source: wgpu::util::make_spirv(
234 SPIRV_SHADERS
235 .get_file("cull-prefix-output.comp.spv")
236 .unwrap()
237 .contents(),
238 ),
239 });
240
241 let atomic_pipeline = device.create_compute_pipeline(&ComputePipelineDescriptor {
242 label: Some("atomic culling pl"),
243 layout: Some(&atomic_pll),
244 module: &atomic_sm,
245 entry_point: "main",
246 });
247
248 let prefix_cull_pipeline = device.create_compute_pipeline(&ComputePipelineDescriptor {
249 label: Some("prefix cull pl"),
250 layout: Some(&prefix_pll),
251 module: &prefix_cull_sm,
252 entry_point: "main",
253 });
254
255 let prefix_sum_pipeline = device.create_compute_pipeline(&ComputePipelineDescriptor {
256 label: Some("prefix sum pl"),
257 layout: Some(&prefix_sum_pll),
258 module: &prefix_sum_sm,
259 entry_point: "main",
260 });
261
262 let prefix_output_pipeline = device.create_compute_pipeline(&ComputePipelineDescriptor {
263 label: Some("prefix output pl"),
264 layout: Some(&prefix_pll),
265 module: &prefix_output_sm,
266 entry_point: "main",
267 });
268
269 Self {
270 atomic_bgl,
271 atomic_pipeline,
272 prefix_bgl,
273 prefix_cull_pipeline,
274 prefix_sum_pipeline,
275 prefix_output_pipeline,
276 }
277 }
278
279 pub fn pre_cull(&self, args: GpuCullerPreCullArgs<'_>) -> PreCulledBuffer {
280 let objects = args.objects.get_objects::<PbrMaterial>(args.transparency as u64);
281 let count = objects.len();
282
283 let objects = if let Some(sorting) = args.sort {
284 profiling::scope!("Sorting");
285
286 let mut sort_objects = objects.to_vec();
287
288 let camera_location = args.camera.location().into();
289
290 match sorting {
291 Sorting::FrontToBack => {
292 sort_objects
293 .sort_unstable_by_key(|o| OrderedFloat(o.mesh_location().distance_squared(camera_location)));
294 }
295 Sorting::BackToFront => {
296 sort_objects
297 .sort_unstable_by_key(|o| OrderedFloat(-o.mesh_location().distance_squared(camera_location)));
298 }
299 }
300
301 Cow::Owned(sort_objects)
302 } else {
303 Cow::Borrowed(objects)
304 };
305 let buffer = build_cull_data(args.device, &objects);
306
307 PreCulledBuffer { inner: buffer, count }
308 }
309
310 pub fn cull(&self, args: GpuCullerCullArgs<'_>) -> CulledObjectSet {
311 profiling::scope!("Record GPU Culling");
312
313 let count = args.input_buffer.count;
314
315 let uniform = GPUCullingUniforms {
316 view: args.camera.view(),
317 view_proj: args.camera.view_proj(),
318 frustum: ShaderFrustum::from_matrix(args.camera.proj()),
319 object_count: count as u32,
320 };
321
322 let uniform_buffer = args.device.create_buffer_init(&BufferInitDescriptor {
323 label: Some("gpu culling uniform buffer"),
324 contents: bytemuck::bytes_of(&uniform),
325 usage: BufferUsages::UNIFORM,
326 });
327
328 let output_buffer = args.device.create_buffer(&BufferDescriptor {
329 label: Some("culling output"),
330 size: (count.max(1) * mem::size_of::<PerObjectData>()) as _,
331 usage: BufferUsages::STORAGE,
332 mapped_at_creation: false,
333 });
334
335 let indirect_buffer = args.device.create_buffer(&BufferDescriptor {
336 label: Some("indirect buffer"),
337 size: (count * 20 + 16) as _,
339 usage: BufferUsages::STORAGE | BufferUsages::INDIRECT | BufferUsages::VERTEX,
340 mapped_at_creation: false,
341 });
342
343 if count != 0 {
344 let dispatch_count = ((count + 255) / 256) as u32;
345
346 if args.sort.is_some() {
347 let buffer_a = args.device.create_buffer(&BufferDescriptor {
348 label: Some("cull result index buffer A"),
349 size: (count * 4) as _,
350 usage: BufferUsages::STORAGE,
351 mapped_at_creation: false,
352 });
353
354 let buffer_b = args.device.create_buffer(&BufferDescriptor {
355 label: Some("cull result index buffer B"),
356 size: (count * 4) as _,
357 usage: BufferUsages::STORAGE,
358 mapped_at_creation: false,
359 });
360
361 let bg_a = BindGroupBuilder::new(Some("prefix cull A bg"))
362 .with_buffer(&args.input_buffer.inner)
363 .with_buffer(&uniform_buffer)
364 .with_buffer(&buffer_a)
365 .with_buffer(&buffer_b)
366 .with_buffer(&output_buffer)
367 .with_buffer(&indirect_buffer)
368 .build(args.device, &self.prefix_bgl);
369
370 let bg_b = BindGroupBuilder::new(Some("prefix cull B bg"))
371 .with_buffer(&args.input_buffer.inner)
372 .with_buffer(&uniform_buffer)
373 .with_buffer(&buffer_b)
374 .with_buffer(&buffer_a)
375 .with_buffer(&output_buffer)
376 .with_buffer(&indirect_buffer)
377 .build(args.device, &self.prefix_bgl);
378
379 let mut cpass = args.encoder.begin_compute_pass(&ComputePassDescriptor {
380 label: Some("prefix cull"),
381 });
382
383 cpass.set_pipeline(&self.prefix_cull_pipeline);
384 cpass.set_bind_group(0, &bg_a, &[]);
385 cpass.dispatch(dispatch_count, 1, 1);
386
387 cpass.set_pipeline(&self.prefix_sum_pipeline);
388 let mut stride = 1_u32;
389 let mut iteration = 0;
390 while stride < count as u32 {
391 let bind_group = if iteration % 2 == 0 { &bg_a } else { &bg_b };
392
393 cpass.set_push_constants(0, bytemuck::cast_slice(&[stride]));
394 cpass.set_bind_group(0, bind_group, &[]);
395 cpass.dispatch(dispatch_count, 1, 1);
396 stride <<= 1;
397 iteration += 1;
398 }
399
400 let bind_group = if iteration % 2 == 0 { &bg_a } else { &bg_b };
401 cpass.set_pipeline(&self.prefix_output_pipeline);
402 cpass.set_bind_group(0, bind_group, &[]);
403 cpass.dispatch(dispatch_count, 1, 1);
404 } else {
405 let bg = BindGroupBuilder::new(Some("atomic culling bg"))
406 .with_buffer(&args.input_buffer.inner)
407 .with_buffer(&uniform_buffer)
408 .with_buffer(&output_buffer)
409 .with_buffer(&indirect_buffer)
410 .build(args.device, &self.atomic_bgl);
411
412 let mut cpass = args.encoder.begin_compute_pass(&ComputePassDescriptor {
413 label: Some("atomic cull"),
414 });
415
416 cpass.set_pipeline(&self.atomic_pipeline);
417 cpass.set_bind_group(0, &bg, &[]);
418 cpass.dispatch(dispatch_count, 1, 1);
419
420 drop(cpass);
421 }
422 }
423
424 let output_bg = args.device.create_bind_group(&BindGroupDescriptor {
425 label: Some("culling input bg"),
426 layout: &args.interfaces.culled_object_bgl,
427 entries: &[BindGroupEntry {
428 binding: 0,
429 resource: output_buffer.as_entire_binding(),
430 }],
431 });
432
433 CulledObjectSet {
434 calls: ModeData::GPU(GPUIndirectData { indirect_buffer, count }),
435 output_bg,
436 }
437 }
438}
439
440fn build_cull_data(device: &Device, objects: &[InternalObject]) -> Buffer {
441 profiling::scope!("Building Input Data");
442
443 let total_length = objects.len() * mem::size_of::<GPUCullingInput>();
444
445 let buffer = device.create_buffer(&BufferDescriptor {
446 label: Some("culling inputs"),
447 size: total_length as u64,
448 usage: BufferUsages::STORAGE,
449 mapped_at_creation: true,
450 });
451
452 let mut data = buffer.slice(..).get_mapped_range_mut();
453
454 unsafe {
456 let data_ptr = data.as_mut_ptr() as *mut GPUCullingInput;
457
458 for idx in 0..objects.len() {
460 let object = objects.get_unchecked(idx);
462
463 data_ptr.add(idx).write_unaligned(object.input);
465 }
466 }
467
468 drop(data);
469 buffer.unmap();
470
471 buffer
472}
473
474pub fn run<'rpass>(rpass: &mut RenderPass<'rpass>, indirect_data: &'rpass GPUIndirectData) {
475 if indirect_data.count != 0 {
476 rpass.set_vertex_buffer(7, indirect_data.indirect_buffer.slice(16..));
477 rpass.multi_draw_indexed_indirect_count(
478 &indirect_data.indirect_buffer,
479 16,
480 &indirect_data.indirect_buffer,
481 0,
482 indirect_data.count as _,
483 );
484 }
485}