Skip to main content

oxiphysics_gpu/
path_tracer.rs

1// Copyright 2026 COOLJAPAN OU (Team KitaSan)
2// SPDX-License-Identifier: Apache-2.0
3
4//! CPU path tracer using GPU-style algorithms.
5//!
6//! Implements a Monte Carlo path tracer with Lambertian, Metal, and Dielectric
7//! materials, sphere and triangle primitives, and a progressive pixel buffer.
8
9#![allow(dead_code)]
10
11use rand::Rng;
12
13use rand::RngExt;
14// ── Vector helpers (f32, 3D) ─────────────────────────────────────────────────
15
16#[inline]
17fn vadd(a: [f32; 3], b: [f32; 3]) -> [f32; 3] {
18    [a[0] + b[0], a[1] + b[1], a[2] + b[2]]
19}
20
21#[inline]
22fn vsub(a: [f32; 3], b: [f32; 3]) -> [f32; 3] {
23    [a[0] - b[0], a[1] - b[1], a[2] - b[2]]
24}
25
26#[inline]
27fn vmul(a: [f32; 3], s: f32) -> [f32; 3] {
28    [a[0] * s, a[1] * s, a[2] * s]
29}
30
31#[inline]
32fn vmul3(a: [f32; 3], b: [f32; 3]) -> [f32; 3] {
33    [a[0] * b[0], a[1] * b[1], a[2] * b[2]]
34}
35
36#[inline]
37fn dot(a: [f32; 3], b: [f32; 3]) -> f32 {
38    a[0] * b[0] + a[1] * b[1] + a[2] * b[2]
39}
40
41#[inline]
42fn cross(a: [f32; 3], b: [f32; 3]) -> [f32; 3] {
43    [
44        a[1] * b[2] - a[2] * b[1],
45        a[2] * b[0] - a[0] * b[2],
46        a[0] * b[1] - a[1] * b[0],
47    ]
48}
49
50#[inline]
51fn length(v: [f32; 3]) -> f32 {
52    dot(v, v).sqrt()
53}
54
55#[inline]
56fn normalize(v: [f32; 3]) -> [f32; 3] {
57    let l = length(v);
58    if l < 1e-8 {
59        return [0.0; 3];
60    }
61    vmul(v, 1.0 / l)
62}
63
64#[inline]
65fn reflect(d: [f32; 3], n: [f32; 3]) -> [f32; 3] {
66    vsub(d, vmul(n, 2.0 * dot(d, n)))
67}
68
69fn refract(uv: [f32; 3], n: [f32; 3], ni_over_nt: f32) -> Option<[f32; 3]> {
70    let cos_theta = (-dot(uv, n)).min(1.0);
71    let r_out_perp = vmul(vadd(uv, vmul(n, cos_theta)), ni_over_nt);
72    let r_out_parallel_len2 = (1.0 - dot(r_out_perp, r_out_perp)).abs();
73    let r_out_parallel = vmul(n, -(r_out_parallel_len2.sqrt()));
74    Some(vadd(r_out_perp, r_out_parallel))
75}
76
77fn schlick(cosine: f32, ref_idx: f32) -> f32 {
78    let r0 = ((1.0 - ref_idx) / (1.0 + ref_idx)).powi(2);
79    r0 + (1.0 - r0) * (1.0 - cosine).powi(5)
80}
81
82fn random_in_unit_sphere(rng: &mut impl Rng) -> [f32; 3] {
83    loop {
84        let v = [
85            rng.random_range(-1.0f32..1.0),
86            rng.random_range(-1.0f32..1.0),
87            rng.random_range(-1.0f32..1.0),
88        ];
89        if dot(v, v) < 1.0 {
90            return v;
91        }
92    }
93}
94
95fn random_unit_vector(rng: &mut impl Rng) -> [f32; 3] {
96    normalize(random_in_unit_sphere(rng))
97}
98
99// ── Ray ──────────────────────────────────────────────────────────────────────
100
101/// A ray with an origin and a direction.
102#[derive(Debug, Clone, Copy)]
103pub struct Ray {
104    /// Ray origin in world space.
105    pub origin: [f32; 3],
106    /// Ray direction (need not be normalised, but usually is).
107    pub direction: [f32; 3],
108}
109
110impl Ray {
111    /// Create a new ray.
112    pub fn new(origin: [f32; 3], direction: [f32; 3]) -> Self {
113        Self { origin, direction }
114    }
115
116    /// Evaluate the ray at parameter `t`: `origin + t * direction`.
117    pub fn at(&self, t: f32) -> [f32; 3] {
118        vadd(self.origin, vmul(self.direction, t))
119    }
120}
121
122// ── Material ─────────────────────────────────────────────────────────────────
123
124/// The scattering type of a material.
125#[derive(Debug, Clone, Copy)]
126pub enum MaterialType {
127    /// Perfectly diffuse (Lambertian) scattering.
128    Lambertian,
129    /// Metallic reflection with a roughness fuzz factor in \[0,1\].
130    Metal(f32),
131    /// Dielectric (glass-like) material with index of refraction.
132    Dielectric(f32),
133}
134
135/// A surface material with albedo colour and scattering type.
136#[derive(Debug, Clone, Copy)]
137pub struct Material {
138    /// Base colour of the material (RGB in \[0,1\]).
139    pub albedo: [f32; 3],
140    /// Scattering behaviour.
141    pub kind: MaterialType,
142}
143
144impl Material {
145    /// Create a Lambertian (diffuse) material.
146    pub fn lambertian(albedo: [f32; 3]) -> Self {
147        Self {
148            albedo,
149            kind: MaterialType::Lambertian,
150        }
151    }
152
153    /// Create a metallic material.
154    pub fn metal(albedo: [f32; 3], fuzz: f32) -> Self {
155        Self {
156            albedo,
157            kind: MaterialType::Metal(fuzz.clamp(0.0, 1.0)),
158        }
159    }
160
161    /// Create a dielectric (glass) material.
162    pub fn dielectric(ior: f32) -> Self {
163        Self {
164            albedo: [1.0; 3],
165            kind: MaterialType::Dielectric(ior),
166        }
167    }
168
169    /// Scatter a ray off this material surface.
170    ///
171    /// Returns `Some((scattered_ray, attenuation))` or `None` if absorbed.
172    pub fn scatter(
173        &self,
174        ray: &Ray,
175        hit: &HitRecord,
176        rng: &mut impl Rng,
177    ) -> Option<(Ray, [f32; 3])> {
178        match self.kind {
179            MaterialType::Lambertian => {
180                let target = vadd(vadd(hit.point, hit.normal), random_unit_vector(rng));
181                let scattered = Ray::new(hit.point, vsub(target, hit.point));
182                Some((scattered, self.albedo))
183            }
184            MaterialType::Metal(fuzz) => {
185                let reflected = reflect(normalize(ray.direction), hit.normal);
186                let fuzzed = vadd(reflected, vmul(random_in_unit_sphere(rng), fuzz));
187                if dot(fuzzed, hit.normal) > 0.0 {
188                    Some((Ray::new(hit.point, fuzzed), self.albedo))
189                } else {
190                    None
191                }
192            }
193            MaterialType::Dielectric(ior) => {
194                let attenuation = [1.0f32; 3];
195                let refraction_ratio = if hit.front_face { 1.0 / ior } else { ior };
196                let unit_dir = normalize(ray.direction);
197                let cos_theta = (-dot(unit_dir, hit.normal)).min(1.0);
198                let sin_theta = (1.0 - cos_theta * cos_theta).sqrt();
199                let cannot_refract = refraction_ratio * sin_theta > 1.0;
200                let scattered_dir = if cannot_refract
201                    || schlick(cos_theta, refraction_ratio) > rng.random::<f32>()
202                {
203                    reflect(unit_dir, hit.normal)
204                } else {
205                    refract(unit_dir, hit.normal, refraction_ratio)
206                        .unwrap_or_else(|| reflect(unit_dir, hit.normal))
207                };
208                Some((Ray::new(hit.point, scattered_dir), attenuation))
209            }
210        }
211    }
212}
213
214// ── HitRecord ────────────────────────────────────────────────────────────────
215
216/// Record of a ray–surface intersection.
217#[derive(Debug, Clone, Copy)]
218pub struct HitRecord {
219    /// Ray parameter at intersection.
220    pub t: f32,
221    /// World-space hit point.
222    pub point: [f32; 3],
223    /// Outward-facing surface normal (normalised).
224    pub normal: [f32; 3],
225    /// Index into the scene's material list.
226    pub material_index: usize,
227    /// True when the ray hits the front face.
228    pub front_face: bool,
229}
230
231impl HitRecord {
232    fn new(t: f32, point: [f32; 3], outward_normal: [f32; 3], ray: &Ray, mat: usize) -> Self {
233        let front_face = dot(ray.direction, outward_normal) < 0.0;
234        let normal = if front_face {
235            outward_normal
236        } else {
237            vmul(outward_normal, -1.0)
238        };
239        Self {
240            t,
241            point,
242            normal,
243            material_index: mat,
244            front_face,
245        }
246    }
247}
248
249// ── Sphere ───────────────────────────────────────────────────────────────────
250
251/// A sphere primitive.
252#[derive(Debug, Clone, Copy)]
253pub struct Sphere {
254    /// Centre of the sphere.
255    pub center: [f32; 3],
256    /// Radius of the sphere.
257    pub radius: f32,
258    /// Index into the scene's material list.
259    pub material_index: usize,
260}
261
262impl Sphere {
263    /// Create a new sphere.
264    pub fn new(center: [f32; 3], radius: f32, material_index: usize) -> Self {
265        Self {
266            center,
267            radius,
268            material_index,
269        }
270    }
271
272    /// Test ray–sphere intersection in the interval `(t_min, t_max)`.
273    pub fn hit(&self, ray: &Ray, t_min: f32, t_max: f32) -> Option<HitRecord> {
274        let oc = vsub(ray.origin, self.center);
275        let a = dot(ray.direction, ray.direction);
276        let half_b = dot(oc, ray.direction);
277        let c = dot(oc, oc) - self.radius * self.radius;
278        let discriminant = half_b * half_b - a * c;
279        if discriminant < 0.0 {
280            return None;
281        }
282        let sqrt_d = discriminant.sqrt();
283        let mut root = (-half_b - sqrt_d) / a;
284        if root < t_min || root > t_max {
285            root = (-half_b + sqrt_d) / a;
286            if root < t_min || root > t_max {
287                return None;
288            }
289        }
290        let point = ray.at(root);
291        let outward_normal = vmul(vsub(point, self.center), 1.0 / self.radius);
292        Some(HitRecord::new(
293            root,
294            point,
295            outward_normal,
296            ray,
297            self.material_index,
298        ))
299    }
300}
301
302// ── Triangle ─────────────────────────────────────────────────────────────────
303
304/// A triangle primitive.
305#[derive(Debug, Clone, Copy)]
306pub struct Triangle {
307    /// Vertex A.
308    pub v0: [f32; 3],
309    /// Vertex B.
310    pub v1: [f32; 3],
311    /// Vertex C.
312    pub v2: [f32; 3],
313    /// Precomputed geometric normal (normalised).
314    pub normal: [f32; 3],
315    /// Index into the scene's material list.
316    pub material_index: usize,
317}
318
319impl Triangle {
320    /// Create a new triangle, computing the normal automatically.
321    pub fn new(v0: [f32; 3], v1: [f32; 3], v2: [f32; 3], material_index: usize) -> Self {
322        let edge1 = vsub(v1, v0);
323        let edge2 = vsub(v2, v0);
324        let normal = normalize(cross(edge1, edge2));
325        Self {
326            v0,
327            v1,
328            v2,
329            normal,
330            material_index,
331        }
332    }
333
334    /// Möller–Trumbore ray–triangle intersection.
335    pub fn hit(&self, ray: &Ray, t_min: f32, t_max: f32) -> Option<HitRecord> {
336        const EPSILON: f32 = 1e-7;
337        let edge1 = vsub(self.v1, self.v0);
338        let edge2 = vsub(self.v2, self.v0);
339        let h = cross(ray.direction, edge2);
340        let a = dot(edge1, h);
341        if a.abs() < EPSILON {
342            return None; // parallel
343        }
344        let f = 1.0 / a;
345        let s = vsub(ray.origin, self.v0);
346        let u = f * dot(s, h);
347        if !(0.0..=1.0).contains(&u) {
348            return None;
349        }
350        let q = cross(s, edge1);
351        let v = f * dot(ray.direction, q);
352        if v < 0.0 || u + v > 1.0 {
353            return None;
354        }
355        let t = f * dot(edge2, q);
356        if t < t_min || t > t_max {
357            return None;
358        }
359        let point = ray.at(t);
360        Some(HitRecord::new(
361            t,
362            point,
363            self.normal,
364            ray,
365            self.material_index,
366        ))
367    }
368}
369
370// ── Light ────────────────────────────────────────────────────────────────────
371
372/// A point light source.
373#[derive(Debug, Clone, Copy)]
374pub struct PointLight {
375    /// Light position.
376    pub position: [f32; 3],
377    /// Light colour/intensity (RGB).
378    pub color: [f32; 3],
379    /// Light intensity multiplier.
380    pub intensity: f32,
381}
382
383impl PointLight {
384    /// Create a new point light.
385    pub fn new(position: [f32; 3], color: [f32; 3], intensity: f32) -> Self {
386        Self {
387            position,
388            color,
389            intensity,
390        }
391    }
392}
393
394// ── PathTracerScene ──────────────────────────────────────────────────────────
395
396/// A scene containing geometry, materials, and lights.
397#[derive(Debug, Clone, Default)]
398pub struct PathTracerScene {
399    /// Sphere primitives in the scene.
400    pub spheres: Vec<Sphere>,
401    /// Triangle primitives in the scene.
402    pub triangles: Vec<Triangle>,
403    /// Point lights in the scene.
404    pub lights: Vec<PointLight>,
405    /// Material list shared by all primitives.
406    pub materials: Vec<Material>,
407    /// Background (sky) gradient top colour.
408    pub sky_top: [f32; 3],
409    /// Background (sky) gradient bottom colour.
410    pub sky_bottom: [f32; 3],
411}
412
413impl PathTracerScene {
414    /// Create a new empty scene.
415    pub fn new() -> Self {
416        Self {
417            sky_top: [0.5, 0.7, 1.0],
418            sky_bottom: [1.0, 1.0, 1.0],
419            ..Default::default()
420        }
421    }
422
423    /// Add a material and return its index.
424    pub fn add_material(&mut self, mat: Material) -> usize {
425        let idx = self.materials.len();
426        self.materials.push(mat);
427        idx
428    }
429
430    /// Add a sphere to the scene.
431    pub fn add_sphere(&mut self, sphere: Sphere) {
432        self.spheres.push(sphere);
433    }
434
435    /// Add a triangle to the scene.
436    pub fn add_triangle(&mut self, triangle: Triangle) {
437        self.triangles.push(triangle);
438    }
439
440    /// Add a point light to the scene.
441    pub fn add_light(&mut self, light: PointLight) {
442        self.lights.push(light);
443    }
444
445    /// Find the closest intersection along a ray.
446    pub fn hit_scene(&self, ray: &Ray, t_min: f32, t_max: f32) -> Option<HitRecord> {
447        let mut closest: Option<HitRecord> = None;
448        let mut t_closest = t_max;
449        for sphere in &self.spheres {
450            if let Some(rec) = sphere.hit(ray, t_min, t_closest) {
451                t_closest = rec.t;
452                closest = Some(rec);
453            }
454        }
455        for tri in &self.triangles {
456            if let Some(rec) = tri.hit(ray, t_min, t_closest) {
457                t_closest = rec.t;
458                closest = Some(rec);
459            }
460        }
461        closest
462    }
463
464    /// Sky background colour for a given ray direction.
465    fn sky_color(&self, ray: &Ray) -> [f32; 3] {
466        let unit = normalize(ray.direction);
467        let t = 0.5 * (unit[1] + 1.0);
468        let a = self.sky_bottom;
469        let b = self.sky_top;
470        [
471            a[0] * (1.0 - t) + b[0] * t,
472            a[1] * (1.0 - t) + b[1] * t,
473            a[2] * (1.0 - t) + b[2] * t,
474        ]
475    }
476
477    /// Trace a ray through the scene with Monte Carlo path tracing.
478    ///
479    /// Returns the estimated radiance (RGB) for the ray.
480    pub fn trace(&self, ray: &Ray, max_depth: usize, rng: &mut impl Rng) -> [f32; 3] {
481        if max_depth == 0 {
482            return [0.0; 3];
483        }
484        if let Some(hit) = self.hit_scene(ray, 1e-4, f32::INFINITY) {
485            let mat = &self.materials[hit.material_index];
486            if let Some((scattered, attenuation)) = mat.scatter(ray, &hit, rng) {
487                let incoming = self.trace(&scattered, max_depth - 1, rng);
488                vmul3(attenuation, incoming)
489            } else {
490                [0.0; 3]
491            }
492        } else {
493            self.sky_color(ray)
494        }
495    }
496}
497
498// ── PathTracerBuffer ─────────────────────────────────────────────────────────
499
500/// A pixel accumulation buffer for progressive rendering.
501#[derive(Debug, Clone)]
502pub struct PathTracerBuffer {
503    /// Buffer width in pixels.
504    pub width: usize,
505    /// Buffer height in pixels.
506    pub height: usize,
507    /// Accumulated colour per pixel (RGB, floating point).
508    pub accumulator: Vec<[f32; 3]>,
509    /// Number of samples accumulated per pixel.
510    pub sample_count: Vec<u32>,
511}
512
513impl PathTracerBuffer {
514    /// Create a new buffer of `width × height` pixels, all zeroed.
515    pub fn new(width: usize, height: usize) -> Self {
516        let n = width * height;
517        Self {
518            width,
519            height,
520            accumulator: vec![[0.0; 3]; n],
521            sample_count: vec![0; n],
522        }
523    }
524
525    /// Add a colour sample to pixel `(x, y)`.
526    pub fn add_sample(&mut self, x: usize, y: usize, color: [f32; 3]) {
527        let idx = y * self.width + x;
528        let acc = &mut self.accumulator[idx];
529        acc[0] += color[0];
530        acc[1] += color[1];
531        acc[2] += color[2];
532        self.sample_count[idx] += 1;
533    }
534
535    /// Get the averaged colour at pixel `(x, y)`.
536    pub fn get_pixel(&self, x: usize, y: usize) -> [f32; 3] {
537        let idx = y * self.width + x;
538        let n = self.sample_count[idx] as f32;
539        if n == 0.0 {
540            return [0.0; 3];
541        }
542        let acc = self.accumulator[idx];
543        [acc[0] / n, acc[1] / n, acc[2] / n]
544    }
545
546    /// Return a gamma-corrected (gamma=2) u8 RGB image row-major.
547    pub fn to_rgb8(&self) -> Vec<u8> {
548        let mut out = Vec::with_capacity(self.width * self.height * 3);
549        for y in 0..self.height {
550            for x in 0..self.width {
551                let c = self.get_pixel(x, y);
552                for ch in c.iter() {
553                    let linear = ch.clamp(0.0, 1.0);
554                    let gamma = linear.sqrt(); // gamma = 2
555                    out.push((gamma * 255.999) as u8);
556                }
557            }
558        }
559        out
560    }
561
562    /// Reset all accumulated samples.
563    pub fn clear(&mut self) {
564        for acc in &mut self.accumulator {
565            *acc = [0.0; 3];
566        }
567        for s in &mut self.sample_count {
568            *s = 0;
569        }
570    }
571
572    /// Total number of samples accumulated across all pixels.
573    pub fn total_samples(&self) -> u64 {
574        self.sample_count.iter().map(|&s| s as u64).sum()
575    }
576}
577
578// ── Camera ───────────────────────────────────────────────────────────────────
579
580/// A simple pinhole camera.
581#[derive(Debug, Clone)]
582pub struct Camera {
583    /// Camera origin.
584    pub origin: [f32; 3],
585    lower_left_corner: [f32; 3],
586    horizontal: [f32; 3],
587    vertical: [f32; 3],
588    lens_radius: f32,
589    u: [f32; 3],
590    v: [f32; 3],
591}
592
593impl Camera {
594    /// Construct a camera.
595    ///
596    /// * `look_from` — eye position
597    /// * `look_at` — target position
598    /// * `vup` — world up vector
599    /// * `vfov` — vertical field of view in degrees
600    /// * `aspect_ratio` — image width / height
601    /// * `aperture` — lens aperture (0 = pinhole)
602    /// * `focus_dist` — focus distance
603    #[allow(clippy::too_many_arguments)]
604    pub fn new(
605        look_from: [f32; 3],
606        look_at: [f32; 3],
607        vup: [f32; 3],
608        vfov: f32,
609        aspect_ratio: f32,
610        aperture: f32,
611        focus_dist: f32,
612    ) -> Self {
613        let theta = vfov.to_radians();
614        let h = (theta / 2.0).tan();
615        let viewport_height = 2.0 * h;
616        let viewport_width = aspect_ratio * viewport_height;
617
618        let w = normalize(vsub(look_from, look_at));
619        let u = normalize(cross(vup, w));
620        let v = cross(w, u);
621
622        let horizontal = vmul(u, viewport_width * focus_dist);
623        let vertical = vmul(v, viewport_height * focus_dist);
624        let lower_left_corner = vsub(
625            vsub(vsub(look_from, vmul(horizontal, 0.5)), vmul(vertical, 0.5)),
626            vmul(w, focus_dist),
627        );
628
629        Self {
630            origin: look_from,
631            lower_left_corner,
632            horizontal,
633            vertical,
634            lens_radius: aperture / 2.0,
635            u,
636            v,
637        }
638    }
639
640    /// Generate a ray for pixel coordinates `(s, t)` in \[0,1\]×\[0,1\].
641    pub fn get_ray(&self, s: f32, t: f32, rng: &mut impl Rng) -> Ray {
642        let rd = vmul(self.random_in_unit_disk(rng), self.lens_radius);
643        let offset = vadd(vmul(self.u, rd[0]), vmul(self.v, rd[1]));
644        let dir = vsub(
645            vadd(
646                vadd(self.lower_left_corner, vmul(self.horizontal, s)),
647                vmul(self.vertical, t),
648            ),
649            vadd(self.origin, offset),
650        );
651        Ray::new(vadd(self.origin, offset), dir)
652    }
653
654    fn random_in_unit_disk(&self, rng: &mut impl Rng) -> [f32; 3] {
655        loop {
656            let p = [
657                rng.random_range(-1.0f32..1.0),
658                rng.random_range(-1.0f32..1.0),
659                0.0,
660            ];
661            if dot(p, p) < 1.0 {
662                return p;
663            }
664        }
665    }
666}
667
668// ── PathTracerRenderer ───────────────────────────────────────────────────────
669
670/// Renderer that progressively renders a scene into a `PathTracerBuffer`.
671#[derive(Debug, Clone)]
672pub struct PathTracerRenderer {
673    /// Scene to render.
674    pub scene: PathTracerScene,
675    /// Camera.
676    pub camera: Camera,
677    /// Maximum ray bounce depth.
678    pub max_depth: usize,
679    /// Samples per pixel per call to `render_pass`.
680    pub samples_per_pass: usize,
681}
682
683impl PathTracerRenderer {
684    /// Create a new renderer.
685    pub fn new(
686        scene: PathTracerScene,
687        camera: Camera,
688        max_depth: usize,
689        samples_per_pass: usize,
690    ) -> Self {
691        Self {
692            scene,
693            camera,
694            max_depth,
695            samples_per_pass,
696        }
697    }
698
699    /// Render one progressive pass into `buffer`.
700    ///
701    /// Each pixel receives `self.samples_per_pass` new samples.
702    pub fn render_pass(&self, buffer: &mut PathTracerBuffer) {
703        let w = buffer.width;
704        let h = buffer.height;
705        let mut rng = rand::rng();
706        for y in 0..h {
707            for x in 0..w {
708                let mut color = [0.0f32; 3];
709                for _ in 0..self.samples_per_pass {
710                    let u = (x as f32 + rng.random::<f32>()) / (w - 1) as f32;
711                    let v = (y as f32 + rng.random::<f32>()) / (h - 1) as f32;
712                    let ray = self.camera.get_ray(u, v, &mut rng);
713                    let c = self.scene.trace(&ray, self.max_depth, &mut rng);
714                    color[0] += c[0];
715                    color[1] += c[1];
716                    color[2] += c[2];
717                }
718                let inv = 1.0 / self.samples_per_pass as f32;
719                buffer.add_sample(x, y, vmul(color, inv));
720            }
721        }
722    }
723}
724
725// ── Tests ─────────────────────────────────────────────────────────────────────
726
727#[cfg(test)]
728mod tests {
729    use super::*;
730
731    fn make_rng() -> impl Rng {
732        rand::rng()
733    }
734
735    // ── Ray tests ────────────────────────────────────────────────────────
736
737    #[test]
738    fn test_ray_at_origin() {
739        let r = Ray::new([0.0; 3], [1.0, 0.0, 0.0]);
740        let p = r.at(0.0);
741        assert_eq!(p, [0.0; 3]);
742    }
743
744    #[test]
745    fn test_ray_at_t() {
746        let r = Ray::new([1.0, 2.0, 3.0], [1.0, 0.0, 0.0]);
747        let p = r.at(3.0);
748        assert!((p[0] - 4.0).abs() < 1e-6);
749        assert!((p[1] - 2.0).abs() < 1e-6);
750        assert!((p[2] - 3.0).abs() < 1e-6);
751    }
752
753    #[test]
754    fn test_ray_at_negative_t() {
755        let r = Ray::new([0.0; 3], [0.0, 1.0, 0.0]);
756        let p = r.at(-2.0);
757        assert!((p[1] - (-2.0)).abs() < 1e-6);
758    }
759
760    // ── Sphere tests ─────────────────────────────────────────────────────
761
762    #[test]
763    fn test_sphere_hit_center() {
764        let s = Sphere::new([0.0, 0.0, -1.0], 0.5, 0);
765        let r = Ray::new([0.0; 3], [0.0, 0.0, -1.0]);
766        let hit = s.hit(&r, 0.001, f32::INFINITY);
767        assert!(hit.is_some());
768        let rec = hit.unwrap();
769        assert!(rec.t > 0.4 && rec.t < 0.6);
770    }
771
772    #[test]
773    fn test_sphere_miss() {
774        let s = Sphere::new([0.0, 0.0, -1.0], 0.5, 0);
775        let r = Ray::new([0.0; 3], [0.0, 1.0, 0.0]);
776        assert!(s.hit(&r, 0.001, f32::INFINITY).is_none());
777    }
778
779    #[test]
780    fn test_sphere_hit_from_inside() {
781        let s = Sphere::new([0.0; 3], 1.0, 0);
782        let r = Ray::new([0.0; 3], [1.0, 0.0, 0.0]);
783        let hit = s.hit(&r, 0.001, f32::INFINITY);
784        assert!(hit.is_some());
785        let rec = hit.unwrap();
786        assert!(!rec.front_face);
787    }
788
789    #[test]
790    fn test_sphere_normal_outward() {
791        let s = Sphere::new([0.0; 3], 1.0, 0);
792        let r = Ray::new([0.0, 0.0, 5.0], [0.0, 0.0, -1.0]);
793        let hit = s.hit(&r, 0.001, f32::INFINITY).unwrap();
794        assert!(hit.front_face);
795        assert!((hit.normal[2] - 1.0).abs() < 1e-5);
796    }
797
798    #[test]
799    fn test_sphere_t_range_cull() {
800        let s = Sphere::new([0.0, 0.0, -1.0], 0.5, 0);
801        let r = Ray::new([0.0; 3], [0.0, 0.0, -1.0]);
802        // t_max < actual hit
803        assert!(s.hit(&r, 0.001, 0.1).is_none());
804    }
805
806    // ── Triangle tests ───────────────────────────────────────────────────
807
808    #[test]
809    fn test_triangle_hit() {
810        let tri = Triangle::new([-1.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0], 0);
811        let r = Ray::new([0.0, 0.3, 1.0], [0.0, 0.0, -1.0]);
812        let hit = tri.hit(&r, 0.001, f32::INFINITY);
813        assert!(hit.is_some());
814    }
815
816    #[test]
817    fn test_triangle_miss_outside() {
818        let tri = Triangle::new([-1.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0], 0);
819        let r = Ray::new([5.0, 5.0, 1.0], [0.0, 0.0, -1.0]);
820        assert!(tri.hit(&r, 0.001, f32::INFINITY).is_none());
821    }
822
823    #[test]
824    fn test_triangle_miss_parallel() {
825        let tri = Triangle::new([-1.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0], 0);
826        // Ray parallel to triangle plane
827        let r = Ray::new([0.0, 0.0, 1.0], [1.0, 0.0, 0.0]);
828        assert!(tri.hit(&r, 0.001, f32::INFINITY).is_none());
829    }
830
831    #[test]
832    fn test_triangle_normal_direction() {
833        let tri = Triangle::new([0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0], 0);
834        // Normal should point in +Z for counter-clockwise winding
835        assert!(tri.normal[2].abs() > 0.9);
836    }
837
838    #[test]
839    fn test_triangle_hit_at_vertex() {
840        let tri = Triangle::new([0.0, 0.0, 0.0], [2.0, 0.0, 0.0], [0.0, 2.0, 0.0], 0);
841        // Ray aimed near edge midpoint
842        let r = Ray::new([0.5, 0.5, 1.0], [0.0, 0.0, -1.0]);
843        assert!(tri.hit(&r, 0.001, f32::INFINITY).is_some());
844    }
845
846    // ── Material tests ───────────────────────────────────────────────────
847
848    #[test]
849    fn test_lambertian_scatter() {
850        let mat = Material::lambertian([0.8, 0.3, 0.3]);
851        let ray = Ray::new([0.0; 3], [0.0, 0.0, -1.0]);
852        let hit = HitRecord {
853            t: 1.0,
854            point: [0.0, 0.0, -1.0],
855            normal: [0.0, 0.0, 1.0],
856            material_index: 0,
857            front_face: true,
858        };
859        let mut rng = make_rng();
860        let result = mat.scatter(&ray, &hit, &mut rng);
861        assert!(result.is_some());
862        let (_scattered, attenuation) = result.unwrap();
863        assert!((attenuation[0] - 0.8).abs() < 1e-6);
864    }
865
866    #[test]
867    fn test_metal_scatter() {
868        let mat = Material::metal([0.8, 0.8, 0.8], 0.0);
869        let ray = Ray::new([0.0; 3], normalize([1.0, -1.0, 0.0]));
870        let hit = HitRecord {
871            t: 1.0,
872            point: [0.0; 3],
873            normal: [0.0, 1.0, 0.0],
874            material_index: 0,
875            front_face: true,
876        };
877        let mut rng = make_rng();
878        let result = mat.scatter(&ray, &hit, &mut rng);
879        assert!(result.is_some());
880        let (scattered, _attenuation) = result.unwrap();
881        // Reflected ray should go up (positive y)
882        assert!(scattered.direction[1] > 0.0);
883    }
884
885    #[test]
886    fn test_dielectric_scatter() {
887        let mat = Material::dielectric(1.5);
888        let ray = Ray::new([0.0, 0.0, 1.0], normalize([0.0, 0.0, -1.0]));
889        let hit = HitRecord {
890            t: 1.0,
891            point: [0.0; 3],
892            normal: [0.0, 0.0, 1.0],
893            material_index: 0,
894            front_face: true,
895        };
896        let mut rng = make_rng();
897        let result = mat.scatter(&ray, &hit, &mut rng);
898        assert!(result.is_some());
899        let (_s, attn) = result.unwrap();
900        assert!((attn[0] - 1.0).abs() < 1e-6);
901    }
902
903    #[test]
904    fn test_metal_fuzz_clamped() {
905        let mat = Material::metal([1.0; 3], 5.0);
906        if let MaterialType::Metal(f) = mat.kind {
907            assert!(f <= 1.0);
908        } else {
909            panic!("expected Metal");
910        }
911    }
912
913    // ── Scene tests ──────────────────────────────────────────────────────
914
915    #[test]
916    fn test_scene_add_material() {
917        let mut scene = PathTracerScene::new();
918        let idx = scene.add_material(Material::lambertian([1.0; 3]));
919        assert_eq!(idx, 0);
920        let idx2 = scene.add_material(Material::lambertian([0.5; 3]));
921        assert_eq!(idx2, 1);
922    }
923
924    #[test]
925    fn test_scene_hit_sphere() {
926        let mut scene = PathTracerScene::new();
927        let m = scene.add_material(Material::lambertian([0.5; 3]));
928        scene.add_sphere(Sphere::new([0.0, 0.0, -1.0], 0.5, m));
929        let r = Ray::new([0.0; 3], [0.0, 0.0, -1.0]);
930        assert!(scene.hit_scene(&r, 0.001, f32::INFINITY).is_some());
931    }
932
933    #[test]
934    fn test_scene_miss() {
935        let scene = PathTracerScene::new();
936        let r = Ray::new([0.0; 3], [0.0, 0.0, -1.0]);
937        assert!(scene.hit_scene(&r, 0.001, f32::INFINITY).is_none());
938    }
939
940    #[test]
941    fn test_scene_sky_color_up() {
942        let scene = PathTracerScene::new();
943        let r = Ray::new([0.0; 3], [0.0, 1.0, 0.0]);
944        let c = scene.sky_color(&r);
945        // Should be close to sky_top
946        assert!(c[2] > 0.9);
947    }
948
949    #[test]
950    fn test_scene_trace_no_hit() {
951        let scene = PathTracerScene::new();
952        let r = Ray::new([0.0; 3], [0.0, 1.0, 0.0]);
953        let mut rng = make_rng();
954        let c = scene.trace(&r, 5, &mut rng);
955        // Should be sky colour
956        assert!(c[2] > 0.0);
957    }
958
959    #[test]
960    fn test_scene_trace_depth_zero() {
961        let mut scene = PathTracerScene::new();
962        let m = scene.add_material(Material::lambertian([0.5; 3]));
963        scene.add_sphere(Sphere::new([0.0, 0.0, -1.0], 0.5, m));
964        let r = Ray::new([0.0; 3], [0.0, 0.0, -1.0]);
965        let mut rng = make_rng();
966        let c = scene.trace(&r, 0, &mut rng);
967        assert_eq!(c, [0.0; 3]);
968    }
969
970    #[test]
971    fn test_scene_closest_hit() {
972        let mut scene = PathTracerScene::new();
973        let m = scene.add_material(Material::lambertian([0.5; 3]));
974        scene.add_sphere(Sphere::new([0.0, 0.0, -2.0], 0.5, m));
975        scene.add_sphere(Sphere::new([0.0, 0.0, -1.0], 0.5, m));
976        let r = Ray::new([0.0; 3], [0.0, 0.0, -1.0]);
977        let hit = scene.hit_scene(&r, 0.001, f32::INFINITY).unwrap();
978        // Closer sphere is at z=-1, so t ~ 0.5
979        assert!(hit.t < 1.0);
980    }
981
982    // ── PathTracerBuffer tests ───────────────────────────────────────────
983
984    #[test]
985    fn test_buffer_new() {
986        let buf = PathTracerBuffer::new(4, 4);
987        assert_eq!(buf.width, 4);
988        assert_eq!(buf.height, 4);
989        assert_eq!(buf.total_samples(), 0);
990    }
991
992    #[test]
993    fn test_buffer_add_and_get() {
994        let mut buf = PathTracerBuffer::new(4, 4);
995        buf.add_sample(1, 2, [0.6, 0.4, 0.2]);
996        buf.add_sample(1, 2, [0.4, 0.6, 0.8]);
997        let p = buf.get_pixel(1, 2);
998        assert!((p[0] - 0.5).abs() < 1e-5);
999        assert!((p[1] - 0.5).abs() < 1e-5);
1000        assert!((p[2] - 0.5).abs() < 1e-5);
1001    }
1002
1003    #[test]
1004    fn test_buffer_zero_samples() {
1005        let buf = PathTracerBuffer::new(4, 4);
1006        let p = buf.get_pixel(0, 0);
1007        assert_eq!(p, [0.0; 3]);
1008    }
1009
1010    #[test]
1011    fn test_buffer_to_rgb8_white() {
1012        let mut buf = PathTracerBuffer::new(1, 1);
1013        buf.add_sample(0, 0, [1.0; 3]);
1014        let rgb = buf.to_rgb8();
1015        assert_eq!(rgb.len(), 3);
1016        assert_eq!(rgb[0], 255);
1017    }
1018
1019    #[test]
1020    fn test_buffer_to_rgb8_black() {
1021        let mut buf = PathTracerBuffer::new(1, 1);
1022        buf.add_sample(0, 0, [0.0; 3]);
1023        let rgb = buf.to_rgb8();
1024        assert_eq!(rgb[0], 0);
1025    }
1026
1027    #[test]
1028    fn test_buffer_total_samples() {
1029        let mut buf = PathTracerBuffer::new(2, 2);
1030        buf.add_sample(0, 0, [1.0; 3]);
1031        buf.add_sample(0, 0, [1.0; 3]);
1032        buf.add_sample(1, 1, [0.5; 3]);
1033        assert_eq!(buf.total_samples(), 3);
1034    }
1035
1036    #[test]
1037    fn test_buffer_clear() {
1038        let mut buf = PathTracerBuffer::new(2, 2);
1039        buf.add_sample(0, 0, [1.0; 3]);
1040        buf.clear();
1041        assert_eq!(buf.total_samples(), 0);
1042        assert_eq!(buf.get_pixel(0, 0), [0.0; 3]);
1043    }
1044
1045    #[test]
1046    fn test_buffer_size() {
1047        let buf = PathTracerBuffer::new(8, 6);
1048        assert_eq!(buf.accumulator.len(), 48);
1049        assert_eq!(buf.sample_count.len(), 48);
1050    }
1051
1052    // ── Camera tests ─────────────────────────────────────────────────────
1053
1054    #[test]
1055    fn test_camera_get_ray_center() {
1056        let cam = Camera::new(
1057            [0.0, 0.0, 0.0],
1058            [0.0, 0.0, -1.0],
1059            [0.0, 1.0, 0.0],
1060            90.0,
1061            1.0,
1062            0.0,
1063            1.0,
1064        );
1065        let mut rng = make_rng();
1066        let ray = cam.get_ray(0.5, 0.5, &mut rng);
1067        // Center ray should point roughly in -Z
1068        let d = normalize(ray.direction);
1069        assert!(d[2] < -0.9);
1070    }
1071
1072    #[test]
1073    fn test_camera_origin() {
1074        let cam = Camera::new(
1075            [1.0, 2.0, 3.0],
1076            [0.0, 0.0, 0.0],
1077            [0.0, 1.0, 0.0],
1078            60.0,
1079            1.5,
1080            0.0,
1081            1.0,
1082        );
1083        assert!((cam.origin[0] - 1.0).abs() < 1e-5);
1084    }
1085
1086    // ── Renderer integration test ────────────────────────────────────────
1087
1088    #[test]
1089    fn test_renderer_render_pass_small() {
1090        let mut scene = PathTracerScene::new();
1091        let m = scene.add_material(Material::lambertian([0.7, 0.3, 0.5]));
1092        scene.add_sphere(Sphere::new([0.0, 0.0, -1.0], 0.5, m));
1093        let cam = Camera::new(
1094            [0.0, 0.0, 0.0],
1095            [0.0, 0.0, -1.0],
1096            [0.0, 1.0, 0.0],
1097            90.0,
1098            1.0,
1099            0.0,
1100            1.0,
1101        );
1102        let renderer = PathTracerRenderer::new(scene, cam, 3, 2);
1103        let mut buf = PathTracerBuffer::new(4, 4);
1104        renderer.render_pass(&mut buf);
1105        assert!(buf.total_samples() > 0);
1106        // Each pixel gets exactly one accumulated sample entry (averaged internally)
1107        assert_eq!(buf.total_samples(), (4 * 4) as u64);
1108    }
1109
1110    #[test]
1111    fn test_renderer_rgb8_output_valid() {
1112        let mut scene = PathTracerScene::new();
1113        let m = scene.add_material(Material::lambertian([0.5; 3]));
1114        scene.add_sphere(Sphere::new([0.0, 0.0, -1.0], 0.5, m));
1115        let cam = Camera::new(
1116            [0.0, 0.0, 0.0],
1117            [0.0, 0.0, -1.0],
1118            [0.0, 1.0, 0.0],
1119            90.0,
1120            1.0,
1121            0.0,
1122            1.0,
1123        );
1124        let renderer = PathTracerRenderer::new(scene, cam, 2, 1);
1125        let mut buf = PathTracerBuffer::new(8, 8);
1126        renderer.render_pass(&mut buf);
1127        let rgb = buf.to_rgb8();
1128        assert_eq!(rgb.len(), 8 * 8 * 3);
1129        // All values valid u8
1130        for &v in &rgb {
1131            let _ = v; // just confirm they exist
1132        }
1133    }
1134
1135    // ── Vector helper tests ──────────────────────────────────────────────
1136
1137    #[test]
1138    fn test_vadd() {
1139        let a = [1.0, 2.0, 3.0];
1140        let b = [4.0, 5.0, 6.0];
1141        let c = vadd(a, b);
1142        assert_eq!(c, [5.0, 7.0, 9.0]);
1143    }
1144
1145    #[test]
1146    fn test_vsub() {
1147        let a = [3.0, 2.0, 1.0];
1148        let b = [1.0, 1.0, 1.0];
1149        assert_eq!(vsub(a, b), [2.0, 1.0, 0.0]);
1150    }
1151
1152    #[test]
1153    fn test_dot_orthogonal() {
1154        assert!((dot([1.0, 0.0, 0.0], [0.0, 1.0, 0.0])).abs() < 1e-7);
1155    }
1156
1157    #[test]
1158    fn test_cross_unit_vectors() {
1159        let k = cross([1.0, 0.0, 0.0], [0.0, 1.0, 0.0]);
1160        assert!((k[2] - 1.0).abs() < 1e-7);
1161    }
1162
1163    #[test]
1164    fn test_normalize_length() {
1165        let v = [3.0, 4.0, 0.0];
1166        let n = normalize(v);
1167        let l = length(n);
1168        assert!((l - 1.0).abs() < 1e-6);
1169    }
1170
1171    #[test]
1172    fn test_reflect_normal_incidence() {
1173        let d = [0.0, -1.0, 0.0];
1174        let n = [0.0, 1.0, 0.0];
1175        let r = reflect(d, n);
1176        assert!((r[1] - 1.0).abs() < 1e-6);
1177    }
1178
1179    #[test]
1180    fn test_schlick_zero_angle() {
1181        // cosine=0 → schlick = r0 + (1-r0)*(1-0)^5 = 1.0
1182        let s = schlick(0.0, 1.5);
1183        assert!((s - 1.0).abs() < 1e-5);
1184    }
1185
1186    #[test]
1187    fn test_schlick_grazing() {
1188        // cosine=1 → schlick = r0 + (1-r0)*(1-1)^5 = r0
1189        let s = schlick(1.0, 1.5);
1190        let r0 = ((1.0 - 1.5f32) / (1.0 + 1.5)).powi(2);
1191        assert!((s - r0).abs() < 1e-5);
1192    }
1193}