1use super::types::{
7 Aabb, BvhNode, BvhPrimitive, BvhStats, BvhTreeStatistics, FlatBvhNode, LbvhPrimitive,
8 MortonCluster, RayHit,
9};
10
11pub fn sah_cost(n_left: usize, sa_left: f32, n_right: usize, sa_right: f32, sa_parent: f32) -> f32 {
18 if sa_parent <= 0.0 {
19 return f32::MAX;
20 }
21 (sa_left / sa_parent) * n_left as f32 + (sa_right / sa_parent) * n_right as f32
22}
23
24pub fn ray_aabb_intersect(origin: [f32; 3], inv_dir: [f32; 3], aabb: &Aabb, max_t: f32) -> bool {
33 let mut t_min = 0.0_f32;
34 let mut t_max = max_t;
35
36 for i in 0..3 {
37 let t1 = (aabb.min[i] - origin[i]) * inv_dir[i];
38 let t2 = (aabb.max[i] - origin[i]) * inv_dir[i];
39 let lo = t1.min(t2);
40 let hi = t1.max(t2);
41 t_min = t_min.max(lo);
42 t_max = t_max.min(hi);
43 }
44
45 t_min <= t_max
46}
47
48pub(crate) fn ray_aabb_t(origin: [f32; 3], inv_dir: [f32; 3], aabb: &Aabb) -> Option<(f32, f32)> {
50 let mut t_min = 0.0_f32;
51 let mut t_max = f32::MAX;
52 for i in 0..3 {
53 let t1 = (aabb.min[i] - origin[i]) * inv_dir[i];
54 let t2 = (aabb.max[i] - origin[i]) * inv_dir[i];
55 t_min = t_min.max(t1.min(t2));
56 t_max = t_max.min(t1.max(t2));
57 }
58 if t_min <= t_max {
59 Some((t_min, t_max))
60 } else {
61 None
62 }
63}
64
65pub(crate) const LEAF_SIZE: usize = 4;
71
72pub struct Bvh {
74 pub root: Option<BvhNode>,
76 pub primitives: Vec<BvhPrimitive>,
78}
79
80impl Bvh {
81 pub fn build(primitives: Vec<BvhPrimitive>) -> Self {
84 if primitives.is_empty() {
85 return Self {
86 root: None,
87 primitives,
88 };
89 }
90 let indices: Vec<usize> = (0..primitives.len()).collect();
91 let root = build_recursive(&primitives, indices);
92 Self {
93 root: Some(root),
94 primitives,
95 }
96 }
97
98 pub fn query_aabb(&self, query: &Aabb) -> Vec<usize> {
100 let mut result = Vec::new();
101 if let Some(root) = &self.root {
102 query_aabb_recursive(root, query, &self.primitives, &mut result);
103 }
104 result
105 }
106
107 pub fn query_ray(&self, origin: [f32; 3], direction: [f32; 3], max_t: f32) -> Vec<usize> {
110 let inv_dir = [1.0 / direction[0], 1.0 / direction[1], 1.0 / direction[2]];
111 let mut result = Vec::new();
112 if let Some(root) = &self.root {
113 query_ray_recursive(root, origin, inv_dir, max_t, &self.primitives, &mut result);
114 }
115 result
116 }
117
118 pub fn node_count(&self) -> usize {
120 match &self.root {
121 None => 0,
122 Some(root) => count_nodes(root),
123 }
124 }
125
126 pub fn depth(&self) -> usize {
128 match &self.root {
129 None => 0,
130 Some(root) => node_depth(root),
131 }
132 }
133}
134
135pub(crate) fn bounding_box(primitives: &[BvhPrimitive], indices: &[usize]) -> Aabb {
140 let mut aabb = primitives[indices[0]].aabb.clone();
141 for &i in &indices[1..] {
142 aabb = Aabb::merge(&aabb, &primitives[i].aabb);
143 }
144 aabb
145}
146
147fn build_recursive(primitives: &[BvhPrimitive], mut indices: Vec<usize>) -> BvhNode {
148 let aabb = bounding_box(primitives, &indices);
149
150 if indices.len() <= LEAF_SIZE {
151 return BvhNode {
152 aabb,
153 left: None,
154 right: None,
155 primitives: indices,
156 };
157 }
158
159 let dx = aabb.max[0] - aabb.min[0];
161 let dy = aabb.max[1] - aabb.min[1];
162 let dz = aabb.max[2] - aabb.min[2];
163 let axis = if dx >= dy && dx >= dz {
164 0
165 } else if dy >= dz {
166 1
167 } else {
168 2
169 };
170
171 indices.sort_unstable_by(|&a, &b| {
173 let ca = primitives[a].aabb.center()[axis];
174 let cb = primitives[b].aabb.center()[axis];
175 ca.partial_cmp(&cb).unwrap_or(std::cmp::Ordering::Equal)
176 });
177
178 let mid = indices.len() / 2;
179 let right_indices = indices.split_off(mid);
180 let left_indices = indices;
181
182 let left = build_recursive(primitives, left_indices);
183 let right = build_recursive(primitives, right_indices);
184
185 BvhNode {
186 aabb,
187 left: Some(Box::new(left)),
188 right: Some(Box::new(right)),
189 primitives: Vec::new(),
190 }
191}
192
193fn query_aabb_recursive(
194 node: &BvhNode,
195 query: &Aabb,
196 primitives: &[BvhPrimitive],
197 result: &mut Vec<usize>,
198) {
199 if !node.aabb.intersects(query) {
200 return;
201 }
202 if node.is_leaf() {
203 for &idx in &node.primitives {
204 if primitives[idx].aabb.intersects(query) {
205 result.push(primitives[idx].object_id);
206 }
207 }
208 } else {
209 if let Some(left) = &node.left {
210 query_aabb_recursive(left, query, primitives, result);
211 }
212 if let Some(right) = &node.right {
213 query_aabb_recursive(right, query, primitives, result);
214 }
215 }
216}
217
218fn query_ray_recursive(
219 node: &BvhNode,
220 origin: [f32; 3],
221 inv_dir: [f32; 3],
222 max_t: f32,
223 primitives: &[BvhPrimitive],
224 result: &mut Vec<usize>,
225) {
226 if !ray_aabb_intersect(origin, inv_dir, &node.aabb, max_t) {
227 return;
228 }
229 if node.is_leaf() {
230 for &idx in &node.primitives {
231 if ray_aabb_intersect(origin, inv_dir, &primitives[idx].aabb, max_t) {
232 result.push(primitives[idx].object_id);
233 }
234 }
235 } else {
236 if let Some(left) = &node.left {
237 query_ray_recursive(left, origin, inv_dir, max_t, primitives, result);
238 }
239 if let Some(right) = &node.right {
240 query_ray_recursive(right, origin, inv_dir, max_t, primitives, result);
241 }
242 }
243}
244
245fn count_nodes(node: &BvhNode) -> usize {
246 1 + node.left.as_ref().map_or(0, |n| count_nodes(n))
247 + node.right.as_ref().map_or(0, |n| count_nodes(n))
248}
249
250fn node_depth(node: &BvhNode) -> usize {
251 1 + node
252 .left
253 .as_ref()
254 .map_or(0, |n| node_depth(n))
255 .max(node.right.as_ref().map_or(0, |n| node_depth(n)))
256}
257
258pub fn flatten(bvh: &Bvh) -> (Vec<FlatBvhNode>, Vec<usize>) {
268 let mut nodes: Vec<FlatBvhNode> = Vec::new();
269 let mut prim_indices: Vec<usize> = Vec::new();
270
271 if let Some(root) = &bvh.root {
272 flatten_recursive(root, &mut nodes, &mut prim_indices);
273 }
274
275 (nodes, prim_indices)
276}
277
278fn flatten_recursive(
280 node: &BvhNode,
281 nodes: &mut Vec<FlatBvhNode>,
282 prim_indices: &mut Vec<usize>,
283) -> usize {
284 let node_idx = nodes.len();
285
286 if node.is_leaf() {
287 let first = prim_indices.len() as u32;
288 let count = node.primitives.len() as u32;
289 prim_indices.extend_from_slice(&node.primitives);
290 nodes.push(FlatBvhNode {
291 aabb: node.aabb.clone(),
292 left_first: first,
293 count,
294 });
295 } else {
296 nodes.push(FlatBvhNode {
298 aabb: node.aabb.clone(),
299 left_first: 0,
300 count: 0,
301 });
302 if let Some(left) = &node.left {
304 flatten_recursive(left, nodes, prim_indices);
305 }
306 let right_idx = if let Some(right) = &node.right {
308 flatten_recursive(right, nodes, prim_indices)
309 } else {
310 0
311 };
312 nodes[node_idx].left_first = right_idx as u32;
313 }
314
315 node_idx
316}
317
318pub fn query_flat(
323 nodes: &[FlatBvhNode],
324 prim_indices: &[usize],
325 bvh_primitives: &[BvhPrimitive],
326 query: &Aabb,
327) -> Vec<usize> {
328 let mut result = Vec::new();
329 if nodes.is_empty() {
330 return result;
331 }
332
333 let mut stack: Vec<usize> = Vec::with_capacity(64);
334 stack.push(0);
335
336 while let Some(idx) = stack.pop() {
337 let node = &nodes[idx];
338 if !node.aabb.intersects(query) {
339 continue;
340 }
341 if node.count > 0 {
342 let start = node.left_first as usize;
344 let end = start + node.count as usize;
345 for &pi in &prim_indices[start..end] {
346 if bvh_primitives[pi].aabb.intersects(query) {
347 result.push(bvh_primitives[pi].object_id);
348 }
349 }
350 } else {
351 let right = node.left_first as usize;
353 stack.push(right);
354 stack.push(idx + 1);
355 }
356 }
357
358 result
359}
360
361pub fn bvh_closest_hit(
367 bvh: &Bvh,
368 origin: [f32; 3],
369 direction: [f32; 3],
370 max_t: f32,
371) -> Option<RayHit> {
372 let inv_dir = [1.0 / direction[0], 1.0 / direction[1], 1.0 / direction[2]];
373 let root = bvh.root.as_ref()?;
374 let mut best: Option<RayHit> = None;
375 let mut current_max = max_t;
376 closest_hit_recursive(
377 root,
378 origin,
379 inv_dir,
380 &bvh.primitives,
381 &mut best,
382 &mut current_max,
383 );
384 best
385}
386
387fn closest_hit_recursive(
388 node: &BvhNode,
389 origin: [f32; 3],
390 inv_dir: [f32; 3],
391 primitives: &[BvhPrimitive],
392 best: &mut Option<RayHit>,
393 max_t: &mut f32,
394) {
395 if ray_aabb_t(origin, inv_dir, &node.aabb).is_none() {
396 return;
397 }
398 if node.is_leaf() {
399 for &idx in &node.primitives {
400 if let Some((t_min, _)) = ray_aabb_t(origin, inv_dir, &primitives[idx].aabb)
401 && t_min >= 0.0
402 && t_min < *max_t
403 {
404 *max_t = t_min;
405 *best = Some(RayHit {
406 object_id: primitives[idx].object_id,
407 t: t_min,
408 });
409 }
410 }
411 } else {
412 if let Some(left) = &node.left {
413 closest_hit_recursive(left, origin, inv_dir, primitives, best, max_t);
414 }
415 if let Some(right) = &node.right {
416 closest_hit_recursive(right, origin, inv_dir, primitives, best, max_t);
417 }
418 }
419}
420
421pub fn refit(node: &mut BvhNode, primitives: &[BvhPrimitive]) {
429 if node.is_leaf() {
430 if !node.primitives.is_empty() {
431 node.aabb = bounding_box(primitives, &node.primitives);
432 }
433 return;
434 }
435 if let Some(left) = node.left.as_mut() {
436 refit(left, primitives);
437 }
438 if let Some(right) = node.right.as_mut() {
439 refit(right, primitives);
440 }
441 let left_aabb = node.left.as_ref().map(|n| n.aabb.clone());
443 let right_aabb = node.right.as_ref().map(|n| n.aabb.clone());
444 node.aabb = match (left_aabb, right_aabb) {
445 (Some(l), Some(r)) => Aabb::merge(&l, &r),
446 (Some(l), None) => l,
447 (None, Some(r)) => r,
448 (None, None) => node.aabb.clone(),
449 };
450}
451
452fn expand_bits(mut v: u32) -> u32 {
458 v = (v | (v << 16)) & 0x030000FF;
459 v = (v | (v << 8)) & 0x0300F00F;
460 v = (v | (v << 4)) & 0x030C30C3;
461 v = (v | (v << 2)) & 0x09249249;
462 v
463}
464
465pub fn morton_code(p: [f32; 3]) -> u32 {
467 let x = (p[0].clamp(0.0, 1.0) * 1023.0) as u32;
468 let y = (p[1].clamp(0.0, 1.0) * 1023.0) as u32;
469 let z = (p[2].clamp(0.0, 1.0) * 1023.0) as u32;
470 expand_bits(x) | (expand_bits(y) << 1) | (expand_bits(z) << 2)
471}
472
473impl LbvhPrimitive {
474 pub fn new(aabb: Aabb, object_id: usize, scene_aabb: &Aabb) -> Self {
477 let c = aabb.center();
478 let scene_size = [
479 (scene_aabb.max[0] - scene_aabb.min[0]).max(1e-10),
480 (scene_aabb.max[1] - scene_aabb.min[1]).max(1e-10),
481 (scene_aabb.max[2] - scene_aabb.min[2]).max(1e-10),
482 ];
483 let norm = [
484 (c[0] - scene_aabb.min[0]) / scene_size[0],
485 (c[1] - scene_aabb.min[1]) / scene_size[1],
486 (c[2] - scene_aabb.min[2]) / scene_size[2],
487 ];
488 let morton = morton_code(norm);
489 Self {
490 aabb,
491 object_id,
492 morton,
493 }
494 }
495}
496
497pub fn lbvh_build(primitives: Vec<BvhPrimitive>) -> Bvh {
502 if primitives.is_empty() {
503 return Bvh {
504 root: None,
505 primitives,
506 };
507 }
508
509 let mut scene = primitives[0].aabb.clone();
511 for p in &primitives[1..] {
512 scene = Aabb::merge(&scene, &p.aabb);
513 }
514
515 let mut indexed: Vec<(u32, usize)> = primitives
517 .iter()
518 .enumerate()
519 .map(|(i, p)| {
520 let lp = LbvhPrimitive::new(p.aabb.clone(), p.object_id, &scene);
521 (lp.morton, i)
522 })
523 .collect();
524 indexed.sort_unstable_by_key(|&(m, _)| m);
525
526 let sorted_indices: Vec<usize> = indexed.iter().map(|&(_, i)| i).collect();
527 let root = lbvh_recursive(&primitives, &sorted_indices);
528
529 Bvh {
530 root: Some(root),
531 primitives,
532 }
533}
534
535fn lbvh_recursive(primitives: &[BvhPrimitive], indices: &[usize]) -> BvhNode {
536 let aabb = bounding_box(primitives, indices);
537
538 if indices.len() <= LEAF_SIZE {
539 return BvhNode {
540 aabb,
541 left: None,
542 right: None,
543 primitives: indices.to_vec(),
544 };
545 }
546
547 let mid = indices.len() / 2;
548 let left = lbvh_recursive(primitives, &indices[..mid]);
549 let right = lbvh_recursive(primitives, &indices[mid..]);
550
551 BvhNode {
552 aabb,
553 left: Some(Box::new(left)),
554 right: Some(Box::new(right)),
555 primitives: Vec::new(),
556 }
557}
558
559pub fn hlbvh_split(mortons: &[u32]) -> usize {
568 if mortons.len() < 2 {
569 return 1;
570 }
571 let first = mortons[0];
572 let last = mortons[mortons.len() - 1];
573 let common_prefix = (first ^ last).leading_zeros();
574 let mut lo = 0usize;
576 let mut hi = mortons.len() - 1;
577 while hi - lo > 1 {
578 let mid = (lo + hi) / 2;
579 let prefix = (first ^ mortons[mid]).leading_zeros();
580 if prefix > common_prefix {
581 lo = mid;
582 } else {
583 hi = mid;
584 }
585 }
586 hi
587}
588
589pub fn compute_bvh_from_sorted(sorted: &[LbvhPrimitive]) -> Bvh {
596 if sorted.is_empty() {
597 return Bvh {
598 root: None,
599 primitives: Vec::new(),
600 };
601 }
602
603 let primitives: Vec<BvhPrimitive> = sorted
605 .iter()
606 .map(|lp| BvhPrimitive::new(lp.aabb.clone(), lp.object_id))
607 .collect();
608
609 let mortons: Vec<u32> = sorted.iter().map(|lp| lp.morton).collect();
610 let indices: Vec<usize> = (0..primitives.len()).collect();
611 let root = bvh_from_sorted_recursive(&primitives, &indices, &mortons);
612 Bvh {
613 root: Some(root),
614 primitives,
615 }
616}
617
618fn bvh_from_sorted_recursive(
619 primitives: &[BvhPrimitive],
620 indices: &[usize],
621 mortons: &[u32],
622) -> BvhNode {
623 let aabb = bounding_box(primitives, indices);
624 if indices.len() <= LEAF_SIZE {
625 return BvhNode {
626 aabb,
627 left: None,
628 right: None,
629 primitives: indices.to_vec(),
630 };
631 }
632 let local_mortons: Vec<u32> = indices.iter().map(|&i| mortons[i]).collect();
634 let split = hlbvh_split(&local_mortons);
635 let left = bvh_from_sorted_recursive(primitives, &indices[..split], mortons);
636 let right = bvh_from_sorted_recursive(primitives, &indices[split..], mortons);
637 BvhNode {
638 aabb,
639 left: Some(Box::new(left)),
640 right: Some(Box::new(right)),
641 primitives: Vec::new(),
642 }
643}
644
645pub fn compute_cluster_radius(cluster: &[LbvhPrimitive]) -> f32 {
647 if cluster.is_empty() {
648 return 0.0;
649 }
650 let mut merged = cluster[0].aabb.clone();
652 for lp in &cluster[1..] {
653 merged = Aabb::merge(&merged, &lp.aabb);
654 }
655 let cx = (merged.min[0] + merged.max[0]) * 0.5;
656 let cy = (merged.min[1] + merged.max[1]) * 0.5;
657 let cz = (merged.min[2] + merged.max[2]) * 0.5;
658
659 let mut max_dist_sq = 0.0_f32;
660 for lp in cluster {
661 let c = lp.aabb.center();
662 let dx = c[0] - cx;
663 let dy = c[1] - cy;
664 let dz = c[2] - cz;
665 let d2 = dx * dx + dy * dy + dz * dz;
666 if d2 > max_dist_sq {
667 max_dist_sq = d2;
668 }
669 }
670 max_dist_sq.sqrt()
671}
672
673pub fn build_morton_clusters(sorted: &[LbvhPrimitive], cluster_size: usize) -> Vec<MortonCluster> {
676 if sorted.is_empty() || cluster_size == 0 {
677 return Vec::new();
678 }
679 sorted
680 .chunks(cluster_size)
681 .map(|chunk| {
682 let indices: Vec<usize> = (0..chunk.len()).collect();
683 let mut aabb = chunk[0].aabb.clone();
684 for lp in &chunk[1..] {
685 aabb = Aabb::merge(&aabb, &lp.aabb);
686 }
687 let radius = compute_cluster_radius(chunk);
688 MortonCluster {
689 indices,
690 aabb,
691 radius,
692 }
693 })
694 .collect()
695}
696
697impl BvhStats {
702 pub fn compute(bvh: &Bvh) -> Self {
704 let mut s = BvhStats {
705 node_count: 0,
706 leaf_count: 0,
707 internal_count: 0,
708 max_depth: 0,
709 total_primitives: 0,
710 avg_primitives_per_leaf: 0.0,
711 };
712 if let Some(root) = &bvh.root {
713 collect_stats(root, 1, &mut s);
714 }
715 if s.leaf_count > 0 {
716 s.avg_primitives_per_leaf = s.total_primitives as f32 / s.leaf_count as f32;
717 }
718 s
719 }
720}
721
722fn collect_stats(node: &BvhNode, depth: usize, s: &mut BvhStats) {
723 s.node_count += 1;
724 if depth > s.max_depth {
725 s.max_depth = depth;
726 }
727 if node.is_leaf() {
728 s.leaf_count += 1;
729 s.total_primitives += node.primitives.len();
730 } else {
731 s.internal_count += 1;
732 if let Some(left) = &node.left {
733 collect_stats(left, depth + 1, s);
734 }
735 if let Some(right) = &node.right {
736 collect_stats(right, depth + 1, s);
737 }
738 }
739}
740
741impl BvhTreeStatistics {
742 pub fn compute(bvh: &Bvh) -> Self {
744 let mut s = BvhTreeStatistics {
745 node_count: 0,
746 leaf_count: 0,
747 internal_count: 0,
748 max_depth: 0,
749 total_primitives: 0,
750 avg_fanout: 0.0,
751 total_leaf_surface_area: 0.0,
752 };
753 if let Some(root) = &bvh.root {
754 let mut child_sum = 0usize;
755 collect_tree_stats(root, 1, &mut s, &mut child_sum);
756 s.avg_fanout = if s.internal_count > 0 {
757 child_sum as f32 / s.internal_count as f32
758 } else {
759 0.0
760 };
761 }
762 s
763 }
764}
765
766fn collect_tree_stats(
767 node: &BvhNode,
768 depth: usize,
769 s: &mut BvhTreeStatistics,
770 child_sum: &mut usize,
771) {
772 s.node_count += 1;
773 if depth > s.max_depth {
774 s.max_depth = depth;
775 }
776 if node.is_leaf() {
777 s.leaf_count += 1;
778 s.total_primitives += node.primitives.len();
779 s.total_leaf_surface_area += node.aabb.surface_area();
780 } else {
781 s.internal_count += 1;
782 let mut children = 0usize;
783 if let Some(left) = &node.left {
784 children += 1;
785 collect_tree_stats(left, depth + 1, s, child_sum);
786 }
787 if let Some(right) = &node.right {
788 children += 1;
789 collect_tree_stats(right, depth + 1, s, child_sum);
790 }
791 *child_sum += children;
792 }
793}