1use super::RenderNodeCpu;
2
3#[cfg(feature = "wgpu")]
6struct ColorGradePipeline {
7 render_pipeline: wgpu::RenderPipeline,
8 bind_group_layout: wgpu::BindGroupLayout,
9 sampler: wgpu::Sampler,
10 uniform_buf: wgpu::Buffer,
11}
12
13pub struct ColorGradeNode {
21 pub brightness: f32,
23 pub contrast: f32,
25 pub saturation: f32,
27 pub temperature: f32,
29 pub tint: f32,
31 #[cfg(feature = "wgpu")]
32 pipeline: std::sync::OnceLock<ColorGradePipeline>,
33}
34
35impl ColorGradeNode {
36 #[must_use]
38 pub fn new(
39 brightness: f32,
40 contrast: f32,
41 saturation: f32,
42 temperature: f32,
43 tint: f32,
44 ) -> Self {
45 Self {
46 brightness,
47 contrast,
48 saturation,
49 temperature,
50 tint,
51 #[cfg(feature = "wgpu")]
52 pipeline: std::sync::OnceLock::new(),
53 }
54 }
55}
56
57impl Default for ColorGradeNode {
58 fn default() -> Self {
59 Self::new(0.0, 1.0, 1.0, 0.0, 0.0)
60 }
61}
62
63impl RenderNodeCpu for ColorGradeNode {
66 #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
67 fn process_cpu(&self, rgba: &mut [u8], _w: u32, _h: u32) {
68 for pixel in rgba.chunks_exact_mut(4) {
69 let r = f32::from(pixel[0]) / 255.0;
70 let g = f32::from(pixel[1]) / 255.0;
71 let b = f32::from(pixel[2]) / 255.0;
72
73 let r = r + self.brightness;
75 let g = g + self.brightness;
76 let b = b + self.brightness;
77
78 let r = (r - 0.5) * self.contrast + 0.5;
80 let g = (g - 0.5) * self.contrast + 0.5;
81 let b = (b - 0.5) * self.contrast + 0.5;
82
83 let r = r + self.temperature * 0.1;
85 let b = b - self.temperature * 0.1;
86
87 let g = g + self.tint * 0.1;
89
90 let luma = 0.2126 * r + 0.7152 * g + 0.0722 * b;
92 let r = luma + (r - luma) * self.saturation;
93 let g = luma + (g - luma) * self.saturation;
94 let b = luma + (b - luma) * self.saturation;
95
96 pixel[0] = (r.clamp(0.0, 1.0) * 255.0 + 0.5) as u8;
97 pixel[1] = (g.clamp(0.0, 1.0) * 255.0 + 0.5) as u8;
98 pixel[2] = (b.clamp(0.0, 1.0) * 255.0 + 0.5) as u8;
99 }
101 }
102}
103
104#[cfg(feature = "wgpu")]
107impl ColorGradeNode {
108 fn get_or_create_pipeline(&self, ctx: &crate::context::RenderContext) -> &ColorGradePipeline {
109 self.pipeline.get_or_init(|| {
110 let device = &ctx.device;
111
112 let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
113 label: Some("ColorGrade shader"),
114 source: wgpu::ShaderSource::Wgsl(
115 include_str!("../shaders/color_grade.wgsl").into(),
116 ),
117 });
118
119 let bgl = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
120 label: Some("ColorGrade BGL"),
121 entries: &[
122 wgpu::BindGroupLayoutEntry {
123 binding: 0,
124 visibility: wgpu::ShaderStages::FRAGMENT,
125 ty: wgpu::BindingType::Texture {
126 sample_type: wgpu::TextureSampleType::Float { filterable: true },
127 view_dimension: wgpu::TextureViewDimension::D2,
128 multisampled: false,
129 },
130 count: None,
131 },
132 wgpu::BindGroupLayoutEntry {
133 binding: 1,
134 visibility: wgpu::ShaderStages::FRAGMENT,
135 ty: wgpu::BindingType::Sampler(wgpu::SamplerBindingType::Filtering),
136 count: None,
137 },
138 wgpu::BindGroupLayoutEntry {
139 binding: 2,
140 visibility: wgpu::ShaderStages::FRAGMENT,
141 ty: wgpu::BindingType::Buffer {
142 ty: wgpu::BufferBindingType::Uniform,
143 has_dynamic_offset: false,
144 min_binding_size: None,
145 },
146 count: None,
147 },
148 ],
149 });
150
151 let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
152 label: Some("ColorGrade layout"),
153 bind_group_layouts: &[Some(&bgl)],
154 immediate_size: 0,
155 });
156
157 let render_pipeline = device.create_render_pipeline(&wgpu::RenderPipelineDescriptor {
158 label: Some("ColorGrade pipeline"),
159 layout: Some(&pipeline_layout),
160 vertex: wgpu::VertexState {
161 module: &shader,
162 entry_point: Some("vs_main"),
163 buffers: &[],
164 compilation_options: wgpu::PipelineCompilationOptions::default(),
165 },
166 fragment: Some(wgpu::FragmentState {
167 module: &shader,
168 entry_point: Some("fs_main"),
169 targets: &[Some(wgpu::ColorTargetState {
170 format: wgpu::TextureFormat::Rgba8Unorm,
171 blend: None,
172 write_mask: wgpu::ColorWrites::ALL,
173 })],
174 compilation_options: wgpu::PipelineCompilationOptions::default(),
175 }),
176 primitive: wgpu::PrimitiveState::default(),
177 depth_stencil: None,
178 multisample: wgpu::MultisampleState::default(),
179 multiview_mask: None,
180 cache: None,
181 });
182
183 let sampler = device.create_sampler(&wgpu::SamplerDescriptor {
184 label: Some("ColorGrade sampler"),
185 address_mode_u: wgpu::AddressMode::ClampToEdge,
186 address_mode_v: wgpu::AddressMode::ClampToEdge,
187 mag_filter: wgpu::FilterMode::Linear,
188 min_filter: wgpu::FilterMode::Linear,
189 ..Default::default()
190 });
191
192 let uniform_buf = device.create_buffer(&wgpu::BufferDescriptor {
194 label: Some("ColorGrade uniforms"),
195 size: 32,
196 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
197 mapped_at_creation: false,
198 });
199
200 ColorGradePipeline {
201 render_pipeline,
202 bind_group_layout: bgl,
203 sampler,
204 uniform_buf,
205 }
206 })
207 }
208}
209
210#[cfg(feature = "wgpu")]
211impl super::RenderNode for ColorGradeNode {
212 fn process(
213 &self,
214 inputs: &[&wgpu::Texture],
215 outputs: &[&wgpu::Texture],
216 ctx: &crate::context::RenderContext,
217 ) {
218 let Some(input) = inputs.first() else {
219 log::warn!("ColorGradeNode::process called with no inputs");
220 return;
221 };
222 let Some(output) = outputs.first() else {
223 log::warn!("ColorGradeNode::process called with no outputs");
224 return;
225 };
226
227 let pd = self.get_or_create_pipeline(ctx);
228
229 let uniform_bytes = pack_f32(&[
231 self.brightness,
232 self.contrast,
233 self.saturation,
234 self.temperature,
235 self.tint,
236 0.0,
237 0.0,
238 0.0,
239 ]);
240 ctx.queue.write_buffer(&pd.uniform_buf, 0, &uniform_bytes);
241
242 let input_view = input.create_view(&wgpu::TextureViewDescriptor::default());
243 let output_view = output.create_view(&wgpu::TextureViewDescriptor::default());
244
245 let bind_group = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor {
246 label: Some("ColorGrade BG"),
247 layout: &pd.bind_group_layout,
248 entries: &[
249 wgpu::BindGroupEntry {
250 binding: 0,
251 resource: wgpu::BindingResource::TextureView(&input_view),
252 },
253 wgpu::BindGroupEntry {
254 binding: 1,
255 resource: wgpu::BindingResource::Sampler(&pd.sampler),
256 },
257 wgpu::BindGroupEntry {
258 binding: 2,
259 resource: pd.uniform_buf.as_entire_binding(),
260 },
261 ],
262 });
263
264 let mut encoder = ctx
265 .device
266 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
267 label: Some("ColorGrade pass"),
268 });
269 {
270 let mut pass = encoder.begin_render_pass(&wgpu::RenderPassDescriptor {
271 label: Some("ColorGrade pass"),
272 color_attachments: &[Some(wgpu::RenderPassColorAttachment {
273 view: &output_view,
274 resolve_target: None,
275 depth_slice: None,
276 ops: wgpu::Operations {
277 load: wgpu::LoadOp::Clear(wgpu::Color::TRANSPARENT),
278 store: wgpu::StoreOp::Store,
279 },
280 })],
281 depth_stencil_attachment: None,
282 timestamp_writes: None,
283 occlusion_query_set: None,
284 multiview_mask: None,
285 });
286 pass.set_pipeline(&pd.render_pipeline);
287 pass.set_bind_group(0, &bind_group, &[]);
288 pass.draw(0..6, 0..1);
289 }
290 ctx.queue.submit(std::iter::once(encoder.finish()));
291 }
292}
293
294#[cfg(test)]
297mod tests {
298 use super::*;
299
300 #[test]
301 fn color_grade_node_default_should_be_identity() {
302 let node = ColorGradeNode::default();
303 let original = vec![100u8, 150, 200, 255];
304 let mut rgba = original.clone();
305 node.process_cpu(&mut rgba, 1, 1);
306 for i in 0..3 {
309 let diff = (rgba[i] as i32 - original[i] as i32).abs();
310 assert!(
311 diff <= 1,
312 "identity must preserve pixel at channel {i}: expected ~{} got {}",
313 original[i],
314 rgba[i]
315 );
316 }
317 assert_eq!(rgba[3], 255, "alpha must not change");
318 }
319
320 #[test]
321 fn color_grade_node_brightness_positive_should_increase_mid_grey() {
322 let node = ColorGradeNode {
323 brightness: 0.5,
324 ..Default::default()
325 };
326 let mut rgba = vec![128u8, 128, 128, 255]; node.process_cpu(&mut rgba, 1, 1);
328 assert!(
329 rgba[0] > 128,
330 "brightness +0.5 must increase R; got {}",
331 rgba[0]
332 );
333 assert!(
334 rgba[1] > 128,
335 "brightness +0.5 must increase G; got {}",
336 rgba[1]
337 );
338 assert!(
339 rgba[2] > 128,
340 "brightness +0.5 must increase B; got {}",
341 rgba[2]
342 );
343 assert_eq!(rgba[3], 255, "alpha must not change");
344 }
345
346 #[test]
347 fn color_grade_node_brightness_negative_should_decrease_mid_grey() {
348 let node = ColorGradeNode {
349 brightness: -0.5,
350 ..Default::default()
351 };
352 let mut rgba = vec![128u8, 128, 128, 255];
353 node.process_cpu(&mut rgba, 1, 1);
354 assert!(
355 rgba[0] < 128,
356 "brightness −0.5 must decrease R; got {}",
357 rgba[0]
358 );
359 }
360
361 #[test]
362 fn color_grade_node_saturation_zero_should_produce_greyscale() {
363 let node = ColorGradeNode {
364 saturation: 0.0,
365 ..Default::default()
366 };
367 let mut rgba = vec![200u8, 100, 50, 255]; node.process_cpu(&mut rgba, 1, 1);
369 let diff_rg = (rgba[0] as i32 - rgba[1] as i32).abs();
371 let diff_rb = (rgba[0] as i32 - rgba[2] as i32).abs();
372 assert!(
373 diff_rg <= 1,
374 "saturation=0 must equalise R and G; got R={} G={}",
375 rgba[0],
376 rgba[1]
377 );
378 assert!(
379 diff_rb <= 1,
380 "saturation=0 must equalise R and B; got R={} B={}",
381 rgba[0],
382 rgba[2]
383 );
384 }
385
386 #[test]
387 fn color_grade_node_clamp_should_not_exceed_255() {
388 let node = ColorGradeNode {
389 brightness: 2.0,
390 ..Default::default()
391 };
392 let mut rgba = vec![200u8, 200, 200, 255];
393 node.process_cpu(&mut rgba, 1, 1);
394 assert_eq!(rgba[0], 255, "clamped R must be 255");
395 assert_eq!(rgba[1], 255, "clamped G must be 255");
396 assert_eq!(rgba[2], 255, "clamped B must be 255");
397 }
398
399 #[test]
400 fn color_grade_node_variants_should_construct_via_default() {
401 let _ = ColorGradeNode {
402 brightness: 0.1,
403 contrast: 1.2,
404 saturation: 0.9,
405 ..Default::default()
406 };
407 }
408}
409
410#[cfg(feature = "wgpu")]
413fn pack_f32(values: &[f32]) -> Vec<u8> {
414 values.iter().flat_map(|f| f.to_le_bytes()).collect()
415}