oxiphysics_gpu/bvh/gpu.rs
1// Copyright 2026 COOLJAPAN OU (Team KitaSan)
2// SPDX-License-Identifier: Apache-2.0
3
4//! GPU-accelerated BVH traversal using a persistent `WgpuBackendReal` instance.
5
6use super::cpu::{Bvh, flatten, ray_aabb_t};
7use super::types::{FlatBvhNode, GpuRay};
8
9/// WGSL source for the BVH traversal kernel.
10#[cfg(feature = "wgpu-backend")]
11const BVH_TRAVERSAL_WGSL: &str = include_str!("../shaders/bvh_traversal.wgsl");
12
13// ============================================================================
14// BvhGpuState — per-BVH GPU resources (allocated once at construction)
15// ============================================================================
16
17/// Persistent GPU resources for a BVH.
18///
19/// `BvhGpuState` owns a `WgpuBackendReal` and three primitive buffers that are
20/// uploaded once at construction time. Per-call allocations are limited to the
21/// rays buffer and the results buffer, which are re-allocated on demand when the
22/// ray count increases.
23///
24/// # Thread safety
25///
26/// `WgpuBackendReal` is `Send + Sync` (device/queue are `Arc`-wrapped and the
27/// shader cache uses `Mutex`). We wrap the backend in an additional `Mutex` so
28/// that callers who share a `BvhGpuTraverser` across threads can do so safely
29/// without re-entrant dispatch issues. `BvhGpuTraverser` is never placed inside
30/// a `rayon::ParallelIterator` closure in the current code-base (verified by the
31/// send-bound audit in Step 0 of the block spec); nevertheless we default to
32/// `Mutex` so the type is unconditionally `Send + Sync`.
33#[cfg(feature = "wgpu-backend")]
34pub(crate) struct BvhGpuState {
35 /// Backend stored once; locked per dispatch.
36 ///
37 /// Using `Mutex<WgpuBackendReal>` rather than `RefCell` so that
38 /// `BvhGpuTraverser` is `Send + Sync` regardless of call-site threading.
39 /// The audit found no parallel call sites today, but the Mutex overhead is
40 /// negligible compared to GPU dispatch latency.
41 pub(crate) backend: std::sync::Mutex<crate::compute::wgpu_backend::real::WgpuBackendReal>,
42 /// Primitive AABB buffer: 6 × f32 per primitive [min_xyz, max_xyz].
43 pub(crate) prim_aabbs_buf: crate::compute::WgpuBufferHandle,
44 /// Primitive-index buffer: one u32 per entry in the flat prim-index array.
45 pub(crate) prim_indices_buf: crate::compute::WgpuBufferHandle,
46 /// Object-ID buffer: one i32 per primitive.
47 pub(crate) object_ids_buf: crate::compute::WgpuBufferHandle,
48 /// Number of dispatches completed (observability / reuse test).
49 pub(crate) dispatch_count: std::sync::atomic::AtomicU64,
50 /// Monotonically increasing ID assigned at construction (reuse test).
51 pub(crate) creation_id: u64,
52}
53
54// ============================================================================
55// BvhTraverserInner — enum over CPU-only and GPU variants
56// ============================================================================
57
58pub(crate) enum BvhTraverserInner {
59 /// CPU-only fallback.
60 Cpu,
61 /// GPU backend with pre-uploaded primitive buffers.
62 #[cfg(feature = "wgpu-backend")]
63 Gpu(Box<BvhGpuState>),
64}
65
66// ============================================================================
67// BvhGpuTraverser
68// ============================================================================
69
70/// GPU-accelerated BVH ray traversal.
71///
72/// Encodes a flat BVH into GPU-resident buffers and dispatches the
73/// `bvh_traversal.wgsl` kernel. Falls back to CPU traversal when no GPU
74/// adapter is available.
75///
76/// # Usage
77///
78/// ```rust
79/// use oxiphysics_gpu::bvh::{Aabb, Bvh, BvhPrimitive, BvhGpuTraverser, GpuRay};
80///
81/// let prims = vec![
82/// BvhPrimitive::new(Aabb::new([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]), 0),
83/// ];
84/// let bvh = Bvh::build(prims);
85/// let traverser = BvhGpuTraverser::new(&bvh);
86/// let rays = vec![GpuRay::new([0.5, 0.5, -1.0], [0.0, 0.0, 1.0], 100.0)];
87/// let hits = traverser.traverse_rays(&rays);
88/// assert_eq!(hits.len(), 1);
89/// ```
90pub struct BvhGpuTraverser {
91 /// Flat BVH nodes (CPU copy kept for fallback).
92 pub(crate) flat_nodes: Vec<FlatBvhNode>,
93 /// Primitive indices (CPU copy).
94 pub(crate) prim_indices: Vec<usize>,
95 /// Primitives (CPU copy).
96 pub(crate) primitives: Vec<super::types::BvhPrimitive>,
97 /// GPU resources (gated behind feature).
98 pub(crate) inner: BvhTraverserInner,
99}
100
101/// Monotonically increasing counter for creation_id assignment.
102#[cfg(feature = "wgpu-backend")]
103static CREATION_COUNTER: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(1);
104
105impl BvhGpuTraverser {
106 /// Create a traverser from a BVH.
107 ///
108 /// Flattens the BVH and uploads the node + primitive buffers to the GPU.
109 /// Falls back to CPU traversal if no GPU adapter is found.
110 pub fn new(bvh: &Bvh) -> Self {
111 let (flat_nodes, prim_indices) = flatten(bvh);
112 let primitives = bvh.primitives.clone();
113
114 #[cfg(feature = "wgpu-backend")]
115 {
116 use crate::compute::wgpu_backend::real::WgpuBackendReal;
117 use std::sync::atomic::Ordering;
118
119 if let Ok(mut backend) = WgpuBackendReal::try_new() {
120 // ── Upload primitive data (once, reused across all traverse_rays calls) ──
121
122 // prim_aabbs: 6 × f32 per primitive
123 let prim_aabb_f32s: Vec<f32> = primitives
124 .iter()
125 .flat_map(|p| {
126 [
127 p.aabb.min[0],
128 p.aabb.min[1],
129 p.aabb.min[2],
130 p.aabb.max[0],
131 p.aabb.max[1],
132 p.aabb.max[2],
133 ]
134 })
135 .collect();
136 let prim_aabbs_buf =
137 backend.create_buffer_storage((prim_aabb_f32s.len() * 4).max(16) as u64);
138 backend.queue_write_buffer_f32(&prim_aabbs_buf, &prim_aabb_f32s);
139
140 // prim_indices: one u32 per entry
141 let prim_u32s: Vec<u32> = prim_indices.iter().map(|&i| i as u32).collect();
142 let prim_indices_buf =
143 backend.create_buffer_storage((prim_u32s.len() * 4).max(16) as u64);
144 backend.queue_write_buffer_raw(&prim_indices_buf, bytemuck::cast_slice(&prim_u32s));
145
146 // object_ids: one i32 per primitive
147 let obj_ids: Vec<i32> = primitives.iter().map(|p| p.object_id as i32).collect();
148 let object_ids_buf =
149 backend.create_buffer_storage((obj_ids.len() * 4).max(16) as u64);
150 backend.queue_write_buffer_raw(&object_ids_buf, bytemuck::cast_slice(&obj_ids));
151
152 let creation_id = CREATION_COUNTER.fetch_add(1, Ordering::Relaxed);
153
154 return Self {
155 flat_nodes,
156 prim_indices,
157 primitives,
158 inner: BvhTraverserInner::Gpu(Box::new(BvhGpuState {
159 backend: std::sync::Mutex::new(backend),
160 prim_aabbs_buf,
161 prim_indices_buf,
162 object_ids_buf,
163 dispatch_count: std::sync::atomic::AtomicU64::new(0),
164 creation_id,
165 })),
166 };
167 }
168 }
169
170 Self {
171 flat_nodes,
172 prim_indices,
173 primitives,
174 inner: BvhTraverserInner::Cpu,
175 }
176 }
177
178 /// Create a CPU-only traverser (useful for testing without a GPU).
179 pub fn new_cpu(bvh: &Bvh) -> Self {
180 let (flat_nodes, prim_indices) = flatten(bvh);
181 Self {
182 flat_nodes,
183 prim_indices,
184 primitives: bvh.primitives.clone(),
185 inner: BvhTraverserInner::Cpu,
186 }
187 }
188
189 /// Returns `true` if using a real GPU backend.
190 pub fn is_gpu(&self) -> bool {
191 match &self.inner {
192 BvhTraverserInner::Cpu => false,
193 #[cfg(feature = "wgpu-backend")]
194 BvhTraverserInner::Gpu(_) => true,
195 }
196 }
197
198 /// Returns the current dispatch count (GPU variant only; always 0 for CPU).
199 #[cfg(feature = "wgpu-backend")]
200 pub fn dispatch_count(&self) -> u64 {
201 match &self.inner {
202 BvhTraverserInner::Cpu => 0,
203 BvhTraverserInner::Gpu(state) => state
204 .dispatch_count
205 .load(std::sync::atomic::Ordering::Relaxed),
206 }
207 }
208
209 /// Returns the creation_id of the underlying GPU state (for reuse tests).
210 #[cfg(feature = "wgpu-backend")]
211 pub fn creation_id(&self) -> Option<u64> {
212 match &self.inner {
213 BvhTraverserInner::Cpu => None,
214 BvhTraverserInner::Gpu(state) => Some(state.creation_id),
215 }
216 }
217
218 /// Traverse the BVH for each ray.
219 ///
220 /// Returns a `Vec<i32>` of length `rays.len()`. Each element is either:
221 /// - the `object_id` of the first hit leaf's primitive, or
222 /// - `-1` if no intersection was found.
223 pub fn traverse_rays(&self, rays: &[GpuRay]) -> Vec<i32> {
224 if rays.is_empty() {
225 return Vec::new();
226 }
227 match &self.inner {
228 BvhTraverserInner::Cpu => self.traverse_rays_cpu(rays),
229 #[cfg(feature = "wgpu-backend")]
230 BvhTraverserInner::Gpu(state) => self
231 .traverse_rays_gpu(state, rays)
232 .unwrap_or_else(|_| self.traverse_rays_cpu(rays)),
233 }
234 }
235
236 // ── CPU traversal ─────────────────────────────────────────────────────────
237
238 fn traverse_rays_cpu(&self, rays: &[GpuRay]) -> Vec<i32> {
239 rays.iter()
240 .map(|ray| self.traverse_single_cpu(ray))
241 .collect()
242 }
243
244 fn traverse_single_cpu(&self, ray: &GpuRay) -> i32 {
245 if self.flat_nodes.is_empty() {
246 return -1;
247 }
248 let inv_dir = [
249 1.0 / ray.direction[0],
250 1.0 / ray.direction[1],
251 1.0 / ray.direction[2],
252 ];
253 let origin = ray.origin;
254 let max_t = ray.max_t;
255 let mut best_hit: i32 = -1;
256 let mut best_t = max_t;
257
258 let mut stack = Vec::with_capacity(64);
259 stack.push(0usize);
260
261 while let Some(idx) = stack.pop() {
262 let node = &self.flat_nodes[idx];
263 // Slab test
264 if ray_aabb_t(origin, inv_dir, &node.aabb).is_none() {
265 continue;
266 }
267 if node.count > 0 {
268 // Leaf: check each primitive
269 let start = node.left_first as usize;
270 let end = (start + node.count as usize).min(self.prim_indices.len());
271 for &pi in &self.prim_indices[start..end] {
272 if pi >= self.primitives.len() {
273 continue;
274 }
275 if let Some((t_near, _)) =
276 ray_aabb_t(origin, inv_dir, &self.primitives[pi].aabb)
277 && t_near >= 0.0
278 && t_near < best_t
279 {
280 best_t = t_near;
281 best_hit = self.primitives[pi].object_id as i32;
282 }
283 }
284 } else {
285 let right = node.left_first as usize;
286 let left = idx + 1;
287 if right < self.flat_nodes.len() {
288 stack.push(right);
289 }
290 if left < self.flat_nodes.len() && left != right {
291 stack.push(left);
292 }
293 }
294 }
295 best_hit
296 }
297
298 // ── GPU traversal ─────────────────────────────────────────────────────────
299
300 #[cfg(feature = "wgpu-backend")]
301 fn traverse_rays_gpu(
302 &self,
303 state: &BvhGpuState,
304 rays: &[GpuRay],
305 ) -> Result<Vec<i32>, crate::GpuError> {
306 use std::sync::atomic::Ordering;
307
308 let n_rays = rays.len() as u32;
309 let n_nodes = self.flat_nodes.len() as u32;
310 let n_prims = self.prim_indices.len() as u32;
311
312 // Lock the backend for this dispatch.
313 let mut backend = state
314 .backend
315 .lock()
316 .expect("BvhGpuState backend lock poisoned");
317
318 // ── Encode params [n_nodes, n_rays, n_prims, 0] (per-call) ──────────
319 let params_data: [u32; 4] = [n_nodes, n_rays, n_prims, 0];
320 let params_buf = backend.create_buffer_storage(16);
321 backend.queue_write_buffer_raw(¶ms_buf, bytemuck::cast_slice(¶ms_data));
322
323 // ── Encode BVH nodes (per-call, same data each time) ─────────────────
324 // Layout: [min_x, min_y, min_z, max_x, max_y, max_z, left_first_bits, count_bits]
325 let node_f32s: Vec<f32> = self
326 .flat_nodes
327 .iter()
328 .flat_map(|n| {
329 [
330 n.aabb.min[0],
331 n.aabb.min[1],
332 n.aabb.min[2],
333 n.aabb.max[0],
334 n.aabb.max[1],
335 n.aabb.max[2],
336 f32::from_bits(n.left_first),
337 f32::from_bits(n.count),
338 ]
339 })
340 .collect();
341 let nodes_buf = backend.create_buffer_storage((node_f32s.len() * 4).max(16) as u64);
342 backend.queue_write_buffer_f32(&nodes_buf, &node_f32s);
343
344 // ── Encode rays: 8 f32 per ray ────────────────────────────────────────
345 let ray_f32s: Vec<f32> = rays
346 .iter()
347 .flat_map(|r| {
348 [
349 r.origin[0],
350 r.origin[1],
351 r.origin[2],
352 r.direction[0],
353 r.direction[1],
354 r.direction[2],
355 r.max_t,
356 0.0_f32, // pad
357 ]
358 })
359 .collect();
360 let rays_buf = backend.create_buffer_storage((ray_f32s.len() * 4).max(16) as u64);
361 backend.queue_write_buffer_f32(&rays_buf, &ray_f32s);
362
363 // ── Results buffer ────────────────────────────────────────────────────
364 let results_buf = backend.create_buffer_storage((n_rays as usize * 4).max(16) as u64);
365
366 // ── Dispatch ──────────────────────────────────────────────────────────
367 let workgroups_x = n_rays.div_ceil(64);
368 backend
369 .dispatch_wgsl(
370 BVH_TRAVERSAL_WGSL,
371 "main",
372 &[
373 (
374 params_buf,
375 wgpu::BufferBindingType::Storage { read_only: true },
376 ),
377 (
378 nodes_buf,
379 wgpu::BufferBindingType::Storage { read_only: true },
380 ),
381 (
382 rays_buf,
383 wgpu::BufferBindingType::Storage { read_only: true },
384 ),
385 (
386 results_buf,
387 wgpu::BufferBindingType::Storage { read_only: false },
388 ),
389 (
390 state.prim_indices_buf,
391 wgpu::BufferBindingType::Storage { read_only: true },
392 ),
393 (
394 state.object_ids_buf,
395 wgpu::BufferBindingType::Storage { read_only: true },
396 ),
397 (
398 state.prim_aabbs_buf,
399 wgpu::BufferBindingType::Storage { read_only: true },
400 ),
401 ],
402 [workgroups_x, 1, 1],
403 )
404 .map_err(|e| crate::GpuError::ShaderDispatch(e.to_string()))?;
405
406 // ── Read results back ─────────────────────────────────────────────────
407 let raw = backend.read_buffer_f32(results_buf);
408 let hits: Vec<i32> = raw
409 .iter()
410 .take(n_rays as usize)
411 .map(|&f| f32::to_bits(f) as i32)
412 .collect();
413
414 // Pad to n_rays if buffer was short.
415 let mut result = hits;
416 result.resize(n_rays as usize, -1);
417
418 // Increment observability counter.
419 state.dispatch_count.fetch_add(1, Ordering::Relaxed);
420
421 Ok(result)
422 }
423}