bevy_solari/scene/
binder.rs

1use super::{blas::BlasManager, extract::StandardMaterialAssets, RaytracingMesh3d};
2use bevy_asset::{AssetId, Handle};
3use bevy_color::{ColorToComponents, LinearRgba};
4use bevy_ecs::{
5    entity::{Entity, EntityHashMap},
6    resource::Resource,
7    system::{Query, Res, ResMut},
8    world::{FromWorld, World},
9};
10use bevy_math::{ops::cos, Mat4, Vec3};
11use bevy_pbr::{ExtractedDirectionalLight, MeshMaterial3d, StandardMaterial};
12use bevy_platform::{collections::HashMap, hash::FixedHasher};
13use bevy_render::{
14    mesh::allocator::MeshAllocator,
15    render_asset::RenderAssets,
16    render_resource::{binding_types::*, *},
17    renderer::{RenderDevice, RenderQueue},
18    texture::{FallbackImage, GpuImage},
19};
20use bevy_transform::components::GlobalTransform;
21use core::{f32::consts::TAU, hash::Hash, num::NonZeroU32, ops::Deref};
22
23const MAX_MESH_SLAB_COUNT: NonZeroU32 = NonZeroU32::new(500).unwrap();
24const MAX_TEXTURE_COUNT: NonZeroU32 = NonZeroU32::new(5_000).unwrap();
25
26const TEXTURE_MAP_NONE: u32 = u32::MAX;
27const LIGHT_NOT_PRESENT_THIS_FRAME: u32 = u32::MAX;
28
29#[derive(Resource)]
30pub struct RaytracingSceneBindings {
31    pub bind_group: Option<BindGroup>,
32    pub bind_group_layout: BindGroupLayout,
33    previous_frame_light_entities: Vec<Entity>,
34}
35
36pub fn prepare_raytracing_scene_bindings(
37    instances_query: Query<(
38        Entity,
39        &RaytracingMesh3d,
40        &MeshMaterial3d<StandardMaterial>,
41        &GlobalTransform,
42    )>,
43    directional_lights_query: Query<(Entity, &ExtractedDirectionalLight)>,
44    mesh_allocator: Res<MeshAllocator>,
45    blas_manager: Res<BlasManager>,
46    material_assets: Res<StandardMaterialAssets>,
47    texture_assets: Res<RenderAssets<GpuImage>>,
48    fallback_texture: Res<FallbackImage>,
49    render_device: Res<RenderDevice>,
50    render_queue: Res<RenderQueue>,
51    mut raytracing_scene_bindings: ResMut<RaytracingSceneBindings>,
52) {
53    raytracing_scene_bindings.bind_group = None;
54
55    let mut this_frame_entity_to_light_id = EntityHashMap::<u32>::default();
56    let previous_frame_light_entities: Vec<_> = raytracing_scene_bindings
57        .previous_frame_light_entities
58        .drain(..)
59        .collect();
60
61    if instances_query.iter().len() == 0 {
62        return;
63    }
64
65    let mut vertex_buffers = CachedBindingArray::new();
66    let mut index_buffers = CachedBindingArray::new();
67    let mut textures = CachedBindingArray::new();
68    let mut samplers = Vec::new();
69    let mut materials = StorageBufferList::<GpuMaterial>::default();
70    let mut tlas = render_device
71        .wgpu_device()
72        .create_tlas(&CreateTlasDescriptor {
73            label: Some("tlas"),
74            flags: AccelerationStructureFlags::PREFER_FAST_TRACE,
75            update_mode: AccelerationStructureUpdateMode::Build,
76            max_instances: instances_query.iter().len() as u32,
77        });
78    let mut transforms = StorageBufferList::<Mat4>::default();
79    let mut geometry_ids = StorageBufferList::<GpuInstanceGeometryIds>::default();
80    let mut material_ids = StorageBufferList::<u32>::default();
81    let mut light_sources = StorageBufferList::<GpuLightSource>::default();
82    let mut directional_lights = StorageBufferList::<GpuDirectionalLight>::default();
83    let mut previous_frame_light_id_translations = StorageBufferList::<u32>::default();
84
85    let mut material_id_map: HashMap<AssetId<StandardMaterial>, u32, FixedHasher> =
86        HashMap::default();
87    let mut material_id = 0;
88    let mut process_texture = |texture_handle: &Option<Handle<_>>| -> Option<u32> {
89        match texture_handle {
90            Some(texture_handle) => match texture_assets.get(texture_handle.id()) {
91                Some(texture) => {
92                    let (texture_id, is_new) =
93                        textures.push_if_absent(texture.texture_view.deref(), texture_handle.id());
94                    if is_new {
95                        samplers.push(texture.sampler.deref());
96                    }
97                    Some(texture_id)
98                }
99                None => None,
100            },
101            None => Some(TEXTURE_MAP_NONE),
102        }
103    };
104    for (asset_id, material) in material_assets.iter() {
105        let Some(base_color_texture_id) = process_texture(&material.base_color_texture) else {
106            continue;
107        };
108        let Some(normal_map_texture_id) = process_texture(&material.normal_map_texture) else {
109            continue;
110        };
111        let Some(emissive_texture_id) = process_texture(&material.emissive_texture) else {
112            continue;
113        };
114        let Some(metallic_roughness_texture_id) =
115            process_texture(&material.metallic_roughness_texture)
116        else {
117            continue;
118        };
119
120        materials.get_mut().push(GpuMaterial {
121            normal_map_texture_id,
122            base_color_texture_id,
123            emissive_texture_id,
124            metallic_roughness_texture_id,
125
126            base_color: LinearRgba::from(material.base_color).to_vec3(),
127            perceptual_roughness: material.perceptual_roughness,
128            emissive: material.emissive.to_vec3(),
129            metallic: material.metallic,
130            reflectance: LinearRgba::from(material.specular_tint).to_vec3() * material.reflectance,
131            _padding: Default::default(),
132        });
133
134        material_id_map.insert(*asset_id, material_id);
135        material_id += 1;
136    }
137
138    if material_id == 0 {
139        return;
140    }
141
142    if textures.is_empty() {
143        textures.vec.push(fallback_texture.d2.texture_view.deref());
144        samplers.push(fallback_texture.d2.sampler.deref());
145    }
146
147    let mut instance_id = 0;
148    for (entity, mesh, material, transform) in &instances_query {
149        let Some(blas) = blas_manager.get(&mesh.id()) else {
150            continue;
151        };
152        let Some(vertex_slice) = mesh_allocator.mesh_vertex_slice(&mesh.id()) else {
153            continue;
154        };
155        let Some(index_slice) = mesh_allocator.mesh_index_slice(&mesh.id()) else {
156            continue;
157        };
158        let Some(material_id) = material_id_map.get(&material.id()).copied() else {
159            continue;
160        };
161        let Some(material) = materials.get().get(material_id as usize) else {
162            continue;
163        };
164
165        let transform = transform.to_matrix();
166        *tlas.get_mut_single(instance_id).unwrap() = Some(TlasInstance::new(
167            blas,
168            tlas_transform(&transform),
169            Default::default(),
170            0xFF,
171        ));
172
173        transforms.get_mut().push(transform);
174
175        let (vertex_buffer_id, _) = vertex_buffers.push_if_absent(
176            vertex_slice.buffer.as_entire_buffer_binding(),
177            vertex_slice.buffer.id(),
178        );
179        let (index_buffer_id, _) = index_buffers.push_if_absent(
180            index_slice.buffer.as_entire_buffer_binding(),
181            index_slice.buffer.id(),
182        );
183
184        geometry_ids.get_mut().push(GpuInstanceGeometryIds {
185            vertex_buffer_id,
186            vertex_buffer_offset: vertex_slice.range.start,
187            index_buffer_id,
188            index_buffer_offset: index_slice.range.start,
189            triangle_count: (index_slice.range.len() / 3) as u32,
190        });
191
192        material_ids.get_mut().push(material_id);
193
194        if material.emissive != Vec3::ZERO {
195            light_sources
196                .get_mut()
197                .push(GpuLightSource::new_emissive_mesh_light(
198                    instance_id as u32,
199                    (index_slice.range.len() / 3) as u32,
200                ));
201
202            this_frame_entity_to_light_id.insert(entity, light_sources.get().len() as u32 - 1);
203            raytracing_scene_bindings
204                .previous_frame_light_entities
205                .push(entity);
206        }
207
208        instance_id += 1;
209    }
210
211    if instance_id == 0 {
212        return;
213    }
214
215    for (entity, directional_light) in &directional_lights_query {
216        let directional_lights = directional_lights.get_mut();
217        let directional_light_id = directional_lights.len() as u32;
218
219        directional_lights.push(GpuDirectionalLight::new(directional_light));
220
221        light_sources
222            .get_mut()
223            .push(GpuLightSource::new_directional_light(directional_light_id));
224
225        this_frame_entity_to_light_id.insert(entity, light_sources.get().len() as u32 - 1);
226        raytracing_scene_bindings
227            .previous_frame_light_entities
228            .push(entity);
229    }
230
231    for previous_frame_light_entity in previous_frame_light_entities {
232        let current_frame_index = this_frame_entity_to_light_id
233            .get(&previous_frame_light_entity)
234            .copied()
235            .unwrap_or(LIGHT_NOT_PRESENT_THIS_FRAME);
236        previous_frame_light_id_translations
237            .get_mut()
238            .push(current_frame_index);
239    }
240
241    if light_sources.get().len() > u16::MAX as usize {
242        panic!("Too many light sources in the scene, maximum is 65536.");
243    }
244
245    materials.write_buffer(&render_device, &render_queue);
246    transforms.write_buffer(&render_device, &render_queue);
247    geometry_ids.write_buffer(&render_device, &render_queue);
248    material_ids.write_buffer(&render_device, &render_queue);
249    light_sources.write_buffer(&render_device, &render_queue);
250    directional_lights.write_buffer(&render_device, &render_queue);
251    previous_frame_light_id_translations.write_buffer(&render_device, &render_queue);
252
253    let mut command_encoder = render_device.create_command_encoder(&CommandEncoderDescriptor {
254        label: Some("build_tlas_command_encoder"),
255    });
256    command_encoder.build_acceleration_structures(&[], [&tlas]);
257    render_queue.submit([command_encoder.finish()]);
258
259    raytracing_scene_bindings.bind_group = Some(render_device.create_bind_group(
260        "raytracing_scene_bind_group",
261        &raytracing_scene_bindings.bind_group_layout,
262        &BindGroupEntries::sequential((
263            vertex_buffers.as_slice(),
264            index_buffers.as_slice(),
265            textures.as_slice(),
266            samplers.as_slice(),
267            materials.binding().unwrap(),
268            tlas.as_binding(),
269            transforms.binding().unwrap(),
270            geometry_ids.binding().unwrap(),
271            material_ids.binding().unwrap(),
272            light_sources.binding().unwrap(),
273            directional_lights.binding().unwrap(),
274            previous_frame_light_id_translations.binding().unwrap(),
275        )),
276    ));
277}
278
279impl FromWorld for RaytracingSceneBindings {
280    fn from_world(world: &mut World) -> Self {
281        let render_device = world.resource::<RenderDevice>();
282
283        Self {
284            bind_group: None,
285            bind_group_layout: render_device.create_bind_group_layout(
286                "raytracing_scene_bind_group_layout",
287                &BindGroupLayoutEntries::sequential(
288                    ShaderStages::COMPUTE,
289                    (
290                        storage_buffer_read_only_sized(false, None).count(MAX_MESH_SLAB_COUNT),
291                        storage_buffer_read_only_sized(false, None).count(MAX_MESH_SLAB_COUNT),
292                        texture_2d(TextureSampleType::Float { filterable: true })
293                            .count(MAX_TEXTURE_COUNT),
294                        sampler(SamplerBindingType::Filtering).count(MAX_TEXTURE_COUNT),
295                        storage_buffer_read_only_sized(false, None),
296                        acceleration_structure(),
297                        storage_buffer_read_only_sized(false, None),
298                        storage_buffer_read_only_sized(false, None),
299                        storage_buffer_read_only_sized(false, None),
300                        storage_buffer_read_only_sized(false, None),
301                        storage_buffer_read_only_sized(false, None),
302                        storage_buffer_read_only_sized(false, None),
303                    ),
304                ),
305            ),
306            previous_frame_light_entities: Vec::new(),
307        }
308    }
309}
310
311struct CachedBindingArray<T, I: Eq + Hash> {
312    map: HashMap<I, u32>,
313    vec: Vec<T>,
314}
315
316impl<T, I: Eq + Hash> CachedBindingArray<T, I> {
317    fn new() -> Self {
318        Self {
319            map: HashMap::default(),
320            vec: Vec::default(),
321        }
322    }
323
324    fn push_if_absent(&mut self, item: T, item_id: I) -> (u32, bool) {
325        let mut is_new = false;
326        let i = *self.map.entry(item_id).or_insert_with(|| {
327            is_new = true;
328            let i = self.vec.len() as u32;
329            self.vec.push(item);
330            i
331        });
332        (i, is_new)
333    }
334
335    fn is_empty(&self) -> bool {
336        self.vec.is_empty()
337    }
338
339    fn as_slice(&self) -> &[T] {
340        self.vec.as_slice()
341    }
342}
343
344type StorageBufferList<T> = StorageBuffer<Vec<T>>;
345
346#[derive(ShaderType)]
347struct GpuInstanceGeometryIds {
348    vertex_buffer_id: u32,
349    vertex_buffer_offset: u32,
350    index_buffer_id: u32,
351    index_buffer_offset: u32,
352    triangle_count: u32,
353}
354
355#[derive(ShaderType)]
356struct GpuMaterial {
357    normal_map_texture_id: u32,
358    base_color_texture_id: u32,
359    emissive_texture_id: u32,
360    metallic_roughness_texture_id: u32,
361
362    base_color: Vec3,
363    perceptual_roughness: f32,
364    emissive: Vec3,
365    metallic: f32,
366    reflectance: Vec3,
367    _padding: f32,
368}
369
370#[derive(ShaderType)]
371struct GpuLightSource {
372    kind: u32,
373    id: u32,
374}
375
376impl GpuLightSource {
377    fn new_emissive_mesh_light(instance_id: u32, triangle_count: u32) -> GpuLightSource {
378        if triangle_count > u16::MAX as u32 {
379            panic!("Too many triangles ({triangle_count}) in an emissive mesh, maximum is 65535.");
380        }
381
382        Self {
383            kind: triangle_count << 1,
384            id: instance_id,
385        }
386    }
387
388    fn new_directional_light(directional_light_id: u32) -> GpuLightSource {
389        Self {
390            kind: 1,
391            id: directional_light_id,
392        }
393    }
394}
395
396#[derive(ShaderType, Default)]
397struct GpuDirectionalLight {
398    direction_to_light: Vec3,
399    cos_theta_max: f32,
400    luminance: Vec3,
401    inverse_pdf: f32,
402}
403
404impl GpuDirectionalLight {
405    fn new(directional_light: &ExtractedDirectionalLight) -> Self {
406        let cos_theta_max = cos(directional_light.sun_disk_angular_size / 2.0);
407        let solid_angle = TAU * (1.0 - cos_theta_max);
408        let luminance =
409            (directional_light.color.to_vec3() * directional_light.illuminance) / solid_angle;
410
411        Self {
412            direction_to_light: directional_light.transform.back().into(),
413            cos_theta_max,
414            luminance,
415            inverse_pdf: solid_angle,
416        }
417    }
418}
419
420fn tlas_transform(transform: &Mat4) -> [f32; 12] {
421    transform.transpose().to_cols_array()[..12]
422        .try_into()
423        .unwrap()
424}