use std::sync::OnceLock;
use awsm_renderer_core::{
bind_groups::{
BindGroupLayoutResource, BufferBindingLayout, BufferBindingType, StorageTextureAccess,
StorageTextureBindingLayout, TextureBindingLayout,
},
buffers::{BufferDescriptor, BufferUsage},
renderer::AwsmRendererWebGpu,
shaders::{ShaderModuleDescriptor, ShaderModuleExt},
texture::{TextureFormat, TextureSampleType, TextureViewDimension},
};
use crate::{
bind_group_layout::{
BindGroupLayoutCacheKey, BindGroupLayoutCacheKeyEntry, BindGroupLayoutKey, BindGroupLayouts,
},
pipeline_layouts::{PipelineLayoutCacheKey, PipelineLayoutKey, PipelineLayouts},
pipelines::{
compute_pipeline::{ComputePipelineCacheKey, ComputePipelineKey},
Pipelines,
},
shaders::{ShaderKey, Shaders},
shadows::AwsmShadowError,
};
pub const MAX_BLUR_RADIUS: u32 = 8;
pub const EVSM_PARAMS_STRIDE: usize = 256;
pub const MAX_EVSM_CASCADES_PER_FRAME: usize = 16;
pub struct EvsmPass {
pub moment_write_layout_key: BindGroupLayoutKey,
pub blur_layout_key: BindGroupLayoutKey,
pub moment_write_pipeline_layout_key: PipelineLayoutKey,
pub blur_pipeline_layout_key: PipelineLayoutKey,
pub moment_write_pipeline_key: Option<ComputePipelineKey>,
pub blur_h_pipeline_key: Option<ComputePipelineKey>,
pub blur_v_pipeline_key: Option<ComputePipelineKey>,
pub params_buffer: web_sys::GpuBuffer,
pub params_bytes: Vec<u8>,
pub active_cascade_count: u32,
}
pub struct EvsmDescriptors {
pub moment_write_layout_key: BindGroupLayoutKey,
pub blur_layout_key: BindGroupLayoutKey,
pub moment_write_pipeline_layout_key: PipelineLayoutKey,
pub blur_pipeline_layout_key: PipelineLayoutKey,
pub modules: [web_sys::GpuShaderModule; 3],
pub params_buffer: web_sys::GpuBuffer,
pub params_bytes: Vec<u8>,
}
impl EvsmDescriptors {
pub fn validate_shader_futures(
&self,
) -> [impl std::future::Future<Output = Result<(), awsm_renderer_core::error::AwsmCoreError>> + '_;
3] {
[
self.modules[0].validate_shader(),
self.modules[1].validate_shader(),
self.modules[2].validate_shader(),
]
}
pub fn pipeline_cache_keys(&self, shader_keys: [ShaderKey; 3]) -> Vec<ComputePipelineCacheKey> {
vec![
ComputePipelineCacheKey::new(shader_keys[0], self.moment_write_pipeline_layout_key),
ComputePipelineCacheKey::new(shader_keys[1], self.blur_pipeline_layout_key),
ComputePipelineCacheKey::new(shader_keys[2], self.blur_pipeline_layout_key),
]
}
}
impl EvsmPass {
pub async fn new(
gpu: &AwsmRendererWebGpu,
bind_group_layouts: &mut BindGroupLayouts,
pipeline_layouts: &mut PipelineLayouts,
pipelines: &mut Pipelines,
shaders: &mut Shaders,
) -> Result<Self, AwsmShadowError> {
let descs = Self::build_descriptors(gpu, bind_group_layouts, pipeline_layouts)?;
let validation_results = futures::future::join_all(descs.validate_shader_futures()).await;
for result in validation_results {
result.map_err(AwsmShadowError::Core)?;
}
let shader_keys: [ShaderKey; 3] = [
shaders.insert_uncached(descs.modules[0].clone()),
shaders.insert_uncached(descs.modules[1].clone()),
shaders.insert_uncached(descs.modules[2].clone()),
];
let pipeline_cache_keys = descs.pipeline_cache_keys(shader_keys);
let resolved = pipelines
.compute
.ensure_keys(gpu, shaders, pipeline_layouts, pipeline_cache_keys)
.await?;
Ok(Self::from_resolved(descs, resolved))
}
pub fn build_descriptors(
gpu: &AwsmRendererWebGpu,
bind_group_layouts: &mut BindGroupLayouts,
pipeline_layouts: &mut PipelineLayouts,
) -> Result<EvsmDescriptors, AwsmShadowError> {
let moment_write_layout_key = bind_group_layouts.get_key(
gpu,
BindGroupLayoutCacheKey::new(vec![
BindGroupLayoutCacheKeyEntry {
resource: BindGroupLayoutResource::Texture(
TextureBindingLayout::new()
.with_sample_type(TextureSampleType::Depth)
.with_view_dimension(TextureViewDimension::N2dArray),
),
visibility_vertex: false,
visibility_fragment: false,
visibility_compute: true,
},
BindGroupLayoutCacheKeyEntry {
resource: BindGroupLayoutResource::StorageTexture(
StorageTextureBindingLayout::new(TextureFormat::Rgba16float)
.with_access(StorageTextureAccess::WriteOnly)
.with_view_dimension(TextureViewDimension::N2d),
),
visibility_vertex: false,
visibility_fragment: false,
visibility_compute: true,
},
BindGroupLayoutCacheKeyEntry {
resource: BindGroupLayoutResource::Buffer(
BufferBindingLayout::new()
.with_binding_type(BufferBindingType::Uniform)
.with_dynamic_offset(true),
),
visibility_vertex: false,
visibility_fragment: false,
visibility_compute: true,
},
]),
)?;
let moment_write_pipeline_layout_key = pipeline_layouts.get_key(
gpu,
bind_group_layouts,
PipelineLayoutCacheKey::new(vec![moment_write_layout_key]),
)?;
let blur_layout_key = bind_group_layouts.get_key(
gpu,
BindGroupLayoutCacheKey::new(vec![
BindGroupLayoutCacheKeyEntry {
resource: BindGroupLayoutResource::Texture(
TextureBindingLayout::new()
.with_sample_type(TextureSampleType::Float)
.with_view_dimension(TextureViewDimension::N2d),
),
visibility_vertex: false,
visibility_fragment: false,
visibility_compute: true,
},
BindGroupLayoutCacheKeyEntry {
resource: BindGroupLayoutResource::StorageTexture(
StorageTextureBindingLayout::new(TextureFormat::Rgba16float)
.with_access(StorageTextureAccess::WriteOnly)
.with_view_dimension(TextureViewDimension::N2d),
),
visibility_vertex: false,
visibility_fragment: false,
visibility_compute: true,
},
BindGroupLayoutCacheKeyEntry {
resource: BindGroupLayoutResource::Buffer(
BufferBindingLayout::new()
.with_binding_type(BufferBindingType::Uniform)
.with_dynamic_offset(true),
),
visibility_vertex: false,
visibility_fragment: false,
visibility_compute: true,
},
]),
)?;
let blur_pipeline_layout_key = pipeline_layouts.get_key(
gpu,
bind_group_layouts,
PipelineLayoutCacheKey::new(vec![blur_layout_key]),
)?;
let inline_specs: [(&str, &str); 3] = [
("Shadow EVSM Moment Write", MOMENT_WRITE_WGSL),
("Shadow EVSM Blur H", blur_h_wgsl()),
("Shadow EVSM Blur V", blur_v_wgsl()),
];
let mods: Vec<web_sys::GpuShaderModule> = inline_specs
.iter()
.map(|(label, code)| {
let desc: web_sys::GpuShaderModuleDescriptor =
ShaderModuleDescriptor::new(code, Some(label)).into();
gpu.compile_shader(&desc)
})
.collect();
let modules: [web_sys::GpuShaderModule; 3] =
[mods[0].clone(), mods[1].clone(), mods[2].clone()];
let params_buffer_size = EVSM_PARAMS_STRIDE * MAX_EVSM_CASCADES_PER_FRAME;
let params_buffer = gpu.create_buffer(
&BufferDescriptor::new(
Some("Shadow EVSM Params"),
params_buffer_size,
BufferUsage::new().with_uniform().with_copy_dst(),
)
.into(),
)?;
Ok(EvsmDescriptors {
moment_write_layout_key,
blur_layout_key,
moment_write_pipeline_layout_key,
blur_pipeline_layout_key,
modules,
params_buffer,
params_bytes: vec![0u8; params_buffer_size],
})
}
pub fn from_resolved(descs: EvsmDescriptors, resolved: Vec<ComputePipelineKey>) -> Self {
let (moment_write, blur_h, blur_v) = match resolved.len() {
3 => (Some(resolved[0]), Some(resolved[1]), Some(resolved[2])),
0 => (None, None, None),
other => {
debug_assert!(
other == 0 || other == 3,
"EvsmPass::from_resolved expects 0 or 3 resolved keys, got {other}"
);
(None, None, None)
}
};
Self {
moment_write_layout_key: descs.moment_write_layout_key,
blur_layout_key: descs.blur_layout_key,
moment_write_pipeline_layout_key: descs.moment_write_pipeline_layout_key,
blur_pipeline_layout_key: descs.blur_pipeline_layout_key,
moment_write_pipeline_key: moment_write,
blur_h_pipeline_key: blur_h,
blur_v_pipeline_key: blur_v,
params_buffer: descs.params_buffer,
params_bytes: descs.params_bytes,
active_cascade_count: 0,
}
}
pub fn params_dynamic_offset(index: u32) -> u32 {
index * EVSM_PARAMS_STRIDE as u32
}
#[allow(clippy::too_many_arguments)]
pub fn write_params_slot(
&mut self,
index: usize,
src_offset: [u32; 2],
src_size: [u32; 2],
dst_offset: [u32; 2],
dst_size: [u32; 2],
exponent: f32,
blur_radius: u32,
cascade_layer: u32,
) {
let base = index * EVSM_PARAMS_STRIDE;
let dst = &mut self.params_bytes[base..base + 48];
dst[0..4].copy_from_slice(&src_offset[0].to_ne_bytes());
dst[4..8].copy_from_slice(&src_offset[1].to_ne_bytes());
dst[8..12].copy_from_slice(&src_size[0].to_ne_bytes());
dst[12..16].copy_from_slice(&src_size[1].to_ne_bytes());
dst[16..20].copy_from_slice(&dst_offset[0].to_ne_bytes());
dst[20..24].copy_from_slice(&dst_offset[1].to_ne_bytes());
dst[24..28].copy_from_slice(&dst_size[0].to_ne_bytes());
dst[28..32].copy_from_slice(&dst_size[1].to_ne_bytes());
dst[32..36].copy_from_slice(&exponent.to_ne_bytes());
let radius = blur_radius.min(MAX_BLUR_RADIUS);
dst[36..40].copy_from_slice(&radius.to_ne_bytes());
dst[40..44].copy_from_slice(&cascade_layer.to_ne_bytes());
dst[44..48].copy_from_slice(&0u32.to_ne_bytes());
}
pub fn upload_params(&self, gpu: &AwsmRendererWebGpu) -> Result<(), AwsmShadowError> {
if self.active_cascade_count == 0 {
return Ok(());
}
let used = self.active_cascade_count as usize * EVSM_PARAMS_STRIDE;
gpu.write_buffer(
&self.params_buffer,
None,
&self.params_bytes[..used],
None,
None,
)?;
Ok(())
}
}
const MOMENT_WRITE_WGSL: &str = r#"
struct Params {
src_offset: vec2<u32>,
src_size: vec2<u32>,
dst_offset: vec2<u32>,
dst_size: vec2<u32>,
exponent: f32,
blur_radius: u32,
cascade_layer: u32,
_pad: u32,
}
@group(0) @binding(0) var src_depth: texture_depth_2d_array;
@group(0) @binding(1) var dst_moments: texture_storage_2d<rgba16float, write>;
@group(0) @binding(2) var<uniform> params: Params;
@compute @workgroup_size(8, 8, 1)
fn cs_main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x >= params.dst_size.x || gid.y >= params.dst_size.y) {
return;
}
// Map dst texel center back to src texel. Same-size atlases give
// 1:1; if PCF is larger, pick nearest source texel — depth is a
// sharp signal, bilinear averages across discontinuities.
let dst_uv = (vec2<f32>(f32(gid.x), f32(gid.y)) + vec2<f32>(0.5, 0.5))
/ vec2<f32>(f32(params.dst_size.x), f32(params.dst_size.y));
let src_xy = vec2<u32>(
params.src_offset.x + u32(dst_uv.x * f32(params.src_size.x)),
params.src_offset.y + u32(dst_uv.y * f32(params.src_size.y)),
);
let depth = textureLoad(
src_depth,
vec2<i32>(i32(src_xy.x), i32(src_xy.y)),
i32(params.cascade_layer),
0,
);
// Remap [0,1] → [-1,1] so the exponent space is symmetric.
let z = 2.0 * depth - 1.0;
let pos_exp = exp(params.exponent * z);
let neg_exp = -exp(-params.exponent * z);
let moments = vec4<f32>(pos_exp, pos_exp * pos_exp, neg_exp, neg_exp * neg_exp);
let store_xy = vec2<i32>(
i32(params.dst_offset.x + gid.x),
i32(params.dst_offset.y + gid.y),
);
textureStore(dst_moments, store_xy, moments);
}
"#;
const BLUR_COMMON_PREFIX: &str = r#"
struct Params {
src_offset: vec2<u32>,
src_size: vec2<u32>,
dst_offset: vec2<u32>,
dst_size: vec2<u32>,
exponent: f32,
blur_radius: u32,
}
@group(0) @binding(0) var src_tex: texture_2d<f32>;
@group(0) @binding(1) var dst_tex: texture_storage_2d<rgba16float, write>;
@group(0) @binding(2) var<uniform> params: Params;
// 9-tap Gaussian (centre + 8 sides), σ ≈ 8/3 covering ~99.7%. Shaders
// pick the first `radius+1` weights and re-normalise via `kernel_sum`.
const GAUSSIAN_W: array<f32, 9> = array<f32, 9>(
0.150946,
0.139148,
0.108878,
0.072448,
0.040951,
0.019696,
0.008049,
0.002800,
0.000829,
);
fn kernel_sum(radius: u32) -> f32 {
var s = GAUSSIAN_W[0];
for (var i = 1u; i <= radius; i = i + 1u) {
s = s + 2.0 * GAUSSIAN_W[i];
}
return s;
}
"#;
const BLUR_H_BODY: &str = r#"
@compute @workgroup_size(64, 1, 1)
fn cs_main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x >= params.dst_size.x || gid.y >= params.dst_size.y) {
return;
}
let centre_xy = vec2<i32>(
i32(params.dst_offset.x + gid.x),
i32(params.dst_offset.y + gid.y),
);
let radius = min(params.blur_radius, 8u);
let inv_sum = 1.0 / kernel_sum(radius);
var acc = textureLoad(src_tex, centre_xy, 0) * GAUSSIAN_W[0];
let lo = i32(params.dst_offset.x);
let hi = i32(params.dst_offset.x + params.dst_size.x) - 1;
for (var i = 1u; i <= radius; i = i + 1u) {
let w = GAUSSIAN_W[i];
let off_pos = clamp(centre_xy.x + i32(i), lo, hi);
let off_neg = clamp(centre_xy.x - i32(i), lo, hi);
acc = acc + textureLoad(src_tex, vec2<i32>(off_pos, centre_xy.y), 0) * w;
acc = acc + textureLoad(src_tex, vec2<i32>(off_neg, centre_xy.y), 0) * w;
}
textureStore(dst_tex, centre_xy, acc * inv_sum);
}
"#;
const BLUR_V_BODY: &str = r#"
@compute @workgroup_size(1, 64, 1)
fn cs_main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x >= params.dst_size.x || gid.y >= params.dst_size.y) {
return;
}
let centre_xy = vec2<i32>(
i32(params.dst_offset.x + gid.x),
i32(params.dst_offset.y + gid.y),
);
let radius = min(params.blur_radius, 8u);
let inv_sum = 1.0 / kernel_sum(radius);
var acc = textureLoad(src_tex, centre_xy, 0) * GAUSSIAN_W[0];
let lo = i32(params.dst_offset.y);
let hi = i32(params.dst_offset.y + params.dst_size.y) - 1;
for (var i = 1u; i <= radius; i = i + 1u) {
let w = GAUSSIAN_W[i];
let off_pos = clamp(centre_xy.y + i32(i), lo, hi);
let off_neg = clamp(centre_xy.y - i32(i), lo, hi);
acc = acc + textureLoad(src_tex, vec2<i32>(centre_xy.x, off_pos), 0) * w;
acc = acc + textureLoad(src_tex, vec2<i32>(centre_xy.x, off_neg), 0) * w;
}
textureStore(dst_tex, centre_xy, acc * inv_sum);
}
"#;
static BLUR_H_ONCE: OnceLock<String> = OnceLock::new();
static BLUR_V_ONCE: OnceLock<String> = OnceLock::new();
fn blur_h_wgsl() -> &'static str {
BLUR_H_ONCE
.get_or_init(|| format!("{}{}", BLUR_COMMON_PREFIX, BLUR_H_BODY))
.as_str()
}
fn blur_v_wgsl() -> &'static str {
BLUR_V_ONCE
.get_or_init(|| format!("{}{}", BLUR_COMMON_PREFIX, BLUR_V_BODY))
.as_str()
}