Skip to main content

cranpose_ui_graphics/
render_effect.rs

1//! Render effects that can be applied to graphics layers.
2//!
3//! Matches the Jetpack Compose `RenderEffect` API with extensions for custom
4//! WGSL shaders (`RuntimeShader`).
5
6use crate::LayerShape;
7use std::sync::Arc;
8
9/// Edge treatment for blur effects at the boundary of the blurred region.
10#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
11pub enum TileMode {
12    /// Clamp to the edge pixel color.
13    #[default]
14    Clamp,
15    /// Repeat the gradient/effect from start to end.
16    Repeated,
17    /// Mirror the gradient/effect every other repetition.
18    Mirror,
19    /// Treat pixels outside the boundary as transparent.
20    Decal,
21}
22
23/// Controls blur behavior outside source bounds.
24///
25/// This mirrors Compose's `BlurredEdgeTreatment`:
26/// - bounded treatment (`shape != None`) clips blur output and uses `TileMode::Clamp`
27/// - unbounded treatment (`shape == None`) does not clip and uses `TileMode::Decal`
28#[derive(Clone, Copy, Debug, PartialEq)]
29pub struct BlurredEdgeTreatment {
30    shape: Option<LayerShape>,
31}
32
33impl BlurredEdgeTreatment {
34    /// Bounded treatment that clips to a rectangle.
35    pub const RECTANGLE: Self = Self {
36        shape: Some(LayerShape::Rectangle),
37    };
38
39    /// Unbounded treatment that does not clip blurred output.
40    pub const UNBOUNDED: Self = Self { shape: None };
41
42    /// Bounded treatment with a specific clip shape.
43    pub const fn with_shape(shape: LayerShape) -> Self {
44        Self { shape: Some(shape) }
45    }
46
47    pub fn shape(self) -> Option<LayerShape> {
48        self.shape
49    }
50
51    pub fn clip(self) -> bool {
52        self.shape.is_some()
53    }
54
55    pub fn tile_mode(self) -> TileMode {
56        if self.clip() {
57            TileMode::Clamp
58        } else {
59            TileMode::Decal
60        }
61    }
62}
63
64impl Default for BlurredEdgeTreatment {
65    fn default() -> Self {
66        Self::RECTANGLE
67    }
68}
69
70/// A custom WGSL shader effect, analogous to Android's `RuntimeShader`.
71///
72/// The shader source must be a complete WGSL module that declares:
73/// ```wgsl
74/// @group(0) @binding(0) var input_texture: texture_2d<f32>;
75/// @group(0) @binding(1) var input_sampler: sampler;
76/// @group(1) @binding(0) var<uniform> u: array<vec4<f32>, 64>;
77/// ```
78///
79/// Float uniforms are packed linearly into the `u` array. Access them in WGSL
80/// as `u[index / 4][index % 4]` for individual floats, or `u[index / 4].xy`
81/// for vec2, etc. User uniforms may use indices `0..248`; slots `248..256`
82/// are reserved for renderer metadata.
83///
84/// RuntimeShader pipelines operate on premultiplied-alpha textures. Custom
85/// shaders should preserve premultiplied output semantics.
86#[derive(Clone, Debug)]
87pub struct RuntimeShader {
88    source: Arc<str>,
89    source_hash: u64,
90    uniforms: Vec<f32>,
91}
92
93/// Error returned when a shader uniform write targets renderer-owned storage.
94#[derive(Clone, Copy, Debug, Eq, PartialEq, thiserror::Error)]
95pub enum RuntimeShaderUniformError {
96    #[error(
97        "uniform range starting at {index} with width {width} exceeds user uniform range 0..{max_user_uniforms}; slots {reserved_start}..{max_uniforms} are reserved for renderer data"
98    )]
99    OutOfUserRange {
100        index: usize,
101        width: usize,
102        max_user_uniforms: usize,
103        reserved_start: usize,
104        max_uniforms: usize,
105    },
106}
107
108impl RuntimeShader {
109    /// Total uniform storage size in floats (64 vec4s = 256 floats).
110    ///
111    /// The final slots are reserved for renderer-managed data.
112    pub const MAX_UNIFORMS: usize = 256;
113    /// First renderer-reserved uniform slot.
114    pub const RESERVED_UNIFORM_START: usize = 248;
115    /// Maximum user-addressable uniform count.
116    pub const MAX_USER_UNIFORMS: usize = Self::RESERVED_UNIFORM_START;
117    const INITIAL_UNIFORM_CAPACITY: usize = 16;
118
119    /// Create a new RuntimeShader from WGSL source code.
120    pub fn new(wgsl_source: &str) -> Self {
121        let source_hash = hash_shader_source(wgsl_source);
122        Self {
123            source: Arc::<str>::from(wgsl_source),
124            source_hash,
125            uniforms: Vec::with_capacity(Self::INITIAL_UNIFORM_CAPACITY),
126        }
127    }
128
129    /// Create a RuntimeShader from shared WGSL source code.
130    ///
131    /// This avoids repeatedly copying large shader modules for animated effects
132    /// that rebuild only their uniform payload every frame.
133    pub fn from_shared_source(source: Arc<str>) -> Self {
134        let source_hash = hash_shader_source(&source);
135        Self {
136            source,
137            source_hash,
138            uniforms: Vec::with_capacity(Self::INITIAL_UNIFORM_CAPACITY),
139        }
140    }
141
142    /// Set a single float uniform at the given index.
143    ///
144    /// Invalid renderer-reserved ranges are ignored. Use [`Self::try_set_float`]
145    /// when the caller needs to handle invalid uniform writes explicitly.
146    pub fn set_float(&mut self, index: usize, value: f32) {
147        let _ = self.try_set_float(index, value);
148    }
149
150    /// Set a single float uniform at the given index.
151    pub fn try_set_float(
152        &mut self,
153        index: usize,
154        value: f32,
155    ) -> Result<(), RuntimeShaderUniformError> {
156        self.try_ensure_capacity(index, 1)?;
157        self.uniforms[index] = value;
158        Ok(())
159    }
160
161    /// Set a vec2 uniform at the given index (consumes indices `[index, index+1]`).
162    ///
163    /// Invalid renderer-reserved ranges are ignored. Use [`Self::try_set_float2`]
164    /// when the caller needs to handle invalid uniform writes explicitly.
165    pub fn set_float2(&mut self, index: usize, x: f32, y: f32) {
166        let _ = self.try_set_float2(index, x, y);
167    }
168
169    /// Set a vec2 uniform at the given index (consumes indices `[index, index+1]`).
170    pub fn try_set_float2(
171        &mut self,
172        index: usize,
173        x: f32,
174        y: f32,
175    ) -> Result<(), RuntimeShaderUniformError> {
176        self.try_ensure_capacity(index, 2)?;
177        self.uniforms[index] = x;
178        self.uniforms[index + 1] = y;
179        Ok(())
180    }
181
182    /// Set a vec4 uniform at the given index (consumes indices `[index..index+4]`).
183    ///
184    /// Invalid renderer-reserved ranges are ignored. Use [`Self::try_set_float4`]
185    /// when the caller needs to handle invalid uniform writes explicitly.
186    pub fn set_float4(&mut self, index: usize, x: f32, y: f32, z: f32, w: f32) {
187        let _ = self.try_set_float4(index, x, y, z, w);
188    }
189
190    /// Set a vec4 uniform at the given index (consumes indices `[index..index+4]`).
191    pub fn try_set_float4(
192        &mut self,
193        index: usize,
194        x: f32,
195        y: f32,
196        z: f32,
197        w: f32,
198    ) -> Result<(), RuntimeShaderUniformError> {
199        self.try_ensure_capacity(index, 4)?;
200        self.uniforms[index] = x;
201        self.uniforms[index + 1] = y;
202        self.uniforms[index + 2] = z;
203        self.uniforms[index + 3] = w;
204        Ok(())
205    }
206
207    /// Get the WGSL source code.
208    pub fn source(&self) -> &str {
209        &self.source
210    }
211
212    /// Get the uniform data as a float slice (for uploading to GPU).
213    pub fn uniforms(&self) -> &[f32] {
214        &self.uniforms
215    }
216
217    /// Get the uniform data padded to full 256-float array (for GPU uniform buffer).
218    pub fn uniforms_padded(&self) -> [f32; Self::MAX_UNIFORMS] {
219        let mut padded = [0.0f32; Self::MAX_UNIFORMS];
220        let len = self.uniforms.len().min(Self::MAX_UNIFORMS);
221        padded[..len].copy_from_slice(&self.uniforms[..len]);
222        padded
223    }
224
225    /// Compute a hash of the shader source for pipeline caching.
226    pub fn source_hash(&self) -> u64 {
227        self.source_hash
228    }
229
230    fn try_ensure_capacity(
231        &mut self,
232        index: usize,
233        width: usize,
234    ) -> Result<(), RuntimeShaderUniformError> {
235        let min_len = index
236            .checked_add(width)
237            .ok_or_else(|| Self::uniform_range_error(index, width))?;
238        if min_len > Self::MAX_USER_UNIFORMS {
239            return Err(Self::uniform_range_error(index, width));
240        }
241        if self.uniforms.len() < min_len {
242            self.uniforms.resize(min_len, 0.0);
243        }
244        Ok(())
245    }
246
247    fn uniform_range_error(index: usize, width: usize) -> RuntimeShaderUniformError {
248        RuntimeShaderUniformError::OutOfUserRange {
249            index,
250            width,
251            max_user_uniforms: Self::MAX_USER_UNIFORMS,
252            reserved_start: Self::RESERVED_UNIFORM_START,
253            max_uniforms: Self::MAX_UNIFORMS,
254        }
255    }
256}
257
258impl PartialEq for RuntimeShader {
259    fn eq(&self, other: &Self) -> bool {
260        self.source_hash == other.source_hash
261            && (Arc::ptr_eq(&self.source, &other.source)
262                || self.source.as_ref() == other.source.as_ref())
263            && self.uniforms == other.uniforms
264    }
265}
266
267fn hash_shader_source(source: &str) -> u64 {
268    const FNV_OFFSET_BASIS: u64 = 0xcbf2_9ce4_8422_2325;
269    const FNV_PRIME: u64 = 0x0000_0100_0000_01b3;
270
271    source
272        .as_bytes()
273        .iter()
274        .fold(FNV_OFFSET_BASIS, |hash, byte| {
275            (hash ^ u64::from(*byte)).wrapping_mul(FNV_PRIME)
276        })
277}
278
279/// A render effect applied to a graphics layer's rendered content.
280///
281/// Matches Jetpack Compose's `RenderEffect` sealed class hierarchy,
282/// extended with `Shader` for custom WGSL effects.
283#[derive(Clone, Debug, PartialEq)]
284pub enum RenderEffect {
285    /// Gaussian blur applied to the layer's rendered content.
286    Blur {
287        radius_x: f32,
288        radius_y: f32,
289        edge_treatment: TileMode,
290    },
291    /// Offset the rendered content by a fixed amount.
292    Offset { offset_x: f32, offset_y: f32 },
293    /// Apply a custom WGSL shader effect.
294    Shader { shader: RuntimeShader },
295    /// Chain two effects: apply `first`, then apply `second` to the result.
296    Chain {
297        first: Box<RenderEffect>,
298        second: Box<RenderEffect>,
299    },
300}
301
302impl RenderEffect {
303    /// Create a blur effect with equal radius in both directions.
304    pub fn blur(radius: f32) -> Self {
305        Self::blur_with_edge_treatment(radius, TileMode::default())
306    }
307
308    /// Create a blur effect with equal radius in both directions and explicit
309    /// edge treatment semantics.
310    pub fn blur_with_edge_treatment(radius: f32, edge_treatment: TileMode) -> Self {
311        Self::Blur {
312            radius_x: radius,
313            radius_y: radius,
314            edge_treatment,
315        }
316    }
317
318    /// Create a blur effect with separate horizontal and vertical radii.
319    pub fn blur_xy(radius_x: f32, radius_y: f32, edge_treatment: TileMode) -> Self {
320        Self::Blur {
321            radius_x,
322            radius_y,
323            edge_treatment,
324        }
325    }
326
327    /// Create an offset effect.
328    pub fn offset(offset_x: f32, offset_y: f32) -> Self {
329        Self::Offset { offset_x, offset_y }
330    }
331
332    /// Create a custom shader effect from a RuntimeShader.
333    pub fn runtime_shader(shader: RuntimeShader) -> Self {
334        Self::Shader { shader }
335    }
336
337    /// Chain this effect with another: `self` is applied first, then `other`.
338    pub fn then(self, other: RenderEffect) -> Self {
339        Self::Chain {
340            first: Box::new(self),
341            second: Box::new(other),
342        }
343    }
344
345    /// Returns `true` if this effect or any chained sub-effect is a
346    /// `RuntimeShader`. Animated shaders produce different output every frame,
347    /// so layer surface caching is counterproductive for them.
348    pub fn contains_runtime_shader(&self) -> bool {
349        match self {
350            RenderEffect::Shader { .. } => true,
351            RenderEffect::Chain { first, second } => {
352                first.contains_runtime_shader() || second.contains_runtime_shader()
353            }
354            _ => false,
355        }
356    }
357}
358
359#[cfg(test)]
360mod tests {
361    use super::*;
362    use crate::RoundedCornerShape;
363
364    #[test]
365    fn runtime_shader_set_uniforms() {
366        let mut shader = RuntimeShader::new("// test");
367        shader.set_float(0, 1.0);
368        shader.set_float2(2, 3.0, 4.0);
369        shader.set_float4(4, 5.0, 6.0, 7.0, 8.0);
370
371        assert_eq!(shader.uniforms()[0], 1.0);
372        assert_eq!(shader.uniforms()[1], 0.0); // gap
373        assert_eq!(shader.uniforms()[2], 3.0);
374        assert_eq!(shader.uniforms()[3], 4.0);
375        assert_eq!(shader.uniforms()[4], 5.0);
376        assert_eq!(shader.uniforms()[5], 6.0);
377        assert_eq!(shader.uniforms()[6], 7.0);
378        assert_eq!(shader.uniforms()[7], 8.0);
379    }
380
381    #[test]
382    fn runtime_shader_padded() {
383        let mut shader = RuntimeShader::new("// test");
384        shader.set_float(0, 42.0);
385        let padded = shader.uniforms_padded();
386        assert_eq!(padded[0], 42.0);
387        assert_eq!(padded[1], 0.0);
388        assert_eq!(padded[255], 0.0);
389    }
390
391    #[test]
392    fn runtime_shader_try_set_reports_reserved_uniform_slots() {
393        let mut shader = RuntimeShader::new("// test");
394
395        let err = shader
396            .try_set_float(RuntimeShader::RESERVED_UNIFORM_START, 1.0)
397            .unwrap_err();
398        assert_eq!(
399            err,
400            RuntimeShaderUniformError::OutOfUserRange {
401                index: RuntimeShader::RESERVED_UNIFORM_START,
402                width: 1,
403                max_user_uniforms: RuntimeShader::MAX_USER_UNIFORMS,
404                reserved_start: RuntimeShader::RESERVED_UNIFORM_START,
405                max_uniforms: RuntimeShader::MAX_UNIFORMS,
406            }
407        );
408        assert!(shader.uniforms().is_empty());
409
410        let err = shader
411            .try_set_float4(RuntimeShader::MAX_USER_UNIFORMS - 3, 1.0, 2.0, 3.0, 4.0)
412            .unwrap_err();
413        assert_eq!(
414            err,
415            RuntimeShaderUniformError::OutOfUserRange {
416                index: RuntimeShader::MAX_USER_UNIFORMS - 3,
417                width: 4,
418                max_user_uniforms: RuntimeShader::MAX_USER_UNIFORMS,
419                reserved_start: RuntimeShader::RESERVED_UNIFORM_START,
420                max_uniforms: RuntimeShader::MAX_UNIFORMS,
421            }
422        );
423    }
424
425    #[test]
426    fn runtime_shader_setters_ignore_invalid_uniform_slots_without_panicking() {
427        let mut shader = RuntimeShader::new("// test");
428        shader.set_float(0, 7.0);
429
430        shader.set_float(RuntimeShader::RESERVED_UNIFORM_START, 1.0);
431        shader.set_float4(RuntimeShader::MAX_USER_UNIFORMS - 3, 1.0, 2.0, 3.0, 4.0);
432
433        assert_eq!(shader.uniforms(), &[7.0]);
434    }
435
436    #[test]
437    fn render_effect_chaining() {
438        let blur = RenderEffect::blur(10.0);
439        let offset = RenderEffect::offset(5.0, 5.0);
440        let chained = blur.then(offset);
441        match chained {
442            RenderEffect::Chain { first, second } => {
443                assert!(matches!(*first, RenderEffect::Blur { .. }));
444                assert!(matches!(*second, RenderEffect::Offset { .. }));
445            }
446            _ => panic!("expected Chain"),
447        }
448    }
449
450    #[test]
451    fn blur_convenience() {
452        let effect = RenderEffect::blur(15.0);
453        match effect {
454            RenderEffect::Blur {
455                radius_x,
456                radius_y,
457                edge_treatment,
458            } => {
459                assert_eq!(radius_x, 15.0);
460                assert_eq!(radius_y, 15.0);
461                assert_eq!(edge_treatment, TileMode::Clamp);
462            }
463            _ => panic!("expected Blur"),
464        }
465    }
466
467    #[test]
468    fn blur_with_edge_treatment_uses_explicit_mode() {
469        let effect = RenderEffect::blur_with_edge_treatment(6.0, TileMode::Decal);
470        match effect {
471            RenderEffect::Blur {
472                radius_x,
473                radius_y,
474                edge_treatment,
475            } => {
476                assert_eq!(radius_x, 6.0);
477                assert_eq!(radius_y, 6.0);
478                assert_eq!(edge_treatment, TileMode::Decal);
479            }
480            _ => panic!("expected Blur"),
481        }
482    }
483
484    #[test]
485    fn source_hash_consistent() {
486        let s1 = RuntimeShader::new("fn main() {}");
487        let s2 = RuntimeShader::new("fn main() {}");
488        assert_eq!(s1.source_hash(), s2.source_hash());
489    }
490
491    #[test]
492    fn runtime_shader_from_shared_source_reuses_shared_source() {
493        let source = Arc::<str>::from("fn fragment() -> vec4<f32> { return vec4<f32>(1.0); }");
494        let s1 = RuntimeShader::from_shared_source(source.clone());
495        let s2 = RuntimeShader::from_shared_source(source);
496
497        assert!(Arc::ptr_eq(&s1.source, &s2.source));
498        assert_eq!(s1.source_hash(), s2.source_hash());
499    }
500
501    #[test]
502    fn runtime_shader_source_storage_has_no_process_global_interner() {
503        let source = include_str!("render_effect.rs");
504        let blocked_static = ["static ", "INTERNER"].concat();
505        let blocked_type = ["ShaderSource", "Interner"].concat();
506
507        assert!(
508            !source.contains(&blocked_static) && !source.contains(&blocked_type),
509            "RuntimeShader source sharing must be explicit via from_shared_source, not a process-global interner"
510        );
511    }
512
513    #[test]
514    fn blur_xy_preserves_tile_mode() {
515        let effect = RenderEffect::blur_xy(3.0, 7.0, TileMode::Clamp);
516        match effect {
517            RenderEffect::Blur {
518                radius_x,
519                radius_y,
520                edge_treatment,
521            } => {
522                assert_eq!(radius_x, 3.0);
523                assert_eq!(radius_y, 7.0);
524                assert_eq!(edge_treatment, TileMode::Clamp);
525            }
526            _ => panic!("expected Blur"),
527        }
528    }
529
530    #[test]
531    fn offset_constructor_sets_components() {
532        let effect = RenderEffect::offset(11.0, -5.0);
533        match effect {
534            RenderEffect::Offset { offset_x, offset_y } => {
535                assert_eq!(offset_x, 11.0);
536                assert_eq!(offset_y, -5.0);
537            }
538            _ => panic!("expected Offset"),
539        }
540    }
541
542    #[test]
543    fn runtime_shader_equality_is_source_value_based() {
544        let mut s1 = RuntimeShader::new("fn main() {}");
545        let mut s2 = RuntimeShader::new("fn main() {}");
546        s1.set_float(0, 1.0);
547        s2.set_float(0, 1.0);
548        assert_eq!(s1, s2);
549    }
550
551    #[test]
552    fn blurred_edge_treatment_defaults_to_bounded_rectangle() {
553        let treatment = BlurredEdgeTreatment::default();
554        assert_eq!(treatment.shape(), Some(LayerShape::Rectangle));
555        assert!(treatment.clip());
556        assert_eq!(treatment.tile_mode(), TileMode::Clamp);
557    }
558
559    #[test]
560    fn blurred_edge_treatment_unbounded_uses_decal_and_no_clip() {
561        let treatment = BlurredEdgeTreatment::UNBOUNDED;
562        assert_eq!(treatment.shape(), None);
563        assert!(!treatment.clip());
564        assert_eq!(treatment.tile_mode(), TileMode::Decal);
565    }
566
567    #[test]
568    fn blurred_edge_treatment_with_shape_uses_bounded_mode() {
569        let rounded = LayerShape::Rounded(RoundedCornerShape::uniform(8.0));
570        let treatment = BlurredEdgeTreatment::with_shape(rounded);
571        assert_eq!(treatment.shape(), Some(rounded));
572        assert!(treatment.clip());
573        assert_eq!(treatment.tile_mode(), TileMode::Clamp);
574    }
575}