Skip to main content

cranpose_ui_graphics/
render_effect.rs

1//! Render effects that can be applied to graphics layers.
2//!
3//! Matches the Jetpack Compose `RenderEffect` API with extensions for custom
4//! WGSL shaders (`RuntimeShader`).
5
6use crate::LayerShape;
7use std::collections::HashMap;
8use std::sync::{Arc, Mutex, OnceLock, Weak};
9
10/// Edge treatment for blur effects at the boundary of the blurred region.
11#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
12pub enum TileMode {
13    /// Clamp to the edge pixel color.
14    #[default]
15    Clamp,
16    /// Repeat the gradient/effect from start to end.
17    Repeated,
18    /// Mirror the gradient/effect every other repetition.
19    Mirror,
20    /// Treat pixels outside the boundary as transparent.
21    Decal,
22}
23
24/// Controls blur behavior outside source bounds.
25///
26/// This mirrors Compose's `BlurredEdgeTreatment`:
27/// - bounded treatment (`shape != None`) clips blur output and uses `TileMode::Clamp`
28/// - unbounded treatment (`shape == None`) does not clip and uses `TileMode::Decal`
29#[derive(Clone, Copy, Debug, PartialEq)]
30pub struct BlurredEdgeTreatment {
31    shape: Option<LayerShape>,
32}
33
34impl BlurredEdgeTreatment {
35    /// Bounded treatment that clips to a rectangle.
36    pub const RECTANGLE: Self = Self {
37        shape: Some(LayerShape::Rectangle),
38    };
39
40    /// Unbounded treatment that does not clip blurred output.
41    pub const UNBOUNDED: Self = Self { shape: None };
42
43    /// Bounded treatment with a specific clip shape.
44    pub const fn with_shape(shape: LayerShape) -> Self {
45        Self { shape: Some(shape) }
46    }
47
48    pub fn shape(self) -> Option<LayerShape> {
49        self.shape
50    }
51
52    pub fn clip(self) -> bool {
53        self.shape.is_some()
54    }
55
56    pub fn tile_mode(self) -> TileMode {
57        if self.clip() {
58            TileMode::Clamp
59        } else {
60            TileMode::Decal
61        }
62    }
63}
64
65impl Default for BlurredEdgeTreatment {
66    fn default() -> Self {
67        Self::RECTANGLE
68    }
69}
70
71/// A custom WGSL shader effect, analogous to Android's `RuntimeShader`.
72///
73/// The shader source must be a complete WGSL module that declares:
74/// ```wgsl
75/// @group(0) @binding(0) var input_texture: texture_2d<f32>;
76/// @group(0) @binding(1) var input_sampler: sampler;
77/// @group(1) @binding(0) var<uniform> u: array<vec4<f32>, 64>;
78/// ```
79///
80/// Float uniforms are packed linearly into the `u` array. Access them in WGSL
81/// as `u[index / 4][index % 4]` for individual floats, or `u[index / 4].xy`
82/// for vec2, etc. User uniforms may use indices `0..248`; slots `248..256`
83/// are reserved for renderer metadata.
84///
85/// RuntimeShader pipelines operate on premultiplied-alpha textures. Custom
86/// shaders should preserve premultiplied output semantics.
87#[derive(Clone, Debug)]
88pub struct RuntimeShader {
89    source: Arc<str>,
90    source_hash: u64,
91    uniforms: Vec<f32>,
92}
93
94impl RuntimeShader {
95    /// Total uniform storage size in floats (64 vec4s = 256 floats).
96    ///
97    /// The final slots are reserved for renderer-managed data.
98    pub const MAX_UNIFORMS: usize = 256;
99    /// First renderer-reserved uniform slot.
100    pub const RESERVED_UNIFORM_START: usize = 248;
101    /// Maximum user-addressable uniform count.
102    pub const MAX_USER_UNIFORMS: usize = Self::RESERVED_UNIFORM_START;
103    const INITIAL_UNIFORM_CAPACITY: usize = 16;
104
105    /// Create a new RuntimeShader from WGSL source code.
106    pub fn new(wgsl_source: &str) -> Self {
107        let source_hash = hash_shader_source(wgsl_source);
108        Self {
109            source: intern_shader_source(wgsl_source, source_hash),
110            source_hash,
111            uniforms: Vec::with_capacity(Self::INITIAL_UNIFORM_CAPACITY),
112        }
113    }
114
115    /// Create a RuntimeShader from shared WGSL source code.
116    ///
117    /// This avoids repeatedly copying large shader modules for animated effects
118    /// that rebuild only their uniform payload every frame.
119    pub fn from_shared_source(source: Arc<str>) -> Self {
120        let source_hash = hash_shader_source(&source);
121        Self {
122            source,
123            source_hash,
124            uniforms: Vec::with_capacity(Self::INITIAL_UNIFORM_CAPACITY),
125        }
126    }
127
128    /// Set a single float uniform at the given index.
129    pub fn set_float(&mut self, index: usize, value: f32) {
130        self.ensure_capacity(index + 1);
131        self.uniforms[index] = value;
132    }
133
134    /// Set a vec2 uniform at the given index (consumes indices `[index, index+1]`).
135    pub fn set_float2(&mut self, index: usize, x: f32, y: f32) {
136        self.ensure_capacity(index + 2);
137        self.uniforms[index] = x;
138        self.uniforms[index + 1] = y;
139    }
140
141    /// Set a vec4 uniform at the given index (consumes indices `[index..index+4]`).
142    pub fn set_float4(&mut self, index: usize, x: f32, y: f32, z: f32, w: f32) {
143        self.ensure_capacity(index + 4);
144        self.uniforms[index] = x;
145        self.uniforms[index + 1] = y;
146        self.uniforms[index + 2] = z;
147        self.uniforms[index + 3] = w;
148    }
149
150    /// Get the WGSL source code.
151    pub fn source(&self) -> &str {
152        &self.source
153    }
154
155    /// Get the uniform data as a float slice (for uploading to GPU).
156    pub fn uniforms(&self) -> &[f32] {
157        &self.uniforms
158    }
159
160    /// Get the uniform data padded to full 256-float array (for GPU uniform buffer).
161    pub fn uniforms_padded(&self) -> [f32; Self::MAX_UNIFORMS] {
162        let mut padded = [0.0f32; Self::MAX_UNIFORMS];
163        let len = self.uniforms.len().min(Self::MAX_UNIFORMS);
164        padded[..len].copy_from_slice(&self.uniforms[..len]);
165        padded
166    }
167
168    /// Compute a hash of the shader source for pipeline caching.
169    pub fn source_hash(&self) -> u64 {
170        self.source_hash
171    }
172
173    fn ensure_capacity(&mut self, min_len: usize) {
174        assert!(
175            min_len <= Self::MAX_USER_UNIFORMS,
176            "uniform index {} exceeds user maximum {}; slots {}..{} are reserved for renderer data",
177            min_len - 1,
178            Self::MAX_USER_UNIFORMS - 1,
179            Self::RESERVED_UNIFORM_START,
180            Self::MAX_UNIFORMS - 1
181        );
182        if self.uniforms.len() < min_len {
183            self.uniforms.resize(min_len, 0.0);
184        }
185    }
186}
187
188impl PartialEq for RuntimeShader {
189    fn eq(&self, other: &Self) -> bool {
190        self.source_hash == other.source_hash
191            && (Arc::ptr_eq(&self.source, &other.source)
192                || self.source.as_ref() == other.source.as_ref())
193            && self.uniforms == other.uniforms
194    }
195}
196
197fn hash_shader_source(source: &str) -> u64 {
198    const FNV_OFFSET_BASIS: u64 = 0xcbf2_9ce4_8422_2325;
199    const FNV_PRIME: u64 = 0x0000_0100_0000_01b3;
200
201    source
202        .as_bytes()
203        .iter()
204        .fold(FNV_OFFSET_BASIS, |hash, byte| {
205            (hash ^ u64::from(*byte)).wrapping_mul(FNV_PRIME)
206        })
207}
208
209type ShaderSourceInterner = HashMap<u64, Vec<Weak<str>>>;
210
211fn intern_shader_source(source: &str, source_hash: u64) -> Arc<str> {
212    static INTERNER: OnceLock<Mutex<ShaderSourceInterner>> = OnceLock::new();
213
214    let interner = INTERNER.get_or_init(|| Mutex::new(HashMap::new()));
215    let mut guard = interner
216        .lock()
217        .unwrap_or_else(|poisoned| poisoned.into_inner());
218    let bucket = guard.entry(source_hash).or_default();
219
220    for weak in bucket.iter() {
221        if let Some(existing) = weak.upgrade() {
222            if existing.as_ref() == source {
223                return existing;
224            }
225        }
226    }
227
228    bucket.retain(|weak| weak.upgrade().is_some());
229
230    let shared = Arc::<str>::from(source);
231    bucket.push(Arc::downgrade(&shared));
232    shared
233}
234
235/// A render effect applied to a graphics layer's rendered content.
236///
237/// Matches Jetpack Compose's `RenderEffect` sealed class hierarchy,
238/// extended with `Shader` for custom WGSL effects.
239#[derive(Clone, Debug, PartialEq)]
240pub enum RenderEffect {
241    /// Gaussian blur applied to the layer's rendered content.
242    Blur {
243        radius_x: f32,
244        radius_y: f32,
245        edge_treatment: TileMode,
246    },
247    /// Offset the rendered content by a fixed amount.
248    Offset { offset_x: f32, offset_y: f32 },
249    /// Apply a custom WGSL shader effect.
250    Shader { shader: RuntimeShader },
251    /// Chain two effects: apply `first`, then apply `second` to the result.
252    Chain {
253        first: Box<RenderEffect>,
254        second: Box<RenderEffect>,
255    },
256}
257
258impl RenderEffect {
259    /// Create a blur effect with equal radius in both directions.
260    pub fn blur(radius: f32) -> Self {
261        Self::blur_with_edge_treatment(radius, TileMode::default())
262    }
263
264    /// Create a blur effect with equal radius in both directions and explicit
265    /// edge treatment semantics.
266    pub fn blur_with_edge_treatment(radius: f32, edge_treatment: TileMode) -> Self {
267        Self::Blur {
268            radius_x: radius,
269            radius_y: radius,
270            edge_treatment,
271        }
272    }
273
274    /// Create a blur effect with separate horizontal and vertical radii.
275    pub fn blur_xy(radius_x: f32, radius_y: f32, edge_treatment: TileMode) -> Self {
276        Self::Blur {
277            radius_x,
278            radius_y,
279            edge_treatment,
280        }
281    }
282
283    /// Create an offset effect.
284    pub fn offset(offset_x: f32, offset_y: f32) -> Self {
285        Self::Offset { offset_x, offset_y }
286    }
287
288    /// Create a custom shader effect from a RuntimeShader.
289    pub fn runtime_shader(shader: RuntimeShader) -> Self {
290        Self::Shader { shader }
291    }
292
293    /// Chain this effect with another: `self` is applied first, then `other`.
294    pub fn then(self, other: RenderEffect) -> Self {
295        Self::Chain {
296            first: Box::new(self),
297            second: Box::new(other),
298        }
299    }
300
301    /// Returns `true` if this effect or any chained sub-effect is a
302    /// `RuntimeShader`. Animated shaders produce different output every frame,
303    /// so layer surface caching is counterproductive for them.
304    pub fn contains_runtime_shader(&self) -> bool {
305        match self {
306            RenderEffect::Shader { .. } => true,
307            RenderEffect::Chain { first, second } => {
308                first.contains_runtime_shader() || second.contains_runtime_shader()
309            }
310            _ => false,
311        }
312    }
313}
314
315#[cfg(test)]
316mod tests {
317    use super::*;
318    use crate::RoundedCornerShape;
319
320    #[test]
321    fn runtime_shader_set_uniforms() {
322        let mut shader = RuntimeShader::new("// test");
323        shader.set_float(0, 1.0);
324        shader.set_float2(2, 3.0, 4.0);
325        shader.set_float4(4, 5.0, 6.0, 7.0, 8.0);
326
327        assert_eq!(shader.uniforms()[0], 1.0);
328        assert_eq!(shader.uniforms()[1], 0.0); // gap
329        assert_eq!(shader.uniforms()[2], 3.0);
330        assert_eq!(shader.uniforms()[3], 4.0);
331        assert_eq!(shader.uniforms()[4], 5.0);
332        assert_eq!(shader.uniforms()[5], 6.0);
333        assert_eq!(shader.uniforms()[6], 7.0);
334        assert_eq!(shader.uniforms()[7], 8.0);
335    }
336
337    #[test]
338    fn runtime_shader_padded() {
339        let mut shader = RuntimeShader::new("// test");
340        shader.set_float(0, 42.0);
341        let padded = shader.uniforms_padded();
342        assert_eq!(padded[0], 42.0);
343        assert_eq!(padded[1], 0.0);
344        assert_eq!(padded[255], 0.0);
345    }
346
347    #[test]
348    #[should_panic(expected = "uniform index 256 exceeds user maximum 247")]
349    fn runtime_shader_overflow() {
350        let mut shader = RuntimeShader::new("// test");
351        shader.set_float(256, 1.0);
352    }
353
354    #[test]
355    #[should_panic(expected = "reserved for renderer data")]
356    fn runtime_shader_rejects_reserved_uniform_slots() {
357        let mut shader = RuntimeShader::new("// test");
358        shader.set_float(RuntimeShader::RESERVED_UNIFORM_START, 1.0);
359    }
360
361    #[test]
362    fn render_effect_chaining() {
363        let blur = RenderEffect::blur(10.0);
364        let offset = RenderEffect::offset(5.0, 5.0);
365        let chained = blur.then(offset);
366        match chained {
367            RenderEffect::Chain { first, second } => {
368                assert!(matches!(*first, RenderEffect::Blur { .. }));
369                assert!(matches!(*second, RenderEffect::Offset { .. }));
370            }
371            _ => panic!("expected Chain"),
372        }
373    }
374
375    #[test]
376    fn blur_convenience() {
377        let effect = RenderEffect::blur(15.0);
378        match effect {
379            RenderEffect::Blur {
380                radius_x,
381                radius_y,
382                edge_treatment,
383            } => {
384                assert_eq!(radius_x, 15.0);
385                assert_eq!(radius_y, 15.0);
386                assert_eq!(edge_treatment, TileMode::Clamp);
387            }
388            _ => panic!("expected Blur"),
389        }
390    }
391
392    #[test]
393    fn blur_with_edge_treatment_uses_explicit_mode() {
394        let effect = RenderEffect::blur_with_edge_treatment(6.0, TileMode::Decal);
395        match effect {
396            RenderEffect::Blur {
397                radius_x,
398                radius_y,
399                edge_treatment,
400            } => {
401                assert_eq!(radius_x, 6.0);
402                assert_eq!(radius_y, 6.0);
403                assert_eq!(edge_treatment, TileMode::Decal);
404            }
405            _ => panic!("expected Blur"),
406        }
407    }
408
409    #[test]
410    fn source_hash_consistent() {
411        let s1 = RuntimeShader::new("fn main() {}");
412        let s2 = RuntimeShader::new("fn main() {}");
413        assert_eq!(s1.source_hash(), s2.source_hash());
414    }
415
416    #[test]
417    fn runtime_shader_new_interns_repeated_sources() {
418        let s1 = RuntimeShader::new("fn fragment() -> vec4<f32> { return vec4<f32>(1.0); }");
419        let s2 = RuntimeShader::new("fn fragment() -> vec4<f32> { return vec4<f32>(1.0); }");
420
421        assert!(Arc::ptr_eq(&s1.source, &s2.source));
422        assert_eq!(s1.source_hash(), s2.source_hash());
423    }
424
425    #[test]
426    fn blur_xy_preserves_tile_mode() {
427        let effect = RenderEffect::blur_xy(3.0, 7.0, TileMode::Clamp);
428        match effect {
429            RenderEffect::Blur {
430                radius_x,
431                radius_y,
432                edge_treatment,
433            } => {
434                assert_eq!(radius_x, 3.0);
435                assert_eq!(radius_y, 7.0);
436                assert_eq!(edge_treatment, TileMode::Clamp);
437            }
438            _ => panic!("expected Blur"),
439        }
440    }
441
442    #[test]
443    fn offset_constructor_sets_components() {
444        let effect = RenderEffect::offset(11.0, -5.0);
445        match effect {
446            RenderEffect::Offset { offset_x, offset_y } => {
447                assert_eq!(offset_x, 11.0);
448                assert_eq!(offset_y, -5.0);
449            }
450            _ => panic!("expected Offset"),
451        }
452    }
453
454    #[test]
455    fn runtime_shader_equality_is_source_value_based() {
456        let mut s1 = RuntimeShader::new("fn main() {}");
457        let mut s2 = RuntimeShader::new("fn main() {}");
458        s1.set_float(0, 1.0);
459        s2.set_float(0, 1.0);
460        assert_eq!(s1, s2);
461    }
462
463    #[test]
464    fn blurred_edge_treatment_defaults_to_bounded_rectangle() {
465        let treatment = BlurredEdgeTreatment::default();
466        assert_eq!(treatment.shape(), Some(LayerShape::Rectangle));
467        assert!(treatment.clip());
468        assert_eq!(treatment.tile_mode(), TileMode::Clamp);
469    }
470
471    #[test]
472    fn blurred_edge_treatment_unbounded_uses_decal_and_no_clip() {
473        let treatment = BlurredEdgeTreatment::UNBOUNDED;
474        assert_eq!(treatment.shape(), None);
475        assert!(!treatment.clip());
476        assert_eq!(treatment.tile_mode(), TileMode::Decal);
477    }
478
479    #[test]
480    fn blurred_edge_treatment_with_shape_uses_bounded_mode() {
481        let rounded = LayerShape::Rounded(RoundedCornerShape::uniform(8.0));
482        let treatment = BlurredEdgeTreatment::with_shape(rounded);
483        assert_eq!(treatment.shape(), Some(rounded));
484        assert!(treatment.clip());
485        assert_eq!(treatment.tile_mode(), TileMode::Clamp);
486    }
487}