Skip to main content

proof_engine/render/pbr/
probe.rs

1//! Environment probes and global illumination helpers.
2//!
3//! Provides:
4//! * Spherical harmonic (SH3) projection and evaluation
5//! * Cubemap / equirectangular / octahedral map utilities
6//! * Reflection probe parallax correction
7//! * Tetrahedral-interpolated light probe grids
8//! * Screen-space reflection ray generation
9//! * Irradiance cache with validity ageing
10
11use glam::{Mat3, Vec2, Vec3, Vec4};
12use std::f32::consts::{FRAC_1_PI, PI};
13
14// ─────────────────────────────────────────────────────────────────────────────
15// Local Ray type
16// ─────────────────────────────────────────────────────────────────────────────
17
18/// A ray with an origin and a (unit) direction.
19#[derive(Debug, Clone, Copy)]
20pub struct Ray {
21    pub origin: Vec3,
22    pub direction: Vec3,
23}
24
25impl Ray {
26    pub fn new(origin: Vec3, direction: Vec3) -> Self {
27        Self {
28            origin,
29            direction: direction.normalize(),
30        }
31    }
32
33    /// Evaluate the ray at parameter `t`.
34    #[inline]
35    pub fn at(&self, t: f32) -> Vec3 {
36        self.origin + self.direction * t
37    }
38}
39
40// ─────────────────────────────────────────────────────────────────────────────
41// Spherical Harmonics
42// ─────────────────────────────────────────────────────────────────────────────
43
44/// Third-order (L2) spherical harmonics — 9 real coefficients per colour
45/// channel.  Each `Sh3` stores (R, G, B) for all 9 basis functions.
46#[derive(Debug, Clone)]
47pub struct Sh3 {
48    /// `coeffs[i]` = (r, g, b) for the i-th SH basis function.
49    pub coeffs: [Vec3; 9],
50}
51
52impl Sh3 {
53    /// Zero SH (no contribution).
54    pub fn zero() -> Self {
55        Self {
56            coeffs: [Vec3::ZERO; 9],
57        }
58    }
59
60    /// Create from a flat array of (r,g,b) tuples.
61    pub fn from_coeffs(c: [Vec3; 9]) -> Self {
62        Self { coeffs: c }
63    }
64
65    /// Evaluate the SH at direction `dir` to get irradiance.
66    pub fn evaluate(&self, dir: Vec3) -> Vec3 {
67        let b = sh_basis(dir);
68        let mut result = Vec3::ZERO;
69        for i in 0..9 {
70            result += self.coeffs[i] * b[i];
71        }
72        result.max(Vec3::ZERO)
73    }
74}
75
76impl Default for Sh3 {
77    fn default() -> Self {
78        Self::zero()
79    }
80}
81
82/// SH normalization constants for bands 0–2.
83const SH_C0: f32 = 0.282_094_8; // 1 / (2*sqrt(pi))
84const SH_C1: f32 = 0.488_602_5; // sqrt(3 / (4*pi))
85const SH_C2_A: f32 = 1.092_548_4; // sqrt(15 / (4*pi))
86const SH_C2_B: f32 = 0.315_391_6; // sqrt(5 / (16*pi))
87const SH_C2_C: f32 = 0.546_274_2; // sqrt(15 / (16*pi))
88
89/// Evaluate all 9 SH basis functions at direction `dir`.
90///
91/// The returned array follows the convention:
92/// `[Y0_0, Y1_{-1}, Y1_0, Y1_1, Y2_{-2}, Y2_{-1}, Y2_0, Y2_1, Y2_2]`
93pub fn sh_basis(dir: Vec3) -> [f32; 9] {
94    let (x, y, z) = (dir.x, dir.y, dir.z);
95    [
96        // L=0
97        SH_C0,
98        // L=1
99        -SH_C1 * y,
100        SH_C1 * z,
101        -SH_C1 * x,
102        // L=2
103        SH_C2_A * x * y,
104        -SH_C2_A * y * z,
105        SH_C2_B * (2.0 * z * z - x * x - y * y),
106        -SH_C2_A * x * z,
107        SH_C2_C * (x * x - y * y),
108    ]
109}
110
111/// Monte Carlo projection of a spherical function onto SH3 basis.
112///
113/// `sample_fn` — callable that returns the RGB radiance for a given direction.
114/// `n_samples` — number of uniform sphere samples.
115pub fn project_to_sh(sample_fn: impl Fn(Vec3) -> Vec3, n_samples: usize) -> Sh3 {
116    let mut coeffs = [Vec3::ZERO; 9];
117
118    for i in 0..n_samples {
119        // Uniform sphere sampling using Fibonacci lattice
120        let golden = (1.0 + 5.0_f32.sqrt()) * 0.5;
121        let theta = (1.0 - 2.0 * (i as f32 + 0.5) / n_samples as f32)
122            .clamp(-1.0, 1.0)
123            .acos();
124        let phi = 2.0 * PI * (i as f32) / golden;
125
126        let dir = Vec3::new(theta.sin() * phi.cos(), theta.sin() * phi.sin(), theta.cos());
127        let radiance = sample_fn(dir);
128        let basis = sh_basis(dir);
129
130        for j in 0..9 {
131            coeffs[j] += radiance * basis[j];
132        }
133    }
134
135    // Normalize by solid-angle weight (4π / N)
136    let weight = 4.0 * PI / n_samples as f32;
137    for c in &mut coeffs {
138        *c *= weight;
139    }
140
141    Sh3 { coeffs }
142}
143
144/// Evaluate irradiance for a surface with `normal` from precomputed SH
145/// coefficients (Ramamoorthi & Hanrahan 2001).
146pub fn irradiance_from_sh(sh: &Sh3, normal: Vec3) -> Vec3 {
147    // Pre-computed zonal harmonic + cosine lobe convolution factors
148    const A0: f32 = PI;
149    const A1: f32 = 2.0 * PI / 3.0;
150    const A2: f32 = PI / 4.0;
151
152    let b = sh_basis(normal);
153
154    sh.coeffs[0] * b[0] * A0
155        + sh.coeffs[1] * b[1] * A1
156        + sh.coeffs[2] * b[2] * A1
157        + sh.coeffs[3] * b[3] * A1
158        + sh.coeffs[4] * b[4] * A2
159        + sh.coeffs[5] * b[5] * A2
160        + sh.coeffs[6] * b[6] * A2
161        + sh.coeffs[7] * b[7] * A2
162        + sh.coeffs[8] * b[8] * A2
163}
164
165/// Convolve SH3 with the clamped-cosine (Lambertian) kernel.
166///
167/// This is the ZH product used for irradiance environment maps.
168pub fn convolve_sh_lambert(sh: &Sh3) -> Sh3 {
169    const ZH0: f32 = 3.141_593;
170    const ZH1: f32 = 2.094_395;
171    const ZH2: f32 = 0.785_398;
172
173    let mut out = sh.clone();
174    // Band 0
175    out.coeffs[0] = sh.coeffs[0] * ZH0;
176    // Band 1
177    for i in 1..=3 {
178        out.coeffs[i] = sh.coeffs[i] * ZH1;
179    }
180    // Band 2
181    for i in 4..=8 {
182        out.coeffs[i] = sh.coeffs[i] * ZH2;
183    }
184    out
185}
186
187/// Add two SH3 sets.
188pub fn sh_add(a: &Sh3, b: &Sh3) -> Sh3 {
189    let mut out = Sh3::zero();
190    for i in 0..9 {
191        out.coeffs[i] = a.coeffs[i] + b.coeffs[i];
192    }
193    out
194}
195
196/// Scale SH3 by scalar `s`.
197pub fn sh_scale(sh: &Sh3, s: f32) -> Sh3 {
198    let mut out = sh.clone();
199    for c in &mut out.coeffs {
200        *c *= s;
201    }
202    out
203}
204
205/// Rotate SH3 by a rotation matrix.
206///
207/// Uses the exact band-0/1/2 SH rotation formulas.  Band 0 is invariant; bands
208/// 1 and 2 are rotated using the provided `rotation` matrix.
209pub fn sh_rotate(sh: &Sh3, rotation: &Mat3) -> Sh3 {
210    let mut out = Sh3::zero();
211
212    // Band 0 is rotationally invariant
213    out.coeffs[0] = sh.coeffs[0];
214
215    // Band 1: Y_1m transforms as Cartesian vector components (x, y, z)
216    // Basis ordering: [1]=y, [2]=z, [3]=x
217    let r = *rotation;
218    // Extract the columns of the rotation matrix that map x,y,z to x',y',z'
219    let rx = Vec3::new(r.x_axis.x, r.x_axis.y, r.x_axis.z);
220    let ry = Vec3::new(r.y_axis.x, r.y_axis.y, r.y_axis.z);
221    let rz = Vec3::new(r.z_axis.x, r.z_axis.y, r.z_axis.z);
222
223    // Coefficients for basis [y, z, x] -> indices [1, 2, 3]
224    let b1_y = sh.coeffs[1];
225    let b1_z = sh.coeffs[2];
226    let b1_x = sh.coeffs[3];
227
228    out.coeffs[1] = b1_x * ry.x + b1_y * ry.y + b1_z * ry.z; // new y
229    out.coeffs[2] = b1_x * rz.x + b1_y * rz.y + b1_z * rz.z; // new z
230    out.coeffs[3] = b1_x * rx.x + b1_y * rx.y + b1_z * rx.z; // new x
231
232    // Band 2: use the analytic 5×5 rotation derived from the Wigner D-matrices.
233    // We sample the rotated SH by evaluating in the original basis.
234    // For each of the 5 band-2 functions we sample the rotated direction to find
235    // the coefficients — this is the "reconstruct & re-project" approach, which
236    // is exact for band-2 if we use the right 5 reference directions.
237    let dirs_b2: [Vec3; 5] = [
238        Vec3::new(1.0, 0.0, 0.0),
239        Vec3::new(0.0, 1.0, 0.0),
240        Vec3::new(0.0, 0.0, 1.0),
241        Vec3::new(1.0, 1.0, 0.0).normalize(),
242        Vec3::new(0.0, 1.0, 1.0).normalize(),
243    ];
244
245    // For each direction, evaluate original band-2 SH and then the rotated dir
246    for (idx, &dir) in dirs_b2.iter().enumerate() {
247        let rot_dir = rotation.mul_vec3(dir).normalize();
248        let b_orig = sh_basis(dir);
249        let b_rot = sh_basis(rot_dir);
250
251        // Contribution from original band-2 coefficients
252        let mut val = Vec3::ZERO;
253        for k in 0..5 {
254            val += sh.coeffs[4 + k] * b_orig[4 + k];
255        }
256
257        // The rotated direction's band-2 basis distributes this back
258        for k in 0..5 {
259            out.coeffs[4 + k] += val * b_rot[4 + k];
260        }
261
262        let _ = idx; // suppress warning
263    }
264
265    // Normalise the band-2 projection (5 samples for 5 coefficients — biased,
266    // but an acceptable approximation when sample directions are well-chosen)
267    for k in 0..5 {
268        out.coeffs[4 + k] /= 5.0;
269    }
270
271    out
272}
273
274// ─────────────────────────────────────────────────────────────────────────────
275// Cubemap utilities
276// ─────────────────────────────────────────────────────────────────────────────
277
278/// The six faces of a cubemap.
279#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
280pub enum CubemapFace {
281    PosX,
282    NegX,
283    PosY,
284    NegY,
285    PosZ,
286    NegZ,
287}
288
289impl CubemapFace {
290    /// Return the face index in the conventional order: +X, -X, +Y, -Y, +Z, -Z.
291    pub fn index(self) -> usize {
292        match self {
293            CubemapFace::PosX => 0,
294            CubemapFace::NegX => 1,
295            CubemapFace::PosY => 2,
296            CubemapFace::NegY => 3,
297            CubemapFace::PosZ => 4,
298            CubemapFace::NegZ => 5,
299        }
300    }
301
302    /// Return all six faces.
303    pub fn all() -> [CubemapFace; 6] {
304        [
305            CubemapFace::PosX,
306            CubemapFace::NegX,
307            CubemapFace::PosY,
308            CubemapFace::NegY,
309            CubemapFace::PosZ,
310            CubemapFace::NegZ,
311        ]
312    }
313}
314
315/// Map a world-space direction `dir` to a cubemap face and per-face UV in [0, 1]².
316pub fn dir_to_face_uv(dir: Vec3) -> (CubemapFace, Vec2) {
317    let abs = dir.abs();
318    let (face, u_raw, v_raw) = if abs.x >= abs.y && abs.x >= abs.z {
319        if dir.x > 0.0 {
320            (CubemapFace::PosX, -dir.z / dir.x, -dir.y / dir.x)
321        } else {
322            (CubemapFace::NegX, dir.z / (-dir.x), -dir.y / (-dir.x))
323        }
324    } else if abs.y >= abs.x && abs.y >= abs.z {
325        if dir.y > 0.0 {
326            (CubemapFace::PosY, dir.x / dir.y, dir.z / dir.y)
327        } else {
328            (CubemapFace::NegY, dir.x / (-dir.y), -dir.z / (-dir.y))
329        }
330    } else if dir.z > 0.0 {
331        (CubemapFace::PosZ, dir.x / dir.z, -dir.y / dir.z)
332    } else {
333        (CubemapFace::NegZ, -dir.x / (-dir.z), -dir.y / (-dir.z))
334    };
335
336    let uv = Vec2::new(u_raw * 0.5 + 0.5, v_raw * 0.5 + 0.5);
337    (face, uv.clamp(Vec2::ZERO, Vec2::ONE))
338}
339
340/// Map a cubemap face + per-face UV back to a world-space direction.
341pub fn face_uv_to_dir(face: CubemapFace, uv: Vec2) -> Vec3 {
342    let uc = uv.x * 2.0 - 1.0;
343    let vc = uv.y * 2.0 - 1.0;
344
345    let dir = match face {
346        CubemapFace::PosX => Vec3::new(1.0, -vc, -uc),
347        CubemapFace::NegX => Vec3::new(-1.0, -vc, uc),
348        CubemapFace::PosY => Vec3::new(uc, 1.0, vc),
349        CubemapFace::NegY => Vec3::new(uc, -1.0, -vc),
350        CubemapFace::PosZ => Vec3::new(uc, -vc, 1.0),
351        CubemapFace::NegZ => Vec3::new(-uc, -vc, -1.0),
352    };
353    dir.normalize()
354}
355
356/// Convert an equirectangular UV `[0,1]²` to a unit direction.
357pub fn equirect_to_dir(uv: Vec2) -> Vec3 {
358    let phi = uv.x * 2.0 * PI - PI; // [-π, π]
359    let theta = uv.y * PI; // [0, π]
360    Vec3::new(theta.sin() * phi.cos(), theta.cos(), theta.sin() * phi.sin())
361}
362
363/// Convert a unit direction to equirectangular UV `[0,1]²`.
364pub fn dir_to_equirect(dir: Vec3) -> Vec2 {
365    let dir = dir.normalize();
366    let phi = dir.z.atan2(dir.x); // [-π, π]
367    let theta = dir.y.clamp(-1.0, 1.0).acos(); // [0, π]
368    Vec2::new(
369        (phi + PI) / (2.0 * PI),
370        theta / PI,
371    )
372}
373
374/// Encode a unit direction using the octahedral map (Cigolle et al., 2014).
375///
376/// Returns a value in `[-1, 1]²`.
377pub fn octahedral_map(dir: Vec3) -> Vec2 {
378    let dir = dir.normalize();
379    let l1 = dir.x.abs() + dir.y.abs() + dir.z.abs();
380    let p = Vec2::new(dir.x / l1, dir.y / l1);
381    if dir.z < 0.0 {
382        let sx = if p.x >= 0.0 { 1.0f32 } else { -1.0f32 };
383        let sy = if p.y >= 0.0 { 1.0f32 } else { -1.0f32 };
384        Vec2::new((1.0 - p.y.abs()) * sx, (1.0 - p.x.abs()) * sy)
385    } else {
386        p
387    }
388}
389
390/// Decode an octahedral-mapped UV back to a unit direction.
391pub fn octahedral_unmap(uv: Vec2) -> Vec3 {
392    let p = uv;
393    let z = 1.0 - p.x.abs() - p.y.abs();
394    let dir = if z >= 0.0 {
395        Vec3::new(p.x, p.y, z)
396    } else {
397        let sx = if p.x >= 0.0 { 1.0f32 } else { -1.0f32 };
398        let sy = if p.y >= 0.0 { 1.0f32 } else { -1.0f32 };
399        Vec3::new((1.0 - p.y.abs()) * sx, (1.0 - p.x.abs()) * sy, z)
400    };
401    dir.normalize()
402}
403
404// ─────────────────────────────────────────────────────────────────────────────
405// Reflection probes
406// ─────────────────────────────────────────────────────────────────────────────
407
408/// A spherical reflection probe that captures environment radiance at a point.
409#[derive(Debug, Clone)]
410pub struct ReflectionProbe {
411    pub position: Vec3,
412    /// Influence radius — objects within this sphere can use this probe.
413    pub radius: f32,
414    /// Priority weight for blending when multiple probes overlap.
415    pub importance: f32,
416    /// Whether to apply box parallax correction.
417    pub parallax_correction: bool,
418}
419
420impl ReflectionProbe {
421    pub fn new(position: Vec3, radius: f32, importance: f32, parallax_correction: bool) -> Self {
422        Self {
423            position,
424            radius,
425            importance,
426            parallax_correction,
427        }
428    }
429
430    /// Blend weight for a sample at `sample_pos`.
431    ///
432    /// Returns 0 outside the influence radius, and 1 at the probe centre.
433    pub fn blend_weight(&self, sample_pos: Vec3) -> f32 {
434        blend_weight(self, sample_pos)
435    }
436}
437
438/// Compute the blend weight for a reflection probe at `sample_pos`.
439pub fn blend_weight(probe: &ReflectionProbe, sample_pos: Vec3) -> f32 {
440    let dist = (probe.position - sample_pos).length();
441    if dist >= probe.radius {
442        return 0.0;
443    }
444    let t = dist / probe.radius;
445    // Smooth step with cubic falloff
446    1.0 - t * t * (3.0 - 2.0 * t)
447}
448
449/// Apply parallax correction to a reflection direction.
450///
451/// Assumes the probe captures the environment inside an AABB centred at
452/// `probe_pos` with half-extents `box_half`.
453///
454/// `dir`        — un-corrected reflection direction (unit)
455/// `sample_pos` — world position of the shaded surface point
456/// `probe_pos`  — world position of the probe centre
457/// `box_half`   — half-extents of the proxy box
458pub fn parallax_correct_dir(
459    dir: Vec3,
460    sample_pos: Vec3,
461    probe_pos: Vec3,
462    box_half: Vec3,
463) -> Vec3 {
464    let dir = dir.normalize();
465    // Ray-AABB intersection from sample_pos
466    let box_min = probe_pos - box_half;
467    let box_max = probe_pos + box_half;
468
469    let inv_dir = Vec3::new(
470        if dir.x.abs() > 1e-10 { 1.0 / dir.x } else { f32::MAX },
471        if dir.y.abs() > 1e-10 { 1.0 / dir.y } else { f32::MAX },
472        if dir.z.abs() > 1e-10 { 1.0 / dir.z } else { f32::MAX },
473    );
474
475    let t0 = (box_min - sample_pos) * inv_dir;
476    let t1 = (box_max - sample_pos) * inv_dir;
477
478    let t_max = Vec3::new(t0.x.max(t1.x), t0.y.max(t1.y), t0.z.max(t1.z));
479    let t_hit = t_max.x.min(t_max.y).min(t_max.z).max(0.0);
480
481    // Intersection point on the proxy box
482    let hit = sample_pos + dir * t_hit;
483
484    // Direction from probe centre to intersection
485    (hit - probe_pos).normalize()
486}
487
488// ─────────────────────────────────────────────────────────────────────────────
489// Light probe grid
490// ─────────────────────────────────────────────────────────────────────────────
491
492/// A regular 3D grid of SH3 irradiance probes.
493///
494/// At runtime, shaders can trilinearly interpolate between the 8 nearest probes
495/// to get smooth irradiance at any position within the grid bounds.
496#[derive(Debug, Clone)]
497pub struct LightProbeGrid {
498    /// World-space AABB minimum corner.
499    pub min: Vec3,
500    /// World-space AABB maximum corner.
501    pub max: Vec3,
502    /// Number of probes along each axis.
503    pub resolution: [u32; 3],
504    /// Flat probe storage, in x-major order (x + y*rx + z*rx*ry).
505    pub probes: Vec<Sh3>,
506}
507
508impl LightProbeGrid {
509    /// Create a grid with all probes initialised to zero.
510    pub fn new(min: Vec3, max: Vec3, resolution: [u32; 3]) -> Self {
511        let count = (resolution[0] * resolution[1] * resolution[2]) as usize;
512        Self {
513            min,
514            max,
515            resolution,
516            probes: vec![Sh3::zero(); count],
517        }
518    }
519
520    /// Cell size along each axis.
521    pub fn cell_size(&self) -> Vec3 {
522        let r = Vec3::new(
523            (self.resolution[0] - 1).max(1) as f32,
524            (self.resolution[1] - 1).max(1) as f32,
525            (self.resolution[2] - 1).max(1) as f32,
526        );
527        (self.max - self.min) / r
528    }
529
530    /// Linear index of probe at grid coordinates `(ix, iy, iz)`.
531    fn index(&self, ix: usize, iy: usize, iz: usize) -> usize {
532        let rx = self.resolution[0] as usize;
533        let ry = self.resolution[1] as usize;
534        ix + iy * rx + iz * rx * ry
535    }
536
537    /// Get a reference to the probe at grid coordinates.
538    pub fn probe_at(&self, ix: usize, iy: usize, iz: usize) -> &Sh3 {
539        &self.probes[self.index(ix, iy, iz)]
540    }
541
542    /// Get a mutable reference to the probe at grid coordinates.
543    pub fn probe_at_mut(&mut self, ix: usize, iy: usize, iz: usize) -> &mut Sh3 {
544        let idx = self.index(ix, iy, iz);
545        &mut self.probes[idx]
546    }
547
548    /// Trilinearly interpolate SH3 at world-space position `pos`.
549    pub fn sample(&self, pos: Vec3) -> Sh3 {
550        let cell_size = self.cell_size();
551        let local = (pos - self.min) / cell_size;
552
553        let rx = (self.resolution[0] - 1) as f32;
554        let ry = (self.resolution[1] - 1) as f32;
555        let rz = (self.resolution[2] - 1) as f32;
556
557        let lx = local.x.clamp(0.0, rx);
558        let ly = local.y.clamp(0.0, ry);
559        let lz = local.z.clamp(0.0, rz);
560
561        let ix0 = (lx as usize).min(self.resolution[0] as usize - 1);
562        let iy0 = (ly as usize).min(self.resolution[1] as usize - 1);
563        let iz0 = (lz as usize).min(self.resolution[2] as usize - 1);
564        let ix1 = (ix0 + 1).min(self.resolution[0] as usize - 1);
565        let iy1 = (iy0 + 1).min(self.resolution[1] as usize - 1);
566        let iz1 = (iz0 + 1).min(self.resolution[2] as usize - 1);
567
568        let tx = lx - ix0 as f32;
569        let ty = ly - iy0 as f32;
570        let tz = lz - iz0 as f32;
571
572        // Trilinear blend of 8 corner probes
573        let p000 = self.probe_at(ix0, iy0, iz0);
574        let p100 = self.probe_at(ix1, iy0, iz0);
575        let p010 = self.probe_at(ix0, iy1, iz0);
576        let p110 = self.probe_at(ix1, iy1, iz0);
577        let p001 = self.probe_at(ix0, iy0, iz1);
578        let p101 = self.probe_at(ix1, iy0, iz1);
579        let p011 = self.probe_at(ix0, iy1, iz1);
580        let p111 = self.probe_at(ix1, iy1, iz1);
581
582        let lerp_sh = |a: &Sh3, b: &Sh3, t: f32| -> Sh3 {
583            sh_add(a, &sh_scale(&sh_add(b, &sh_scale(a, -1.0)), t))
584        };
585
586        let s00 = lerp_sh(p000, p100, tx);
587        let s10 = lerp_sh(p010, p110, tx);
588        let s01 = lerp_sh(p001, p101, tx);
589        let s11 = lerp_sh(p011, p111, tx);
590
591        let s0 = lerp_sh(&s00, &s10, ty);
592        let s1 = lerp_sh(&s01, &s11, ty);
593
594        lerp_sh(&s0, &s1, tz)
595    }
596
597    /// Compute ambient irradiance at `pos` for a surface with `normal`.
598    pub fn compute_ambient(&self, pos: Vec3, normal: Vec3) -> Vec3 {
599        let sh = self.sample(pos);
600        let conv = convolve_sh_lambert(&sh);
601        irradiance_from_sh(&conv, normal.normalize()).max(Vec3::ZERO)
602    }
603}
604
605// ─────────────────────────────────────────────────────────────────────────────
606// Baked AO volume
607// ─────────────────────────────────────────────────────────────────────────────
608
609/// A 3D grid of pre-baked ambient-occlusion values.
610#[derive(Debug, Clone)]
611pub struct BakedAo {
612    pub grid_size: [u32; 3],
613    pub data: Vec<f32>,
614}
615
616impl BakedAo {
617    /// Allocate an AO volume initialised to 1.0 (fully unoccluded).
618    pub fn new(grid_size: [u32; 3]) -> Self {
619        let count = (grid_size[0] * grid_size[1] * grid_size[2]) as usize;
620        Self {
621            grid_size,
622            data: vec![1.0; count],
623        }
624    }
625
626    fn index(&self, ix: usize, iy: usize, iz: usize) -> usize {
627        let rx = self.grid_size[0] as usize;
628        let ry = self.grid_size[1] as usize;
629        ix + iy * rx + iz * rx * ry
630    }
631
632    /// Sample AO at continuous grid coordinates using trilinear interpolation.
633    pub fn sample_trilinear(&self, gx: f32, gy: f32, gz: f32) -> f32 {
634        let rx = (self.grid_size[0] - 1) as f32;
635        let ry = (self.grid_size[1] - 1) as f32;
636        let rz = (self.grid_size[2] - 1) as f32;
637
638        let lx = gx.clamp(0.0, rx);
639        let ly = gy.clamp(0.0, ry);
640        let lz = gz.clamp(0.0, rz);
641
642        let ix0 = (lx as usize).min(self.grid_size[0] as usize - 1);
643        let iy0 = (ly as usize).min(self.grid_size[1] as usize - 1);
644        let iz0 = (lz as usize).min(self.grid_size[2] as usize - 1);
645        let ix1 = (ix0 + 1).min(self.grid_size[0] as usize - 1);
646        let iy1 = (iy0 + 1).min(self.grid_size[1] as usize - 1);
647        let iz1 = (iz0 + 1).min(self.grid_size[2] as usize - 1);
648
649        let tx = lx - ix0 as f32;
650        let ty = ly - iy0 as f32;
651        let tz = lz - iz0 as f32;
652
653        macro_rules! ao {
654            ($x:expr, $y:expr, $z:expr) => {
655                self.data[self.index($x, $y, $z)]
656            };
657        }
658
659        let i00 = ao!(ix0, iy0, iz0) * (1.0 - tx) + ao!(ix1, iy0, iz0) * tx;
660        let i10 = ao!(ix0, iy1, iz0) * (1.0 - tx) + ao!(ix1, iy1, iz0) * tx;
661        let i01 = ao!(ix0, iy0, iz1) * (1.0 - tx) + ao!(ix1, iy0, iz1) * tx;
662        let i11 = ao!(ix0, iy1, iz1) * (1.0 - tx) + ao!(ix1, iy1, iz1) * tx;
663
664        let j0 = i00 * (1.0 - ty) + i10 * ty;
665        let j1 = i01 * (1.0 - ty) + i11 * ty;
666
667        j0 * (1.0 - tz) + j1 * tz
668    }
669}
670
671// ─────────────────────────────────────────────────────────────────────────────
672// Screen-space reflections (CPU-side math)
673// ─────────────────────────────────────────────────────────────────────────────
674
675/// Generate an SSR ray in view space.
676///
677/// `pos_vs`    — surface position in view space
678/// `normal_vs` — surface normal in view space
679/// `roughness` — material roughness (used for cone jitter)
680/// `jitter`    — random jitter in [0, 1) for anti-aliasing (from blue noise etc.)
681pub fn ssr_ray(pos_vs: Vec3, normal_vs: Vec3, roughness: f32, jitter: f32) -> Ray {
682    let view_dir = -pos_vs.normalize(); // looking from origin toward surface
683    let normal = normal_vs.normalize();
684
685    // Base reflection direction
686    let reflect_dir = (view_dir - 2.0 * view_dir.dot(normal) * normal).normalize();
687
688    // Jitter the reflection direction by roughness-scaled cone
689    let cone_angle = roughness * std::f32::consts::FRAC_PI_2 * 0.5;
690    let jitter_angle = jitter * cone_angle;
691
692    // Build local tangent frame around reflect_dir
693    let (t, b) = orthonormal_basis(reflect_dir);
694    let phi = jitter * 2.0 * PI;
695    let sin_j = jitter_angle.sin();
696    let cos_j = jitter_angle.cos();
697
698    let jittered = (reflect_dir * cos_j + t * sin_j * phi.cos() + b * sin_j * phi.sin()).normalize();
699
700    Ray::new(pos_vs, jittered)
701}
702
703/// Compute the SSR fade factor — attenuates hits near screen edges and for
704/// high roughness.
705///
706/// `screen_uv` — screen UV of the hit point [0,1]²
707/// `hit_dist`  — distance travelled by the SSR ray
708/// `roughness` — material roughness
709pub fn ssr_fade(screen_uv: Vec2, hit_dist: f32, roughness: f32) -> f32 {
710    // Edge fade
711    let edge_dist = Vec2::new(
712        screen_uv.x.min(1.0 - screen_uv.x),
713        screen_uv.y.min(1.0 - screen_uv.y),
714    );
715    let edge_fade = (edge_dist.x / 0.1).clamp(0.0, 1.0) * (edge_dist.y / 0.1).clamp(0.0, 1.0);
716
717    // Distance fade (long SSR rays are less reliable)
718    let dist_fade = (1.0 - (hit_dist / 50.0).clamp(0.0, 1.0)).max(0.0);
719
720    // Roughness fade (only smooth surfaces show reflections)
721    let rough_fade = 1.0 - roughness.clamp(0.0, 1.0);
722
723    edge_fade * dist_fade * rough_fade
724}
725
726/// Duff orthonormal basis (re-exported here for use in ssr_ray).
727fn orthonormal_basis(n: Vec3) -> (Vec3, Vec3) {
728    let sign = if n.z >= 0.0 { 1.0_f32 } else { -1.0_f32 };
729    let a = -1.0 / (sign + n.z);
730    let b = n.x * n.y * a;
731    let t = Vec3::new(1.0 + sign * n.x * n.x * a, sign * b, -sign * n.x);
732    let bi = Vec3::new(b, sign + n.y * n.y * a, -n.y);
733    (t, bi)
734}
735
736// ─────────────────────────────────────────────────────────────────────────────
737// Irradiance Cache
738// ─────────────────────────────────────────────────────────────────────────────
739
740/// A single cached irradiance sample.
741#[derive(Debug, Clone)]
742pub struct IrradianceCacheEntry {
743    pub position: Vec3,
744    pub normal: Vec3,
745    /// Cached irradiance value (linear RGB).
746    pub irradiance: Vec3,
747    /// Validity weight — decays to zero over time / distance.
748    pub validity: f32,
749}
750
751impl IrradianceCacheEntry {
752    pub fn new(position: Vec3, normal: Vec3, irradiance: Vec3, validity: f32) -> Self {
753        Self {
754            position,
755            normal: normal.normalize(),
756            irradiance,
757            validity: validity.clamp(0.0, 1.0),
758        }
759    }
760
761    /// Interpolation weight for a query at `(pos, n)`.
762    pub fn weight(&self, pos: Vec3, normal: Vec3, max_dist: f32) -> f32 {
763        let dist = (self.position - pos).length();
764        if dist >= max_dist || self.validity < 1e-4 {
765            return 0.0;
766        }
767        let dist_w = 1.0 - dist / max_dist;
768        let normal_w = self.normal.dot(normal.normalize()).max(0.0);
769        dist_w * dist_w * normal_w * self.validity
770    }
771}
772
773/// A cache of irradiance samples for global-illumination interpolation.
774///
775/// New samples are inserted and queries interpolate nearby entries using
776/// distance and normal similarity as weights.
777pub struct IrradianceCache {
778    pub entries: Vec<IrradianceCacheEntry>,
779    /// Maximum number of entries (oldest are evicted when at capacity).
780    pub capacity: usize,
781}
782
783impl IrradianceCache {
784    pub fn new(capacity: usize) -> Self {
785        Self {
786            entries: Vec::with_capacity(capacity),
787            capacity,
788        }
789    }
790
791    /// Query the cache for irradiance at `(pos, normal)`.
792    ///
793    /// Interpolates from all entries within `max_dist`.  Returns `None` if
794    /// there are no valid entries nearby.
795    pub fn query(&self, pos: Vec3, normal: Vec3, max_dist: f32) -> Option<Vec3> {
796        let mut weighted_sum = Vec3::ZERO;
797        let mut weight_total = 0.0f32;
798
799        for entry in &self.entries {
800            let w = entry.weight(pos, normal, max_dist);
801            if w > 1e-6 {
802                weighted_sum += entry.irradiance * w;
803                weight_total += w;
804            }
805        }
806
807        if weight_total < 1e-6 {
808            None
809        } else {
810            Some(weighted_sum / weight_total)
811        }
812    }
813
814    /// Insert a new irradiance sample.
815    ///
816    /// If the cache is at capacity the entry with the lowest validity is evicted.
817    pub fn insert(&mut self, pos: Vec3, normal: Vec3, irradiance: Vec3, validity: f32) {
818        // De-duplicate: if a very close entry already exists, update it instead.
819        let merge_dist = 0.01f32;
820        for entry in &mut self.entries {
821            if (entry.position - pos).length() < merge_dist
822                && entry.normal.dot(normal.normalize()) > 0.99
823            {
824                // Exponential moving average update
825                let alpha = 0.2f32;
826                entry.irradiance = entry.irradiance * (1.0 - alpha) + irradiance * alpha;
827                entry.validity = validity;
828                return;
829            }
830        }
831
832        if self.entries.len() >= self.capacity {
833            // Evict lowest-validity entry
834            let evict = self
835                .entries
836                .iter()
837                .enumerate()
838                .min_by(|(_, a), (_, b)| a.validity.partial_cmp(&b.validity).unwrap())
839                .map(|(i, _)| i);
840            if let Some(idx) = evict {
841                self.entries.swap_remove(idx);
842            }
843        }
844
845        self.entries.push(IrradianceCacheEntry::new(pos, normal, irradiance, validity));
846    }
847
848    /// Reduce validity of all entries within `radius` of `pos` by `decay_rate`.
849    pub fn update_validity(&mut self, pos: Vec3, decay_rate: f32) {
850        for entry in &mut self.entries {
851            let dist = (entry.position - pos).length();
852            let dist_w = 1.0 - (dist / 50.0).clamp(0.0, 1.0);
853            entry.validity = (entry.validity - decay_rate * dist_w).max(0.0);
854        }
855    }
856
857    /// Remove all entries with validity below threshold.
858    pub fn prune(&mut self, threshold: f32) {
859        self.entries.retain(|e| e.validity >= threshold);
860    }
861
862    /// Number of entries currently in the cache.
863    pub fn len(&self) -> usize {
864        self.entries.len()
865    }
866
867    pub fn is_empty(&self) -> bool {
868        self.entries.is_empty()
869    }
870
871    /// Clear all entries.
872    pub fn clear(&mut self) {
873        self.entries.clear();
874    }
875}
876
877// ─────────────────────────────────────────────────────────────────────────────
878// Re-exports used by mod.rs (PbrMaterial::f0 references brdf::fresnel)
879// ─────────────────────────────────────────────────────────────────────────────
880
881use super::brdf;
882
883// ─────────────────────────────────────────────────────────────────────────────
884// Tests
885// ─────────────────────────────────────────────────────────────────────────────
886
887#[cfg(test)]
888mod tests {
889    use super::*;
890    use glam::Vec3;
891
892    // ── SH basis ──────────────────────────────────────────────────────────────
893
894    #[test]
895    fn sh_basis_length_is_9() {
896        let b = sh_basis(Vec3::Y);
897        assert_eq!(b.len(), 9);
898    }
899
900    #[test]
901    fn sh_basis_band0_is_constant() {
902        let b1 = sh_basis(Vec3::Y);
903        let b2 = sh_basis(Vec3::X);
904        // Band 0 is constant
905        assert!((b1[0] - b2[0]).abs() < 1e-6);
906    }
907
908    #[test]
909    fn sh_project_and_evaluate_constant_fn() {
910        // Projecting a constant function f(ω)=1 and evaluating at any direction
911        // should give approximately 4π * C0 (the monopole integral).
912        let sh = project_to_sh(|_| Vec3::ONE, 2048);
913        let val = sh.evaluate(Vec3::Y);
914        // Band 0 contribution: SH_C0 * SH_C0 * 4π ≈ 1.0 (after normalisation)
915        assert!(
916            val.x > 0.5 && val.x < 2.0,
917            "SH evaluation of constant fn should be ~1: {val:?}"
918        );
919    }
920
921    #[test]
922    fn sh_add_is_commutative() {
923        let mut a = Sh3::zero();
924        let mut b = Sh3::zero();
925        a.coeffs[0] = Vec3::new(1.0, 0.0, 0.0);
926        b.coeffs[1] = Vec3::new(0.0, 1.0, 0.0);
927        let ab = sh_add(&a, &b);
928        let ba = sh_add(&b, &a);
929        for i in 0..9 {
930            assert!((ab.coeffs[i] - ba.coeffs[i]).length() < 1e-6);
931        }
932    }
933
934    #[test]
935    fn sh_scale_zero_gives_zero() {
936        let mut sh = Sh3::zero();
937        sh.coeffs[0] = Vec3::ONE;
938        let scaled = sh_scale(&sh, 0.0);
939        assert_eq!(scaled.coeffs[0], Vec3::ZERO);
940    }
941
942    // ── Cubemap ───────────────────────────────────────────────────────────────
943
944    #[test]
945    fn dir_to_face_uv_round_trip() {
946        let dirs = [
947            Vec3::X,
948            -Vec3::X,
949            Vec3::Y,
950            -Vec3::Y,
951            Vec3::Z,
952            -Vec3::Z,
953            Vec3::new(1.0, 1.0, 0.0).normalize(),
954        ];
955        for dir in dirs {
956            let (face, uv) = dir_to_face_uv(dir);
957            let recovered = face_uv_to_dir(face, uv);
958            let dot = dir.dot(recovered);
959            assert!(
960                dot > 0.99,
961                "Round-trip dir={dir:?} -> face={face:?} uv={uv:?} -> {recovered:?}, dot={dot}"
962            );
963        }
964    }
965
966    #[test]
967    fn equirect_round_trip() {
968        let dirs = [Vec3::X, Vec3::Y, Vec3::Z, Vec3::new(0.5, 0.7, 0.3).normalize()];
969        for dir in dirs {
970            let uv = dir_to_equirect(dir);
971            let back = equirect_to_dir(uv);
972            let dot = dir.dot(back);
973            assert!(dot > 0.999, "Equirect round-trip dot={dot} for dir={dir:?}");
974        }
975    }
976
977    #[test]
978    fn octahedral_round_trip() {
979        let dirs = [
980            Vec3::X,
981            Vec3::Y,
982            Vec3::Z,
983            -Vec3::X,
984            -Vec3::Y,
985            -Vec3::Z,
986            Vec3::new(0.5, 0.3, 0.8).normalize(),
987        ];
988        for dir in dirs {
989            let enc = octahedral_map(dir);
990            let dec = octahedral_unmap(enc);
991            let dot = dir.dot(dec);
992            assert!(dot > 0.999, "Oct round-trip failed for {dir:?}: dot={dot}");
993        }
994    }
995
996    // ── Reflection probe ──────────────────────────────────────────────────────
997
998    #[test]
999    fn probe_weight_zero_outside_radius() {
1000        let probe = ReflectionProbe::new(Vec3::ZERO, 5.0, 1.0, false);
1001        let w = probe.blend_weight(Vec3::new(10.0, 0.0, 0.0));
1002        assert_eq!(w, 0.0);
1003    }
1004
1005    #[test]
1006    fn probe_weight_one_at_centre() {
1007        let probe = ReflectionProbe::new(Vec3::ZERO, 5.0, 1.0, false);
1008        let w = probe.blend_weight(Vec3::ZERO);
1009        assert!((w - 1.0).abs() < 1e-5);
1010    }
1011
1012    // ── Light probe grid ──────────────────────────────────────────────────────
1013
1014    #[test]
1015    fn light_probe_grid_trilinear_at_corner() {
1016        let mut grid = LightProbeGrid::new(Vec3::ZERO, Vec3::ONE, [2, 2, 2]);
1017        // Set one probe to a constant colour
1018        grid.probe_at_mut(0, 0, 0).coeffs[0] = Vec3::new(1.0, 0.0, 0.0);
1019        let sh = grid.sample(Vec3::ZERO);
1020        assert!(sh.coeffs[0].x > 0.5, "Should pick up the red probe at corner");
1021    }
1022
1023    // ── SSR ───────────────────────────────────────────────────────────────────
1024
1025    #[test]
1026    fn ssr_ray_direction_is_unit() {
1027        let pos = Vec3::new(0.0, 0.0, -5.0);
1028        let normal = Vec3::Z;
1029        let ray = ssr_ray(pos, normal, 0.2, 0.3);
1030        assert!(
1031            (ray.direction.length() - 1.0).abs() < 1e-4,
1032            "SSR ray direction must be unit: {}",
1033            ray.direction.length()
1034        );
1035    }
1036
1037    #[test]
1038    fn ssr_fade_zero_at_edge() {
1039        let uv = Vec2::new(0.0, 0.5); // at left edge
1040        let fade = ssr_fade(uv, 5.0, 0.1);
1041        assert_eq!(fade, 0.0, "Fade should be 0 at screen edge");
1042    }
1043
1044    #[test]
1045    fn ssr_fade_high_roughness_is_low() {
1046        let uv = Vec2::new(0.5, 0.5);
1047        let f_smooth = ssr_fade(uv, 1.0, 0.0);
1048        let f_rough = ssr_fade(uv, 1.0, 1.0);
1049        assert!(
1050            f_smooth > f_rough,
1051            "Smooth surfaces should have higher SSR fade factor"
1052        );
1053    }
1054
1055    // ── Irradiance cache ──────────────────────────────────────────────────────
1056
1057    #[test]
1058    fn irradiance_cache_insert_and_query() {
1059        let mut cache = IrradianceCache::new(64);
1060        cache.insert(Vec3::ZERO, Vec3::Y, Vec3::new(1.0, 0.5, 0.2), 1.0);
1061        let result = cache.query(Vec3::ZERO, Vec3::Y, 1.0);
1062        assert!(result.is_some(), "Should find nearby entry");
1063        let irr = result.unwrap();
1064        assert!((irr - Vec3::new(1.0, 0.5, 0.2)).length() < 0.01);
1065    }
1066
1067    #[test]
1068    fn irradiance_cache_miss_returns_none() {
1069        let mut cache = IrradianceCache::new(16);
1070        cache.insert(Vec3::new(100.0, 0.0, 0.0), Vec3::Y, Vec3::ONE, 1.0);
1071        let result = cache.query(Vec3::ZERO, Vec3::Y, 1.0);
1072        assert!(result.is_none(), "Should return None when no nearby entry");
1073    }
1074
1075    #[test]
1076    fn irradiance_cache_eviction() {
1077        let mut cache = IrradianceCache::new(2);
1078        cache.insert(Vec3::new(0.0, 0.0, 0.0), Vec3::Y, Vec3::ONE, 1.0);
1079        cache.insert(Vec3::new(10.0, 0.0, 0.0), Vec3::Y, Vec3::ONE, 0.5);
1080        // Insert third — should evict lowest validity (0.5 at pos 10)
1081        cache.insert(Vec3::new(20.0, 0.0, 0.0), Vec3::Y, Vec3::ONE, 0.9);
1082        assert_eq!(cache.len(), 2);
1083    }
1084
1085    #[test]
1086    fn irradiance_cache_decay() {
1087        let mut cache = IrradianceCache::new(8);
1088        cache.insert(Vec3::ZERO, Vec3::Y, Vec3::ONE, 1.0);
1089        let v0 = cache.entries[0].validity;
1090        cache.update_validity(Vec3::ZERO, 0.1);
1091        let v1 = cache.entries[0].validity;
1092        assert!(v1 < v0, "Validity should decrease after decay");
1093    }
1094}