Skip to main content

cvkg_render_gpu/
color_blindness.rs

1//! Color blindness simulation post-process pass.
2//!
3//! Implements Brettel/Viénot simulation for:
4//! - **Protanopia** (no red cones) — ~1.3% of males
5//! - **Deuteranopia** (no green cones) — ~5.9% of males
6//! - **Tritanopia** (no blue cones) — ~0.003% of general population
7//!
8//! The simulation transforms colors using a Daltonization matrix applied
9//! in linear RGB space. The module provides the transformation matrices,
10//! WGLSL shader source, and uniform types needed to integrate the effect
11//! into a GPU render pipeline.
12//!
13//! # Integration
14//!
15//! The `SurtrRenderer` in cvkg-render-gpu uses a multi-pass pipeline architecture.
16//! To add color blindness simulation, create a dedicated render pipeline using
17//! `shader_source()` and `ColorBlindUniforms`, then render a full-screen triangle
18//! after the main pass but before composite/present.
19
20/// Color blindness simulation modes.
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum ColorBlindMode {
23    /// Normal vision (identity transform — no-op, useful for A/B comparison).
24    Normal,
25    /// Protanopia: absence of L (red) cones.
26    Protanopia,
27    /// Deuteranopia: absence of M (green) cones.
28    Deuteranopia,
29    /// Tritanopia: absence of S (blue) cones.
30    Tritanopia,
31    /// Protanomaly: reduced L cone sensitivity (milder form).
32    Protanomaly,
33    /// Deuteranomaly: reduced M cone sensitivity (milder form).
34    Deuteranomaly,
35}
36
37impl ColorBlindMode {
38    /// Returns the 3x3 color transformation matrix for this mode.
39    ///
40    /// Matrix is in column-major order for WGLSL, operating on linear RGB.
41    /// Values are based on the Brettel, Viénot & Mollon (1997) model.
42    pub fn matrix(&self) -> [f32; 9] {
43        match self {
44            // Identity — no transformation
45            ColorBlindMode::Normal => [1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0],
46            // Protanopia: L cone absent
47            // Based on Brettel et al. projection plane for protanopes
48            ColorBlindMode::Protanopia => [
49                0.567, 0.433, 0.000, // R' = 0.567R + 0.433G
50                0.558, 0.442, 0.000, // G' = 0.558R + 0.442G
51                0.000, 0.242, 0.758, // B' = 0.242G + 0.758B
52            ],
53            // Deuteranopia: M cone absent
54            ColorBlindMode::Deuteranopia => [
55                0.625, 0.375, 0.000, // R' = 0.625R + 0.375G
56                0.700, 0.300, 0.000, // G' = 0.700R + 0.300G
57                0.000, 0.300, 0.700, // B' = 0.300G + 0.700B
58            ],
59            // Tritanopia: S cone absent
60            ColorBlindMode::Tritanopia => [
61                0.950, 0.050, 0.000, // R' = 0.950R + 0.050G
62                0.000, 0.433, 0.567, // G' = 0.433G + 0.567B
63                0.000, 0.475, 0.525, // B' = 0.475G + 0.525B
64            ],
65            // Protanomaly: partial L cone loss (blend of identity + protanopia)
66            ColorBlindMode::Protanomaly => [
67                0.817, 0.183, 0.000, 0.333, 0.667, 0.000, 0.000, 0.125, 0.875,
68            ],
69            // Deuteranomaly: partial M cone loss (blend of identity + deuteranopia)
70            ColorBlindMode::Deuteranomaly => [
71                0.800, 0.200, 0.000, 0.258, 0.742, 0.000, 0.000, 0.142, 0.858,
72            ],
73        }
74    }
75
76    /// Human-readable display name.
77    pub fn display_name(&self) -> &'static str {
78        match self {
79            ColorBlindMode::Normal => "Normal Vision",
80            ColorBlindMode::Protanopia => "Protanopia (no red)",
81            ColorBlindMode::Deuteranopia => "Deuteranopia (no green)",
82            ColorBlindMode::Tritanopia => "Tritanopia (no blue)",
83            ColorBlindMode::Protanomaly => "Protanomaly (reduced red)",
84            ColorBlindMode::Deuteranomaly => "Deuteranomaly (reduced green)",
85        }
86    }
87
88    /// Whether this mode performs any actual transformation.
89    pub fn is_identity(&self) -> bool {
90        matches!(self, ColorBlindMode::Normal)
91    }
92}
93
94/// Returns the WGLSL source for the color blindness fragment shader.
95///
96/// The shader samples the screen texture and applies the 3x3 color matrix
97/// from a uniform buffer. It operates in linear space.
98pub fn shader_source() -> &'static str {
99    r#"
100struct ColorBlindUniforms {
101    matrix_0: vec3<f32>,
102    matrix_1: vec3<f32>,
103    matrix_2: vec3<f32>,
104    mode: u32,
105    intensity: f32,  // 0.0 = no effect, 1.0 = full simulation
106    _pad0: f32,
107    _pad1: f32,
108};
109
110@group(0) @binding(0) var t_screen: texture_2d<f32>;
111@group(0) @binding(1) var s_screen: sampler;
112@group(0) @binding(2) var<uniform> cb: ColorBlindUniforms;
113
114struct VertexOutput {
115    @builtin(position) pos: vec4<f32>,
116    @location(0) uv: vec2<f32>,
117};
118
119@vertex
120fn fs_main_vs(@builtin(vertex_index) vid: u32) -> VertexOutput {
121    // Full-screen triangle
122    let pos = vec4<f32>(
123        select(vec2<f32>(-1.0, -1.0), vec2<f32>(3.0, -1.0), vid == 1u),
124        0.0,
125        1.0
126    );
127    let uv = vec2<f32>(
128        select(0.0, 2.0, vid == 1u),
129        select(0.0, 2.0, vid > 0u),
130    );
131    return VertexOutput(pos, uv);
132}
133
134@fragment
135fn fs_color_blind(in: VertexOutput) -> @location(0) vec4<f32> {
136    // the 3x3 matrix in the uniform is the simulation matrix
137    // see ColorBlindMode::matrix() for the algorithm
138    let screen_uv = vec2<f32>(in.uv.x, 1.0 - in.uv.y);
139    let color = textureSample(t_screen, s_screen, screen_uv);
140    let rgb = color.rgb;
141
142    let mat = mat3x3<f32>(cb.matrix_0, cb.matrix_1, cb.matrix_2);
143    let simulated = mat * rgb;
144    let result = mix(rgb, simulated, cb.intensity);
145
146    return vec4<f32>(result, color.a);
147}
148"#
149}
150
151/// Uniform data for the color blindness shader.
152#[repr(C)]
153#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
154pub struct ColorBlindUniforms {
155    /// Row 0 of the 3x3 transformation matrix (column-major).
156    pub matrix_0: [f32; 3],
157    /// Row 1.
158    pub matrix_1: [f32; 3],
159    /// Row 2.
160    pub matrix_2: [f32; 3],
161    /// Mode ID (for debugging).
162    pub mode: u32,
163    /// Effect intensity (0.0–1.0).
164    pub intensity: f32,
165    _pad0: f32,
166    _pad1: f32,
167}
168
169impl ColorBlindUniforms {
170    /// Create uniforms from a mode and intensity.
171    pub fn new(mode: ColorBlindMode, intensity: f32) -> Self {
172        let m = mode.matrix();
173        Self {
174            matrix_0: [m[0], m[1], m[2]],
175            matrix_1: [m[3], m[4], m[5]],
176            matrix_2: [m[6], m[7], m[8]],
177            mode: mode as u32,
178            intensity: intensity.clamp(0.0, 1.0),
179            _pad0: 0.0,
180            _pad1: 0.0,
181        }
182    }
183}
184
185/// All available color blindness modes for iteration.
186pub const ALL_MODES: &[ColorBlindMode] = &[
187    ColorBlindMode::Normal,
188    ColorBlindMode::Protanopia,
189    ColorBlindMode::Protanomaly,
190    ColorBlindMode::Deuteranopia,
191    ColorBlindMode::Deuteranomaly,
192    ColorBlindMode::Tritanopia,
193];
194
195#[cfg(test)]
196mod tests {
197    use super::*;
198
199    #[test]
200    fn test_normal_matrix_is_identity() {
201        let m = ColorBlindMode::Normal.matrix();
202        assert_eq!(m[0], 1.0);
203        assert_eq!(m[4], 1.0);
204        assert_eq!(m[8], 1.0);
205        assert_eq!(m[1], 0.0);
206    }
207
208    #[test]
209    fn test_protanopia_preserves_blue() {
210        let m = ColorBlindMode::Protanopia.matrix();
211        // Blue channel output should have zero contribution from R and G
212        assert_eq!(m[2], 0.0);
213        assert_eq!(m[5], 0.0);
214    }
215
216    #[test]
217    fn test_uniforms_creation() {
218        let u = ColorBlindUniforms::new(ColorBlindMode::Deuteranopia, 0.8);
219        assert_eq!(u.intensity, 0.8);
220        assert_eq!(u.mode, 2); // Deuteranopia = index 2 in enum
221    }
222
223    #[test]
224    fn test_intensity_clamping() {
225        let u = ColorBlindUniforms::new(ColorBlindMode::Normal, 999.0);
226        assert_eq!(u.intensity, 1.0);
227        let u2 = ColorBlindUniforms::new(ColorBlindMode::Normal, -1.0);
228        assert_eq!(u2.intensity, 0.0);
229    }
230
231    #[test]
232    fn test_all_modes_have_names() {
233        for mode in ALL_MODES {
234            let name = mode.display_name();
235            assert!(!name.is_empty());
236        }
237    }
238}