bevy_hikari/mesh_material/
instance.rs

1use super::{
2    material::GpuStandardMaterials, mesh::GpuMeshes, GpuAliasEntry, GpuAliasTableBuffer,
3    GpuEmissive, GpuEmissiveBuffer, GpuMesh, GpuStandardMaterial, MeshMaterialSystems,
4};
5use crate::{
6    mesh_material::{GpuInstance, GpuInstanceBuffer, GpuNode, GpuNodeBuffer},
7    transform::GlobalTransformQueue,
8    HikariUniversalSettings,
9};
10use bevy::{
11    asset::Asset,
12    ecs::query::QueryItem,
13    math::{Vec3A, Vec4Swizzles},
14    prelude::*,
15    render::{
16        extract_component::{ExtractComponent, ExtractComponentPlugin, UniformComponentPlugin},
17        primitives::Aabb,
18        render_resource::*,
19        renderer::{RenderDevice, RenderQueue},
20        view::VisibilitySystems,
21        Extract, RenderApp, RenderStage,
22    },
23    transform::TransformSystem,
24};
25use bvh::bvh::BVH;
26use itertools::Itertools;
27use std::{collections::BTreeMap, marker::PhantomData};
28
29pub struct InstancePlugin;
30impl Plugin for InstancePlugin {
31    fn build(&self, app: &mut App) {
32        app.add_plugin(ExtractComponentPlugin::<PreviousMeshUniform>::default())
33            .add_plugin(UniformComponentPlugin::<PreviousMeshUniform>::default());
34
35        if let Ok(render_app) = app.get_sub_app_mut(RenderApp) {
36            render_app
37                .init_resource::<ExtractedInstances>()
38                .init_resource::<InstanceRenderAssets>()
39                .add_system_to_stage(
40                    RenderStage::Prepare,
41                    prepare_instances
42                        .label(MeshMaterialSystems::PrepareInstances)
43                        .after(MeshMaterialSystems::PrepareAssets),
44                );
45        }
46    }
47}
48
49#[derive(Default)]
50pub struct GenericInstancePlugin<M: Into<StandardMaterial>>(PhantomData<M>);
51
52impl<M> Plugin for GenericInstancePlugin<M>
53where
54    M: Into<StandardMaterial> + Asset,
55{
56    fn build(&self, app: &mut App) {
57        app.add_event::<InstanceEvent<M>>().add_system_to_stage(
58            CoreStage::PostUpdate,
59            instance_event_system::<M>
60                .after(TransformSystem::TransformPropagate)
61                .after(VisibilitySystems::VisibilityPropagate)
62                .after(VisibilitySystems::CalculateBounds),
63        );
64
65        if let Ok(render_app) = app.get_sub_app_mut(RenderApp) {
66            render_app.add_system_to_stage(RenderStage::Extract, extract_instances::<M>);
67        }
68    }
69}
70
71#[derive(Default, Resource)]
72pub struct InstanceRenderAssets {
73    pub instance_buffer: StorageBuffer<GpuInstanceBuffer>,
74    pub instance_node_buffer: StorageBuffer<GpuNodeBuffer>,
75    pub emissive_buffer: StorageBuffer<GpuEmissiveBuffer>,
76    pub emissive_node_buffer: StorageBuffer<GpuNodeBuffer>,
77    pub alias_table_buffer: StorageBuffer<GpuAliasTableBuffer>,
78    pub instance_indices: DynamicUniformBuffer<InstanceIndex>,
79}
80
81impl InstanceRenderAssets {
82    pub fn set(
83        &mut self,
84        instances: Vec<GpuInstance>,
85        instance_nodes: Vec<GpuNode>,
86        emissives: Vec<GpuEmissive>,
87        emissive_nodes: Vec<GpuNode>,
88        alias_table: Vec<GpuAliasEntry>,
89    ) {
90        self.instance_buffer.get_mut().data = instances;
91        self.emissive_buffer.get_mut().data = emissives;
92        self.alias_table_buffer.get_mut().data = alias_table;
93
94        self.instance_node_buffer.get_mut().count = instance_nodes.len() as u32;
95        self.instance_node_buffer.get_mut().data = instance_nodes;
96
97        self.emissive_node_buffer.get_mut().count = emissive_nodes.len() as u32;
98        self.emissive_node_buffer.get_mut().data = emissive_nodes;
99    }
100
101    pub fn write_buffer(&mut self, device: &RenderDevice, queue: &RenderQueue) {
102        self.instance_buffer.write_buffer(device, queue);
103        self.instance_node_buffer.write_buffer(device, queue);
104        self.emissive_buffer.write_buffer(device, queue);
105        self.emissive_node_buffer.write_buffer(device, queue);
106        self.instance_indices.write_buffer(device, queue);
107        self.alias_table_buffer.write_buffer(device, queue);
108    }
109}
110
111#[derive(Default, Component, Clone, ShaderType)]
112pub struct PreviousMeshUniform {
113    pub transform: Mat4,
114    pub inverse_transpose_model: Mat4,
115}
116
117impl ExtractComponent for PreviousMeshUniform {
118    type Query = &'static GlobalTransformQueue;
119    type Filter = With<Handle<Mesh>>;
120
121    fn extract_component(queue: QueryItem<Self::Query>) -> Self {
122        let transform = queue[1];
123        PreviousMeshUniform {
124            transform,
125            inverse_transpose_model: transform.inverse().transpose(),
126        }
127    }
128}
129
130pub enum InstanceEvent<M: Into<StandardMaterial> + Asset> {
131    Created(Entity, Handle<Mesh>, Handle<M>, ComputedVisibility),
132    Modified(Entity, Handle<Mesh>, Handle<M>, ComputedVisibility),
133    Removed(Entity),
134}
135
136#[allow(clippy::type_complexity)]
137fn instance_event_system<M: Into<StandardMaterial> + Asset>(
138    mut events: EventWriter<InstanceEvent<M>>,
139    removed: RemovedComponents<Handle<Mesh>>,
140    mut set: ParamSet<(
141        Query<
142            (Entity, &Handle<Mesh>, &Handle<M>, &ComputedVisibility),
143            Or<(Added<Handle<Mesh>>, Added<Handle<M>>)>,
144        >,
145        Query<
146            (Entity, &Handle<Mesh>, &Handle<M>, &ComputedVisibility),
147            Or<(
148                Changed<GlobalTransform>,
149                Changed<Handle<Mesh>>,
150                Changed<Handle<M>>,
151                Changed<ComputedVisibility>,
152            )>,
153        >,
154    )>,
155) {
156    for entity in removed.iter() {
157        events.send(InstanceEvent::Removed(entity));
158    }
159    for (entity, mesh, material, visibility) in &set.p0() {
160        events.send(InstanceEvent::Created(
161            entity,
162            mesh.clone_weak(),
163            material.clone_weak(),
164            visibility.clone(),
165        ));
166    }
167    for (entity, mesh, material, visibility) in &set.p1() {
168        events.send(InstanceEvent::Modified(
169            entity,
170            mesh.clone_weak(),
171            material.clone_weak(),
172            visibility.clone(),
173        ));
174    }
175}
176
177#[allow(clippy::type_complexity)]
178#[derive(Default, Resource)]
179pub struct ExtractedInstances {
180    extracted: Vec<(
181        Entity,
182        Aabb,
183        GlobalTransform,
184        Handle<Mesh>,
185        HandleUntyped,
186        ComputedVisibility,
187    )>,
188    removed: Vec<Entity>,
189}
190
191fn extract_instances<M: Into<StandardMaterial> + Asset>(
192    mut events: Extract<EventReader<InstanceEvent<M>>>,
193    query: Extract<Query<(&Aabb, &GlobalTransform)>>,
194    mut extracted_instances: ResMut<ExtractedInstances>,
195) {
196    let mut extracted = vec![];
197    let mut removed = vec![];
198
199    for event in events.iter() {
200        match event {
201            InstanceEvent::Created(entity, mesh, material, visibility)
202            | InstanceEvent::Modified(entity, mesh, material, visibility) => {
203                if let Ok((aabb, transform)) = query.get(*entity) {
204                    extracted.push((
205                        *entity,
206                        aabb.clone(),
207                        *transform,
208                        mesh.clone_weak(),
209                        material.clone_weak_untyped(),
210                        visibility.clone(),
211                    ));
212                }
213            }
214            InstanceEvent::Removed(entity) => removed.push(*entity),
215        }
216    }
217
218    extracted_instances.extracted.append(&mut extracted);
219    extracted_instances.removed.append(&mut removed);
220}
221
222#[derive(Component, Default, Clone, Copy, ShaderType)]
223pub struct InstanceIndex {
224    pub instance: u32,
225    pub material: u32,
226}
227
228#[derive(Component, Default, Clone, Copy)]
229pub struct DynamicInstanceIndex(pub u32);
230
231type Instances = BTreeMap<
232    Entity,
233    (
234        GpuInstance,
235        GpuMesh,
236        GpuStandardMaterial,
237        ComputedVisibility,
238    ),
239>;
240
241type AlisaTableCache = BTreeMap<Entity, (Vec3, Vec<GpuAliasEntry>)>;
242
243/// Note: this system must run AFTER [`prepare_mesh_assets`].
244#[allow(clippy::too_many_arguments)]
245fn prepare_instances(
246    mut commands: Commands,
247    render_device: Res<RenderDevice>,
248    render_queue: Res<RenderQueue>,
249    mut render_assets: ResMut<InstanceRenderAssets>,
250    mut extracted_instances: ResMut<ExtractedInstances>,
251    mut collection: Local<Instances>,
252    mut alias_table_cache: Local<AlisaTableCache>,
253    meshes: Res<GpuMeshes>,
254    materials: Res<GpuStandardMaterials>,
255    universal_settings: Res<HikariUniversalSettings>,
256) {
257    if !universal_settings.build_instance_acceleration_structure {
258        return;
259    }
260
261    let instance_changed =
262        !extracted_instances.extracted.is_empty() || !extracted_instances.removed.is_empty();
263
264    for removed in extracted_instances.removed.drain(..) {
265        collection.remove(&removed);
266        alias_table_cache.remove(&removed);
267    }
268
269    let mut prepare_next_frame = vec![];
270
271    for (entity, aabb, transform, mesh, material, visibility) in extracted_instances
272        .extracted
273        .drain(..)
274        .filter_map(|(entity, aabb, transform, mesh, material, visibility)| {
275            match (meshes.get(&mesh), materials.get(&material)) {
276                (Some(mesh), Some(material)) => {
277                    Some((entity, aabb, transform, mesh, material, visibility))
278                }
279                _ => {
280                    prepare_next_frame.push((entity, aabb, transform, mesh, material, visibility));
281                    None
282                }
283            }
284        })
285    {
286        let transform = transform.compute_matrix();
287        let center = transform.transform_point3a(aabb.center);
288        let vertices: Vec<_> = (0..8i32)
289            .map(|index| {
290                let x = 2 * (index & 1) - 1;
291                let y = 2 * ((index >> 1) & 1) - 1;
292                let z = 2 * ((index >> 2) & 1) - 1;
293                let vertex = aabb.half_extents * Vec3A::new(x as f32, y as f32, z as f32);
294                transform.transform_vector3a(vertex)
295            })
296            .collect();
297
298        let mut min = Vec3A::ZERO;
299        let mut max = Vec3A::ZERO;
300        for vertex in vertices {
301            min = min.min(vertex);
302            max = max.max(vertex);
303        }
304        min += center;
305        max += center;
306
307        // Note that the `GpuInstance` is partially constructed:
308        // since node index is unknown at this point.
309        let min = Vec3::from(min);
310        let max = Vec3::from(max);
311        collection.insert(
312            entity,
313            (
314                GpuInstance {
315                    min,
316                    max,
317                    transform,
318                    inverse_transpose_model: transform.inverse().transpose(),
319                    mesh: mesh.1,
320                    material: material.1,
321                    ..Default::default()
322                },
323                mesh.0.clone(),
324                material.0.clone(),
325                visibility,
326            ),
327        );
328    }
329
330    extracted_instances
331        .extracted
332        .append(&mut prepare_next_frame);
333
334    // Since entities are cleared every frame, this should always be called.
335    let mut add_instance_indices = |instances: &Instances| {
336        render_assets.instance_indices.clear();
337        let command_batch: Vec<_> = instances
338            .iter()
339            .enumerate()
340            .map(|(id, (entity, (instance, _, _, _)))| {
341                let component = InstanceIndex {
342                    instance: id as u32,
343                    material: instance.material,
344                };
345                let index = render_assets.instance_indices.push(component);
346                (*entity, (DynamicInstanceIndex(index),))
347            })
348            .collect();
349        commands.insert_or_spawn_batch(command_batch);
350    };
351
352    if instance_changed || meshes.is_changed() || materials.is_changed() {
353        // Important: update mesh and material info for every instance
354        let mut emissives = vec![];
355        let mut alias_table = vec![];
356
357        collection.retain(|_, (_, _, _, visibility)| visibility.is_visible_in_hierarchy());
358
359        let mut instances: Vec<_> = collection
360            .values()
361            .map(|(instance, _, _, _)| instance)
362            .cloned()
363            .collect();
364
365        let instance_nodes = match collection.is_empty() {
366            true => vec![],
367            false => {
368                let bvh = BVH::build(&mut instances);
369                bvh.flatten_custom(&GpuNode::pack)
370            }
371        };
372
373        for ((instance, _, _, _), value) in collection.values_mut().zip_eq(instances.iter()) {
374            // Assign the computed BVH node index, and mesh/material indices.
375            *instance = value.clone();
376        }
377
378        add_instance_indices(&collection);
379
380        for (id, (entity, (instance, mesh, material, _))) in collection.iter().enumerate() {
381            let emissive = material.emissive;
382            let intensity = 255.0 * emissive.w * emissive.xyz().length();
383            if intensity > 0.0 {
384                // Compute alias table for light sampling
385                let instance_scale = instance.transform.to_scale_rotation_translation().0;
386                let alias_table = {
387                    let cached_table = alias_table_cache.get(entity).and_then(|(scale, table)| {
388                        scale.abs_diff_eq(instance_scale, 0.01).then_some(table)
389                    });
390                    let cache_hit = cached_table.is_some();
391                    let mut instance_table = cached_table
392                        .map_or_else(|| mesh.build_alias_table(instance.transform), Clone::clone);
393                    if !cache_hit {
394                        alias_table_cache.insert(*entity, (instance_scale, instance_table.clone()));
395                    }
396
397                    let index = UVec2::new(alias_table.len() as u32, instance_table.len() as u32);
398                    alias_table.append(&mut instance_table);
399                    index
400                };
401
402                let surface_area = mesh
403                    .transformed_primitive_areas(instance.transform)
404                    .iter()
405                    .sum();
406
407                // Add to emissive list.
408                let position = 0.5 * (instance.max + instance.min);
409                let radius = 0.5 * (instance.max - instance.min).length() + intensity.sqrt();
410                emissives.push(GpuEmissive {
411                    emissive,
412                    position,
413                    radius,
414                    instance: id as u32,
415                    alias_table,
416                    surface_area,
417                    node_index: 0,
418                });
419            }
420        }
421
422        let emissive_nodes = match emissives.is_empty() {
423            true => vec![],
424            false => {
425                let bvh = BVH::build(&mut emissives);
426                bvh.flatten_custom(&GpuNode::pack)
427            }
428        };
429
430        render_assets.set(
431            instances,
432            instance_nodes,
433            emissives,
434            emissive_nodes,
435            alias_table,
436        );
437        render_assets.write_buffer(&render_device, &render_queue);
438    } else {
439        add_instance_indices(&collection);
440        render_assets
441            .instance_indices
442            .write_buffer(&render_device, &render_queue);
443    }
444}