use crate::core::{Device, Pipeline, PipelineLayout};
use crate::ex::errors::PipelineError;
use ash::vk;
use std::ffi::CString;
pub struct PipelineBuilder {
vertex_shader: Option<(vk::ShaderModule, CString)>,
fragment_shader: Option<(vk::ShaderModule, CString)>,
geometry_shader: Option<(vk::ShaderModule, CString)>,
#[allow(dead_code)]
tess_control_shader: Option<(vk::ShaderModule, CString)>,
#[allow(dead_code)]
tess_eval_shader: Option<(vk::ShaderModule, CString)>,
compute_shader: Option<(vk::ShaderModule, CString)>,
vertex_bindings: Vec<vk::VertexInputBindingDescription>,
vertex_attributes: Vec<vk::VertexInputAttributeDescription>,
topology: vk::PrimitiveTopology,
primitive_restart: bool,
viewports: Vec<vk::Viewport>,
scissors: Vec<vk::Rect2D>,
dynamic_viewport: bool,
dynamic_scissor: bool,
depth_clamp: bool,
rasterizer_discard: bool,
polygon_mode: vk::PolygonMode,
cull_mode: vk::CullModeFlags,
front_face: vk::FrontFace,
depth_bias_enable: bool,
sample_count: vk::SampleCountFlags,
depth_test: bool,
depth_write: bool,
depth_compare: vk::CompareOp,
color_formats: Vec<vk::Format>,
blend_attachments: Vec<vk::PipelineColorBlendAttachmentState>,
depth_format: Option<vk::Format>,
descriptor_layouts: Vec<vk::DescriptorSetLayout>,
push_constant_ranges: Vec<vk::PushConstantRange>,
dynamic_states: Vec<vk::DynamicState>,
}
impl Default for PipelineBuilder {
fn default() -> Self {
Self::new()
}
}
impl PipelineBuilder {
pub fn new() -> Self {
Self {
vertex_shader: None,
fragment_shader: None,
geometry_shader: None,
tess_control_shader: None,
tess_eval_shader: None,
compute_shader: None,
vertex_bindings: Vec::new(),
vertex_attributes: Vec::new(),
topology: vk::PrimitiveTopology::TRIANGLE_LIST,
primitive_restart: false,
viewports: Vec::new(),
scissors: Vec::new(),
dynamic_viewport: false,
dynamic_scissor: false,
depth_clamp: false,
rasterizer_discard: false,
polygon_mode: vk::PolygonMode::FILL,
cull_mode: vk::CullModeFlags::BACK,
front_face: vk::FrontFace::COUNTER_CLOCKWISE,
depth_bias_enable: false,
sample_count: vk::SampleCountFlags::TYPE_1,
depth_test: false,
depth_write: false,
depth_compare: vk::CompareOp::LESS_OR_EQUAL,
color_formats: Vec::new(),
blend_attachments: Vec::new(),
depth_format: None,
descriptor_layouts: Vec::new(),
push_constant_ranges: Vec::new(),
dynamic_states: Vec::new(),
}
}
pub fn vertex_shader(mut self, module: vk::ShaderModule, entry: &str) -> Self {
self.vertex_shader = Some((module, CString::new(entry).unwrap()));
self
}
pub fn fragment_shader(mut self, module: vk::ShaderModule, entry: &str) -> Self {
self.fragment_shader = Some((module, CString::new(entry).unwrap()));
self
}
pub fn geometry_shader(mut self, module: vk::ShaderModule, entry: &str) -> Self {
self.geometry_shader = Some((module, CString::new(entry).unwrap()));
self
}
pub fn compute_shader(mut self, module: vk::ShaderModule, entry: &str) -> Self {
self.compute_shader = Some((module, CString::new(entry).unwrap()));
self
}
pub fn vertex_bindings(mut self, bindings: Vec<vk::VertexInputBindingDescription>) -> Self {
self.vertex_bindings = bindings;
self
}
pub fn vertex_attributes(
mut self,
attributes: Vec<vk::VertexInputAttributeDescription>,
) -> Self {
self.vertex_attributes = attributes;
self
}
pub fn topology(mut self, topology: vk::PrimitiveTopology) -> Self {
self.topology = topology;
self
}
pub fn primitive_restart(mut self, enable: bool) -> Self {
self.primitive_restart = enable;
self
}
pub fn viewport(mut self, viewport: vk::Viewport) -> Self {
self.viewports = vec![viewport];
self.dynamic_viewport = false;
self
}
pub fn scissor(mut self, scissor: vk::Rect2D) -> Self {
self.scissors = vec![scissor];
self.dynamic_scissor = false;
self
}
pub fn dynamic_viewport(mut self) -> Self {
self.dynamic_viewport = true;
if !self.dynamic_states.contains(&vk::DynamicState::VIEWPORT) {
self.dynamic_states.push(vk::DynamicState::VIEWPORT);
}
self
}
pub fn dynamic_scissor(mut self) -> Self {
self.dynamic_scissor = true;
if !self.dynamic_states.contains(&vk::DynamicState::SCISSOR) {
self.dynamic_states.push(vk::DynamicState::SCISSOR);
}
self
}
pub fn polygon_mode(mut self, mode: vk::PolygonMode) -> Self {
self.polygon_mode = mode;
self
}
pub fn cull_mode(mut self, mode: vk::CullModeFlags) -> Self {
self.cull_mode = mode;
self
}
pub fn front_face(mut self, front_face: vk::FrontFace) -> Self {
self.front_face = front_face;
self
}
pub fn depth_test(mut self, enable: bool, write: bool, compare: vk::CompareOp) -> Self {
self.depth_test = enable;
self.depth_write = write;
self.depth_compare = compare;
self
}
pub fn depth_format(mut self, format: vk::Format) -> Self {
self.depth_format = Some(format);
self
}
pub fn color_attachment_formats(mut self, formats: Vec<vk::Format>) -> Self {
self.blend_attachments = formats
.iter()
.map(|_| vk::PipelineColorBlendAttachmentState {
blend_enable: vk::FALSE,
src_color_blend_factor: vk::BlendFactor::ONE,
dst_color_blend_factor: vk::BlendFactor::ZERO,
color_blend_op: vk::BlendOp::ADD,
src_alpha_blend_factor: vk::BlendFactor::ONE,
dst_alpha_blend_factor: vk::BlendFactor::ZERO,
alpha_blend_op: vk::BlendOp::ADD,
color_write_mask: vk::ColorComponentFlags::RGBA,
})
.collect();
self.color_formats = formats;
self
}
pub fn blend_attachments(
mut self,
attachments: Vec<vk::PipelineColorBlendAttachmentState>,
) -> Self {
self.blend_attachments = attachments;
self
}
pub fn descriptor_layouts(mut self, layouts: Vec<vk::DescriptorSetLayout>) -> Self {
self.descriptor_layouts = layouts;
self
}
pub fn push_constants(mut self, ranges: Vec<vk::PushConstantRange>) -> Self {
self.push_constant_ranges = ranges;
self
}
pub fn build_graphics(self, device: &Device) -> Result<Pipeline, PipelineError> {
let (vert_module, vert_entry) = self.vertex_shader.ok_or(PipelineError::NoVertexShader)?;
let (frag_module, frag_entry) = if !self.rasterizer_discard {
self.fragment_shader
.ok_or(PipelineError::NoFragmentShader)?
} else {
(vk::ShaderModule::null(), CString::new("main").unwrap())
};
if self.color_formats.is_empty() && !self.rasterizer_discard {
return Err(PipelineError::NoColorAttachmentFormats);
}
let mut shader_stages = vec![vk::PipelineShaderStageCreateInfo {
stage: vk::ShaderStageFlags::VERTEX,
module: vert_module,
p_name: vert_entry.as_ptr(),
..Default::default()
}];
if !self.rasterizer_discard {
shader_stages.push(vk::PipelineShaderStageCreateInfo {
stage: vk::ShaderStageFlags::FRAGMENT,
module: frag_module,
p_name: frag_entry.as_ptr(),
..Default::default()
});
}
if let Some((module, entry)) = &self.geometry_shader {
shader_stages.push(vk::PipelineShaderStageCreateInfo {
stage: vk::ShaderStageFlags::GEOMETRY,
module: *module,
p_name: entry.as_ptr(),
..Default::default()
});
}
let layout =
PipelineLayout::new(device, &self.descriptor_layouts, &self.push_constant_ranges)?;
let vertex_input_state = vk::PipelineVertexInputStateCreateInfo {
vertex_binding_description_count: self.vertex_bindings.len() as u32,
p_vertex_binding_descriptions: self.vertex_bindings.as_ptr(),
vertex_attribute_description_count: self.vertex_attributes.len() as u32,
p_vertex_attribute_descriptions: self.vertex_attributes.as_ptr(),
..Default::default()
};
let input_assembly_state = vk::PipelineInputAssemblyStateCreateInfo {
topology: self.topology,
primitive_restart_enable: if self.primitive_restart {
vk::TRUE
} else {
vk::FALSE
},
..Default::default()
};
let viewport_state = vk::PipelineViewportStateCreateInfo {
viewport_count: if self.dynamic_viewport {
1
} else {
self.viewports.len() as u32
},
p_viewports: if self.dynamic_viewport {
std::ptr::null()
} else {
self.viewports.as_ptr()
},
scissor_count: if self.dynamic_scissor {
1
} else {
self.scissors.len() as u32
},
p_scissors: if self.dynamic_scissor {
std::ptr::null()
} else {
self.scissors.as_ptr()
},
..Default::default()
};
let rasterization_state = vk::PipelineRasterizationStateCreateInfo {
depth_clamp_enable: if self.depth_clamp {
vk::TRUE
} else {
vk::FALSE
},
rasterizer_discard_enable: if self.rasterizer_discard {
vk::TRUE
} else {
vk::FALSE
},
polygon_mode: self.polygon_mode,
cull_mode: self.cull_mode,
front_face: self.front_face,
depth_bias_enable: if self.depth_bias_enable {
vk::TRUE
} else {
vk::FALSE
},
line_width: 1.0,
..Default::default()
};
let multisample_state = vk::PipelineMultisampleStateCreateInfo {
rasterization_samples: self.sample_count,
..Default::default()
};
let depth_stencil_state = vk::PipelineDepthStencilStateCreateInfo {
depth_test_enable: if self.depth_test { vk::TRUE } else { vk::FALSE },
depth_write_enable: if self.depth_write {
vk::TRUE
} else {
vk::FALSE
},
depth_compare_op: self.depth_compare,
..Default::default()
};
let color_blend_state = vk::PipelineColorBlendStateCreateInfo {
attachment_count: self.blend_attachments.len() as u32,
p_attachments: self.blend_attachments.as_ptr(),
..Default::default()
};
let dynamic_state = if !self.dynamic_states.is_empty() {
Some(vk::PipelineDynamicStateCreateInfo {
dynamic_state_count: self.dynamic_states.len() as u32,
p_dynamic_states: self.dynamic_states.as_ptr(),
..Default::default()
})
} else {
None
};
let rendering_info = vk::PipelineRenderingCreateInfo {
color_attachment_count: self.color_formats.len() as u32,
p_color_attachment_formats: self.color_formats.as_ptr(),
depth_attachment_format: self.depth_format.unwrap_or(vk::Format::UNDEFINED),
..Default::default()
};
let pipeline_info = vk::GraphicsPipelineCreateInfo {
p_next: &rendering_info as *const _ as *const std::ffi::c_void,
stage_count: shader_stages.len() as u32,
p_stages: shader_stages.as_ptr(),
p_vertex_input_state: &vertex_input_state,
p_input_assembly_state: &input_assembly_state,
p_viewport_state: &viewport_state,
p_rasterization_state: &rasterization_state,
p_multisample_state: &multisample_state,
p_depth_stencil_state: &depth_stencil_state,
p_color_blend_state: &color_blend_state,
p_dynamic_state: dynamic_state.as_ref().map_or(std::ptr::null(), |s| s),
layout: layout.handle(),
..Default::default()
};
let pipeline_handle = unsafe {
device
.handle()
.create_graphics_pipelines(vk::PipelineCache::null(), &[pipeline_info], None)
.map_err(|(_, err)| crate::core::PipelineError::CreationFailed(err))?[0]
};
Ok(Pipeline::from_handle(
pipeline_handle,
layout,
vk::PipelineBindPoint::GRAPHICS,
))
}
}