Skip to main content

oxiphysics_gpu/
bvh.rs

1// Copyright 2026 COOLJAPAN OU (Team KitaSan)
2// SPDX-License-Identifier: Apache-2.0
3
4//! CPU BVH (Bounding Volume Hierarchy) tree for broad-phase acceleration.
5//!
6//! Provides axis-aligned bounding box (AABB) structures, a recursive BVH tree
7//! built with a SAH-inspired median split, ray and AABB queries, and a
8//! linearised flat representation for cache-friendly traversal.
9
10// --------------------------------------------------------------------------
11// Silence dead-code warnings for public API that may not be exercised at link
12// time within this crate.
13// --------------------------------------------------------------------------
14#![allow(dead_code)]
15
16// ============================================================================
17// Aabb
18// ============================================================================
19
20/// An axis-aligned bounding box stored as two `[f32; 3]` corners.
21#[derive(Debug, Clone, PartialEq)]
22pub struct Aabb {
23    /// Minimum corner (component-wise).
24    pub min: [f32; 3],
25    /// Maximum corner (component-wise).
26    pub max: [f32; 3],
27}
28
29impl Aabb {
30    /// Construct an `Aabb` from explicit min/max corners.
31    pub fn new(min: [f32; 3], max: [f32; 3]) -> Self {
32        Self { min, max }
33    }
34
35    /// Construct a degenerate `Aabb` that contains exactly one point.
36    pub fn point(p: [f32; 3]) -> Self {
37        Self { min: p, max: p }
38    }
39
40    /// Return the smallest `Aabb` that contains both `a` and `b`.
41    pub fn merge(a: &Aabb, b: &Aabb) -> Aabb {
42        Aabb {
43            min: [
44                a.min[0].min(b.min[0]),
45                a.min[1].min(b.min[1]),
46                a.min[2].min(b.min[2]),
47            ],
48            max: [
49                a.max[0].max(b.max[0]),
50                a.max[1].max(b.max[1]),
51                a.max[2].max(b.max[2]),
52            ],
53        }
54    }
55
56    /// Returns `true` if this box overlaps `other` (touching counts).
57    pub fn intersects(&self, other: &Aabb) -> bool {
58        self.min[0] <= other.max[0]
59            && self.max[0] >= other.min[0]
60            && self.min[1] <= other.max[1]
61            && self.max[1] >= other.min[1]
62            && self.min[2] <= other.max[2]
63            && self.max[2] >= other.min[2]
64    }
65
66    /// Returns `true` if point `p` lies inside or on the surface of this box.
67    pub fn contains(&self, p: [f32; 3]) -> bool {
68        p[0] >= self.min[0]
69            && p[0] <= self.max[0]
70            && p[1] >= self.min[1]
71            && p[1] <= self.max[1]
72            && p[2] >= self.min[2]
73            && p[2] <= self.max[2]
74    }
75
76    /// Surface area of the box (sum of all six face areas).
77    pub fn surface_area(&self) -> f32 {
78        let dx = self.max[0] - self.min[0];
79        let dy = self.max[1] - self.min[1];
80        let dz = self.max[2] - self.min[2];
81        2.0 * (dx * dy + dy * dz + dz * dx)
82    }
83
84    /// Geometric centre of the box.
85    pub fn center(&self) -> [f32; 3] {
86        [
87            0.5 * (self.min[0] + self.max[0]),
88            0.5 * (self.min[1] + self.max[1]),
89            0.5 * (self.min[2] + self.max[2]),
90        ]
91    }
92
93    /// Return a copy of this box expanded uniformly by `margin` on every side.
94    pub fn expand(&self, margin: f32) -> Aabb {
95        Aabb {
96            min: [
97                self.min[0] - margin,
98                self.min[1] - margin,
99                self.min[2] - margin,
100            ],
101            max: [
102                self.max[0] + margin,
103                self.max[1] + margin,
104                self.max[2] + margin,
105            ],
106        }
107    }
108}
109
110// ============================================================================
111// BvhPrimitive
112// ============================================================================
113
114/// A leaf primitive: an AABB together with the logical object it belongs to.
115#[derive(Debug, Clone)]
116pub struct BvhPrimitive {
117    /// Bounding box of the primitive.
118    pub aabb: Aabb,
119    /// Caller-defined identifier (returned by queries).
120    pub object_id: usize,
121}
122
123impl BvhPrimitive {
124    /// Construct a `BvhPrimitive`.
125    pub fn new(aabb: Aabb, object_id: usize) -> Self {
126        Self { aabb, object_id }
127    }
128}
129
130// ============================================================================
131// BvhNode
132// ============================================================================
133
134/// A node in the recursive BVH tree.
135#[derive(Debug)]
136pub struct BvhNode {
137    /// Bounding box that contains all children / primitives.
138    pub aabb: Aabb,
139    /// Left subtree (internal nodes only).
140    pub left: Option<Box<BvhNode>>,
141    /// Right subtree (internal nodes only).
142    pub right: Option<Box<BvhNode>>,
143    /// Indices into `Bvh::primitives` (leaf nodes only).
144    pub primitives: Vec<usize>,
145}
146
147impl BvhNode {
148    /// Returns `true` if this node is a leaf (holds primitives directly).
149    pub fn is_leaf(&self) -> bool {
150        self.left.is_none() && self.right.is_none()
151    }
152}
153
154// ============================================================================
155// SAH helper
156// ============================================================================
157
158/// Surface Area Heuristic cost:
159/// `C = (SA_left / SA_parent) * N_left + (SA_right / SA_parent) * N_right`
160pub fn sah_cost(n_left: usize, sa_left: f32, n_right: usize, sa_right: f32, sa_parent: f32) -> f32 {
161    if sa_parent <= 0.0 {
162        return f32::MAX;
163    }
164    (sa_left / sa_parent) * n_left as f32 + (sa_right / sa_parent) * n_right as f32
165}
166
167// ============================================================================
168// Slab-method ray–AABB intersection
169// ============================================================================
170
171/// Test whether a ray defined by `origin` + t * direction intersects `aabb`
172/// within `[0, max_t]`.
173///
174/// `inv_dir` must be the component-wise reciprocal of the ray direction.
175pub fn ray_aabb_intersect(origin: [f32; 3], inv_dir: [f32; 3], aabb: &Aabb, max_t: f32) -> bool {
176    let mut t_min = 0.0_f32;
177    let mut t_max = max_t;
178
179    for i in 0..3 {
180        let t1 = (aabb.min[i] - origin[i]) * inv_dir[i];
181        let t2 = (aabb.max[i] - origin[i]) * inv_dir[i];
182        let lo = t1.min(t2);
183        let hi = t1.max(t2);
184        t_min = t_min.max(lo);
185        t_max = t_max.min(hi);
186    }
187
188    t_min <= t_max
189}
190
191// ============================================================================
192// Bvh
193// ============================================================================
194
195/// Maximum number of primitives per leaf before the tree stops splitting.
196const LEAF_SIZE: usize = 4;
197
198/// A BVH tree built from a flat list of [`BvhPrimitive`]s.
199pub struct Bvh {
200    /// Tree root (if there is at least one primitive).
201    pub root: Option<BvhNode>,
202    /// All primitives passed to [`Bvh::build`].
203    pub primitives: Vec<BvhPrimitive>,
204}
205
206impl Bvh {
207    /// Build a BVH from a list of primitives using a median-split strategy
208    /// guided by the longest axis (SAH-inspired).
209    pub fn build(primitives: Vec<BvhPrimitive>) -> Self {
210        if primitives.is_empty() {
211            return Self {
212                root: None,
213                primitives,
214            };
215        }
216        let indices: Vec<usize> = (0..primitives.len()).collect();
217        let root = build_recursive(&primitives, indices);
218        Self {
219            root: Some(root),
220            primitives,
221        }
222    }
223
224    /// Return the `object_id`s of all primitives whose AABB overlaps `query`.
225    pub fn query_aabb(&self, query: &Aabb) -> Vec<usize> {
226        let mut result = Vec::new();
227        if let Some(root) = &self.root {
228            query_aabb_recursive(root, query, &self.primitives, &mut result);
229        }
230        result
231    }
232
233    /// Return the `object_id`s of all primitives hit by the given ray within
234    /// distance `max_t`.
235    pub fn query_ray(&self, origin: [f32; 3], direction: [f32; 3], max_t: f32) -> Vec<usize> {
236        let inv_dir = [1.0 / direction[0], 1.0 / direction[1], 1.0 / direction[2]];
237        let mut result = Vec::new();
238        if let Some(root) = &self.root {
239            query_ray_recursive(root, origin, inv_dir, max_t, &self.primitives, &mut result);
240        }
241        result
242    }
243
244    /// Total number of nodes in the tree.
245    pub fn node_count(&self) -> usize {
246        match &self.root {
247            None => 0,
248            Some(root) => count_nodes(root),
249        }
250    }
251
252    /// Maximum depth of the tree (root = depth 1).
253    pub fn depth(&self) -> usize {
254        match &self.root {
255            None => 0,
256            Some(root) => node_depth(root),
257        }
258    }
259}
260
261// ============================================================================
262// Internal build / query helpers
263// ============================================================================
264
265fn bounding_box(primitives: &[BvhPrimitive], indices: &[usize]) -> Aabb {
266    let mut aabb = primitives[indices[0]].aabb.clone();
267    for &i in &indices[1..] {
268        aabb = Aabb::merge(&aabb, &primitives[i].aabb);
269    }
270    aabb
271}
272
273fn build_recursive(primitives: &[BvhPrimitive], mut indices: Vec<usize>) -> BvhNode {
274    let aabb = bounding_box(primitives, &indices);
275
276    if indices.len() <= LEAF_SIZE {
277        return BvhNode {
278            aabb,
279            left: None,
280            right: None,
281            primitives: indices,
282        };
283    }
284
285    // Choose longest axis for split.
286    let dx = aabb.max[0] - aabb.min[0];
287    let dy = aabb.max[1] - aabb.min[1];
288    let dz = aabb.max[2] - aabb.min[2];
289    let axis = if dx >= dy && dx >= dz {
290        0
291    } else if dy >= dz {
292        1
293    } else {
294        2
295    };
296
297    // Median split by centre of primitive AABB along the chosen axis.
298    indices.sort_unstable_by(|&a, &b| {
299        let ca = primitives[a].aabb.center()[axis];
300        let cb = primitives[b].aabb.center()[axis];
301        ca.partial_cmp(&cb).unwrap_or(std::cmp::Ordering::Equal)
302    });
303
304    let mid = indices.len() / 2;
305    let right_indices = indices.split_off(mid);
306    let left_indices = indices;
307
308    let left = build_recursive(primitives, left_indices);
309    let right = build_recursive(primitives, right_indices);
310
311    BvhNode {
312        aabb,
313        left: Some(Box::new(left)),
314        right: Some(Box::new(right)),
315        primitives: Vec::new(),
316    }
317}
318
319fn query_aabb_recursive(
320    node: &BvhNode,
321    query: &Aabb,
322    primitives: &[BvhPrimitive],
323    result: &mut Vec<usize>,
324) {
325    if !node.aabb.intersects(query) {
326        return;
327    }
328    if node.is_leaf() {
329        for &idx in &node.primitives {
330            if primitives[idx].aabb.intersects(query) {
331                result.push(primitives[idx].object_id);
332            }
333        }
334    } else {
335        if let Some(left) = &node.left {
336            query_aabb_recursive(left, query, primitives, result);
337        }
338        if let Some(right) = &node.right {
339            query_aabb_recursive(right, query, primitives, result);
340        }
341    }
342}
343
344fn query_ray_recursive(
345    node: &BvhNode,
346    origin: [f32; 3],
347    inv_dir: [f32; 3],
348    max_t: f32,
349    primitives: &[BvhPrimitive],
350    result: &mut Vec<usize>,
351) {
352    if !ray_aabb_intersect(origin, inv_dir, &node.aabb, max_t) {
353        return;
354    }
355    if node.is_leaf() {
356        for &idx in &node.primitives {
357            if ray_aabb_intersect(origin, inv_dir, &primitives[idx].aabb, max_t) {
358                result.push(primitives[idx].object_id);
359            }
360        }
361    } else {
362        if let Some(left) = &node.left {
363            query_ray_recursive(left, origin, inv_dir, max_t, primitives, result);
364        }
365        if let Some(right) = &node.right {
366            query_ray_recursive(right, origin, inv_dir, max_t, primitives, result);
367        }
368    }
369}
370
371fn count_nodes(node: &BvhNode) -> usize {
372    1 + node.left.as_ref().map_or(0, |n| count_nodes(n))
373        + node.right.as_ref().map_or(0, |n| count_nodes(n))
374}
375
376fn node_depth(node: &BvhNode) -> usize {
377    1 + node
378        .left
379        .as_ref()
380        .map_or(0, |n| node_depth(n))
381        .max(node.right.as_ref().map_or(0, |n| node_depth(n)))
382}
383
384// ============================================================================
385// Flat BVH
386// ============================================================================
387
388/// A single node in the linearised (flat) BVH representation.
389///
390/// Layout:
391/// * If `count == 0` this is an **internal** node.
392///   - The **left** child is always at index `node_idx + 1` (i.e. stored
393///     immediately after the parent in DFS pre-order).
394///   - `left_first` holds the index of the **right** child.
395/// * If `count > 0` this is a **leaf**; `left_first` is the start index into
396///   the accompanying primitive-index slice and `count` is the number of
397///   entries.
398#[derive(Debug, Clone)]
399pub struct FlatBvhNode {
400    /// Bounding box of this node.
401    pub aabb: Aabb,
402    /// Right-child index (internal) or first-primitive index (leaf).
403    pub left_first: u32,
404    /// 0 for internal nodes; number of primitives for leaf nodes.
405    pub count: u32,
406}
407
408/// Flatten a [`Bvh`] into a `Vec<FlatBvhNode>` (DFS pre-order) together with
409/// a reordered primitive-index slice.
410///
411/// Returns `(flat_nodes, prim_indices)` where `prim_indices[i]` is an index
412/// into `bvh.primitives`.
413pub fn flatten(bvh: &Bvh) -> (Vec<FlatBvhNode>, Vec<usize>) {
414    let mut nodes: Vec<FlatBvhNode> = Vec::new();
415    let mut prim_indices: Vec<usize> = Vec::new();
416
417    if let Some(root) = &bvh.root {
418        flatten_recursive(root, &mut nodes, &mut prim_indices);
419    }
420
421    (nodes, prim_indices)
422}
423
424/// Returns the index at which `node` was stored.
425fn flatten_recursive(
426    node: &BvhNode,
427    nodes: &mut Vec<FlatBvhNode>,
428    prim_indices: &mut Vec<usize>,
429) -> usize {
430    let node_idx = nodes.len();
431
432    if node.is_leaf() {
433        let first = prim_indices.len() as u32;
434        let count = node.primitives.len() as u32;
435        prim_indices.extend_from_slice(&node.primitives);
436        nodes.push(FlatBvhNode {
437            aabb: node.aabb.clone(),
438            left_first: first,
439            count,
440        });
441    } else {
442        // Reserve a slot; left_first (right child index) filled after recursion.
443        nodes.push(FlatBvhNode {
444            aabb: node.aabb.clone(),
445            left_first: 0,
446            count: 0,
447        });
448        // Left child is always node_idx + 1 (no explicit storage needed).
449        if let Some(left) = &node.left {
450            flatten_recursive(left, nodes, prim_indices);
451        }
452        // Right child comes after the entire left subtree.
453        let right_idx = if let Some(right) = &node.right {
454            flatten_recursive(right, nodes, prim_indices)
455        } else {
456            0
457        };
458        nodes[node_idx].left_first = right_idx as u32;
459    }
460
461    node_idx
462}
463
464/// Iterative AABB query over a flat BVH.
465///
466/// Returns the `object_id` values of all primitives whose AABB overlaps
467/// `query`. `bvh_primitives` is the `Bvh::primitives` slice.
468pub fn query_flat(
469    nodes: &[FlatBvhNode],
470    prim_indices: &[usize],
471    bvh_primitives: &[BvhPrimitive],
472    query: &Aabb,
473) -> Vec<usize> {
474    let mut result = Vec::new();
475    if nodes.is_empty() {
476        return result;
477    }
478
479    let mut stack: Vec<usize> = Vec::with_capacity(64);
480    stack.push(0);
481
482    while let Some(idx) = stack.pop() {
483        let node = &nodes[idx];
484        if !node.aabb.intersects(query) {
485            continue;
486        }
487        if node.count > 0 {
488            // Leaf
489            let start = node.left_first as usize;
490            let end = start + node.count as usize;
491            for &pi in &prim_indices[start..end] {
492                if bvh_primitives[pi].aabb.intersects(query) {
493                    result.push(bvh_primitives[pi].object_id);
494                }
495            }
496        } else {
497            // Internal: left child is at idx+1, right child at left_first.
498            let right = node.left_first as usize;
499            stack.push(right);
500            stack.push(idx + 1);
501        }
502    }
503
504    result
505}
506
507// ============================================================================
508// Morton code (LBVH)
509// ============================================================================
510
511/// Expand a 10-bit integer into 30 bits by inserting two zeros before each bit.
512fn expand_bits(mut v: u32) -> u32 {
513    v = (v | (v << 16)) & 0x030000FF;
514    v = (v | (v << 8)) & 0x0300F00F;
515    v = (v | (v << 4)) & 0x030C30C3;
516    v = (v | (v << 2)) & 0x09249249;
517    v
518}
519
520/// Compute a 30-bit Morton code for a 3D point normalised to \[0, 1\]^3.
521pub fn morton_code(p: [f32; 3]) -> u32 {
522    let x = (p[0].clamp(0.0, 1.0) * 1023.0) as u32;
523    let y = (p[1].clamp(0.0, 1.0) * 1023.0) as u32;
524    let z = (p[2].clamp(0.0, 1.0) * 1023.0) as u32;
525    expand_bits(x) | (expand_bits(y) << 1) | (expand_bits(z) << 2)
526}
527
528/// LBVH primitive: AABB + Morton code.
529#[derive(Debug, Clone)]
530pub struct LbvhPrimitive {
531    /// Bounding box.
532    pub aabb: Aabb,
533    /// Caller-defined object ID.
534    pub object_id: usize,
535    /// 30-bit Morton code computed from the AABB centroid.
536    pub morton: u32,
537}
538
539impl LbvhPrimitive {
540    /// Construct an `LbvhPrimitive`, computing the Morton code from the centroid
541    /// normalised by `scene_aabb`.
542    pub fn new(aabb: Aabb, object_id: usize, scene_aabb: &Aabb) -> Self {
543        let c = aabb.center();
544        let scene_size = [
545            (scene_aabb.max[0] - scene_aabb.min[0]).max(1e-10),
546            (scene_aabb.max[1] - scene_aabb.min[1]).max(1e-10),
547            (scene_aabb.max[2] - scene_aabb.min[2]).max(1e-10),
548        ];
549        let norm = [
550            (c[0] - scene_aabb.min[0]) / scene_size[0],
551            (c[1] - scene_aabb.min[1]) / scene_size[1],
552            (c[2] - scene_aabb.min[2]) / scene_size[2],
553        ];
554        let morton = morton_code(norm);
555        Self {
556            aabb,
557            object_id,
558            morton,
559        }
560    }
561}
562
563/// Build an LBVH (Linear BVH) from a set of primitives using Morton-code
564/// sorting and a recursive binary splitting strategy.
565///
566/// Returns a standard [`Bvh`] so the same query functions can be used.
567pub fn lbvh_build(primitives: Vec<BvhPrimitive>) -> Bvh {
568    if primitives.is_empty() {
569        return Bvh {
570            root: None,
571            primitives,
572        };
573    }
574
575    // Compute scene bounding box.
576    let mut scene = primitives[0].aabb.clone();
577    for p in &primitives[1..] {
578        scene = Aabb::merge(&scene, &p.aabb);
579    }
580
581    // Assign Morton codes and sort.
582    let mut indexed: Vec<(u32, usize)> = primitives
583        .iter()
584        .enumerate()
585        .map(|(i, p)| {
586            let lp = LbvhPrimitive::new(p.aabb.clone(), p.object_id, &scene);
587            (lp.morton, i)
588        })
589        .collect();
590    indexed.sort_unstable_by_key(|&(m, _)| m);
591
592    let sorted_indices: Vec<usize> = indexed.iter().map(|&(_, i)| i).collect();
593    let root = lbvh_recursive(&primitives, &sorted_indices);
594
595    Bvh {
596        root: Some(root),
597        primitives,
598    }
599}
600
601fn lbvh_recursive(primitives: &[BvhPrimitive], indices: &[usize]) -> BvhNode {
602    let aabb = bounding_box(primitives, indices);
603
604    if indices.len() <= LEAF_SIZE {
605        return BvhNode {
606            aabb,
607            left: None,
608            right: None,
609            primitives: indices.to_vec(),
610        };
611    }
612
613    let mid = indices.len() / 2;
614    let left = lbvh_recursive(primitives, &indices[..mid]);
615    let right = lbvh_recursive(primitives, &indices[mid..]);
616
617    BvhNode {
618        aabb,
619        left: Some(Box::new(left)),
620        right: Some(Box::new(right)),
621        primitives: Vec::new(),
622    }
623}
624
625// ============================================================================
626// BVH Traversal (closest hit)
627// ============================================================================
628
629/// Result of a closest-hit ray traversal.
630#[derive(Debug, Clone)]
631pub struct RayHit {
632    /// Object ID of the closest hit primitive.
633    pub object_id: usize,
634    /// Ray parameter at which the hit occurred.
635    pub t: f32,
636}
637
638/// Ray–AABB slab intersection returning the near/far t values.
639fn ray_aabb_t(origin: [f32; 3], inv_dir: [f32; 3], aabb: &Aabb) -> Option<(f32, f32)> {
640    let mut t_min = 0.0_f32;
641    let mut t_max = f32::MAX;
642    for i in 0..3 {
643        let t1 = (aabb.min[i] - origin[i]) * inv_dir[i];
644        let t2 = (aabb.max[i] - origin[i]) * inv_dir[i];
645        t_min = t_min.max(t1.min(t2));
646        t_max = t_max.min(t1.max(t2));
647    }
648    if t_min <= t_max {
649        Some((t_min, t_max))
650    } else {
651        None
652    }
653}
654
655/// Traverse the BVH returning the **closest** hit (smallest positive t).
656pub fn bvh_closest_hit(
657    bvh: &Bvh,
658    origin: [f32; 3],
659    direction: [f32; 3],
660    max_t: f32,
661) -> Option<RayHit> {
662    let inv_dir = [1.0 / direction[0], 1.0 / direction[1], 1.0 / direction[2]];
663    let root = bvh.root.as_ref()?;
664    let mut best: Option<RayHit> = None;
665    let mut current_max = max_t;
666    closest_hit_recursive(
667        root,
668        origin,
669        inv_dir,
670        &bvh.primitives,
671        &mut best,
672        &mut current_max,
673    );
674    best
675}
676
677fn closest_hit_recursive(
678    node: &BvhNode,
679    origin: [f32; 3],
680    inv_dir: [f32; 3],
681    primitives: &[BvhPrimitive],
682    best: &mut Option<RayHit>,
683    max_t: &mut f32,
684) {
685    if ray_aabb_t(origin, inv_dir, &node.aabb).is_none() {
686        return;
687    }
688    if node.is_leaf() {
689        for &idx in &node.primitives {
690            if let Some((t_min, _)) = ray_aabb_t(origin, inv_dir, &primitives[idx].aabb)
691                && t_min >= 0.0
692                && t_min < *max_t
693            {
694                *max_t = t_min;
695                *best = Some(RayHit {
696                    object_id: primitives[idx].object_id,
697                    t: t_min,
698                });
699            }
700        }
701    } else {
702        if let Some(left) = &node.left {
703            closest_hit_recursive(left, origin, inv_dir, primitives, best, max_t);
704        }
705        if let Some(right) = &node.right {
706            closest_hit_recursive(right, origin, inv_dir, primitives, best, max_t);
707        }
708    }
709}
710
711// ============================================================================
712// BVH Refit
713// ============================================================================
714
715/// Refit the bounding boxes of an existing BVH after primitives have moved.
716///
717/// The topology (splits) are preserved; only bounding boxes are recomputed.
718pub fn refit(node: &mut BvhNode, primitives: &[BvhPrimitive]) {
719    if node.is_leaf() {
720        if !node.primitives.is_empty() {
721            node.aabb = bounding_box(primitives, &node.primitives);
722        }
723        return;
724    }
725    if let Some(left) = node.left.as_mut() {
726        refit(left, primitives);
727    }
728    if let Some(right) = node.right.as_mut() {
729        refit(right, primitives);
730    }
731    // Recompute bounding box from children.
732    let left_aabb = node.left.as_ref().map(|n| n.aabb.clone());
733    let right_aabb = node.right.as_ref().map(|n| n.aabb.clone());
734    node.aabb = match (left_aabb, right_aabb) {
735        (Some(l), Some(r)) => Aabb::merge(&l, &r),
736        (Some(l), None) => l,
737        (None, Some(r)) => r,
738        (None, None) => node.aabb.clone(),
739    };
740}
741
742// ============================================================================
743// HLBVH Split (spatial median on the highest non-degenerate bit)
744// ============================================================================
745
746/// Find the split index for a slice of Morton-code-sorted primitives using
747/// the highest differing bit (HLBVH strategy).
748///
749/// Returns the split position (0 < split < len).
750pub fn hlbvh_split(mortons: &[u32]) -> usize {
751    if mortons.len() < 2 {
752        return 1;
753    }
754    let first = mortons[0];
755    let last = mortons[mortons.len() - 1];
756    let common_prefix = (first ^ last).leading_zeros();
757    // Binary search for the split where the highest bit differs.
758    let mut lo = 0usize;
759    let mut hi = mortons.len() - 1;
760    while hi - lo > 1 {
761        let mid = (lo + hi) / 2;
762        let prefix = (first ^ mortons[mid]).leading_zeros();
763        if prefix > common_prefix {
764            lo = mid;
765        } else {
766            hi = mid;
767        }
768    }
769    hi
770}
771
772// ============================================================================
773// BVH Statistics
774// ============================================================================
775
776/// Runtime statistics about a BVH tree.
777#[derive(Debug, Clone)]
778pub struct BvhStats {
779    /// Total number of nodes (internal + leaf).
780    pub node_count: usize,
781    /// Number of leaf nodes.
782    pub leaf_count: usize,
783    /// Number of internal nodes.
784    pub internal_count: usize,
785    /// Maximum tree depth.
786    pub max_depth: usize,
787    /// Total number of primitives stored across all leaves.
788    pub total_primitives: usize,
789    /// Average primitives per leaf.
790    pub avg_primitives_per_leaf: f32,
791}
792
793impl BvhStats {
794    /// Compute statistics by traversing the given BVH.
795    pub fn compute(bvh: &Bvh) -> Self {
796        let mut s = BvhStats {
797            node_count: 0,
798            leaf_count: 0,
799            internal_count: 0,
800            max_depth: 0,
801            total_primitives: 0,
802            avg_primitives_per_leaf: 0.0,
803        };
804        if let Some(root) = &bvh.root {
805            collect_stats(root, 1, &mut s);
806        }
807        if s.leaf_count > 0 {
808            s.avg_primitives_per_leaf = s.total_primitives as f32 / s.leaf_count as f32;
809        }
810        s
811    }
812}
813
814fn collect_stats(node: &BvhNode, depth: usize, s: &mut BvhStats) {
815    s.node_count += 1;
816    if depth > s.max_depth {
817        s.max_depth = depth;
818    }
819    if node.is_leaf() {
820        s.leaf_count += 1;
821        s.total_primitives += node.primitives.len();
822    } else {
823        s.internal_count += 1;
824        if let Some(left) = &node.left {
825            collect_stats(left, depth + 1, s);
826        }
827        if let Some(right) = &node.right {
828            collect_stats(right, depth + 1, s);
829        }
830    }
831}
832
833// ============================================================================
834// MortonCluster — radix-sorted BVH construction helpers
835// ============================================================================
836
837/// A cluster of Morton-coded primitives, with a pre-computed bounding radius.
838#[derive(Debug, Clone)]
839pub struct MortonCluster {
840    /// Indices of the primitives in this cluster (into the parent slice).
841    pub indices: Vec<usize>,
842    /// Axis-aligned bounding box of the cluster.
843    pub aabb: Aabb,
844    /// Bounding sphere radius (centred at the AABB centre).
845    pub radius: f32,
846}
847
848/// Build a flat BVH from a pre-sorted (by Morton code) slice of `LbvhPrimitive`s
849/// using a radix-sort-inspired clustering strategy.
850///
851/// Primitives are split at the highest differing Morton bit, producing a
852/// balanced binary tree stored as a flat `Bvh`.
853pub fn compute_bvh_from_sorted(sorted: &[LbvhPrimitive]) -> Bvh {
854    if sorted.is_empty() {
855        return Bvh {
856            root: None,
857            primitives: Vec::new(),
858        };
859    }
860
861    // Reconstruct BvhPrimitives in sorted order.
862    let primitives: Vec<BvhPrimitive> = sorted
863        .iter()
864        .map(|lp| BvhPrimitive::new(lp.aabb.clone(), lp.object_id))
865        .collect();
866
867    let mortons: Vec<u32> = sorted.iter().map(|lp| lp.morton).collect();
868    let indices: Vec<usize> = (0..primitives.len()).collect();
869    let root = bvh_from_sorted_recursive(&primitives, &indices, &mortons);
870    Bvh {
871        root: Some(root),
872        primitives,
873    }
874}
875
876fn bvh_from_sorted_recursive(
877    primitives: &[BvhPrimitive],
878    indices: &[usize],
879    mortons: &[u32],
880) -> BvhNode {
881    let aabb = bounding_box(primitives, indices);
882    if indices.len() <= LEAF_SIZE {
883        return BvhNode {
884            aabb,
885            left: None,
886            right: None,
887            primitives: indices.to_vec(),
888        };
889    }
890    // Use HLBVH-style split on Morton codes at corresponding positions.
891    let local_mortons: Vec<u32> = indices.iter().map(|&i| mortons[i]).collect();
892    let split = hlbvh_split(&local_mortons);
893    let left = bvh_from_sorted_recursive(primitives, &indices[..split], mortons);
894    let right = bvh_from_sorted_recursive(primitives, &indices[split..], mortons);
895    BvhNode {
896        aabb,
897        left: Some(Box::new(left)),
898        right: Some(Box::new(right)),
899        primitives: Vec::new(),
900    }
901}
902
903/// Compute the bounding sphere radius for a cluster of `LbvhPrimitive`s.
904///
905/// The sphere is centred at the centroid of the cluster AABB and has the
906/// minimum radius that encloses all primitive centroids.
907pub fn compute_cluster_radius(cluster: &[LbvhPrimitive]) -> f32 {
908    if cluster.is_empty() {
909        return 0.0;
910    }
911    // Compute merged AABB.
912    let mut merged = cluster[0].aabb.clone();
913    for lp in &cluster[1..] {
914        merged = Aabb::merge(&merged, &lp.aabb);
915    }
916    let cx = (merged.min[0] + merged.max[0]) * 0.5;
917    let cy = (merged.min[1] + merged.max[1]) * 0.5;
918    let cz = (merged.min[2] + merged.max[2]) * 0.5;
919
920    let mut max_dist_sq = 0.0_f32;
921    for lp in cluster {
922        let c = lp.aabb.center();
923        let dx = c[0] - cx;
924        let dy = c[1] - cy;
925        let dz = c[2] - cz;
926        let d2 = dx * dx + dy * dy + dz * dz;
927        if d2 > max_dist_sq {
928            max_dist_sq = d2;
929        }
930    }
931    max_dist_sq.sqrt()
932}
933
934/// Compute clusters by grouping Morton-sorted `LbvhPrimitive`s into chunks of
935/// `cluster_size` and returning a `MortonCluster` per group.
936pub fn build_morton_clusters(sorted: &[LbvhPrimitive], cluster_size: usize) -> Vec<MortonCluster> {
937    if sorted.is_empty() || cluster_size == 0 {
938        return Vec::new();
939    }
940    sorted
941        .chunks(cluster_size)
942        .map(|chunk| {
943            let indices: Vec<usize> = (0..chunk.len()).collect();
944            let mut aabb = chunk[0].aabb.clone();
945            for lp in &chunk[1..] {
946                aabb = Aabb::merge(&aabb, &lp.aabb);
947            }
948            let radius = compute_cluster_radius(chunk);
949            MortonCluster {
950                indices,
951                aabb,
952                radius,
953            }
954        })
955        .collect()
956}
957
958// ============================================================================
959// BvhNode tree statistics (extended)
960// ============================================================================
961
962/// Extended BVH tree statistics including average fan-out.
963#[derive(Debug, Clone)]
964pub struct BvhTreeStatistics {
965    /// Total node count.
966    pub node_count: usize,
967    /// Number of leaf nodes.
968    pub leaf_count: usize,
969    /// Number of internal nodes.
970    pub internal_count: usize,
971    /// Maximum depth from root (1-indexed).
972    pub max_depth: usize,
973    /// Total number of primitives across all leaves.
974    pub total_primitives: usize,
975    /// Average number of children per internal node (fan-out).
976    /// For a binary tree this is at most 2.
977    pub avg_fanout: f32,
978    /// Total surface area of all leaf AABBs.
979    pub total_leaf_surface_area: f32,
980}
981
982impl BvhTreeStatistics {
983    /// Compute extended tree statistics by traversing the given BVH.
984    pub fn compute(bvh: &Bvh) -> Self {
985        let mut s = BvhTreeStatistics {
986            node_count: 0,
987            leaf_count: 0,
988            internal_count: 0,
989            max_depth: 0,
990            total_primitives: 0,
991            avg_fanout: 0.0,
992            total_leaf_surface_area: 0.0,
993        };
994        if let Some(root) = &bvh.root {
995            let mut child_sum = 0usize;
996            collect_tree_stats(root, 1, &mut s, &mut child_sum);
997            s.avg_fanout = if s.internal_count > 0 {
998                child_sum as f32 / s.internal_count as f32
999            } else {
1000                0.0
1001            };
1002        }
1003        s
1004    }
1005}
1006
1007fn collect_tree_stats(
1008    node: &BvhNode,
1009    depth: usize,
1010    s: &mut BvhTreeStatistics,
1011    child_sum: &mut usize,
1012) {
1013    s.node_count += 1;
1014    if depth > s.max_depth {
1015        s.max_depth = depth;
1016    }
1017    if node.is_leaf() {
1018        s.leaf_count += 1;
1019        s.total_primitives += node.primitives.len();
1020        s.total_leaf_surface_area += node.aabb.surface_area();
1021    } else {
1022        s.internal_count += 1;
1023        let mut children = 0usize;
1024        if let Some(left) = &node.left {
1025            children += 1;
1026            collect_tree_stats(left, depth + 1, s, child_sum);
1027        }
1028        if let Some(right) = &node.right {
1029            children += 1;
1030            collect_tree_stats(right, depth + 1, s, child_sum);
1031        }
1032        *child_sum += children;
1033    }
1034}
1035
1036// ============================================================================
1037// Tests
1038// ============================================================================
1039
1040#[cfg(test)]
1041mod tests {
1042    use super::*;
1043
1044    // ------------------------------------------------------------------
1045    // Aabb tests
1046    // ------------------------------------------------------------------
1047
1048    #[test]
1049    fn aabb_new_stores_corners() {
1050        let a = Aabb::new([1.0, 2.0, 3.0], [4.0, 5.0, 6.0]);
1051        assert_eq!(a.min, [1.0, 2.0, 3.0]);
1052        assert_eq!(a.max, [4.0, 5.0, 6.0]);
1053    }
1054
1055    #[test]
1056    fn aabb_point_is_degenerate() {
1057        let p = [3.0, 3.0, 3.0];
1058        let a = Aabb::point(p);
1059        assert_eq!(a.min, p);
1060        assert_eq!(a.max, p);
1061    }
1062
1063    #[test]
1064    fn aabb_merge_covers_both() {
1065        let a = Aabb::new([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]);
1066        let b = Aabb::new([2.0, 2.0, 2.0], [3.0, 3.0, 3.0]);
1067        let m = Aabb::merge(&a, &b);
1068        assert_eq!(m.min, [0.0, 0.0, 0.0]);
1069        assert_eq!(m.max, [3.0, 3.0, 3.0]);
1070    }
1071
1072    #[test]
1073    fn aabb_intersects_overlapping() {
1074        let a = Aabb::new([0.0, 0.0, 0.0], [2.0, 2.0, 2.0]);
1075        let b = Aabb::new([1.0, 1.0, 1.0], [3.0, 3.0, 3.0]);
1076        assert!(a.intersects(&b));
1077    }
1078
1079    #[test]
1080    fn aabb_intersects_disjoint() {
1081        let a = Aabb::new([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]);
1082        let b = Aabb::new([2.0, 2.0, 2.0], [3.0, 3.0, 3.0]);
1083        assert!(!a.intersects(&b));
1084    }
1085
1086    #[test]
1087    fn aabb_intersects_touching_edge() {
1088        let a = Aabb::new([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]);
1089        let b = Aabb::new([1.0, 0.0, 0.0], [2.0, 1.0, 1.0]);
1090        assert!(a.intersects(&b));
1091    }
1092
1093    #[test]
1094    fn aabb_contains_inside() {
1095        let a = Aabb::new([0.0, 0.0, 0.0], [4.0, 4.0, 4.0]);
1096        assert!(a.contains([2.0, 2.0, 2.0]));
1097    }
1098
1099    #[test]
1100    fn aabb_contains_outside() {
1101        let a = Aabb::new([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]);
1102        assert!(!a.contains([2.0, 0.0, 0.0]));
1103    }
1104
1105    #[test]
1106    fn aabb_contains_on_surface() {
1107        let a = Aabb::new([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]);
1108        assert!(a.contains([1.0, 0.5, 0.5]));
1109    }
1110
1111    #[test]
1112    fn aabb_surface_area_unit_cube() {
1113        let a = Aabb::new([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]);
1114        assert!((a.surface_area() - 6.0).abs() < 1e-6);
1115    }
1116
1117    #[test]
1118    fn aabb_surface_area_flat() {
1119        // 2×3×0 slab: area = 2*(2*3 + 3*0 + 0*2) = 12
1120        let a = Aabb::new([0.0, 0.0, 0.0], [2.0, 3.0, 0.0]);
1121        assert!((a.surface_area() - 12.0).abs() < 1e-6);
1122    }
1123
1124    #[test]
1125    fn aabb_center_correct() {
1126        let a = Aabb::new([0.0, 0.0, 0.0], [2.0, 4.0, 6.0]);
1127        let c = a.center();
1128        assert!((c[0] - 1.0).abs() < 1e-6);
1129        assert!((c[1] - 2.0).abs() < 1e-6);
1130        assert!((c[2] - 3.0).abs() < 1e-6);
1131    }
1132
1133    #[test]
1134    fn aabb_expand_increases_bounds() {
1135        let a = Aabb::new([1.0, 1.0, 1.0], [2.0, 2.0, 2.0]);
1136        let e = a.expand(0.5);
1137        assert_eq!(e.min, [0.5, 0.5, 0.5]);
1138        assert_eq!(e.max, [2.5, 2.5, 2.5]);
1139    }
1140
1141    // ------------------------------------------------------------------
1142    // SAH cost
1143    // ------------------------------------------------------------------
1144
1145    #[test]
1146    fn sah_cost_balanced() {
1147        // Both halves equal area and count → cost == n_left + n_right
1148        let cost = sah_cost(4, 1.0, 4, 1.0, 2.0);
1149        // (1/2)*4 + (1/2)*4 = 4
1150        assert!((cost - 4.0).abs() < 1e-6);
1151    }
1152
1153    #[test]
1154    fn sah_cost_zero_parent_area_returns_max() {
1155        let cost = sah_cost(1, 1.0, 1, 1.0, 0.0);
1156        assert_eq!(cost, f32::MAX);
1157    }
1158
1159    // ------------------------------------------------------------------
1160    // Ray–AABB slab intersection
1161    // ------------------------------------------------------------------
1162
1163    #[test]
1164    fn ray_hits_unit_cube() {
1165        let aabb = Aabb::new([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]);
1166        let origin = [-1.0, 0.5, 0.5];
1167        let dir = [1.0, 0.0, 0.0];
1168        let inv = [1.0 / dir[0], 1.0 / dir[1], 1.0 / dir[2]];
1169        assert!(ray_aabb_intersect(origin, inv, &aabb, 10.0));
1170    }
1171
1172    #[test]
1173    fn ray_misses_unit_cube() {
1174        let aabb = Aabb::new([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]);
1175        let origin = [-1.0, 2.0, 0.5];
1176        let dir = [1.0, 0.0, 0.0];
1177        let inv = [1.0 / dir[0], 1.0 / dir[1], 1.0 / dir[2]];
1178        assert!(!ray_aabb_intersect(origin, inv, &aabb, 10.0));
1179    }
1180
1181    #[test]
1182    fn ray_too_short_misses() {
1183        let aabb = Aabb::new([5.0, 0.0, 0.0], [6.0, 1.0, 1.0]);
1184        let origin = [0.0, 0.5, 0.5];
1185        let dir = [1.0, 0.0, 0.0];
1186        let inv = [1.0 / dir[0], 1.0 / dir[1], 1.0 / dir[2]];
1187        assert!(!ray_aabb_intersect(origin, inv, &aabb, 3.0));
1188    }
1189
1190    // ------------------------------------------------------------------
1191    // Bvh build / query
1192    // ------------------------------------------------------------------
1193
1194    fn make_grid_primitives(n: usize) -> Vec<BvhPrimitive> {
1195        (0..n)
1196            .map(|i| {
1197                let x = i as f32;
1198                BvhPrimitive::new(Aabb::new([x, 0.0, 0.0], [x + 1.0, 1.0, 1.0]), i)
1199            })
1200            .collect()
1201    }
1202
1203    #[test]
1204    fn bvh_build_empty() {
1205        let bvh = Bvh::build(vec![]);
1206        assert!(bvh.root.is_none());
1207        assert_eq!(bvh.node_count(), 0);
1208        assert_eq!(bvh.depth(), 0);
1209    }
1210
1211    #[test]
1212    fn bvh_build_single() {
1213        let prims = make_grid_primitives(1);
1214        let bvh = Bvh::build(prims);
1215        assert!(bvh.root.is_some());
1216        assert!(bvh.root.as_ref().unwrap().is_leaf());
1217        assert_eq!(bvh.node_count(), 1);
1218        assert_eq!(bvh.depth(), 1);
1219    }
1220
1221    #[test]
1222    fn bvh_query_aabb_finds_overlap() {
1223        let prims = make_grid_primitives(10);
1224        let bvh = Bvh::build(prims);
1225        // Query the box that overlaps object 5 only.
1226        let query = Aabb::new([5.1, 0.1, 0.1], [5.9, 0.9, 0.9]);
1227        let mut hits = bvh.query_aabb(&query);
1228        hits.sort();
1229        assert_eq!(hits, vec![5]);
1230    }
1231
1232    #[test]
1233    fn bvh_query_aabb_empty_result() {
1234        let prims = make_grid_primitives(5);
1235        let bvh = Bvh::build(prims);
1236        let query = Aabb::new([100.0, 0.0, 0.0], [101.0, 1.0, 1.0]);
1237        assert!(bvh.query_aabb(&query).is_empty());
1238    }
1239
1240    #[test]
1241    fn bvh_query_aabb_finds_multiple() {
1242        let prims = make_grid_primitives(10);
1243        let bvh = Bvh::build(prims);
1244        // Query spanning objects 2, 3, 4
1245        let query = Aabb::new([2.1, 0.1, 0.1], [4.9, 0.9, 0.9]);
1246        let mut hits = bvh.query_aabb(&query);
1247        hits.sort();
1248        assert_eq!(hits, vec![2, 3, 4]);
1249    }
1250
1251    #[test]
1252    fn bvh_query_ray_hits() {
1253        let prims = make_grid_primitives(8);
1254        let bvh = Bvh::build(prims);
1255        // Ray along X axis at y=0.5, z=0.5 hits all 8 primitives.
1256        let mut hits = bvh.query_ray([-1.0, 0.5, 0.5], [1.0, 0.0, 0.0], 20.0);
1257        hits.sort();
1258        assert_eq!(hits, (0..8).collect::<Vec<_>>());
1259    }
1260
1261    #[test]
1262    fn bvh_query_ray_misses() {
1263        let prims = make_grid_primitives(5);
1264        let bvh = Bvh::build(prims);
1265        let hits = bvh.query_ray([0.5, 10.0, 0.5], [0.0, 1.0, 0.0], 100.0);
1266        assert!(hits.is_empty());
1267    }
1268
1269    #[test]
1270    fn bvh_node_count_and_depth_consistent() {
1271        let prims = make_grid_primitives(16);
1272        let bvh = Bvh::build(prims);
1273        // With LEAF_SIZE=4 and 16 prims, depth should be at least 2.
1274        assert!(bvh.depth() >= 2);
1275        // An n-primitive tree has at most 2n-1 nodes.
1276        assert!(bvh.node_count() < 2 * 16);
1277    }
1278
1279    // ------------------------------------------------------------------
1280    // Flat BVH
1281    // ------------------------------------------------------------------
1282
1283    #[test]
1284    fn flatten_empty_bvh() {
1285        let bvh = Bvh::build(vec![]);
1286        let (nodes, prim_indices) = flatten(&bvh);
1287        assert!(nodes.is_empty());
1288        assert!(prim_indices.is_empty());
1289    }
1290
1291    #[test]
1292    fn flatten_single_primitive() {
1293        let prims = make_grid_primitives(1);
1294        let bvh = Bvh::build(prims);
1295        let (nodes, prim_indices) = flatten(&bvh);
1296        assert_eq!(nodes.len(), 1);
1297        assert_eq!(prim_indices.len(), 1);
1298        assert_eq!(nodes[0].count, 1);
1299    }
1300
1301    #[test]
1302    fn query_flat_finds_overlap() {
1303        let prims = make_grid_primitives(10);
1304        let bvh = Bvh::build(prims);
1305        let (nodes, prim_indices) = flatten(&bvh);
1306        let query = Aabb::new([3.1, 0.1, 0.1], [3.9, 0.9, 0.9]);
1307        let mut hits = query_flat(&nodes, &prim_indices, &bvh.primitives, &query);
1308        hits.sort();
1309        assert_eq!(hits, vec![3]);
1310    }
1311
1312    #[test]
1313    fn query_flat_empty_result() {
1314        let prims = make_grid_primitives(5);
1315        let bvh = Bvh::build(prims);
1316        let (nodes, prim_indices) = flatten(&bvh);
1317        let query = Aabb::new([50.0, 0.0, 0.0], [51.0, 1.0, 1.0]);
1318        assert!(query_flat(&nodes, &prim_indices, &bvh.primitives, &query).is_empty());
1319    }
1320
1321    #[test]
1322    fn query_flat_matches_recursive() {
1323        let prims = make_grid_primitives(20);
1324        let bvh = Bvh::build(prims);
1325        let query = Aabb::new([7.1, 0.0, 0.0], [12.9, 1.0, 1.0]);
1326        let mut recursive_hits = bvh.query_aabb(&query);
1327        recursive_hits.sort();
1328
1329        let (nodes, prim_indices) = flatten(&bvh);
1330        let mut flat_hits = query_flat(&nodes, &prim_indices, &bvh.primitives, &query);
1331        flat_hits.sort();
1332
1333        assert_eq!(recursive_hits, flat_hits);
1334    }
1335
1336    // ------------------------------------------------------------------
1337    // Morton code tests
1338    // ------------------------------------------------------------------
1339
1340    #[test]
1341    fn morton_origin_is_zero() {
1342        assert_eq!(morton_code([0.0, 0.0, 0.0]), 0);
1343    }
1344
1345    #[test]
1346    fn morton_increases_along_x() {
1347        let m0 = morton_code([0.0, 0.0, 0.0]);
1348        let m1 = morton_code([0.5, 0.0, 0.0]);
1349        let m2 = morton_code([1.0, 0.0, 0.0]);
1350        // With all y,z=0 the Morton code grows with x.
1351        // (bits interleaved so we check non-decreasing)
1352        assert!(m0 <= m1, "m0={} m1={}", m0, m1);
1353        assert!(m1 <= m2, "m1={} m2={}", m1, m2);
1354    }
1355
1356    #[test]
1357    fn morton_clamps_outside_unit_cube() {
1358        let m_neg = morton_code([-1.0, -1.0, -1.0]);
1359        let m_zero = morton_code([0.0, 0.0, 0.0]);
1360        assert_eq!(m_neg, m_zero);
1361
1362        let m_big = morton_code([2.0, 2.0, 2.0]);
1363        let m_one = morton_code([1.0, 1.0, 1.0]);
1364        assert_eq!(m_big, m_one);
1365    }
1366
1367    // ------------------------------------------------------------------
1368    // LBVH construction tests
1369    // ------------------------------------------------------------------
1370
1371    #[test]
1372    fn lbvh_build_empty() {
1373        let bvh = lbvh_build(vec![]);
1374        assert!(bvh.root.is_none());
1375    }
1376
1377    #[test]
1378    fn lbvh_build_single() {
1379        let prims = make_grid_primitives(1);
1380        let bvh = lbvh_build(prims);
1381        assert!(bvh.root.is_some());
1382        assert!(bvh.root.as_ref().unwrap().is_leaf());
1383    }
1384
1385    #[test]
1386    fn lbvh_build_query_finds_correct_objects() {
1387        let prims = make_grid_primitives(10);
1388        let bvh = lbvh_build(prims);
1389        let query = Aabb::new([4.1, 0.1, 0.1], [4.9, 0.9, 0.9]);
1390        let mut hits = bvh.query_aabb(&query);
1391        hits.sort();
1392        assert_eq!(hits, vec![4]);
1393    }
1394
1395    #[test]
1396    fn lbvh_build_covers_all_primitives() {
1397        let prims = make_grid_primitives(8);
1398        let bvh = lbvh_build(prims);
1399        // Root AABB should contain all primitives.
1400        let root = bvh.root.as_ref().unwrap();
1401        assert!(root.aabb.min[0] <= 0.0);
1402        assert!(root.aabb.max[0] >= 8.0);
1403    }
1404
1405    // ------------------------------------------------------------------
1406    // BVH closest-hit traversal
1407    // ------------------------------------------------------------------
1408
1409    #[test]
1410    fn closest_hit_returns_nearest() {
1411        let prims = make_grid_primitives(10);
1412        let bvh = Bvh::build(prims);
1413        // Ray along X from x=-1: should hit object 0 first (x ∈ [0,1])
1414        let hit = bvh_closest_hit(&bvh, [-1.0, 0.5, 0.5], [1.0, 0.0, 0.0], 100.0);
1415        assert!(hit.is_some(), "ray should hit something");
1416        let hit = hit.unwrap();
1417        assert_eq!(
1418            hit.object_id, 0,
1419            "closest hit should be object 0, got {}",
1420            hit.object_id
1421        );
1422    }
1423
1424    #[test]
1425    fn closest_hit_misses_returns_none() {
1426        let prims = make_grid_primitives(5);
1427        let bvh = Bvh::build(prims);
1428        let hit = bvh_closest_hit(&bvh, [0.5, 10.0, 0.5], [0.0, 1.0, 0.0], 100.0);
1429        assert!(hit.is_none());
1430    }
1431
1432    #[test]
1433    fn closest_hit_empty_bvh_returns_none() {
1434        let bvh = Bvh::build(vec![]);
1435        let hit = bvh_closest_hit(&bvh, [0.0, 0.0, 0.0], [1.0, 0.0, 0.0], 100.0);
1436        assert!(hit.is_none());
1437    }
1438
1439    #[test]
1440    fn closest_hit_t_is_positive() {
1441        let prims = make_grid_primitives(5);
1442        let bvh = Bvh::build(prims);
1443        let hit = bvh_closest_hit(&bvh, [-1.0, 0.5, 0.5], [1.0, 0.0, 0.0], 100.0);
1444        if let Some(h) = hit {
1445            assert!(h.t >= 0.0, "t should be non-negative, got {}", h.t);
1446        }
1447    }
1448
1449    // ------------------------------------------------------------------
1450    // BVH refit tests
1451    // ------------------------------------------------------------------
1452
1453    #[test]
1454    fn refit_preserves_topology() {
1455        let prims = make_grid_primitives(8);
1456        let mut bvh = Bvh::build(prims);
1457        let before_count = bvh.node_count();
1458        if let Some(root) = bvh.root.as_mut() {
1459            refit(root, &bvh.primitives);
1460        }
1461        assert_eq!(
1462            bvh.node_count(),
1463            before_count,
1464            "refit should not change node count"
1465        );
1466    }
1467
1468    #[test]
1469    fn refit_root_aabb_covers_all() {
1470        let prims = make_grid_primitives(8);
1471        let mut bvh = Bvh::build(prims);
1472        if let Some(root) = bvh.root.as_mut() {
1473            refit(root, &bvh.primitives);
1474        }
1475        let root = bvh.root.as_ref().unwrap();
1476        assert!(root.aabb.min[0] <= 0.0 + 1e-5);
1477        assert!(root.aabb.max[0] >= 8.0 - 1e-5);
1478    }
1479
1480    // ------------------------------------------------------------------
1481    // HLBVH split tests
1482    // ------------------------------------------------------------------
1483
1484    #[test]
1485    fn hlbvh_split_two_distinct_values() {
1486        let mortons = vec![0u32, 1u32];
1487        let split = hlbvh_split(&mortons);
1488        assert_eq!(split, 1);
1489    }
1490
1491    #[test]
1492    fn hlbvh_split_returns_valid_index() {
1493        let mortons: Vec<u32> = (0..16).map(|i| i * 64).collect();
1494        let split = hlbvh_split(&mortons);
1495        assert!(split > 0 && split < mortons.len(), "split={}", split);
1496    }
1497
1498    #[test]
1499    fn hlbvh_split_equal_values_returns_one() {
1500        let mortons = vec![5u32; 8];
1501        let split = hlbvh_split(&mortons);
1502        // All equal → leading_zeros of 0 = 32; binary search result = 1
1503        assert!(split >= 1 && split < mortons.len());
1504    }
1505
1506    // ------------------------------------------------------------------
1507    // BVH statistics tests
1508    // ------------------------------------------------------------------
1509
1510    #[test]
1511    fn bvh_stats_empty() {
1512        let bvh = Bvh::build(vec![]);
1513        let s = BvhStats::compute(&bvh);
1514        assert_eq!(s.node_count, 0);
1515        assert_eq!(s.leaf_count, 0);
1516        assert_eq!(s.total_primitives, 0);
1517    }
1518
1519    #[test]
1520    fn bvh_stats_single_primitive() {
1521        let prims = make_grid_primitives(1);
1522        let bvh = Bvh::build(prims);
1523        let s = BvhStats::compute(&bvh);
1524        assert_eq!(s.node_count, 1);
1525        assert_eq!(s.leaf_count, 1);
1526        assert_eq!(s.total_primitives, 1);
1527        assert_eq!(s.max_depth, 1);
1528    }
1529
1530    #[test]
1531    fn bvh_stats_node_count_consistent() {
1532        let prims = make_grid_primitives(16);
1533        let bvh = Bvh::build(prims.clone());
1534        let s = BvhStats::compute(&bvh);
1535        assert_eq!(s.node_count, bvh.node_count());
1536        assert_eq!(s.leaf_count + s.internal_count, s.node_count);
1537        assert_eq!(s.total_primitives, prims.len());
1538    }
1539
1540    #[test]
1541    fn bvh_stats_avg_primitives_per_leaf() {
1542        let prims = make_grid_primitives(8);
1543        let bvh = Bvh::build(prims);
1544        let s = BvhStats::compute(&bvh);
1545        assert!(s.avg_primitives_per_leaf > 0.0);
1546        // avg should not exceed LEAF_SIZE + 1
1547        assert!(s.avg_primitives_per_leaf <= (LEAF_SIZE + 1) as f32);
1548    }
1549
1550    #[test]
1551    fn bvh_stats_max_depth_reasonable() {
1552        let prims = make_grid_primitives(32);
1553        let bvh = Bvh::build(prims);
1554        let s = BvhStats::compute(&bvh);
1555        // Depth should be at most log2(32/LEAF_SIZE) + 1 = ~4
1556        assert!(
1557            s.max_depth >= 1 && s.max_depth <= 20,
1558            "depth={}",
1559            s.max_depth
1560        );
1561    }
1562
1563    // ------------------------------------------------------------------
1564    // LbvhPrimitive Morton code assignment
1565    // ------------------------------------------------------------------
1566
1567    #[test]
1568    fn lbvh_primitive_morton_in_range() {
1569        let aabb = Aabb::new([0.5, 0.5, 0.5], [1.0, 1.0, 1.0]);
1570        let scene = Aabb::new([0.0, 0.0, 0.0], [2.0, 2.0, 2.0]);
1571        let lp = LbvhPrimitive::new(aabb, 0, &scene);
1572        // Morton code is 30-bit so max is (1<<30)-1
1573        assert!(lp.morton < (1u32 << 30));
1574    }
1575
1576    #[test]
1577    fn lbvh_primitive_at_origin_small_code() {
1578        let aabb = Aabb::point([0.0, 0.0, 0.0]);
1579        let scene = Aabb::new([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]);
1580        let lp = LbvhPrimitive::new(aabb, 0, &scene);
1581        assert_eq!(lp.morton, 0);
1582    }
1583
1584    // ------------------------------------------------------------------
1585    // compute_bvh_from_sorted tests
1586    // ------------------------------------------------------------------
1587
1588    fn make_sorted_lbvh_prims(n: usize) -> Vec<LbvhPrimitive> {
1589        let scene = Aabb::new([0.0, 0.0, 0.0], [n as f32 + 1.0, 1.0, 1.0]);
1590        let mut prims: Vec<LbvhPrimitive> = (0..n)
1591            .map(|i| {
1592                let x = i as f32;
1593                LbvhPrimitive::new(Aabb::new([x, 0.0, 0.0], [x + 1.0, 1.0, 1.0]), i, &scene)
1594            })
1595            .collect();
1596        prims.sort_unstable_by_key(|lp| lp.morton);
1597        prims
1598    }
1599
1600    #[test]
1601    fn compute_bvh_from_sorted_empty() {
1602        let bvh = compute_bvh_from_sorted(&[]);
1603        assert!(bvh.root.is_none());
1604        assert_eq!(bvh.primitives.len(), 0);
1605    }
1606
1607    #[test]
1608    fn compute_bvh_from_sorted_single() {
1609        let scene = Aabb::new([0.0, 0.0, 0.0], [2.0, 1.0, 1.0]);
1610        let lp = LbvhPrimitive::new(Aabb::new([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]), 7, &scene);
1611        let bvh = compute_bvh_from_sorted(&[lp]);
1612        assert!(bvh.root.is_some());
1613        assert_eq!(bvh.primitives.len(), 1);
1614    }
1615
1616    #[test]
1617    fn compute_bvh_from_sorted_preserves_count() {
1618        let sorted = make_sorted_lbvh_prims(16);
1619        let bvh = compute_bvh_from_sorted(&sorted);
1620        assert_eq!(bvh.primitives.len(), 16);
1621    }
1622
1623    #[test]
1624    fn compute_bvh_from_sorted_root_covers_all() {
1625        let sorted = make_sorted_lbvh_prims(8);
1626        let bvh = compute_bvh_from_sorted(&sorted);
1627        let root_aabb = &bvh.root.as_ref().unwrap().aabb;
1628        assert!(root_aabb.min[0] <= 0.0 + 1e-5);
1629        assert!(root_aabb.max[0] >= 8.0 - 1e-5);
1630    }
1631
1632    // ------------------------------------------------------------------
1633    // compute_cluster_radius tests
1634    // ------------------------------------------------------------------
1635
1636    #[test]
1637    fn compute_cluster_radius_empty() {
1638        let r = compute_cluster_radius(&[]);
1639        assert_eq!(r, 0.0);
1640    }
1641
1642    #[test]
1643    fn compute_cluster_radius_single() {
1644        let scene = Aabb::new([0.0, 0.0, 0.0], [2.0, 2.0, 2.0]);
1645        let lp = LbvhPrimitive::new(Aabb::point([1.0, 1.0, 1.0]), 0, &scene);
1646        let r = compute_cluster_radius(&[lp]);
1647        // Single point cluster: radius = 0
1648        assert!(
1649            r < 1e-6,
1650            "single-point cluster radius should be ~0, got {r}"
1651        );
1652    }
1653
1654    #[test]
1655    fn compute_cluster_radius_two_points() {
1656        let scene = Aabb::new([0.0, 0.0, 0.0], [4.0, 1.0, 1.0]);
1657        let lp0 = LbvhPrimitive::new(Aabb::point([0.0, 0.0, 0.0]), 0, &scene);
1658        let lp1 = LbvhPrimitive::new(Aabb::point([2.0, 0.0, 0.0]), 1, &scene);
1659        let r = compute_cluster_radius(&[lp0, lp1]);
1660        // Centroid is (1,0,0); each point is at distance 1.
1661        assert!((r - 1.0).abs() < 1e-5, "radius should be 1.0, got {r}");
1662    }
1663
1664    #[test]
1665    fn compute_cluster_radius_is_non_negative() {
1666        let sorted = make_sorted_lbvh_prims(12);
1667        let r = compute_cluster_radius(&sorted);
1668        assert!(r >= 0.0, "radius must be non-negative, got {r}");
1669    }
1670
1671    // ------------------------------------------------------------------
1672    // BvhTreeStatistics tests
1673    // ------------------------------------------------------------------
1674
1675    #[test]
1676    fn bvh_tree_stats_empty() {
1677        let bvh = Bvh::build(vec![]);
1678        let s = BvhTreeStatistics::compute(&bvh);
1679        assert_eq!(s.node_count, 0);
1680        assert_eq!(s.leaf_count, 0);
1681        assert_eq!(s.internal_count, 0);
1682        assert_eq!(s.total_primitives, 0);
1683    }
1684
1685    #[test]
1686    fn bvh_tree_stats_fanout_binary() {
1687        let prims = make_grid_primitives(16);
1688        let bvh = Bvh::build(prims);
1689        let s = BvhTreeStatistics::compute(&bvh);
1690        // Binary tree: avg fanout must be <= 2.0
1691        assert!(s.avg_fanout <= 2.0 + 1e-6, "fanout = {}", s.avg_fanout);
1692    }
1693
1694    #[test]
1695    fn bvh_tree_stats_node_count_consistent() {
1696        let prims = make_grid_primitives(16);
1697        let bvh = Bvh::build(prims.clone());
1698        let s = BvhTreeStatistics::compute(&bvh);
1699        assert_eq!(s.leaf_count + s.internal_count, s.node_count);
1700        assert_eq!(s.total_primitives, prims.len());
1701    }
1702
1703    #[test]
1704    fn bvh_tree_stats_leaf_surface_area_positive() {
1705        let prims = make_grid_primitives(8);
1706        let bvh = Bvh::build(prims);
1707        let s = BvhTreeStatistics::compute(&bvh);
1708        assert!(
1709            s.total_leaf_surface_area > 0.0,
1710            "leaf surface area should be > 0"
1711        );
1712    }
1713
1714    // ------------------------------------------------------------------
1715    // build_morton_clusters tests
1716    // ------------------------------------------------------------------
1717
1718    #[test]
1719    fn build_morton_clusters_empty() {
1720        let clusters = build_morton_clusters(&[], 4);
1721        assert!(clusters.is_empty());
1722    }
1723
1724    #[test]
1725    fn build_morton_clusters_count() {
1726        let sorted = make_sorted_lbvh_prims(10);
1727        let clusters = build_morton_clusters(&sorted, 3);
1728        // 10 primitives in chunks of 3 → ceil(10/3) = 4 clusters
1729        assert_eq!(clusters.len(), 4);
1730    }
1731
1732    #[test]
1733    fn build_morton_clusters_radii_non_negative() {
1734        let sorted = make_sorted_lbvh_prims(8);
1735        let clusters = build_morton_clusters(&sorted, 2);
1736        for c in &clusters {
1737            assert!(c.radius >= 0.0, "cluster radius must be non-negative");
1738        }
1739    }
1740}