Skip to main content

ff_render/nodes/composite/
masks.rs

1//! Mask nodes: `ShapeMaskNode`, `LumaMaskNode`, `AlphaMatteNode` + shared pipeline.
2
3use super::chroma_key::bt709_luma;
4#[cfg(feature = "wgpu")]
5use super::helpers::{
6    fullscreen_pipeline, linear_sampler, submit_render_pass, two_tex_sampler_uniform_bgl,
7    upload_rgba_texture,
8};
9use crate::nodes::RenderNodeCpu;
10
11// ── Shared mask pipeline ──────────────────────────────────────────────────────
12
13#[cfg(feature = "wgpu")]
14struct MaskPipeline {
15    render_pipeline: wgpu::RenderPipeline,
16    bind_group_layout: wgpu::BindGroupLayout,
17    sampler: wgpu::Sampler,
18    uniform_buf: wgpu::Buffer,
19}
20
21#[cfg(feature = "wgpu")]
22fn create_mask_pipeline(ctx: &crate::context::RenderContext) -> MaskPipeline {
23    let device = &ctx.device;
24    let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
25        label: Some("Mask shader"),
26        source: wgpu::ShaderSource::Wgsl(include_str!("../../shaders/mask.wgsl").into()),
27    });
28    let bgl = two_tex_sampler_uniform_bgl(device, "Mask");
29    let render_pipeline = fullscreen_pipeline(device, &shader, "Mask", &bgl);
30    let sampler = linear_sampler(device, "Mask");
31    let uniform_buf = device.create_buffer(&wgpu::BufferDescriptor {
32        label: Some("Mask uniforms"),
33        size: 16,
34        usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
35        mapped_at_creation: false,
36    });
37    MaskPipeline {
38        render_pipeline,
39        bind_group_layout: bgl,
40        sampler,
41        uniform_buf,
42    }
43}
44
45#[cfg(feature = "wgpu")]
46fn submit_mask_pass(
47    ctx: &crate::context::RenderContext,
48    pd: &MaskPipeline,
49    base_tex: &wgpu::Texture,
50    mask_tex: &wgpu::Texture,
51    output_tex: &wgpu::Texture,
52    mode: u32,
53    label: &str,
54) {
55    let mode_bytes = mode.to_le_bytes();
56    let uniforms: [u8; 16] = [
57        mode_bytes[0],
58        mode_bytes[1],
59        mode_bytes[2],
60        mode_bytes[3],
61        0,
62        0,
63        0,
64        0,
65        0,
66        0,
67        0,
68        0,
69        0,
70        0,
71        0,
72        0,
73    ];
74    ctx.queue.write_buffer(&pd.uniform_buf, 0, &uniforms);
75
76    let base_view = base_tex.create_view(&wgpu::TextureViewDescriptor::default());
77    let mask_view = mask_tex.create_view(&wgpu::TextureViewDescriptor::default());
78    let out_view = output_tex.create_view(&wgpu::TextureViewDescriptor::default());
79
80    let bind_group = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor {
81        label: Some(label),
82        layout: &pd.bind_group_layout,
83        entries: &[
84            wgpu::BindGroupEntry {
85                binding: 0,
86                resource: wgpu::BindingResource::TextureView(&base_view),
87            },
88            wgpu::BindGroupEntry {
89                binding: 1,
90                resource: wgpu::BindingResource::TextureView(&mask_view),
91            },
92            wgpu::BindGroupEntry {
93                binding: 2,
94                resource: wgpu::BindingResource::Sampler(&pd.sampler),
95            },
96            wgpu::BindGroupEntry {
97                binding: 3,
98                resource: pd.uniform_buf.as_entire_binding(),
99            },
100        ],
101    });
102    submit_render_pass(ctx, &pd.render_pipeline, &bind_group, &out_view, label);
103}
104
105// ── ShapeMaskNode ─────────────────────────────────────────────────────────────
106
107/// Mask `inputs[0]` using the alpha channel of `inputs[1]` (or `mask_rgba`).
108///
109/// Pixels where the mask alpha is > 0 are kept opaque; all others are made
110/// fully transparent (hard threshold at ~1/255).
111pub struct ShapeMaskNode {
112    /// Mask frame RGBA bytes (required for the CPU path).
113    pub mask_rgba: Vec<u8>,
114    /// Width of `mask_rgba`.
115    pub mask_width: u32,
116    /// Height of `mask_rgba`.
117    pub mask_height: u32,
118    #[cfg(feature = "wgpu")]
119    pipeline: std::sync::OnceLock<MaskPipeline>,
120}
121
122impl ShapeMaskNode {
123    #[must_use]
124    pub fn new(mask_rgba: Vec<u8>, mask_width: u32, mask_height: u32) -> Self {
125        Self {
126            mask_rgba,
127            mask_width,
128            mask_height,
129            #[cfg(feature = "wgpu")]
130            pipeline: std::sync::OnceLock::new(),
131        }
132    }
133}
134
135impl RenderNodeCpu for ShapeMaskNode {
136    #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
137    fn process_cpu(&self, rgba: &mut [u8], _w: u32, _h: u32) {
138        if self.mask_rgba.len() != rgba.len() {
139            return;
140        }
141        for (base, mask) in rgba.chunks_exact_mut(4).zip(self.mask_rgba.chunks_exact(4)) {
142            let keep = if mask[3] > 1 { 1.0_f32 } else { 0.0_f32 };
143            let a = f32::from(base[3]) / 255.0;
144            base[3] = ((a * keep).clamp(0.0, 1.0) * 255.0 + 0.5) as u8;
145        }
146    }
147}
148
149#[cfg(feature = "wgpu")]
150impl ShapeMaskNode {
151    fn get_or_create_pipeline(&self, ctx: &crate::context::RenderContext) -> &MaskPipeline {
152        self.pipeline.get_or_init(|| create_mask_pipeline(ctx))
153    }
154}
155
156#[cfg(feature = "wgpu")]
157impl crate::nodes::RenderNode for ShapeMaskNode {
158    fn input_count(&self) -> usize {
159        2
160    }
161
162    fn process(
163        &self,
164        inputs: &[&wgpu::Texture],
165        outputs: &[&wgpu::Texture],
166        ctx: &crate::context::RenderContext,
167    ) {
168        let Some(base_tex) = inputs.first() else {
169            log::warn!("ShapeMaskNode::process called with no inputs");
170            return;
171        };
172        let Some(output) = outputs.first() else {
173            log::warn!("ShapeMaskNode::process called with no outputs");
174            return;
175        };
176        let pd = self.get_or_create_pipeline(ctx);
177        let mask_tex = upload_rgba_texture(
178            ctx,
179            &self.mask_rgba,
180            self.mask_width,
181            self.mask_height,
182            "ShapeMask mask",
183        );
184        submit_mask_pass(ctx, pd, base_tex, &mask_tex, output, 0, "ShapeMask BG");
185    }
186}
187
188// ── LumaMaskNode ──────────────────────────────────────────────────────────────
189
190/// Mask `inputs[0]` using the BT.709 luma of `inputs[1]` (or `mask_rgba`).
191///
192/// The mask luma (0.0–1.0) is multiplied into the base alpha channel.
193pub struct LumaMaskNode {
194    /// Mask frame RGBA bytes (required for the CPU path).
195    pub mask_rgba: Vec<u8>,
196    /// Width of `mask_rgba`.
197    pub mask_width: u32,
198    /// Height of `mask_rgba`.
199    pub mask_height: u32,
200    #[cfg(feature = "wgpu")]
201    pipeline: std::sync::OnceLock<MaskPipeline>,
202}
203
204impl LumaMaskNode {
205    #[must_use]
206    pub fn new(mask_rgba: Vec<u8>, mask_width: u32, mask_height: u32) -> Self {
207        Self {
208            mask_rgba,
209            mask_width,
210            mask_height,
211            #[cfg(feature = "wgpu")]
212            pipeline: std::sync::OnceLock::new(),
213        }
214    }
215}
216
217impl RenderNodeCpu for LumaMaskNode {
218    #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
219    fn process_cpu(&self, rgba: &mut [u8], _w: u32, _h: u32) {
220        if self.mask_rgba.len() != rgba.len() {
221            return;
222        }
223        for (base, mask) in rgba.chunks_exact_mut(4).zip(self.mask_rgba.chunks_exact(4)) {
224            let mr = f32::from(mask[0]) / 255.0;
225            let mg = f32::from(mask[1]) / 255.0;
226            let mb = f32::from(mask[2]) / 255.0;
227            let luma = bt709_luma(mr, mg, mb);
228            let ba = f32::from(base[3]) / 255.0;
229            base[3] = ((ba * luma).clamp(0.0, 1.0) * 255.0 + 0.5) as u8;
230        }
231    }
232}
233
234#[cfg(feature = "wgpu")]
235impl LumaMaskNode {
236    fn get_or_create_pipeline(&self, ctx: &crate::context::RenderContext) -> &MaskPipeline {
237        self.pipeline.get_or_init(|| create_mask_pipeline(ctx))
238    }
239}
240
241#[cfg(feature = "wgpu")]
242impl crate::nodes::RenderNode for LumaMaskNode {
243    fn input_count(&self) -> usize {
244        2
245    }
246
247    fn process(
248        &self,
249        inputs: &[&wgpu::Texture],
250        outputs: &[&wgpu::Texture],
251        ctx: &crate::context::RenderContext,
252    ) {
253        let Some(base_tex) = inputs.first() else {
254            log::warn!("LumaMaskNode::process called with no inputs");
255            return;
256        };
257        let Some(output) = outputs.first() else {
258            log::warn!("LumaMaskNode::process called with no outputs");
259            return;
260        };
261        let pd = self.get_or_create_pipeline(ctx);
262        let mask_tex = upload_rgba_texture(
263            ctx,
264            &self.mask_rgba,
265            self.mask_width,
266            self.mask_height,
267            "LumaMask mask",
268        );
269        submit_mask_pass(ctx, pd, base_tex, &mask_tex, output, 1, "LumaMask BG");
270    }
271}
272
273// ── AlphaMatteNode ────────────────────────────────────────────────────────────
274
275/// Porter-Duff src-over: composite `inputs[0]` (foreground) over `inputs[1]`
276/// (background) using the foreground's own alpha channel.
277///
278/// For the CPU path the background data must be stored in `background_rgba`.
279pub struct AlphaMatteNode {
280    /// Background frame RGBA bytes (required for the CPU path).
281    pub background_rgba: Vec<u8>,
282    /// Width of `background_rgba`.
283    pub background_width: u32,
284    /// Height of `background_rgba`.
285    pub background_height: u32,
286    #[cfg(feature = "wgpu")]
287    pipeline: std::sync::OnceLock<MaskPipeline>,
288}
289
290impl AlphaMatteNode {
291    #[must_use]
292    pub fn new(background_rgba: Vec<u8>, background_width: u32, background_height: u32) -> Self {
293        Self {
294            background_rgba,
295            background_width,
296            background_height,
297            #[cfg(feature = "wgpu")]
298            pipeline: std::sync::OnceLock::new(),
299        }
300    }
301}
302
303impl RenderNodeCpu for AlphaMatteNode {
304    #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
305    fn process_cpu(&self, rgba: &mut [u8], _w: u32, _h: u32) {
306        if self.background_rgba.len() != rgba.len() {
307            return;
308        }
309        for (fg, bg) in rgba
310            .chunks_exact_mut(4)
311            .zip(self.background_rgba.chunks_exact(4))
312        {
313            let fa = f32::from(fg[3]) / 255.0;
314            let ba = f32::from(bg[3]) / 255.0;
315            for ch in 0..3 {
316                let fc = f32::from(fg[ch]) / 255.0;
317                let bc = f32::from(bg[ch]) / 255.0;
318                fg[ch] = ((fc * fa + bc * (1.0 - fa)).clamp(0.0, 1.0) * 255.0 + 0.5) as u8;
319            }
320            fg[3] = ((fa + ba * (1.0 - fa)).clamp(0.0, 1.0) * 255.0 + 0.5) as u8;
321        }
322    }
323}
324
325#[cfg(feature = "wgpu")]
326impl AlphaMatteNode {
327    fn get_or_create_pipeline(&self, ctx: &crate::context::RenderContext) -> &MaskPipeline {
328        self.pipeline.get_or_init(|| create_mask_pipeline(ctx))
329    }
330}
331
332#[cfg(feature = "wgpu")]
333impl crate::nodes::RenderNode for AlphaMatteNode {
334    fn input_count(&self) -> usize {
335        2
336    }
337
338    fn process(
339        &self,
340        inputs: &[&wgpu::Texture],
341        outputs: &[&wgpu::Texture],
342        ctx: &crate::context::RenderContext,
343    ) {
344        let Some(fg_tex) = inputs.first() else {
345            log::warn!("AlphaMatteNode::process called with no inputs");
346            return;
347        };
348        let Some(output) = outputs.first() else {
349            log::warn!("AlphaMatteNode::process called with no outputs");
350            return;
351        };
352        let pd = self.get_or_create_pipeline(ctx);
353        let bg_tex = upload_rgba_texture(
354            ctx,
355            &self.background_rgba,
356            self.background_width,
357            self.background_height,
358            "AlphaMatte bg",
359        );
360        submit_mask_pass(ctx, pd, fg_tex, &bg_tex, output, 2, "AlphaMatte BG");
361    }
362}
363
364#[cfg(test)]
365mod tests {
366    use super::*;
367    use crate::nodes::RenderNodeCpu;
368
369    #[test]
370    fn shape_mask_node_opaque_mask_should_keep_base_alpha() {
371        let mask = vec![0u8, 0, 0, 255]; // fully opaque mask
372        let node = ShapeMaskNode::new(mask, 1, 1);
373        let mut rgba = vec![128u8, 128, 128, 200];
374        node.process_cpu(&mut rgba, 1, 1);
375        assert!(
376            (rgba[3] as i32 - 200).abs() <= 1,
377            "opaque mask must preserve base alpha"
378        );
379    }
380
381    #[test]
382    fn shape_mask_node_transparent_mask_should_zero_alpha() {
383        let mask = vec![255u8, 255, 255, 0]; // fully transparent mask
384        let node = ShapeMaskNode::new(mask, 1, 1);
385        let mut rgba = vec![128u8, 128, 128, 255];
386        node.process_cpu(&mut rgba, 1, 1);
387        assert_eq!(rgba[3], 0, "transparent mask must produce zero alpha");
388    }
389
390    // ── LumaMaskNode ─────────────────────────────────────────────────────────
391
392    #[test]
393    fn luma_mask_node_white_mask_should_preserve_alpha() {
394        let mask = vec![255u8, 255, 255, 255]; // white → luma = 1.0
395        let node = LumaMaskNode::new(mask, 1, 1);
396        let mut rgba = vec![100u8, 100, 100, 200];
397        node.process_cpu(&mut rgba, 1, 1);
398        assert!(
399            (rgba[3] as i32 - 200).abs() <= 2,
400            "white mask preserves alpha"
401        );
402    }
403
404    #[test]
405    fn luma_mask_node_black_mask_should_zero_alpha() {
406        let mask = vec![0u8, 0, 0, 255]; // black → luma = 0.0
407        let node = LumaMaskNode::new(mask, 1, 1);
408        let mut rgba = vec![100u8, 100, 100, 255];
409        node.process_cpu(&mut rgba, 1, 1);
410        assert_eq!(rgba[3], 0, "black mask must zero out alpha");
411    }
412
413    // ── AlphaMatteNode ───────────────────────────────────────────────────────
414
415    #[test]
416    fn alpha_matte_node_opaque_fg_should_replace_background() {
417        let bg = vec![50u8, 50, 50, 255];
418        let node = AlphaMatteNode::new(bg, 1, 1);
419        let mut fg = vec![200u8, 100, 50, 255]; // fully opaque fg
420        node.process_cpu(&mut fg, 1, 1);
421        assert!(
422            (fg[0] as i32 - 200).abs() <= 1,
423            "opaque fg must dominate; got {}",
424            fg[0]
425        );
426    }
427
428    #[test]
429    fn alpha_matte_node_transparent_fg_should_show_background() {
430        let bg = vec![50u8, 80, 120, 255];
431        let node = AlphaMatteNode::new(bg, 1, 1);
432        let mut fg = vec![200u8, 200, 200, 0]; // fully transparent fg
433        node.process_cpu(&mut fg, 1, 1);
434        assert!(
435            (fg[0] as i32 - 50).abs() <= 1,
436            "transparent fg must show bg; got {}",
437            fg[0]
438        );
439    }
440}