use super::cpu::{Bvh, flatten, ray_aabb_t};
use super::types::{FlatBvhNode, GpuRay};
#[cfg(feature = "wgpu-backend")]
const BVH_TRAVERSAL_WGSL: &str = include_str!("../shaders/bvh_traversal.wgsl");
#[cfg(feature = "wgpu-backend")]
pub(crate) struct BvhGpuState {
pub(crate) backend: std::sync::Mutex<crate::compute::wgpu_backend::real::WgpuBackendReal>,
pub(crate) prim_aabbs_buf: crate::compute::WgpuBufferHandle,
pub(crate) prim_indices_buf: crate::compute::WgpuBufferHandle,
pub(crate) object_ids_buf: crate::compute::WgpuBufferHandle,
pub(crate) dispatch_count: std::sync::atomic::AtomicU64,
pub(crate) creation_id: u64,
}
pub(crate) enum BvhTraverserInner {
Cpu,
#[cfg(feature = "wgpu-backend")]
Gpu(Box<BvhGpuState>),
}
pub struct BvhGpuTraverser {
pub(crate) flat_nodes: Vec<FlatBvhNode>,
pub(crate) prim_indices: Vec<usize>,
pub(crate) primitives: Vec<super::types::BvhPrimitive>,
pub(crate) inner: BvhTraverserInner,
}
#[cfg(feature = "wgpu-backend")]
static CREATION_COUNTER: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(1);
impl BvhGpuTraverser {
pub fn new(bvh: &Bvh) -> Self {
let (flat_nodes, prim_indices) = flatten(bvh);
let primitives = bvh.primitives.clone();
#[cfg(feature = "wgpu-backend")]
{
use crate::compute::wgpu_backend::real::WgpuBackendReal;
use std::sync::atomic::Ordering;
if let Ok(mut backend) = WgpuBackendReal::try_new() {
let prim_aabb_f32s: Vec<f32> = primitives
.iter()
.flat_map(|p| {
[
p.aabb.min[0],
p.aabb.min[1],
p.aabb.min[2],
p.aabb.max[0],
p.aabb.max[1],
p.aabb.max[2],
]
})
.collect();
let prim_aabbs_buf =
backend.create_buffer_storage((prim_aabb_f32s.len() * 4).max(16) as u64);
backend.queue_write_buffer_f32(&prim_aabbs_buf, &prim_aabb_f32s);
let prim_u32s: Vec<u32> = prim_indices.iter().map(|&i| i as u32).collect();
let prim_indices_buf =
backend.create_buffer_storage((prim_u32s.len() * 4).max(16) as u64);
backend.queue_write_buffer_raw(&prim_indices_buf, bytemuck::cast_slice(&prim_u32s));
let obj_ids: Vec<i32> = primitives.iter().map(|p| p.object_id as i32).collect();
let object_ids_buf =
backend.create_buffer_storage((obj_ids.len() * 4).max(16) as u64);
backend.queue_write_buffer_raw(&object_ids_buf, bytemuck::cast_slice(&obj_ids));
let creation_id = CREATION_COUNTER.fetch_add(1, Ordering::Relaxed);
return Self {
flat_nodes,
prim_indices,
primitives,
inner: BvhTraverserInner::Gpu(Box::new(BvhGpuState {
backend: std::sync::Mutex::new(backend),
prim_aabbs_buf,
prim_indices_buf,
object_ids_buf,
dispatch_count: std::sync::atomic::AtomicU64::new(0),
creation_id,
})),
};
}
}
Self {
flat_nodes,
prim_indices,
primitives,
inner: BvhTraverserInner::Cpu,
}
}
pub fn new_cpu(bvh: &Bvh) -> Self {
let (flat_nodes, prim_indices) = flatten(bvh);
Self {
flat_nodes,
prim_indices,
primitives: bvh.primitives.clone(),
inner: BvhTraverserInner::Cpu,
}
}
pub fn is_gpu(&self) -> bool {
match &self.inner {
BvhTraverserInner::Cpu => false,
#[cfg(feature = "wgpu-backend")]
BvhTraverserInner::Gpu(_) => true,
}
}
#[cfg(feature = "wgpu-backend")]
pub fn dispatch_count(&self) -> u64 {
match &self.inner {
BvhTraverserInner::Cpu => 0,
BvhTraverserInner::Gpu(state) => state
.dispatch_count
.load(std::sync::atomic::Ordering::Relaxed),
}
}
#[cfg(feature = "wgpu-backend")]
pub fn creation_id(&self) -> Option<u64> {
match &self.inner {
BvhTraverserInner::Cpu => None,
BvhTraverserInner::Gpu(state) => Some(state.creation_id),
}
}
pub fn traverse_rays(&self, rays: &[GpuRay]) -> Vec<i32> {
if rays.is_empty() {
return Vec::new();
}
match &self.inner {
BvhTraverserInner::Cpu => self.traverse_rays_cpu(rays),
#[cfg(feature = "wgpu-backend")]
BvhTraverserInner::Gpu(state) => self
.traverse_rays_gpu(state, rays)
.unwrap_or_else(|_| self.traverse_rays_cpu(rays)),
}
}
fn traverse_rays_cpu(&self, rays: &[GpuRay]) -> Vec<i32> {
rays.iter()
.map(|ray| self.traverse_single_cpu(ray))
.collect()
}
fn traverse_single_cpu(&self, ray: &GpuRay) -> i32 {
if self.flat_nodes.is_empty() {
return -1;
}
let inv_dir = [
1.0 / ray.direction[0],
1.0 / ray.direction[1],
1.0 / ray.direction[2],
];
let origin = ray.origin;
let max_t = ray.max_t;
let mut best_hit: i32 = -1;
let mut best_t = max_t;
let mut stack = Vec::with_capacity(64);
stack.push(0usize);
while let Some(idx) = stack.pop() {
let node = &self.flat_nodes[idx];
if ray_aabb_t(origin, inv_dir, &node.aabb).is_none() {
continue;
}
if node.count > 0 {
let start = node.left_first as usize;
let end = (start + node.count as usize).min(self.prim_indices.len());
for &pi in &self.prim_indices[start..end] {
if pi >= self.primitives.len() {
continue;
}
if let Some((t_near, _)) =
ray_aabb_t(origin, inv_dir, &self.primitives[pi].aabb)
&& t_near >= 0.0
&& t_near < best_t
{
best_t = t_near;
best_hit = self.primitives[pi].object_id as i32;
}
}
} else {
let right = node.left_first as usize;
let left = idx + 1;
if right < self.flat_nodes.len() {
stack.push(right);
}
if left < self.flat_nodes.len() && left != right {
stack.push(left);
}
}
}
best_hit
}
#[cfg(feature = "wgpu-backend")]
fn traverse_rays_gpu(
&self,
state: &BvhGpuState,
rays: &[GpuRay],
) -> Result<Vec<i32>, crate::GpuError> {
use std::sync::atomic::Ordering;
let n_rays = rays.len() as u32;
let n_nodes = self.flat_nodes.len() as u32;
let n_prims = self.prim_indices.len() as u32;
let mut backend = state
.backend
.lock()
.expect("BvhGpuState backend lock poisoned");
let params_data: [u32; 4] = [n_nodes, n_rays, n_prims, 0];
let params_buf = backend.create_buffer_storage(16);
backend.queue_write_buffer_raw(¶ms_buf, bytemuck::cast_slice(¶ms_data));
let node_f32s: Vec<f32> = self
.flat_nodes
.iter()
.flat_map(|n| {
[
n.aabb.min[0],
n.aabb.min[1],
n.aabb.min[2],
n.aabb.max[0],
n.aabb.max[1],
n.aabb.max[2],
f32::from_bits(n.left_first),
f32::from_bits(n.count),
]
})
.collect();
let nodes_buf = backend.create_buffer_storage((node_f32s.len() * 4).max(16) as u64);
backend.queue_write_buffer_f32(&nodes_buf, &node_f32s);
let ray_f32s: Vec<f32> = rays
.iter()
.flat_map(|r| {
[
r.origin[0],
r.origin[1],
r.origin[2],
r.direction[0],
r.direction[1],
r.direction[2],
r.max_t,
0.0_f32, ]
})
.collect();
let rays_buf = backend.create_buffer_storage((ray_f32s.len() * 4).max(16) as u64);
backend.queue_write_buffer_f32(&rays_buf, &ray_f32s);
let results_buf = backend.create_buffer_storage((n_rays as usize * 4).max(16) as u64);
let workgroups_x = n_rays.div_ceil(64);
backend
.dispatch_wgsl(
BVH_TRAVERSAL_WGSL,
"main",
&[
(
params_buf,
wgpu::BufferBindingType::Storage { read_only: true },
),
(
nodes_buf,
wgpu::BufferBindingType::Storage { read_only: true },
),
(
rays_buf,
wgpu::BufferBindingType::Storage { read_only: true },
),
(
results_buf,
wgpu::BufferBindingType::Storage { read_only: false },
),
(
state.prim_indices_buf,
wgpu::BufferBindingType::Storage { read_only: true },
),
(
state.object_ids_buf,
wgpu::BufferBindingType::Storage { read_only: true },
),
(
state.prim_aabbs_buf,
wgpu::BufferBindingType::Storage { read_only: true },
),
],
[workgroups_x, 1, 1],
)
.map_err(|e| crate::GpuError::ShaderDispatch(e.to_string()))?;
let raw = backend.read_buffer_f32(results_buf);
let hits: Vec<i32> = raw
.iter()
.take(n_rays as usize)
.map(|&f| f32::to_bits(f) as i32)
.collect();
let mut result = hits;
result.resize(n_rays as usize, -1);
state.dispatch_count.fetch_add(1, Ordering::Relaxed);
Ok(result)
}
}