Skip to main content

proof_engine/render/postfx/
lut_grade.rs

1//! Color Grading LUT — 3D look-up table generation, blending, and GPU upload.
2//!
3//! Generates a 3D LUT from color grade parameters and applies it as a post-processing
4//! pass. The LUT transforms input RGB → output RGB in a single texture lookup,
5//! enabling complex color transformations at zero per-pixel math cost.
6//!
7//! # Sizes
8//!
9//! - 16³ = 4,096 entries × 3 channels = 12,288 bytes (fast, low quality)
10//! - 32³ = 32,768 entries × 3 channels = 98,304 bytes (high quality)
11//!
12//! # Game state LUTs
13//!
14//! Each game state has a pre-baked LUT. When state changes, we smoothly blend
15//! between the old and new LUT over a configurable duration.
16
17use glam::{Vec3, Vec4};
18use super::color_grade::ColorGradeParams;
19
20// ── LUT data ────────────────────────────────────────────────────────────────
21
22/// A 3D color look-up table.
23#[derive(Clone, Debug)]
24pub struct Lut3D {
25    /// LUT dimension (e.g. 16 or 32). Total entries = size³.
26    pub size: u32,
27    /// RGB data, row-major: data[r + g*size + b*size*size] = (out_r, out_g, out_b).
28    pub data: Vec<[f32; 3]>,
29}
30
31impl Lut3D {
32    /// Create an identity LUT (output = input) of the given size.
33    pub fn identity(size: u32) -> Self {
34        let total = (size * size * size) as usize;
35        let mut data = Vec::with_capacity(total);
36        let scale = 1.0 / (size - 1).max(1) as f32;
37
38        for b in 0..size {
39            for g in 0..size {
40                for r in 0..size {
41                    data.push([r as f32 * scale, g as f32 * scale, b as f32 * scale]);
42                }
43            }
44        }
45
46        Self { size, data }
47    }
48
49    /// Generate a LUT from color grade parameters.
50    pub fn from_params(size: u32, params: &ColorGradeParams) -> Self {
51        let mut lut = Self::identity(size);
52        let scale = 1.0 / (size - 1).max(1) as f32;
53
54        for b in 0..size {
55            for g in 0..size {
56                for r in 0..size {
57                    let idx = (r + g * size + b * size * size) as usize;
58                    let input = Vec3::new(r as f32 * scale, g as f32 * scale, b as f32 * scale);
59                    let output = apply_grade(input, params);
60                    lut.data[idx] = [output.x, output.y, output.z];
61                }
62            }
63        }
64
65        lut
66    }
67
68    /// Sample the LUT with trilinear interpolation.
69    pub fn sample(&self, input: Vec3) -> Vec3 {
70        let s = (self.size - 1) as f32;
71        let r = (input.x * s).clamp(0.0, s);
72        let g = (input.y * s).clamp(0.0, s);
73        let b = (input.z * s).clamp(0.0, s);
74
75        let r0 = r.floor() as u32;
76        let g0 = g.floor() as u32;
77        let b0 = b.floor() as u32;
78        let r1 = (r0 + 1).min(self.size - 1);
79        let g1 = (g0 + 1).min(self.size - 1);
80        let b1 = (b0 + 1).min(self.size - 1);
81
82        let fr = r.fract();
83        let fg = g.fract();
84        let fb = b.fract();
85
86        // Trilinear interpolation
87        let fetch = |ri: u32, gi: u32, bi: u32| -> Vec3 {
88            let idx = (ri + gi * self.size + bi * self.size * self.size) as usize;
89            let d = self.data[idx];
90            Vec3::new(d[0], d[1], d[2])
91        };
92
93        let c000 = fetch(r0, g0, b0);
94        let c100 = fetch(r1, g0, b0);
95        let c010 = fetch(r0, g1, b0);
96        let c110 = fetch(r1, g1, b0);
97        let c001 = fetch(r0, g0, b1);
98        let c101 = fetch(r1, g0, b1);
99        let c011 = fetch(r0, g1, b1);
100        let c111 = fetch(r1, g1, b1);
101
102        let c00 = c000.lerp(c100, fr);
103        let c10 = c010.lerp(c110, fr);
104        let c01 = c001.lerp(c101, fr);
105        let c11 = c011.lerp(c111, fr);
106
107        let c0 = c00.lerp(c10, fg);
108        let c1 = c01.lerp(c11, fg);
109
110        c0.lerp(c1, fb)
111    }
112
113    /// Linearly blend two LUTs together. `t` = 0.0 returns `self`, 1.0 returns `other`.
114    pub fn blend(&self, other: &Lut3D, t: f32) -> Lut3D {
115        assert_eq!(self.size, other.size, "LUT sizes must match for blending");
116        let t = t.clamp(0.0, 1.0);
117        let data: Vec<[f32; 3]> = self.data.iter().zip(other.data.iter()).map(|(a, b)| {
118            [
119                a[0] + (b[0] - a[0]) * t,
120                a[1] + (b[1] - a[1]) * t,
121                a[2] + (b[2] - a[2]) * t,
122            ]
123        }).collect();
124
125        Lut3D { size: self.size, data }
126    }
127
128    /// Convert to flat f32 RGB buffer for GPU upload.
129    pub fn to_rgb_f32(&self) -> Vec<f32> {
130        let mut buf = Vec::with_capacity(self.data.len() * 3);
131        for entry in &self.data {
132            buf.push(entry[0]);
133            buf.push(entry[1]);
134            buf.push(entry[2]);
135        }
136        buf
137    }
138
139    /// Convert to u8 RGB buffer for GPU upload (RGBA8 3D texture).
140    pub fn to_rgb_u8(&self) -> Vec<u8> {
141        let mut buf = Vec::with_capacity(self.data.len() * 3);
142        for entry in &self.data {
143            buf.push((entry[0].clamp(0.0, 1.0) * 255.0) as u8);
144            buf.push((entry[1].clamp(0.0, 1.0) * 255.0) as u8);
145            buf.push((entry[2].clamp(0.0, 1.0) * 255.0) as u8);
146        }
147        buf
148    }
149
150    /// Total number of entries.
151    pub fn entry_count(&self) -> usize { (self.size * self.size * self.size) as usize }
152
153    /// Memory size in bytes (f32 RGB).
154    pub fn memory_bytes(&self) -> usize { self.entry_count() * 3 * 4 }
155}
156
157// ── Color grade application (CPU-side) ──────────────────────────────────────
158
159fn apply_grade(color: Vec3, params: &ColorGradeParams) -> Vec3 {
160    let mut c = color;
161
162    // Tint
163    c *= params.tint;
164
165    // Brightness
166    c += Vec3::splat(params.brightness);
167
168    // Contrast (around 0.5 midpoint)
169    c = (c - Vec3::splat(0.5)) * params.contrast + Vec3::splat(0.5);
170
171    // Lift/Gamma/Gain
172    c = c * params.gain + params.lift;
173    if params.gamma != Vec3::ONE {
174        c = Vec3::new(
175            c.x.max(0.0).powf(1.0 / params.gamma.x.max(0.01)),
176            c.y.max(0.0).powf(1.0 / params.gamma.y.max(0.01)),
177            c.z.max(0.0).powf(1.0 / params.gamma.z.max(0.01)),
178        );
179    }
180
181    // Saturation
182    let lum = 0.2126 * c.x + 0.7152 * c.y + 0.0722 * c.z;
183    c = Vec3::splat(lum).lerp(c, params.saturation);
184
185    // Split toning
186    if params.shadow_tint_strength > 0.0 || params.highlight_tint_strength > 0.0 {
187        let shadow_w = (1.0 - lum / params.split_midpoint.max(0.01)).clamp(0.0, 1.0);
188        let highlight_w = (lum / params.split_midpoint.max(0.01) - 1.0).clamp(0.0, 1.0);
189        c += params.shadow_tint * shadow_w * params.shadow_tint_strength;
190        c += params.highlight_tint * highlight_w * params.highlight_tint_strength;
191    }
192
193    // Clamp
194    Vec3::new(c.x.clamp(0.0, 1.0), c.y.clamp(0.0, 1.0), c.z.clamp(0.0, 1.0))
195}
196
197// ── Game State LUT Presets ──────────────────────────────────────────────────
198
199/// Pre-built LUT presets for each game state.
200pub struct GameStateLuts;
201
202impl GameStateLuts {
203    const SIZE: u32 = 16;
204
205    /// Normal gameplay: neutral with slight theme tint.
206    pub fn normal() -> Lut3D {
207        let params = ColorGradeParams::default();
208        Lut3D::from_params(Self::SIZE, &params)
209    }
210
211    /// Low HP: desaturated with red push.
212    pub fn low_hp(severity: f32) -> Lut3D {
213        let params = ColorGradeParams::danger(severity);
214        Lut3D::from_params(Self::SIZE, &params)
215    }
216
217    /// High corruption: purple shift, reduced contrast.
218    pub fn corruption(level: f32) -> Lut3D {
219        let t = (level / 1000.0).clamp(0.0, 1.0);
220        let params = ColorGradeParams {
221            tint: Vec3::new(0.9 + t * 0.1, 0.8 - t * 0.2, 0.95 + t * 0.15),
222            saturation: 1.0 + t * 0.3,
223            contrast: 1.0 - t * 0.15,
224            shadow_tint: Vec3::new(0.15, 0.0, 0.3),
225            shadow_tint_strength: t * 0.5,
226            brightness: t * 0.05,
227            ..Default::default()
228        };
229        Lut3D::from_params(Self::SIZE, &params)
230    }
231
232    /// Death: full desaturation over progress (0→1).
233    pub fn death(progress: f32) -> Lut3D {
234        let params = ColorGradeParams::death(progress);
235        Lut3D::from_params(Self::SIZE, &params)
236    }
237
238    /// Victory: warm golden tint.
239    pub fn victory() -> Lut3D {
240        let params = ColorGradeParams::victory();
241        Lut3D::from_params(Self::SIZE, &params)
242    }
243
244    /// Boss fight: high contrast, deep shadows.
245    pub fn boss_fight() -> Lut3D {
246        let params = ColorGradeParams {
247            contrast: 1.3,
248            saturation: 1.1,
249            shadow_tint: Vec3::new(0.05, 0.0, 0.1),
250            shadow_tint_strength: 0.4,
251            vignette: 0.35,
252            lift: Vec3::new(-0.02, -0.02, -0.01),
253            ..Default::default()
254        };
255        Lut3D::from_params(Self::SIZE, &params)
256    }
257
258    /// Shrine: blue-shifted, soft, low contrast.
259    pub fn shrine() -> Lut3D {
260        let params = ColorGradeParams {
261            tint: Vec3::new(0.85, 0.9, 1.15),
262            contrast: 0.85,
263            saturation: 0.9,
264            brightness: 0.05,
265            highlight_tint: Vec3::new(0.7, 0.8, 1.0),
266            highlight_tint_strength: 0.3,
267            ..Default::default()
268        };
269        Lut3D::from_params(Self::SIZE, &params)
270    }
271
272    /// Chaos Rift: oversaturated, high contrast, green-purple split tone.
273    pub fn chaos_rift() -> Lut3D {
274        let params = ColorGradeParams {
275            saturation: 1.6,
276            contrast: 1.4,
277            shadow_tint: Vec3::new(0.0, 0.2, 0.0),
278            shadow_tint_strength: 0.5,
279            highlight_tint: Vec3::new(0.5, 0.0, 0.5),
280            highlight_tint_strength: 0.4,
281            ..Default::default()
282        };
283        Lut3D::from_params(Self::SIZE, &params)
284    }
285}
286
287// ── LUT Blender ─────────────────────────────────────────────────────────────
288
289/// Smoothly interpolates between LUTs when game state changes.
290pub struct LutBlender {
291    /// Current active LUT.
292    current: Lut3D,
293    /// Target LUT (what we're blending toward).
294    target: Option<Lut3D>,
295    /// Blend progress: 0.0 = current, 1.0 = target.
296    progress: f32,
297    /// Blend duration in seconds.
298    duration: f32,
299    /// Blended result (updated each tick).
300    blended: Lut3D,
301    /// Whether the blended LUT needs re-upload to GPU.
302    pub dirty: bool,
303}
304
305impl LutBlender {
306    pub fn new(initial: Lut3D) -> Self {
307        let blended = initial.clone();
308        Self {
309            current: initial,
310            target: None,
311            progress: 0.0,
312            duration: 0.0,
313            blended,
314            dirty: true,
315        }
316    }
317
318    /// Start blending toward a new LUT over `duration` seconds.
319    pub fn blend_to(&mut self, target: Lut3D, duration: f32) {
320        // If already at this target, skip
321        self.current = self.blended.clone();
322        self.target = Some(target);
323        self.progress = 0.0;
324        self.duration = duration.max(0.01);
325        self.dirty = true;
326    }
327
328    /// Instant cut to a new LUT (no blend).
329    pub fn set(&mut self, lut: Lut3D) {
330        self.current = lut.clone();
331        self.target = None;
332        self.progress = 0.0;
333        self.blended = lut;
334        self.dirty = true;
335    }
336
337    /// Advance the blend by `dt` seconds.
338    pub fn tick(&mut self, dt: f32) {
339        if let Some(ref target) = self.target {
340            self.progress = (self.progress + dt / self.duration).min(1.0);
341
342            // Smooth-step easing
343            let t = self.progress;
344            let eased = t * t * (3.0 - 2.0 * t);
345
346            self.blended = self.current.blend(target, eased);
347            self.dirty = true;
348
349            if self.progress >= 1.0 {
350                self.current = self.blended.clone();
351                self.target = None;
352            }
353        }
354    }
355
356    /// Whether a blend is currently in progress.
357    pub fn is_blending(&self) -> bool { self.target.is_some() }
358
359    /// Current blend progress (0.0 to 1.0).
360    pub fn blend_progress(&self) -> f32 { self.progress }
361
362    /// Get the current (possibly blended) LUT.
363    pub fn current_lut(&self) -> &Lut3D { &self.blended }
364
365    /// Consume the dirty flag. Returns true if the LUT needs re-upload.
366    pub fn take_dirty(&mut self) -> bool {
367        let d = self.dirty;
368        self.dirty = false;
369        d
370    }
371}
372
373// ── Blend durations for game state transitions ──────────────────────────────
374
375/// Standard blend durations for LUT transitions between game states.
376pub struct LutBlendDurations;
377
378impl LutBlendDurations {
379    /// Combat state changes (entering/exiting combat).
380    pub const COMBAT: f32 = 0.3;
381    /// Boss encounter entry.
382    pub const BOSS_ENTRY: f32 = 0.5;
383    /// HP threshold changes.
384    pub const HP_CHANGE: f32 = 0.4;
385    /// Death sequence.
386    pub const DEATH: f32 = 2.0;
387    /// Victory celebration.
388    pub const VICTORY: f32 = 1.5;
389    /// Shrine entry.
390    pub const SHRINE: f32 = 0.8;
391    /// Corruption level changes.
392    pub const CORRUPTION: f32 = 0.6;
393    /// Floor/room transition.
394    pub const FLOOR_CHANGE: f32 = 0.5;
395    /// Chaos rift entry.
396    pub const CHAOS_RIFT: f32 = 0.4;
397}
398
399// ── GLSL shader for LUT application ────────────────────────────────────────
400
401/// Fragment shader that applies a 3D LUT to the scene color.
402/// The LUT is stored as a 3D texture (GL_TEXTURE_3D).
403pub const LUT_APPLY_FRAG: &str = r#"
404#version 330 core
405
406in  vec2 f_uv;
407out vec4 frag_color;
408
409uniform sampler2D u_scene;
410uniform sampler3D u_lut;
411uniform float     u_lut_strength;
412
413void main() {
414    vec3 color = texture(u_scene, f_uv).rgb;
415
416    // Clamp to [0,1] before LUT lookup
417    vec3 clamped = clamp(color, 0.0, 1.0);
418
419    // 3D LUT lookup
420    vec3 graded = texture(u_lut, clamped).rgb;
421
422    // Blend between original and graded
423    frag_color = vec4(mix(color, graded, u_lut_strength), 1.0);
424}
425"#;
426
427// ── Tests ───────────────────────────────────────────────────────────────────
428
429#[cfg(test)]
430mod tests {
431    use super::*;
432
433    #[test]
434    fn identity_lut_is_passthrough() {
435        let lut = Lut3D::identity(16);
436        let input = Vec3::new(0.5, 0.3, 0.8);
437        let output = lut.sample(input);
438        assert!((output - input).length() < 0.02, "Identity LUT should be passthrough, got {:?}", output);
439    }
440
441    #[test]
442    fn lut_blend_at_zero_is_first() {
443        let a = Lut3D::identity(8);
444        let mut b = Lut3D::identity(8);
445        // Make b different
446        for entry in &mut b.data { entry[0] = 0.0; }
447        let blended = a.blend(&b, 0.0);
448        // Should be identical to a
449        for (ae, be) in a.data.iter().zip(blended.data.iter()) {
450            assert!((ae[0] - be[0]).abs() < 1e-6);
451        }
452    }
453
454    #[test]
455    fn lut_blend_at_one_is_second() {
456        let a = Lut3D::identity(8);
457        let mut b = Lut3D::identity(8);
458        for entry in &mut b.data { entry[0] = 0.0; }
459        let blended = a.blend(&b, 1.0);
460        for (be, re) in b.data.iter().zip(blended.data.iter()) {
461            assert!((be[0] - re[0]).abs() < 1e-6);
462        }
463    }
464
465    #[test]
466    fn game_state_luts_differ() {
467        let normal = GameStateLuts::normal();
468        let boss = GameStateLuts::boss_fight();
469        // They shouldn't be identical
470        let diffs: usize = normal.data.iter().zip(boss.data.iter())
471            .filter(|(a, b)| (a[0] - b[0]).abs() > 0.01)
472            .count();
473        assert!(diffs > 0, "Boss LUT should differ from normal");
474    }
475
476    #[test]
477    fn lut_blender_completes() {
478        let a = GameStateLuts::normal();
479        let b = GameStateLuts::boss_fight();
480        let mut blender = LutBlender::new(a);
481        blender.blend_to(b, 1.0);
482
483        assert!(blender.is_blending());
484        for _ in 0..100 {
485            blender.tick(0.02);
486        }
487        assert!(!blender.is_blending());
488        assert!(blender.blend_progress() >= 1.0);
489    }
490
491    #[test]
492    fn lut_to_u8_correct_range() {
493        let lut = Lut3D::identity(4);
494        let bytes = lut.to_rgb_u8();
495        assert_eq!(bytes.len(), 4 * 4 * 4 * 3);
496        assert!(*bytes.iter().max().unwrap() <= 255);
497    }
498
499    #[test]
500    fn death_lut_desaturates() {
501        let lut = GameStateLuts::death(1.0);
502        // A bright red should be nearly grey when fully desaturated
503        let output = lut.sample(Vec3::new(1.0, 0.0, 0.0));
504        // R and G should be closer together than in the input
505        assert!((output.x - output.y).abs() < 0.5, "Death LUT should desaturate, got {:?}", output);
506    }
507}