ambient_renderer/
tree_renderer.rs

1use std::{
2    collections::{HashMap, HashSet},
3    sync::Arc,
4};
5
6use ambient_ecs::{query, ArchetypeFilter, EntityId, FramedEventsReader, QueryState, World};
7use ambient_gpu::{
8    gpu::Gpu,
9    multi_buffer::{MultiBufferSizeStrategy, SubBufferId, TypedMultiBuffer},
10    shader_module::{GraphicsPipeline, GraphicsPipelineInfo},
11};
12use ambient_std::asset_cache::AssetCache;
13use glam::{uvec2, UVec2};
14use itertools::Itertools;
15use wgpu::DepthBiasState;
16
17use super::{
18    double_sided, lod::cpu_lod_visible, primitives, CollectPrimitive, DrawIndexedIndirect, FSMain,
19    PrimitiveIndex, RendererCollectState, RendererResources, RendererShader, SharedMaterial,
20};
21use crate::{bind_groups::BindGroups, is_transparent, RendererConfig};
22
23pub struct TreeRendererConfig {
24    pub gpu: Arc<Gpu>,
25    pub renderer_config: RendererConfig,
26    pub assets: AssetCache,
27    pub filter: ArchetypeFilter,
28    pub targets: Vec<Option<wgpu::ColorTargetState>>,
29    pub renderer_resources: RendererResources,
30    pub fs_main: FSMain,
31    pub opaque_only: bool,
32    pub depth_stencil: bool,
33    pub cull_mode: Option<wgpu::Face>,
34    pub depth_bias: DepthBiasState,
35}
36
37pub struct TreeRenderer {
38    config: Arc<TreeRendererConfig>,
39    tree: HashMap<String, ShaderNode>,
40    entity_primitive_count: HashMap<EntityId, usize>,
41    primitives_lookup: HashMap<(EntityId, PrimitiveIndex), (String, String, usize)>,
42    loc_changed_reader: FramedEventsReader<EntityId>,
43
44    primitives: TypedMultiBuffer<CollectPrimitive>,
45    primitives_bind_group: Option<wgpu::BindGroup>,
46    spawn_qs: QueryState,
47    despawn_qs: QueryState,
48    material_indices: MaterialIndices,
49}
50impl TreeRenderer {
51    pub fn new(config: TreeRendererConfig) -> Self {
52        Self {
53            tree: HashMap::new(),
54            entity_primitive_count: HashMap::new(),
55            primitives_lookup: HashMap::new(),
56            loc_changed_reader: FramedEventsReader::new(),
57
58            primitives_bind_group: None,
59            primitives: TypedMultiBuffer::new(
60                config.gpu.clone(),
61                "TreeRenderer.primitives",
62                wgpu::BufferUsages::STORAGE
63                    | wgpu::BufferUsages::COPY_DST
64                    | wgpu::BufferUsages::COPY_SRC
65                    | wgpu::BufferUsages::INDIRECT,
66                MultiBufferSizeStrategy::Pow2,
67            ),
68
69            config: Arc::new(config),
70            spawn_qs: QueryState::new(),
71            despawn_qs: QueryState::new(),
72            material_indices: MaterialIndices::new(),
73        }
74    }
75    fn create_primitives_bind_group(
76        gpu: &Gpu,
77        layout: &wgpu::BindGroupLayout,
78        buffer: &wgpu::Buffer,
79    ) -> wgpu::BindGroup {
80        gpu.device.create_bind_group(&wgpu::BindGroupDescriptor {
81            layout,
82            entries: &[wgpu::BindGroupEntry {
83                binding: 0,
84                resource: buffer.as_entire_binding(),
85            }],
86            label: Some("TreeRenderer.primitives"),
87        })
88    }
89    #[ambient_profiling::function]
90    pub fn update(&mut self, world: &mut World) {
91        let mut to_update = HashSet::new();
92        let mut spawn_qs = std::mem::replace(&mut self.spawn_qs, QueryState::new());
93        let mut despawn_qs = std::mem::replace(&mut self.despawn_qs, QueryState::new());
94
95        for (id, (primitives,)) in query((primitives().changed(),))
96            .optional_changed(cpu_lod_visible())
97            .filter(&self.config.filter)
98            .iter(world, Some(&mut spawn_qs))
99        {
100            if let Some(primitive_count) = self.entity_primitive_count.get(&id) {
101                for primitive_index in 0..*primitive_count {
102                    if let Some(update) = self.remove_primitive(id, primitive_index) {
103                        to_update.insert(update);
104                    }
105                }
106            }
107            for (primitive_index, primitive) in primitives.iter().enumerate() {
108                let primitive_shader =
109                    (primitive.shader)(&self.config.assets, &self.config.renderer_config);
110                if let Some(update) = self.insert(
111                    world,
112                    id,
113                    primitive_index,
114                    &primitive_shader,
115                    &primitive.material,
116                ) {
117                    to_update.insert(update);
118                }
119            }
120            self.entity_primitive_count.insert(id, primitives.len());
121        }
122
123        for (id, _) in query(())
124            .incl(primitives())
125            .filter(&self.config.filter)
126            .despawned()
127            .iter(world, Some(&mut despawn_qs))
128        {
129            if let Some(primitive_count) = self.entity_primitive_count.get(&id) {
130                for primitive_index in 0..*primitive_count {
131                    if let Some(update) = self.remove_primitive(id, primitive_index) {
132                        to_update.insert(update);
133                    }
134                }
135            }
136            self.entity_primitive_count.remove(&id);
137        }
138
139        self.spawn_qs = spawn_qs;
140        self.despawn_qs = despawn_qs;
141        self.clean_empty();
142        for (_, id) in self.loc_changed_reader.iter(world.loc_changed()) {
143            if let Ok(primitives) = world.get_ref(*id, primitives()) {
144                for primivite_index in 0..primitives.len() {
145                    if let Some((shader_id, material_id, _)) =
146                        self.primitives_lookup.get(&(*id, primivite_index))
147                    {
148                        to_update.insert((shader_id.clone(), material_id.clone()));
149                    }
150                }
151            }
152        }
153
154        let mut encoder =
155            self.config
156                .gpu
157                .device
158                .create_command_encoder(&wgpu::CommandEncoderDescriptor {
159                    label: Some("TreeRenderer.update"),
160                });
161        let mut primitives_to_write = Vec::new();
162        for (shader_id, material_id) in to_update.into_iter() {
163            if let Some(shader) = self.tree.get(&shader_id) {
164                if let Some(mat) = shader.tree.get(&material_id) {
165                    let primitives = mat
166                        .primitives
167                        .iter()
168                        .map(|(id, primitive_index)| {
169                            CollectPrimitive::from_primitive(
170                                world,
171                                *id,
172                                *primitive_index,
173                                mat.material_index,
174                            )
175                        })
176                        .collect_vec();
177                    self.primitives
178                        .resize_buffer_with_encoder(
179                            &mut encoder,
180                            mat.primitives_subbuffer,
181                            primitives.len() as u64,
182                        )
183                        .unwrap();
184                    primitives_to_write.push((mat.primitives_subbuffer, primitives));
185                }
186            }
187        }
188
189        self.config.gpu.queue.submit(Some(encoder.finish()));
190        for (subbuffer, primitives) in primitives_to_write.into_iter() {
191            self.primitives.write(subbuffer, 0, &primitives).unwrap();
192        }
193
194        for node in self.tree.values_mut() {
195            for mat in node.tree.values_mut() {
196                // TODO: Materials can be shared between many renderers, so this should be moved
197                // somewhere where it's done just once for all of them
198                mat.material.update(world);
199            }
200        }
201
202        self.primitives_bind_group = if self.primitives.total_len() > 0 {
203            Some(Self::create_primitives_bind_group(
204                &self.config.gpu,
205                &self.config.renderer_resources.primitives_layout,
206                self.primitives.buffer(),
207            ))
208        } else {
209            None
210        };
211    }
212    pub fn run_collect(
213        &self,
214        encoder: &mut wgpu::CommandEncoder,
215        post_submit: &mut Vec<Box<dyn FnOnce() + Send + Send>>,
216        resources_bind_group: &wgpu::BindGroup,
217        entities_bind_group: &wgpu::BindGroup,
218        collect_state: &mut RendererCollectState,
219    ) {
220        let mut material_layouts = vec![UVec2::ZERO; self.material_indices.counter as usize];
221        for node in self.tree.values() {
222            for mat in node.tree.values() {
223                let offset = self
224                    .primitives
225                    .buffer_offset(mat.primitives_subbuffer)
226                    .unwrap();
227                material_layouts[mat.material_index as usize] =
228                    uvec2(offset as u32, mat.primitives.len() as u32);
229            }
230        }
231
232        self.config.renderer_resources.collect.run(
233            encoder,
234            post_submit,
235            resources_bind_group,
236            entities_bind_group,
237            &self.primitives,
238            collect_state,
239            self.primitives.total_len() as u32,
240            material_layouts,
241        );
242    }
243
244    fn insert(
245        &mut self,
246        world: &World,
247        id: EntityId,
248        primitive_index: usize,
249        shader: &Arc<RendererShader>,
250        material: &SharedMaterial,
251    ) -> Option<(String, String)> {
252        let transparent = is_transparent(world, id, material, shader);
253        if (!transparent || !self.config.opaque_only)
254            && world.get(id, cpu_lod_visible()).unwrap_or(true)
255        {
256            let config = &self.config;
257            let double_sided = world
258                .get(id, double_sided())
259                .unwrap_or(material.double_sided().unwrap_or(shader.double_sided));
260            let shader_id = format!("{}-{}", shader.id, double_sided);
261            let node = self
262                .tree
263                .entry(shader_id.clone())
264                .or_insert_with(|| ShaderNode::new(config, shader.clone(), double_sided));
265
266            let mat = node
267                .tree
268                .entry(material.id().to_string())
269                .or_insert_with(|| MaterialNode {
270                    material_index: self.material_indices.acquire_index(),
271                    primitives_subbuffer: self.primitives.create_buffer(None),
272                    material: material.clone(),
273                    primitives: Vec::new(),
274                });
275            self.primitives_lookup.insert(
276                (id, primitive_index),
277                (
278                    shader_id.clone(),
279                    material.id().to_string(),
280                    mat.primitives.len(),
281                ),
282            );
283            mat.primitives.push((id, primitive_index));
284            Some((shader_id, material.id().to_string()))
285        } else {
286            None
287        }
288    }
289
290    fn remove_primitive(
291        &mut self,
292        id: EntityId,
293        primitive_index: usize,
294    ) -> Option<(String, String)> {
295        if let Some((shader_id, material_id, index)) =
296            self.primitives_lookup.remove(&(id, primitive_index))
297        {
298            let shader = self.tree.get_mut(&shader_id).unwrap();
299            let material = shader.tree.get_mut(&material_id).unwrap();
300            let is_last = material.primitives.len() == index + 1;
301            if !is_last {
302                if let Some(last_id) = material.primitives.last() {
303                    self.primitives_lookup.get_mut(last_id).unwrap().2 = index;
304                }
305            }
306            material.primitives.swap_remove(index);
307            Some((shader_id, material_id))
308        } else {
309            None
310        }
311    }
312    fn clean_empty(&mut self) {
313        for node in self.tree.values_mut() {
314            node.tree.retain(|_, mat| {
315                let to_remove = mat.primitives.is_empty();
316                if to_remove {
317                    self.primitives
318                        .remove_buffer(mat.primitives_subbuffer)
319                        .unwrap();
320                    self.material_indices.release_index(mat.material_index);
321                }
322                !to_remove
323            });
324        }
325        self.tree.retain(|_, v| !v.is_empty());
326    }
327    #[ambient_profiling::function]
328    pub fn render<'a>(
329        &'a self,
330        render_pass: &mut wgpu::RenderPass<'a>,
331        collect_state: &'a RendererCollectState,
332        bind_groups: &BindGroups<'a>,
333    ) {
334        let primitives_bind_group = if let Some(primitives_bind_group) = &self.primitives_bind_group
335        {
336            primitives_bind_group
337        } else {
338            return; // Nothing to render
339        };
340
341        #[cfg(target_os = "macos")]
342        let counts = collect_state.counts_cpu.lock().clone();
343
344        let mut is_bound = false;
345
346        for node in self.tree.values() {
347            render_pass.set_pipeline(node.pipeline.pipeline());
348            // Bind on first invocation
349            let bind_groups = [
350                bind_groups.globals,
351                bind_groups.entities,
352                primitives_bind_group,
353            ];
354            if !is_bound {
355                for (i, bind_group) in bind_groups.iter().enumerate() {
356                    render_pass.set_bind_group(i as _, bind_group, &[]);
357                    is_bound = true
358                }
359            }
360
361            for mat in node.tree.values() {
362                let material = &mat.material;
363
364                render_pass.set_bind_group(bind_groups.len() as _, material.bind_group(), &[]);
365
366                let offset = self
367                    .primitives
368                    .buffer_offset(mat.primitives_subbuffer)
369                    .unwrap();
370                #[cfg(not(target_os = "macos"))]
371                {
372                    render_pass.multi_draw_indexed_indirect_count(
373                        collect_state.commands.buffer(),
374                        offset * std::mem::size_of::<DrawIndexedIndirect>() as u64,
375                        collect_state.counts.buffer(),
376                        mat.material_index as u64 * std::mem::size_of::<u32>() as u64,
377                        mat.primitives.len() as u32,
378                    );
379                }
380                #[cfg(target_os = "macos")]
381                {
382                    if let Some(count) = counts.get(mat.material_index as usize) {
383                        for i in 0..*count {
384                            render_pass.draw_indexed_indirect(
385                                collect_state.commands.buffer(),
386                                (offset + i as u64)
387                                    * std::mem::size_of::<DrawIndexedIndirect>() as u64,
388                            );
389                        }
390                    }
391                }
392            }
393        }
394    }
395    pub fn n_entities(&self) -> usize {
396        self.tree.values().fold(0, |p, n| p + n.n_entities())
397    }
398    pub fn n_nodes(&self) -> usize {
399        self.tree.values().fold(0, |p, n| p + n.n_nodes())
400    }
401    pub fn dump(&self, f: &mut dyn std::io::Write) {
402        for (key, node) in self.tree.iter() {
403            writeln!(f, "    shader {key:?}").unwrap();
404            node.dump(f);
405        }
406    }
407}
408struct ShaderNode {
409    pipeline: GraphicsPipeline,
410    tree: HashMap<String, MaterialNode>,
411}
412impl ShaderNode {
413    pub fn new(
414        config: &TreeRendererConfig,
415        shader: Arc<RendererShader>,
416        double_sided: bool,
417    ) -> Self {
418        let gpu = config.gpu.clone();
419
420        let mut pipeline_info = GraphicsPipelineInfo {
421            vs_main: &shader.vs_main,
422            fs_main: shader.get_fs_main_name(config.fs_main),
423            targets: &config.targets,
424            cull_mode: config
425                .cull_mode
426                .and_then(|f| if double_sided { None } else { Some(f) }),
427            ..Default::default()
428        };
429        if config.depth_stencil {
430            pipeline_info = pipeline_info
431                .with_depth()
432                .with_depth_bias(config.depth_bias);
433        }
434
435        let pipeline = shader.shader.to_pipeline(&gpu, pipeline_info);
436
437        Self {
438            pipeline,
439            tree: HashMap::new(),
440        }
441    }
442    fn is_empty(&self) -> bool {
443        self.tree.is_empty()
444    }
445    pub fn n_entities(&self) -> usize {
446        self.tree.values().fold(0, |p, n| p + n.primitives.len())
447    }
448    pub fn n_nodes(&self) -> usize {
449        self.tree.len() + 1
450    }
451    pub fn dump(&self, f: &mut dyn std::io::Write) {
452        for (_key, node) in self.tree.iter() {
453            writeln!(
454                f,
455                "      material {:?}: {} entities",
456                node.material.name(),
457                node.primitives.len()
458            )
459            .unwrap();
460        }
461    }
462}
463struct MaterialNode {
464    material_index: u32,
465    primitives_subbuffer: SubBufferId,
466    material: SharedMaterial,
467    primitives: Vec<(EntityId, PrimitiveIndex)>,
468}
469
470struct MaterialIndices {
471    free_indices: Vec<u32>,
472    counter: u32,
473}
474impl MaterialIndices {
475    fn new() -> Self {
476        Self {
477            free_indices: Vec::new(),
478            counter: 0,
479        }
480    }
481    fn acquire_index(&mut self) -> u32 {
482        if let Some(index) = self.free_indices.pop() {
483            index
484        } else {
485            self.counter += 1;
486            self.counter - 1
487        }
488    }
489    fn release_index(&mut self, index: u32) {
490        self.free_indices.push(index);
491    }
492}