1use super::{blas::BlasManager, extract::StandardMaterialAssets, RaytracingMesh3d};
2use bevy_asset::{AssetId, Handle};
3use bevy_color::{ColorToComponents, LinearRgba};
4use bevy_ecs::{
5 entity::{Entity, EntityHashMap},
6 resource::Resource,
7 system::{Query, Res, ResMut},
8 world::{FromWorld, World},
9};
10use bevy_math::{ops::cos, Mat4, Vec3};
11use bevy_pbr::{ExtractedDirectionalLight, MeshMaterial3d, StandardMaterial};
12use bevy_platform::{collections::HashMap, hash::FixedHasher};
13use bevy_render::{
14 mesh::allocator::MeshAllocator,
15 render_asset::RenderAssets,
16 render_resource::{binding_types::*, *},
17 renderer::{RenderDevice, RenderQueue},
18 texture::{FallbackImage, GpuImage},
19};
20use bevy_transform::components::GlobalTransform;
21use core::{f32::consts::TAU, hash::Hash, num::NonZeroU32, ops::Deref};
22
23const MAX_MESH_SLAB_COUNT: NonZeroU32 = NonZeroU32::new(500).unwrap();
24const MAX_TEXTURE_COUNT: NonZeroU32 = NonZeroU32::new(5_000).unwrap();
25
26const TEXTURE_MAP_NONE: u32 = u32::MAX;
27const LIGHT_NOT_PRESENT_THIS_FRAME: u32 = u32::MAX;
28
29#[derive(Resource)]
30pub struct RaytracingSceneBindings {
31 pub bind_group: Option<BindGroup>,
32 pub bind_group_layout: BindGroupLayout,
33 previous_frame_light_entities: Vec<Entity>,
34}
35
36pub fn prepare_raytracing_scene_bindings(
37 instances_query: Query<(
38 Entity,
39 &RaytracingMesh3d,
40 &MeshMaterial3d<StandardMaterial>,
41 &GlobalTransform,
42 )>,
43 directional_lights_query: Query<(Entity, &ExtractedDirectionalLight)>,
44 mesh_allocator: Res<MeshAllocator>,
45 blas_manager: Res<BlasManager>,
46 material_assets: Res<StandardMaterialAssets>,
47 texture_assets: Res<RenderAssets<GpuImage>>,
48 fallback_texture: Res<FallbackImage>,
49 render_device: Res<RenderDevice>,
50 render_queue: Res<RenderQueue>,
51 mut raytracing_scene_bindings: ResMut<RaytracingSceneBindings>,
52) {
53 raytracing_scene_bindings.bind_group = None;
54
55 let mut this_frame_entity_to_light_id = EntityHashMap::<u32>::default();
56 let previous_frame_light_entities: Vec<_> = raytracing_scene_bindings
57 .previous_frame_light_entities
58 .drain(..)
59 .collect();
60
61 if instances_query.iter().len() == 0 {
62 return;
63 }
64
65 let mut vertex_buffers = CachedBindingArray::new();
66 let mut index_buffers = CachedBindingArray::new();
67 let mut textures = CachedBindingArray::new();
68 let mut samplers = Vec::new();
69 let mut materials = StorageBufferList::<GpuMaterial>::default();
70 let mut tlas = render_device
71 .wgpu_device()
72 .create_tlas(&CreateTlasDescriptor {
73 label: Some("tlas"),
74 flags: AccelerationStructureFlags::PREFER_FAST_TRACE,
75 update_mode: AccelerationStructureUpdateMode::Build,
76 max_instances: instances_query.iter().len() as u32,
77 });
78 let mut transforms = StorageBufferList::<Mat4>::default();
79 let mut geometry_ids = StorageBufferList::<GpuInstanceGeometryIds>::default();
80 let mut material_ids = StorageBufferList::<u32>::default();
81 let mut light_sources = StorageBufferList::<GpuLightSource>::default();
82 let mut directional_lights = StorageBufferList::<GpuDirectionalLight>::default();
83 let mut previous_frame_light_id_translations = StorageBufferList::<u32>::default();
84
85 let mut material_id_map: HashMap<AssetId<StandardMaterial>, u32, FixedHasher> =
86 HashMap::default();
87 let mut material_id = 0;
88 let mut process_texture = |texture_handle: &Option<Handle<_>>| -> Option<u32> {
89 match texture_handle {
90 Some(texture_handle) => match texture_assets.get(texture_handle.id()) {
91 Some(texture) => {
92 let (texture_id, is_new) =
93 textures.push_if_absent(texture.texture_view.deref(), texture_handle.id());
94 if is_new {
95 samplers.push(texture.sampler.deref());
96 }
97 Some(texture_id)
98 }
99 None => None,
100 },
101 None => Some(TEXTURE_MAP_NONE),
102 }
103 };
104 for (asset_id, material) in material_assets.iter() {
105 let Some(base_color_texture_id) = process_texture(&material.base_color_texture) else {
106 continue;
107 };
108 let Some(normal_map_texture_id) = process_texture(&material.normal_map_texture) else {
109 continue;
110 };
111 let Some(emissive_texture_id) = process_texture(&material.emissive_texture) else {
112 continue;
113 };
114 let Some(metallic_roughness_texture_id) =
115 process_texture(&material.metallic_roughness_texture)
116 else {
117 continue;
118 };
119
120 materials.get_mut().push(GpuMaterial {
121 normal_map_texture_id,
122 base_color_texture_id,
123 emissive_texture_id,
124 metallic_roughness_texture_id,
125
126 base_color: LinearRgba::from(material.base_color).to_vec3(),
127 perceptual_roughness: material.perceptual_roughness,
128 emissive: material.emissive.to_vec3(),
129 metallic: material.metallic,
130 reflectance: LinearRgba::from(material.specular_tint).to_vec3() * material.reflectance,
131 _padding: Default::default(),
132 });
133
134 material_id_map.insert(*asset_id, material_id);
135 material_id += 1;
136 }
137
138 if material_id == 0 {
139 return;
140 }
141
142 if textures.is_empty() {
143 textures.vec.push(fallback_texture.d2.texture_view.deref());
144 samplers.push(fallback_texture.d2.sampler.deref());
145 }
146
147 let mut instance_id = 0;
148 for (entity, mesh, material, transform) in &instances_query {
149 let Some(blas) = blas_manager.get(&mesh.id()) else {
150 continue;
151 };
152 let Some(vertex_slice) = mesh_allocator.mesh_vertex_slice(&mesh.id()) else {
153 continue;
154 };
155 let Some(index_slice) = mesh_allocator.mesh_index_slice(&mesh.id()) else {
156 continue;
157 };
158 let Some(material_id) = material_id_map.get(&material.id()).copied() else {
159 continue;
160 };
161 let Some(material) = materials.get().get(material_id as usize) else {
162 continue;
163 };
164
165 let transform = transform.to_matrix();
166 *tlas.get_mut_single(instance_id).unwrap() = Some(TlasInstance::new(
167 blas,
168 tlas_transform(&transform),
169 Default::default(),
170 0xFF,
171 ));
172
173 transforms.get_mut().push(transform);
174
175 let (vertex_buffer_id, _) = vertex_buffers.push_if_absent(
176 vertex_slice.buffer.as_entire_buffer_binding(),
177 vertex_slice.buffer.id(),
178 );
179 let (index_buffer_id, _) = index_buffers.push_if_absent(
180 index_slice.buffer.as_entire_buffer_binding(),
181 index_slice.buffer.id(),
182 );
183
184 geometry_ids.get_mut().push(GpuInstanceGeometryIds {
185 vertex_buffer_id,
186 vertex_buffer_offset: vertex_slice.range.start,
187 index_buffer_id,
188 index_buffer_offset: index_slice.range.start,
189 triangle_count: (index_slice.range.len() / 3) as u32,
190 });
191
192 material_ids.get_mut().push(material_id);
193
194 if material.emissive != Vec3::ZERO {
195 light_sources
196 .get_mut()
197 .push(GpuLightSource::new_emissive_mesh_light(
198 instance_id as u32,
199 (index_slice.range.len() / 3) as u32,
200 ));
201
202 this_frame_entity_to_light_id.insert(entity, light_sources.get().len() as u32 - 1);
203 raytracing_scene_bindings
204 .previous_frame_light_entities
205 .push(entity);
206 }
207
208 instance_id += 1;
209 }
210
211 if instance_id == 0 {
212 return;
213 }
214
215 for (entity, directional_light) in &directional_lights_query {
216 let directional_lights = directional_lights.get_mut();
217 let directional_light_id = directional_lights.len() as u32;
218
219 directional_lights.push(GpuDirectionalLight::new(directional_light));
220
221 light_sources
222 .get_mut()
223 .push(GpuLightSource::new_directional_light(directional_light_id));
224
225 this_frame_entity_to_light_id.insert(entity, light_sources.get().len() as u32 - 1);
226 raytracing_scene_bindings
227 .previous_frame_light_entities
228 .push(entity);
229 }
230
231 for previous_frame_light_entity in previous_frame_light_entities {
232 let current_frame_index = this_frame_entity_to_light_id
233 .get(&previous_frame_light_entity)
234 .copied()
235 .unwrap_or(LIGHT_NOT_PRESENT_THIS_FRAME);
236 previous_frame_light_id_translations
237 .get_mut()
238 .push(current_frame_index);
239 }
240
241 if light_sources.get().len() > u16::MAX as usize {
242 panic!("Too many light sources in the scene, maximum is 65536.");
243 }
244
245 materials.write_buffer(&render_device, &render_queue);
246 transforms.write_buffer(&render_device, &render_queue);
247 geometry_ids.write_buffer(&render_device, &render_queue);
248 material_ids.write_buffer(&render_device, &render_queue);
249 light_sources.write_buffer(&render_device, &render_queue);
250 directional_lights.write_buffer(&render_device, &render_queue);
251 previous_frame_light_id_translations.write_buffer(&render_device, &render_queue);
252
253 let mut command_encoder = render_device.create_command_encoder(&CommandEncoderDescriptor {
254 label: Some("build_tlas_command_encoder"),
255 });
256 command_encoder.build_acceleration_structures(&[], [&tlas]);
257 render_queue.submit([command_encoder.finish()]);
258
259 raytracing_scene_bindings.bind_group = Some(render_device.create_bind_group(
260 "raytracing_scene_bind_group",
261 &raytracing_scene_bindings.bind_group_layout,
262 &BindGroupEntries::sequential((
263 vertex_buffers.as_slice(),
264 index_buffers.as_slice(),
265 textures.as_slice(),
266 samplers.as_slice(),
267 materials.binding().unwrap(),
268 tlas.as_binding(),
269 transforms.binding().unwrap(),
270 geometry_ids.binding().unwrap(),
271 material_ids.binding().unwrap(),
272 light_sources.binding().unwrap(),
273 directional_lights.binding().unwrap(),
274 previous_frame_light_id_translations.binding().unwrap(),
275 )),
276 ));
277}
278
279impl FromWorld for RaytracingSceneBindings {
280 fn from_world(world: &mut World) -> Self {
281 let render_device = world.resource::<RenderDevice>();
282
283 Self {
284 bind_group: None,
285 bind_group_layout: render_device.create_bind_group_layout(
286 "raytracing_scene_bind_group_layout",
287 &BindGroupLayoutEntries::sequential(
288 ShaderStages::COMPUTE,
289 (
290 storage_buffer_read_only_sized(false, None).count(MAX_MESH_SLAB_COUNT),
291 storage_buffer_read_only_sized(false, None).count(MAX_MESH_SLAB_COUNT),
292 texture_2d(TextureSampleType::Float { filterable: true })
293 .count(MAX_TEXTURE_COUNT),
294 sampler(SamplerBindingType::Filtering).count(MAX_TEXTURE_COUNT),
295 storage_buffer_read_only_sized(false, None),
296 acceleration_structure(),
297 storage_buffer_read_only_sized(false, None),
298 storage_buffer_read_only_sized(false, None),
299 storage_buffer_read_only_sized(false, None),
300 storage_buffer_read_only_sized(false, None),
301 storage_buffer_read_only_sized(false, None),
302 storage_buffer_read_only_sized(false, None),
303 ),
304 ),
305 ),
306 previous_frame_light_entities: Vec::new(),
307 }
308 }
309}
310
311struct CachedBindingArray<T, I: Eq + Hash> {
312 map: HashMap<I, u32>,
313 vec: Vec<T>,
314}
315
316impl<T, I: Eq + Hash> CachedBindingArray<T, I> {
317 fn new() -> Self {
318 Self {
319 map: HashMap::default(),
320 vec: Vec::default(),
321 }
322 }
323
324 fn push_if_absent(&mut self, item: T, item_id: I) -> (u32, bool) {
325 let mut is_new = false;
326 let i = *self.map.entry(item_id).or_insert_with(|| {
327 is_new = true;
328 let i = self.vec.len() as u32;
329 self.vec.push(item);
330 i
331 });
332 (i, is_new)
333 }
334
335 fn is_empty(&self) -> bool {
336 self.vec.is_empty()
337 }
338
339 fn as_slice(&self) -> &[T] {
340 self.vec.as_slice()
341 }
342}
343
344type StorageBufferList<T> = StorageBuffer<Vec<T>>;
345
346#[derive(ShaderType)]
347struct GpuInstanceGeometryIds {
348 vertex_buffer_id: u32,
349 vertex_buffer_offset: u32,
350 index_buffer_id: u32,
351 index_buffer_offset: u32,
352 triangle_count: u32,
353}
354
355#[derive(ShaderType)]
356struct GpuMaterial {
357 normal_map_texture_id: u32,
358 base_color_texture_id: u32,
359 emissive_texture_id: u32,
360 metallic_roughness_texture_id: u32,
361
362 base_color: Vec3,
363 perceptual_roughness: f32,
364 emissive: Vec3,
365 metallic: f32,
366 reflectance: Vec3,
367 _padding: f32,
368}
369
370#[derive(ShaderType)]
371struct GpuLightSource {
372 kind: u32,
373 id: u32,
374}
375
376impl GpuLightSource {
377 fn new_emissive_mesh_light(instance_id: u32, triangle_count: u32) -> GpuLightSource {
378 if triangle_count > u16::MAX as u32 {
379 panic!("Too many triangles ({triangle_count}) in an emissive mesh, maximum is 65535.");
380 }
381
382 Self {
383 kind: triangle_count << 1,
384 id: instance_id,
385 }
386 }
387
388 fn new_directional_light(directional_light_id: u32) -> GpuLightSource {
389 Self {
390 kind: 1,
391 id: directional_light_id,
392 }
393 }
394}
395
396#[derive(ShaderType, Default)]
397struct GpuDirectionalLight {
398 direction_to_light: Vec3,
399 cos_theta_max: f32,
400 luminance: Vec3,
401 inverse_pdf: f32,
402}
403
404impl GpuDirectionalLight {
405 fn new(directional_light: &ExtractedDirectionalLight) -> Self {
406 let cos_theta_max = cos(directional_light.sun_disk_angular_size / 2.0);
407 let solid_angle = TAU * (1.0 - cos_theta_max);
408 let luminance =
409 (directional_light.color.to_vec3() * directional_light.illuminance) / solid_angle;
410
411 Self {
412 direction_to_light: directional_light.transform.back().into(),
413 cos_theta_max,
414 luminance,
415 inverse_pdf: solid_angle,
416 }
417 }
418}
419
420fn tlas_transform(transform: &Mat4) -> [f32; 12] {
421 transform.transpose().to_cols_array()[..12]
422 .try_into()
423 .unwrap()
424}