use tracing::warn;
pub use wgpu::CompareFunction;
use std::borrow::Cow;
use wgpu::ShaderModuleDescriptor;
use wgpu::ShaderSource;
use crate::Gpu;
use crate::GpuError;
use crate::RenderPipeline;
pub trait ColorTargetBuilderExt {
fn blend_over(self) -> Self;
fn blend_over_premult(self) -> Self;
fn blend_add(self) -> Self;
fn blend_subtract(self) -> Self;
fn write_mask(self, mask: u32) -> Self;
}
impl ColorTargetBuilderExt for wgpu::ColorTargetState {
fn blend_over(mut self) -> Self {
self.blend = Some(wgpu::BlendState::ALPHA_BLENDING);
self
}
fn blend_over_premult(mut self) -> Self {
self.blend = Some(wgpu::BlendState::PREMULTIPLIED_ALPHA_BLENDING);
self
}
fn blend_add(mut self) -> Self {
self.blend = Some(wgpu::BlendState {
color: wgpu::BlendComponent {
src_factor: wgpu::BlendFactor::SrcAlpha,
dst_factor: wgpu::BlendFactor::SrcAlpha,
operation: wgpu::BlendOperation::Add,
},
alpha: wgpu::BlendComponent::OVER,
});
self
}
fn blend_subtract(mut self) -> Self {
self.blend = Some(wgpu::BlendState {
color: wgpu::BlendComponent {
src_factor: wgpu::BlendFactor::SrcAlpha,
dst_factor: wgpu::BlendFactor::SrcAlpha,
operation: wgpu::BlendOperation::Subtract,
},
alpha: wgpu::BlendComponent::OVER,
});
self
}
fn write_mask(mut self, mask: u32) -> Self {
self.write_mask = wgpu::ColorWrites::from_bits(mask).unwrap();
self
}
}
pub struct PipelineBuilder<'a> {
gpu: Gpu,
label: Option<&'a str>,
desc: PipelineDescriptor<'a>,
vertex: ShaderModuleDescriptor<'a>,
fragment: Option<ShaderModuleDescriptor<'a>>,
vertex_entry: &'a str,
fragment_entry: &'a str,
fragment_targets: &'a [wgpu::ColorTargetState],
}
#[derive(Default)]
struct PipelineDescriptor<'a> {
bind_group_layouts: &'a [&'a wgpu::BindGroupLayout],
push_constant_ranges: &'a [wgpu::PushConstantRange],
primitive: wgpu::PrimitiveState,
depth_stencil: Option<wgpu::DepthStencilState>,
multisample: wgpu::MultisampleState,
vertex_layouts: &'a [wgpu::VertexBufferLayout<'a>],
}
impl PipelineBuilder<'_> {
pub fn make_spirv(bytes: &[u8]) -> Result<ShaderSource, GpuError> {
let prev_hook = std::panic::take_hook();
std::panic::set_hook(Box::new(|_| {}));
let result = std::panic::catch_unwind(|| wgpu::util::make_spirv(bytes))
.map_err(|_| GpuError::ShaderParseError);
std::panic::set_hook(prev_hook);
result
}
pub fn make_spirv_owned<'f>(mut vec8: Vec<u8>) -> Result<ShaderSource<'f>, GpuError> {
let vec32 = unsafe {
let ratio = std::mem::size_of::<u8>() / std::mem::size_of::<u32>();
let length = vec8.len() * ratio;
let capacity = vec8.capacity() * ratio;
let ptr = vec8.as_mut_ptr() as *mut u32;
std::mem::forget(vec8);
Vec::from_raw_parts(ptr, length, capacity)
};
Ok(ShaderSource::SpirV(Cow::Owned(vec32)))
}
pub fn make_wgsl(wgsl: &str) -> Result<ShaderSource, GpuError> {
Ok(ShaderSource::Wgsl(Cow::Borrowed(wgsl)))
}
pub fn make_wgsl_owned<'f>(wgsl: String) -> Result<ShaderSource<'f>, GpuError> {
Ok(ShaderSource::Wgsl(Cow::Owned(wgsl)))
}
pub fn shader_auto_load<'a, 'b>(path: &'b str) -> Result<ShaderSource<'a>, GpuError> {
if let Ok(spirv) = Self::make_spirv_owned(std::fs::read(path).unwrap()) {
Ok(spirv)
} else if let Ok(wgsl) = Self::make_wgsl_owned(std::fs::read_to_string(path).unwrap()) {
Ok(wgsl)
} else {
Err(GpuError::ShaderParseError)
}
}
pub fn shader_auto(bytes: &[u8]) -> Result<ShaderSource, GpuError> {
if let Ok(spirv) = Self::make_spirv(bytes) {
Ok(spirv)
} else if let Ok(wgsl) = Self::make_wgsl(Self::str_from_bytes(bytes)?) {
Ok(wgsl)
} else {
Err(GpuError::ShaderParseError)
}
}
}
impl<'a> PipelineBuilder<'a> {
pub fn new(gpu: Gpu, label: &'a str) -> Self {
const DEFAULT_FRAGMENT_FORMAT: wgpu::TextureFormat = wgpu::TextureFormat::Bgra8UnormSrgb;
let vertex = wgpu::util::make_spirv(include_bytes!("../../shader/screen.vert.spv"));
let fragment = wgpu::util::make_spirv(include_bytes!("../../shader/uv.frag.spv"));
let vertex = ShaderModuleDescriptor {
label: Some("Default vertex shader"),
source: vertex,
};
let fragment = Some(ShaderModuleDescriptor {
label: Some("Default fragment shader"),
source: fragment,
});
Self {
gpu,
label: Some(label),
desc: PipelineDescriptor::default(),
vertex,
fragment,
vertex_entry: "main",
fragment_entry: "main",
fragment_targets: &[wgpu::ColorTargetState {
format: DEFAULT_FRAGMENT_FORMAT,
blend: Some(wgpu::BlendState::ALPHA_BLENDING),
write_mask: wgpu::ColorWrites::ALL,
}],
}
}
pub fn with_label(mut self, label: &'a str) -> Self {
self.label = Some(label);
self
}
pub fn with_vertex_layouts(mut self, layouts: &'a [wgpu::VertexBufferLayout<'a>]) -> Self {
self.desc.vertex_layouts = layouts;
self
}
pub fn with_fragment_targets(mut self, targets: &'a [wgpu::ColorTargetState]) -> Self {
self.fragment_targets = targets;
self
}
pub fn with_depth(mut self) -> Self {
self.desc.depth_stencil = Some(wgpu::DepthStencilState {
depth_write_enabled: true,
depth_compare: wgpu::CompareFunction::Less,
stencil: wgpu::StencilState::default(),
format: wgpu::TextureFormat::Depth32Float,
bias: wgpu::DepthBiasState::default(),
});
self
}
fn do_depth<F>(&mut self, op: F)
where
F: FnOnce(&mut wgpu::DepthStencilState),
{
if let Some(desc) = self.desc.depth_stencil.as_mut() {
op(desc);
} else {
warn!("Depth mod was called before with_depth() was called in pipeline builder");
}
}
pub fn depth_bias(mut self, constant: i32, slope: f32) -> Self {
self.do_depth(|desc| {
desc.bias.constant = constant;
desc.bias.slope_scale = slope;
});
self
}
pub fn depth_bias_clamp(mut self, clamp: f32) -> Self {
self.do_depth(|desc| {
desc.bias.clamp = clamp;
});
self
}
pub fn depth_compare(mut self, compare: CompareFunction) -> Self {
self.do_depth(|desc| {
desc.depth_compare = compare;
});
self
}
pub fn with_depth_stencil(mut self) -> Self {
self.desc.depth_stencil = Some(wgpu::DepthStencilState {
depth_write_enabled: true,
depth_compare: wgpu::CompareFunction::Less,
stencil: wgpu::StencilState::default(),
format: wgpu::TextureFormat::Depth24PlusStencil8,
bias: wgpu::DepthBiasState::default(),
});
self
}
fn str_from_bytes(bytes: &[u8]) -> Result<&str, GpuError> {
std::str::from_utf8(bytes).map_err(|_| GpuError::ShaderParseError)
}
pub fn load_vertex(mut self, path: &'a str) -> Self {
self.vertex.source = Self::shader_auto_load(path).expect("Load vertex shader");
self
}
pub fn with_vertex(mut self, bytes: &'a [u8]) -> Self {
self.vertex.source = Self::shader_auto(bytes).expect("Parse vertex shader");
self
}
pub fn with_fragment(mut self, bytes: &'static [u8]) -> Self {
self.fragment = Some(ShaderModuleDescriptor {
label: Some("Default fragment shader"),
source: Self::shader_auto(bytes).expect("Parse fragment shader"),
});
self
}
pub const fn with_fragment_entry(mut self, entry: &'a str) -> Self {
self.fragment_entry = entry;
self
}
pub const fn with_vertex_entry(mut self, entry: &'a str) -> Self {
self.vertex_entry = entry;
self
}
pub fn with_vertex_fragment(mut self, bytes: &'static [u8]) -> Self {
self.vertex_entry = "vs_main";
self.fragment_entry = "fs_main";
self.with_vertex(bytes).with_fragment(bytes)
}
pub fn with_fragment_opt(self, fragment_bytes: Option<&'static [u8]>) -> Self {
if let Some(bytes) = fragment_bytes {
self.with_fragment(bytes)
} else {
self
}
}
pub fn load_fragment(mut self, fragment: &'a str) -> Self {
self.fragment = Some(ShaderModuleDescriptor {
label: Some("Default fragment shader"),
source: Self::shader_auto_load(fragment).expect("Load fragment shader"),
});
self
}
pub const fn with_bind_groups(mut self, bind_groups: &'a [&wgpu::BindGroupLayout]) -> Self {
self.desc.bind_group_layouts = bind_groups;
self
}
pub const fn cull_front(mut self) -> Self {
self.desc.primitive.cull_mode = Some(wgpu::Face::Front);
self
}
pub const fn cull_back(mut self) -> Self {
self.desc.primitive.cull_mode = Some(wgpu::Face::Back);
self
}
pub const fn wireframe(mut self) -> Self {
self.desc.primitive.polygon_mode = wgpu::PolygonMode::Line;
self
}
pub const fn vertex_points(mut self) -> Self {
self.desc.primitive.topology = wgpu::PrimitiveTopology::PointList;
self
}
pub const fn vertex_lines(mut self, strip: bool) -> Self {
self.desc.primitive.topology = if strip {
wgpu::PrimitiveTopology::LineStrip
} else {
wgpu::PrimitiveTopology::LineList
};
self
}
pub const fn vertex_triangles(mut self, strip: bool) -> Self {
self.desc.primitive.topology = if strip {
wgpu::PrimitiveTopology::TriangleStrip
} else {
wgpu::PrimitiveTopology::TriangleList
};
self
}
#[must_use]
pub fn create(&self) -> RenderPipeline {
let vertex_module = self.gpu.device.create_shader_module(&self.vertex);
let fragment_module = self
.fragment
.as_ref()
.map(|fragment| self.gpu.device.create_shader_module(fragment));
let fragment = fragment_module
.as_ref()
.map(|fs_module| wgpu::FragmentState {
module: fs_module,
entry_point: self.fragment_entry,
targets: self.fragment_targets,
});
let layout = self
.gpu
.device
.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: self.label_suffix("pipeline layout").as_deref(),
bind_group_layouts: self.desc.bind_group_layouts,
push_constant_ranges: self.desc.push_constant_ranges,
});
let pipeline_desc = wgpu::RenderPipelineDescriptor {
layout: Some(&layout),
label: self.label,
vertex: wgpu::VertexState {
module: &vertex_module,
entry_point: self.vertex_entry,
buffers: self.desc.vertex_layouts,
},
primitive: self.desc.primitive,
depth_stencil: self.desc.depth_stencil.clone(),
multisample: self.desc.multisample,
fragment,
multiview: None,
};
let pipeline = self.gpu.device.create_render_pipeline(&pipeline_desc);
RenderPipeline {
depth_stencil: self.desc.depth_stencil.clone(),
gpu: self.gpu.clone(),
inner: pipeline,
}
}
fn label_suffix(&self, suffix: &str) -> Option<String> {
self.label.map(|label| format!("{} {}", label, suffix))
}
}