Skip to main content

ff_render/nodes/composite/
chroma_key.rs

1//! `ChromaKeyNode` — green-screen keying (CPU + GPU) with tolerance/softness.
2
3#[cfg(feature = "wgpu")]
4use super::helpers::{
5    fullscreen_pipeline, linear_sampler, one_tex_sampler_uniform_bgl, pack_f32, submit_render_pass,
6};
7use crate::nodes::RenderNodeCpu;
8
9// ── ChromaKeyNode ─────────────────────────────────────────────────────────────
10
11#[cfg(feature = "wgpu")]
12struct ChromaKeyPipeline {
13    render_pipeline: wgpu::RenderPipeline,
14    bind_group_layout: wgpu::BindGroupLayout,
15    sampler: wgpu::Sampler,
16    uniform_buf: wgpu::Buffer,
17}
18
19/// Remove a solid colour from a texture by chroma distance, producing alpha.
20///
21/// The algorithm computes the Euclidean distance between the pixel's chroma
22/// vector (RGB − luma) and the key colour's chroma vector, then applies a soft
23/// threshold to set the alpha channel.  Pixels that match `key_color` within
24/// `tolerance` become fully transparent; pixels further than `tolerance +
25/// softness` stay fully opaque.
26pub struct ChromaKeyNode {
27    /// Key colour in linear RGB [0.0, 1.0].
28    pub key_color: [f32; 3],
29    /// Chroma distance threshold (0.0–1.0).
30    pub tolerance: f32,
31    /// Edge feather width (0.0–1.0).
32    pub softness: f32,
33    #[cfg(feature = "wgpu")]
34    pipeline: std::sync::OnceLock<ChromaKeyPipeline>,
35}
36
37impl ChromaKeyNode {
38    #[must_use]
39    pub fn new(key_color: [f32; 3], tolerance: f32, softness: f32) -> Self {
40        Self {
41            key_color,
42            tolerance,
43            softness,
44            #[cfg(feature = "wgpu")]
45            pipeline: std::sync::OnceLock::new(),
46        }
47    }
48}
49
50// ── CPU helpers ───────────────────────────────────────────────────────────────
51
52pub(super) fn bt709_luma(r: f32, g: f32, b: f32) -> f32 {
53    0.2126 * r + 0.7152 * g + 0.0722 * b
54}
55
56fn chroma_dist_cpu(pixel: [f32; 3], key: [f32; 3]) -> f32 {
57    let pl = bt709_luma(pixel[0], pixel[1], pixel[2]);
58    let kl = bt709_luma(key[0], key[1], key[2]);
59    let dp = [pixel[0] - pl, pixel[1] - pl, pixel[2] - pl];
60    let dk = [key[0] - kl, key[1] - kl, key[2] - kl];
61    let d = [dp[0] - dk[0], dp[1] - dk[1], dp[2] - dk[2]];
62    (d[0] * d[0] + d[1] * d[1] + d[2] * d[2]).sqrt()
63}
64
65fn smoothstep(edge0: f32, edge1: f32, x: f32) -> f32 {
66    let t = ((x - edge0) / (edge1 - edge0)).clamp(0.0, 1.0);
67    t * t * (3.0 - 2.0 * t)
68}
69
70impl RenderNodeCpu for ChromaKeyNode {
71    #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
72    fn process_cpu(&self, rgba: &mut [u8], _w: u32, _h: u32) {
73        for pixel in rgba.chunks_exact_mut(4) {
74            let r = f32::from(pixel[0]) / 255.0;
75            let g = f32::from(pixel[1]) / 255.0;
76            let b = f32::from(pixel[2]) / 255.0;
77            let a = f32::from(pixel[3]) / 255.0;
78            let dist = chroma_dist_cpu([r, g, b], self.key_color);
79            let alpha_factor = smoothstep(
80                self.tolerance - self.softness,
81                self.tolerance + self.softness,
82                dist,
83            );
84            pixel[3] = ((a * alpha_factor).clamp(0.0, 1.0) * 255.0 + 0.5) as u8;
85        }
86    }
87}
88
89#[cfg(feature = "wgpu")]
90impl ChromaKeyNode {
91    fn get_or_create_pipeline(&self, ctx: &crate::context::RenderContext) -> &ChromaKeyPipeline {
92        self.pipeline.get_or_init(|| {
93            let device = &ctx.device;
94            let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
95                label: Some("ChromaKey shader"),
96                source: wgpu::ShaderSource::Wgsl(
97                    include_str!("../../shaders/chroma_key.wgsl").into(),
98                ),
99            });
100            let bgl = one_tex_sampler_uniform_bgl(device, "ChromaKey");
101            let render_pipeline = fullscreen_pipeline(device, &shader, "ChromaKey", &bgl);
102            let sampler = linear_sampler(device, "ChromaKey");
103            // Uniform: key_color(3) + tolerance(1) + softness(1) + pad(3) = 8×f32 = 32 bytes.
104            let uniform_buf = device.create_buffer(&wgpu::BufferDescriptor {
105                label: Some("ChromaKey uniforms"),
106                size: 32,
107                usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
108                mapped_at_creation: false,
109            });
110            ChromaKeyPipeline {
111                render_pipeline,
112                bind_group_layout: bgl,
113                sampler,
114                uniform_buf,
115            }
116        })
117    }
118}
119
120#[cfg(feature = "wgpu")]
121impl crate::nodes::RenderNode for ChromaKeyNode {
122    fn process(
123        &self,
124        inputs: &[&wgpu::Texture],
125        outputs: &[&wgpu::Texture],
126        ctx: &crate::context::RenderContext,
127    ) {
128        let Some(input) = inputs.first() else {
129            log::warn!("ChromaKeyNode::process called with no inputs");
130            return;
131        };
132        let Some(output) = outputs.first() else {
133            log::warn!("ChromaKeyNode::process called with no outputs");
134            return;
135        };
136        let pd = self.get_or_create_pipeline(ctx);
137
138        let uniforms = pack_f32(&[
139            self.key_color[0],
140            self.key_color[1],
141            self.key_color[2],
142            self.tolerance,
143            self.softness,
144            0.0,
145            0.0,
146            0.0,
147        ]);
148        ctx.queue.write_buffer(&pd.uniform_buf, 0, &uniforms);
149
150        let in_view = input.create_view(&wgpu::TextureViewDescriptor::default());
151        let out_view = output.create_view(&wgpu::TextureViewDescriptor::default());
152
153        let bind_group = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor {
154            label: Some("ChromaKey BG"),
155            layout: &pd.bind_group_layout,
156            entries: &[
157                wgpu::BindGroupEntry {
158                    binding: 0,
159                    resource: wgpu::BindingResource::TextureView(&in_view),
160                },
161                wgpu::BindGroupEntry {
162                    binding: 1,
163                    resource: wgpu::BindingResource::Sampler(&pd.sampler),
164                },
165                wgpu::BindGroupEntry {
166                    binding: 2,
167                    resource: pd.uniform_buf.as_entire_binding(),
168                },
169            ],
170        });
171        submit_render_pass(
172            ctx,
173            &pd.render_pipeline,
174            &bind_group,
175            &out_view,
176            "ChromaKey",
177        );
178    }
179}
180
181#[cfg(test)]
182mod tests {
183    use super::*;
184    use crate::nodes::RenderNodeCpu;
185
186    #[test]
187    fn chroma_key_node_pure_green_should_become_transparent() {
188        let mut rgba = vec![0u8, 255, 0, 255]; // pure green
189        let node = ChromaKeyNode::new([0.0, 1.0, 0.0], 0.1, 0.05);
190        node.process_cpu(&mut rgba, 1, 1);
191        assert_eq!(
192            rgba[3], 0,
193            "pure green key must produce fully transparent alpha"
194        );
195    }
196
197    #[test]
198    fn chroma_key_node_non_key_colour_should_stay_opaque() {
199        let mut rgba = vec![255u8, 0, 0, 255]; // pure red
200        let node = ChromaKeyNode::new([0.0, 1.0, 0.0], 0.1, 0.05);
201        node.process_cpu(&mut rgba, 1, 1);
202        assert!(
203            rgba[3] > 200,
204            "non-key colour must stay opaque; got alpha={}",
205            rgba[3]
206        );
207    }
208
209    #[test]
210    fn chroma_key_node_tolerances_should_control_threshold() {
211        // A dark green should be keyed with a generous tolerance but not with a tight one.
212        let mut rgba_tight = vec![0u8, 100, 0, 255]; // dark green
213        let mut rgba_loose = rgba_tight.clone();
214        let node_tight = ChromaKeyNode::new([0.0, 1.0, 0.0], 0.05, 0.01);
215        let node_loose = ChromaKeyNode::new([0.0, 1.0, 0.0], 0.8, 0.1);
216        node_tight.process_cpu(&mut rgba_tight, 1, 1);
217        node_loose.process_cpu(&mut rgba_loose, 1, 1);
218        assert!(
219            rgba_loose[3] < rgba_tight[3],
220            "loose tolerance must key more aggressively than tight"
221        );
222    }
223}