Skip to main content

oxiphysics_gpu/bvh/
cpu.rs

1// Copyright 2026 COOLJAPAN OU (Team KitaSan)
2// SPDX-License-Identifier: Apache-2.0
3
4//! CPU BVH construction, query, traversal, and LBVH helpers.
5
6use super::types::{
7    Aabb, BvhNode, BvhPrimitive, BvhStats, BvhTreeStatistics, FlatBvhNode, LbvhPrimitive,
8    MortonCluster, RayHit,
9};
10
11// ============================================================================
12// SAH helper
13// ============================================================================
14
15/// Surface Area Heuristic cost:
16/// `C = (SA_left / SA_parent) * N_left + (SA_right / SA_parent) * N_right`
17pub fn sah_cost(n_left: usize, sa_left: f32, n_right: usize, sa_right: f32, sa_parent: f32) -> f32 {
18    if sa_parent <= 0.0 {
19        return f32::MAX;
20    }
21    (sa_left / sa_parent) * n_left as f32 + (sa_right / sa_parent) * n_right as f32
22}
23
24// ============================================================================
25// Slab-method ray–AABB intersection
26// ============================================================================
27
28/// Test whether a ray defined by `origin` + t * direction intersects `aabb`
29/// within `[0, max_t]`.
30///
31/// `inv_dir` must be the component-wise reciprocal of the ray direction.
32pub fn ray_aabb_intersect(origin: [f32; 3], inv_dir: [f32; 3], aabb: &Aabb, max_t: f32) -> bool {
33    let mut t_min = 0.0_f32;
34    let mut t_max = max_t;
35
36    for i in 0..3 {
37        let t1 = (aabb.min[i] - origin[i]) * inv_dir[i];
38        let t2 = (aabb.max[i] - origin[i]) * inv_dir[i];
39        let lo = t1.min(t2);
40        let hi = t1.max(t2);
41        t_min = t_min.max(lo);
42        t_max = t_max.min(hi);
43    }
44
45    t_min <= t_max
46}
47
48/// Ray–AABB slab intersection returning the near/far t values.
49pub(crate) fn ray_aabb_t(origin: [f32; 3], inv_dir: [f32; 3], aabb: &Aabb) -> Option<(f32, f32)> {
50    let mut t_min = 0.0_f32;
51    let mut t_max = f32::MAX;
52    for i in 0..3 {
53        let t1 = (aabb.min[i] - origin[i]) * inv_dir[i];
54        let t2 = (aabb.max[i] - origin[i]) * inv_dir[i];
55        t_min = t_min.max(t1.min(t2));
56        t_max = t_max.min(t1.max(t2));
57    }
58    if t_min <= t_max {
59        Some((t_min, t_max))
60    } else {
61        None
62    }
63}
64
65// ============================================================================
66// Bvh
67// ============================================================================
68
69/// Maximum number of primitives per leaf before the tree stops splitting.
70pub(crate) const LEAF_SIZE: usize = 4;
71
72/// A BVH tree built from a flat list of [`BvhPrimitive`]s.
73pub struct Bvh {
74    /// Tree root (if there is at least one primitive).
75    pub root: Option<BvhNode>,
76    /// All primitives passed to [`Bvh::build`].
77    pub primitives: Vec<BvhPrimitive>,
78}
79
80impl Bvh {
81    /// Build a BVH from a list of primitives using a median-split strategy
82    /// guided by the longest axis (SAH-inspired).
83    pub fn build(primitives: Vec<BvhPrimitive>) -> Self {
84        if primitives.is_empty() {
85            return Self {
86                root: None,
87                primitives,
88            };
89        }
90        let indices: Vec<usize> = (0..primitives.len()).collect();
91        let root = build_recursive(&primitives, indices);
92        Self {
93            root: Some(root),
94            primitives,
95        }
96    }
97
98    /// Return the `object_id`s of all primitives whose AABB overlaps `query`.
99    pub fn query_aabb(&self, query: &Aabb) -> Vec<usize> {
100        let mut result = Vec::new();
101        if let Some(root) = &self.root {
102            query_aabb_recursive(root, query, &self.primitives, &mut result);
103        }
104        result
105    }
106
107    /// Return the `object_id`s of all primitives hit by the given ray within
108    /// distance `max_t`.
109    pub fn query_ray(&self, origin: [f32; 3], direction: [f32; 3], max_t: f32) -> Vec<usize> {
110        let inv_dir = [1.0 / direction[0], 1.0 / direction[1], 1.0 / direction[2]];
111        let mut result = Vec::new();
112        if let Some(root) = &self.root {
113            query_ray_recursive(root, origin, inv_dir, max_t, &self.primitives, &mut result);
114        }
115        result
116    }
117
118    /// Total number of nodes in the tree.
119    pub fn node_count(&self) -> usize {
120        match &self.root {
121            None => 0,
122            Some(root) => count_nodes(root),
123        }
124    }
125
126    /// Maximum depth of the tree (root = depth 1).
127    pub fn depth(&self) -> usize {
128        match &self.root {
129            None => 0,
130            Some(root) => node_depth(root),
131        }
132    }
133}
134
135// ============================================================================
136// Internal build / query helpers
137// ============================================================================
138
139pub(crate) fn bounding_box(primitives: &[BvhPrimitive], indices: &[usize]) -> Aabb {
140    let mut aabb = primitives[indices[0]].aabb.clone();
141    for &i in &indices[1..] {
142        aabb = Aabb::merge(&aabb, &primitives[i].aabb);
143    }
144    aabb
145}
146
147fn build_recursive(primitives: &[BvhPrimitive], mut indices: Vec<usize>) -> BvhNode {
148    let aabb = bounding_box(primitives, &indices);
149
150    if indices.len() <= LEAF_SIZE {
151        return BvhNode {
152            aabb,
153            left: None,
154            right: None,
155            primitives: indices,
156        };
157    }
158
159    // Choose longest axis for split.
160    let dx = aabb.max[0] - aabb.min[0];
161    let dy = aabb.max[1] - aabb.min[1];
162    let dz = aabb.max[2] - aabb.min[2];
163    let axis = if dx >= dy && dx >= dz {
164        0
165    } else if dy >= dz {
166        1
167    } else {
168        2
169    };
170
171    // Median split by centre of primitive AABB along the chosen axis.
172    indices.sort_unstable_by(|&a, &b| {
173        let ca = primitives[a].aabb.center()[axis];
174        let cb = primitives[b].aabb.center()[axis];
175        ca.partial_cmp(&cb).unwrap_or(std::cmp::Ordering::Equal)
176    });
177
178    let mid = indices.len() / 2;
179    let right_indices = indices.split_off(mid);
180    let left_indices = indices;
181
182    let left = build_recursive(primitives, left_indices);
183    let right = build_recursive(primitives, right_indices);
184
185    BvhNode {
186        aabb,
187        left: Some(Box::new(left)),
188        right: Some(Box::new(right)),
189        primitives: Vec::new(),
190    }
191}
192
193fn query_aabb_recursive(
194    node: &BvhNode,
195    query: &Aabb,
196    primitives: &[BvhPrimitive],
197    result: &mut Vec<usize>,
198) {
199    if !node.aabb.intersects(query) {
200        return;
201    }
202    if node.is_leaf() {
203        for &idx in &node.primitives {
204            if primitives[idx].aabb.intersects(query) {
205                result.push(primitives[idx].object_id);
206            }
207        }
208    } else {
209        if let Some(left) = &node.left {
210            query_aabb_recursive(left, query, primitives, result);
211        }
212        if let Some(right) = &node.right {
213            query_aabb_recursive(right, query, primitives, result);
214        }
215    }
216}
217
218fn query_ray_recursive(
219    node: &BvhNode,
220    origin: [f32; 3],
221    inv_dir: [f32; 3],
222    max_t: f32,
223    primitives: &[BvhPrimitive],
224    result: &mut Vec<usize>,
225) {
226    if !ray_aabb_intersect(origin, inv_dir, &node.aabb, max_t) {
227        return;
228    }
229    if node.is_leaf() {
230        for &idx in &node.primitives {
231            if ray_aabb_intersect(origin, inv_dir, &primitives[idx].aabb, max_t) {
232                result.push(primitives[idx].object_id);
233            }
234        }
235    } else {
236        if let Some(left) = &node.left {
237            query_ray_recursive(left, origin, inv_dir, max_t, primitives, result);
238        }
239        if let Some(right) = &node.right {
240            query_ray_recursive(right, origin, inv_dir, max_t, primitives, result);
241        }
242    }
243}
244
245fn count_nodes(node: &BvhNode) -> usize {
246    1 + node.left.as_ref().map_or(0, |n| count_nodes(n))
247        + node.right.as_ref().map_or(0, |n| count_nodes(n))
248}
249
250fn node_depth(node: &BvhNode) -> usize {
251    1 + node
252        .left
253        .as_ref()
254        .map_or(0, |n| node_depth(n))
255        .max(node.right.as_ref().map_or(0, |n| node_depth(n)))
256}
257
258// ============================================================================
259// Flat BVH
260// ============================================================================
261
262/// Flatten a [`Bvh`] into a `Vec<FlatBvhNode>` (DFS pre-order) together with
263/// a reordered primitive-index slice.
264///
265/// Returns `(flat_nodes, prim_indices)` where `prim_indices[i]` is an index
266/// into `bvh.primitives`.
267pub fn flatten(bvh: &Bvh) -> (Vec<FlatBvhNode>, Vec<usize>) {
268    let mut nodes: Vec<FlatBvhNode> = Vec::new();
269    let mut prim_indices: Vec<usize> = Vec::new();
270
271    if let Some(root) = &bvh.root {
272        flatten_recursive(root, &mut nodes, &mut prim_indices);
273    }
274
275    (nodes, prim_indices)
276}
277
278/// Returns the index at which `node` was stored.
279fn flatten_recursive(
280    node: &BvhNode,
281    nodes: &mut Vec<FlatBvhNode>,
282    prim_indices: &mut Vec<usize>,
283) -> usize {
284    let node_idx = nodes.len();
285
286    if node.is_leaf() {
287        let first = prim_indices.len() as u32;
288        let count = node.primitives.len() as u32;
289        prim_indices.extend_from_slice(&node.primitives);
290        nodes.push(FlatBvhNode {
291            aabb: node.aabb.clone(),
292            left_first: first,
293            count,
294        });
295    } else {
296        // Reserve a slot; left_first (right child index) filled after recursion.
297        nodes.push(FlatBvhNode {
298            aabb: node.aabb.clone(),
299            left_first: 0,
300            count: 0,
301        });
302        // Left child is always node_idx + 1 (no explicit storage needed).
303        if let Some(left) = &node.left {
304            flatten_recursive(left, nodes, prim_indices);
305        }
306        // Right child comes after the entire left subtree.
307        let right_idx = if let Some(right) = &node.right {
308            flatten_recursive(right, nodes, prim_indices)
309        } else {
310            0
311        };
312        nodes[node_idx].left_first = right_idx as u32;
313    }
314
315    node_idx
316}
317
318/// Iterative AABB query over a flat BVH.
319///
320/// Returns the `object_id` values of all primitives whose AABB overlaps
321/// `query`. `bvh_primitives` is the `Bvh::primitives` slice.
322pub fn query_flat(
323    nodes: &[FlatBvhNode],
324    prim_indices: &[usize],
325    bvh_primitives: &[BvhPrimitive],
326    query: &Aabb,
327) -> Vec<usize> {
328    let mut result = Vec::new();
329    if nodes.is_empty() {
330        return result;
331    }
332
333    let mut stack: Vec<usize> = Vec::with_capacity(64);
334    stack.push(0);
335
336    while let Some(idx) = stack.pop() {
337        let node = &nodes[idx];
338        if !node.aabb.intersects(query) {
339            continue;
340        }
341        if node.count > 0 {
342            // Leaf
343            let start = node.left_first as usize;
344            let end = start + node.count as usize;
345            for &pi in &prim_indices[start..end] {
346                if bvh_primitives[pi].aabb.intersects(query) {
347                    result.push(bvh_primitives[pi].object_id);
348                }
349            }
350        } else {
351            // Internal: left child is at idx+1, right child at left_first.
352            let right = node.left_first as usize;
353            stack.push(right);
354            stack.push(idx + 1);
355        }
356    }
357
358    result
359}
360
361// ============================================================================
362// BVH Traversal (closest hit)
363// ============================================================================
364
365/// Traverse the BVH returning the **closest** hit (smallest positive t).
366pub fn bvh_closest_hit(
367    bvh: &Bvh,
368    origin: [f32; 3],
369    direction: [f32; 3],
370    max_t: f32,
371) -> Option<RayHit> {
372    let inv_dir = [1.0 / direction[0], 1.0 / direction[1], 1.0 / direction[2]];
373    let root = bvh.root.as_ref()?;
374    let mut best: Option<RayHit> = None;
375    let mut current_max = max_t;
376    closest_hit_recursive(
377        root,
378        origin,
379        inv_dir,
380        &bvh.primitives,
381        &mut best,
382        &mut current_max,
383    );
384    best
385}
386
387fn closest_hit_recursive(
388    node: &BvhNode,
389    origin: [f32; 3],
390    inv_dir: [f32; 3],
391    primitives: &[BvhPrimitive],
392    best: &mut Option<RayHit>,
393    max_t: &mut f32,
394) {
395    if ray_aabb_t(origin, inv_dir, &node.aabb).is_none() {
396        return;
397    }
398    if node.is_leaf() {
399        for &idx in &node.primitives {
400            if let Some((t_min, _)) = ray_aabb_t(origin, inv_dir, &primitives[idx].aabb)
401                && t_min >= 0.0
402                && t_min < *max_t
403            {
404                *max_t = t_min;
405                *best = Some(RayHit {
406                    object_id: primitives[idx].object_id,
407                    t: t_min,
408                });
409            }
410        }
411    } else {
412        if let Some(left) = &node.left {
413            closest_hit_recursive(left, origin, inv_dir, primitives, best, max_t);
414        }
415        if let Some(right) = &node.right {
416            closest_hit_recursive(right, origin, inv_dir, primitives, best, max_t);
417        }
418    }
419}
420
421// ============================================================================
422// BVH Refit
423// ============================================================================
424
425/// Refit the bounding boxes of an existing BVH after primitives have moved.
426///
427/// The topology (splits) are preserved; only bounding boxes are recomputed.
428pub fn refit(node: &mut BvhNode, primitives: &[BvhPrimitive]) {
429    if node.is_leaf() {
430        if !node.primitives.is_empty() {
431            node.aabb = bounding_box(primitives, &node.primitives);
432        }
433        return;
434    }
435    if let Some(left) = node.left.as_mut() {
436        refit(left, primitives);
437    }
438    if let Some(right) = node.right.as_mut() {
439        refit(right, primitives);
440    }
441    // Recompute bounding box from children.
442    let left_aabb = node.left.as_ref().map(|n| n.aabb.clone());
443    let right_aabb = node.right.as_ref().map(|n| n.aabb.clone());
444    node.aabb = match (left_aabb, right_aabb) {
445        (Some(l), Some(r)) => Aabb::merge(&l, &r),
446        (Some(l), None) => l,
447        (None, Some(r)) => r,
448        (None, None) => node.aabb.clone(),
449    };
450}
451
452// ============================================================================
453// Morton code (LBVH)
454// ============================================================================
455
456/// Expand a 10-bit integer into 30 bits by inserting two zeros before each bit.
457fn expand_bits(mut v: u32) -> u32 {
458    v = (v | (v << 16)) & 0x030000FF;
459    v = (v | (v << 8)) & 0x0300F00F;
460    v = (v | (v << 4)) & 0x030C30C3;
461    v = (v | (v << 2)) & 0x09249249;
462    v
463}
464
465/// Compute a 30-bit Morton code for a 3D point normalised to \[0, 1\]^3.
466pub fn morton_code(p: [f32; 3]) -> u32 {
467    let x = (p[0].clamp(0.0, 1.0) * 1023.0) as u32;
468    let y = (p[1].clamp(0.0, 1.0) * 1023.0) as u32;
469    let z = (p[2].clamp(0.0, 1.0) * 1023.0) as u32;
470    expand_bits(x) | (expand_bits(y) << 1) | (expand_bits(z) << 2)
471}
472
473impl LbvhPrimitive {
474    /// Construct an `LbvhPrimitive`, computing the Morton code from the centroid
475    /// normalised by `scene_aabb`.
476    pub fn new(aabb: Aabb, object_id: usize, scene_aabb: &Aabb) -> Self {
477        let c = aabb.center();
478        let scene_size = [
479            (scene_aabb.max[0] - scene_aabb.min[0]).max(1e-10),
480            (scene_aabb.max[1] - scene_aabb.min[1]).max(1e-10),
481            (scene_aabb.max[2] - scene_aabb.min[2]).max(1e-10),
482        ];
483        let norm = [
484            (c[0] - scene_aabb.min[0]) / scene_size[0],
485            (c[1] - scene_aabb.min[1]) / scene_size[1],
486            (c[2] - scene_aabb.min[2]) / scene_size[2],
487        ];
488        let morton = morton_code(norm);
489        Self {
490            aabb,
491            object_id,
492            morton,
493        }
494    }
495}
496
497/// Build an LBVH (Linear BVH) from a set of primitives using Morton-code
498/// sorting and a recursive binary splitting strategy.
499///
500/// Returns a standard [`Bvh`] so the same query functions can be used.
501pub fn lbvh_build(primitives: Vec<BvhPrimitive>) -> Bvh {
502    if primitives.is_empty() {
503        return Bvh {
504            root: None,
505            primitives,
506        };
507    }
508
509    // Compute scene bounding box.
510    let mut scene = primitives[0].aabb.clone();
511    for p in &primitives[1..] {
512        scene = Aabb::merge(&scene, &p.aabb);
513    }
514
515    // Assign Morton codes and sort.
516    let mut indexed: Vec<(u32, usize)> = primitives
517        .iter()
518        .enumerate()
519        .map(|(i, p)| {
520            let lp = LbvhPrimitive::new(p.aabb.clone(), p.object_id, &scene);
521            (lp.morton, i)
522        })
523        .collect();
524    indexed.sort_unstable_by_key(|&(m, _)| m);
525
526    let sorted_indices: Vec<usize> = indexed.iter().map(|&(_, i)| i).collect();
527    let root = lbvh_recursive(&primitives, &sorted_indices);
528
529    Bvh {
530        root: Some(root),
531        primitives,
532    }
533}
534
535fn lbvh_recursive(primitives: &[BvhPrimitive], indices: &[usize]) -> BvhNode {
536    let aabb = bounding_box(primitives, indices);
537
538    if indices.len() <= LEAF_SIZE {
539        return BvhNode {
540            aabb,
541            left: None,
542            right: None,
543            primitives: indices.to_vec(),
544        };
545    }
546
547    let mid = indices.len() / 2;
548    let left = lbvh_recursive(primitives, &indices[..mid]);
549    let right = lbvh_recursive(primitives, &indices[mid..]);
550
551    BvhNode {
552        aabb,
553        left: Some(Box::new(left)),
554        right: Some(Box::new(right)),
555        primitives: Vec::new(),
556    }
557}
558
559// ============================================================================
560// HLBVH Split
561// ============================================================================
562
563/// Find the split index for a slice of Morton-code-sorted primitives using
564/// the highest differing bit (HLBVH strategy).
565///
566/// Returns the split position (0 < split < len).
567pub fn hlbvh_split(mortons: &[u32]) -> usize {
568    if mortons.len() < 2 {
569        return 1;
570    }
571    let first = mortons[0];
572    let last = mortons[mortons.len() - 1];
573    let common_prefix = (first ^ last).leading_zeros();
574    // Binary search for the split where the highest bit differs.
575    let mut lo = 0usize;
576    let mut hi = mortons.len() - 1;
577    while hi - lo > 1 {
578        let mid = (lo + hi) / 2;
579        let prefix = (first ^ mortons[mid]).leading_zeros();
580        if prefix > common_prefix {
581            lo = mid;
582        } else {
583            hi = mid;
584        }
585    }
586    hi
587}
588
589// ============================================================================
590// Morton cluster helpers
591// ============================================================================
592
593/// Build a flat BVH from a pre-sorted (by Morton code) slice of `LbvhPrimitive`s
594/// using a radix-sort-inspired clustering strategy.
595pub fn compute_bvh_from_sorted(sorted: &[LbvhPrimitive]) -> Bvh {
596    if sorted.is_empty() {
597        return Bvh {
598            root: None,
599            primitives: Vec::new(),
600        };
601    }
602
603    // Reconstruct BvhPrimitives in sorted order.
604    let primitives: Vec<BvhPrimitive> = sorted
605        .iter()
606        .map(|lp| BvhPrimitive::new(lp.aabb.clone(), lp.object_id))
607        .collect();
608
609    let mortons: Vec<u32> = sorted.iter().map(|lp| lp.morton).collect();
610    let indices: Vec<usize> = (0..primitives.len()).collect();
611    let root = bvh_from_sorted_recursive(&primitives, &indices, &mortons);
612    Bvh {
613        root: Some(root),
614        primitives,
615    }
616}
617
618fn bvh_from_sorted_recursive(
619    primitives: &[BvhPrimitive],
620    indices: &[usize],
621    mortons: &[u32],
622) -> BvhNode {
623    let aabb = bounding_box(primitives, indices);
624    if indices.len() <= LEAF_SIZE {
625        return BvhNode {
626            aabb,
627            left: None,
628            right: None,
629            primitives: indices.to_vec(),
630        };
631    }
632    // Use HLBVH-style split on Morton codes at corresponding positions.
633    let local_mortons: Vec<u32> = indices.iter().map(|&i| mortons[i]).collect();
634    let split = hlbvh_split(&local_mortons);
635    let left = bvh_from_sorted_recursive(primitives, &indices[..split], mortons);
636    let right = bvh_from_sorted_recursive(primitives, &indices[split..], mortons);
637    BvhNode {
638        aabb,
639        left: Some(Box::new(left)),
640        right: Some(Box::new(right)),
641        primitives: Vec::new(),
642    }
643}
644
645/// Compute the bounding sphere radius for a cluster of `LbvhPrimitive`s.
646pub fn compute_cluster_radius(cluster: &[LbvhPrimitive]) -> f32 {
647    if cluster.is_empty() {
648        return 0.0;
649    }
650    // Compute merged AABB.
651    let mut merged = cluster[0].aabb.clone();
652    for lp in &cluster[1..] {
653        merged = Aabb::merge(&merged, &lp.aabb);
654    }
655    let cx = (merged.min[0] + merged.max[0]) * 0.5;
656    let cy = (merged.min[1] + merged.max[1]) * 0.5;
657    let cz = (merged.min[2] + merged.max[2]) * 0.5;
658
659    let mut max_dist_sq = 0.0_f32;
660    for lp in cluster {
661        let c = lp.aabb.center();
662        let dx = c[0] - cx;
663        let dy = c[1] - cy;
664        let dz = c[2] - cz;
665        let d2 = dx * dx + dy * dy + dz * dz;
666        if d2 > max_dist_sq {
667            max_dist_sq = d2;
668        }
669    }
670    max_dist_sq.sqrt()
671}
672
673/// Compute clusters by grouping Morton-sorted `LbvhPrimitive`s into chunks of
674/// `cluster_size` and returning a `MortonCluster` per group.
675pub fn build_morton_clusters(sorted: &[LbvhPrimitive], cluster_size: usize) -> Vec<MortonCluster> {
676    if sorted.is_empty() || cluster_size == 0 {
677        return Vec::new();
678    }
679    sorted
680        .chunks(cluster_size)
681        .map(|chunk| {
682            let indices: Vec<usize> = (0..chunk.len()).collect();
683            let mut aabb = chunk[0].aabb.clone();
684            for lp in &chunk[1..] {
685                aabb = Aabb::merge(&aabb, &lp.aabb);
686            }
687            let radius = compute_cluster_radius(chunk);
688            MortonCluster {
689                indices,
690                aabb,
691                radius,
692            }
693        })
694        .collect()
695}
696
697// ============================================================================
698// BVH Statistics
699// ============================================================================
700
701impl BvhStats {
702    /// Compute statistics by traversing the given BVH.
703    pub fn compute(bvh: &Bvh) -> Self {
704        let mut s = BvhStats {
705            node_count: 0,
706            leaf_count: 0,
707            internal_count: 0,
708            max_depth: 0,
709            total_primitives: 0,
710            avg_primitives_per_leaf: 0.0,
711        };
712        if let Some(root) = &bvh.root {
713            collect_stats(root, 1, &mut s);
714        }
715        if s.leaf_count > 0 {
716            s.avg_primitives_per_leaf = s.total_primitives as f32 / s.leaf_count as f32;
717        }
718        s
719    }
720}
721
722fn collect_stats(node: &BvhNode, depth: usize, s: &mut BvhStats) {
723    s.node_count += 1;
724    if depth > s.max_depth {
725        s.max_depth = depth;
726    }
727    if node.is_leaf() {
728        s.leaf_count += 1;
729        s.total_primitives += node.primitives.len();
730    } else {
731        s.internal_count += 1;
732        if let Some(left) = &node.left {
733            collect_stats(left, depth + 1, s);
734        }
735        if let Some(right) = &node.right {
736            collect_stats(right, depth + 1, s);
737        }
738    }
739}
740
741impl BvhTreeStatistics {
742    /// Compute extended tree statistics by traversing the given BVH.
743    pub fn compute(bvh: &Bvh) -> Self {
744        let mut s = BvhTreeStatistics {
745            node_count: 0,
746            leaf_count: 0,
747            internal_count: 0,
748            max_depth: 0,
749            total_primitives: 0,
750            avg_fanout: 0.0,
751            total_leaf_surface_area: 0.0,
752        };
753        if let Some(root) = &bvh.root {
754            let mut child_sum = 0usize;
755            collect_tree_stats(root, 1, &mut s, &mut child_sum);
756            s.avg_fanout = if s.internal_count > 0 {
757                child_sum as f32 / s.internal_count as f32
758            } else {
759                0.0
760            };
761        }
762        s
763    }
764}
765
766fn collect_tree_stats(
767    node: &BvhNode,
768    depth: usize,
769    s: &mut BvhTreeStatistics,
770    child_sum: &mut usize,
771) {
772    s.node_count += 1;
773    if depth > s.max_depth {
774        s.max_depth = depth;
775    }
776    if node.is_leaf() {
777        s.leaf_count += 1;
778        s.total_primitives += node.primitives.len();
779        s.total_leaf_surface_area += node.aabb.surface_area();
780    } else {
781        s.internal_count += 1;
782        let mut children = 0usize;
783        if let Some(left) = &node.left {
784            children += 1;
785            collect_tree_stats(left, depth + 1, s, child_sum);
786        }
787        if let Some(right) = &node.right {
788            children += 1;
789            collect_tree_stats(right, depth + 1, s, child_sum);
790        }
791        *child_sum += children;
792    }
793}