#![allow(dead_code)]
use std::collections::{HashMap};
use std::hash::{DefaultHasher, Hash, Hasher};
use std::path::{Path, PathBuf};
use wgpu::*;
use crate::shader_preprocessing::compile_wgsl;
#[derive(Clone, Debug)]
pub struct ShadowOptions {
pub sampler: Sampler,
pub view: TextureView,
}
#[derive(Clone, Debug)]
pub enum FragmentOption {
None,
Default {
targets: Vec<Option<ColorTargetState>>,
},
Custom(FragmentState<'static>),
}
#[derive(Clone, Debug)]
pub struct PipelineOptions<'a> {
pub topology: PrimitiveTopology,
pub msaa_samples: u32,
pub depth_stencil: Option<DepthStencilState>,
pub vertex_layouts: Vec<VertexBufferLayout<'static>>,
pub cull_mode: Option<Face>,
pub fragment: FragmentOption,
pub shadow: Option<ShadowOptions>,
pub sampler: SamplerDescriptor<'a>
}
impl Default for PipelineOptions<'_> {
fn default() -> Self {
Self {
topology: PrimitiveTopology::TriangleList,
msaa_samples: 1,
depth_stencil: None,
vertex_layouts: vec![],
cull_mode: None,
fragment: FragmentOption::Default {targets: vec![]},
shadow: None,
sampler: SamplerDescriptor {
label: Some("material sampler"),
address_mode_u: AddressMode::Repeat,
address_mode_v: AddressMode::Repeat,
address_mode_w: AddressMode::Repeat,
mag_filter: FilterMode::Linear,
min_filter: FilterMode::Linear,
mipmap_filter: MipmapFilterMode::Linear,
..Default::default()
},
}
}
}
impl PipelineOptions<'_> {
pub fn with_topology(mut self, topology: PrimitiveTopology) -> Self {
self.topology = topology;
self
}
pub fn with_msaa(mut self, samples: u32) -> Self {
self.msaa_samples = samples;
self
}
pub fn with_depth_stencil(mut self, state: DepthStencilState) -> Self {
self.depth_stencil = Some(state);
self
}
pub fn with_vertex_layout(mut self, layout: VertexBufferLayout<'static>) -> Self {
self.vertex_layouts.push(layout);
self
}
pub fn with_cull_mode(mut self, cull: Face) -> Self {
self.cull_mode = Some(cull);
self
}
pub fn with_target(mut self, target: ColorTargetState) -> Self {
match &mut self.fragment {
FragmentOption::None => {
self.fragment = FragmentOption::Default {targets: vec![Some(target)]};
self
}
FragmentOption::Default { targets } => {
targets.push(Some(target));
self
}
FragmentOption::Custom(_) => {
self
}
}
}
pub fn depth_only(mut self) -> Self {
self.fragment = FragmentOption::None;
self
}
}
#[derive(Clone, Debug, Hash, PartialEq, Eq)]
struct DepthStencilKey {
format: TextureFormat,
depth_write_enabled: bool,
depth_compare: CompareFunction,
}
impl From<&DepthStencilState> for DepthStencilKey {
fn from(d: &DepthStencilState) -> Self {
Self {
format: d.format,
depth_write_enabled: d.depth_write_enabled.unwrap_or(false),
depth_compare: d.depth_compare.unwrap_or(CompareFunction::Always),
}
}
}
#[derive(Clone, Debug, Hash, PartialEq, Eq)]
struct PipelineKey {
shader_path: PathBuf,
layout_hash: u64,
topology: PrimitiveTopology,
msaa_samples: u32,
depth_stencil: Option<DepthStencilKey>,
cull_mode: Option<Face>,
fragment_hash: u64,
defines_hash: u64,
}
struct ShaderEntry {
module: ShaderModule,
}
#[derive(Hash, Eq, PartialEq)]
struct ShaderKey {
shader_path: PathBuf,
defines_hash: u64,
}
pub struct PipelineCache {
device: Device,
shaders: HashMap<ShaderKey, ShaderEntry>,
pipelines: HashMap<PipelineKey, RenderPipeline>,
pub(crate) uniform_layouts: HashMap<usize, BindGroupLayout>,
}
impl PipelineCache {
pub fn new(device: Device) -> Self {
Self {
device,
shaders: HashMap::new(),
pipelines: HashMap::new(),
uniform_layouts: HashMap::new(),
}
}
pub(crate) fn device(&self) -> &Device {
&self.device
}
pub(crate) fn uniform_layout(&mut self, buffer_count: usize) -> &BindGroupLayout {
if !self.uniform_layouts.contains_key(&buffer_count) {
let entries: Vec<BindGroupLayoutEntry> = (0..buffer_count)
.map(|i| BindGroupLayoutEntry {
binding: i as u32,
visibility: ShaderStages::VERTEX | ShaderStages::FRAGMENT,
ty: BindingType::Buffer {
ty: BufferBindingType::Uniform,
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
})
.collect();
let layout = self.device.create_bind_group_layout(&BindGroupLayoutDescriptor {
label: Some(&format!("uniform layout ({})", buffer_count)),
entries: &entries,
});
self.uniform_layouts.insert(buffer_count, layout);
}
self.uniform_layouts.get(&buffer_count).unwrap()
}
pub(crate) fn create_uniform_bind_group(&mut self, buffers: &[&Buffer], label: &str) -> BindGroup {
let layout = &self.uniform_layout(buffers.len()).clone();
let entries: Vec<BindGroupEntry> = buffers
.iter()
.enumerate()
.map(|(i, buf)| BindGroupEntry {
binding: i as u32,
resource: buf.as_entire_binding(),
})
.collect();
self.device.create_bind_group(&BindGroupDescriptor {
label: Some(label),
layout,
entries: &entries,
})
}
pub(crate) fn get_or_create(
&mut self,
shader_path: &Path,
bind_group_layouts: &[Option<&BindGroupLayout>],
options: &PipelineOptions,
defines: &HashMap<String, bool>,
) -> &RenderPipeline {
let layout_hash = hash_layouts(bind_group_layouts, &options.vertex_layouts);
let key = PipelineKey {
shader_path: shader_path.to_path_buf(),
layout_hash,
topology: options.topology,
msaa_samples: options.msaa_samples,
depth_stencil: options.depth_stencil.as_ref().map(|d| d.into()),
cull_mode: options.cull_mode,
fragment_hash: hash_fragment(&options.fragment),
defines_hash: hash_defines(defines)
};
if !self.pipelines.contains_key(&key) {
self.load_shader(shader_path, defines);
let pipeline = self.create_pipeline(&key, bind_group_layouts, options, defines);
self.pipelines.insert(key.clone(), pipeline);
}
self.pipelines.get(&key).unwrap()
}
pub(crate) fn reload_shaders(&mut self, paths: &[PathBuf], defines: &HashMap<String, bool>) {
for path in paths {
let shader_key = ShaderKey {
shader_path: path.clone(),
defines_hash: hash_defines(defines)
};
if self.shaders.contains_key(&shader_key) {
self.load_shader(path, defines);
}
}
self.pipelines.retain(|key, _| !paths.contains(&key.shader_path));
}
pub(crate) fn clear(&mut self) {
self.shaders.clear();
self.pipelines.clear();
}
fn load_shader(&mut self, path: &Path, defines: &HashMap<String, bool>) {
let shader_key = ShaderKey {
shader_path: path.to_path_buf(),
defines_hash: hash_defines(defines)
};
if self.shaders.contains_key(&shader_key) {
return;
}
let module = compile_wgsl(&self.device, path, defines);
self.shaders.insert(shader_key, ShaderEntry { module });
}
fn create_pipeline(
&self,
key: &PipelineKey,
bind_group_layouts: &[Option<&BindGroupLayout>],
options: &PipelineOptions,
defines: &HashMap<String, bool>,
) -> RenderPipeline {
let shader_key = ShaderKey {
shader_path: key.shader_path.clone(),
defines_hash: hash_defines(defines)
};
let shader = &self.shaders.get(&shader_key).unwrap().module;
let pipeline_layout = self.device.create_pipeline_layout(&PipelineLayoutDescriptor {
label: Some(&format!("{} layout", key.shader_path.display())),
bind_group_layouts,
immediate_size: 0,
});
let fragment = match &options.fragment {
FragmentOption::None => None,
FragmentOption::Default { targets } => {
Some(FragmentState {
module: shader,
entry_point: Some("fs_main"),
targets: &targets,
compilation_options: Default::default(),
})
}
FragmentOption::Custom(f) => Some(f.clone()),
};
self.device.create_render_pipeline(&RenderPipelineDescriptor {
label: Some(&format!("{} Pipeline", key.shader_path.display())),
layout: Some(&pipeline_layout),
vertex: VertexState {
module: shader,
entry_point: Some("vs_main"),
buffers: &options.vertex_layouts,
compilation_options: Default::default(),
},
fragment,
primitive: PrimitiveState {
topology: options.topology,
strip_index_format: None,
front_face: FrontFace::Ccw,
cull_mode: options.cull_mode,
polygon_mode: PolygonMode::Fill,
unclipped_depth: false,
conservative: false,
},
depth_stencil: options.depth_stencil.clone(),
multisample: MultisampleState {
count: options.msaa_samples,
mask: !0,
alpha_to_coverage_enabled: false,
},
cache: None,
multiview_mask: None,
})
}
}
fn hash_fragment(fragment: &FragmentOption) -> u64 {
let mut hasher = DefaultHasher::new();
match fragment {
FragmentOption::None => {
0u8.hash(&mut hasher); }
FragmentOption::Default { targets } => {
1u8.hash(&mut hasher); hash_targets(&mut hasher, targets);
}
FragmentOption::Custom(f) => {
2u8.hash(&mut hasher); f.entry_point.hash(&mut hasher);
f.module.hash(&mut hasher);
hash_targets(&mut hasher, f.targets);
}
}
hasher.finish()
}
fn hash_targets(hasher: &mut impl Hasher, targets: &[Option<ColorTargetState>]) {
targets.len().hash(hasher);
for target in targets {
match target {
Some(state) => {
true.hash(hasher);
state.format.hash(hasher);
hash_blend_state(hasher, &state.blend);
state.write_mask.bits().hash(hasher);
}
None => {
false.hash(hasher);
}
}
}
}
fn hash_blend_state(hasher: &mut impl Hasher, blend: &Option<BlendState>) {
match blend {
Some(b) => {
true.hash(hasher);
b.color.src_factor.hash(hasher);
b.color.dst_factor.hash(hasher);
b.color.operation.hash(hasher);
b.alpha.src_factor.hash(hasher);
b.alpha.dst_factor.hash(hasher);
b.alpha.operation.hash(hasher);
}
None => {
false.hash(hasher);
}
}
}
fn hash_layouts(bgls: &[Option<&BindGroupLayout>], vertex_layouts: &[VertexBufferLayout]) -> u64 {
let mut hasher = DefaultHasher::new();
for &bgl in bgls {
bgl.hash(&mut hasher); }
for vl in vertex_layouts {
vl.hash(&mut hasher); }
hasher.finish()
}
pub fn hash_defines(defines: &HashMap<String, bool>) -> u64 { let mut keys: Vec<_> = defines.keys().collect();
keys.sort_unstable();
let mut hasher = DefaultHasher::new();
for k in keys {
k.hash(&mut hasher);
defines.get(k).unwrap().hash(&mut hasher);
}
hasher.finish()
}