Skip to main content

oxiphysics_gpu/bvh/
gpu.rs

1// Copyright 2026 COOLJAPAN OU (Team KitaSan)
2// SPDX-License-Identifier: Apache-2.0
3
4//! GPU-accelerated BVH traversal using a persistent `WgpuBackendReal` instance.
5
6use super::cpu::{Bvh, flatten, ray_aabb_t};
7use super::types::{FlatBvhNode, GpuRay};
8
9/// WGSL source for the BVH traversal kernel.
10#[cfg(feature = "wgpu-backend")]
11const BVH_TRAVERSAL_WGSL: &str = include_str!("../shaders/bvh_traversal.wgsl");
12
13// ============================================================================
14// BvhGpuState — per-BVH GPU resources (allocated once at construction)
15// ============================================================================
16
17/// Persistent GPU resources for a BVH.
18///
19/// `BvhGpuState` owns a `WgpuBackendReal` and three primitive buffers that are
20/// uploaded once at construction time.  Per-call allocations are limited to the
21/// rays buffer and the results buffer, which are re-allocated on demand when the
22/// ray count increases.
23///
24/// # Thread safety
25///
26/// `WgpuBackendReal` is `Send + Sync` (device/queue are `Arc`-wrapped and the
27/// shader cache uses `Mutex`).  We wrap the backend in an additional `Mutex` so
28/// that callers who share a `BvhGpuTraverser` across threads can do so safely
29/// without re-entrant dispatch issues.  `BvhGpuTraverser` is never placed inside
30/// a `rayon::ParallelIterator` closure in the current code-base (verified by the
31/// send-bound audit in Step 0 of the block spec); nevertheless we default to
32/// `Mutex` so the type is unconditionally `Send + Sync`.
33#[cfg(feature = "wgpu-backend")]
34pub(crate) struct BvhGpuState {
35    /// Backend stored once; locked per dispatch.
36    ///
37    /// Using `Mutex<WgpuBackendReal>` rather than `RefCell` so that
38    /// `BvhGpuTraverser` is `Send + Sync` regardless of call-site threading.
39    /// The audit found no parallel call sites today, but the Mutex overhead is
40    /// negligible compared to GPU dispatch latency.
41    pub(crate) backend: std::sync::Mutex<crate::compute::wgpu_backend::real::WgpuBackendReal>,
42    /// Primitive AABB buffer: 6 × f32 per primitive [min_xyz, max_xyz].
43    pub(crate) prim_aabbs_buf: crate::compute::WgpuBufferHandle,
44    /// Primitive-index buffer: one u32 per entry in the flat prim-index array.
45    pub(crate) prim_indices_buf: crate::compute::WgpuBufferHandle,
46    /// Object-ID buffer: one i32 per primitive.
47    pub(crate) object_ids_buf: crate::compute::WgpuBufferHandle,
48    /// Number of dispatches completed (observability / reuse test).
49    pub(crate) dispatch_count: std::sync::atomic::AtomicU64,
50    /// Monotonically increasing ID assigned at construction (reuse test).
51    pub(crate) creation_id: u64,
52}
53
54// ============================================================================
55// BvhTraverserInner — enum over CPU-only and GPU variants
56// ============================================================================
57
58pub(crate) enum BvhTraverserInner {
59    /// CPU-only fallback.
60    Cpu,
61    /// GPU backend with pre-uploaded primitive buffers.
62    #[cfg(feature = "wgpu-backend")]
63    Gpu(Box<BvhGpuState>),
64}
65
66// ============================================================================
67// BvhGpuTraverser
68// ============================================================================
69
70/// GPU-accelerated BVH ray traversal.
71///
72/// Encodes a flat BVH into GPU-resident buffers and dispatches the
73/// `bvh_traversal.wgsl` kernel.  Falls back to CPU traversal when no GPU
74/// adapter is available.
75///
76/// # Usage
77///
78/// ```rust
79/// use oxiphysics_gpu::bvh::{Aabb, Bvh, BvhPrimitive, BvhGpuTraverser, GpuRay};
80///
81/// let prims = vec![
82///     BvhPrimitive::new(Aabb::new([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]), 0),
83/// ];
84/// let bvh = Bvh::build(prims);
85/// let traverser = BvhGpuTraverser::new(&bvh);
86/// let rays = vec![GpuRay::new([0.5, 0.5, -1.0], [0.0, 0.0, 1.0], 100.0)];
87/// let hits = traverser.traverse_rays(&rays);
88/// assert_eq!(hits.len(), 1);
89/// ```
90pub struct BvhGpuTraverser {
91    /// Flat BVH nodes (CPU copy kept for fallback).
92    pub(crate) flat_nodes: Vec<FlatBvhNode>,
93    /// Primitive indices (CPU copy).
94    pub(crate) prim_indices: Vec<usize>,
95    /// Primitives (CPU copy).
96    pub(crate) primitives: Vec<super::types::BvhPrimitive>,
97    /// GPU resources (gated behind feature).
98    pub(crate) inner: BvhTraverserInner,
99}
100
101/// Monotonically increasing counter for creation_id assignment.
102#[cfg(feature = "wgpu-backend")]
103static CREATION_COUNTER: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(1);
104
105impl BvhGpuTraverser {
106    /// Create a traverser from a BVH.
107    ///
108    /// Flattens the BVH and uploads the node + primitive buffers to the GPU.
109    /// Falls back to CPU traversal if no GPU adapter is found.
110    pub fn new(bvh: &Bvh) -> Self {
111        let (flat_nodes, prim_indices) = flatten(bvh);
112        let primitives = bvh.primitives.clone();
113
114        #[cfg(feature = "wgpu-backend")]
115        {
116            use crate::compute::wgpu_backend::real::WgpuBackendReal;
117            use std::sync::atomic::Ordering;
118
119            if let Ok(mut backend) = WgpuBackendReal::try_new() {
120                // ── Upload primitive data (once, reused across all traverse_rays calls) ──
121
122                // prim_aabbs: 6 × f32 per primitive
123                let prim_aabb_f32s: Vec<f32> = primitives
124                    .iter()
125                    .flat_map(|p| {
126                        [
127                            p.aabb.min[0],
128                            p.aabb.min[1],
129                            p.aabb.min[2],
130                            p.aabb.max[0],
131                            p.aabb.max[1],
132                            p.aabb.max[2],
133                        ]
134                    })
135                    .collect();
136                let prim_aabbs_buf =
137                    backend.create_buffer_storage((prim_aabb_f32s.len() * 4).max(16) as u64);
138                backend.queue_write_buffer_f32(&prim_aabbs_buf, &prim_aabb_f32s);
139
140                // prim_indices: one u32 per entry
141                let prim_u32s: Vec<u32> = prim_indices.iter().map(|&i| i as u32).collect();
142                let prim_indices_buf =
143                    backend.create_buffer_storage((prim_u32s.len() * 4).max(16) as u64);
144                backend.queue_write_buffer_raw(&prim_indices_buf, bytemuck::cast_slice(&prim_u32s));
145
146                // object_ids: one i32 per primitive
147                let obj_ids: Vec<i32> = primitives.iter().map(|p| p.object_id as i32).collect();
148                let object_ids_buf =
149                    backend.create_buffer_storage((obj_ids.len() * 4).max(16) as u64);
150                backend.queue_write_buffer_raw(&object_ids_buf, bytemuck::cast_slice(&obj_ids));
151
152                let creation_id = CREATION_COUNTER.fetch_add(1, Ordering::Relaxed);
153
154                return Self {
155                    flat_nodes,
156                    prim_indices,
157                    primitives,
158                    inner: BvhTraverserInner::Gpu(Box::new(BvhGpuState {
159                        backend: std::sync::Mutex::new(backend),
160                        prim_aabbs_buf,
161                        prim_indices_buf,
162                        object_ids_buf,
163                        dispatch_count: std::sync::atomic::AtomicU64::new(0),
164                        creation_id,
165                    })),
166                };
167            }
168        }
169
170        Self {
171            flat_nodes,
172            prim_indices,
173            primitives,
174            inner: BvhTraverserInner::Cpu,
175        }
176    }
177
178    /// Create a CPU-only traverser (useful for testing without a GPU).
179    pub fn new_cpu(bvh: &Bvh) -> Self {
180        let (flat_nodes, prim_indices) = flatten(bvh);
181        Self {
182            flat_nodes,
183            prim_indices,
184            primitives: bvh.primitives.clone(),
185            inner: BvhTraverserInner::Cpu,
186        }
187    }
188
189    /// Returns `true` if using a real GPU backend.
190    pub fn is_gpu(&self) -> bool {
191        match &self.inner {
192            BvhTraverserInner::Cpu => false,
193            #[cfg(feature = "wgpu-backend")]
194            BvhTraverserInner::Gpu(_) => true,
195        }
196    }
197
198    /// Returns the current dispatch count (GPU variant only; always 0 for CPU).
199    #[cfg(feature = "wgpu-backend")]
200    pub fn dispatch_count(&self) -> u64 {
201        match &self.inner {
202            BvhTraverserInner::Cpu => 0,
203            BvhTraverserInner::Gpu(state) => state
204                .dispatch_count
205                .load(std::sync::atomic::Ordering::Relaxed),
206        }
207    }
208
209    /// Returns the creation_id of the underlying GPU state (for reuse tests).
210    #[cfg(feature = "wgpu-backend")]
211    pub fn creation_id(&self) -> Option<u64> {
212        match &self.inner {
213            BvhTraverserInner::Cpu => None,
214            BvhTraverserInner::Gpu(state) => Some(state.creation_id),
215        }
216    }
217
218    /// Traverse the BVH for each ray.
219    ///
220    /// Returns a `Vec<i32>` of length `rays.len()`.  Each element is either:
221    /// - the `object_id` of the first hit leaf's primitive, or
222    /// - `-1` if no intersection was found.
223    pub fn traverse_rays(&self, rays: &[GpuRay]) -> Vec<i32> {
224        if rays.is_empty() {
225            return Vec::new();
226        }
227        match &self.inner {
228            BvhTraverserInner::Cpu => self.traverse_rays_cpu(rays),
229            #[cfg(feature = "wgpu-backend")]
230            BvhTraverserInner::Gpu(state) => self
231                .traverse_rays_gpu(state, rays)
232                .unwrap_or_else(|_| self.traverse_rays_cpu(rays)),
233        }
234    }
235
236    // ── CPU traversal ─────────────────────────────────────────────────────────
237
238    fn traverse_rays_cpu(&self, rays: &[GpuRay]) -> Vec<i32> {
239        rays.iter()
240            .map(|ray| self.traverse_single_cpu(ray))
241            .collect()
242    }
243
244    fn traverse_single_cpu(&self, ray: &GpuRay) -> i32 {
245        if self.flat_nodes.is_empty() {
246            return -1;
247        }
248        let inv_dir = [
249            1.0 / ray.direction[0],
250            1.0 / ray.direction[1],
251            1.0 / ray.direction[2],
252        ];
253        let origin = ray.origin;
254        let max_t = ray.max_t;
255        let mut best_hit: i32 = -1;
256        let mut best_t = max_t;
257
258        let mut stack = Vec::with_capacity(64);
259        stack.push(0usize);
260
261        while let Some(idx) = stack.pop() {
262            let node = &self.flat_nodes[idx];
263            // Slab test
264            if ray_aabb_t(origin, inv_dir, &node.aabb).is_none() {
265                continue;
266            }
267            if node.count > 0 {
268                // Leaf: check each primitive
269                let start = node.left_first as usize;
270                let end = (start + node.count as usize).min(self.prim_indices.len());
271                for &pi in &self.prim_indices[start..end] {
272                    if pi >= self.primitives.len() {
273                        continue;
274                    }
275                    if let Some((t_near, _)) =
276                        ray_aabb_t(origin, inv_dir, &self.primitives[pi].aabb)
277                        && t_near >= 0.0
278                        && t_near < best_t
279                    {
280                        best_t = t_near;
281                        best_hit = self.primitives[pi].object_id as i32;
282                    }
283                }
284            } else {
285                let right = node.left_first as usize;
286                let left = idx + 1;
287                if right < self.flat_nodes.len() {
288                    stack.push(right);
289                }
290                if left < self.flat_nodes.len() && left != right {
291                    stack.push(left);
292                }
293            }
294        }
295        best_hit
296    }
297
298    // ── GPU traversal ─────────────────────────────────────────────────────────
299
300    #[cfg(feature = "wgpu-backend")]
301    fn traverse_rays_gpu(
302        &self,
303        state: &BvhGpuState,
304        rays: &[GpuRay],
305    ) -> Result<Vec<i32>, crate::GpuError> {
306        use std::sync::atomic::Ordering;
307
308        let n_rays = rays.len() as u32;
309        let n_nodes = self.flat_nodes.len() as u32;
310        let n_prims = self.prim_indices.len() as u32;
311
312        // Lock the backend for this dispatch.
313        let mut backend = state
314            .backend
315            .lock()
316            .expect("BvhGpuState backend lock poisoned");
317
318        // ── Encode params [n_nodes, n_rays, n_prims, 0] (per-call) ──────────
319        let params_data: [u32; 4] = [n_nodes, n_rays, n_prims, 0];
320        let params_buf = backend.create_buffer_storage(16);
321        backend.queue_write_buffer_raw(&params_buf, bytemuck::cast_slice(&params_data));
322
323        // ── Encode BVH nodes (per-call, same data each time) ─────────────────
324        // Layout: [min_x, min_y, min_z, max_x, max_y, max_z, left_first_bits, count_bits]
325        let node_f32s: Vec<f32> = self
326            .flat_nodes
327            .iter()
328            .flat_map(|n| {
329                [
330                    n.aabb.min[0],
331                    n.aabb.min[1],
332                    n.aabb.min[2],
333                    n.aabb.max[0],
334                    n.aabb.max[1],
335                    n.aabb.max[2],
336                    f32::from_bits(n.left_first),
337                    f32::from_bits(n.count),
338                ]
339            })
340            .collect();
341        let nodes_buf = backend.create_buffer_storage((node_f32s.len() * 4).max(16) as u64);
342        backend.queue_write_buffer_f32(&nodes_buf, &node_f32s);
343
344        // ── Encode rays: 8 f32 per ray ────────────────────────────────────────
345        let ray_f32s: Vec<f32> = rays
346            .iter()
347            .flat_map(|r| {
348                [
349                    r.origin[0],
350                    r.origin[1],
351                    r.origin[2],
352                    r.direction[0],
353                    r.direction[1],
354                    r.direction[2],
355                    r.max_t,
356                    0.0_f32, // pad
357                ]
358            })
359            .collect();
360        let rays_buf = backend.create_buffer_storage((ray_f32s.len() * 4).max(16) as u64);
361        backend.queue_write_buffer_f32(&rays_buf, &ray_f32s);
362
363        // ── Results buffer ────────────────────────────────────────────────────
364        let results_buf = backend.create_buffer_storage((n_rays as usize * 4).max(16) as u64);
365
366        // ── Dispatch ──────────────────────────────────────────────────────────
367        let workgroups_x = n_rays.div_ceil(64);
368        backend
369            .dispatch_wgsl(
370                BVH_TRAVERSAL_WGSL,
371                "main",
372                &[
373                    (
374                        params_buf,
375                        wgpu::BufferBindingType::Storage { read_only: true },
376                    ),
377                    (
378                        nodes_buf,
379                        wgpu::BufferBindingType::Storage { read_only: true },
380                    ),
381                    (
382                        rays_buf,
383                        wgpu::BufferBindingType::Storage { read_only: true },
384                    ),
385                    (
386                        results_buf,
387                        wgpu::BufferBindingType::Storage { read_only: false },
388                    ),
389                    (
390                        state.prim_indices_buf,
391                        wgpu::BufferBindingType::Storage { read_only: true },
392                    ),
393                    (
394                        state.object_ids_buf,
395                        wgpu::BufferBindingType::Storage { read_only: true },
396                    ),
397                    (
398                        state.prim_aabbs_buf,
399                        wgpu::BufferBindingType::Storage { read_only: true },
400                    ),
401                ],
402                [workgroups_x, 1, 1],
403            )
404            .map_err(|e| crate::GpuError::ShaderDispatch(e.to_string()))?;
405
406        // ── Read results back ─────────────────────────────────────────────────
407        let raw = backend.read_buffer_f32(results_buf);
408        let hits: Vec<i32> = raw
409            .iter()
410            .take(n_rays as usize)
411            .map(|&f| f32::to_bits(f) as i32)
412            .collect();
413
414        // Pad to n_rays if buffer was short.
415        let mut result = hits;
416        result.resize(n_rays as usize, -1);
417
418        // Increment observability counter.
419        state.dispatch_count.fetch_add(1, Ordering::Relaxed);
420
421        Ok(result)
422    }
423}