use crate::Core;
use std::collections::HashMap;
use wgpu;
pub struct MultiPassManager {
buffers: HashMap<String, (wgpu::Texture, wgpu::Texture)>,
bind_groups: HashMap<String, (wgpu::BindGroup, wgpu::BindGroup)>,
write_side: HashMap<String, bool>,
output_texture: wgpu::Texture,
output_bind_group: wgpu::BindGroup,
storage_layout: wgpu::BindGroupLayout,
input_layout: wgpu::BindGroupLayout,
width: u32,
height: u32,
buffer_dimensions: HashMap<String, (u32, u32)>,
buffer_resolution: HashMap<String, Option<[u32; 2]>>,
buffer_scale: HashMap<String, Option<f32>>,
texture_format: wgpu::TextureFormat,
max_input_deps: usize,
}
impl MultiPassManager {
pub fn new(
core: &Core,
buffer_names: &[String],
texture_format: wgpu::TextureFormat,
_storage_layout: wgpu::BindGroupLayout,
max_input_deps: usize,
passes: &[crate::compute::PassDescription],
) -> Self {
let width = core.size.width;
let height = core.size.height;
let storage_layout =
core.device
.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some("Multi-Pass Storage Layout"),
entries: &[wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::StorageTexture {
access: wgpu::StorageTextureAccess::WriteOnly,
format: texture_format,
view_dimension: wgpu::TextureViewDimension::D2,
},
count: None,
}],
});
let input_layout = Self::create_input_layout(&core.device, max_input_deps);
let mut buffer_resolution: HashMap<String, Option<[u32; 2]>> = HashMap::new();
let mut buffer_scale: HashMap<String, Option<f32>> = HashMap::new();
let mut buffer_dimensions: HashMap<String, (u32, u32)> = HashMap::new();
for pass in passes {
buffer_resolution.insert(pass.name.clone(), pass.resolution);
buffer_scale.insert(pass.name.clone(), pass.resolution_scale);
let (bw, bh) = Self::compute_buffer_dims(
width, height, pass.resolution, pass.resolution_scale,
);
buffer_dimensions.insert(pass.name.clone(), (bw, bh));
}
let mut buffers = HashMap::new();
let mut bind_groups = HashMap::new();
for name in buffer_names {
let (bw, bh) = buffer_dimensions.get(name).copied().unwrap_or((width, height));
let texture0 = Self::create_storage_texture(
&core.device,
bw,
bh,
texture_format,
&format!("{name}_0"),
);
let texture1 = Self::create_storage_texture(
&core.device,
bw,
bh,
texture_format,
&format!("{name}_1"),
);
let bind_group0 = Self::create_storage_bind_group(
&core.device,
&storage_layout,
&texture0,
&format!("{name}_0_bind"),
);
let bind_group1 = Self::create_storage_bind_group(
&core.device,
&storage_layout,
&texture1,
&format!("{name}_1_bind"),
);
buffers.insert(name.clone(), (texture0, texture1));
bind_groups.insert(name.clone(), (bind_group0, bind_group1));
}
let output_texture = Self::create_storage_texture(
&core.device,
width,
height,
texture_format,
"multipass_output",
);
let output_bind_group = Self::create_storage_bind_group(
&core.device,
&storage_layout,
&output_texture,
"output_bind",
);
let mut write_side = HashMap::new();
for name in buffer_names {
write_side.insert(name.clone(), false);
}
Self {
buffers,
bind_groups,
write_side,
output_texture,
output_bind_group,
storage_layout,
input_layout,
width,
height,
buffer_dimensions,
buffer_resolution,
buffer_scale,
texture_format,
max_input_deps,
}
}
fn compute_buffer_dims(
screen_w: u32,
screen_h: u32,
resolution: Option<[u32; 2]>,
scale: Option<f32>,
) -> (u32, u32) {
if let Some([w, h]) = resolution {
(w.max(1), h.max(1))
} else if let Some(s) = scale {
((screen_w as f32 * s).round().max(1.0) as u32,
(screen_h as f32 * s).round().max(1.0) as u32)
} else {
(screen_w, screen_h)
}
}
fn create_storage_texture(
device: &wgpu::Device,
width: u32,
height: u32,
format: wgpu::TextureFormat,
label: &str,
) -> wgpu::Texture {
device.create_texture(&wgpu::TextureDescriptor {
label: Some(label),
size: wgpu::Extent3d {
width,
height,
depth_or_array_layers: 1,
},
mip_level_count: 1,
sample_count: 1,
dimension: wgpu::TextureDimension::D2,
format,
usage: wgpu::TextureUsages::STORAGE_BINDING
| wgpu::TextureUsages::TEXTURE_BINDING
| wgpu::TextureUsages::COPY_SRC
| wgpu::TextureUsages::COPY_DST,
view_formats: &[],
})
}
fn create_storage_bind_group(
device: &wgpu::Device,
layout: &wgpu::BindGroupLayout,
texture: &wgpu::Texture,
label: &str,
) -> wgpu::BindGroup {
let view = texture.create_view(&wgpu::TextureViewDescriptor::default());
device.create_bind_group(&wgpu::BindGroupDescriptor {
layout,
entries: &[wgpu::BindGroupEntry {
binding: 0,
resource: wgpu::BindingResource::TextureView(&view),
}],
label: Some(label),
})
}
fn create_input_layout(device: &wgpu::Device, max_input_deps: usize) -> wgpu::BindGroupLayout {
let mut entries = Vec::with_capacity(max_input_deps * 2);
for i in 0..max_input_deps {
entries.push(wgpu::BindGroupLayoutEntry {
binding: (i * 2) as u32,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Texture {
multisampled: false,
sample_type: wgpu::TextureSampleType::Float { filterable: true },
view_dimension: wgpu::TextureViewDimension::D2,
},
count: None,
});
entries.push(wgpu::BindGroupLayoutEntry {
binding: (i * 2 + 1) as u32,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Sampler(wgpu::SamplerBindingType::Filtering),
count: None,
});
}
device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
entries: &entries,
label: Some("Multi-Pass Input Layout"),
})
}
pub fn get_write_bind_group(&self, buffer_name: &str) -> &wgpu::BindGroup {
let bind_groups = self.bind_groups.get(buffer_name).expect("Buffer not found");
let last_wrote_0 = self.write_side.get(buffer_name).copied().unwrap_or(false);
if last_wrote_0 {
&bind_groups.1 } else {
&bind_groups.0 }
}
pub fn get_write_texture(&self, buffer_name: &str) -> &wgpu::Texture {
let textures = self.buffers.get(buffer_name).expect("Buffer not found");
let last_wrote_0 = self.write_side.get(buffer_name).copied().unwrap_or(false);
if last_wrote_0 {
&textures.1
} else {
&textures.0
}
}
pub fn get_read_texture(&self, buffer_name: &str) -> &wgpu::Texture {
let textures = self.buffers.get(buffer_name).expect("Buffer not found");
let last_wrote_0 = self.write_side.get(buffer_name).copied().unwrap_or(false);
if last_wrote_0 {
&textures.0 } else {
&textures.1 }
}
pub fn get_output_bind_group(&self) -> &wgpu::BindGroup {
&self.output_bind_group
}
pub fn get_output_texture(&self) -> &wgpu::Texture {
&self.output_texture
}
pub fn mark_written(&mut self, buffer_name: &str) {
if let Some(side) = self.write_side.get_mut(buffer_name) {
*side = !*side;
}
}
pub fn flip_buffers(&mut self) {
for side in self.write_side.values_mut() {
*side = !*side;
}
}
pub fn clear_all(&mut self, core: &Core) {
let names: Vec<String> = self.buffers.keys().cloned().collect();
for name in &names {
let (bw, bh) = self.buffer_dimensions.get(name).copied()
.unwrap_or((self.width, self.height));
let texture0 = Self::create_storage_texture(
&core.device, bw, bh, self.texture_format, &format!("{name}_0"),
);
let texture1 = Self::create_storage_texture(
&core.device, bw, bh, self.texture_format, &format!("{name}_1"),
);
let bind_group0 = Self::create_storage_bind_group(
&core.device, &self.storage_layout, &texture0, &format!("{name}_0_bind"),
);
let bind_group1 = Self::create_storage_bind_group(
&core.device, &self.storage_layout, &texture1, &format!("{name}_1_bind"),
);
self.buffers.insert(name.clone(), (texture0, texture1));
self.bind_groups.insert(name.clone(), (bind_group0, bind_group1));
}
self.output_texture = Self::create_storage_texture(
&core.device,
self.width,
self.height,
self.texture_format,
"multipass_output",
);
self.output_bind_group = Self::create_storage_bind_group(
&core.device,
&self.storage_layout,
&self.output_texture,
"output_bind",
);
for side in self.write_side.values_mut() {
*side = false;
}
}
pub fn resize(&mut self, core: &Core, width: u32, height: u32) {
self.width = width;
self.height = height;
let names: Vec<String> = self.buffer_dimensions.keys().cloned().collect();
for name in names {
let resolution = self.buffer_resolution.get(&name).copied().flatten();
let scale = self.buffer_scale.get(&name).copied().flatten();
let (bw, bh) = Self::compute_buffer_dims(width, height, resolution, scale);
self.buffer_dimensions.insert(name, (bw, bh));
}
self.clear_all(core);
}
pub fn get_input_layout(&self) -> &wgpu::BindGroupLayout {
&self.input_layout
}
pub fn get_storage_layout(&self) -> &wgpu::BindGroupLayout {
&self.storage_layout
}
pub fn get_write_side(&self, buffer_name: &str) -> bool {
self.write_side.get(buffer_name).copied().unwrap_or(false)
}
pub fn get_buffer_pair(&self, buffer_name: &str) -> Option<&(wgpu::Texture, wgpu::Texture)> {
self.buffers.get(buffer_name)
}
pub fn first_buffer_name(&self) -> Option<&String> {
self.buffers.keys().next()
}
pub fn max_input_deps(&self) -> usize {
self.max_input_deps
}
pub fn get_buffer_dimensions(&self, buffer_name: &str) -> (u32, u32) {
self.buffer_dimensions.get(buffer_name).copied()
.unwrap_or((self.width, self.height))
}
}