#[cfg(feature = "visualization")]
use std::collections::HashMap;
use std::iter;
use bytemuck_derive::{Pod, Zeroable};
use glam::Affine3A;
use wgpu::util::DeviceExt;
pub use wgpu;
pub mod depth_camera;
pub mod lidar;
pub mod utils;
#[inline]
fn affine_to_rows(mat: &Affine3A) -> [f32; 12] {
let row_0 = mat.matrix3.row(0);
let row_1 = mat.matrix3.row(1);
let row_2 = mat.matrix3.row(2);
let translation = mat.translation;
[
row_0.x,
row_0.y,
row_0.z,
translation.x,
row_1.x,
row_1.y,
row_1.z,
translation.y,
row_2.x,
row_2.y,
row_2.z,
translation.z,
]
}
#[inline]
fn affine_to_4x4rows(mat: &Affine3A) -> [f32; 16] {
let row_0 = mat.matrix3.row(0);
let row_1 = mat.matrix3.row(1);
let row_2 = mat.matrix3.row(2);
let translation = mat.translation;
[
row_0.x,
row_0.y,
row_0.z,
translation.x,
row_1.x,
row_1.y,
row_1.z,
translation.y,
row_2.x,
row_2.y,
row_2.z,
translation.z,
0.0,
0.0,
0.0,
0.1,
]
}
#[repr(C)]
#[derive(Clone, Copy, Pod, Zeroable, Debug)]
pub struct Vertex {
_pos: [f32; 4],
_tex_coord: [f32; 2],
}
pub fn vertex(pos: [f32; 3]) -> Vertex {
Vertex {
_pos: [pos[0], pos[1], pos[2], 1.0],
_tex_coord: [0.0, 0.0],
}
}
#[derive(Clone, Debug)]
pub struct AssetMesh {
pub vertex_buf: Vec<Vertex>,
pub index_buf: Vec<u16>,
}
#[derive(Clone, Debug)]
pub struct Instance {
pub asset_mesh_index: usize,
pub transform: Affine3A,
}
pub struct RayTraceScene {
#[cfg(feature = "visualization")]
pub(crate) vertex_buf: wgpu::Buffer,
#[cfg(feature = "visualization")]
pub(crate) index_buf: wgpu::Buffer,
pub(crate) blas: Vec<wgpu::Blas>,
pub(crate) tlas_package: wgpu::Tlas,
#[cfg(feature = "visualization")]
pub(crate) assets: Vec<AssetMesh>,
pub(crate) instances: Vec<Instance>,
}
impl RayTraceScene {
pub async fn new(
device: &wgpu::Device,
queue: &wgpu::Queue,
assets: &Vec<AssetMesh>,
instances: &[Instance],
) -> Self {
let mut vertex_data = vec![];
let mut index_data = vec![];
let mut start_vertex_address = vec![];
let mut start_indices_address = vec![];
let mut geometries = vec![];
for asset in assets {
start_vertex_address.push(vertex_data.len());
vertex_data.extend(asset.vertex_buf.iter().cloned());
let start_indices = index_data.len();
start_indices_address.push(index_data.len());
index_data.extend(asset.index_buf.iter().cloned());
let end_indices = index_data.len();
geometries.push(start_indices..end_indices);
}
let vertex_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("Vertex Buffer"),
contents: bytemuck::cast_slice(&vertex_data),
usage: wgpu::BufferUsages::VERTEX | wgpu::BufferUsages::BLAS_INPUT,
});
let index_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("Index Buffer"),
contents: bytemuck::cast_slice(&index_data),
usage: wgpu::BufferUsages::INDEX | wgpu::BufferUsages::BLAS_INPUT,
});
let mut geometry_desc_sizes = vec![];
let mut blas = vec![];
println!("Creating BLAS for {} assets", assets.len());
for asset in assets {
println!(
"Creating BLAS for asset with {} vertices and {} indices",
asset.vertex_buf.len(),
asset.index_buf.len()
);
let geom_list = vec![wgpu::BlasTriangleGeometrySizeDescriptor {
vertex_count: asset.vertex_buf.len() as u32,
vertex_format: wgpu::VertexFormat::Float32x3,
index_count: Some(asset.index_buf.len() as u32),
index_format: Some(wgpu::IndexFormat::Uint16),
flags: wgpu::AccelerationStructureGeometryFlags::OPAQUE,
}];
geometry_desc_sizes.push(geom_list.clone());
blas.push(device.create_blas(
&wgpu::CreateBlasDescriptor {
label: Some(&format!("BLAS {}", blas.len())),
flags: wgpu::AccelerationStructureFlags::PREFER_FAST_TRACE,
update_mode: wgpu::AccelerationStructureUpdateMode::Build,
},
wgpu::BlasGeometrySizeDescriptors::Triangles {
descriptors: geom_list.clone(),
},
));
}
let tlas = device.create_tlas(&wgpu::CreateTlasDescriptor {
label: None,
flags: wgpu::AccelerationStructureFlags::PREFER_FAST_TRACE,
update_mode: wgpu::AccelerationStructureUpdateMode::Build,
max_instances: instances.len() as u32,
});
let mut tlas_package = tlas;
for (idx, instance) in instances.iter().enumerate() {
tlas_package[idx] = Some(wgpu::TlasInstance::new(
&blas[instance.asset_mesh_index],
affine_to_rows(&instance.transform),
0,
0xff,
));
}
let mut encoder =
device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
let blas_iter: Vec<_> = blas
.iter()
.enumerate()
.map(|(index, blas)| wgpu::BlasBuildEntry {
blas,
geometry: wgpu::BlasGeometries::TriangleGeometries(vec![
wgpu::BlasTriangleGeometry {
size: &geometry_desc_sizes[index][0],
vertex_buffer: &vertex_buf,
first_vertex: start_vertex_address[index] as u32,
vertex_stride: std::mem::size_of::<Vertex>() as u64,
index_buffer: Some(&index_buf),
first_index: Some(start_indices_address[index] as u32),
transform_buffer: None,
transform_buffer_offset: None,
},
]),
})
.collect();
encoder.build_acceleration_structures(blas_iter.iter(), iter::once(&tlas_package));
queue.submit(Some(encoder.finish()));
device.push_error_scope(wgpu::ErrorFilter::Validation);
Self {
#[cfg(feature = "visualization")]
vertex_buf,
#[cfg(feature = "visualization")]
index_buf,
blas,
tlas_package,
#[cfg(feature = "visualization")]
assets: assets.clone(),
instances: instances.to_vec(),
}
}
pub async fn set_transform(
&mut self,
device: &wgpu::Device,
update_instance: &[Instance],
idx: &[usize],
) -> Result<(), String> {
if update_instance.len() != idx.len() {
return Err("Instance and index length mismatch".to_string());
}
for (i, instance) in update_instance.iter().enumerate() {
self.tlas_package[idx[i]] = Some(wgpu::TlasInstance::new(
&self.blas[instance.asset_mesh_index],
affine_to_rows(&instance.transform),
0,
0xff,
));
}
let mut encoder =
device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
encoder.build_acceleration_structures(iter::empty(), iter::once(&self.tlas_package));
self.instances = update_instance.to_owned();
Ok(())
}
#[cfg(feature = "visualization")]
pub fn visualize(&self, rerun: &rerun::RecordingStream) {
for (idx, mesh) in self.assets.iter().enumerate() {
let vertex: Vec<_> = mesh
.vertex_buf
.iter()
.map(|a| [a._pos[0], a._pos[1], a._pos[2]])
.collect();
let indices: Vec<_> = mesh
.index_buf
.chunks(3)
.map(|a| [a[0] as u32, a[1] as u32, a[2] as u32])
.collect();
rerun.log(
format!("mesh_{}", idx),
&rerun::Mesh3D::new(vertex).with_triangle_indices(indices),
);
}
let mut instance_map = HashMap::new();
for (idx, instance) in self.instances.iter().enumerate() {
let translations = [
instance.transform.translation.x,
instance.transform.translation.y,
instance.transform.translation.z,
];
let rotation = glam::Quat::from_mat3a(&instance.transform.matrix3);
let rotation =
rerun::Quaternion::from_xyzw([rotation.x, rotation.y, rotation.z, rotation.w]);
let Some(mesh_idx) = instance_map.get_mut(&instance.asset_mesh_index) else {
instance_map.insert(idx, vec![(translations, rotation)]);
continue;
};
mesh_idx.push((translations, rotation));
}
for (idx, transform) in instance_map.iter() {
let translations = transform.iter().map(|f| f.0);
let rotations = transform.iter().map(|f| f.1);
rerun.log(
format!("mesh_{}", idx),
&rerun::InstancePoses3D::new()
.with_translations(translations)
.with_quaternions(rotations),
);
}
}
}