Skip to main content

cvkg_render_gpu/
color_blindness.rs

1//! Color blindness simulation post-process pass.
2//!
3//! Implements Brettel/Viénot simulation for:
4//! - **Protanopia** (no red cones) — ~1.3% of males
5//! - **Deuteranopia** (no green cones) — ~5.9% of males
6//! - **Tritanopia** (no blue cones) — ~0.003% of general population
7//!
8//! The simulation transforms colors using a Daltonization matrix applied
9//! in linear RGB space. The module provides the transformation matrices,
10//! WGLSL shader source, and uniform types needed to integrate the effect
11//! into a GPU render pipeline.
12//!
13//! # Integration
14//!
15//! The `SurtrRenderer` in cvkg-render-gpu uses a multi-pass pipeline architecture.
16//! To add color blindness simulation, create a dedicated render pipeline using
17//! `shader_source()` and `ColorBlindUniforms`, then render a full-screen triangle
18//! after the main pass but before composite/present.
19
20/// Color blindness simulation modes.
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22#[repr(u32)]
23pub enum ColorBlindMode {
24    /// Normal vision (identity transform — no-op, useful for A/B comparison).
25    Normal,
26    /// Protanopia: absence of L (red) cones.
27    Protanopia,
28    /// Deuteranopia: absence of M (green) cones.
29    Deuteranopia,
30    /// Tritanopia: absence of S (blue) cones.
31    Tritanopia,
32    /// Protanomaly: reduced L cone sensitivity (milder form).
33    Protanomaly,
34    /// Deuteranomaly: reduced M cone sensitivity (milder form).
35    Deuteranomaly,
36}
37
38impl ColorBlindMode {
39    /// Returns the 3x3 color transformation matrix for this mode.
40    ///
41    /// Matrix is in column-major order for WGLSL, operating on linear RGB.
42    /// Values are based on the Brettel, Viénot & Mollon (1997) model.
43    pub fn matrix(&self) -> [f32; 9] {
44        match self {
45            // Identity — no transformation
46            ColorBlindMode::Normal => [1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0],
47            // Protanopia: L cone absent
48            // Based on Brettel et al. projection plane for protanopes
49            ColorBlindMode::Protanopia => [
50                0.567, 0.433, 0.000, // R' = 0.567R + 0.433G
51                0.558, 0.442, 0.000, // G' = 0.558R + 0.442G
52                0.000, 0.242, 0.758, // B' = 0.242G + 0.758B
53            ],
54            // Deuteranopia: M cone absent
55            ColorBlindMode::Deuteranopia => [
56                0.625, 0.375, 0.000, // R' = 0.625R + 0.375G
57                0.700, 0.300, 0.000, // G' = 0.700R + 0.300G
58                0.000, 0.300, 0.700, // B' = 0.300G + 0.700B
59            ],
60            // Tritanopia: S cone absent
61            ColorBlindMode::Tritanopia => [
62                0.950, 0.050, 0.000, // R' = 0.950R + 0.050G
63                0.000, 0.433, 0.567, // G' = 0.433G + 0.567B
64                0.000, 0.475, 0.525, // B' = 0.475G + 0.525B
65            ],
66            // Protanomaly: partial L cone loss (blend of identity + protanopia)
67            ColorBlindMode::Protanomaly => [
68                0.817, 0.183, 0.000, 0.333, 0.667, 0.000, 0.000, 0.125, 0.875,
69            ],
70            // Deuteranomaly: partial M cone loss (blend of identity + deuteranopia)
71            ColorBlindMode::Deuteranomaly => [
72                0.800, 0.200, 0.000, 0.258, 0.742, 0.000, 0.000, 0.142, 0.858,
73            ],
74        }
75    }
76
77    /// Human-readable display name.
78    pub fn display_name(&self) -> &'static str {
79        match self {
80            ColorBlindMode::Normal => "Normal Vision",
81            ColorBlindMode::Protanopia => "Protanopia (no red)",
82            ColorBlindMode::Deuteranopia => "Deuteranopia (no green)",
83            ColorBlindMode::Tritanopia => "Tritanopia (no blue)",
84            ColorBlindMode::Protanomaly => "Protanomaly (reduced red)",
85            ColorBlindMode::Deuteranomaly => "Deuteranomaly (reduced green)",
86        }
87    }
88
89    /// Whether this mode performs any actual transformation.
90    pub fn is_identity(&self) -> bool {
91        matches!(self, ColorBlindMode::Normal)
92    }
93}
94
95/// Returns the WGLSL source for the color blindness fragment shader.
96///
97/// The shader samples the screen texture and applies the 3x3 color matrix
98/// from a uniform buffer. It operates in linear space.
99pub fn shader_source() -> &'static str {
100    r#"
101struct ColorBlindUniforms {
102    matrix_0: vec3<f32>,
103    matrix_1: vec3<f32>,
104    matrix_2: vec3<f32>,
105    mode: u32,
106    intensity: f32,  // 0.0 = no effect, 1.0 = full simulation
107    _pad0: f32,
108    _pad1: f32,
109};
110
111@group(0) @binding(0) var t_screen: texture_2d<f32>;
112@group(0) @binding(1) var s_screen: sampler;
113@group(0) @binding(2) var<uniform> cb: ColorBlindUniforms;
114
115struct VertexOutput {
116    @builtin(position) pos: vec4<f32>,
117    @location(0) uv: vec2<f32>,
118};
119
120@vertex
121fn fs_main_vs(@builtin(vertex_index) vid: u32) -> VertexOutput {
122    // Full-screen triangle
123    let pos = vec4<f32>(
124        select(vec2<f32>(-1.0, -1.0), vec2<f32>(3.0, -1.0), vid == 1u),
125        0.0,
126        1.0
127    );
128    let uv = vec2<f32>(
129        select(0.0, 2.0, vid == 1u),
130        select(0.0, 2.0, vid > 0u),
131    );
132    return VertexOutput(pos, uv);
133}
134
135@fragment
136fn fs_color_blind(in: VertexOutput) -> @location(0) vec4<f32> {
137    // the 3x3 matrix in the uniform is the simulation matrix
138    // see ColorBlindMode::matrix() for the algorithm
139    let screen_uv = vec2<f32>(in.uv.x, 1.0 - in.uv.y);
140    let color = textureSample(t_screen, s_screen, screen_uv);
141    let rgb = color.rgb;
142
143    let mat = mat3x3<f32>(cb.matrix_0, cb.matrix_1, cb.matrix_2);
144    let simulated = mat * rgb;
145    let result = mix(rgb, simulated, cb.intensity);
146
147    return vec4<f32>(result, color.a);
148}
149"#
150}
151
152/// Uniform data for the color blindness shader.
153#[repr(C)]
154#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
155pub struct ColorBlindUniforms {
156    /// Column 0 of the 3x3 transformation matrix.
157    pub matrix_0: [f32; 3],
158    _pad_m0: f32, // vec3<f32> is 16-byte aligned in WGSL
159    /// Column 1.
160    pub matrix_1: [f32; 3],
161    _pad_m1: f32,
162    /// Column 2.
163    pub matrix_2: [f32; 3],
164    _pad_m2: f32,
165    /// Mode ID (for debugging).
166    pub mode: u32,
167    /// Effect intensity (0.0–1.0).
168    pub intensity: f32,
169    _pad0: f32,
170    _pad1: f32,
171}
172
173impl ColorBlindUniforms {
174    /// Create uniforms from a mode and intensity.
175    pub fn new(mode: ColorBlindMode, intensity: f32) -> Self {
176        let m = mode.matrix();
177        Self {
178            matrix_0: [m[0], m[3], m[6]],
179            _pad_m0: 0.0,
180            matrix_1: [m[1], m[4], m[7]],
181            _pad_m1: 0.0,
182            matrix_2: [m[2], m[5], m[8]],
183            _pad_m2: 0.0,
184            mode: mode as u32,
185            intensity: intensity.clamp(0.0, 1.0),
186            _pad0: 0.0,
187            _pad1: 0.0,
188        }
189    }
190}
191
192/// All available color blindness modes for iteration.
193pub const ALL_MODES: &[ColorBlindMode] = &[
194    ColorBlindMode::Normal,
195    ColorBlindMode::Protanopia,
196    ColorBlindMode::Protanomaly,
197    ColorBlindMode::Deuteranopia,
198    ColorBlindMode::Deuteranomaly,
199    ColorBlindMode::Tritanopia,
200];
201
202#[cfg(test)]
203mod tests {
204    use super::*;
205
206    #[test]
207    fn test_normal_matrix_is_identity() {
208        let m = ColorBlindMode::Normal.matrix();
209        assert_eq!(m[0], 1.0);
210        assert_eq!(m[4], 1.0);
211        assert_eq!(m[8], 1.0);
212        assert_eq!(m[1], 0.0);
213    }
214
215    #[test]
216    fn test_protanopia_blue_input_isolated() {
217        let m = ColorBlindMode::Protanopia.matrix();
218        // Blue channel input should have zero contribution to R' and G' outputs
219        assert_eq!(m[2], 0.0);
220        assert_eq!(m[5], 0.0);
221    }
222
223    #[test]
224    fn test_uniforms_creation() {
225        let u = ColorBlindUniforms::new(ColorBlindMode::Deuteranopia, 0.8);
226        assert_eq!(u.intensity, 0.8);
227        assert_eq!(u.mode, 2); // Deuteranopia = index 2 in enum
228    }
229
230    #[test]
231    fn test_intensity_clamping() {
232        let u = ColorBlindUniforms::new(ColorBlindMode::Normal, 999.0);
233        assert_eq!(u.intensity, 1.0);
234        let u2 = ColorBlindUniforms::new(ColorBlindMode::Normal, -1.0);
235        assert_eq!(u2.intensity, 0.0);
236    }
237
238    #[test]
239    fn test_all_modes_have_names() {
240        for mode in ALL_MODES {
241            let name = mode.display_name();
242            assert!(!name.is_empty());
243        }
244    }
245}