use cranpose_ui_graphics::RuntimeShader;
use naga::back::glsl;
use naga::ShaderStage;
use std::collections::{HashMap, HashSet};
pub(crate) struct ShaderPipelineCache {
backend: wgpu::Backend,
cache: HashMap<u64, wgpu::RenderPipeline>,
disabled: HashSet<u64>,
}
impl ShaderPipelineCache {
pub fn new(backend: wgpu::Backend) -> Self {
Self {
backend,
cache: HashMap::new(),
disabled: HashSet::new(),
}
}
pub fn get_or_create(
&mut self,
device: &wgpu::Device,
shader: &RuntimeShader,
format: wgpu::TextureFormat,
texture_bind_group_layout: &wgpu::BindGroupLayout,
uniform_bind_group_layout: &wgpu::BindGroupLayout,
) -> Option<&wgpu::RenderPipeline> {
let source_hash = shader.source_hash();
if self.disabled.contains(&source_hash) {
return None;
}
if self.cache.contains_key(&source_hash) {
return self.cache.get(&source_hash);
}
if let Err(err) = validate_runtime_shader_source(shader.source(), self.backend) {
log::warn!(
"Disabling RuntimeShader (hash={}): {}. Falling back to pass-through.",
source_hash,
err
);
self.disabled.insert(source_hash);
return None;
}
let pipeline = {
let shader_module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("RuntimeShader Effect"),
source: wgpu::ShaderSource::Wgsl(shader.source().into()),
});
let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some("Effect Pipeline Layout"),
bind_group_layouts: &[texture_bind_group_layout, uniform_bind_group_layout],
immediate_size: 0,
});
device.create_render_pipeline(&wgpu::RenderPipelineDescriptor {
label: Some("RuntimeShader Effect Pipeline"),
layout: Some(&pipeline_layout),
vertex: wgpu::VertexState {
module: &shader_module,
entry_point: Some("fullscreen_vs"),
buffers: &[], compilation_options: wgpu::PipelineCompilationOptions::default(),
},
fragment: Some(wgpu::FragmentState {
module: &shader_module,
entry_point: Some("effect_fs"),
targets: &[Some(wgpu::ColorTargetState {
format,
blend: Some(wgpu::BlendState::REPLACE),
write_mask: wgpu::ColorWrites::ALL,
})],
compilation_options: wgpu::PipelineCompilationOptions::default(),
}),
primitive: wgpu::PrimitiveState {
topology: wgpu::PrimitiveTopology::TriangleStrip,
strip_index_format: None,
front_face: wgpu::FrontFace::Ccw,
cull_mode: None,
..Default::default()
},
depth_stencil: None,
multisample: wgpu::MultisampleState::default(),
multiview_mask: None,
cache: None,
})
};
self.cache.insert(source_hash, pipeline);
self.cache.get(&source_hash)
}
}
fn validate_runtime_shader_source(source: &str, backend: wgpu::Backend) -> Result<(), String> {
let module =
naga::front::wgsl::parse_str(source).map_err(|err| format!("WGSL parse error: {err}"))?;
let has_fullscreen_vs = module
.entry_points
.iter()
.any(|ep| ep.stage == ShaderStage::Vertex && ep.name == "fullscreen_vs");
if !has_fullscreen_vs {
return Err("missing required vertex entry point `fullscreen_vs`".to_string());
}
let has_effect_fs = module
.entry_points
.iter()
.any(|ep| ep.stage == ShaderStage::Fragment && ep.name == "effect_fs");
if !has_effect_fs {
return Err("missing required fragment entry point `effect_fs`".to_string());
}
let mut validator = naga::valid::Validator::new(
naga::valid::ValidationFlags::all(),
naga::valid::Capabilities::all(),
);
let module_info = validator
.validate(&module)
.map_err(|err| format!("WGSL validation error: {err}"))?;
validate_runtime_shader_backend_support(&module, &module_info, backend)?;
Ok(())
}
fn validate_runtime_shader_backend_support(
module: &naga::Module,
module_info: &naga::valid::ModuleInfo,
backend: wgpu::Backend,
) -> Result<(), String> {
if backend != wgpu::Backend::Gl {
return Ok(());
}
validate_glsl_portability(module, module_info, "fullscreen_vs", ShaderStage::Vertex)?;
validate_glsl_portability(module, module_info, "effect_fs", ShaderStage::Fragment)
}
fn validate_glsl_portability(
module: &naga::Module,
module_info: &naga::valid::ModuleInfo,
entry_point: &str,
shader_stage: ShaderStage,
) -> Result<(), String> {
let mut glsl_source = String::new();
let options = glsl::Options {
version: glsl::Version::new_gles(300),
writer_flags: glsl::WriterFlags::ADJUST_COORDINATE_SPACE,
..Default::default()
};
let pipeline_options = glsl::PipelineOptions {
shader_stage,
entry_point: entry_point.to_string(),
multiview: None,
};
let mut writer = glsl::Writer::new(
&mut glsl_source,
module,
module_info,
&options,
&pipeline_options,
naga::proc::BoundsCheckPolicies::default(),
)
.map_err(|err| format!("GL/WebGL portability validation failed for `{entry_point}`: {err}"))?;
writer
.write()
.map(|_| ())
.map_err(|err| format!("GL/WebGL portability emission failed for `{entry_point}`: {err}"))
}
#[cfg(test)]
mod tests {
use super::validate_runtime_shader_source;
use crate::pipeline::GPU_TEXT_BRUSH_EFFECT_SHADER;
use cranpose_ui_graphics::{
GRADIENT_CUT_MASK_WGSL, GRADIENT_FADE_DST_OUT_WGSL, LIQUID_GLASS_WGSL,
ROUNDED_ALPHA_MASK_WGSL,
};
const VALID_SHADER: &str = r#"
struct VertexOutput {
@builtin(position) position: vec4<f32>,
@location(0) uv: vec2<f32>,
}
@vertex
fn fullscreen_vs(@builtin(vertex_index) vertex_index: u32) -> VertexOutput {
var output: VertexOutput;
let x = f32(i32(vertex_index & 1u) * 2 - 1);
let y = f32(i32(vertex_index >> 1u) * 2 - 1);
output.uv = vec2<f32>(x * 0.5 + 0.5, 1.0 - (y * 0.5 + 0.5));
output.position = vec4<f32>(x, y, 0.0, 1.0);
return output;
}
@group(0) @binding(0) var input_texture: texture_2d<f32>;
@group(0) @binding(1) var input_sampler: sampler;
@group(1) @binding(0) var<uniform> u: array<vec4<f32>, 64>;
@fragment
fn effect_fs(input: VertexOutput) -> @location(0) vec4<f32> {
return textureSample(input_texture, input_sampler, input.uv);
}
"#;
#[test]
fn validator_accepts_valid_runtime_shader() {
assert!(validate_runtime_shader_source(VALID_SHADER, wgpu::Backend::Vulkan).is_ok());
}
#[test]
fn validator_rejects_invalid_wgsl() {
let invalid = "this is not wgsl";
assert!(validate_runtime_shader_source(invalid, wgpu::Backend::Vulkan).is_err());
}
#[test]
fn validator_rejects_missing_required_entry_points() {
let missing_effect_fs = r#"
@vertex
fn fullscreen_vs(@builtin(vertex_index) i: u32) -> @builtin(position) vec4<f32> {
let x = f32(i32(i & 1u) * 2 - 1);
let y = f32(i32(i >> 1u) * 2 - 1);
return vec4<f32>(x, y, 0.0, 1.0);
}
"#;
assert!(validate_runtime_shader_source(missing_effect_fs, wgpu::Backend::Vulkan).is_err());
}
#[test]
fn validator_accepts_gl_portable_builtin_runtime_shaders() {
for (name, source) in [
("gradient_cut_mask", GRADIENT_CUT_MASK_WGSL),
("rounded_alpha_mask", ROUNDED_ALPHA_MASK_WGSL),
("gradient_fade_dst_out", GRADIENT_FADE_DST_OUT_WGSL),
("liquid_glass", LIQUID_GLASS_WGSL),
("gpu_text_brush_effect", GPU_TEXT_BRUSH_EFFECT_SHADER),
] {
assert!(
validate_runtime_shader_source(source, wgpu::Backend::Gl).is_ok(),
"{name} should remain GL-portable"
);
}
}
}