1#[derive(Debug, Clone, Copy)]
11pub struct Ray {
12 pub origin: [f64; 3],
14 pub direction: [f64; 3],
16}
17
18impl Ray {
19 pub fn new(origin: [f64; 3], direction: [f64; 3]) -> Self {
21 Self { origin, direction }
22 }
23
24 pub fn at(&self, t: f64) -> [f64; 3] {
26 [
27 self.origin[0] + t * self.direction[0],
28 self.origin[1] + t * self.direction[1],
29 self.origin[2] + t * self.direction[2],
30 ]
31 }
32}
33
34#[derive(Debug, Clone, Copy)]
36pub struct Aabb {
37 pub min: [f64; 3],
39 pub max: [f64; 3],
41}
42
43impl Aabb {
44 pub fn new(min: [f64; 3], max: [f64; 3]) -> Self {
46 Self { min, max }
47 }
48
49 pub fn union(&self, other: &Aabb) -> Aabb {
51 Aabb {
52 min: [
53 self.min[0].min(other.min[0]),
54 self.min[1].min(other.min[1]),
55 self.min[2].min(other.min[2]),
56 ],
57 max: [
58 self.max[0].max(other.max[0]),
59 self.max[1].max(other.max[1]),
60 self.max[2].max(other.max[2]),
61 ],
62 }
63 }
64
65 pub fn centroid(&self) -> [f64; 3] {
67 [
68 (self.min[0] + self.max[0]) * 0.5,
69 (self.min[1] + self.max[1]) * 0.5,
70 (self.min[2] + self.max[2]) * 0.5,
71 ]
72 }
73}
74
75#[derive(Debug, Clone, Copy)]
77pub struct Triangle {
78 pub v0: [f64; 3],
80 pub v1: [f64; 3],
82 pub v2: [f64; 3],
84}
85
86impl Triangle {
87 pub fn new(v0: [f64; 3], v1: [f64; 3], v2: [f64; 3]) -> Self {
89 Self { v0, v1, v2 }
90 }
91
92 pub fn aabb(&self) -> Aabb {
94 Aabb {
95 min: [
96 self.v0[0].min(self.v1[0]).min(self.v2[0]),
97 self.v0[1].min(self.v1[1]).min(self.v2[1]),
98 self.v0[2].min(self.v1[2]).min(self.v2[2]),
99 ],
100 max: [
101 self.v0[0].max(self.v1[0]).max(self.v2[0]),
102 self.v0[1].max(self.v1[1]).max(self.v2[1]),
103 self.v0[2].max(self.v1[2]).max(self.v2[2]),
104 ],
105 }
106 }
107}
108
109#[derive(Debug, Clone)]
111pub struct BvhNode {
112 pub bounds: Aabb,
114 pub left: usize,
116 pub right: usize,
118 pub triangle_index: usize,
120}
121
122impl BvhNode {
123 pub fn is_leaf(&self) -> bool {
125 self.triangle_index != usize::MAX
126 }
127}
128
129#[derive(Debug, Clone, Copy)]
131pub struct HitRecord {
132 pub t: f64,
134 pub triangle_index: usize,
136 pub uv: [f64; 2],
138}
139
140fn dot3(a: [f64; 3], b: [f64; 3]) -> f64 {
143 a[0] * b[0] + a[1] * b[1] + a[2] * b[2]
144}
145
146fn cross3(a: [f64; 3], b: [f64; 3]) -> [f64; 3] {
147 [
148 a[1] * b[2] - a[2] * b[1],
149 a[2] * b[0] - a[0] * b[2],
150 a[0] * b[1] - a[1] * b[0],
151 ]
152}
153
154fn sub3(a: [f64; 3], b: [f64; 3]) -> [f64; 3] {
155 [a[0] - b[0], a[1] - b[1], a[2] - b[2]]
156}
157
158pub fn ray_aabb_intersect(ray: &Ray, aabb: &Aabb, t_min: f64, t_max: f64) -> Option<f64> {
165 let mut t_lo = t_min;
166 let mut t_hi = t_max;
167
168 for axis in 0..3 {
169 let inv_d = if ray.direction[axis].abs() > 1e-15 {
170 1.0 / ray.direction[axis]
171 } else {
172 f64::INFINITY
173 };
174 let mut t0 = (aabb.min[axis] - ray.origin[axis]) * inv_d;
175 let mut t1 = (aabb.max[axis] - ray.origin[axis]) * inv_d;
176 if inv_d < 0.0 {
177 std::mem::swap(&mut t0, &mut t1);
178 }
179 t_lo = t_lo.max(t0);
180 t_hi = t_hi.min(t1);
181 if t_hi < t_lo {
182 return None;
183 }
184 }
185 Some(t_lo)
186}
187
188pub fn ray_triangle_intersect(
193 ray: &Ray,
194 tri: &Triangle,
195 tri_index: usize,
196 t_min: f64,
197 t_max: f64,
198) -> Option<HitRecord> {
199 const EPSILON: f64 = 1e-10;
200
201 let edge1 = sub3(tri.v1, tri.v0);
202 let edge2 = sub3(tri.v2, tri.v0);
203 let h = cross3(ray.direction, edge2);
204 let det = dot3(edge1, h);
205
206 if det.abs() < EPSILON {
207 return None; }
209
210 let inv_det = 1.0 / det;
211 let s = sub3(ray.origin, tri.v0);
212 let u = inv_det * dot3(s, h);
213
214 if !(0.0..=1.0).contains(&u) {
215 return None;
216 }
217
218 let q = cross3(s, edge1);
219 let v = inv_det * dot3(ray.direction, q);
220
221 if v < 0.0 || u + v > 1.0 {
222 return None;
223 }
224
225 let t = inv_det * dot3(edge2, q);
226 if t < t_min || t > t_max {
227 return None;
228 }
229
230 Some(HitRecord {
231 t,
232 triangle_index: tri_index,
233 uv: [u, v],
234 })
235}
236
237pub fn build_bvh(triangles: &[Triangle]) -> Vec<BvhNode> {
241 if triangles.is_empty() {
242 return Vec::new();
243 }
244
245 let mut nodes: Vec<BvhNode> = Vec::new();
246
247 let mut leaf_indices: Vec<usize> = (0..triangles.len()).collect();
249
250 fn build_recursive(
251 tris: &[Triangle],
252 indices: &mut [usize],
253 nodes: &mut Vec<BvhNode>,
254 ) -> usize {
255 if indices.len() == 1 {
256 let tri_idx = indices[0];
257 let bounds = tris[tri_idx].aabb();
258 let node = BvhNode {
259 bounds,
260 left: usize::MAX,
261 right: usize::MAX,
262 triangle_index: tri_idx,
263 };
264 let idx = nodes.len();
265 nodes.push(node);
266 return idx;
267 }
268
269 let mut combined = tris[indices[0]].aabb();
271 for &i in indices.iter().skip(1) {
272 combined = combined.union(&tris[i].aabb());
273 }
274
275 let extent = [
277 combined.max[0] - combined.min[0],
278 combined.max[1] - combined.min[1],
279 combined.max[2] - combined.min[2],
280 ];
281 let axis = if extent[0] >= extent[1] && extent[0] >= extent[2] {
282 0
283 } else if extent[1] >= extent[2] {
284 1
285 } else {
286 2
287 };
288
289 indices.sort_by(|&a, &b| {
291 let ca = tris[a].aabb().centroid()[axis];
292 let cb = tris[b].aabb().centroid()[axis];
293 ca.partial_cmp(&cb).unwrap_or(std::cmp::Ordering::Equal)
294 });
295
296 let mid = indices.len() / 2;
297 let (left_ids, right_ids) = indices.split_at_mut(mid);
298
299 let left_child = build_recursive(tris, left_ids, nodes);
300 let right_child = build_recursive(tris, right_ids, nodes);
301
302 let left_bounds = nodes[left_child].bounds;
303 let right_bounds = nodes[right_child].bounds;
304 let node = BvhNode {
305 bounds: left_bounds.union(&right_bounds),
306 left: left_child,
307 right: right_child,
308 triangle_index: usize::MAX,
309 };
310 let idx = nodes.len();
311 nodes.push(node);
312 idx
313 }
314
315 build_recursive(triangles, &mut leaf_indices, &mut nodes);
316 nodes
317}
318
319pub fn traverse_bvh(
324 ray: &Ray,
325 nodes: &[BvhNode],
326 triangles: &[Triangle],
327 root: usize,
328 t_min: f64,
329 t_max: f64,
330) -> Option<HitRecord> {
331 if nodes.is_empty() {
332 return None;
333 }
334
335 let mut best: Option<HitRecord> = None;
336 let mut t_closest = t_max;
337
338 let mut stack = Vec::with_capacity(64);
340 stack.push(root);
341
342 while let Some(node_idx) = stack.pop() {
343 if node_idx >= nodes.len() {
344 continue;
345 }
346 let node = &nodes[node_idx];
347
348 if ray_aabb_intersect(ray, &node.bounds, t_min, t_closest).is_none() {
350 continue;
351 }
352
353 if node.is_leaf() {
354 if node.triangle_index < triangles.len()
355 && let Some(hit) = ray_triangle_intersect(
356 ray,
357 &triangles[node.triangle_index],
358 node.triangle_index,
359 t_min,
360 t_closest,
361 )
362 {
363 t_closest = hit.t;
364 best = Some(hit);
365 }
366 } else {
367 if node.left != usize::MAX {
368 stack.push(node.left);
369 }
370 if node.right != usize::MAX {
371 stack.push(node.right);
372 }
373 }
374 }
375
376 best
377}
378
379pub fn batch_ray_cast(
384 rays: &[Ray],
385 nodes: &[BvhNode],
386 triangles: &[Triangle],
387 root: usize,
388) -> Vec<Option<HitRecord>> {
389 rays.iter()
390 .map(|ray| traverse_bvh(ray, nodes, triangles, root, 1e-4, f64::INFINITY))
391 .collect()
392}
393
394#[cfg(test)]
397mod tests {
398 use super::*;
399
400 fn unit_box_aabb() -> Aabb {
401 Aabb::new([0.0, 0.0, 0.0], [1.0, 1.0, 1.0])
402 }
403
404 fn simple_tri() -> Triangle {
405 Triangle::new([0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0])
406 }
407
408 #[test]
410 fn test_ray_at() {
411 let ray = Ray::new([0.0, 0.0, 0.0], [1.0, 0.0, 0.0]);
412 let p = ray.at(3.0);
413 assert!((p[0] - 3.0).abs() < 1e-12);
414 assert!(p[1].abs() < 1e-12);
415 assert!(p[2].abs() < 1e-12);
416 }
417
418 #[test]
419 fn test_ray_at_negative_t() {
420 let ray = Ray::new([1.0, 0.0, 0.0], [1.0, 0.0, 0.0]);
421 let p = ray.at(-1.0);
422 assert!((p[0]).abs() < 1e-12);
423 }
424
425 #[test]
427 fn test_aabb_union() {
428 let a = Aabb::new([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]);
429 let b = Aabb::new([0.5, 0.5, 0.5], [2.0, 2.0, 2.0]);
430 let u = a.union(&b);
431 assert!((u.max[0] - 2.0).abs() < 1e-12);
432 assert!((u.min[0]).abs() < 1e-12);
433 }
434
435 #[test]
436 fn test_aabb_centroid() {
437 let aabb = Aabb::new([0.0, 0.0, 0.0], [2.0, 4.0, 6.0]);
438 let c = aabb.centroid();
439 assert!((c[0] - 1.0).abs() < 1e-12);
440 assert!((c[1] - 2.0).abs() < 1e-12);
441 assert!((c[2] - 3.0).abs() < 1e-12);
442 }
443
444 #[test]
446 fn test_triangle_aabb() {
447 let tri = simple_tri();
448 let aabb = tri.aabb();
449 assert!((aabb.max[0] - 1.0).abs() < 1e-12);
450 assert!((aabb.max[1] - 1.0).abs() < 1e-12);
451 assert!((aabb.max[2]).abs() < 1e-12);
452 }
453
454 #[test]
456 fn test_ray_aabb_hit() {
457 let ray = Ray::new([0.5, 0.5, -1.0], [0.0, 0.0, 1.0]);
458 let aabb = unit_box_aabb();
459 let result = ray_aabb_intersect(&ray, &aabb, 0.0, f64::INFINITY);
460 assert!(result.is_some());
461 let t = result.unwrap();
462 assert!((t - 1.0).abs() < 1e-10);
463 }
464
465 #[test]
467 fn test_ray_aabb_miss() {
468 let ray = Ray::new([2.0, 2.0, -1.0], [0.0, 0.0, 1.0]);
469 let aabb = unit_box_aabb();
470 assert!(ray_aabb_intersect(&ray, &aabb, 0.0, f64::INFINITY).is_none());
471 }
472
473 #[test]
475 fn test_ray_aabb_inside() {
476 let ray = Ray::new([0.5, 0.5, 0.5], [0.0, 0.0, 1.0]);
477 let aabb = unit_box_aabb();
478 let result = ray_aabb_intersect(&ray, &aabb, 0.0, f64::INFINITY);
479 assert!(result.is_some());
480 }
481
482 #[test]
484 fn test_ray_aabb_behind() {
485 let ray = Ray::new([0.5, 0.5, 5.0], [0.0, 0.0, 1.0]);
486 let aabb = unit_box_aabb();
487 assert!(ray_aabb_intersect(&ray, &aabb, 0.0, f64::INFINITY).is_none());
488 }
489
490 #[test]
492 fn test_ray_triangle_hit() {
493 let tri = Triangle::new([0.0, 0.0, 0.0], [2.0, 0.0, 0.0], [0.0, 2.0, 0.0]);
495 let ray = Ray::new([0.5, 0.5, 1.0], [0.0, 0.0, -1.0]);
496 let result = ray_triangle_intersect(&ray, &tri, 0, 0.0, f64::INFINITY);
497 assert!(result.is_some());
498 let hit = result.unwrap();
499 assert!((hit.t - 1.0).abs() < 1e-9);
500 }
501
502 #[test]
504 fn test_ray_triangle_miss_outside() {
505 let tri = Triangle::new([0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]);
506 let ray = Ray::new([2.0, 2.0, 1.0], [0.0, 0.0, -1.0]);
507 assert!(ray_triangle_intersect(&ray, &tri, 0, 0.0, f64::INFINITY).is_none());
508 }
509
510 #[test]
512 fn test_ray_triangle_parallel() {
513 let tri = Triangle::new([0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]);
514 let ray = Ray::new([0.0, 0.0, 1.0], [1.0, 0.0, 0.0]); assert!(ray_triangle_intersect(&ray, &tri, 0, 0.0, f64::INFINITY).is_none());
516 }
517
518 #[test]
520 fn test_ray_triangle_t_range() {
521 let tri = Triangle::new([0.0, 0.0, 0.0], [2.0, 0.0, 0.0], [0.0, 2.0, 0.0]);
522 let ray = Ray::new([0.5, 0.5, 1.0], [0.0, 0.0, -1.0]);
523 assert!(ray_triangle_intersect(&ray, &tri, 0, 0.0, 0.5).is_none());
525 }
526
527 #[test]
529 fn test_build_bvh_single() {
530 let tris = vec![simple_tri()];
531 let nodes = build_bvh(&tris);
532 assert!(!nodes.is_empty());
533 assert!(nodes.last().unwrap().is_leaf());
534 }
535
536 #[test]
538 fn test_build_bvh_empty() {
539 let nodes = build_bvh(&[]);
540 assert!(nodes.is_empty());
541 }
542
543 #[test]
545 fn test_build_bvh_multiple() {
546 let tris = vec![
547 Triangle::new([0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]),
548 Triangle::new([2.0, 0.0, 0.0], [3.0, 0.0, 0.0], [2.0, 1.0, 0.0]),
549 Triangle::new([4.0, 0.0, 0.0], [5.0, 0.0, 0.0], [4.0, 1.0, 0.0]),
550 Triangle::new([6.0, 0.0, 0.0], [7.0, 0.0, 0.0], [6.0, 1.0, 0.0]),
551 ];
552 let nodes = build_bvh(&tris);
553 assert!(!nodes.is_empty());
554 let root = nodes.len() - 1;
556 assert!(!nodes[root].is_leaf());
557 }
558
559 #[test]
561 fn test_traverse_bvh_hit() {
562 let tris = vec![
563 Triangle::new([0.0, 0.0, 0.0], [2.0, 0.0, 0.0], [0.0, 2.0, 0.0]),
564 Triangle::new([3.0, 0.0, 0.0], [5.0, 0.0, 0.0], [3.0, 2.0, 0.0]),
565 ];
566 let nodes = build_bvh(&tris);
567 let root = nodes.len() - 1;
568 let ray = Ray::new([0.5, 0.5, 1.0], [0.0, 0.0, -1.0]);
569 let hit = traverse_bvh(&ray, &nodes, &tris, root, 1e-4, f64::INFINITY);
570 assert!(hit.is_some());
571 }
572
573 #[test]
575 fn test_traverse_bvh_miss() {
576 let tris = vec![Triangle::new(
577 [0.0, 0.0, 0.0],
578 [1.0, 0.0, 0.0],
579 [0.0, 1.0, 0.0],
580 )];
581 let nodes = build_bvh(&tris);
582 let root = nodes.len() - 1;
583 let ray = Ray::new([5.0, 5.0, 1.0], [0.0, 0.0, -1.0]);
584 let hit = traverse_bvh(&ray, &nodes, &tris, root, 1e-4, f64::INFINITY);
585 assert!(hit.is_none());
586 }
587
588 #[test]
590 fn test_traverse_bvh_empty_nodes() {
591 let ray = Ray::new([0.0, 0.0, 0.0], [0.0, 0.0, 1.0]);
592 let hit = traverse_bvh(&ray, &[], &[], 0, 0.0, f64::INFINITY);
593 assert!(hit.is_none());
594 }
595
596 #[test]
598 fn test_batch_ray_cast() {
599 let tris = vec![Triangle::new(
600 [0.0, 0.0, 0.0],
601 [2.0, 0.0, 0.0],
602 [0.0, 2.0, 0.0],
603 )];
604 let nodes = build_bvh(&tris);
605 let root = nodes.len() - 1;
606 let rays = vec![
607 Ray::new([0.5, 0.5, 1.0], [0.0, 0.0, -1.0]),
608 Ray::new([5.0, 5.0, 1.0], [0.0, 0.0, -1.0]),
609 ];
610 let results = batch_ray_cast(&rays, &nodes, &tris, root);
611 assert_eq!(results.len(), 2);
612 assert!(results[0].is_some());
613 assert!(results[1].is_none());
614 }
615
616 #[test]
618 fn test_batch_ray_cast_empty_rays() {
619 let tris = vec![simple_tri()];
620 let nodes = build_bvh(&tris);
621 let root = nodes.len() - 1;
622 let results = batch_ray_cast(&[], &nodes, &tris, root);
623 assert!(results.is_empty());
624 }
625
626 #[test]
628 fn test_bvh_node_is_leaf() {
629 let node = BvhNode {
630 bounds: unit_box_aabb(),
631 left: usize::MAX,
632 right: usize::MAX,
633 triangle_index: 0,
634 };
635 assert!(node.is_leaf());
636 }
637
638 #[test]
639 fn test_bvh_node_not_leaf() {
640 let node = BvhNode {
641 bounds: unit_box_aabb(),
642 left: 0,
643 right: 1,
644 triangle_index: usize::MAX,
645 };
646 assert!(!node.is_leaf());
647 }
648
649 #[test]
651 fn test_hit_record_uv() {
652 let tris = [Triangle::new(
653 [0.0, 0.0, 0.0],
654 [4.0, 0.0, 0.0],
655 [0.0, 4.0, 0.0],
656 )];
657 let ray = Ray::new([1.0, 1.0, 1.0], [0.0, 0.0, -1.0]);
658 let hit = ray_triangle_intersect(&ray, &tris[0], 0, 0.0, f64::INFINITY);
659 assert!(hit.is_some());
660 let h = hit.unwrap();
661 assert!(h.uv[0] >= 0.0 && h.uv[0] <= 1.0);
662 assert!(h.uv[1] >= 0.0 && h.uv[1] <= 1.0);
663 }
664
665 #[test]
667 fn test_batch_returns_closest_hit() {
668 let tris = vec![
669 Triangle::new([0.0, 0.0, 2.0], [2.0, 0.0, 2.0], [0.0, 2.0, 2.0]),
670 Triangle::new([0.0, 0.0, 5.0], [2.0, 0.0, 5.0], [0.0, 2.0, 5.0]),
671 ];
672 let nodes = build_bvh(&tris);
673 let root = nodes.len() - 1;
674 let rays = vec![Ray::new([0.5, 0.5, 0.0], [0.0, 0.0, 1.0])];
675 let results = batch_ray_cast(&rays, &nodes, &tris, root);
676 if let Some(hit) = results[0] {
678 assert!((hit.t - 2.0).abs() < 1e-9);
679 }
680 }
681
682 #[test]
684 fn test_build_bvh_8_triangles() {
685 let tris: Vec<Triangle> = (0..8)
686 .map(|i| {
687 let x = i as f64 * 2.0;
688 Triangle::new([x, 0.0, 0.0], [x + 1.0, 0.0, 0.0], [x, 1.0, 0.0])
689 })
690 .collect();
691 let nodes = build_bvh(&tris);
692 assert_eq!(nodes.len(), 2 * tris.len() - 1);
694 }
695
696 #[test]
698 fn test_ray_triangle_t_value() {
699 let tri = Triangle::new([0.0, 0.0, 0.0], [2.0, 0.0, 0.0], [0.0, 2.0, 0.0]);
700 let ray = Ray::new([0.5, 0.5, 3.0], [0.0, 0.0, -1.0]);
701 let hit = ray_triangle_intersect(&ray, &tri, 0, 0.0, f64::INFINITY);
702 assert!(hit.is_some());
703 assert!((hit.unwrap().t - 3.0).abs() < 1e-9);
704 }
705
706 #[test]
708 fn test_ray_aabb_near_zero_dir_component() {
709 let ray = Ray::new([0.5, 0.5, -1.0], [0.0, 0.0, 1.0]);
711 let aabb = unit_box_aabb();
712 let result = ray_aabb_intersect(&ray, &aabb, 0.0, f64::INFINITY);
713 assert!(result.is_some());
714 }
715
716 #[test]
718 fn test_traverse_picks_correct_triangle() {
719 let tris = vec![
720 Triangle::new([0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]),
721 Triangle::new([10.0, 0.0, 0.0], [11.0, 0.0, 0.0], [10.0, 1.0, 0.0]),
722 ];
723 let nodes = build_bvh(&tris);
724 let root = nodes.len() - 1;
725 let ray = Ray::new([10.2, 0.2, 1.0], [0.0, 0.0, -1.0]);
726 let hit = traverse_bvh(&ray, &nodes, &tris, root, 1e-4, f64::INFINITY);
727 assert!(hit.is_some());
728 assert_eq!(hit.unwrap().triangle_index, 1);
729 }
730}