Skip to main content

ff_render/nodes/
color_grade.rs

1use super::RenderNodeCpu;
2
3// ── Pipeline cache ────────────────────────────────────────────────────────────
4
5#[cfg(feature = "wgpu")]
6struct ColorGradePipeline {
7    render_pipeline: wgpu::RenderPipeline,
8    bind_group_layout: wgpu::BindGroupLayout,
9    sampler: wgpu::Sampler,
10    uniform_buf: wgpu::Buffer,
11}
12
13// ── ColorGradeNode ────────────────────────────────────────────────────────────
14
15/// Basic colour grading: brightness, contrast, saturation, temperature, tint.
16///
17/// # Processing order
18///
19/// brightness → contrast → temperature/tint → saturation
20pub struct ColorGradeNode {
21    /// Additive brightness offset (−1.0 – +1.0; 0.0 = no change).
22    pub brightness: f32,
23    /// Contrast multiplier around 0.5 (0.0 – 4.0; 1.0 = no change).
24    pub contrast: f32,
25    /// Saturation multiplier (0.0 = greyscale; 1.0 = no change; 2.0 = double).
26    pub saturation: f32,
27    /// Colour temperature offset (−1.0 = cool/blue; +1.0 = warm/orange).
28    pub temperature: f32,
29    /// Tint offset (−1.0 = magenta; +1.0 = green).
30    pub tint: f32,
31    #[cfg(feature = "wgpu")]
32    pipeline: std::sync::OnceLock<ColorGradePipeline>,
33}
34
35impl ColorGradeNode {
36    /// Identity node (no colour change).
37    #[must_use]
38    pub fn new(
39        brightness: f32,
40        contrast: f32,
41        saturation: f32,
42        temperature: f32,
43        tint: f32,
44    ) -> Self {
45        Self {
46            brightness,
47            contrast,
48            saturation,
49            temperature,
50            tint,
51            #[cfg(feature = "wgpu")]
52            pipeline: std::sync::OnceLock::new(),
53        }
54    }
55}
56
57impl Default for ColorGradeNode {
58    fn default() -> Self {
59        Self::new(0.0, 1.0, 1.0, 0.0, 0.0)
60    }
61}
62
63// ── CPU path ──────────────────────────────────────────────────────────────────
64
65impl RenderNodeCpu for ColorGradeNode {
66    #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
67    fn process_cpu(&self, rgba: &mut [u8], _w: u32, _h: u32) {
68        for pixel in rgba.chunks_exact_mut(4) {
69            let r = f32::from(pixel[0]) / 255.0;
70            let g = f32::from(pixel[1]) / 255.0;
71            let b = f32::from(pixel[2]) / 255.0;
72
73            // Brightness
74            let r = r + self.brightness;
75            let g = g + self.brightness;
76            let b = b + self.brightness;
77
78            // Contrast
79            let r = (r - 0.5) * self.contrast + 0.5;
80            let g = (g - 0.5) * self.contrast + 0.5;
81            let b = (b - 0.5) * self.contrast + 0.5;
82
83            // Temperature
84            let r = r + self.temperature * 0.1;
85            let b = b - self.temperature * 0.1;
86
87            // Tint
88            let g = g + self.tint * 0.1;
89
90            // Saturation (BT.709 luma coefficients)
91            let luma = 0.2126 * r + 0.7152 * g + 0.0722 * b;
92            let r = luma + (r - luma) * self.saturation;
93            let g = luma + (g - luma) * self.saturation;
94            let b = luma + (b - luma) * self.saturation;
95
96            pixel[0] = (r.clamp(0.0, 1.0) * 255.0 + 0.5) as u8;
97            pixel[1] = (g.clamp(0.0, 1.0) * 255.0 + 0.5) as u8;
98            pixel[2] = (b.clamp(0.0, 1.0) * 255.0 + 0.5) as u8;
99            // alpha unchanged
100        }
101    }
102}
103
104// ── GPU path ──────────────────────────────────────────────────────────────────
105
106#[cfg(feature = "wgpu")]
107impl ColorGradeNode {
108    fn get_or_create_pipeline(&self, ctx: &crate::context::RenderContext) -> &ColorGradePipeline {
109        self.pipeline.get_or_init(|| {
110            let device = &ctx.device;
111
112            let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
113                label: Some("ColorGrade shader"),
114                source: wgpu::ShaderSource::Wgsl(
115                    include_str!("../shaders/color_grade.wgsl").into(),
116                ),
117            });
118
119            let bgl = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
120                label: Some("ColorGrade BGL"),
121                entries: &[
122                    wgpu::BindGroupLayoutEntry {
123                        binding: 0,
124                        visibility: wgpu::ShaderStages::FRAGMENT,
125                        ty: wgpu::BindingType::Texture {
126                            sample_type: wgpu::TextureSampleType::Float { filterable: true },
127                            view_dimension: wgpu::TextureViewDimension::D2,
128                            multisampled: false,
129                        },
130                        count: None,
131                    },
132                    wgpu::BindGroupLayoutEntry {
133                        binding: 1,
134                        visibility: wgpu::ShaderStages::FRAGMENT,
135                        ty: wgpu::BindingType::Sampler(wgpu::SamplerBindingType::Filtering),
136                        count: None,
137                    },
138                    wgpu::BindGroupLayoutEntry {
139                        binding: 2,
140                        visibility: wgpu::ShaderStages::FRAGMENT,
141                        ty: wgpu::BindingType::Buffer {
142                            ty: wgpu::BufferBindingType::Uniform,
143                            has_dynamic_offset: false,
144                            min_binding_size: None,
145                        },
146                        count: None,
147                    },
148                ],
149            });
150
151            let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
152                label: Some("ColorGrade layout"),
153                bind_group_layouts: &[Some(&bgl)],
154                immediate_size: 0,
155            });
156
157            let render_pipeline = device.create_render_pipeline(&wgpu::RenderPipelineDescriptor {
158                label: Some("ColorGrade pipeline"),
159                layout: Some(&pipeline_layout),
160                vertex: wgpu::VertexState {
161                    module: &shader,
162                    entry_point: Some("vs_main"),
163                    buffers: &[],
164                    compilation_options: wgpu::PipelineCompilationOptions::default(),
165                },
166                fragment: Some(wgpu::FragmentState {
167                    module: &shader,
168                    entry_point: Some("fs_main"),
169                    targets: &[Some(wgpu::ColorTargetState {
170                        format: wgpu::TextureFormat::Rgba8Unorm,
171                        blend: None,
172                        write_mask: wgpu::ColorWrites::ALL,
173                    })],
174                    compilation_options: wgpu::PipelineCompilationOptions::default(),
175                }),
176                primitive: wgpu::PrimitiveState::default(),
177                depth_stencil: None,
178                multisample: wgpu::MultisampleState::default(),
179                multiview_mask: None,
180                cache: None,
181            });
182
183            let sampler = device.create_sampler(&wgpu::SamplerDescriptor {
184                label: Some("ColorGrade sampler"),
185                address_mode_u: wgpu::AddressMode::ClampToEdge,
186                address_mode_v: wgpu::AddressMode::ClampToEdge,
187                mag_filter: wgpu::FilterMode::Linear,
188                min_filter: wgpu::FilterMode::Linear,
189                ..Default::default()
190            });
191
192            // 8 × f32 = 32 bytes — matches ColorGradeUniforms in the shader.
193            let uniform_buf = device.create_buffer(&wgpu::BufferDescriptor {
194                label: Some("ColorGrade uniforms"),
195                size: 32,
196                usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
197                mapped_at_creation: false,
198            });
199
200            ColorGradePipeline {
201                render_pipeline,
202                bind_group_layout: bgl,
203                sampler,
204                uniform_buf,
205            }
206        })
207    }
208}
209
210#[cfg(feature = "wgpu")]
211impl super::RenderNode for ColorGradeNode {
212    fn process(
213        &self,
214        inputs: &[&wgpu::Texture],
215        outputs: &[&wgpu::Texture],
216        ctx: &crate::context::RenderContext,
217    ) {
218        let Some(input) = inputs.first() else {
219            log::warn!("ColorGradeNode::process called with no inputs");
220            return;
221        };
222        let Some(output) = outputs.first() else {
223            log::warn!("ColorGradeNode::process called with no outputs");
224            return;
225        };
226
227        let pd = self.get_or_create_pipeline(ctx);
228
229        // Update uniforms for this frame.
230        let uniform_bytes = pack_f32(&[
231            self.brightness,
232            self.contrast,
233            self.saturation,
234            self.temperature,
235            self.tint,
236            0.0,
237            0.0,
238            0.0,
239        ]);
240        ctx.queue.write_buffer(&pd.uniform_buf, 0, &uniform_bytes);
241
242        let input_view = input.create_view(&wgpu::TextureViewDescriptor::default());
243        let output_view = output.create_view(&wgpu::TextureViewDescriptor::default());
244
245        let bind_group = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor {
246            label: Some("ColorGrade BG"),
247            layout: &pd.bind_group_layout,
248            entries: &[
249                wgpu::BindGroupEntry {
250                    binding: 0,
251                    resource: wgpu::BindingResource::TextureView(&input_view),
252                },
253                wgpu::BindGroupEntry {
254                    binding: 1,
255                    resource: wgpu::BindingResource::Sampler(&pd.sampler),
256                },
257                wgpu::BindGroupEntry {
258                    binding: 2,
259                    resource: pd.uniform_buf.as_entire_binding(),
260                },
261            ],
262        });
263
264        let mut encoder = ctx
265            .device
266            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
267                label: Some("ColorGrade pass"),
268            });
269        {
270            let mut pass = encoder.begin_render_pass(&wgpu::RenderPassDescriptor {
271                label: Some("ColorGrade pass"),
272                color_attachments: &[Some(wgpu::RenderPassColorAttachment {
273                    view: &output_view,
274                    resolve_target: None,
275                    depth_slice: None,
276                    ops: wgpu::Operations {
277                        load: wgpu::LoadOp::Clear(wgpu::Color::TRANSPARENT),
278                        store: wgpu::StoreOp::Store,
279                    },
280                })],
281                depth_stencil_attachment: None,
282                timestamp_writes: None,
283                occlusion_query_set: None,
284                multiview_mask: None,
285            });
286            pass.set_pipeline(&pd.render_pipeline);
287            pass.set_bind_group(0, &bind_group, &[]);
288            pass.draw(0..6, 0..1);
289        }
290        ctx.queue.submit(std::iter::once(encoder.finish()));
291    }
292}
293
294// ── Tests ─────────────────────────────────────────────────────────────────────
295
296#[cfg(test)]
297mod tests {
298    use super::*;
299
300    #[test]
301    fn color_grade_node_default_should_be_identity() {
302        let node = ColorGradeNode::default();
303        let original = vec![100u8, 150, 200, 255];
304        let mut rgba = original.clone();
305        node.process_cpu(&mut rgba, 1, 1);
306        // Identity: brightness=0, contrast=1, saturation=1, temperature=0, tint=0.
307        // Allow ±1 rounding error from f32 round-trip.
308        for i in 0..3 {
309            let diff = (rgba[i] as i32 - original[i] as i32).abs();
310            assert!(
311                diff <= 1,
312                "identity must preserve pixel at channel {i}: expected ~{} got {}",
313                original[i],
314                rgba[i]
315            );
316        }
317        assert_eq!(rgba[3], 255, "alpha must not change");
318    }
319
320    #[test]
321    fn color_grade_node_brightness_positive_should_increase_mid_grey() {
322        let node = ColorGradeNode {
323            brightness: 0.5,
324            ..Default::default()
325        };
326        let mut rgba = vec![128u8, 128, 128, 255]; // mid-grey
327        node.process_cpu(&mut rgba, 1, 1);
328        assert!(
329            rgba[0] > 128,
330            "brightness +0.5 must increase R; got {}",
331            rgba[0]
332        );
333        assert!(
334            rgba[1] > 128,
335            "brightness +0.5 must increase G; got {}",
336            rgba[1]
337        );
338        assert!(
339            rgba[2] > 128,
340            "brightness +0.5 must increase B; got {}",
341            rgba[2]
342        );
343        assert_eq!(rgba[3], 255, "alpha must not change");
344    }
345
346    #[test]
347    fn color_grade_node_brightness_negative_should_decrease_mid_grey() {
348        let node = ColorGradeNode {
349            brightness: -0.5,
350            ..Default::default()
351        };
352        let mut rgba = vec![128u8, 128, 128, 255];
353        node.process_cpu(&mut rgba, 1, 1);
354        assert!(
355            rgba[0] < 128,
356            "brightness −0.5 must decrease R; got {}",
357            rgba[0]
358        );
359    }
360
361    #[test]
362    fn color_grade_node_saturation_zero_should_produce_greyscale() {
363        let node = ColorGradeNode {
364            saturation: 0.0,
365            ..Default::default()
366        };
367        let mut rgba = vec![200u8, 100, 50, 255]; // colourful pixel
368        node.process_cpu(&mut rgba, 1, 1);
369        // All channels must be equal (greyscale) — allow ±1 rounding.
370        let diff_rg = (rgba[0] as i32 - rgba[1] as i32).abs();
371        let diff_rb = (rgba[0] as i32 - rgba[2] as i32).abs();
372        assert!(
373            diff_rg <= 1,
374            "saturation=0 must equalise R and G; got R={} G={}",
375            rgba[0],
376            rgba[1]
377        );
378        assert!(
379            diff_rb <= 1,
380            "saturation=0 must equalise R and B; got R={} B={}",
381            rgba[0],
382            rgba[2]
383        );
384    }
385
386    #[test]
387    fn color_grade_node_clamp_should_not_exceed_255() {
388        let node = ColorGradeNode {
389            brightness: 2.0,
390            ..Default::default()
391        };
392        let mut rgba = vec![200u8, 200, 200, 255];
393        node.process_cpu(&mut rgba, 1, 1);
394        assert_eq!(rgba[0], 255, "clamped R must be 255");
395        assert_eq!(rgba[1], 255, "clamped G must be 255");
396        assert_eq!(rgba[2], 255, "clamped B must be 255");
397    }
398
399    #[test]
400    fn color_grade_node_variants_should_construct_via_default() {
401        let _ = ColorGradeNode {
402            brightness: 0.1,
403            contrast: 1.2,
404            saturation: 0.9,
405            ..Default::default()
406        };
407    }
408}
409
410// ── helpers ───────────────────────────────────────────────────────────────────
411
412#[cfg(feature = "wgpu")]
413fn pack_f32(values: &[f32]) -> Vec<u8> {
414    values.iter().flat_map(|f| f.to_le_bytes()).collect()
415}