Skip to main content

cranpose_ui_graphics/
render_effect.rs

1//! Render effects that can be applied to graphics layers.
2//!
3//! Matches the Jetpack Compose `RenderEffect` API with extensions for custom
4//! WGSL shaders (`RuntimeShader`).
5
6use crate::LayerShape;
7use std::hash::{Hash, Hasher};
8use std::sync::Arc;
9
10/// Edge treatment for blur effects at the boundary of the blurred region.
11#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
12pub enum TileMode {
13    /// Clamp to the edge pixel color.
14    #[default]
15    Clamp,
16    /// Repeat the gradient/effect from start to end.
17    Repeated,
18    /// Mirror the gradient/effect every other repetition.
19    Mirror,
20    /// Treat pixels outside the boundary as transparent.
21    Decal,
22}
23
24/// Controls blur behavior outside source bounds.
25///
26/// This mirrors Compose's `BlurredEdgeTreatment`:
27/// - bounded treatment (`shape != None`) clips blur output and uses `TileMode::Clamp`
28/// - unbounded treatment (`shape == None`) does not clip and uses `TileMode::Decal`
29#[derive(Clone, Copy, Debug, PartialEq)]
30pub struct BlurredEdgeTreatment {
31    shape: Option<LayerShape>,
32}
33
34impl BlurredEdgeTreatment {
35    /// Bounded treatment that clips to a rectangle.
36    pub const RECTANGLE: Self = Self {
37        shape: Some(LayerShape::Rectangle),
38    };
39
40    /// Unbounded treatment that does not clip blurred output.
41    pub const UNBOUNDED: Self = Self { shape: None };
42
43    /// Bounded treatment with a specific clip shape.
44    pub const fn with_shape(shape: LayerShape) -> Self {
45        Self { shape: Some(shape) }
46    }
47
48    pub fn shape(self) -> Option<LayerShape> {
49        self.shape
50    }
51
52    pub fn clip(self) -> bool {
53        self.shape.is_some()
54    }
55
56    pub fn tile_mode(self) -> TileMode {
57        if self.clip() {
58            TileMode::Clamp
59        } else {
60            TileMode::Decal
61        }
62    }
63}
64
65impl Default for BlurredEdgeTreatment {
66    fn default() -> Self {
67        Self::RECTANGLE
68    }
69}
70
71/// A custom WGSL shader effect, analogous to Android's `RuntimeShader`.
72///
73/// The shader source must be a complete WGSL module that declares:
74/// ```wgsl
75/// @group(0) @binding(0) var input_texture: texture_2d<f32>;
76/// @group(0) @binding(1) var input_sampler: sampler;
77/// @group(1) @binding(0) var<uniform> u: array<vec4<f32>, 64>;
78/// ```
79///
80/// Float uniforms are packed linearly into the `u` array. Access them in WGSL
81/// as `u[index / 4][index % 4]` for individual floats, or `u[index / 4].xy`
82/// for vec2, etc. User uniforms may use indices `0..248`; slots `248..256`
83/// are reserved for renderer metadata.
84///
85/// RuntimeShader pipelines operate on premultiplied-alpha textures. Custom
86/// shaders should preserve premultiplied output semantics.
87#[derive(Clone, Debug)]
88pub struct RuntimeShader {
89    source: Arc<str>,
90    uniforms: Vec<f32>,
91}
92
93impl RuntimeShader {
94    /// Total uniform storage size in floats (64 vec4s = 256 floats).
95    ///
96    /// The final slots are reserved for renderer-managed data.
97    pub const MAX_UNIFORMS: usize = 256;
98    /// First renderer-reserved uniform slot.
99    pub const RESERVED_UNIFORM_START: usize = 248;
100    /// Maximum user-addressable uniform count.
101    pub const MAX_USER_UNIFORMS: usize = Self::RESERVED_UNIFORM_START;
102
103    /// Create a new RuntimeShader from WGSL source code.
104    pub fn new(wgsl_source: &str) -> Self {
105        Self {
106            source: Arc::from(wgsl_source),
107            uniforms: Vec::new(),
108        }
109    }
110
111    /// Set a single float uniform at the given index.
112    pub fn set_float(&mut self, index: usize, value: f32) {
113        self.ensure_capacity(index + 1);
114        self.uniforms[index] = value;
115    }
116
117    /// Set a vec2 uniform at the given index (consumes indices `[index, index+1]`).
118    pub fn set_float2(&mut self, index: usize, x: f32, y: f32) {
119        self.ensure_capacity(index + 2);
120        self.uniforms[index] = x;
121        self.uniforms[index + 1] = y;
122    }
123
124    /// Set a vec4 uniform at the given index (consumes indices `[index..index+4]`).
125    pub fn set_float4(&mut self, index: usize, x: f32, y: f32, z: f32, w: f32) {
126        self.ensure_capacity(index + 4);
127        self.uniforms[index] = x;
128        self.uniforms[index + 1] = y;
129        self.uniforms[index + 2] = z;
130        self.uniforms[index + 3] = w;
131    }
132
133    /// Get the WGSL source code.
134    pub fn source(&self) -> &str {
135        &self.source
136    }
137
138    /// Get the uniform data as a float slice (for uploading to GPU).
139    pub fn uniforms(&self) -> &[f32] {
140        &self.uniforms
141    }
142
143    /// Get the uniform data padded to full 256-float array (for GPU uniform buffer).
144    pub fn uniforms_padded(&self) -> [f32; Self::MAX_UNIFORMS] {
145        let mut padded = [0.0f32; Self::MAX_UNIFORMS];
146        let len = self.uniforms.len().min(Self::MAX_UNIFORMS);
147        padded[..len].copy_from_slice(&self.uniforms[..len]);
148        padded
149    }
150
151    /// Compute a hash of the shader source for pipeline caching.
152    pub fn source_hash(&self) -> u64 {
153        let mut hasher = std::collections::hash_map::DefaultHasher::new();
154        self.source.hash(&mut hasher);
155        hasher.finish()
156    }
157
158    fn ensure_capacity(&mut self, min_len: usize) {
159        assert!(
160            min_len <= Self::MAX_USER_UNIFORMS,
161            "uniform index {} exceeds user maximum {}; slots {}..{} are reserved for renderer data",
162            min_len - 1,
163            Self::MAX_USER_UNIFORMS - 1,
164            Self::RESERVED_UNIFORM_START,
165            Self::MAX_UNIFORMS - 1
166        );
167        if self.uniforms.len() < min_len {
168            self.uniforms.resize(min_len, 0.0);
169        }
170    }
171}
172
173impl PartialEq for RuntimeShader {
174    fn eq(&self, other: &Self) -> bool {
175        self.source.as_ref() == other.source.as_ref() && self.uniforms == other.uniforms
176    }
177}
178
179/// A render effect applied to a graphics layer's rendered content.
180///
181/// Matches Jetpack Compose's `RenderEffect` sealed class hierarchy,
182/// extended with `Shader` for custom WGSL effects.
183#[derive(Clone, Debug, PartialEq)]
184pub enum RenderEffect {
185    /// Gaussian blur applied to the layer's rendered content.
186    Blur {
187        radius_x: f32,
188        radius_y: f32,
189        edge_treatment: TileMode,
190    },
191    /// Offset the rendered content by a fixed amount.
192    Offset { offset_x: f32, offset_y: f32 },
193    /// Apply a custom WGSL shader effect.
194    Shader { shader: RuntimeShader },
195    /// Chain two effects: apply `first`, then apply `second` to the result.
196    Chain {
197        first: Box<RenderEffect>,
198        second: Box<RenderEffect>,
199    },
200}
201
202impl RenderEffect {
203    /// Create a blur effect with equal radius in both directions.
204    pub fn blur(radius: f32) -> Self {
205        Self::blur_with_edge_treatment(radius, TileMode::default())
206    }
207
208    /// Create a blur effect with equal radius in both directions and explicit
209    /// edge treatment semantics.
210    pub fn blur_with_edge_treatment(radius: f32, edge_treatment: TileMode) -> Self {
211        Self::Blur {
212            radius_x: radius,
213            radius_y: radius,
214            edge_treatment,
215        }
216    }
217
218    /// Create a blur effect with separate horizontal and vertical radii.
219    pub fn blur_xy(radius_x: f32, radius_y: f32, edge_treatment: TileMode) -> Self {
220        Self::Blur {
221            radius_x,
222            radius_y,
223            edge_treatment,
224        }
225    }
226
227    /// Create an offset effect.
228    pub fn offset(offset_x: f32, offset_y: f32) -> Self {
229        Self::Offset { offset_x, offset_y }
230    }
231
232    /// Create a custom shader effect from a RuntimeShader.
233    pub fn runtime_shader(shader: RuntimeShader) -> Self {
234        Self::Shader { shader }
235    }
236
237    /// Chain this effect with another: `self` is applied first, then `other`.
238    pub fn then(self, other: RenderEffect) -> Self {
239        Self::Chain {
240            first: Box::new(self),
241            second: Box::new(other),
242        }
243    }
244}
245
246#[cfg(test)]
247mod tests {
248    use super::*;
249    use crate::RoundedCornerShape;
250
251    #[test]
252    fn runtime_shader_set_uniforms() {
253        let mut shader = RuntimeShader::new("// test");
254        shader.set_float(0, 1.0);
255        shader.set_float2(2, 3.0, 4.0);
256        shader.set_float4(4, 5.0, 6.0, 7.0, 8.0);
257
258        assert_eq!(shader.uniforms()[0], 1.0);
259        assert_eq!(shader.uniforms()[1], 0.0); // gap
260        assert_eq!(shader.uniforms()[2], 3.0);
261        assert_eq!(shader.uniforms()[3], 4.0);
262        assert_eq!(shader.uniforms()[4], 5.0);
263        assert_eq!(shader.uniforms()[5], 6.0);
264        assert_eq!(shader.uniforms()[6], 7.0);
265        assert_eq!(shader.uniforms()[7], 8.0);
266    }
267
268    #[test]
269    fn runtime_shader_padded() {
270        let mut shader = RuntimeShader::new("// test");
271        shader.set_float(0, 42.0);
272        let padded = shader.uniforms_padded();
273        assert_eq!(padded[0], 42.0);
274        assert_eq!(padded[1], 0.0);
275        assert_eq!(padded[255], 0.0);
276    }
277
278    #[test]
279    #[should_panic(expected = "uniform index 256 exceeds user maximum 247")]
280    fn runtime_shader_overflow() {
281        let mut shader = RuntimeShader::new("// test");
282        shader.set_float(256, 1.0);
283    }
284
285    #[test]
286    #[should_panic(expected = "reserved for renderer data")]
287    fn runtime_shader_rejects_reserved_uniform_slots() {
288        let mut shader = RuntimeShader::new("// test");
289        shader.set_float(RuntimeShader::RESERVED_UNIFORM_START, 1.0);
290    }
291
292    #[test]
293    fn render_effect_chaining() {
294        let blur = RenderEffect::blur(10.0);
295        let offset = RenderEffect::offset(5.0, 5.0);
296        let chained = blur.then(offset);
297        match chained {
298            RenderEffect::Chain { first, second } => {
299                assert!(matches!(*first, RenderEffect::Blur { .. }));
300                assert!(matches!(*second, RenderEffect::Offset { .. }));
301            }
302            _ => panic!("expected Chain"),
303        }
304    }
305
306    #[test]
307    fn blur_convenience() {
308        let effect = RenderEffect::blur(15.0);
309        match effect {
310            RenderEffect::Blur {
311                radius_x,
312                radius_y,
313                edge_treatment,
314            } => {
315                assert_eq!(radius_x, 15.0);
316                assert_eq!(radius_y, 15.0);
317                assert_eq!(edge_treatment, TileMode::Clamp);
318            }
319            _ => panic!("expected Blur"),
320        }
321    }
322
323    #[test]
324    fn blur_with_edge_treatment_uses_explicit_mode() {
325        let effect = RenderEffect::blur_with_edge_treatment(6.0, TileMode::Decal);
326        match effect {
327            RenderEffect::Blur {
328                radius_x,
329                radius_y,
330                edge_treatment,
331            } => {
332                assert_eq!(radius_x, 6.0);
333                assert_eq!(radius_y, 6.0);
334                assert_eq!(edge_treatment, TileMode::Decal);
335            }
336            _ => panic!("expected Blur"),
337        }
338    }
339
340    #[test]
341    fn source_hash_consistent() {
342        let s1 = RuntimeShader::new("fn main() {}");
343        let s2 = RuntimeShader::new("fn main() {}");
344        assert_eq!(s1.source_hash(), s2.source_hash());
345    }
346
347    #[test]
348    fn blur_xy_preserves_tile_mode() {
349        let effect = RenderEffect::blur_xy(3.0, 7.0, TileMode::Clamp);
350        match effect {
351            RenderEffect::Blur {
352                radius_x,
353                radius_y,
354                edge_treatment,
355            } => {
356                assert_eq!(radius_x, 3.0);
357                assert_eq!(radius_y, 7.0);
358                assert_eq!(edge_treatment, TileMode::Clamp);
359            }
360            _ => panic!("expected Blur"),
361        }
362    }
363
364    #[test]
365    fn offset_constructor_sets_components() {
366        let effect = RenderEffect::offset(11.0, -5.0);
367        match effect {
368            RenderEffect::Offset { offset_x, offset_y } => {
369                assert_eq!(offset_x, 11.0);
370                assert_eq!(offset_y, -5.0);
371            }
372            _ => panic!("expected Offset"),
373        }
374    }
375
376    #[test]
377    fn runtime_shader_equality_is_source_value_based() {
378        let mut s1 = RuntimeShader::new("fn main() {}");
379        let mut s2 = RuntimeShader::new("fn main() {}");
380        s1.set_float(0, 1.0);
381        s2.set_float(0, 1.0);
382        assert_eq!(s1, s2);
383    }
384
385    #[test]
386    fn blurred_edge_treatment_defaults_to_bounded_rectangle() {
387        let treatment = BlurredEdgeTreatment::default();
388        assert_eq!(treatment.shape(), Some(LayerShape::Rectangle));
389        assert!(treatment.clip());
390        assert_eq!(treatment.tile_mode(), TileMode::Clamp);
391    }
392
393    #[test]
394    fn blurred_edge_treatment_unbounded_uses_decal_and_no_clip() {
395        let treatment = BlurredEdgeTreatment::UNBOUNDED;
396        assert_eq!(treatment.shape(), None);
397        assert!(!treatment.clip());
398        assert_eq!(treatment.tile_mode(), TileMode::Decal);
399    }
400
401    #[test]
402    fn blurred_edge_treatment_with_shape_uses_bounded_mode() {
403        let rounded = LayerShape::Rounded(RoundedCornerShape::uniform(8.0));
404        let treatment = BlurredEdgeTreatment::with_shape(rounded);
405        assert_eq!(treatment.shape(), Some(rounded));
406        assert!(treatment.clip());
407        assert_eq!(treatment.tile_mode(), TileMode::Clamp);
408    }
409}