Skip to main content

cranpose_ui_graphics/
alpha_mask.rs

1//! Alpha-mask helpers for graphics-layer effects.
2//!
3//! These utilities provide a dev-facing API for common Jetpack Compose style
4//! masking workflows, such as revealing/cutting content with a rounded shape
5//! and a feathered opacity transition.
6
7use crate::{CornerRadii, RenderEffect, RuntimeShader};
8
9/// Direction of the gradient cut mask.
10#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
11pub enum CutDirection {
12    /// Keep content from the left edge up to the cut progress.
13    #[default]
14    LeftToRight,
15    /// Keep content from the right edge up to the cut progress.
16    RightToLeft,
17    /// Keep content from the top edge down to the cut progress.
18    TopToBottom,
19    /// Keep content from the bottom edge up to the cut progress.
20    BottomToTop,
21}
22
23impl CutDirection {
24    fn uniform_code(self) -> f32 {
25        match self {
26            CutDirection::LeftToRight => 0.0,
27            CutDirection::RightToLeft => 1.0,
28            CutDirection::TopToBottom => 2.0,
29            CutDirection::BottomToTop => 3.0,
30        }
31    }
32}
33
34/// Configuration for a directional gradient cut mask.
35#[derive(Clone, Copy, Debug, PartialEq)]
36pub struct GradientCutMaskSpec {
37    /// Reveal progress in [0, 1].
38    pub progress: f32,
39    /// Width of the edge feather in dp/px.
40    pub feather: f32,
41    /// Rounded-corner radius of the masked area in dp/px.
42    pub corner_radius: f32,
43    /// Direction from which content is revealed.
44    pub direction: CutDirection,
45}
46
47impl Default for GradientCutMaskSpec {
48    fn default() -> Self {
49        Self {
50            progress: 0.5,
51            feather: 24.0,
52            corner_radius: 16.0,
53            direction: CutDirection::LeftToRight,
54        }
55    }
56}
57
58/// Configuration for a directional gradient fade mask that matches
59/// `drawWithContent { drawRect(..., blendMode = DstOut) }` behavior.
60#[derive(Clone, Copy, Debug, PartialEq)]
61pub struct GradientFadeMaskSpec {
62    /// Axis coordinate where fade starts (fully cut) in dp/px.
63    pub start: f32,
64    /// Axis coordinate where fade ends (fully visible) in dp/px.
65    pub end: f32,
66    /// Direction that defines the gradient axis.
67    pub direction: CutDirection,
68}
69
70impl Default for GradientFadeMaskSpec {
71    fn default() -> Self {
72        Self {
73            start: 0.0,
74            end: 64.0,
75            direction: CutDirection::TopToBottom,
76        }
77    }
78}
79
80/// WGSL shader for directional cut + rounded rect alpha mask.
81///
82/// Uniform layout:
83/// - 0,1: container size in dp
84/// - 2: progress [0,1]
85/// - 3: feather in dp
86/// - 4: corner radius in dp
87/// - 5: direction code (0=L->R, 1=R->L, 2=T->B, 3=B->T)
88pub const GRADIENT_CUT_MASK_WGSL: &str = r#"
89struct VertexOutput {
90    @builtin(position) position: vec4<f32>,
91    @location(0) uv: vec2<f32>,
92}
93
94@vertex
95fn fullscreen_vs(@builtin(vertex_index) vertex_index: u32) -> VertexOutput {
96    var output: VertexOutput;
97    let x = f32(i32(vertex_index & 1u) * 2 - 1);
98    let y = f32(i32(vertex_index >> 1u) * 2 - 1);
99    output.uv = vec2<f32>(x * 0.5 + 0.5, 1.0 - (y * 0.5 + 0.5));
100    output.position = vec4<f32>(x, y, 0.0, 1.0);
101    return output;
102}
103
104@group(0) @binding(0) var input_texture: texture_2d<f32>;
105@group(0) @binding(1) var input_sampler: sampler;
106@group(1) @binding(0) var<uniform> u: array<vec4<f32>, 64>;
107
108fn get_float(index: u32) -> f32 {
109    return u[index / 4u][index % 4u];
110}
111
112fn get_vec2(index: u32) -> vec2<f32> {
113    return vec2<f32>(get_float(index), get_float(index + 1u));
114}
115
116fn sd_round_rect(p: vec2<f32>, half_size: vec2<f32>, radius: f32) -> f32 {
117    let q = abs(p) - half_size + vec2<f32>(radius);
118    return length(max(q, vec2<f32>(0.0))) + min(max(q.x, q.y), 0.0) - radius;
119}
120
121fn rounded_rect_alpha(local_px: vec2<f32>, size_px: vec2<f32>, corner_radius_px: f32) -> f32 {
122    let half = size_px * 0.5;
123    let p = local_px - half;
124    let d = sd_round_rect(p, half, corner_radius_px);
125    return 1.0 - smoothstep(-1.0, 1.0, d);
126}
127
128@fragment
129fn effect_fs(input: VertexOutput) -> @location(0) vec4<f32> {
130    let uv = input.uv;
131    let tex_size = vec2<f32>(textureDimensions(input_texture));
132
133    // Effect layer pixel rect injected by renderer in uniform slot 62.
134    let effect_rect = vec4<f32>(get_float(248u), get_float(249u), get_float(250u), get_float(251u));
135    let container_dp = get_vec2(0u);
136
137    // dp -> pixel mapping for local effect coordinates.
138    let dp_scale = effect_rect.zw / max(container_dp, vec2<f32>(1.0));
139    let s = min(dp_scale.x, dp_scale.y);
140
141    let local_px = uv * tex_size - effect_rect.xy;
142    let size_px = container_dp * dp_scale;
143
144    let progress = clamp(get_float(2u), 0.0, 1.0);
145    let feather_px = max(get_float(3u) * s, 0.001);
146    let corner_radius_px = max(get_float(4u) * s, 0.0);
147    let direction = get_float(5u);
148
149    var axis_value = local_px.x;
150    var axis_extent = max(size_px.x, 0.001);
151
152    if (direction >= 0.5 && direction < 1.5) {
153        axis_value = size_px.x - local_px.x;
154        axis_extent = max(size_px.x, 0.001);
155    } else if (direction >= 1.5 && direction < 2.5) {
156        axis_value = local_px.y;
157        axis_extent = max(size_px.y, 0.001);
158    } else if (direction >= 2.5) {
159        axis_value = size_px.y - local_px.y;
160        axis_extent = max(size_px.y, 0.001);
161    }
162
163    var directional_alpha = 1.0;
164    if (progress < 1.0) {
165        let cut_edge = progress * axis_extent;
166        directional_alpha = smoothstep(cut_edge + feather_px * 0.5, cut_edge - feather_px * 0.5, axis_value);
167    }
168    let shape_alpha = rounded_rect_alpha(local_px, size_px, corner_radius_px);
169    let mask = directional_alpha * shape_alpha;
170
171    let sample = textureSample(input_texture, input_sampler, uv);
172    return sample * mask;
173}
174"#;
175
176/// WGSL shader for rounded-rectangle alpha masking with feathered edges.
177///
178/// Uniform layout:
179/// - 0,1: container size in dp
180/// - 2: edge feather in dp
181/// - 3,4,5,6: corner radii in dp (top-left, top-right, bottom-right, bottom-left)
182pub const ROUNDED_ALPHA_MASK_WGSL: &str = r#"
183struct VertexOutput {
184    @builtin(position) position: vec4<f32>,
185    @location(0) uv: vec2<f32>,
186}
187
188@vertex
189fn fullscreen_vs(@builtin(vertex_index) vertex_index: u32) -> VertexOutput {
190    var output: VertexOutput;
191    let x = f32(i32(vertex_index & 1u) * 2 - 1);
192    let y = f32(i32(vertex_index >> 1u) * 2 - 1);
193    output.uv = vec2<f32>(x * 0.5 + 0.5, 1.0 - (y * 0.5 + 0.5));
194    output.position = vec4<f32>(x, y, 0.0, 1.0);
195    return output;
196}
197
198@group(0) @binding(0) var input_texture: texture_2d<f32>;
199@group(0) @binding(1) var input_sampler: sampler;
200@group(1) @binding(0) var<uniform> u: array<vec4<f32>, 64>;
201
202fn get_float(index: u32) -> f32 {
203    return u[index / 4u][index % 4u];
204}
205
206fn get_vec2(index: u32) -> vec2<f32> {
207    return vec2<f32>(get_float(index), get_float(index + 1u));
208}
209
210fn corner_radius_for_point(p: vec2<f32>, radii: vec4<f32>) -> f32 {
211    if (p.x < 0.0) {
212        if (p.y < 0.0) {
213            return radii.x;
214        }
215        return radii.w;
216    }
217    if (p.y < 0.0) {
218        return radii.y;
219    }
220    return radii.z;
221}
222
223fn sd_round_rect(p: vec2<f32>, half_size: vec2<f32>, radii: vec4<f32>) -> f32 {
224    let radius = corner_radius_for_point(p, radii);
225    let q = abs(p) - half_size + vec2<f32>(radius);
226    return length(max(q, vec2<f32>(0.0))) + min(max(q.x, q.y), 0.0) - radius;
227}
228
229fn rounded_rect_alpha(local_px: vec2<f32>, size_px: vec2<f32>, corner_radii_px: vec4<f32>, feather_px: f32) -> f32 {
230    let half = size_px * 0.5;
231    let p = local_px - half;
232    let d = sd_round_rect(p, half, corner_radii_px);
233    let half_feather = max(feather_px * 0.5, 0.001);
234    return 1.0 - smoothstep(-half_feather, half_feather, d);
235}
236
237@fragment
238fn effect_fs(input: VertexOutput) -> @location(0) vec4<f32> {
239    let uv = input.uv;
240    let tex_size = vec2<f32>(textureDimensions(input_texture));
241
242    // Effect layer pixel rect injected by renderer in uniform slot 62.
243    let effect_rect = vec4<f32>(get_float(248u), get_float(249u), get_float(250u), get_float(251u));
244    let container_dp = get_vec2(0u);
245
246    // dp -> pixel mapping for local effect coordinates.
247    let dp_scale = effect_rect.zw / max(container_dp, vec2<f32>(1.0));
248    let s = min(dp_scale.x, dp_scale.y);
249
250    let local_px = uv * tex_size - effect_rect.xy;
251    let size_px = container_dp * dp_scale;
252
253    let corner_radii_px = max(vec4<f32>(
254        get_float(3u),
255        get_float(4u),
256        get_float(5u),
257        get_float(6u),
258    ) * s, vec4<f32>(0.0));
259    let feather_px = max(get_float(2u) * s, 0.0);
260    let mask = rounded_rect_alpha(local_px, size_px, corner_radii_px, feather_px);
261
262    let sample = textureSample(input_texture, input_sampler, uv);
263    return sample * mask;
264}
265"#;
266
267/// WGSL shader for directional fade-out alpha masking (DstOut-style).
268///
269/// Uniform layout:
270/// - 0,1: container size in dp
271/// - 2: fade start in dp
272/// - 3: fade end in dp
273/// - 4: direction code (0=L->R, 1=R->L, 2=T->B, 3=B->T)
274pub const GRADIENT_FADE_DST_OUT_WGSL: &str = r#"
275struct VertexOutput {
276    @builtin(position) position: vec4<f32>,
277    @location(0) uv: vec2<f32>,
278}
279
280@vertex
281fn fullscreen_vs(@builtin(vertex_index) vertex_index: u32) -> VertexOutput {
282    var output: VertexOutput;
283    let x = f32(i32(vertex_index & 1u) * 2 - 1);
284    let y = f32(i32(vertex_index >> 1u) * 2 - 1);
285    output.uv = vec2<f32>(x * 0.5 + 0.5, 1.0 - (y * 0.5 + 0.5));
286    output.position = vec4<f32>(x, y, 0.0, 1.0);
287    return output;
288}
289
290@group(0) @binding(0) var input_texture: texture_2d<f32>;
291@group(0) @binding(1) var input_sampler: sampler;
292@group(1) @binding(0) var<uniform> u: array<vec4<f32>, 64>;
293
294fn get_float(index: u32) -> f32 {
295    return u[index / 4u][index % 4u];
296}
297
298fn get_vec2(index: u32) -> vec2<f32> {
299    return vec2<f32>(get_float(index), get_float(index + 1u));
300}
301
302@fragment
303fn effect_fs(input: VertexOutput) -> @location(0) vec4<f32> {
304    let uv = input.uv;
305    let tex_size = vec2<f32>(textureDimensions(input_texture));
306
307    // Effect layer pixel rect injected by renderer in uniform slot 62.
308    let effect_rect = vec4<f32>(get_float(248u), get_float(249u), get_float(250u), get_float(251u));
309    let container_dp = get_vec2(0u);
310
311    // dp -> pixel mapping for local effect coordinates.
312    let dp_scale = effect_rect.zw / max(container_dp, vec2<f32>(1.0));
313
314    let local_px = uv * tex_size - effect_rect.xy;
315    let size_px = container_dp * dp_scale;
316    let direction = get_float(4u);
317
318    var axis_value = local_px.x;
319    var axis_scale = dp_scale.x;
320    if (direction >= 0.5 && direction < 1.5) {
321        axis_value = size_px.x - local_px.x;
322        axis_scale = dp_scale.x;
323    } else if (direction >= 1.5 && direction < 2.5) {
324        axis_value = local_px.y;
325        axis_scale = dp_scale.y;
326    } else if (direction >= 2.5) {
327        axis_value = size_px.y - local_px.y;
328        axis_scale = dp_scale.y;
329    }
330
331    let start_px = get_float(2u) * axis_scale;
332    let end_px = get_float(3u) * axis_scale;
333    let span = max(abs(end_px - start_px), 0.001);
334
335    var keep_alpha = 1.0;
336    if (end_px >= start_px) {
337        keep_alpha = clamp((axis_value - start_px) / span, 0.0, 1.0);
338    } else {
339        keep_alpha = clamp((start_px - axis_value) / span, 0.0, 1.0);
340    }
341
342    let sample = textureSample(input_texture, input_sampler, uv);
343    return sample * keep_alpha;
344}
345"#;
346
347/// Builds a directional cut mask effect.
348///
349/// The resulting `RenderEffect` keeps content from one side up to `progress`
350/// with a feathered transition and rounded-rect outer masking.
351pub fn gradient_cut_mask_effect(
352    spec: &GradientCutMaskSpec,
353    area_width: f32,
354    area_height: f32,
355) -> RenderEffect {
356    let mut shader = RuntimeShader::new(GRADIENT_CUT_MASK_WGSL);
357    shader.set_float2(0, area_width.max(1.0), area_height.max(1.0));
358    shader.set_float(2, spec.progress.clamp(0.0, 1.0));
359    shader.set_float(3, spec.feather.max(0.0));
360    shader.set_float(4, spec.corner_radius.max(0.0));
361    shader.set_float(5, spec.direction.uniform_code());
362    RenderEffect::runtime_shader(shader)
363}
364
365/// Builds a rounded alpha mask effect without directional cutting.
366///
367/// Useful for masking other effects (for example blur) to a rounded rectangle
368/// with an explicit feathered edge width.
369pub fn rounded_alpha_mask_effect(
370    area_width: f32,
371    area_height: f32,
372    corner_radius: f32,
373    edge_feather: f32,
374) -> RenderEffect {
375    rounded_corner_alpha_mask_effect(
376        area_width,
377        area_height,
378        CornerRadii::uniform(corner_radius),
379        edge_feather,
380    )
381}
382
383/// Builds a rounded alpha mask effect with independent corner radii.
384pub fn rounded_corner_alpha_mask_effect(
385    area_width: f32,
386    area_height: f32,
387    corner_radii: CornerRadii,
388    edge_feather: f32,
389) -> RenderEffect {
390    let mut shader = RuntimeShader::new(ROUNDED_ALPHA_MASK_WGSL);
391    shader.set_float2(0, area_width.max(1.0), area_height.max(1.0));
392    shader.set_float(2, edge_feather.max(0.0));
393    shader.set_float4(
394        3,
395        corner_radii.top_left.max(0.0),
396        corner_radii.top_right.max(0.0),
397        corner_radii.bottom_right.max(0.0),
398        corner_radii.bottom_left.max(0.0),
399    );
400    RenderEffect::runtime_shader(shader)
401}
402
403/// Builds a directional gradient fade mask with DstOut semantics.
404///
405/// Equivalent to drawing an opaque->transparent gradient on top of content
406/// using destination-out blending: the fade start is fully cut, and the end
407/// is fully preserved.
408pub fn gradient_fade_dst_out_effect(
409    spec: &GradientFadeMaskSpec,
410    area_width: f32,
411    area_height: f32,
412) -> RenderEffect {
413    let mut shader = RuntimeShader::new(GRADIENT_FADE_DST_OUT_WGSL);
414    shader.set_float2(0, area_width.max(1.0), area_height.max(1.0));
415    shader.set_float(2, spec.start);
416    shader.set_float(3, spec.end);
417    shader.set_float(4, spec.direction.uniform_code());
418    RenderEffect::runtime_shader(shader)
419}
420
421#[cfg(test)]
422mod tests {
423    use super::*;
424
425    #[test]
426    fn gradient_cut_spec_defaults() {
427        let spec = GradientCutMaskSpec::default();
428        assert_eq!(spec.progress, 0.5);
429        assert_eq!(spec.feather, 24.0);
430        assert_eq!(spec.corner_radius, 16.0);
431        assert_eq!(spec.direction, CutDirection::LeftToRight);
432    }
433
434    #[test]
435    fn gradient_fade_spec_defaults() {
436        let spec = GradientFadeMaskSpec::default();
437        assert_eq!(spec.start, 0.0);
438        assert_eq!(spec.end, 64.0);
439        assert_eq!(spec.direction, CutDirection::TopToBottom);
440    }
441
442    #[test]
443    fn gradient_cut_effect_sets_uniforms() {
444        let spec = GradientCutMaskSpec {
445            progress: 0.33,
446            feather: 18.0,
447            corner_radius: 20.0,
448            direction: CutDirection::BottomToTop,
449        };
450        let effect = gradient_cut_mask_effect(&spec, 320.0, 180.0);
451        let RenderEffect::Shader { shader } = effect else {
452            panic!("expected shader render effect");
453        };
454
455        let u = shader.uniforms();
456        assert_eq!(u[0], 320.0);
457        assert_eq!(u[1], 180.0);
458        assert_eq!(u[2], 0.33);
459        assert_eq!(u[3], 18.0);
460        assert_eq!(u[4], 20.0);
461        assert_eq!(u[5], 3.0);
462    }
463
464    #[test]
465    fn gradient_cut_effect_clamps_values() {
466        let spec = GradientCutMaskSpec {
467            progress: 2.4,
468            feather: -3.0,
469            corner_radius: -8.0,
470            direction: CutDirection::RightToLeft,
471        };
472        let effect = gradient_cut_mask_effect(&spec, 0.0, 0.0);
473        let RenderEffect::Shader { shader } = effect else {
474            panic!("expected shader render effect");
475        };
476
477        let u = shader.uniforms();
478        assert_eq!(u[0], 1.0);
479        assert_eq!(u[1], 1.0);
480        assert_eq!(u[2], 1.0);
481        assert_eq!(u[3], 0.0);
482        assert_eq!(u[4], 0.0);
483        assert_eq!(u[5], 1.0);
484    }
485
486    #[test]
487    fn rounded_alpha_mask_uses_dedicated_shader_uniforms() {
488        let effect = rounded_alpha_mask_effect(240.0, 120.0, 14.0, 6.0);
489        let RenderEffect::Shader { shader } = effect else {
490            panic!("expected shader render effect");
491        };
492
493        assert_eq!(shader.source(), ROUNDED_ALPHA_MASK_WGSL);
494        let u = shader.uniforms();
495        assert_eq!(u[0], 240.0);
496        assert_eq!(u[1], 120.0);
497        assert_eq!(u[2], 6.0);
498        assert_eq!(u[3], 14.0);
499        assert_eq!(u[4], 14.0);
500        assert_eq!(u[5], 14.0);
501        assert_eq!(u[6], 14.0);
502    }
503
504    #[test]
505    fn rounded_corner_alpha_mask_sets_per_corner_uniforms() {
506        let effect = rounded_corner_alpha_mask_effect(
507            240.0,
508            120.0,
509            CornerRadii {
510                top_left: 4.0,
511                top_right: 8.0,
512                bottom_right: 12.0,
513                bottom_left: 16.0,
514            },
515            2.0,
516        );
517        let RenderEffect::Shader { shader } = effect else {
518            panic!("expected shader render effect");
519        };
520
521        assert_eq!(shader.source(), ROUNDED_ALPHA_MASK_WGSL);
522        let u = shader.uniforms();
523        assert_eq!(u[0], 240.0);
524        assert_eq!(u[1], 120.0);
525        assert_eq!(u[2], 2.0);
526        assert_eq!(u[3], 4.0);
527        assert_eq!(u[4], 8.0);
528        assert_eq!(u[5], 12.0);
529        assert_eq!(u[6], 16.0);
530    }
531
532    #[test]
533    fn gradient_fade_dst_out_effect_sets_uniforms() {
534        let spec = GradientFadeMaskSpec {
535            start: 24.0,
536            end: 52.0,
537            direction: CutDirection::BottomToTop,
538        };
539        let effect = gradient_fade_dst_out_effect(&spec, 300.0, 180.0);
540        let RenderEffect::Shader { shader } = effect else {
541            panic!("expected shader render effect");
542        };
543
544        assert_eq!(shader.source(), GRADIENT_FADE_DST_OUT_WGSL);
545        let u = shader.uniforms();
546        assert_eq!(u[0], 300.0);
547        assert_eq!(u[1], 180.0);
548        assert_eq!(u[2], 24.0);
549        assert_eq!(u[3], 52.0);
550        assert_eq!(u[4], 3.0);
551    }
552}