Skip to main content

oxiphysics_gpu/
gpu_ray_tracing.rs

1// Copyright 2026 COOLJAPAN OU (Team KitaSan)
2// SPDX-License-Identifier: Apache-2.0
3
4//! GPU ray tracing (CPU mock implementation).
5//!
6//! Provides BVH construction, ray-AABB intersection (slab method),
7//! ray-triangle intersection (Möller-Trumbore), and BVH traversal.
8
9/// A ray defined by an origin and a direction.
10#[derive(Debug, Clone, Copy)]
11pub struct Ray {
12    /// Ray origin in world space.
13    pub origin: [f64; 3],
14    /// Ray direction (should be normalised for correct distance output).
15    pub direction: [f64; 3],
16}
17
18impl Ray {
19    /// Create a new ray.
20    pub fn new(origin: [f64; 3], direction: [f64; 3]) -> Self {
21        Self { origin, direction }
22    }
23
24    /// Evaluate the ray at parameter `t`: `origin + t * direction`.
25    pub fn at(&self, t: f64) -> [f64; 3] {
26        [
27            self.origin[0] + t * self.direction[0],
28            self.origin[1] + t * self.direction[1],
29            self.origin[2] + t * self.direction[2],
30        ]
31    }
32}
33
34/// Axis-aligned bounding box (AABB).
35#[derive(Debug, Clone, Copy)]
36pub struct Aabb {
37    /// Minimum corner.
38    pub min: [f64; 3],
39    /// Maximum corner.
40    pub max: [f64; 3],
41}
42
43impl Aabb {
44    /// Create a new AABB.
45    pub fn new(min: [f64; 3], max: [f64; 3]) -> Self {
46        Self { min, max }
47    }
48
49    /// Compute the AABB that encloses both `self` and `other`.
50    pub fn union(&self, other: &Aabb) -> Aabb {
51        Aabb {
52            min: [
53                self.min[0].min(other.min[0]),
54                self.min[1].min(other.min[1]),
55                self.min[2].min(other.min[2]),
56            ],
57            max: [
58                self.max[0].max(other.max[0]),
59                self.max[1].max(other.max[1]),
60                self.max[2].max(other.max[2]),
61            ],
62        }
63    }
64
65    /// Return the centroid of this AABB.
66    pub fn centroid(&self) -> [f64; 3] {
67        [
68            (self.min[0] + self.max[0]) * 0.5,
69            (self.min[1] + self.max[1]) * 0.5,
70            (self.min[2] + self.max[2]) * 0.5,
71        ]
72    }
73}
74
75/// A triangle defined by three vertices.
76#[derive(Debug, Clone, Copy)]
77pub struct Triangle {
78    /// First vertex.
79    pub v0: [f64; 3],
80    /// Second vertex.
81    pub v1: [f64; 3],
82    /// Third vertex.
83    pub v2: [f64; 3],
84}
85
86impl Triangle {
87    /// Create a new triangle from three vertices.
88    pub fn new(v0: [f64; 3], v1: [f64; 3], v2: [f64; 3]) -> Self {
89        Self { v0, v1, v2 }
90    }
91
92    /// Return the AABB that encloses this triangle.
93    pub fn aabb(&self) -> Aabb {
94        Aabb {
95            min: [
96                self.v0[0].min(self.v1[0]).min(self.v2[0]),
97                self.v0[1].min(self.v1[1]).min(self.v2[1]),
98                self.v0[2].min(self.v1[2]).min(self.v2[2]),
99            ],
100            max: [
101                self.v0[0].max(self.v1[0]).max(self.v2[0]),
102                self.v0[1].max(self.v1[1]).max(self.v2[1]),
103                self.v0[2].max(self.v1[2]).max(self.v2[2]),
104            ],
105        }
106    }
107}
108
109/// A node in the bounding volume hierarchy.
110#[derive(Debug, Clone)]
111pub struct BvhNode {
112    /// Bounding box of this node.
113    pub bounds: Aabb,
114    /// Index of the left child node, or `usize::MAX` if this is a leaf.
115    pub left: usize,
116    /// Index of the right child node, or `usize::MAX` if this is a leaf.
117    pub right: usize,
118    /// Index into the triangle list if this is a leaf (`usize::MAX` otherwise).
119    pub triangle_index: usize,
120}
121
122impl BvhNode {
123    /// Returns `true` if this node is a leaf.
124    pub fn is_leaf(&self) -> bool {
125        self.triangle_index != usize::MAX
126    }
127}
128
129/// Result of a ray-triangle intersection test.
130#[derive(Debug, Clone, Copy)]
131pub struct HitRecord {
132    /// Distance along the ray at the hit point (`t` parameter).
133    pub t: f64,
134    /// Index of the hit triangle.
135    pub triangle_index: usize,
136    /// Barycentric coordinates `(u, v)` of the hit point on the triangle.
137    pub uv: [f64; 2],
138}
139
140// ── Internal helpers ─────────────────────────────────────────────────────────
141
142fn dot3(a: [f64; 3], b: [f64; 3]) -> f64 {
143    a[0] * b[0] + a[1] * b[1] + a[2] * b[2]
144}
145
146fn cross3(a: [f64; 3], b: [f64; 3]) -> [f64; 3] {
147    [
148        a[1] * b[2] - a[2] * b[1],
149        a[2] * b[0] - a[0] * b[2],
150        a[0] * b[1] - a[1] * b[0],
151    ]
152}
153
154fn sub3(a: [f64; 3], b: [f64; 3]) -> [f64; 3] {
155    [a[0] - b[0], a[1] - b[1], a[2] - b[2]]
156}
157
158// ── Public API ────────────────────────────────────────────────────────────────
159
160/// Test whether a ray intersects an AABB using the slab method.
161///
162/// Returns `Some(t_near)` when there is an intersection with `t_near >= t_min`,
163/// or `None` when the ray misses.
164pub fn ray_aabb_intersect(ray: &Ray, aabb: &Aabb, t_min: f64, t_max: f64) -> Option<f64> {
165    let mut t_lo = t_min;
166    let mut t_hi = t_max;
167
168    for axis in 0..3 {
169        let inv_d = if ray.direction[axis].abs() > 1e-15 {
170            1.0 / ray.direction[axis]
171        } else {
172            f64::INFINITY
173        };
174        let mut t0 = (aabb.min[axis] - ray.origin[axis]) * inv_d;
175        let mut t1 = (aabb.max[axis] - ray.origin[axis]) * inv_d;
176        if inv_d < 0.0 {
177            std::mem::swap(&mut t0, &mut t1);
178        }
179        t_lo = t_lo.max(t0);
180        t_hi = t_hi.min(t1);
181        if t_hi < t_lo {
182            return None;
183        }
184    }
185    Some(t_lo)
186}
187
188/// Test whether a ray intersects a triangle using the Möller-Trumbore algorithm.
189///
190/// Returns `Some(HitRecord)` when the ray hits the front face within
191/// `[t_min, t_max]`, or `None` on a miss.
192pub fn ray_triangle_intersect(
193    ray: &Ray,
194    tri: &Triangle,
195    tri_index: usize,
196    t_min: f64,
197    t_max: f64,
198) -> Option<HitRecord> {
199    const EPSILON: f64 = 1e-10;
200
201    let edge1 = sub3(tri.v1, tri.v0);
202    let edge2 = sub3(tri.v2, tri.v0);
203    let h = cross3(ray.direction, edge2);
204    let det = dot3(edge1, h);
205
206    if det.abs() < EPSILON {
207        return None; // Ray is parallel to triangle
208    }
209
210    let inv_det = 1.0 / det;
211    let s = sub3(ray.origin, tri.v0);
212    let u = inv_det * dot3(s, h);
213
214    if !(0.0..=1.0).contains(&u) {
215        return None;
216    }
217
218    let q = cross3(s, edge1);
219    let v = inv_det * dot3(ray.direction, q);
220
221    if v < 0.0 || u + v > 1.0 {
222        return None;
223    }
224
225    let t = inv_det * dot3(edge2, q);
226    if t < t_min || t > t_max {
227        return None;
228    }
229
230    Some(HitRecord {
231        t,
232        triangle_index: tri_index,
233        uv: [u, v],
234    })
235}
236
237/// Build a BVH from a list of triangles using bottom-up SAH-lite construction.
238///
239/// Returns a flat node list; node 0 is the root.
240pub fn build_bvh(triangles: &[Triangle]) -> Vec<BvhNode> {
241    if triangles.is_empty() {
242        return Vec::new();
243    }
244
245    let mut nodes: Vec<BvhNode> = Vec::new();
246
247    // Create leaf nodes
248    let mut leaf_indices: Vec<usize> = (0..triangles.len()).collect();
249
250    fn build_recursive(
251        tris: &[Triangle],
252        indices: &mut [usize],
253        nodes: &mut Vec<BvhNode>,
254    ) -> usize {
255        if indices.len() == 1 {
256            let tri_idx = indices[0];
257            let bounds = tris[tri_idx].aabb();
258            let node = BvhNode {
259                bounds,
260                left: usize::MAX,
261                right: usize::MAX,
262                triangle_index: tri_idx,
263            };
264            let idx = nodes.len();
265            nodes.push(node);
266            return idx;
267        }
268
269        // Compute combined AABB
270        let mut combined = tris[indices[0]].aabb();
271        for &i in indices.iter().skip(1) {
272            combined = combined.union(&tris[i].aabb());
273        }
274
275        // Choose split axis as the longest dimension
276        let extent = [
277            combined.max[0] - combined.min[0],
278            combined.max[1] - combined.min[1],
279            combined.max[2] - combined.min[2],
280        ];
281        let axis = if extent[0] >= extent[1] && extent[0] >= extent[2] {
282            0
283        } else if extent[1] >= extent[2] {
284            1
285        } else {
286            2
287        };
288
289        // Sort by centroid along chosen axis
290        indices.sort_by(|&a, &b| {
291            let ca = tris[a].aabb().centroid()[axis];
292            let cb = tris[b].aabb().centroid()[axis];
293            ca.partial_cmp(&cb).unwrap_or(std::cmp::Ordering::Equal)
294        });
295
296        let mid = indices.len() / 2;
297        let (left_ids, right_ids) = indices.split_at_mut(mid);
298
299        let left_child = build_recursive(tris, left_ids, nodes);
300        let right_child = build_recursive(tris, right_ids, nodes);
301
302        let left_bounds = nodes[left_child].bounds;
303        let right_bounds = nodes[right_child].bounds;
304        let node = BvhNode {
305            bounds: left_bounds.union(&right_bounds),
306            left: left_child,
307            right: right_child,
308            triangle_index: usize::MAX,
309        };
310        let idx = nodes.len();
311        nodes.push(node);
312        idx
313    }
314
315    build_recursive(triangles, &mut leaf_indices, &mut nodes);
316    nodes
317}
318
319/// Traverse the BVH to find the closest ray-triangle intersection.
320///
321/// `root` is the index of the root node (last element returned by `build_bvh`).
322/// Returns `Some(HitRecord)` for the closest hit, or `None`.
323pub fn traverse_bvh(
324    ray: &Ray,
325    nodes: &[BvhNode],
326    triangles: &[Triangle],
327    root: usize,
328    t_min: f64,
329    t_max: f64,
330) -> Option<HitRecord> {
331    if nodes.is_empty() {
332        return None;
333    }
334
335    let mut best: Option<HitRecord> = None;
336    let mut t_closest = t_max;
337
338    // Stack-based traversal
339    let mut stack = Vec::with_capacity(64);
340    stack.push(root);
341
342    while let Some(node_idx) = stack.pop() {
343        if node_idx >= nodes.len() {
344            continue;
345        }
346        let node = &nodes[node_idx];
347
348        // AABB test
349        if ray_aabb_intersect(ray, &node.bounds, t_min, t_closest).is_none() {
350            continue;
351        }
352
353        if node.is_leaf() {
354            if node.triangle_index < triangles.len()
355                && let Some(hit) = ray_triangle_intersect(
356                    ray,
357                    &triangles[node.triangle_index],
358                    node.triangle_index,
359                    t_min,
360                    t_closest,
361                )
362            {
363                t_closest = hit.t;
364                best = Some(hit);
365            }
366        } else {
367            if node.left != usize::MAX {
368                stack.push(node.left);
369            }
370            if node.right != usize::MAX {
371                stack.push(node.right);
372            }
373        }
374    }
375
376    best
377}
378
379/// Cast multiple rays against a BVH and return the closest hit for each.
380///
381/// This is a CPU mock of a GPU batch ray cast. Returns a `Vec` of optional
382/// `HitRecord`s, one per input ray.
383pub fn batch_ray_cast(
384    rays: &[Ray],
385    nodes: &[BvhNode],
386    triangles: &[Triangle],
387    root: usize,
388) -> Vec<Option<HitRecord>> {
389    rays.iter()
390        .map(|ray| traverse_bvh(ray, nodes, triangles, root, 1e-4, f64::INFINITY))
391        .collect()
392}
393
394// ── Tests ─────────────────────────────────────────────────────────────────────
395
396#[cfg(test)]
397mod tests {
398    use super::*;
399
400    fn unit_box_aabb() -> Aabb {
401        Aabb::new([0.0, 0.0, 0.0], [1.0, 1.0, 1.0])
402    }
403
404    fn simple_tri() -> Triangle {
405        Triangle::new([0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0])
406    }
407
408    // Ray tests
409    #[test]
410    fn test_ray_at() {
411        let ray = Ray::new([0.0, 0.0, 0.0], [1.0, 0.0, 0.0]);
412        let p = ray.at(3.0);
413        assert!((p[0] - 3.0).abs() < 1e-12);
414        assert!(p[1].abs() < 1e-12);
415        assert!(p[2].abs() < 1e-12);
416    }
417
418    #[test]
419    fn test_ray_at_negative_t() {
420        let ray = Ray::new([1.0, 0.0, 0.0], [1.0, 0.0, 0.0]);
421        let p = ray.at(-1.0);
422        assert!((p[0]).abs() < 1e-12);
423    }
424
425    // AABB union / centroid
426    #[test]
427    fn test_aabb_union() {
428        let a = Aabb::new([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]);
429        let b = Aabb::new([0.5, 0.5, 0.5], [2.0, 2.0, 2.0]);
430        let u = a.union(&b);
431        assert!((u.max[0] - 2.0).abs() < 1e-12);
432        assert!((u.min[0]).abs() < 1e-12);
433    }
434
435    #[test]
436    fn test_aabb_centroid() {
437        let aabb = Aabb::new([0.0, 0.0, 0.0], [2.0, 4.0, 6.0]);
438        let c = aabb.centroid();
439        assert!((c[0] - 1.0).abs() < 1e-12);
440        assert!((c[1] - 2.0).abs() < 1e-12);
441        assert!((c[2] - 3.0).abs() < 1e-12);
442    }
443
444    // Triangle AABB
445    #[test]
446    fn test_triangle_aabb() {
447        let tri = simple_tri();
448        let aabb = tri.aabb();
449        assert!((aabb.max[0] - 1.0).abs() < 1e-12);
450        assert!((aabb.max[1] - 1.0).abs() < 1e-12);
451        assert!((aabb.max[2]).abs() < 1e-12);
452    }
453
454    // Ray-AABB hit
455    #[test]
456    fn test_ray_aabb_hit() {
457        let ray = Ray::new([0.5, 0.5, -1.0], [0.0, 0.0, 1.0]);
458        let aabb = unit_box_aabb();
459        let result = ray_aabb_intersect(&ray, &aabb, 0.0, f64::INFINITY);
460        assert!(result.is_some());
461        let t = result.unwrap();
462        assert!((t - 1.0).abs() < 1e-10);
463    }
464
465    // Ray-AABB miss
466    #[test]
467    fn test_ray_aabb_miss() {
468        let ray = Ray::new([2.0, 2.0, -1.0], [0.0, 0.0, 1.0]);
469        let aabb = unit_box_aabb();
470        assert!(ray_aabb_intersect(&ray, &aabb, 0.0, f64::INFINITY).is_none());
471    }
472
473    // Ray-AABB from inside
474    #[test]
475    fn test_ray_aabb_inside() {
476        let ray = Ray::new([0.5, 0.5, 0.5], [0.0, 0.0, 1.0]);
477        let aabb = unit_box_aabb();
478        let result = ray_aabb_intersect(&ray, &aabb, 0.0, f64::INFINITY);
479        assert!(result.is_some());
480    }
481
482    // Ray-AABB behind ray
483    #[test]
484    fn test_ray_aabb_behind() {
485        let ray = Ray::new([0.5, 0.5, 5.0], [0.0, 0.0, 1.0]);
486        let aabb = unit_box_aabb();
487        assert!(ray_aabb_intersect(&ray, &aabb, 0.0, f64::INFINITY).is_none());
488    }
489
490    // Ray-triangle hit
491    #[test]
492    fn test_ray_triangle_hit() {
493        // Ray straight down the z-axis hitting a triangle in the xy-plane
494        let tri = Triangle::new([0.0, 0.0, 0.0], [2.0, 0.0, 0.0], [0.0, 2.0, 0.0]);
495        let ray = Ray::new([0.5, 0.5, 1.0], [0.0, 0.0, -1.0]);
496        let result = ray_triangle_intersect(&ray, &tri, 0, 0.0, f64::INFINITY);
497        assert!(result.is_some());
498        let hit = result.unwrap();
499        assert!((hit.t - 1.0).abs() < 1e-9);
500    }
501
502    // Ray-triangle miss (outside)
503    #[test]
504    fn test_ray_triangle_miss_outside() {
505        let tri = Triangle::new([0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]);
506        let ray = Ray::new([2.0, 2.0, 1.0], [0.0, 0.0, -1.0]);
507        assert!(ray_triangle_intersect(&ray, &tri, 0, 0.0, f64::INFINITY).is_none());
508    }
509
510    // Ray-triangle parallel
511    #[test]
512    fn test_ray_triangle_parallel() {
513        let tri = Triangle::new([0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]);
514        let ray = Ray::new([0.0, 0.0, 1.0], [1.0, 0.0, 0.0]); // parallel to z=0 plane
515        assert!(ray_triangle_intersect(&ray, &tri, 0, 0.0, f64::INFINITY).is_none());
516    }
517
518    // Ray-triangle t outside range
519    #[test]
520    fn test_ray_triangle_t_range() {
521        let tri = Triangle::new([0.0, 0.0, 0.0], [2.0, 0.0, 0.0], [0.0, 2.0, 0.0]);
522        let ray = Ray::new([0.5, 0.5, 1.0], [0.0, 0.0, -1.0]);
523        // t_max = 0.5, hit is at t=1.0 → miss
524        assert!(ray_triangle_intersect(&ray, &tri, 0, 0.0, 0.5).is_none());
525    }
526
527    // BVH build with single triangle
528    #[test]
529    fn test_build_bvh_single() {
530        let tris = vec![simple_tri()];
531        let nodes = build_bvh(&tris);
532        assert!(!nodes.is_empty());
533        assert!(nodes.last().unwrap().is_leaf());
534    }
535
536    // BVH build empty
537    #[test]
538    fn test_build_bvh_empty() {
539        let nodes = build_bvh(&[]);
540        assert!(nodes.is_empty());
541    }
542
543    // BVH build multiple triangles
544    #[test]
545    fn test_build_bvh_multiple() {
546        let tris = vec![
547            Triangle::new([0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]),
548            Triangle::new([2.0, 0.0, 0.0], [3.0, 0.0, 0.0], [2.0, 1.0, 0.0]),
549            Triangle::new([4.0, 0.0, 0.0], [5.0, 0.0, 0.0], [4.0, 1.0, 0.0]),
550            Triangle::new([6.0, 0.0, 0.0], [7.0, 0.0, 0.0], [6.0, 1.0, 0.0]),
551        ];
552        let nodes = build_bvh(&tris);
553        assert!(!nodes.is_empty());
554        // Root is the last node
555        let root = nodes.len() - 1;
556        assert!(!nodes[root].is_leaf());
557    }
558
559    // BVH traversal hit
560    #[test]
561    fn test_traverse_bvh_hit() {
562        let tris = vec![
563            Triangle::new([0.0, 0.0, 0.0], [2.0, 0.0, 0.0], [0.0, 2.0, 0.0]),
564            Triangle::new([3.0, 0.0, 0.0], [5.0, 0.0, 0.0], [3.0, 2.0, 0.0]),
565        ];
566        let nodes = build_bvh(&tris);
567        let root = nodes.len() - 1;
568        let ray = Ray::new([0.5, 0.5, 1.0], [0.0, 0.0, -1.0]);
569        let hit = traverse_bvh(&ray, &nodes, &tris, root, 1e-4, f64::INFINITY);
570        assert!(hit.is_some());
571    }
572
573    // BVH traversal miss
574    #[test]
575    fn test_traverse_bvh_miss() {
576        let tris = vec![Triangle::new(
577            [0.0, 0.0, 0.0],
578            [1.0, 0.0, 0.0],
579            [0.0, 1.0, 0.0],
580        )];
581        let nodes = build_bvh(&tris);
582        let root = nodes.len() - 1;
583        let ray = Ray::new([5.0, 5.0, 1.0], [0.0, 0.0, -1.0]);
584        let hit = traverse_bvh(&ray, &nodes, &tris, root, 1e-4, f64::INFINITY);
585        assert!(hit.is_none());
586    }
587
588    // BVH traversal empty
589    #[test]
590    fn test_traverse_bvh_empty_nodes() {
591        let ray = Ray::new([0.0, 0.0, 0.0], [0.0, 0.0, 1.0]);
592        let hit = traverse_bvh(&ray, &[], &[], 0, 0.0, f64::INFINITY);
593        assert!(hit.is_none());
594    }
595
596    // Batch ray cast
597    #[test]
598    fn test_batch_ray_cast() {
599        let tris = vec![Triangle::new(
600            [0.0, 0.0, 0.0],
601            [2.0, 0.0, 0.0],
602            [0.0, 2.0, 0.0],
603        )];
604        let nodes = build_bvh(&tris);
605        let root = nodes.len() - 1;
606        let rays = vec![
607            Ray::new([0.5, 0.5, 1.0], [0.0, 0.0, -1.0]),
608            Ray::new([5.0, 5.0, 1.0], [0.0, 0.0, -1.0]),
609        ];
610        let results = batch_ray_cast(&rays, &nodes, &tris, root);
611        assert_eq!(results.len(), 2);
612        assert!(results[0].is_some());
613        assert!(results[1].is_none());
614    }
615
616    // Batch ray cast empty
617    #[test]
618    fn test_batch_ray_cast_empty_rays() {
619        let tris = vec![simple_tri()];
620        let nodes = build_bvh(&tris);
621        let root = nodes.len() - 1;
622        let results = batch_ray_cast(&[], &nodes, &tris, root);
623        assert!(results.is_empty());
624    }
625
626    // BVH leaf detection
627    #[test]
628    fn test_bvh_node_is_leaf() {
629        let node = BvhNode {
630            bounds: unit_box_aabb(),
631            left: usize::MAX,
632            right: usize::MAX,
633            triangle_index: 0,
634        };
635        assert!(node.is_leaf());
636    }
637
638    #[test]
639    fn test_bvh_node_not_leaf() {
640        let node = BvhNode {
641            bounds: unit_box_aabb(),
642            left: 0,
643            right: 1,
644            triangle_index: usize::MAX,
645        };
646        assert!(!node.is_leaf());
647    }
648
649    // Hit record fields
650    #[test]
651    fn test_hit_record_uv() {
652        let tris = [Triangle::new(
653            [0.0, 0.0, 0.0],
654            [4.0, 0.0, 0.0],
655            [0.0, 4.0, 0.0],
656        )];
657        let ray = Ray::new([1.0, 1.0, 1.0], [0.0, 0.0, -1.0]);
658        let hit = ray_triangle_intersect(&ray, &tris[0], 0, 0.0, f64::INFINITY);
659        assert!(hit.is_some());
660        let h = hit.unwrap();
661        assert!(h.uv[0] >= 0.0 && h.uv[0] <= 1.0);
662        assert!(h.uv[1] >= 0.0 && h.uv[1] <= 1.0);
663    }
664
665    // Closest hit in batch
666    #[test]
667    fn test_batch_returns_closest_hit() {
668        let tris = vec![
669            Triangle::new([0.0, 0.0, 2.0], [2.0, 0.0, 2.0], [0.0, 2.0, 2.0]),
670            Triangle::new([0.0, 0.0, 5.0], [2.0, 0.0, 5.0], [0.0, 2.0, 5.0]),
671        ];
672        let nodes = build_bvh(&tris);
673        let root = nodes.len() - 1;
674        let rays = vec![Ray::new([0.5, 0.5, 0.0], [0.0, 0.0, 1.0])];
675        let results = batch_ray_cast(&rays, &nodes, &tris, root);
676        // Should hit the nearer triangle at z=2
677        if let Some(hit) = results[0] {
678            assert!((hit.t - 2.0).abs() < 1e-9);
679        }
680    }
681
682    // BVH with 8 triangles — deeper tree
683    #[test]
684    fn test_build_bvh_8_triangles() {
685        let tris: Vec<Triangle> = (0..8)
686            .map(|i| {
687                let x = i as f64 * 2.0;
688                Triangle::new([x, 0.0, 0.0], [x + 1.0, 0.0, 0.0], [x, 1.0, 0.0])
689            })
690            .collect();
691        let nodes = build_bvh(&tris);
692        // There should be 2*N-1 nodes for N leaves
693        assert_eq!(nodes.len(), 2 * tris.len() - 1);
694    }
695
696    // Intersection distance ordering
697    #[test]
698    fn test_ray_triangle_t_value() {
699        let tri = Triangle::new([0.0, 0.0, 0.0], [2.0, 0.0, 0.0], [0.0, 2.0, 0.0]);
700        let ray = Ray::new([0.5, 0.5, 3.0], [0.0, 0.0, -1.0]);
701        let hit = ray_triangle_intersect(&ray, &tri, 0, 0.0, f64::INFINITY);
702        assert!(hit.is_some());
703        assert!((hit.unwrap().t - 3.0).abs() < 1e-9);
704    }
705
706    // AABB ray direction components near zero
707    #[test]
708    fn test_ray_aabb_near_zero_dir_component() {
709        // Direction has x and y ≈ 0
710        let ray = Ray::new([0.5, 0.5, -1.0], [0.0, 0.0, 1.0]);
711        let aabb = unit_box_aabb();
712        let result = ray_aabb_intersect(&ray, &aabb, 0.0, f64::INFINITY);
713        assert!(result.is_some());
714    }
715
716    // BVH traversal picks the right triangle
717    #[test]
718    fn test_traverse_picks_correct_triangle() {
719        let tris = vec![
720            Triangle::new([0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]),
721            Triangle::new([10.0, 0.0, 0.0], [11.0, 0.0, 0.0], [10.0, 1.0, 0.0]),
722        ];
723        let nodes = build_bvh(&tris);
724        let root = nodes.len() - 1;
725        let ray = Ray::new([10.2, 0.2, 1.0], [0.0, 0.0, -1.0]);
726        let hit = traverse_bvh(&ray, &nodes, &tris, root, 1e-4, f64::INFINITY);
727        assert!(hit.is_some());
728        assert_eq!(hit.unwrap().triangle_index, 1);
729    }
730}