Skip to main content

oxiphysics_gpu/bvh/
mod.rs

1// Copyright 2026 COOLJAPAN OU (Team KitaSan)
2// SPDX-License-Identifier: Apache-2.0
3
4//! CPU BVH (Bounding Volume Hierarchy) tree for broad-phase acceleration.
5//!
6//! Provides axis-aligned bounding box (AABB) structures, a recursive BVH tree
7//! built with a SAH-inspired median split, ray and AABB queries, and a
8//! linearised flat representation for cache-friendly traversal.
9//!
10//! # Module layout
11//!
12//! - [`types`] — Core data types: `Aabb`, `BvhPrimitive`, `BvhNode`, `FlatBvhNode`, `GpuRay`, `RayHit`, statistics types, Morton/LBVH types.
13//! - [`cpu`] — CPU builder (`Bvh::build`), query helpers, flat traversal, LBVH construction, statistics.
14//! - [`gpu`] — GPU-accelerated ray traversal (`BvhGpuTraverser`) backed by `WgpuBackendReal`; falls back to CPU when no GPU is present.
15
16pub mod cpu;
17pub mod gpu;
18pub mod types;
19
20// ── Public re-exports ────────────────────────────────────────────────────────
21
22pub use cpu::{
23    Bvh, build_morton_clusters, bvh_closest_hit, compute_bvh_from_sorted, compute_cluster_radius,
24    flatten, hlbvh_split, lbvh_build, morton_code, query_flat, ray_aabb_intersect, refit, sah_cost,
25};
26pub use gpu::BvhGpuTraverser;
27pub use types::{
28    Aabb, BvhNode, BvhPrimitive, BvhStats, BvhTreeStatistics, FlatBvhNode, GpuRay, LbvhPrimitive,
29    MortonCluster, RayHit,
30};
31
32// ============================================================================
33// Tests (migrated from the original monolithic bvh.rs)
34// ============================================================================
35
36#[cfg(test)]
37mod tests {
38    use super::*;
39
40    // ------------------------------------------------------------------
41    // Aabb tests
42    // ------------------------------------------------------------------
43
44    #[test]
45    fn aabb_new_stores_corners() {
46        let a = Aabb::new([1.0, 2.0, 3.0], [4.0, 5.0, 6.0]);
47        assert_eq!(a.min, [1.0, 2.0, 3.0]);
48        assert_eq!(a.max, [4.0, 5.0, 6.0]);
49    }
50
51    #[test]
52    fn aabb_point_is_degenerate() {
53        let p = [3.0, 3.0, 3.0];
54        let a = Aabb::point(p);
55        assert_eq!(a.min, p);
56        assert_eq!(a.max, p);
57    }
58
59    #[test]
60    fn aabb_merge_covers_both() {
61        let a = Aabb::new([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]);
62        let b = Aabb::new([2.0, 2.0, 2.0], [3.0, 3.0, 3.0]);
63        let m = Aabb::merge(&a, &b);
64        assert_eq!(m.min, [0.0, 0.0, 0.0]);
65        assert_eq!(m.max, [3.0, 3.0, 3.0]);
66    }
67
68    #[test]
69    fn aabb_intersects_overlapping() {
70        let a = Aabb::new([0.0, 0.0, 0.0], [2.0, 2.0, 2.0]);
71        let b = Aabb::new([1.0, 1.0, 1.0], [3.0, 3.0, 3.0]);
72        assert!(a.intersects(&b));
73    }
74
75    #[test]
76    fn aabb_intersects_disjoint() {
77        let a = Aabb::new([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]);
78        let b = Aabb::new([2.0, 2.0, 2.0], [3.0, 3.0, 3.0]);
79        assert!(!a.intersects(&b));
80    }
81
82    #[test]
83    fn aabb_intersects_touching_edge() {
84        let a = Aabb::new([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]);
85        let b = Aabb::new([1.0, 0.0, 0.0], [2.0, 1.0, 1.0]);
86        assert!(a.intersects(&b));
87    }
88
89    #[test]
90    fn aabb_contains_inside() {
91        let a = Aabb::new([0.0, 0.0, 0.0], [4.0, 4.0, 4.0]);
92        assert!(a.contains([2.0, 2.0, 2.0]));
93    }
94
95    #[test]
96    fn aabb_contains_outside() {
97        let a = Aabb::new([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]);
98        assert!(!a.contains([2.0, 0.0, 0.0]));
99    }
100
101    #[test]
102    fn aabb_contains_on_surface() {
103        let a = Aabb::new([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]);
104        assert!(a.contains([1.0, 0.5, 0.5]));
105    }
106
107    #[test]
108    fn aabb_surface_area_unit_cube() {
109        let a = Aabb::new([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]);
110        assert!((a.surface_area() - 6.0).abs() < 1e-6);
111    }
112
113    #[test]
114    fn aabb_surface_area_flat() {
115        // 2×3×0 slab: area = 2*(2*3 + 3*0 + 0*2) = 12
116        let a = Aabb::new([0.0, 0.0, 0.0], [2.0, 3.0, 0.0]);
117        assert!((a.surface_area() - 12.0).abs() < 1e-6);
118    }
119
120    #[test]
121    fn aabb_center_correct() {
122        let a = Aabb::new([0.0, 0.0, 0.0], [2.0, 4.0, 6.0]);
123        let c = a.center();
124        assert!((c[0] - 1.0).abs() < 1e-6);
125        assert!((c[1] - 2.0).abs() < 1e-6);
126        assert!((c[2] - 3.0).abs() < 1e-6);
127    }
128
129    #[test]
130    fn aabb_expand_increases_bounds() {
131        let a = Aabb::new([1.0, 1.0, 1.0], [2.0, 2.0, 2.0]);
132        let e = a.expand(0.5);
133        assert_eq!(e.min, [0.5, 0.5, 0.5]);
134        assert_eq!(e.max, [2.5, 2.5, 2.5]);
135    }
136
137    // ------------------------------------------------------------------
138    // SAH cost
139    // ------------------------------------------------------------------
140
141    #[test]
142    fn sah_cost_balanced() {
143        // Both halves equal area and count → cost == n_left + n_right
144        let cost = sah_cost(4, 1.0, 4, 1.0, 2.0);
145        // (1/2)*4 + (1/2)*4 = 4
146        assert!((cost - 4.0).abs() < 1e-6);
147    }
148
149    #[test]
150    fn sah_cost_zero_parent_area_returns_max() {
151        let cost = sah_cost(1, 1.0, 1, 1.0, 0.0);
152        assert_eq!(cost, f32::MAX);
153    }
154
155    // ------------------------------------------------------------------
156    // Ray–AABB slab intersection
157    // ------------------------------------------------------------------
158
159    #[test]
160    fn ray_hits_unit_cube() {
161        let aabb = Aabb::new([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]);
162        let origin = [-1.0, 0.5, 0.5];
163        let dir = [1.0, 0.0, 0.0];
164        let inv = [1.0 / dir[0], 1.0 / dir[1], 1.0 / dir[2]];
165        assert!(ray_aabb_intersect(origin, inv, &aabb, 10.0));
166    }
167
168    #[test]
169    fn ray_misses_unit_cube() {
170        let aabb = Aabb::new([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]);
171        let origin = [-1.0, 2.0, 0.5];
172        let dir = [1.0, 0.0, 0.0];
173        let inv = [1.0 / dir[0], 1.0 / dir[1], 1.0 / dir[2]];
174        assert!(!ray_aabb_intersect(origin, inv, &aabb, 10.0));
175    }
176
177    #[test]
178    fn ray_too_short_misses() {
179        let aabb = Aabb::new([5.0, 0.0, 0.0], [6.0, 1.0, 1.0]);
180        let origin = [0.0, 0.5, 0.5];
181        let dir = [1.0, 0.0, 0.0];
182        let inv = [1.0 / dir[0], 1.0 / dir[1], 1.0 / dir[2]];
183        assert!(!ray_aabb_intersect(origin, inv, &aabb, 3.0));
184    }
185
186    // ------------------------------------------------------------------
187    // Bvh build / query
188    // ------------------------------------------------------------------
189
190    fn make_grid_primitives(n: usize) -> Vec<BvhPrimitive> {
191        (0..n)
192            .map(|i| {
193                let x = i as f32;
194                BvhPrimitive::new(Aabb::new([x, 0.0, 0.0], [x + 1.0, 1.0, 1.0]), i)
195            })
196            .collect()
197    }
198
199    #[test]
200    fn bvh_build_empty() {
201        let bvh = Bvh::build(vec![]);
202        assert!(bvh.root.is_none());
203        assert_eq!(bvh.node_count(), 0);
204        assert_eq!(bvh.depth(), 0);
205    }
206
207    #[test]
208    fn bvh_build_single() {
209        let prims = make_grid_primitives(1);
210        let bvh = Bvh::build(prims);
211        assert!(bvh.root.is_some());
212        assert!(bvh.root.as_ref().unwrap().is_leaf());
213        assert_eq!(bvh.node_count(), 1);
214        assert_eq!(bvh.depth(), 1);
215    }
216
217    #[test]
218    fn bvh_query_aabb_finds_overlap() {
219        let prims = make_grid_primitives(10);
220        let bvh = Bvh::build(prims);
221        // Query the box that overlaps object 5 only.
222        let query = Aabb::new([5.1, 0.1, 0.1], [5.9, 0.9, 0.9]);
223        let mut hits = bvh.query_aabb(&query);
224        hits.sort();
225        assert_eq!(hits, vec![5]);
226    }
227
228    #[test]
229    fn bvh_query_aabb_empty_result() {
230        let prims = make_grid_primitives(5);
231        let bvh = Bvh::build(prims);
232        let query = Aabb::new([100.0, 0.0, 0.0], [101.0, 1.0, 1.0]);
233        assert!(bvh.query_aabb(&query).is_empty());
234    }
235
236    #[test]
237    fn bvh_query_aabb_finds_multiple() {
238        let prims = make_grid_primitives(10);
239        let bvh = Bvh::build(prims);
240        // Query spanning objects 2, 3, 4
241        let query = Aabb::new([2.1, 0.1, 0.1], [4.9, 0.9, 0.9]);
242        let mut hits = bvh.query_aabb(&query);
243        hits.sort();
244        assert_eq!(hits, vec![2, 3, 4]);
245    }
246
247    #[test]
248    fn bvh_query_ray_hits() {
249        let prims = make_grid_primitives(8);
250        let bvh = Bvh::build(prims);
251        // Ray along X axis at y=0.5, z=0.5 hits all 8 primitives.
252        let mut hits = bvh.query_ray([-1.0, 0.5, 0.5], [1.0, 0.0, 0.0], 20.0);
253        hits.sort();
254        assert_eq!(hits, (0..8).collect::<Vec<_>>());
255    }
256
257    #[test]
258    fn bvh_query_ray_misses() {
259        let prims = make_grid_primitives(5);
260        let bvh = Bvh::build(prims);
261        let hits = bvh.query_ray([0.5, 10.0, 0.5], [0.0, 1.0, 0.0], 100.0);
262        assert!(hits.is_empty());
263    }
264
265    #[test]
266    fn bvh_node_count_and_depth_consistent() {
267        let prims = make_grid_primitives(16);
268        let bvh = Bvh::build(prims);
269        // With LEAF_SIZE=4 and 16 prims, depth should be at least 2.
270        assert!(bvh.depth() >= 2);
271        // An n-primitive tree has at most 2n-1 nodes.
272        assert!(bvh.node_count() < 2 * 16);
273    }
274
275    // ------------------------------------------------------------------
276    // Flat BVH
277    // ------------------------------------------------------------------
278
279    #[test]
280    fn flatten_empty_bvh() {
281        let bvh = Bvh::build(vec![]);
282        let (nodes, prim_indices) = flatten(&bvh);
283        assert!(nodes.is_empty());
284        assert!(prim_indices.is_empty());
285    }
286
287    #[test]
288    fn flatten_single_primitive() {
289        let prims = make_grid_primitives(1);
290        let bvh = Bvh::build(prims);
291        let (nodes, prim_indices) = flatten(&bvh);
292        assert_eq!(nodes.len(), 1);
293        assert_eq!(prim_indices.len(), 1);
294        assert_eq!(nodes[0].count, 1);
295    }
296
297    #[test]
298    fn query_flat_finds_overlap() {
299        let prims = make_grid_primitives(10);
300        let bvh = Bvh::build(prims);
301        let (nodes, prim_indices) = flatten(&bvh);
302        let query = Aabb::new([3.1, 0.1, 0.1], [3.9, 0.9, 0.9]);
303        let mut hits = query_flat(&nodes, &prim_indices, &bvh.primitives, &query);
304        hits.sort();
305        assert_eq!(hits, vec![3]);
306    }
307
308    #[test]
309    fn query_flat_empty_result() {
310        let prims = make_grid_primitives(5);
311        let bvh = Bvh::build(prims);
312        let (nodes, prim_indices) = flatten(&bvh);
313        let query = Aabb::new([50.0, 0.0, 0.0], [51.0, 1.0, 1.0]);
314        assert!(query_flat(&nodes, &prim_indices, &bvh.primitives, &query).is_empty());
315    }
316
317    #[test]
318    fn query_flat_matches_recursive() {
319        let prims = make_grid_primitives(20);
320        let bvh = Bvh::build(prims);
321        let query = Aabb::new([7.1, 0.0, 0.0], [12.9, 1.0, 1.0]);
322        let mut recursive_hits = bvh.query_aabb(&query);
323        recursive_hits.sort();
324
325        let (nodes, prim_indices) = flatten(&bvh);
326        let mut flat_hits = query_flat(&nodes, &prim_indices, &bvh.primitives, &query);
327        flat_hits.sort();
328
329        assert_eq!(recursive_hits, flat_hits);
330    }
331
332    // ------------------------------------------------------------------
333    // Morton code tests
334    // ------------------------------------------------------------------
335
336    #[test]
337    fn morton_origin_is_zero() {
338        assert_eq!(morton_code([0.0, 0.0, 0.0]), 0);
339    }
340
341    #[test]
342    fn morton_increases_along_x() {
343        let m0 = morton_code([0.0, 0.0, 0.0]);
344        let m1 = morton_code([0.5, 0.0, 0.0]);
345        let m2 = morton_code([1.0, 0.0, 0.0]);
346        assert!(m0 <= m1, "m0={} m1={}", m0, m1);
347        assert!(m1 <= m2, "m1={} m2={}", m1, m2);
348    }
349
350    #[test]
351    fn morton_clamps_outside_unit_cube() {
352        let m_neg = morton_code([-1.0, -1.0, -1.0]);
353        let m_zero = morton_code([0.0, 0.0, 0.0]);
354        assert_eq!(m_neg, m_zero);
355
356        let m_big = morton_code([2.0, 2.0, 2.0]);
357        let m_one = morton_code([1.0, 1.0, 1.0]);
358        assert_eq!(m_big, m_one);
359    }
360
361    // ------------------------------------------------------------------
362    // LBVH construction tests
363    // ------------------------------------------------------------------
364
365    #[test]
366    fn lbvh_build_empty() {
367        let bvh = lbvh_build(vec![]);
368        assert!(bvh.root.is_none());
369    }
370
371    #[test]
372    fn lbvh_build_single() {
373        let prims = make_grid_primitives(1);
374        let bvh = lbvh_build(prims);
375        assert!(bvh.root.is_some());
376        assert!(bvh.root.as_ref().unwrap().is_leaf());
377    }
378
379    #[test]
380    fn lbvh_build_query_finds_correct_objects() {
381        let prims = make_grid_primitives(10);
382        let bvh = lbvh_build(prims);
383        let query = Aabb::new([4.1, 0.1, 0.1], [4.9, 0.9, 0.9]);
384        let mut hits = bvh.query_aabb(&query);
385        hits.sort();
386        assert_eq!(hits, vec![4]);
387    }
388
389    #[test]
390    fn lbvh_build_covers_all_primitives() {
391        let prims = make_grid_primitives(8);
392        let bvh = lbvh_build(prims);
393        // Root AABB should contain all primitives.
394        let root = bvh.root.as_ref().unwrap();
395        assert!(root.aabb.min[0] <= 0.0);
396        assert!(root.aabb.max[0] >= 8.0);
397    }
398
399    // ------------------------------------------------------------------
400    // BVH closest-hit traversal
401    // ------------------------------------------------------------------
402
403    #[test]
404    fn closest_hit_returns_nearest() {
405        let prims = make_grid_primitives(10);
406        let bvh = Bvh::build(prims);
407        // Ray along X from x=-1: should hit object 0 first (x ∈ [0,1])
408        let hit = bvh_closest_hit(&bvh, [-1.0, 0.5, 0.5], [1.0, 0.0, 0.0], 100.0);
409        assert!(hit.is_some(), "ray should hit something");
410        let hit = hit.unwrap();
411        assert_eq!(
412            hit.object_id, 0,
413            "closest hit should be object 0, got {}",
414            hit.object_id
415        );
416    }
417
418    #[test]
419    fn closest_hit_misses_returns_none() {
420        let prims = make_grid_primitives(5);
421        let bvh = Bvh::build(prims);
422        let hit = bvh_closest_hit(&bvh, [0.5, 10.0, 0.5], [0.0, 1.0, 0.0], 100.0);
423        assert!(hit.is_none());
424    }
425
426    #[test]
427    fn closest_hit_empty_bvh_returns_none() {
428        let bvh = Bvh::build(vec![]);
429        let hit = bvh_closest_hit(&bvh, [0.0, 0.0, 0.0], [1.0, 0.0, 0.0], 100.0);
430        assert!(hit.is_none());
431    }
432
433    #[test]
434    fn closest_hit_t_is_positive() {
435        let prims = make_grid_primitives(5);
436        let bvh = Bvh::build(prims);
437        let hit = bvh_closest_hit(&bvh, [-1.0, 0.5, 0.5], [1.0, 0.0, 0.0], 100.0);
438        if let Some(h) = hit {
439            assert!(h.t >= 0.0, "hit t should be non-negative, got {}", h.t);
440        }
441    }
442
443    // ------------------------------------------------------------------
444    // HLBVH split
445    // ------------------------------------------------------------------
446
447    #[test]
448    fn hlbvh_split_single_pair_splits_at_1() {
449        let mortons = vec![0u32, 1u32];
450        assert_eq!(hlbvh_split(&mortons), 1);
451    }
452
453    #[test]
454    fn hlbvh_split_identical_codes_splits_at_end() {
455        let mortons = vec![42u32; 4];
456        let s = hlbvh_split(&mortons);
457        assert!(s > 0 && s < mortons.len());
458    }
459
460    #[test]
461    fn hlbvh_split_returns_valid_range() {
462        let mortons: Vec<u32> = (0..16).map(|i| i * 2).collect();
463        let s = hlbvh_split(&mortons);
464        assert!(s > 0, "split must be > 0");
465        assert!(s < mortons.len(), "split must be < len");
466    }
467
468    // ------------------------------------------------------------------
469    // BVH refit
470    // ------------------------------------------------------------------
471
472    #[test]
473    fn refit_does_not_panic_on_leaf() {
474        let prims = make_grid_primitives(2);
475        let mut bvh = Bvh::build(prims.clone());
476        if let Some(root) = bvh.root.as_mut() {
477            refit(root, &bvh.primitives.clone());
478        }
479        // No panic = pass.
480    }
481
482    // ------------------------------------------------------------------
483    // BvhStats
484    // ------------------------------------------------------------------
485
486    #[test]
487    fn bvh_stats_empty_tree() {
488        let bvh = Bvh::build(vec![]);
489        let s = BvhStats::compute(&bvh);
490        assert_eq!(s.node_count, 0);
491        assert_eq!(s.leaf_count, 0);
492    }
493
494    #[test]
495    fn bvh_stats_single_leaf() {
496        let prims = make_grid_primitives(1);
497        let bvh = Bvh::build(prims);
498        let s = BvhStats::compute(&bvh);
499        assert_eq!(s.leaf_count, 1);
500        assert_eq!(s.total_primitives, 1);
501    }
502
503    #[test]
504    fn bvh_stats_counts_consistent() {
505        let prims = make_grid_primitives(16);
506        let bvh = Bvh::build(prims);
507        let s = BvhStats::compute(&bvh);
508        assert_eq!(s.leaf_count + s.internal_count, s.node_count);
509    }
510
511    #[test]
512    fn bvh_stats_total_primitives() {
513        let prims = make_grid_primitives(16);
514        let bvh = Bvh::build(prims);
515        let s = BvhStats::compute(&bvh);
516        assert_eq!(s.total_primitives, 16);
517    }
518
519    // ------------------------------------------------------------------
520    // BvhTreeStatistics
521    // ------------------------------------------------------------------
522
523    #[test]
524    fn bvh_tree_stats_empty() {
525        let bvh = Bvh::build(vec![]);
526        let s = BvhTreeStatistics::compute(&bvh);
527        assert_eq!(s.node_count, 0);
528    }
529
530    #[test]
531    fn bvh_tree_stats_single() {
532        let prims = make_grid_primitives(1);
533        let bvh = Bvh::build(prims);
534        let s = BvhTreeStatistics::compute(&bvh);
535        assert_eq!(s.leaf_count, 1);
536        assert_eq!(s.total_primitives, 1);
537    }
538
539    #[test]
540    fn bvh_tree_stats_consistent() {
541        let prims = make_grid_primitives(32);
542        let bvh = Bvh::build(prims.clone());
543        let s = BvhTreeStatistics::compute(&bvh);
544        assert_eq!(s.leaf_count + s.internal_count, s.node_count);
545        assert_eq!(s.total_primitives, prims.len());
546    }
547
548    #[test]
549    fn bvh_tree_stats_leaf_surface_area_positive() {
550        let prims = make_grid_primitives(8);
551        let bvh = Bvh::build(prims);
552        let s = BvhTreeStatistics::compute(&bvh);
553        assert!(
554            s.total_leaf_surface_area > 0.0,
555            "leaf surface area should be > 0"
556        );
557    }
558
559    // ------------------------------------------------------------------
560    // build_morton_clusters tests
561    // ------------------------------------------------------------------
562
563    fn make_sorted_lbvh_prims(n: usize) -> Vec<LbvhPrimitive> {
564        let scene = Aabb::new([0.0, 0.0, 0.0], [n as f32 + 1.0, 1.0, 1.0]);
565        let mut prims: Vec<LbvhPrimitive> = (0..n)
566            .map(|i| {
567                let x = i as f32;
568                LbvhPrimitive::new(Aabb::new([x, 0.0, 0.0], [x + 1.0, 1.0, 1.0]), i, &scene)
569            })
570            .collect();
571        prims.sort_unstable_by_key(|lp| lp.morton);
572        prims
573    }
574
575    #[test]
576    fn build_morton_clusters_empty() {
577        let clusters = build_morton_clusters(&[], 4);
578        assert!(clusters.is_empty());
579    }
580
581    #[test]
582    fn build_morton_clusters_count() {
583        let sorted = make_sorted_lbvh_prims(10);
584        let clusters = build_morton_clusters(&sorted, 3);
585        // 10 primitives in chunks of 3 → ceil(10/3) = 4 clusters
586        assert_eq!(clusters.len(), 4);
587    }
588
589    #[test]
590    fn build_morton_clusters_radii_non_negative() {
591        let sorted = make_sorted_lbvh_prims(8);
592        let clusters = build_morton_clusters(&sorted, 2);
593        for c in &clusters {
594            assert!(c.radius >= 0.0, "cluster radius must be non-negative");
595        }
596    }
597
598    // ------------------------------------------------------------------
599    // LbvhPrimitive Morton code assignment
600    // ------------------------------------------------------------------
601
602    #[test]
603    fn lbvh_primitive_morton_in_range() {
604        let aabb = Aabb::new([0.5, 0.5, 0.5], [1.0, 1.0, 1.0]);
605        let scene = Aabb::new([0.0, 0.0, 0.0], [2.0, 2.0, 2.0]);
606        let lp = LbvhPrimitive::new(aabb, 0, &scene);
607        // Morton code is 30-bit so max is (1<<30)-1
608        assert!(lp.morton < (1u32 << 30));
609    }
610
611    #[test]
612    fn lbvh_primitive_at_origin_small_code() {
613        let aabb = Aabb::point([0.0, 0.0, 0.0]);
614        let scene = Aabb::new([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]);
615        let lp = LbvhPrimitive::new(aabb, 0, &scene);
616        assert_eq!(lp.morton, 0);
617    }
618
619    // ------------------------------------------------------------------
620    // compute_bvh_from_sorted tests
621    // ------------------------------------------------------------------
622
623    #[test]
624    fn compute_bvh_from_sorted_empty() {
625        let bvh = compute_bvh_from_sorted(&[]);
626        assert!(bvh.root.is_none());
627        assert_eq!(bvh.primitives.len(), 0);
628    }
629
630    #[test]
631    fn compute_bvh_from_sorted_single() {
632        let scene = Aabb::new([0.0, 0.0, 0.0], [2.0, 1.0, 1.0]);
633        let lp = LbvhPrimitive::new(Aabb::new([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]), 7, &scene);
634        let bvh = compute_bvh_from_sorted(&[lp]);
635        assert!(bvh.root.is_some());
636        assert_eq!(bvh.primitives[0].object_id, 7);
637    }
638
639    #[test]
640    fn compute_bvh_from_sorted_covers_all() {
641        let sorted = make_sorted_lbvh_prims(8);
642        let bvh = compute_bvh_from_sorted(&sorted);
643        assert_eq!(bvh.primitives.len(), 8);
644    }
645
646    // ------------------------------------------------------------------
647    // BvhGpuTraverser tests
648    // ------------------------------------------------------------------
649
650    /// Build 1000 AABBs, create BVH, fire 100 rays.
651    /// Compare GPU results vs CPU reference traversal — assert identical hit-leaf IDs.
652    #[test]
653    fn test_bvh_gpu_matches_cpu() {
654        // Build 1000 axis-aligned boxes
655        let prims: Vec<BvhPrimitive> = (0..1000)
656            .map(|i| {
657                let x = (i % 10) as f32 * 2.0;
658                let y = ((i / 10) % 10) as f32 * 2.0;
659                let z = (i / 100) as f32 * 2.0;
660                BvhPrimitive::new(Aabb::new([x, y, z], [x + 1.0, y + 1.0, z + 1.0]), i)
661            })
662            .collect();
663
664        let bvh = Bvh::build(prims);
665
666        let cpu_traverser = BvhGpuTraverser::new_cpu(&bvh);
667        let gpu_traverser = BvhGpuTraverser::new(&bvh);
668
669        // 100 test rays
670        let rays: Vec<GpuRay> = (0..100)
671            .map(|i| {
672                let t = i as f32 * 0.19;
673                let x = (t * 18.0) % 20.0;
674                let y = (t * 7.3) % 20.0;
675                GpuRay::new([x, y, -1.0], [0.0, 0.0, 1.0], 100.0)
676            })
677            .collect();
678
679        let cpu_hits = cpu_traverser.traverse_rays(&rays);
680        let gpu_hits = gpu_traverser.traverse_rays(&rays);
681
682        assert_eq!(cpu_hits.len(), gpu_hits.len());
683
684        // If both CPU and GPU found a hit, they must agree.
685        for (i, (&cpu_hit, &gpu_hit)) in cpu_hits.iter().zip(gpu_hits.iter()).enumerate() {
686            if cpu_hit >= 0 && gpu_hit >= 0 {
687                assert_eq!(
688                    cpu_hit, gpu_hit,
689                    "ray {} hit mismatch: cpu={cpu_hit} gpu={gpu_hit}",
690                    i
691                );
692            }
693            if gpu_hit >= 0 {
694                assert!(gpu_hit >= 0, "ray {i}: gpu returned invalid id {gpu_hit}");
695            }
696        }
697    }
698
699    #[test]
700    fn test_bvh_gpu_traverser_cpu_fallback() {
701        let prims = make_grid_primitives(16);
702        let bvh = Bvh::build(prims);
703        let traverser = BvhGpuTraverser::new_cpu(&bvh);
704        assert!(!traverser.is_gpu());
705
706        let rays = vec![GpuRay::new([0.5, 0.5, -1.0], [0.0, 0.0, 1.0], 100.0)];
707        let hits = traverser.traverse_rays(&rays);
708        assert_eq!(hits.len(), 1);
709        // Should find a hit (all primitives are axis-aligned boxes in the XZ plane)
710        assert!(hits[0] >= 0, "expected a hit, got -1");
711    }
712
713    #[test]
714    fn test_bvh_gpu_traverser_no_hit() {
715        let prims = vec![BvhPrimitive::new(
716            Aabb::new([10.0, 10.0, 10.0], [11.0, 11.0, 11.0]),
717            42,
718        )];
719        let bvh = Bvh::build(prims);
720        let traverser = BvhGpuTraverser::new_cpu(&bvh);
721
722        // Ray that misses the box completely
723        let rays = vec![GpuRay::new([0.0, 0.0, -1.0], [0.0, 0.0, 1.0], 5.0)];
724        let hits = traverser.traverse_rays(&rays);
725        assert_eq!(hits[0], -1, "expected no hit, got {}", hits[0]);
726    }
727}