Skip to main content

oxiphysics_gpu/
ray_tracing_gpu.rs

1// Copyright 2026 COOLJAPAN OU (Team KitaSan)
2// SPDX-License-Identifier: Apache-2.0
3
4//! GPU ray tracing with f32 precision (CPU mock).
5//!
6//! Provides BVH nodes, triangle/sphere/AABB intersection tests,
7//! full scene traversal, ambient occlusion sampling, and a simple
8//! orthographic renderer — all running on the CPU as a mock GPU backend.
9
10// ── Core types ────────────────────────────────────────────────────────────────
11
12/// A ray defined by an origin and a direction (f32 precision).
13#[derive(Debug, Clone, Copy)]
14pub struct Ray {
15    /// Ray origin in world space.
16    pub origin: [f32; 3],
17    /// Ray direction (should be normalised for physically correct distances).
18    pub direction: [f32; 3],
19}
20
21impl Ray {
22    /// Create a new ray.
23    pub fn new(origin: [f32; 3], direction: [f32; 3]) -> Self {
24        Self { origin, direction }
25    }
26
27    /// Evaluate the ray at parameter `t`: `origin + t * direction`.
28    pub fn at(&self, t: f32) -> [f32; 3] {
29        [
30            self.origin[0] + t * self.direction[0],
31            self.origin[1] + t * self.direction[1],
32            self.origin[2] + t * self.direction[2],
33        ]
34    }
35}
36
37/// Information about a ray–surface intersection.
38#[derive(Debug, Clone, Copy)]
39pub struct HitRecord {
40    /// Ray parameter at the hit point.
41    pub t: f32,
42    /// World-space hit point.
43    pub point: [f32; 3],
44    /// Surface normal at the hit point (outward facing).
45    pub normal: [f32; 3],
46    /// Index into the material table.
47    pub material_id: u32,
48}
49
50/// A BVH (bounding-volume hierarchy) node.
51#[derive(Debug, Clone, Copy)]
52pub struct BvhNode {
53    /// Minimum corner of the axis-aligned bounding box.
54    pub aabb_min: [f32; 3],
55    /// Maximum corner of the axis-aligned bounding box.
56    pub aabb_max: [f32; 3],
57    /// Index of the left child (used when `is_leaf == false`).
58    pub left: u32,
59    /// Index of the right child (used when `is_leaf == false`).
60    pub right: u32,
61    /// True if this node stores a triangle directly.
62    pub is_leaf: bool,
63    /// Triangle index stored by this leaf node.
64    pub tri_idx: u32,
65}
66
67/// A triangle in GPU-friendly format.
68#[derive(Debug, Clone, Copy)]
69pub struct GpuTriangle {
70    /// First vertex.
71    pub v0: [f32; 3],
72    /// Second vertex.
73    pub v1: [f32; 3],
74    /// Third vertex.
75    pub v2: [f32; 3],
76    /// Pre-computed face normal.
77    pub normal: [f32; 3],
78    /// Index into the material table.
79    pub material_id: u32,
80}
81
82// ── Vector helpers (f32) ──────────────────────────────────────────────────────
83
84#[inline]
85fn dot3f(a: [f32; 3], b: [f32; 3]) -> f32 {
86    a[0] * b[0] + a[1] * b[1] + a[2] * b[2]
87}
88
89#[inline]
90fn cross3f(a: [f32; 3], b: [f32; 3]) -> [f32; 3] {
91    [
92        a[1] * b[2] - a[2] * b[1],
93        a[2] * b[0] - a[0] * b[2],
94        a[0] * b[1] - a[1] * b[0],
95    ]
96}
97
98#[inline]
99fn sub3f(a: [f32; 3], b: [f32; 3]) -> [f32; 3] {
100    [a[0] - b[0], a[1] - b[1], a[2] - b[2]]
101}
102
103#[inline]
104fn normalize3f(v: [f32; 3]) -> [f32; 3] {
105    let len = dot3f(v, v).sqrt();
106    if len < 1e-10 {
107        return [0.0; 3];
108    }
109    [v[0] / len, v[1] / len, v[2] / len]
110}
111
112// ── Intersection tests ────────────────────────────────────────────────────────
113
114/// Test ray–sphere intersection.
115///
116/// Returns the smallest positive `t`, or `None` if there is no hit.
117pub fn ray_sphere_intersect(ray: &Ray, center: [f32; 3], radius: f32) -> Option<f32> {
118    let oc = sub3f(ray.origin, center);
119    let a = dot3f(ray.direction, ray.direction);
120    let half_b = dot3f(oc, ray.direction);
121    let c = dot3f(oc, oc) - radius * radius;
122    let discriminant = half_b * half_b - a * c;
123    if discriminant < 0.0 {
124        return None;
125    }
126    let sqrt_d = discriminant.sqrt();
127    let t1 = (-half_b - sqrt_d) / a;
128    if t1 > 1e-4 {
129        return Some(t1);
130    }
131    let t2 = (-half_b + sqrt_d) / a;
132    if t2 > 1e-4 { Some(t2) } else { None }
133}
134
135/// Test ray–triangle intersection using the Möller–Trumbore algorithm.
136///
137/// Returns `Some(t)` on a valid hit, or `None`.
138pub fn ray_triangle_intersect(ray: &Ray, tri: &GpuTriangle) -> Option<f32> {
139    let edge1 = sub3f(tri.v1, tri.v0);
140    let edge2 = sub3f(tri.v2, tri.v0);
141    let h = cross3f(ray.direction, edge2);
142    let a = dot3f(edge1, h);
143    if a.abs() < 1e-8 {
144        return None; // Ray is parallel to triangle
145    }
146    let f = 1.0 / a;
147    let s = sub3f(ray.origin, tri.v0);
148    let u = f * dot3f(s, h);
149    if !(0.0..=1.0).contains(&u) {
150        return None;
151    }
152    let q = cross3f(s, edge1);
153    let v = f * dot3f(ray.direction, q);
154    if v < 0.0 || u + v > 1.0 {
155        return None;
156    }
157    let t = f * dot3f(edge2, q);
158    if t > 1e-4 { Some(t) } else { None }
159}
160
161/// Test ray–AABB intersection using the slab method.
162///
163/// Returns `true` if the ray hits the box `[aabb_min, aabb_max]`.
164pub fn ray_aabb_intersect(ray: &Ray, aabb_min: [f32; 3], aabb_max: [f32; 3]) -> bool {
165    let mut t_min = 0.0_f32;
166    let mut t_max = f32::MAX;
167    for i in 0..3 {
168        let inv_d = 1.0 / ray.direction[i];
169        let t0 = (aabb_min[i] - ray.origin[i]) * inv_d;
170        let t1 = (aabb_max[i] - ray.origin[i]) * inv_d;
171        let (t_near, t_far) = if inv_d >= 0.0 { (t0, t1) } else { (t1, t0) };
172        t_min = t_min.max(t_near);
173        t_max = t_max.min(t_far);
174        if t_max < t_min {
175            return false;
176        }
177    }
178    t_max >= 0.0
179}
180
181// ── Scene traversal ───────────────────────────────────────────────────────────
182
183/// Find the closest triangle hit by `ray` in `triangles`.
184///
185/// Returns `Some(HitRecord)` if any triangle is hit, `None` otherwise.
186pub fn trace_ray(ray: &Ray, triangles: &[GpuTriangle]) -> Option<HitRecord> {
187    let mut best_t = f32::MAX;
188    let mut best_hit: Option<HitRecord> = None;
189
190    for tri in triangles {
191        if let Some(t) = ray_triangle_intersect(ray, tri)
192            && t < best_t
193        {
194            best_t = t;
195            let point = ray.at(t);
196            best_hit = Some(HitRecord {
197                t,
198                point,
199                normal: tri.normal,
200                material_id: tri.material_id,
201            });
202        }
203    }
204    best_hit
205}
206
207// ── Rendering ─────────────────────────────────────────────────────────────────
208
209/// Render a scene with orthographic projection, returning one RGB pixel per
210/// sample.
211///
212/// * `camera_pos` — origin of all primary rays (ortho camera).
213/// * `nx`, `ny` — image dimensions in pixels.
214/// * Returns a flat `Vec` of `[r, g, b]` triples (row-major).
215///
216/// Each pixel fires one ray in the `-Z` direction; hit pixels get the
217/// absolute-value of the triangle normal as colour, miss pixels return black.
218pub fn gpu_render_pixels(
219    triangles: &[GpuTriangle],
220    camera_pos: [f32; 3],
221    nx: usize,
222    ny: usize,
223) -> Vec<[f32; 3]> {
224    let mut pixels = Vec::with_capacity(nx * ny);
225    for row in 0..ny {
226        for col in 0..nx {
227            let u = (col as f32 + 0.5) / nx as f32 * 2.0 - 1.0;
228            let v = (row as f32 + 0.5) / ny as f32 * 2.0 - 1.0;
229            let ray = Ray::new(
230                [camera_pos[0] + u, camera_pos[1] + v, camera_pos[2]],
231                [0.0, 0.0, -1.0],
232            );
233            let colour = match trace_ray(&ray, triangles) {
234                Some(hit) => [
235                    hit.normal[0].abs(),
236                    hit.normal[1].abs(),
237                    hit.normal[2].abs(),
238                ],
239                None => [0.0, 0.0, 0.0],
240            };
241            pixels.push(colour);
242        }
243    }
244    pixels
245}
246
247// ── Ambient occlusion ─────────────────────────────────────────────────────────
248
249/// Estimate ambient occlusion at a hit point by sampling hemisphere directions.
250///
251/// `n_samples` rays are fired in directions cosine-weighted around `hit.normal`.
252/// Returns the fraction of rays that are *not* occluded (1.0 = fully lit).
253pub fn ambient_occlusion_sample(
254    hit: &HitRecord,
255    triangles: &[GpuTriangle],
256    n_samples: usize,
257) -> f32 {
258    use rand::RngExt;
259    if n_samples == 0 {
260        return 1.0;
261    }
262    let mut rng = rand::rng();
263    let mut unoccluded = 0usize;
264
265    let n = normalize3f(hit.normal);
266    // Build a local tangent frame
267    let up = if n[0].abs() < 0.9 {
268        [1.0_f32, 0.0, 0.0]
269    } else {
270        [0.0_f32, 1.0, 0.0]
271    };
272    let tangent = normalize3f(cross3f(n, up));
273    let bitangent = cross3f(n, tangent);
274
275    for _ in 0..n_samples {
276        // Sample hemisphere using cosine-weighting
277        let r1: f32 = rng.random_range(0.0_f32..1.0_f32);
278        let r2: f32 = rng.random_range(0.0_f32..1.0_f32);
279        let phi = 2.0 * std::f32::consts::PI * r1;
280        let cos_theta = r2.sqrt();
281        let sin_theta = (1.0_f32 - cos_theta * cos_theta).sqrt();
282        let lx = sin_theta * phi.cos();
283        let ly = sin_theta * phi.sin();
284        let lz = cos_theta;
285        let dir = [
286            lx * tangent[0] + ly * bitangent[0] + lz * n[0],
287            lx * tangent[1] + ly * bitangent[1] + lz * n[1],
288            lx * tangent[2] + ly * bitangent[2] + lz * n[2],
289        ];
290        let ao_ray = Ray::new(hit.point, dir);
291        if trace_ray(&ao_ray, triangles).is_none() {
292            unoccluded += 1;
293        }
294    }
295    unoccluded as f32 / n_samples as f32
296}
297
298// ── Tests ─────────────────────────────────────────────────────────────────────
299
300#[cfg(test)]
301mod tests {
302    use super::*;
303
304    fn unit_triangle() -> GpuTriangle {
305        GpuTriangle {
306            v0: [0.0, 0.0, -1.0],
307            v1: [1.0, 0.0, -1.0],
308            v2: [0.0, 1.0, -1.0],
309            normal: [0.0, 0.0, 1.0],
310            material_id: 0,
311        }
312    }
313
314    fn centered_ray() -> Ray {
315        Ray::new([0.25, 0.25, 0.0], [0.0, 0.0, -1.0])
316    }
317
318    #[test]
319    fn test_ray_at() {
320        let r = Ray::new([0.0, 0.0, 0.0], [1.0, 0.0, 0.0]);
321        let p = r.at(2.0);
322        assert!((p[0] - 2.0).abs() < 1e-6);
323        assert!(p[1].abs() < 1e-6);
324        assert!(p[2].abs() < 1e-6);
325    }
326
327    #[test]
328    fn test_ray_sphere_hit() {
329        let r = Ray::new([0.0, 0.0, 5.0], [0.0, 0.0, -1.0]);
330        let t = ray_sphere_intersect(&r, [0.0, 0.0, 0.0], 1.0);
331        assert!(t.is_some());
332        assert!((t.unwrap() - 4.0).abs() < 1e-4);
333    }
334
335    #[test]
336    fn test_ray_sphere_miss() {
337        let r = Ray::new([5.0, 0.0, 0.0], [0.0, 0.0, -1.0]);
338        assert!(ray_sphere_intersect(&r, [0.0, 0.0, 0.0], 1.0).is_none());
339    }
340
341    #[test]
342    fn test_ray_sphere_inside() {
343        let r = Ray::new([0.0, 0.0, 0.0], [1.0, 0.0, 0.0]);
344        let t = ray_sphere_intersect(&r, [0.0, 0.0, 0.0], 2.0);
345        assert!(t.is_some());
346    }
347
348    #[test]
349    fn test_ray_triangle_hit() {
350        let tri = unit_triangle();
351        let r = centered_ray();
352        let t = ray_triangle_intersect(&r, &tri);
353        assert!(t.is_some());
354        assert!((t.unwrap() - 1.0).abs() < 1e-4);
355    }
356
357    #[test]
358    fn test_ray_triangle_miss_outside() {
359        let tri = unit_triangle();
360        let r = Ray::new([2.0, 2.0, 0.0], [0.0, 0.0, -1.0]);
361        assert!(ray_triangle_intersect(&r, &tri).is_none());
362    }
363
364    #[test]
365    fn test_ray_triangle_parallel() {
366        let tri = unit_triangle();
367        let r = Ray::new([0.0, 0.0, 0.0], [1.0, 0.0, 0.0]);
368        assert!(ray_triangle_intersect(&r, &tri).is_none());
369    }
370
371    #[test]
372    fn test_ray_aabb_hit_direct() {
373        let r = Ray::new([0.0, 0.0, 2.0], [0.0, 0.0, -1.0]);
374        assert!(ray_aabb_intersect(&r, [-1.0, -1.0, -1.0], [1.0, 1.0, 1.0]));
375    }
376
377    #[test]
378    fn test_ray_aabb_miss() {
379        let r = Ray::new([5.0, 0.0, 0.0], [0.0, 0.0, -1.0]);
380        assert!(!ray_aabb_intersect(&r, [-1.0, -1.0, -1.0], [1.0, 1.0, 1.0]));
381    }
382
383    #[test]
384    fn test_ray_aabb_from_inside() {
385        let r = Ray::new([0.0, 0.0, 0.0], [1.0, 0.0, 0.0]);
386        assert!(ray_aabb_intersect(&r, [-1.0, -1.0, -1.0], [1.0, 1.0, 1.0]));
387    }
388
389    #[test]
390    fn test_trace_ray_hit() {
391        let tri = unit_triangle();
392        let r = centered_ray();
393        let hit = trace_ray(&r, &[tri]);
394        assert!(hit.is_some());
395    }
396
397    #[test]
398    fn test_trace_ray_miss() {
399        let tri = unit_triangle();
400        let r = Ray::new([5.0, 5.0, 0.0], [0.0, 0.0, -1.0]);
401        assert!(trace_ray(&r, &[tri]).is_none());
402    }
403
404    #[test]
405    fn test_trace_ray_closest() {
406        // Two triangles at z=-1 and z=-2; ray should hit the nearer one
407        let tri1 = unit_triangle(); // z=-1
408        let tri2 = GpuTriangle {
409            v0: [0.0, 0.0, -2.0],
410            v1: [1.0, 0.0, -2.0],
411            v2: [0.0, 1.0, -2.0],
412            normal: [0.0, 0.0, 1.0],
413            material_id: 1,
414        };
415        let r = centered_ray();
416        let hit = trace_ray(&r, &[tri1, tri2]).unwrap();
417        assert_eq!(hit.material_id, 0);
418    }
419
420    #[test]
421    fn test_trace_ray_empty_scene() {
422        let r = centered_ray();
423        assert!(trace_ray(&r, &[]).is_none());
424    }
425
426    #[test]
427    fn test_hit_record_point() {
428        let tri = unit_triangle();
429        let r = centered_ray();
430        let hit = trace_ray(&r, &[tri]).unwrap();
431        assert!((hit.point[2] - (-1.0)).abs() < 1e-4);
432    }
433
434    #[test]
435    fn test_hit_record_normal() {
436        let tri = unit_triangle();
437        let r = centered_ray();
438        let hit = trace_ray(&r, &[tri]).unwrap();
439        assert!((hit.normal[2] - 1.0).abs() < 1e-4);
440    }
441
442    #[test]
443    fn test_gpu_render_pixels_count() {
444        let pixels = gpu_render_pixels(&[], [0.0, 0.0, 5.0], 4, 4);
445        assert_eq!(pixels.len(), 16);
446    }
447
448    #[test]
449    fn test_gpu_render_pixels_miss_black() {
450        // Empty scene: all pixels should be black
451        let pixels = gpu_render_pixels(&[], [0.0, 0.0, 5.0], 2, 2);
452        for p in &pixels {
453            assert_eq!(*p, [0.0, 0.0, 0.0]);
454        }
455    }
456
457    #[test]
458    fn test_gpu_render_pixels_hit_coloured() {
459        // Single triangle centred at the viewport
460        let tri = GpuTriangle {
461            v0: [-2.0, -2.0, -1.0],
462            v1: [2.0, -2.0, -1.0],
463            v2: [0.0, 2.0, -1.0],
464            normal: [0.0, 0.0, 1.0],
465            material_id: 0,
466        };
467        let pixels = gpu_render_pixels(&[tri], [0.0, 0.0, 0.0], 3, 3);
468        // Middle pixel should be non-black
469        let has_hit = pixels.iter().any(|p| p[2] > 0.5);
470        assert!(has_hit);
471    }
472
473    #[test]
474    fn test_ambient_occlusion_zero_samples() {
475        let hit = HitRecord {
476            t: 1.0,
477            point: [0.0, 0.0, 0.0],
478            normal: [0.0, 1.0, 0.0],
479            material_id: 0,
480        };
481        let ao = ambient_occlusion_sample(&hit, &[], 0);
482        assert!((ao - 1.0).abs() < 1e-6);
483    }
484
485    #[test]
486    fn test_ambient_occlusion_empty_scene() {
487        let hit = HitRecord {
488            t: 1.0,
489            point: [0.0, 1.0, 0.0],
490            normal: [0.0, 1.0, 0.0],
491            material_id: 0,
492        };
493        let ao = ambient_occlusion_sample(&hit, &[], 32);
494        assert!((ao - 1.0).abs() < 1e-6);
495    }
496
497    #[test]
498    fn test_ambient_occlusion_range() {
499        let hit = HitRecord {
500            t: 1.0,
501            point: [0.0, 0.0, 0.0],
502            normal: [0.0, 0.0, 1.0],
503            material_id: 0,
504        };
505        let ao = ambient_occlusion_sample(&hit, &[], 16);
506        assert!((0.0..=1.0).contains(&ao));
507    }
508
509    #[test]
510    fn test_bvh_node_fields() {
511        let node = BvhNode {
512            aabb_min: [-1.0, -1.0, -1.0],
513            aabb_max: [1.0, 1.0, 1.0],
514            left: 0,
515            right: 1,
516            is_leaf: true,
517            tri_idx: 42,
518        };
519        assert_eq!(node.tri_idx, 42);
520        assert!(node.is_leaf);
521    }
522
523    #[test]
524    fn test_gpu_triangle_fields() {
525        let tri = unit_triangle();
526        assert_eq!(tri.material_id, 0);
527        assert!((tri.normal[2] - 1.0).abs() < 1e-6);
528    }
529
530    #[test]
531    fn test_ray_sphere_tangent() {
532        // Ray that just grazes the sphere
533        let r = Ray::new([1.0, 0.0, 5.0], [0.0, 0.0, -1.0]);
534        // sphere radius 1 at origin: tangent at x=1
535        let t = ray_sphere_intersect(&r, [0.0, 0.0, 0.0], 1.0);
536        assert!(t.is_some());
537    }
538
539    #[test]
540    fn test_render_1x1_empty() {
541        let pixels = gpu_render_pixels(&[], [0.0, 0.0, 1.0], 1, 1);
542        assert_eq!(pixels.len(), 1);
543        assert_eq!(pixels[0], [0.0, 0.0, 0.0]);
544    }
545
546    #[test]
547    fn test_ray_aabb_negative_direction() {
548        let r = Ray::new([2.0, 0.0, 0.0], [-1.0, 0.0, 0.0]);
549        assert!(ray_aabb_intersect(&r, [-1.0, -1.0, -1.0], [1.0, 1.0, 1.0]));
550    }
551}