#![allow(dead_code)]
use std::collections::{HashMap};
use std::path::{PathBuf};
use wgpu::*;
use crate::pipelines::hash_defines;
use crate::shader_preprocessing::compile_wgsl;
pub struct ComputePipelineOptions {
pub dispatch_size: [u32; 3],
}
impl Default for ComputePipelineOptions {
fn default() -> Self {
Self {
dispatch_size: [1, 1, 1],
}
}
}
#[derive(Clone, PartialEq, Eq, Hash)]
struct PipelineKey {
shader_path: String,
input_specs: Vec<(TextureFormat, u32, bool)>, output_formats: Vec<TextureFormat>,
buffer_bindings: Vec<BufferBindingType>,
defines_hash: u64,
}
struct CachedPipeline {
pipeline: ComputePipeline,
bind_group_layouts: [BindGroupLayout; 3],
}
#[derive(Debug)]
pub struct BufferSet {
pub buffer: Buffer,
pub read_only: bool,
}
impl BufferSet {
pub fn from_uniform(uniform_buffer: &Buffer) -> Self {
if uniform_buffer.usage().contains(BufferUsages::UNIFORM) {
Self {
buffer: uniform_buffer.clone(),
read_only: false,
}
} else {
panic!("Buffer is not a uniform buffer in BufferSet::from_uniform()");
}
}
pub fn from_storage(storage_buffer: &Buffer, read_only: bool) -> Self {
if storage_buffer.usage().contains(BufferUsages::STORAGE) {
Self {
buffer: storage_buffer.clone(),
read_only,
}
} else {
panic!("Buffer is not a storage buffer in BufferSet::from_storage()");
}
}
}
pub struct ComputeSystem {
device: Device,
queue: Queue,
pipeline_cache: HashMap<PipelineKey, CachedPipeline>,
filtering_sampler: Sampler,
non_filtering_sampler: Sampler,
}
impl ComputeSystem {
pub fn new(device: &Device, queue: &Queue) -> Self {
let device = device.clone();
let queue = queue.clone();
let filtering_sampler = device.create_sampler(&SamplerDescriptor {
label: Some("compute_filtering_sampler"),
mag_filter: FilterMode::Linear,
min_filter: FilterMode::Linear,
address_mode_u: AddressMode::ClampToEdge,
address_mode_v: AddressMode::ClampToEdge,
..Default::default()
});
let non_filtering_sampler = device.create_sampler(&SamplerDescriptor {
label: Some("compute_non_filtering_sampler"),
mag_filter: FilterMode::Nearest,
min_filter: FilterMode::Nearest,
address_mode_u: AddressMode::ClampToEdge,
address_mode_v: AddressMode::ClampToEdge,
..Default::default()
});
Self {
device,
queue,
pipeline_cache: HashMap::new(),
filtering_sampler,
non_filtering_sampler,
}
}
pub(crate) fn compute(
&mut self,
encoder: Option<&mut CommandEncoder>,
label: &str,
input_views: Vec<&TextureView>,
output_views: Vec<&TextureView>,
shader_path: &PathBuf,
options: ComputePipelineOptions,
buffer_sets: &[BufferSet],
defines: &HashMap<String, bool>,
) {
let encoder_is_none = encoder.is_none();
#[cfg(debug_assertions)]
{
if encoder_is_none {
let internal_counters = self.device.get_internal_counters();
let command_encoder_count = internal_counters.hal.command_encoders.read();
if command_encoder_count > 0 {
eprintln!(
"\n
You're creating a NEW CommandEncoder for this compute dispatch while {} encoder(s) are already open!\n\
This is a classic recipe for desynchronization disasters, resource hazards, validation errors, or straight-up crashes.\n\
The GPU might execute the submitted commands out of order relative to your other in-flight encoders.\n\
FIX: Pass an existing encoder with Some(&mut your_encoder) instead of None.\n\
Do NOT ignore this unless you really know what you're doing.\n",
command_encoder_count
);
}
}
}
let input_specs: Vec<_> = input_views
.iter()
.map(|v| {
let tex = v.texture();
let format = tex.format();
let sample_count = tex.sample_count();
let is_filterable = self.is_format_filterable(format, sample_count);
(format, sample_count, is_filterable)
})
.collect();
let output_formats: Vec<_> = output_views.iter().map(|v| v.texture().format()).collect();
let buffer_bindings: Vec<_> = buffer_sets
.iter()
.map(|b| buffer_binding_type(b))
.collect();
let key = PipelineKey {
shader_path: shader_path.to_str().unwrap_or("").to_string(),
input_specs: input_specs.clone(),
output_formats: output_formats.clone(),
buffer_bindings: buffer_bindings.clone(),
defines_hash: hash_defines(defines)
};
if !self.pipeline_cache.contains_key(&key) {
let cached = self.create_pipeline(
shader_path,
&input_specs,
&output_formats,
&buffer_bindings,
defines
);
self.pipeline_cache.insert(key.clone(), cached);
}
let cached = self.pipeline_cache.get(&key).unwrap();
let use_filtering = input_specs.iter().all(|(format, sample_count, is_filterable)| {
format.has_depth_aspect() || *sample_count > 1 || *is_filterable
});
let input_bg = self.create_input_bind_group(&cached.bind_group_layouts[0], &input_views, use_filtering);
let output_bg = self.create_output_bind_group(&cached.bind_group_layouts[1], &output_views);
let uniform_bg = self.create_buffer_bind_group(&cached.bind_group_layouts[2], buffer_sets, Some("Compute Buffers Bind Group"));
let mut owned_encoder = None;
let enc = match encoder {
Some(e) => e,
None => {
owned_encoder = Some(
self.device
.create_command_encoder(&CommandEncoderDescriptor { label: Some(label) }),
);
owned_encoder.as_mut().unwrap()
}
};
{
let mut pass = enc.begin_compute_pass(&ComputePassDescriptor {
label: Some(label),
timestamp_writes: None,
});
pass.set_pipeline(&cached.pipeline);
pass.set_bind_group(0, &input_bg, &[]);
pass.set_bind_group(1, &output_bg, &[]);
pass.set_bind_group(2, &uniform_bg, &[]);
pass.dispatch_workgroups(
options.dispatch_size[0],
options.dispatch_size[1],
options.dispatch_size[2],
);
}
if encoder_is_none {
let finished = owned_encoder.unwrap().finish();
self.queue.submit(std::iter::once(finished));
}
}
fn create_pipeline(
&self,
shader_path: &PathBuf,
input_specs: &[(TextureFormat, u32, bool)], output_formats: &[TextureFormat],
buffer_bindings: &[BufferBindingType],
defines: &HashMap<String, bool>,
) -> CachedPipeline {
let shader = compile_wgsl(&self.device, shader_path, defines);
let samplable_textures: Vec<_> = input_specs
.iter()
.filter(|(format, sample_count, _)| {
!is_integer_format(*format) && *sample_count == 1
})
.collect();
let use_filtering = samplable_textures.is_empty()
|| samplable_textures.iter().all(|(format, _, is_filterable)| {
format.has_depth_aspect() || *is_filterable
});
let mut input_entries: Vec<BindGroupLayoutEntry> = input_specs
.iter()
.enumerate()
.map(|(i, (format, sample_count, is_filterable))| {
let multisampled = *sample_count > 1;
let sample_type = get_texture_sample_type(*format, *is_filterable, multisampled);
BindGroupLayoutEntry {
binding: i as u32,
visibility: ShaderStages::COMPUTE,
ty: BindingType::Texture {
sample_type,
view_dimension: TextureViewDimension::D2,
multisampled,
},
count: None,
}
})
.collect();
if !input_specs.is_empty() {
let sampler_type = if use_filtering {
SamplerBindingType::Filtering
} else {
SamplerBindingType::NonFiltering
};
input_entries.push(BindGroupLayoutEntry {
binding: input_specs.len() as u32,
visibility: ShaderStages::COMPUTE,
ty: BindingType::Sampler(sampler_type),
count: None,
});
}
let input_layout = self
.device
.create_bind_group_layout(&BindGroupLayoutDescriptor {
label: Some("compute_input_layout"),
entries: &input_entries,
});
let output_entries: Vec<BindGroupLayoutEntry> = output_formats
.iter()
.enumerate()
.map(|(i, format)| {
#[cfg(debug_assertions)]
if !is_storage_compatible(*format) {
eprintln!(
"Warning: Format {:?} may not be supported as a storage texture",
format
);
}
BindGroupLayoutEntry {
binding: i as u32,
visibility: ShaderStages::COMPUTE,
ty: BindingType::StorageTexture {
access: StorageTextureAccess::WriteOnly,
format: *format,
view_dimension: TextureViewDimension::D2,
},
count: None,
}
})
.collect();
let output_layout = self
.device
.create_bind_group_layout(&BindGroupLayoutDescriptor {
label: Some("compute_output_layout"),
entries: &output_entries,
});
let buffer_entries: Vec<BindGroupLayoutEntry> = buffer_bindings
.iter()
.enumerate()
.map(|(i, ty)| BindGroupLayoutEntry {
binding: i as u32,
visibility: ShaderStages::COMPUTE,
ty: BindingType::Buffer {
ty: *ty,
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
})
.collect();
let buffer_layout = self
.device
.create_bind_group_layout(&BindGroupLayoutDescriptor {
label: Some("compute_buffer_layout"),
entries: &buffer_entries,
});
let bind_group_layouts = [input_layout, output_layout, buffer_layout];
let pipeline_layout = self
.device
.create_pipeline_layout(&PipelineLayoutDescriptor {
label: Some("compute_pipeline_layout"),
bind_group_layouts: &bind_group_layouts.iter().map(|bgl| Some(bgl)).collect::<Vec<_>>(),
immediate_size: 0,
});
let pipeline = self
.device
.create_compute_pipeline(&ComputePipelineDescriptor {
label: Some(shader_path.to_str().unwrap_or("")),
layout: Some(&pipeline_layout),
module: &shader,
entry_point: Some("main"),
compilation_options: Default::default(),
cache: None,
});
CachedPipeline {
pipeline,
bind_group_layouts,
}
}
fn create_input_bind_group(
&self,
layout: &BindGroupLayout,
views: &[&TextureView],
use_filtering: bool,
) -> BindGroup {
let mut entries: Vec<BindGroupEntry> = views
.iter()
.enumerate()
.map(|(i, view)| BindGroupEntry {
binding: i as u32,
resource: BindingResource::TextureView(view),
})
.collect();
if !views.is_empty() {
let sampler = if use_filtering {
&self.filtering_sampler
} else {
&self.non_filtering_sampler
};
entries.push(BindGroupEntry {
binding: views.len() as u32,
resource: BindingResource::Sampler(sampler),
});
}
self.device.create_bind_group(&BindGroupDescriptor {
label: Some("compute_input_bg"),
layout,
entries: &entries,
})
}
fn create_output_bind_group(
&self,
layout: &BindGroupLayout,
views: &[&TextureView],
) -> BindGroup {
let entries: Vec<BindGroupEntry> = views
.iter()
.enumerate()
.map(|(i, view)| BindGroupEntry {
binding: i as u32,
resource: BindingResource::TextureView(view),
})
.collect();
self.device.create_bind_group(&BindGroupDescriptor {
label: Some("compute_output_bg"),
layout,
entries: &entries,
})
}
fn create_buffer_bind_group(
&self,
layout: &BindGroupLayout,
buffer_sets: &[BufferSet],
label: Option<&str>,
) -> BindGroup {
let entries: Vec<BindGroupEntry> = buffer_sets
.iter()
.enumerate()
.map(|(i, buffer_set)| BindGroupEntry {
binding: i as u32,
resource: buffer_set.buffer.as_entire_binding(),
})
.collect();
self.device.create_bind_group(&BindGroupDescriptor {
label,
layout,
entries: &entries,
})
}
pub fn invalidate_cache(&mut self) {
self.pipeline_cache.clear();
}
fn is_format_filterable(&self, format: TextureFormat, sample_count: u32) -> bool {
if sample_count > 1 {
return false;
}
let aspect = figure_out_aspect(format);
if format.has_depth_aspect() || format.has_stencil_aspect() {
return false;
}
match format.sample_type(aspect, Some(self.device.features())) {
Some(TextureSampleType::Float { filterable }) => filterable,
_ => false,
}
}
}
pub(crate) fn figure_out_aspect(format: TextureFormat) -> Option<TextureAspect> {
if format.has_depth_aspect() && format.has_stencil_aspect() {
Some(TextureAspect::DepthOnly) } else if format.has_depth_aspect() {
Some(TextureAspect::DepthOnly)
} else if format.has_stencil_aspect() {
Some(TextureAspect::StencilOnly)
} else {
None
}
}
fn buffer_binding_type(buffer_set: &BufferSet) -> BufferBindingType {
let usage = buffer_set.buffer.usage();
if usage.contains(BufferUsages::UNIFORM) {
BufferBindingType::Uniform
} else if usage.contains(BufferUsages::STORAGE) {
BufferBindingType::Storage { read_only: buffer_set.read_only }
} else {
panic!(
"Buffer {:?} has unsupported usage {:?} for compute binding",
buffer_set,
usage
);
}
}
fn get_texture_sample_type(
format: TextureFormat,
is_filterable: bool,
multisampled: bool,
) -> TextureSampleType {
use wgpu::TextureFormat::*;
if format.has_depth_aspect() {
return TextureSampleType::Depth;
}
match format {
Stencil8 => TextureSampleType::Uint,
R8Uint | R16Uint | R32Uint
| Rg8Uint | Rg16Uint | Rg32Uint
| Rgba8Uint | Rgba16Uint | Rgba32Uint
| Rgb10a2Uint => TextureSampleType::Uint,
R8Sint | R16Sint | R32Sint
| Rg8Sint | Rg16Sint | Rg32Sint
| Rgba8Sint | Rgba16Sint | Rgba32Sint => TextureSampleType::Sint,
_ => TextureSampleType::Float {
filterable: is_filterable && !multisampled,
},
}
}
fn is_integer_format(format: TextureFormat) -> bool {
use wgpu::TextureFormat::*;
matches!(
format,
R8Uint | R16Uint | R32Uint
| Rg8Uint | Rg16Uint | Rg32Uint
| Rgba8Uint | Rgba16Uint | Rgba32Uint
| Rgb10a2Uint
| R8Sint | R16Sint | R32Sint
| Rg8Sint | Rg16Sint | Rg32Sint
| Rgba8Sint | Rgba16Sint | Rgba32Sint
| Stencil8
)
}
fn is_storage_compatible(format: TextureFormat) -> bool {
use wgpu::TextureFormat::*;
matches!(
format,
R32Float | R32Uint | R32Sint
| Rg32Float | Rg32Uint | Rg32Sint
| Rgba8Unorm | Rgba8Snorm | Rgba8Uint | Rgba8Sint
| Rgba16Float | Rgba16Uint | Rgba16Sint
| Rgba32Float | Rgba32Uint | Rgba32Sint
| R16Float | Rg16Float
| R8Unorm | Rg8Unorm
| R8Uint | R8Sint | Rg8Uint | Rg8Sint
| R16Uint | R16Sint | Rg16Uint | Rg16Sint
| Bgra8Unorm
)
}