use wgpu::util::DeviceExt;
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub enum RenderStage {
PreRender,
PostProcess,
}
pub struct RenderPassContext<'a> {
pub device: &'a wgpu::Device,
pub queue: &'a wgpu::Queue,
pub target: &'a wgpu::TextureView,
pub viewport_width: u32,
pub viewport_height: u32,
pub texture_format: wgpu::TextureFormat,
pub scale_factor: f64,
}
#[cfg(not(target_arch = "wasm32"))]
pub trait CustomRenderPass: Send {
fn label(&self) -> &str;
fn stage(&self) -> RenderStage;
fn initialize(
&mut self,
device: &wgpu::Device,
queue: &wgpu::Queue,
format: wgpu::TextureFormat,
);
fn render(&mut self, ctx: &RenderPassContext);
fn resize(&mut self, _device: &wgpu::Device, _width: u32, _height: u32) {}
fn enabled(&self) -> bool {
true
}
}
#[cfg(target_arch = "wasm32")]
pub trait CustomRenderPass {
fn label(&self) -> &str;
fn stage(&self) -> RenderStage;
fn initialize(
&mut self,
device: &wgpu::Device,
queue: &wgpu::Queue,
format: wgpu::TextureFormat,
);
fn render(&mut self, ctx: &RenderPassContext);
fn resize(&mut self, _device: &wgpu::Device, _width: u32, _height: u32) {}
fn enabled(&self) -> bool {
true
}
}
pub(crate) struct CustomPassManager {
passes: Vec<Box<dyn CustomRenderPass>>,
}
impl CustomPassManager {
pub fn new() -> Self {
Self { passes: Vec::new() }
}
pub fn register(&mut self, pass: Box<dyn CustomRenderPass>) {
self.passes.push(pass);
}
pub fn initialize_pending(
&mut self,
device: &wgpu::Device,
queue: &wgpu::Queue,
format: wgpu::TextureFormat,
) {
for pass in &mut self.passes {
pass.initialize(device, queue, format);
}
}
pub fn execute_stage(&mut self, stage: RenderStage, ctx: &RenderPassContext) {
for pass in &mut self.passes {
if pass.stage() == stage && pass.enabled() {
pass.render(ctx);
}
}
}
pub fn resize(&mut self, device: &wgpu::Device, width: u32, height: u32) {
for pass in &mut self.passes {
pass.resize(device, width, height);
}
}
pub fn has_passes(&self, stage: RenderStage) -> bool {
self.passes
.iter()
.any(|p| p.stage() == stage && p.enabled())
}
pub fn remove(&mut self, label: &str) -> bool {
let len_before = self.passes.len();
self.passes.retain(|p| p.label() != label);
self.passes.len() < len_before
}
}
enum BindGroupEntry<'a> {
UniformBuffer(wgpu::BindingResource<'a>),
StorageBuffer {
resource: wgpu::BindingResource<'a>,
read_only: bool,
},
Texture(&'a wgpu::TextureView),
StorageTexture {
view: &'a wgpu::TextureView,
format: wgpu::TextureFormat,
access: wgpu::StorageTextureAccess,
},
Sampler(&'a wgpu::Sampler),
ComparisonSampler(&'a wgpu::Sampler),
}
pub struct BindGroupBuilder<'a> {
label: &'a str,
entries: Vec<BindGroupEntry<'a>>,
visibility: wgpu::ShaderStages,
}
impl<'a> BindGroupBuilder<'a> {
pub fn new(label: &'a str) -> Self {
Self {
label,
entries: Vec::new(),
visibility: wgpu::ShaderStages::VERTEX | wgpu::ShaderStages::FRAGMENT,
}
}
pub fn with_visibility(mut self, visibility: wgpu::ShaderStages) -> Self {
self.visibility = visibility;
self
}
pub fn add_uniform_buffer(&mut self, resource: wgpu::BindingResource<'a>) -> &mut Self {
self.entries.push(BindGroupEntry::UniformBuffer(resource));
self
}
pub fn add_storage_buffer(
&mut self,
resource: wgpu::BindingResource<'a>,
read_only: bool,
) -> &mut Self {
self.entries.push(BindGroupEntry::StorageBuffer {
resource,
read_only,
});
self
}
pub fn add_texture(&mut self, view: &'a wgpu::TextureView) -> &mut Self {
self.entries.push(BindGroupEntry::Texture(view));
self
}
pub fn add_storage_texture(
&mut self,
view: &'a wgpu::TextureView,
format: wgpu::TextureFormat,
access: wgpu::StorageTextureAccess,
) -> &mut Self {
self.entries.push(BindGroupEntry::StorageTexture {
view,
format,
access,
});
self
}
pub fn add_sampler(&mut self, sampler: &'a wgpu::Sampler) -> &mut Self {
self.entries.push(BindGroupEntry::Sampler(sampler));
self
}
pub fn add_comparison_sampler(&mut self, sampler: &'a wgpu::Sampler) -> &mut Self {
self.entries
.push(BindGroupEntry::ComparisonSampler(sampler));
self
}
pub fn build(self, device: &wgpu::Device) -> (wgpu::BindGroupLayout, wgpu::BindGroup) {
let vis = self.visibility;
let layout_entries: Vec<wgpu::BindGroupLayoutEntry> = self
.entries
.iter()
.enumerate()
.map(|(i, entry)| {
let ty = match entry {
BindGroupEntry::UniformBuffer(_) => wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Uniform,
has_dynamic_offset: false,
min_binding_size: None,
},
BindGroupEntry::StorageBuffer { read_only, .. } => wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage {
read_only: *read_only,
},
has_dynamic_offset: false,
min_binding_size: None,
},
BindGroupEntry::Texture(_) => wgpu::BindingType::Texture {
multisampled: false,
view_dimension: wgpu::TextureViewDimension::D2,
sample_type: wgpu::TextureSampleType::Float { filterable: true },
},
BindGroupEntry::StorageTexture { format, access, .. } => {
wgpu::BindingType::StorageTexture {
access: *access,
format: *format,
view_dimension: wgpu::TextureViewDimension::D2,
}
}
BindGroupEntry::Sampler(_) => {
wgpu::BindingType::Sampler(wgpu::SamplerBindingType::Filtering)
}
BindGroupEntry::ComparisonSampler(_) => {
wgpu::BindingType::Sampler(wgpu::SamplerBindingType::Comparison)
}
};
wgpu::BindGroupLayoutEntry {
binding: i as u32,
visibility: vis,
ty,
count: None,
}
})
.collect();
let layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some(self.label),
entries: &layout_entries,
});
let bind_entries: Vec<wgpu::BindGroupEntry> = self
.entries
.iter()
.enumerate()
.map(|(i, entry)| wgpu::BindGroupEntry {
binding: i as u32,
resource: match entry {
BindGroupEntry::UniformBuffer(r)
| BindGroupEntry::StorageBuffer { resource: r, .. } => r.clone(),
BindGroupEntry::Texture(v) | BindGroupEntry::StorageTexture { view: v, .. } => {
wgpu::BindingResource::TextureView(v)
}
BindGroupEntry::Sampler(s) | BindGroupEntry::ComparisonSampler(s) => {
wgpu::BindingResource::Sampler(s)
}
},
})
.collect();
let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some(self.label),
layout: &layout,
entries: &bind_entries,
});
(layout, bind_group)
}
}
pub struct ComputeDispatch<'a> {
pub pipeline: &'a wgpu::ComputePipeline,
pub bind_group: &'a wgpu::BindGroup,
pub workgroups: (u32, u32, u32),
pub label: &'a str,
}
impl<'a> ComputeDispatch<'a> {
pub fn execute(&self, device: &wgpu::Device, queue: &wgpu::Queue) {
let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some(self.label),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some(self.label),
timestamp_writes: None,
});
pass.set_pipeline(self.pipeline);
pass.set_bind_group(0, self.bind_group, &[]);
pass.dispatch_workgroups(self.workgroups.0, self.workgroups.1, self.workgroups.2);
}
queue.submit(std::iter::once(encoder.finish()));
}
}
pub fn create_compute_pipeline(
device: &wgpu::Device,
label: &str,
wgsl_source: &str,
entry_point: &str,
bind_group_layout: &wgpu::BindGroupLayout,
) -> wgpu::ComputePipeline {
let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some(label),
source: wgpu::ShaderSource::Wgsl(wgsl_source.into()),
});
let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some(label),
bind_group_layouts: &[bind_group_layout],
push_constant_ranges: &[],
});
device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some(label),
layout: Some(&pipeline_layout),
module: &shader,
entry_point: Some(entry_point),
compilation_options: Default::default(),
cache: None,
})
}
pub trait PostProcessEffect: Send {
fn label(&self) -> &str;
fn initialize(
&mut self,
device: &wgpu::Device,
queue: &wgpu::Queue,
format: wgpu::TextureFormat,
);
fn apply(
&mut self,
device: &wgpu::Device,
queue: &wgpu::Queue,
input: &wgpu::TextureView,
output: &wgpu::TextureView,
width: u32,
height: u32,
);
fn resize(&mut self, _device: &wgpu::Device, _width: u32, _height: u32) {}
fn enabled(&self) -> bool {
true
}
}
pub struct PostProcessChain {
label: String,
effects: Vec<Box<dyn PostProcessEffect>>,
ping: Option<(wgpu::Texture, wgpu::TextureView)>,
pong: Option<(wgpu::Texture, wgpu::TextureView)>,
copy_pipeline: Option<wgpu::RenderPipeline>,
copy_bind_group_layout: Option<wgpu::BindGroupLayout>,
copy_sampler: Option<wgpu::Sampler>,
texture_format: wgpu::TextureFormat,
size: (u32, u32),
}
impl PostProcessChain {
pub fn new(label: impl Into<String>) -> Self {
Self {
label: label.into(),
effects: Vec::new(),
ping: None,
pong: None,
copy_pipeline: None,
copy_bind_group_layout: None,
copy_sampler: None,
texture_format: wgpu::TextureFormat::Bgra8Unorm,
size: (0, 0),
}
}
pub fn add_effect(&mut self, effect: Box<dyn PostProcessEffect>) {
self.effects.push(effect);
}
pub fn remove_effect(&mut self, label: &str) -> bool {
let before = self.effects.len();
self.effects.retain(|e| e.label() != label);
self.effects.len() < before
}
fn create_texture(
device: &wgpu::Device,
width: u32,
height: u32,
format: wgpu::TextureFormat,
label: &str,
) -> (wgpu::Texture, wgpu::TextureView) {
let tex = device.create_texture(&wgpu::TextureDescriptor {
label: Some(label),
size: wgpu::Extent3d {
width: width.max(1),
height: height.max(1),
depth_or_array_layers: 1,
},
mip_level_count: 1,
sample_count: 1,
dimension: wgpu::TextureDimension::D2,
format,
usage: wgpu::TextureUsages::RENDER_ATTACHMENT
| wgpu::TextureUsages::TEXTURE_BINDING
| wgpu::TextureUsages::COPY_SRC
| wgpu::TextureUsages::COPY_DST,
view_formats: &[],
});
let view = tex.create_view(&wgpu::TextureViewDescriptor::default());
(tex, view)
}
fn ensure_textures(&mut self, device: &wgpu::Device, width: u32, height: u32) {
if self.size == (width, height) && self.ping.is_some() {
return;
}
self.size = (width, height);
self.ping = Some(Self::create_texture(
device,
width,
height,
self.texture_format,
"postprocess_ping",
));
self.pong = Some(Self::create_texture(
device,
width,
height,
self.texture_format,
"postprocess_pong",
));
}
fn ensure_copy_pipeline(&mut self, device: &wgpu::Device, format: wgpu::TextureFormat) {
if self.copy_pipeline.is_some() {
return;
}
let shader_src = r#"
@group(0) @binding(0) var src_texture: texture_2d<f32>;
@group(0) @binding(1) var src_sampler: sampler;
struct VertexOutput {
@builtin(position) position: vec4<f32>,
@location(0) uv: vec2<f32>,
}
@vertex
fn vs_main(@builtin(vertex_index) vi: u32) -> VertexOutput {
var positions = array<vec2<f32>, 6>(
vec2(-1.0, -1.0), vec2(1.0, -1.0), vec2(-1.0, 1.0),
vec2(-1.0, 1.0), vec2(1.0, -1.0), vec2(1.0, 1.0),
);
var out: VertexOutput;
out.position = vec4(positions[vi], 0.0, 1.0);
out.uv = positions[vi] * vec2(0.5, -0.5) + 0.5;
return out;
}
@fragment
fn fs_main(input: VertexOutput) -> @location(0) vec4<f32> {
return textureSample(src_texture, src_sampler, input.uv);
}
"#;
let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("PostProcess Copy Shader"),
source: wgpu::ShaderSource::Wgsl(shader_src.into()),
});
let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some("PostProcess Copy Layout"),
entries: &[
wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStages::FRAGMENT,
ty: wgpu::BindingType::Texture {
multisampled: false,
view_dimension: wgpu::TextureViewDimension::D2,
sample_type: wgpu::TextureSampleType::Float { filterable: true },
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 1,
visibility: wgpu::ShaderStages::FRAGMENT,
ty: wgpu::BindingType::Sampler(wgpu::SamplerBindingType::Filtering),
count: None,
},
],
});
let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some("PostProcess Copy Pipeline Layout"),
bind_group_layouts: &[&bind_group_layout],
push_constant_ranges: &[],
});
let pipeline = device.create_render_pipeline(&wgpu::RenderPipelineDescriptor {
label: Some("PostProcess Copy Pipeline"),
layout: Some(&pipeline_layout),
vertex: wgpu::VertexState {
module: &shader,
entry_point: Some("vs_main"),
buffers: &[],
compilation_options: Default::default(),
},
fragment: Some(wgpu::FragmentState {
module: &shader,
entry_point: Some("fs_main"),
targets: &[Some(wgpu::ColorTargetState {
format,
blend: None,
write_mask: wgpu::ColorWrites::ALL,
})],
compilation_options: Default::default(),
}),
primitive: wgpu::PrimitiveState::default(),
depth_stencil: None,
multisample: wgpu::MultisampleState::default(),
multiview: None,
cache: None,
});
let sampler = device.create_sampler(&wgpu::SamplerDescriptor {
label: Some("PostProcess Copy Sampler"),
mag_filter: wgpu::FilterMode::Linear,
min_filter: wgpu::FilterMode::Linear,
..Default::default()
});
self.copy_pipeline = Some(pipeline);
self.copy_bind_group_layout = Some(bind_group_layout);
self.copy_sampler = Some(sampler);
}
fn blit(
&self,
device: &wgpu::Device,
queue: &wgpu::Queue,
src: &wgpu::TextureView,
dst: &wgpu::TextureView,
) {
let layout = self.copy_bind_group_layout.as_ref().unwrap();
let sampler = self.copy_sampler.as_ref().unwrap();
let pipeline = self.copy_pipeline.as_ref().unwrap();
let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("PostProcess Blit"),
layout,
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: wgpu::BindingResource::TextureView(src),
},
wgpu::BindGroupEntry {
binding: 1,
resource: wgpu::BindingResource::Sampler(sampler),
},
],
});
let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("PostProcess Blit"),
});
{
let mut pass = encoder.begin_render_pass(&wgpu::RenderPassDescriptor {
label: Some("PostProcess Blit Pass"),
color_attachments: &[Some(wgpu::RenderPassColorAttachment {
view: dst,
resolve_target: None,
depth_slice: None,
ops: wgpu::Operations {
load: wgpu::LoadOp::Clear(wgpu::Color::TRANSPARENT),
store: wgpu::StoreOp::Store,
},
})],
depth_stencil_attachment: None,
..Default::default()
});
pass.set_pipeline(pipeline);
pass.set_bind_group(0, &bind_group, &[]);
pass.draw(0..6, 0..1);
}
queue.submit(std::iter::once(encoder.finish()));
}
}
impl CustomRenderPass for PostProcessChain {
fn label(&self) -> &str {
&self.label
}
fn stage(&self) -> RenderStage {
RenderStage::PostProcess
}
fn initialize(
&mut self,
device: &wgpu::Device,
queue: &wgpu::Queue,
format: wgpu::TextureFormat,
) {
self.texture_format = format;
self.ensure_copy_pipeline(device, format);
for effect in &mut self.effects {
effect.initialize(device, queue, format);
}
}
fn render(&mut self, ctx: &RenderPassContext) {
let active: Vec<usize> = self
.effects
.iter()
.enumerate()
.filter(|(_, e)| e.enabled())
.map(|(i, _)| i)
.collect();
if active.is_empty() {
return;
}
self.ensure_textures(ctx.device, ctx.viewport_width, ctx.viewport_height);
let (_, ping_view) = self.ping.as_ref().unwrap();
let (_, pong_view) = self.pong.as_ref().unwrap();
self.blit(ctx.device, ctx.queue, ctx.target, ping_view);
let views = [ping_view as *const _, pong_view as *const _];
for (step, &idx) in active.iter().enumerate() {
let is_last = step == active.len() - 1;
let input_view = unsafe { &*(views[step % 2]) };
let output_view = if is_last {
ctx.target
} else {
unsafe { &*(views[(step + 1) % 2]) }
};
self.effects[idx].apply(
ctx.device,
ctx.queue,
input_view,
output_view,
ctx.viewport_width,
ctx.viewport_height,
);
}
}
fn resize(&mut self, device: &wgpu::Device, width: u32, height: u32) {
self.ping = None;
self.pong = None;
self.size = (0, 0);
for effect in &mut self.effects {
effect.resize(device, width, height);
}
}
fn enabled(&self) -> bool {
self.effects.iter().any(|e| e.enabled())
}
}
pub fn create_buffer(
device: &wgpu::Device,
label: &str,
data: &[u8],
usage: wgpu::BufferUsages,
) -> wgpu::Buffer {
device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some(label),
contents: data,
usage,
})
}
pub fn create_fullscreen_pipeline(
device: &wgpu::Device,
label: &str,
wgsl_source: &str,
format: wgpu::TextureFormat,
bind_group_layout: &wgpu::BindGroupLayout,
) -> wgpu::RenderPipeline {
let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some(label),
source: wgpu::ShaderSource::Wgsl(wgsl_source.into()),
});
let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some(label),
bind_group_layouts: &[bind_group_layout],
push_constant_ranges: &[],
});
device.create_render_pipeline(&wgpu::RenderPipelineDescriptor {
label: Some(label),
layout: Some(&pipeline_layout),
vertex: wgpu::VertexState {
module: &shader,
entry_point: Some("vs_main"),
buffers: &[],
compilation_options: Default::default(),
},
fragment: Some(wgpu::FragmentState {
module: &shader,
entry_point: Some("fs_main"),
targets: &[Some(wgpu::ColorTargetState {
format,
blend: Some(wgpu::BlendState::ALPHA_BLENDING),
write_mask: wgpu::ColorWrites::ALL,
})],
compilation_options: Default::default(),
}),
primitive: wgpu::PrimitiveState::default(),
depth_stencil: None,
multisample: wgpu::MultisampleState::default(),
multiview: None,
cache: None,
})
}