1use std::{
2 collections::{HashMap, HashSet},
3 sync::Arc,
4};
5
6use ambient_ecs::{query, ArchetypeFilter, EntityId, FramedEventsReader, QueryState, World};
7use ambient_gpu::{
8 gpu::Gpu,
9 multi_buffer::{MultiBufferSizeStrategy, SubBufferId, TypedMultiBuffer},
10 shader_module::{GraphicsPipeline, GraphicsPipelineInfo},
11};
12use ambient_std::asset_cache::AssetCache;
13use glam::{uvec2, UVec2};
14use itertools::Itertools;
15use wgpu::DepthBiasState;
16
17use super::{
18 double_sided, lod::cpu_lod_visible, primitives, CollectPrimitive, DrawIndexedIndirect, FSMain,
19 PrimitiveIndex, RendererCollectState, RendererResources, RendererShader, SharedMaterial,
20};
21use crate::{bind_groups::BindGroups, is_transparent, RendererConfig};
22
23pub struct TreeRendererConfig {
24 pub gpu: Arc<Gpu>,
25 pub renderer_config: RendererConfig,
26 pub assets: AssetCache,
27 pub filter: ArchetypeFilter,
28 pub targets: Vec<Option<wgpu::ColorTargetState>>,
29 pub renderer_resources: RendererResources,
30 pub fs_main: FSMain,
31 pub opaque_only: bool,
32 pub depth_stencil: bool,
33 pub cull_mode: Option<wgpu::Face>,
34 pub depth_bias: DepthBiasState,
35}
36
37pub struct TreeRenderer {
38 config: Arc<TreeRendererConfig>,
39 tree: HashMap<String, ShaderNode>,
40 entity_primitive_count: HashMap<EntityId, usize>,
41 primitives_lookup: HashMap<(EntityId, PrimitiveIndex), (String, String, usize)>,
42 loc_changed_reader: FramedEventsReader<EntityId>,
43
44 primitives: TypedMultiBuffer<CollectPrimitive>,
45 primitives_bind_group: Option<wgpu::BindGroup>,
46 spawn_qs: QueryState,
47 despawn_qs: QueryState,
48 material_indices: MaterialIndices,
49}
50impl TreeRenderer {
51 pub fn new(config: TreeRendererConfig) -> Self {
52 Self {
53 tree: HashMap::new(),
54 entity_primitive_count: HashMap::new(),
55 primitives_lookup: HashMap::new(),
56 loc_changed_reader: FramedEventsReader::new(),
57
58 primitives_bind_group: None,
59 primitives: TypedMultiBuffer::new(
60 config.gpu.clone(),
61 "TreeRenderer.primitives",
62 wgpu::BufferUsages::STORAGE
63 | wgpu::BufferUsages::COPY_DST
64 | wgpu::BufferUsages::COPY_SRC
65 | wgpu::BufferUsages::INDIRECT,
66 MultiBufferSizeStrategy::Pow2,
67 ),
68
69 config: Arc::new(config),
70 spawn_qs: QueryState::new(),
71 despawn_qs: QueryState::new(),
72 material_indices: MaterialIndices::new(),
73 }
74 }
75 fn create_primitives_bind_group(
76 gpu: &Gpu,
77 layout: &wgpu::BindGroupLayout,
78 buffer: &wgpu::Buffer,
79 ) -> wgpu::BindGroup {
80 gpu.device.create_bind_group(&wgpu::BindGroupDescriptor {
81 layout,
82 entries: &[wgpu::BindGroupEntry {
83 binding: 0,
84 resource: buffer.as_entire_binding(),
85 }],
86 label: Some("TreeRenderer.primitives"),
87 })
88 }
89 #[ambient_profiling::function]
90 pub fn update(&mut self, world: &mut World) {
91 let mut to_update = HashSet::new();
92 let mut spawn_qs = std::mem::replace(&mut self.spawn_qs, QueryState::new());
93 let mut despawn_qs = std::mem::replace(&mut self.despawn_qs, QueryState::new());
94
95 for (id, (primitives,)) in query((primitives().changed(),))
96 .optional_changed(cpu_lod_visible())
97 .filter(&self.config.filter)
98 .iter(world, Some(&mut spawn_qs))
99 {
100 if let Some(primitive_count) = self.entity_primitive_count.get(&id) {
101 for primitive_index in 0..*primitive_count {
102 if let Some(update) = self.remove_primitive(id, primitive_index) {
103 to_update.insert(update);
104 }
105 }
106 }
107 for (primitive_index, primitive) in primitives.iter().enumerate() {
108 let primitive_shader =
109 (primitive.shader)(&self.config.assets, &self.config.renderer_config);
110 if let Some(update) = self.insert(
111 world,
112 id,
113 primitive_index,
114 &primitive_shader,
115 &primitive.material,
116 ) {
117 to_update.insert(update);
118 }
119 }
120 self.entity_primitive_count.insert(id, primitives.len());
121 }
122
123 for (id, _) in query(())
124 .incl(primitives())
125 .filter(&self.config.filter)
126 .despawned()
127 .iter(world, Some(&mut despawn_qs))
128 {
129 if let Some(primitive_count) = self.entity_primitive_count.get(&id) {
130 for primitive_index in 0..*primitive_count {
131 if let Some(update) = self.remove_primitive(id, primitive_index) {
132 to_update.insert(update);
133 }
134 }
135 }
136 self.entity_primitive_count.remove(&id);
137 }
138
139 self.spawn_qs = spawn_qs;
140 self.despawn_qs = despawn_qs;
141 self.clean_empty();
142 for (_, id) in self.loc_changed_reader.iter(world.loc_changed()) {
143 if let Ok(primitives) = world.get_ref(*id, primitives()) {
144 for primivite_index in 0..primitives.len() {
145 if let Some((shader_id, material_id, _)) =
146 self.primitives_lookup.get(&(*id, primivite_index))
147 {
148 to_update.insert((shader_id.clone(), material_id.clone()));
149 }
150 }
151 }
152 }
153
154 let mut encoder =
155 self.config
156 .gpu
157 .device
158 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
159 label: Some("TreeRenderer.update"),
160 });
161 let mut primitives_to_write = Vec::new();
162 for (shader_id, material_id) in to_update.into_iter() {
163 if let Some(shader) = self.tree.get(&shader_id) {
164 if let Some(mat) = shader.tree.get(&material_id) {
165 let primitives = mat
166 .primitives
167 .iter()
168 .map(|(id, primitive_index)| {
169 CollectPrimitive::from_primitive(
170 world,
171 *id,
172 *primitive_index,
173 mat.material_index,
174 )
175 })
176 .collect_vec();
177 self.primitives
178 .resize_buffer_with_encoder(
179 &mut encoder,
180 mat.primitives_subbuffer,
181 primitives.len() as u64,
182 )
183 .unwrap();
184 primitives_to_write.push((mat.primitives_subbuffer, primitives));
185 }
186 }
187 }
188
189 self.config.gpu.queue.submit(Some(encoder.finish()));
190 for (subbuffer, primitives) in primitives_to_write.into_iter() {
191 self.primitives.write(subbuffer, 0, &primitives).unwrap();
192 }
193
194 for node in self.tree.values_mut() {
195 for mat in node.tree.values_mut() {
196 mat.material.update(world);
199 }
200 }
201
202 self.primitives_bind_group = if self.primitives.total_len() > 0 {
203 Some(Self::create_primitives_bind_group(
204 &self.config.gpu,
205 &self.config.renderer_resources.primitives_layout,
206 self.primitives.buffer(),
207 ))
208 } else {
209 None
210 };
211 }
212 pub fn run_collect(
213 &self,
214 encoder: &mut wgpu::CommandEncoder,
215 post_submit: &mut Vec<Box<dyn FnOnce() + Send + Send>>,
216 resources_bind_group: &wgpu::BindGroup,
217 entities_bind_group: &wgpu::BindGroup,
218 collect_state: &mut RendererCollectState,
219 ) {
220 let mut material_layouts = vec![UVec2::ZERO; self.material_indices.counter as usize];
221 for node in self.tree.values() {
222 for mat in node.tree.values() {
223 let offset = self
224 .primitives
225 .buffer_offset(mat.primitives_subbuffer)
226 .unwrap();
227 material_layouts[mat.material_index as usize] =
228 uvec2(offset as u32, mat.primitives.len() as u32);
229 }
230 }
231
232 self.config.renderer_resources.collect.run(
233 encoder,
234 post_submit,
235 resources_bind_group,
236 entities_bind_group,
237 &self.primitives,
238 collect_state,
239 self.primitives.total_len() as u32,
240 material_layouts,
241 );
242 }
243
244 fn insert(
245 &mut self,
246 world: &World,
247 id: EntityId,
248 primitive_index: usize,
249 shader: &Arc<RendererShader>,
250 material: &SharedMaterial,
251 ) -> Option<(String, String)> {
252 let transparent = is_transparent(world, id, material, shader);
253 if (!transparent || !self.config.opaque_only)
254 && world.get(id, cpu_lod_visible()).unwrap_or(true)
255 {
256 let config = &self.config;
257 let double_sided = world
258 .get(id, double_sided())
259 .unwrap_or(material.double_sided().unwrap_or(shader.double_sided));
260 let shader_id = format!("{}-{}", shader.id, double_sided);
261 let node = self
262 .tree
263 .entry(shader_id.clone())
264 .or_insert_with(|| ShaderNode::new(config, shader.clone(), double_sided));
265
266 let mat = node
267 .tree
268 .entry(material.id().to_string())
269 .or_insert_with(|| MaterialNode {
270 material_index: self.material_indices.acquire_index(),
271 primitives_subbuffer: self.primitives.create_buffer(None),
272 material: material.clone(),
273 primitives: Vec::new(),
274 });
275 self.primitives_lookup.insert(
276 (id, primitive_index),
277 (
278 shader_id.clone(),
279 material.id().to_string(),
280 mat.primitives.len(),
281 ),
282 );
283 mat.primitives.push((id, primitive_index));
284 Some((shader_id, material.id().to_string()))
285 } else {
286 None
287 }
288 }
289
290 fn remove_primitive(
291 &mut self,
292 id: EntityId,
293 primitive_index: usize,
294 ) -> Option<(String, String)> {
295 if let Some((shader_id, material_id, index)) =
296 self.primitives_lookup.remove(&(id, primitive_index))
297 {
298 let shader = self.tree.get_mut(&shader_id).unwrap();
299 let material = shader.tree.get_mut(&material_id).unwrap();
300 let is_last = material.primitives.len() == index + 1;
301 if !is_last {
302 if let Some(last_id) = material.primitives.last() {
303 self.primitives_lookup.get_mut(last_id).unwrap().2 = index;
304 }
305 }
306 material.primitives.swap_remove(index);
307 Some((shader_id, material_id))
308 } else {
309 None
310 }
311 }
312 fn clean_empty(&mut self) {
313 for node in self.tree.values_mut() {
314 node.tree.retain(|_, mat| {
315 let to_remove = mat.primitives.is_empty();
316 if to_remove {
317 self.primitives
318 .remove_buffer(mat.primitives_subbuffer)
319 .unwrap();
320 self.material_indices.release_index(mat.material_index);
321 }
322 !to_remove
323 });
324 }
325 self.tree.retain(|_, v| !v.is_empty());
326 }
327 #[ambient_profiling::function]
328 pub fn render<'a>(
329 &'a self,
330 render_pass: &mut wgpu::RenderPass<'a>,
331 collect_state: &'a RendererCollectState,
332 bind_groups: &BindGroups<'a>,
333 ) {
334 let primitives_bind_group = if let Some(primitives_bind_group) = &self.primitives_bind_group
335 {
336 primitives_bind_group
337 } else {
338 return; };
340
341 #[cfg(target_os = "macos")]
342 let counts = collect_state.counts_cpu.lock().clone();
343
344 let mut is_bound = false;
345
346 for node in self.tree.values() {
347 render_pass.set_pipeline(node.pipeline.pipeline());
348 let bind_groups = [
350 bind_groups.globals,
351 bind_groups.entities,
352 primitives_bind_group,
353 ];
354 if !is_bound {
355 for (i, bind_group) in bind_groups.iter().enumerate() {
356 render_pass.set_bind_group(i as _, bind_group, &[]);
357 is_bound = true
358 }
359 }
360
361 for mat in node.tree.values() {
362 let material = &mat.material;
363
364 render_pass.set_bind_group(bind_groups.len() as _, material.bind_group(), &[]);
365
366 let offset = self
367 .primitives
368 .buffer_offset(mat.primitives_subbuffer)
369 .unwrap();
370 #[cfg(not(target_os = "macos"))]
371 {
372 render_pass.multi_draw_indexed_indirect_count(
373 collect_state.commands.buffer(),
374 offset * std::mem::size_of::<DrawIndexedIndirect>() as u64,
375 collect_state.counts.buffer(),
376 mat.material_index as u64 * std::mem::size_of::<u32>() as u64,
377 mat.primitives.len() as u32,
378 );
379 }
380 #[cfg(target_os = "macos")]
381 {
382 if let Some(count) = counts.get(mat.material_index as usize) {
383 for i in 0..*count {
384 render_pass.draw_indexed_indirect(
385 collect_state.commands.buffer(),
386 (offset + i as u64)
387 * std::mem::size_of::<DrawIndexedIndirect>() as u64,
388 );
389 }
390 }
391 }
392 }
393 }
394 }
395 pub fn n_entities(&self) -> usize {
396 self.tree.values().fold(0, |p, n| p + n.n_entities())
397 }
398 pub fn n_nodes(&self) -> usize {
399 self.tree.values().fold(0, |p, n| p + n.n_nodes())
400 }
401 pub fn dump(&self, f: &mut dyn std::io::Write) {
402 for (key, node) in self.tree.iter() {
403 writeln!(f, " shader {key:?}").unwrap();
404 node.dump(f);
405 }
406 }
407}
408struct ShaderNode {
409 pipeline: GraphicsPipeline,
410 tree: HashMap<String, MaterialNode>,
411}
412impl ShaderNode {
413 pub fn new(
414 config: &TreeRendererConfig,
415 shader: Arc<RendererShader>,
416 double_sided: bool,
417 ) -> Self {
418 let gpu = config.gpu.clone();
419
420 let mut pipeline_info = GraphicsPipelineInfo {
421 vs_main: &shader.vs_main,
422 fs_main: shader.get_fs_main_name(config.fs_main),
423 targets: &config.targets,
424 cull_mode: config
425 .cull_mode
426 .and_then(|f| if double_sided { None } else { Some(f) }),
427 ..Default::default()
428 };
429 if config.depth_stencil {
430 pipeline_info = pipeline_info
431 .with_depth()
432 .with_depth_bias(config.depth_bias);
433 }
434
435 let pipeline = shader.shader.to_pipeline(&gpu, pipeline_info);
436
437 Self {
438 pipeline,
439 tree: HashMap::new(),
440 }
441 }
442 fn is_empty(&self) -> bool {
443 self.tree.is_empty()
444 }
445 pub fn n_entities(&self) -> usize {
446 self.tree.values().fold(0, |p, n| p + n.primitives.len())
447 }
448 pub fn n_nodes(&self) -> usize {
449 self.tree.len() + 1
450 }
451 pub fn dump(&self, f: &mut dyn std::io::Write) {
452 for (_key, node) in self.tree.iter() {
453 writeln!(
454 f,
455 " material {:?}: {} entities",
456 node.material.name(),
457 node.primitives.len()
458 )
459 .unwrap();
460 }
461 }
462}
463struct MaterialNode {
464 material_index: u32,
465 primitives_subbuffer: SubBufferId,
466 material: SharedMaterial,
467 primitives: Vec<(EntityId, PrimitiveIndex)>,
468}
469
470struct MaterialIndices {
471 free_indices: Vec<u32>,
472 counter: u32,
473}
474impl MaterialIndices {
475 fn new() -> Self {
476 Self {
477 free_indices: Vec::new(),
478 counter: 0,
479 }
480 }
481 fn acquire_index(&mut self) -> u32 {
482 if let Some(index) = self.free_indices.pop() {
483 index
484 } else {
485 self.counter += 1;
486 self.counter - 1
487 }
488 }
489 fn release_index(&mut self, index: u32) {
490 self.free_indices.push(index);
491 }
492}