1use std::sync::Arc;
2
3use ambient_core::gpu_ecs::{GpuWorldShaderModuleKey, ENTITIES_BIND_GROUP};
4use ambient_ecs::{EntityId, World};
5use ambient_gpu::{
6    gpu::{Gpu, GpuKey},
7    multi_buffer::TypedMultiBuffer,
8    shader_module::{BindGroupDesc, ComputePipeline, Shader, ShaderIdent, ShaderModule},
9    typed_buffer::TypedBuffer,
10};
11use ambient_std::{
12    asset_cache::{AssetCache, SyncAssetKey, SyncAssetKeyExt},
13    include_file,
14};
15use glam::{uvec2, UVec2, UVec3};
16use parking_lot::Mutex;
17use wgpu::{
18    BindGroupEntry, BindGroupLayout, BindGroupLayoutEntry, BindingType, BufferBindingType,
19    ShaderStages,
20};
21
22use crate::{get_mesh_meta_module, GLOBALS_BIND_GROUP};
23
24use super::{get_defs_module, DrawIndexedIndirect, PrimitiveIndex};
25
26#[repr(C)]
27#[derive(Debug, Clone, Copy, Default, bytemuck::Pod, bytemuck::Zeroable)]
28pub struct CollectPrimitive {
29    entity_loc: UVec2,
30    primitive_index: u32,
31    material_index: u32,
32}
33
34impl CollectPrimitive {
35    pub fn from_primitive(
36        world: &World,
37        id: EntityId,
38        primitive_index: PrimitiveIndex,
39        material_index: u32,
40    ) -> Self {
41        let loc = world.entity_loc(id).unwrap();
42        Self {
43            entity_loc: uvec2(loc.archetype as u32, loc.index as u32),
44            primitive_index: primitive_index as u32,
45            material_index,
46        }
47    }
48}
49
50pub struct RendererCollectState {
51    pub params: TypedBuffer<RendererCollectParams>,
52    pub commands: TypedBuffer<DrawIndexedIndirect>,
53    pub counts: TypedBuffer<u32>,
54    #[cfg(target_os = "macos")]
55    pub counts_cpu: Arc<Mutex<Vec<u32>>>,
56    pub material_layouts: TypedBuffer<UVec2>,
57}
58impl RendererCollectState {
59    pub fn new(assets: &AssetCache) -> Self {
60        log::debug!("Setting up renderer collect state");
61        let gpu = GpuKey.get(assets);
62        Self {
63            params: TypedBuffer::new(
64                gpu.clone(),
65                "RendererCollectState.params",
66                1,
67                1,
68                wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
69            ),
70            commands: TypedBuffer::new(
71                gpu.clone(),
72                "RendererCollectState.commands",
73                1,
74                1,
75                wgpu::BufferUsages::STORAGE
76                    | wgpu::BufferUsages::COPY_DST
77                    | wgpu::BufferUsages::COPY_SRC
78                    | wgpu::BufferUsages::INDIRECT,
79            ),
80            counts: TypedBuffer::new(
81                gpu.clone(),
82                "RendererCollectState.counts",
83                1,
84                1,
85                wgpu::BufferUsages::STORAGE
86                    | wgpu::BufferUsages::COPY_DST
87                    | wgpu::BufferUsages::COPY_SRC
88                    | wgpu::BufferUsages::INDIRECT,
89            ),
90            #[cfg(target_os = "macos")]
91            counts_cpu: Arc::new(Mutex::new(Vec::new())),
92            material_layouts: TypedBuffer::new(
93                gpu,
94                "RendererCollectState.materials",
95                1,
96                1,
97                wgpu::BufferUsages::STORAGE
98                    | wgpu::BufferUsages::COPY_DST
99                    | wgpu::BufferUsages::COPY_SRC
100                    | wgpu::BufferUsages::INDIRECT,
101            ),
102        }
103    }
104    pub fn set_camera(&self, camera: u32) {
105        let collect_params = RendererCollectParams {
106            camera,
107            _padding: Default::default(),
108        };
109        self.params.write(0, &[collect_params]);
110    }
111}
112
113#[repr(C)]
114#[derive(Debug, Clone, Copy, Default, bytemuck::Pod, bytemuck::Zeroable)]
115pub struct RendererCollectParams {
116    pub camera: u32,
117    pub _padding: UVec3,
118}
119
120const COLLECT_WORKGROUP_SIZE: u32 = 32;
121const COLLECT_CHUNK_SIZE: u32 = 256;
122
123#[allow(dead_code)]
125pub struct RendererCollect {
126    gpu: Arc<Gpu>,
127    pipeline: ComputePipeline,
128    layout: Arc<BindGroupLayout>,
129    assets: AssetCache,
130}
131
132impl RendererCollect {
133    pub fn new(assets: &AssetCache) -> Self {
134        let gpu = GpuKey.get(assets);
135
136        let layout_desc = BindGroupDesc {
137            label: "RendererCollect.layout".into(),
138            entries: vec![
139                BindGroupLayoutEntry {
140                    binding: 0,
141                    visibility: ShaderStages::COMPUTE,
142                    ty: BindingType::Buffer {
143                        ty: BufferBindingType::Uniform,
144                        has_dynamic_offset: false,
145                        min_binding_size: None,
146                    },
147                    count: None,
148                },
149                BindGroupLayoutEntry {
150                    binding: 1,
151                    visibility: ShaderStages::COMPUTE,
152                    ty: BindingType::Buffer {
153                        ty: BufferBindingType::Storage { read_only: true },
154                        has_dynamic_offset: false,
155                        min_binding_size: None,
156                    },
157                    count: None,
158                },
159                BindGroupLayoutEntry {
160                    binding: 2,
161                    visibility: ShaderStages::COMPUTE,
162                    ty: BindingType::Buffer {
163                        ty: BufferBindingType::Storage { read_only: false },
164                        has_dynamic_offset: false,
165                        min_binding_size: None,
166                    },
167                    count: None,
168                },
169                BindGroupLayoutEntry {
170                    binding: 3,
171                    visibility: ShaderStages::COMPUTE,
172                    ty: BindingType::Buffer {
173                        ty: BufferBindingType::Storage { read_only: false },
174                        has_dynamic_offset: false,
175                        min_binding_size: None,
176                    },
177                    count: None,
178                },
179                BindGroupLayoutEntry {
180                    binding: 4,
181                    visibility: ShaderStages::COMPUTE,
182                    ty: BindingType::Buffer {
183                        ty: BufferBindingType::Storage { read_only: true },
184                        has_dynamic_offset: false,
185                        min_binding_size: None,
186                    },
187                    count: None,
188                },
189            ],
190        };
191
192        let layout = layout_desc.load(assets.clone());
193        let shader = Shader::new(
194            assets,
195            "collect",
196            &[
197                GLOBALS_BIND_GROUP,
198                ENTITIES_BIND_GROUP,
199                "RendererCollect.layout",
200            ],
201            &ShaderModule::new("RendererCollect", include_file!("collect.wgsl"))
202                .with_ident(ShaderIdent::constant(
203                    "COLLECT_WORKGROUP_SIZE",
204                    COLLECT_WORKGROUP_SIZE,
205                ))
206                .with_ident(ShaderIdent::constant(
207                    "COLLECT_CHUNK_SIZE",
208                    COLLECT_CHUNK_SIZE,
209                ))
210                .with_binding_desc(layout_desc)
211                .with_dependency(get_defs_module())
212                .with_dependency(get_mesh_meta_module(0))
213                .with_dependency(GpuWorldShaderModuleKey { read_only: true }.get(assets)),
214        )
215        .unwrap();
216
217        let pipeline = shader.to_compute_pipeline(&gpu, "main");
218
219        Self {
220            gpu,
221            pipeline,
222            layout,
223            assets: assets.clone(),
224        }
225    }
226
227    #[allow(clippy::too_many_arguments)]
228    #[allow(clippy::ptr_arg)]
229    #[ambient_profiling::function]
230    pub fn run(
231        &self,
232        encoder: &mut wgpu::CommandEncoder,
233        _post_submit: &mut Vec<Box<dyn FnOnce() + Send + Send>>,
234        mesh_meta_bind_group: &wgpu::BindGroup,
235        entities_bind_group: &wgpu::BindGroup,
236        input_primitives: &TypedMultiBuffer<CollectPrimitive>,
237        output: &mut RendererCollectState,
238        primitives_count: u32,
239        material_layouts: Vec<UVec2>,
240    ) {
241        if primitives_count == 0 {
242            return;
243        }
244
245        output.commands.resize(primitives_count as u64, true);
246        let counts = vec![0; material_layouts.len()];
247        output.counts.fill(&counts, |_| {});
248        output.material_layouts.fill(&material_layouts, |_| {});
249
250        let bind_group = self
251            .gpu
252            .device
253            .create_bind_group(&wgpu::BindGroupDescriptor {
254                label: None,
255                layout: &self.layout,
256                entries: &[
257                    BindGroupEntry {
258                        binding: 0,
259                        resource: output.params.buffer().as_entire_binding(),
260                    },
261                    BindGroupEntry {
262                        binding: 1,
263                        resource: input_primitives.buffer().as_entire_binding(),
264                    },
265                    BindGroupEntry {
266                        binding: 2,
267                        resource: output.commands.buffer().as_entire_binding(),
268                    },
269                    BindGroupEntry {
270                        binding: 3,
271                        resource: output.counts.buffer().as_entire_binding(),
272                    },
273                    BindGroupEntry {
274                        binding: 4,
275                        resource: output.material_layouts.buffer().as_entire_binding(),
276                    },
277                ],
278            });
279
280        {
281            let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
282                label: Some("Collect"),
283            });
284            cpass.set_pipeline(self.pipeline.pipeline());
285
286            for (i, bind_group) in [mesh_meta_bind_group, entities_bind_group, &bind_group]
287                .iter()
288                .enumerate()
289            {
290                cpass.set_bind_group(i as _, bind_group, &[]);
291            }
292
293            let count = (primitives_count as f32 / COLLECT_WORKGROUP_SIZE as f32).ceil() as u32;
294            let width = if count < COLLECT_CHUNK_SIZE {
295                count
296            } else {
297                COLLECT_CHUNK_SIZE
298            };
299            let height = (count as f32 / COLLECT_CHUNK_SIZE as f32).ceil() as u32;
300            cpass.dispatch_workgroups(width, height, 1);
301        }
302
303        #[cfg(target_os = "macos")]
304        {
305            use ambient_core::RuntimeKey;
306
307            let buffs = CollectCountStagingBuffersKey.get(&self.assets);
308            let staging = buffs.take_buffer(output.counts.len());
309            encoder.copy_buffer_to_buffer(
310                output.counts.buffer(),
311                0,
312                staging.buffer(),
313                0,
314                output.counts.byte_size(),
315            );
316            let counts_res = output.counts_cpu.clone();
317            let runtime = RuntimeKey.get(&self.assets);
318            _post_submit.push(Box::new(move || {
319                runtime.spawn(async move {
320                    if let Ok(res) = staging.read(.., false).await {
321                        *counts_res.lock() = res;
322                        buffs.return_buffer(staging);
323                    }
324                });
325            }))
326        }
327    }
328}
329
330#[derive(Clone, Debug)]
331struct CollectCountStagingBuffersKey;
332impl SyncAssetKey<CollectCountStagingBuffers> for CollectCountStagingBuffersKey {
333    fn load(&self, assets: AssetCache) -> CollectCountStagingBuffers {
334        CollectCountStagingBuffers::new(GpuKey.get(&assets))
335    }
336}
337
338#[derive(Clone)]
339#[allow(dead_code)]
340struct CollectCountStagingBuffers {
341    gpu: Arc<Gpu>,
342    buffers: Arc<Mutex<Vec<TypedBuffer<u32>>>>,
343}
344impl CollectCountStagingBuffers {
345    fn new(gpu: Arc<Gpu>) -> Self {
346        Self {
347            gpu,
348            buffers: Arc::new(Mutex::new(Vec::new())),
349        }
350    }
351
352    #[cfg(target_os = "macos")]
353    fn take_buffer(&self, size: u64) -> TypedBuffer<u32> {
354        match self.buffers.lock().pop() {
355            Some(mut buffer) => {
356                buffer.resize(size, false);
357                buffer
358            }
359            None => TypedBuffer::<u32>::new(
360                self.gpu.clone(),
361                "RendererCollectState.counts_staging",
362                size,
363                size,
364                wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
365            ),
366        }
367    }
368
369    #[cfg(target_os = "macos")]
370    fn return_buffer(&self, buffer: TypedBuffer<u32>) {
371        self.buffers.lock().push(buffer)
372    }
373}