1use super::chroma_key::bt709_luma;
4#[cfg(feature = "wgpu")]
5use super::helpers::{
6 fullscreen_pipeline, linear_sampler, submit_render_pass, two_tex_sampler_uniform_bgl,
7 upload_rgba_texture,
8};
9use crate::nodes::RenderNodeCpu;
10
11#[cfg(feature = "wgpu")]
14struct MaskPipeline {
15 render_pipeline: wgpu::RenderPipeline,
16 bind_group_layout: wgpu::BindGroupLayout,
17 sampler: wgpu::Sampler,
18 uniform_buf: wgpu::Buffer,
19}
20
21#[cfg(feature = "wgpu")]
22fn create_mask_pipeline(ctx: &crate::context::RenderContext) -> MaskPipeline {
23 let device = &ctx.device;
24 let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
25 label: Some("Mask shader"),
26 source: wgpu::ShaderSource::Wgsl(include_str!("../../shaders/mask.wgsl").into()),
27 });
28 let bgl = two_tex_sampler_uniform_bgl(device, "Mask");
29 let render_pipeline = fullscreen_pipeline(device, &shader, "Mask", &bgl);
30 let sampler = linear_sampler(device, "Mask");
31 let uniform_buf = device.create_buffer(&wgpu::BufferDescriptor {
32 label: Some("Mask uniforms"),
33 size: 16,
34 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
35 mapped_at_creation: false,
36 });
37 MaskPipeline {
38 render_pipeline,
39 bind_group_layout: bgl,
40 sampler,
41 uniform_buf,
42 }
43}
44
45#[cfg(feature = "wgpu")]
46fn submit_mask_pass(
47 ctx: &crate::context::RenderContext,
48 pd: &MaskPipeline,
49 base_tex: &wgpu::Texture,
50 mask_tex: &wgpu::Texture,
51 output_tex: &wgpu::Texture,
52 mode: u32,
53 label: &str,
54) {
55 let mode_bytes = mode.to_le_bytes();
56 let uniforms: [u8; 16] = [
57 mode_bytes[0],
58 mode_bytes[1],
59 mode_bytes[2],
60 mode_bytes[3],
61 0,
62 0,
63 0,
64 0,
65 0,
66 0,
67 0,
68 0,
69 0,
70 0,
71 0,
72 0,
73 ];
74 ctx.queue.write_buffer(&pd.uniform_buf, 0, &uniforms);
75
76 let base_view = base_tex.create_view(&wgpu::TextureViewDescriptor::default());
77 let mask_view = mask_tex.create_view(&wgpu::TextureViewDescriptor::default());
78 let out_view = output_tex.create_view(&wgpu::TextureViewDescriptor::default());
79
80 let bind_group = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor {
81 label: Some(label),
82 layout: &pd.bind_group_layout,
83 entries: &[
84 wgpu::BindGroupEntry {
85 binding: 0,
86 resource: wgpu::BindingResource::TextureView(&base_view),
87 },
88 wgpu::BindGroupEntry {
89 binding: 1,
90 resource: wgpu::BindingResource::TextureView(&mask_view),
91 },
92 wgpu::BindGroupEntry {
93 binding: 2,
94 resource: wgpu::BindingResource::Sampler(&pd.sampler),
95 },
96 wgpu::BindGroupEntry {
97 binding: 3,
98 resource: pd.uniform_buf.as_entire_binding(),
99 },
100 ],
101 });
102 submit_render_pass(ctx, &pd.render_pipeline, &bind_group, &out_view, label);
103}
104
105pub struct ShapeMaskNode {
112 pub mask_rgba: Vec<u8>,
114 pub mask_width: u32,
116 pub mask_height: u32,
118 #[cfg(feature = "wgpu")]
119 pipeline: std::sync::OnceLock<MaskPipeline>,
120}
121
122impl ShapeMaskNode {
123 #[must_use]
124 pub fn new(mask_rgba: Vec<u8>, mask_width: u32, mask_height: u32) -> Self {
125 Self {
126 mask_rgba,
127 mask_width,
128 mask_height,
129 #[cfg(feature = "wgpu")]
130 pipeline: std::sync::OnceLock::new(),
131 }
132 }
133}
134
135impl RenderNodeCpu for ShapeMaskNode {
136 #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
137 fn process_cpu(&self, rgba: &mut [u8], _w: u32, _h: u32) {
138 if self.mask_rgba.len() != rgba.len() {
139 return;
140 }
141 for (base, mask) in rgba.chunks_exact_mut(4).zip(self.mask_rgba.chunks_exact(4)) {
142 let keep = if mask[3] > 1 { 1.0_f32 } else { 0.0_f32 };
143 let a = f32::from(base[3]) / 255.0;
144 base[3] = ((a * keep).clamp(0.0, 1.0) * 255.0 + 0.5) as u8;
145 }
146 }
147}
148
149#[cfg(feature = "wgpu")]
150impl ShapeMaskNode {
151 fn get_or_create_pipeline(&self, ctx: &crate::context::RenderContext) -> &MaskPipeline {
152 self.pipeline.get_or_init(|| create_mask_pipeline(ctx))
153 }
154}
155
156#[cfg(feature = "wgpu")]
157impl crate::nodes::RenderNode for ShapeMaskNode {
158 fn input_count(&self) -> usize {
159 2
160 }
161
162 fn process(
163 &self,
164 inputs: &[&wgpu::Texture],
165 outputs: &[&wgpu::Texture],
166 ctx: &crate::context::RenderContext,
167 ) {
168 let Some(base_tex) = inputs.first() else {
169 log::warn!("ShapeMaskNode::process called with no inputs");
170 return;
171 };
172 let Some(output) = outputs.first() else {
173 log::warn!("ShapeMaskNode::process called with no outputs");
174 return;
175 };
176 let pd = self.get_or_create_pipeline(ctx);
177 let mask_tex = upload_rgba_texture(
178 ctx,
179 &self.mask_rgba,
180 self.mask_width,
181 self.mask_height,
182 "ShapeMask mask",
183 );
184 submit_mask_pass(ctx, pd, base_tex, &mask_tex, output, 0, "ShapeMask BG");
185 }
186}
187
188pub struct LumaMaskNode {
194 pub mask_rgba: Vec<u8>,
196 pub mask_width: u32,
198 pub mask_height: u32,
200 #[cfg(feature = "wgpu")]
201 pipeline: std::sync::OnceLock<MaskPipeline>,
202}
203
204impl LumaMaskNode {
205 #[must_use]
206 pub fn new(mask_rgba: Vec<u8>, mask_width: u32, mask_height: u32) -> Self {
207 Self {
208 mask_rgba,
209 mask_width,
210 mask_height,
211 #[cfg(feature = "wgpu")]
212 pipeline: std::sync::OnceLock::new(),
213 }
214 }
215}
216
217impl RenderNodeCpu for LumaMaskNode {
218 #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
219 fn process_cpu(&self, rgba: &mut [u8], _w: u32, _h: u32) {
220 if self.mask_rgba.len() != rgba.len() {
221 return;
222 }
223 for (base, mask) in rgba.chunks_exact_mut(4).zip(self.mask_rgba.chunks_exact(4)) {
224 let mr = f32::from(mask[0]) / 255.0;
225 let mg = f32::from(mask[1]) / 255.0;
226 let mb = f32::from(mask[2]) / 255.0;
227 let luma = bt709_luma(mr, mg, mb);
228 let ba = f32::from(base[3]) / 255.0;
229 base[3] = ((ba * luma).clamp(0.0, 1.0) * 255.0 + 0.5) as u8;
230 }
231 }
232}
233
234#[cfg(feature = "wgpu")]
235impl LumaMaskNode {
236 fn get_or_create_pipeline(&self, ctx: &crate::context::RenderContext) -> &MaskPipeline {
237 self.pipeline.get_or_init(|| create_mask_pipeline(ctx))
238 }
239}
240
241#[cfg(feature = "wgpu")]
242impl crate::nodes::RenderNode for LumaMaskNode {
243 fn input_count(&self) -> usize {
244 2
245 }
246
247 fn process(
248 &self,
249 inputs: &[&wgpu::Texture],
250 outputs: &[&wgpu::Texture],
251 ctx: &crate::context::RenderContext,
252 ) {
253 let Some(base_tex) = inputs.first() else {
254 log::warn!("LumaMaskNode::process called with no inputs");
255 return;
256 };
257 let Some(output) = outputs.first() else {
258 log::warn!("LumaMaskNode::process called with no outputs");
259 return;
260 };
261 let pd = self.get_or_create_pipeline(ctx);
262 let mask_tex = upload_rgba_texture(
263 ctx,
264 &self.mask_rgba,
265 self.mask_width,
266 self.mask_height,
267 "LumaMask mask",
268 );
269 submit_mask_pass(ctx, pd, base_tex, &mask_tex, output, 1, "LumaMask BG");
270 }
271}
272
273pub struct AlphaMatteNode {
280 pub background_rgba: Vec<u8>,
282 pub background_width: u32,
284 pub background_height: u32,
286 #[cfg(feature = "wgpu")]
287 pipeline: std::sync::OnceLock<MaskPipeline>,
288}
289
290impl AlphaMatteNode {
291 #[must_use]
292 pub fn new(background_rgba: Vec<u8>, background_width: u32, background_height: u32) -> Self {
293 Self {
294 background_rgba,
295 background_width,
296 background_height,
297 #[cfg(feature = "wgpu")]
298 pipeline: std::sync::OnceLock::new(),
299 }
300 }
301}
302
303impl RenderNodeCpu for AlphaMatteNode {
304 #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
305 fn process_cpu(&self, rgba: &mut [u8], _w: u32, _h: u32) {
306 if self.background_rgba.len() != rgba.len() {
307 return;
308 }
309 for (fg, bg) in rgba
310 .chunks_exact_mut(4)
311 .zip(self.background_rgba.chunks_exact(4))
312 {
313 let fa = f32::from(fg[3]) / 255.0;
314 let ba = f32::from(bg[3]) / 255.0;
315 for ch in 0..3 {
316 let fc = f32::from(fg[ch]) / 255.0;
317 let bc = f32::from(bg[ch]) / 255.0;
318 fg[ch] = ((fc * fa + bc * (1.0 - fa)).clamp(0.0, 1.0) * 255.0 + 0.5) as u8;
319 }
320 fg[3] = ((fa + ba * (1.0 - fa)).clamp(0.0, 1.0) * 255.0 + 0.5) as u8;
321 }
322 }
323}
324
325#[cfg(feature = "wgpu")]
326impl AlphaMatteNode {
327 fn get_or_create_pipeline(&self, ctx: &crate::context::RenderContext) -> &MaskPipeline {
328 self.pipeline.get_or_init(|| create_mask_pipeline(ctx))
329 }
330}
331
332#[cfg(feature = "wgpu")]
333impl crate::nodes::RenderNode for AlphaMatteNode {
334 fn input_count(&self) -> usize {
335 2
336 }
337
338 fn process(
339 &self,
340 inputs: &[&wgpu::Texture],
341 outputs: &[&wgpu::Texture],
342 ctx: &crate::context::RenderContext,
343 ) {
344 let Some(fg_tex) = inputs.first() else {
345 log::warn!("AlphaMatteNode::process called with no inputs");
346 return;
347 };
348 let Some(output) = outputs.first() else {
349 log::warn!("AlphaMatteNode::process called with no outputs");
350 return;
351 };
352 let pd = self.get_or_create_pipeline(ctx);
353 let bg_tex = upload_rgba_texture(
354 ctx,
355 &self.background_rgba,
356 self.background_width,
357 self.background_height,
358 "AlphaMatte bg",
359 );
360 submit_mask_pass(ctx, pd, fg_tex, &bg_tex, output, 2, "AlphaMatte BG");
361 }
362}
363
364#[cfg(test)]
365mod tests {
366 use super::*;
367 use crate::nodes::RenderNodeCpu;
368
369 #[test]
370 fn shape_mask_node_opaque_mask_should_keep_base_alpha() {
371 let mask = vec![0u8, 0, 0, 255]; let node = ShapeMaskNode::new(mask, 1, 1);
373 let mut rgba = vec![128u8, 128, 128, 200];
374 node.process_cpu(&mut rgba, 1, 1);
375 assert!(
376 (rgba[3] as i32 - 200).abs() <= 1,
377 "opaque mask must preserve base alpha"
378 );
379 }
380
381 #[test]
382 fn shape_mask_node_transparent_mask_should_zero_alpha() {
383 let mask = vec![255u8, 255, 255, 0]; let node = ShapeMaskNode::new(mask, 1, 1);
385 let mut rgba = vec![128u8, 128, 128, 255];
386 node.process_cpu(&mut rgba, 1, 1);
387 assert_eq!(rgba[3], 0, "transparent mask must produce zero alpha");
388 }
389
390 #[test]
393 fn luma_mask_node_white_mask_should_preserve_alpha() {
394 let mask = vec![255u8, 255, 255, 255]; let node = LumaMaskNode::new(mask, 1, 1);
396 let mut rgba = vec![100u8, 100, 100, 200];
397 node.process_cpu(&mut rgba, 1, 1);
398 assert!(
399 (rgba[3] as i32 - 200).abs() <= 2,
400 "white mask preserves alpha"
401 );
402 }
403
404 #[test]
405 fn luma_mask_node_black_mask_should_zero_alpha() {
406 let mask = vec![0u8, 0, 0, 255]; let node = LumaMaskNode::new(mask, 1, 1);
408 let mut rgba = vec![100u8, 100, 100, 255];
409 node.process_cpu(&mut rgba, 1, 1);
410 assert_eq!(rgba[3], 0, "black mask must zero out alpha");
411 }
412
413 #[test]
416 fn alpha_matte_node_opaque_fg_should_replace_background() {
417 let bg = vec![50u8, 50, 50, 255];
418 let node = AlphaMatteNode::new(bg, 1, 1);
419 let mut fg = vec![200u8, 100, 50, 255]; node.process_cpu(&mut fg, 1, 1);
421 assert!(
422 (fg[0] as i32 - 200).abs() <= 1,
423 "opaque fg must dominate; got {}",
424 fg[0]
425 );
426 }
427
428 #[test]
429 fn alpha_matte_node_transparent_fg_should_show_background() {
430 let bg = vec![50u8, 80, 120, 255];
431 let node = AlphaMatteNode::new(bg, 1, 1);
432 let mut fg = vec![200u8, 200, 200, 0]; node.process_cpu(&mut fg, 1, 1);
434 assert!(
435 (fg[0] as i32 - 50).abs() <= 1,
436 "transparent fg must show bg; got {}",
437 fg[0]
438 );
439 }
440}