Skip to main content

cranpose_ui_graphics/
render_effect.rs

1//! Render effects that can be applied to graphics layers.
2//!
3//! Matches the Jetpack Compose `RenderEffect` API with extensions for custom
4//! WGSL shaders (`RuntimeShader`).
5
6use crate::LayerShape;
7use std::hash::{Hash, Hasher};
8use std::sync::Arc;
9
10/// Edge treatment for blur effects at the boundary of the blurred region.
11#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
12pub enum TileMode {
13    /// Clamp to the edge pixel color.
14    #[default]
15    Clamp,
16    /// Repeat the gradient/effect from start to end.
17    Repeated,
18    /// Mirror the gradient/effect every other repetition.
19    Mirror,
20    /// Treat pixels outside the boundary as transparent.
21    Decal,
22}
23
24/// Controls blur behavior outside source bounds.
25///
26/// This mirrors Compose's `BlurredEdgeTreatment`:
27/// - bounded treatment (`shape != None`) clips blur output and uses `TileMode::Clamp`
28/// - unbounded treatment (`shape == None`) does not clip and uses `TileMode::Decal`
29#[derive(Clone, Copy, Debug, PartialEq)]
30pub struct BlurredEdgeTreatment {
31    shape: Option<LayerShape>,
32}
33
34impl BlurredEdgeTreatment {
35    /// Bounded treatment that clips to a rectangle.
36    pub const RECTANGLE: Self = Self {
37        shape: Some(LayerShape::Rectangle),
38    };
39
40    /// Unbounded treatment that does not clip blurred output.
41    pub const UNBOUNDED: Self = Self { shape: None };
42
43    /// Bounded treatment with a specific clip shape.
44    pub const fn with_shape(shape: LayerShape) -> Self {
45        Self { shape: Some(shape) }
46    }
47
48    pub fn shape(self) -> Option<LayerShape> {
49        self.shape
50    }
51
52    pub fn clip(self) -> bool {
53        self.shape.is_some()
54    }
55
56    pub fn tile_mode(self) -> TileMode {
57        if self.clip() {
58            TileMode::Clamp
59        } else {
60            TileMode::Decal
61        }
62    }
63}
64
65impl Default for BlurredEdgeTreatment {
66    fn default() -> Self {
67        Self::RECTANGLE
68    }
69}
70
71/// A custom WGSL shader effect, analogous to Android's `RuntimeShader`.
72///
73/// The shader source must be a complete WGSL module that declares:
74/// ```wgsl
75/// @group(0) @binding(0) var input_texture: texture_2d<f32>;
76/// @group(0) @binding(1) var input_sampler: sampler;
77/// @group(1) @binding(0) var<uniform> u: array<vec4<f32>, 64>;
78/// ```
79///
80/// Float uniforms are packed linearly into the `u` array. Access them in WGSL
81/// as `u[index / 4][index % 4]` for individual floats, or `u[index / 4].xy`
82/// for vec2, etc. User uniforms may use indices `0..248`; slots `248..256`
83/// are reserved for renderer metadata.
84///
85/// RuntimeShader pipelines operate on premultiplied-alpha textures. Custom
86/// shaders should preserve premultiplied output semantics.
87#[derive(Clone, Debug)]
88pub struct RuntimeShader {
89    source: Arc<str>,
90    uniforms: Vec<f32>,
91}
92
93impl RuntimeShader {
94    /// Total uniform storage size in floats (64 vec4s = 256 floats).
95    ///
96    /// The final slots are reserved for renderer-managed data.
97    pub const MAX_UNIFORMS: usize = 256;
98    /// First renderer-reserved uniform slot.
99    pub const RESERVED_UNIFORM_START: usize = 248;
100    /// Maximum user-addressable uniform count.
101    pub const MAX_USER_UNIFORMS: usize = Self::RESERVED_UNIFORM_START;
102
103    /// Create a new RuntimeShader from WGSL source code.
104    pub fn new(wgsl_source: &str) -> Self {
105        Self {
106            source: Arc::from(wgsl_source),
107            uniforms: Vec::new(),
108        }
109    }
110
111    /// Set a single float uniform at the given index.
112    pub fn set_float(&mut self, index: usize, value: f32) {
113        self.ensure_capacity(index + 1);
114        self.uniforms[index] = value;
115    }
116
117    /// Set a vec2 uniform at the given index (consumes indices `[index, index+1]`).
118    pub fn set_float2(&mut self, index: usize, x: f32, y: f32) {
119        self.ensure_capacity(index + 2);
120        self.uniforms[index] = x;
121        self.uniforms[index + 1] = y;
122    }
123
124    /// Set a vec4 uniform at the given index (consumes indices `[index..index+4]`).
125    pub fn set_float4(&mut self, index: usize, x: f32, y: f32, z: f32, w: f32) {
126        self.ensure_capacity(index + 4);
127        self.uniforms[index] = x;
128        self.uniforms[index + 1] = y;
129        self.uniforms[index + 2] = z;
130        self.uniforms[index + 3] = w;
131    }
132
133    /// Get the WGSL source code.
134    pub fn source(&self) -> &str {
135        &self.source
136    }
137
138    /// Get the uniform data as a float slice (for uploading to GPU).
139    pub fn uniforms(&self) -> &[f32] {
140        &self.uniforms
141    }
142
143    /// Get the uniform data padded to full 256-float array (for GPU uniform buffer).
144    pub fn uniforms_padded(&self) -> [f32; Self::MAX_UNIFORMS] {
145        let mut padded = [0.0f32; Self::MAX_UNIFORMS];
146        let len = self.uniforms.len().min(Self::MAX_UNIFORMS);
147        padded[..len].copy_from_slice(&self.uniforms[..len]);
148        padded
149    }
150
151    /// Compute a hash of the shader source for pipeline caching.
152    pub fn source_hash(&self) -> u64 {
153        let mut hasher = std::collections::hash_map::DefaultHasher::new();
154        self.source.hash(&mut hasher);
155        hasher.finish()
156    }
157
158    fn ensure_capacity(&mut self, min_len: usize) {
159        assert!(
160            min_len <= Self::MAX_USER_UNIFORMS,
161            "uniform index {} exceeds user maximum {}; slots {}..{} are reserved for renderer data",
162            min_len - 1,
163            Self::MAX_USER_UNIFORMS - 1,
164            Self::RESERVED_UNIFORM_START,
165            Self::MAX_UNIFORMS - 1
166        );
167        if self.uniforms.len() < min_len {
168            self.uniforms.resize(min_len, 0.0);
169        }
170    }
171}
172
173impl PartialEq for RuntimeShader {
174    fn eq(&self, other: &Self) -> bool {
175        self.source.as_ref() == other.source.as_ref() && self.uniforms == other.uniforms
176    }
177}
178
179/// A render effect applied to a graphics layer's rendered content.
180///
181/// Matches Jetpack Compose's `RenderEffect` sealed class hierarchy,
182/// extended with `Shader` for custom WGSL effects.
183#[derive(Clone, Debug, PartialEq)]
184pub enum RenderEffect {
185    /// Gaussian blur applied to the layer's rendered content.
186    Blur {
187        radius_x: f32,
188        radius_y: f32,
189        edge_treatment: TileMode,
190    },
191    /// Offset the rendered content by a fixed amount.
192    Offset { offset_x: f32, offset_y: f32 },
193    /// Apply a custom WGSL shader effect.
194    Shader { shader: RuntimeShader },
195    /// Chain two effects: apply `first`, then apply `second` to the result.
196    Chain {
197        first: Box<RenderEffect>,
198        second: Box<RenderEffect>,
199    },
200}
201
202impl RenderEffect {
203    /// Create a blur effect with equal radius in both directions.
204    pub fn blur(radius: f32) -> Self {
205        Self::blur_with_edge_treatment(radius, TileMode::default())
206    }
207
208    /// Create a blur effect with equal radius in both directions and explicit
209    /// edge treatment semantics.
210    pub fn blur_with_edge_treatment(radius: f32, edge_treatment: TileMode) -> Self {
211        Self::Blur {
212            radius_x: radius,
213            radius_y: radius,
214            edge_treatment,
215        }
216    }
217
218    /// Create a blur effect with separate horizontal and vertical radii.
219    pub fn blur_xy(radius_x: f32, radius_y: f32, edge_treatment: TileMode) -> Self {
220        Self::Blur {
221            radius_x,
222            radius_y,
223            edge_treatment,
224        }
225    }
226
227    /// Create an offset effect.
228    pub fn offset(offset_x: f32, offset_y: f32) -> Self {
229        Self::Offset { offset_x, offset_y }
230    }
231
232    /// Create a custom shader effect from a RuntimeShader.
233    pub fn runtime_shader(shader: RuntimeShader) -> Self {
234        Self::Shader { shader }
235    }
236
237    /// Chain this effect with another: `self` is applied first, then `other`.
238    pub fn then(self, other: RenderEffect) -> Self {
239        Self::Chain {
240            first: Box::new(self),
241            second: Box::new(other),
242        }
243    }
244
245    /// Returns `true` if this effect or any chained sub-effect is a
246    /// `RuntimeShader`. Animated shaders produce different output every frame,
247    /// so layer surface caching is counterproductive for them.
248    pub fn contains_runtime_shader(&self) -> bool {
249        match self {
250            RenderEffect::Shader { .. } => true,
251            RenderEffect::Chain { first, second } => {
252                first.contains_runtime_shader() || second.contains_runtime_shader()
253            }
254            _ => false,
255        }
256    }
257}
258
259#[cfg(test)]
260mod tests {
261    use super::*;
262    use crate::RoundedCornerShape;
263
264    #[test]
265    fn runtime_shader_set_uniforms() {
266        let mut shader = RuntimeShader::new("// test");
267        shader.set_float(0, 1.0);
268        shader.set_float2(2, 3.0, 4.0);
269        shader.set_float4(4, 5.0, 6.0, 7.0, 8.0);
270
271        assert_eq!(shader.uniforms()[0], 1.0);
272        assert_eq!(shader.uniforms()[1], 0.0); // gap
273        assert_eq!(shader.uniforms()[2], 3.0);
274        assert_eq!(shader.uniforms()[3], 4.0);
275        assert_eq!(shader.uniforms()[4], 5.0);
276        assert_eq!(shader.uniforms()[5], 6.0);
277        assert_eq!(shader.uniforms()[6], 7.0);
278        assert_eq!(shader.uniforms()[7], 8.0);
279    }
280
281    #[test]
282    fn runtime_shader_padded() {
283        let mut shader = RuntimeShader::new("// test");
284        shader.set_float(0, 42.0);
285        let padded = shader.uniforms_padded();
286        assert_eq!(padded[0], 42.0);
287        assert_eq!(padded[1], 0.0);
288        assert_eq!(padded[255], 0.0);
289    }
290
291    #[test]
292    #[should_panic(expected = "uniform index 256 exceeds user maximum 247")]
293    fn runtime_shader_overflow() {
294        let mut shader = RuntimeShader::new("// test");
295        shader.set_float(256, 1.0);
296    }
297
298    #[test]
299    #[should_panic(expected = "reserved for renderer data")]
300    fn runtime_shader_rejects_reserved_uniform_slots() {
301        let mut shader = RuntimeShader::new("// test");
302        shader.set_float(RuntimeShader::RESERVED_UNIFORM_START, 1.0);
303    }
304
305    #[test]
306    fn render_effect_chaining() {
307        let blur = RenderEffect::blur(10.0);
308        let offset = RenderEffect::offset(5.0, 5.0);
309        let chained = blur.then(offset);
310        match chained {
311            RenderEffect::Chain { first, second } => {
312                assert!(matches!(*first, RenderEffect::Blur { .. }));
313                assert!(matches!(*second, RenderEffect::Offset { .. }));
314            }
315            _ => panic!("expected Chain"),
316        }
317    }
318
319    #[test]
320    fn blur_convenience() {
321        let effect = RenderEffect::blur(15.0);
322        match effect {
323            RenderEffect::Blur {
324                radius_x,
325                radius_y,
326                edge_treatment,
327            } => {
328                assert_eq!(radius_x, 15.0);
329                assert_eq!(radius_y, 15.0);
330                assert_eq!(edge_treatment, TileMode::Clamp);
331            }
332            _ => panic!("expected Blur"),
333        }
334    }
335
336    #[test]
337    fn blur_with_edge_treatment_uses_explicit_mode() {
338        let effect = RenderEffect::blur_with_edge_treatment(6.0, TileMode::Decal);
339        match effect {
340            RenderEffect::Blur {
341                radius_x,
342                radius_y,
343                edge_treatment,
344            } => {
345                assert_eq!(radius_x, 6.0);
346                assert_eq!(radius_y, 6.0);
347                assert_eq!(edge_treatment, TileMode::Decal);
348            }
349            _ => panic!("expected Blur"),
350        }
351    }
352
353    #[test]
354    fn source_hash_consistent() {
355        let s1 = RuntimeShader::new("fn main() {}");
356        let s2 = RuntimeShader::new("fn main() {}");
357        assert_eq!(s1.source_hash(), s2.source_hash());
358    }
359
360    #[test]
361    fn blur_xy_preserves_tile_mode() {
362        let effect = RenderEffect::blur_xy(3.0, 7.0, TileMode::Clamp);
363        match effect {
364            RenderEffect::Blur {
365                radius_x,
366                radius_y,
367                edge_treatment,
368            } => {
369                assert_eq!(radius_x, 3.0);
370                assert_eq!(radius_y, 7.0);
371                assert_eq!(edge_treatment, TileMode::Clamp);
372            }
373            _ => panic!("expected Blur"),
374        }
375    }
376
377    #[test]
378    fn offset_constructor_sets_components() {
379        let effect = RenderEffect::offset(11.0, -5.0);
380        match effect {
381            RenderEffect::Offset { offset_x, offset_y } => {
382                assert_eq!(offset_x, 11.0);
383                assert_eq!(offset_y, -5.0);
384            }
385            _ => panic!("expected Offset"),
386        }
387    }
388
389    #[test]
390    fn runtime_shader_equality_is_source_value_based() {
391        let mut s1 = RuntimeShader::new("fn main() {}");
392        let mut s2 = RuntimeShader::new("fn main() {}");
393        s1.set_float(0, 1.0);
394        s2.set_float(0, 1.0);
395        assert_eq!(s1, s2);
396    }
397
398    #[test]
399    fn blurred_edge_treatment_defaults_to_bounded_rectangle() {
400        let treatment = BlurredEdgeTreatment::default();
401        assert_eq!(treatment.shape(), Some(LayerShape::Rectangle));
402        assert!(treatment.clip());
403        assert_eq!(treatment.tile_mode(), TileMode::Clamp);
404    }
405
406    #[test]
407    fn blurred_edge_treatment_unbounded_uses_decal_and_no_clip() {
408        let treatment = BlurredEdgeTreatment::UNBOUNDED;
409        assert_eq!(treatment.shape(), None);
410        assert!(!treatment.clip());
411        assert_eq!(treatment.tile_mode(), TileMode::Decal);
412    }
413
414    #[test]
415    fn blurred_edge_treatment_with_shape_uses_bounded_mode() {
416        let rounded = LayerShape::Rounded(RoundedCornerShape::uniform(8.0));
417        let treatment = BlurredEdgeTreatment::with_shape(rounded);
418        assert_eq!(treatment.shape(), Some(rounded));
419        assert!(treatment.clip());
420        assert_eq!(treatment.tile_mode(), TileMode::Clamp);
421    }
422}